我想知道在我的转发功能中使用 torch.cat 是否可以。我这样做是因为我希望输入的前两列跳过中间隐藏层并直接进入最后一层。
这是我的代码:你可以看到我在最后一刻使用了 torch.cat 来制作 xcat。
梯度会向后传播吗?还是 torch.cat 掩盖了我的隐藏变量发生了什么?
class LinearRegressionForce(nn.Module):
def __init__(self, focus_input_size, rest_input_size, hidden_size_1, hidden_size_2, output_size):
super(LinearRegressionForce, self).__init__()
self.in1 = nn.Linear(rest_input_size, hidden_size_1)
self.middle1 = nn.Linear(hidden_size_1,hidden_size_2)
self.out4 = nn.Linear(focus_input_size + hidden_size_2,output_size)
def forward(self, inputs):
focus_inputs = inputs[:,0:focus_input_size]
rest_inputs = inputs[:,focus_input_size:(rest_input_size+focus_input_size)]
x = self.in1(rest_inputs).clamp(min=0)
x = self.middle1(x).clamp(min=0)
xcat = torch.cat((focus_inputs,x),1)
out = self.out4(xcat).clamp(min=0)
return out
我这样称呼它:
rest_inputs = Variable(torch.from_numpy(rest_x_train))
focus_x_train_ones = np.concatenate((focus_x_train, np.ones((n,1))), axis=1)
focus_inputs = Variable(torch.from_numpy(focus_x_train_ones)).float()
inputs = torch.cat((focus_inputs,rest_inputs),1)
predicted = model(inputs).data.numpy()
