对于 Unet 中的 SpatialAveragePooling,输出大小太小

数据挖掘 Python 深度学习 火炬
2022-03-04 03:57:22

尝试从此处使用 Resnext 作为 Unet 中的编码器,但不断收到 RuntimeError: Given input size: (4320x4x4)。计算输出大小:(4320x-6x-6)。/opt/conda/conda-bld/pytorch_1525796793591/work/torch/lib/THCUNN/generic/SpatialAveragePooling.cu:63 的输出大小太小

(输入图像 128*128 批量大小 32)

class UNetResNext(nn.Module):
    def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                 pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d
    if encoder_depth == 34:
        self.encoder = resnext34()
        bottom_channel_nr = 512
    elif encoder_depth == 101:
        self.encoder = resnext101()
        bottom_channel_nr = 2048
    elif encoder_depth == 152:
        self.encoder = resnext152()
        bottom_channel_nr = 2048

    else:
        raise NotImplementedError('only 34, 101, 152 version of Resnext are implemented')

    self.pool = nn.MaxPool2d(2, 2)

    self.relu = nn.ReLU(inplace=True)

    self.conv1 = nn.Sequential(self.encoder.conv1,
                               self.encoder.bn1,
                               self.encoder.relu,
                               self.pool)

    self.conv2 = self.encoder.layer1
    self.conv3 = self.encoder.layer2
    self.conv4 = self.encoder.layer3
    self.conv5 = self.encoder.layer4
    self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)

    self.dec5 =  DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
    self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
    self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
    self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                               is_deconv)
    self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
    self.dec0 = ConvRelu(num_filters, num_filters)
    self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

def forward(self, x):
    conv1 = self.conv1(x)
    conv2 = self.conv2(conv1)
    conv3 = self.conv3(conv2)
    conv4 = self.conv4(conv3)
    conv5 = self.conv5(conv4)
    center = self.center(conv5)
    dec5 = self.dec5(torch.cat([center, conv5], 1))
    dec4 = self.dec4(torch.cat([dec5, conv4], 1))
    dec3 = self.dec3(torch.cat([dec4, conv3], 1))
    dec2 = self.dec2(torch.cat([dec3, conv2], 1))
    dec1 = self.dec1(dec2)
    dec0 = self.dec0(dec1)

    return self.final(F.dropout2d(dec0, p=self.dropout_2d))

class DecoderBlockV2(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
        super(DecoderBlockV2, self).__init__()
        self.in_channels = in_channels

        if is_deconv:
            """
                Paramaters for Deconvolution were chosen to avoid artifacts, following
                link https://distill.pub/2016/deconv-checkerboard/
            """

            self.block = nn.Sequential(
                ConvRelu(in_channels, middle_channels),
                nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
                                   padding=1),
                nn.BatchNorm2d(out_channels), ##me added
                nn.ReLU(inplace=True)

            )
        else:
            self.block = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='bilinear'),
                ConvRelu(in_channels, middle_channels),
                ConvRelu(middle_channels, out_channels),
            )

    def forward(self, x):
        return self.block(x)

我想知道如何正确设置一切以开始工作?也许输入图像尺寸太小 - 128*128

0个回答
没有发现任何回复~