我正在创建一个消息传递神经网络,并且在创建数据集时遇到了一些问题。在 pytorch (geometric) 中,建议使用以下类创建数据集。我想知道在 process 方法结束时调用的 collate 函数是什么意思?在什么情况下我应该使用自己的整理功能?我的图表大多有不同的大小。
import torch
from torch_geometric.data import InMemoryDataset
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# Download to `self.raw_dir`.
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
我得到的错误是:
RuntimeError:无效参数 0:张量的大小必须匹配,但维度 0 除外。在 /opt/conda/conda-bld/pytorch_1573049304260/work/aten/src/THC/generic/THCTensorMath.cu:71 的维度 1 中得到 4422 和 4032