torch.cat 是否适用于反向传播?

数据挖掘 火炬 火炬
2021-10-05 11:04:51

我想知道在我的转发功能中使用 torch.cat 是否可以。我这样做是因为我希望输入的前两列跳过中间隐藏层并直接进入最后一层。

这是我的代码:你可以看到我在最后一刻使用了 torch.cat 来制作 xcat。

梯度会向后传播吗?还是 torch.cat 掩盖了我的隐藏变量发生了什么?

class LinearRegressionForce(nn.Module):

    def __init__(self, focus_input_size, rest_input_size, hidden_size_1, hidden_size_2, output_size):
        super(LinearRegressionForce, self).__init__()

        self.in1 = nn.Linear(rest_input_size, hidden_size_1) 
        self.middle1 = nn.Linear(hidden_size_1,hidden_size_2)
        self.out4 = nn.Linear(focus_input_size + hidden_size_2,output_size)

    def forward(self, inputs):
        focus_inputs = inputs[:,0:focus_input_size]
        rest_inputs = inputs[:,focus_input_size:(rest_input_size+focus_input_size)]
        x = self.in1(rest_inputs).clamp(min=0)
        x = self.middle1(x).clamp(min=0)
        xcat = torch.cat((focus_inputs,x),1)
        out = self.out4(xcat).clamp(min=0)
        return out

我这样称呼它:

rest_inputs = Variable(torch.from_numpy(rest_x_train))
focus_x_train_ones = np.concatenate((focus_x_train, np.ones((n,1))), axis=1)
focus_inputs =  Variable(torch.from_numpy(focus_x_train_ones)).float()
inputs = torch.cat((focus_inputs,rest_inputs),1)

predicted = model(inputs).data.numpy()
1个回答

是的,torch.cat适用于backward操作。

在此处输入图像描述