在这个 pyTorch conv 网络中如何调用 forward 方法?

数据挖掘 神经网络 火炬
2021-09-17 08:33:19

在这个来自pyTorch 教程的示例网络中

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

net = Net()
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

为什么没有明确调用 forward() 方法?我的意思是如何调用 net(output) 调用 forward() ?(据我所知,这是发生的事情)顺便说一句,我不明白这条线的含义:

super(Net, self).__init__()

我可以想象 super() 正在调用父类的构造函数但是......?

1个回答

如果你查看 pyTorch 的Module 实现,你会发现 forward 是一个在特殊方法中调用的方法__call__

class Module(object):
   ...
   def __call__(self, *input, **kwargs):
      ...
      result = self.forward(*input, **kwargs)

当您Net通过从类继承来构造一个类Module并覆盖__init__构造函数的默认行为时,您还需要显式调用父级的super(Net, self).__init__().