pytorch(几何)中的整理功能是什么?

数据挖掘 火炬 pytorch-几何
2022-02-25 07:44:46

我正在创建一个消息传递神经网络,并且在创建数据集时遇到了一些问题。在 pytorch (geometric) 中,建议使用以下类创建数据集。我想知道在 process 方法结束时调用的 collat​​e 函数是什么意思?在什么情况下我应该使用自己的整理功能?我的图表大多有不同的大小。

import torch
from torch_geometric.data import InMemoryDataset


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

我得到的错误是:

RuntimeError:无效参数 0:张量的大小必须匹配,但维度 0 除外。在 /opt/conda/conda-bld/pytorch_1573049304260/work/aten/src/THC/generic/THCTensorMath.cu:71 的维度 1 中得到 4422 和 4032

1个回答

collat​​e 的作用和原因:

因为保存一个巨大的 python 列表真的很慢,我们在保存之前通过 torch_geometric.data.InMemoryDataset.collat​​e() 将列表整理成一个巨大的 torch_geometric.data.Data 对象。整理后的数据对象将所有示例连接到一个大数据对象中,此外,还返回一个切片字典以从该对象重构单个示例。最后,我们需要在构造函数中将这两个对象加载到 self.data 和 self.slices 属性中。

来源