在pytorch中实现时空卷积

数据挖掘 深度学习 计算机视觉 Python
2021-09-17 23:37:55

我正在尝试实现一个层来执行本文中描述的 (2+1)D 卷积:https ://arxiv.org/pdf/1711.11248.pdf

基本思想如下:假设我有一个 3D 卷积层,它接受一个输入Ni1渠道,并执行Ni3D 张量上的 3D 卷积表示T2D 帧被接管T时间步长。

该方法不是在张量上进行 3D 卷积,而是将 3D 卷积替换为 2D 卷积,然后沿时间轴进行 1D 卷积。特别是,如果你想表演Ni具有大小内核的 3D 卷积Ni1×t×d×d, 你改为执行Mi具有大小过滤器的 2D 卷积Ni1×1×d×d其次是Ni沿时间轴大小的一维卷积Mi×t×1×1.

下面是这种层的tensorflow实现:https ://github.com/facebookresearch/R2Plus1D/blob/master/lib/models/video_model.py

但是,我在 pytorch 中实现这种类型的层时遇到了麻烦。这就是我现在所拥有的,我显然做错了什么。如果有人能指出我将如何进行此类卷积的正确方向,我将不胜感激。我正在设置中间通道维度Mi在这个片段中到 20:

class hybrid3d(nn.Module):

def __init__(self, n_classes):
    super(hybrid3d, self).__init__()
    t, d, d = 3, 3, 3
    self.conv1_1 = nn.Conv2d(3, 20, kernel_size=(1, d, d), stride=[1,1,1] , padding=[0,0,0])
    self.relu1_1 = nn.PReLU()
    self.conv1_2 = nn.Conv1d(20, 12, kernel_size=(t, 1, 1), stride=[1,1,1], padding=[0,0,0])
    self.relu1_2 = nn.PReLU()


def forward(self, x):
    x = self.relu1_1(self.conv1_1(x))
    x = self.relu1_2(self.conv1_2(x))
    return x
1个回答
import torch as t

class Net_1(t.nn.Module):
    def __init__(self):
        super(Net_1, self).__init__()
        self.conv3d = t.nn.Conv3d(3, 8, (2, 3, 3))
        self.relu = t.nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.conv3d(x))
        return x
   
net_1 = Net_1()

video_1 = t.empty(1, 3, 6, 72, 108).normal_()
%%时间
net_1(video_1).shape

挂墙时间:18 毫秒

火炬尺寸([1, 8, 5, 70, 106])

class Net_2(t.nn.Module):
    def __init__(self):
        super(Net_2, self).__init__()
        self.conv2d = t.nn.Conv2d(3, 8, (3, 3))
        self.conv1d = t.nn.Conv1d(8, 8, (2))
        self.flatten = t.nn.Flatten(start_dim=2)
        self.relu = t.nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.conv2d(x))
        x = self.flatten(x)
        x.transpose_(0, 2)
        x = self.relu(self.conv1d(x))
        x.transpose_(0, 2)
        x.transpose_(0, 1)
        return x.view(5, 8, 70, 106)
        
net_2 = Net_2()

video_2 = t.empty(6, 3, 72, 108).normal_()
%%时间
net_2(video_2).shape

挂墙时间:58 毫秒

火炬大小([8, 5, 70, 106])