卷积如何表示为矩阵乘法(矩阵形式)?

信息处理 图像处理 卷积 过滤 边缘检测 坡度
2022-01-15 07:08:28

我知道这个问题可能与编程不太相关,但如果我不了解图像处理背后的理论,我将永远无法在实践中实现某些东西。

如果我做对了,高斯滤波器与图像进行卷积以降低噪声,因为它们计算像素邻域的加权平均值,并且它们在边缘检测中非常有用,因为您可以应用模糊并同时导出图像简单地与高斯函数的导数进行卷积。

但是谁能解释我,或者给我一些关于它们是如何计算的参考?

例如, Canny 的边缘检测器谈到了 5x5 高斯滤波器,但他们是如何获得这些特定数字的?它们是如何从连续卷积变为矩阵乘法的?

4个回答

要使此操作起作用,您需要想象您的图像被重新塑造为矢量。然后,该向量在其左侧乘以卷积矩阵,以获得模糊图像。请注意,结果也是与输入大小相同的向量,即大小相同的图像。

卷积矩阵的每一行对应于输入图像中的一个像素。它包含图像中所有其他像素对所考虑像素的模糊对应部分的贡献的权重。

举个例子:box blur of size3×3大小图像上的像素6×6像素。重塑后的图像是一列 36 个选区,而模糊矩阵有大小36×36.

  • 让我们将此矩阵初始化为 0。
  • 现在,考虑坐标的像素(一世,j)在输入图像中(为简单起见,不在其边界上)。它的模糊对应物是通过应用权重获得的1/9对它自己和它的每个邻居在位置上(一世-1,j-1);(一世-1,j),(一世-1,j+1),,(一世+1,j+1).
  • 在列向量中,像素(一世,j)有地位6*一世+j(假设字典顺序)。我们报告重量1/9在里面(6一世+j)- 模糊矩阵的第 行。
  • 对所有其他像素执行相同操作。

可以在我的个人博客Computers Don't See (Yet) - Linearized model of LBPs上找到密切相关的过程(卷积 + 减法)的可视化说明。

我在我的StackOverflow Q2080835 GitHub 存储库中编写了一个函数来解决这个问题(看看CreateImageConvMtx())。
实际上,该函数可以支持您想要的任何卷积形状-full.samevalid

代码如下:

function [ mK ] = CreateImageConvMtx( mH, numRows, numCols, convShape )

CONVOLUTION_SHAPE_FULL  = 1;
CONVOLUTION_SHAPE_SAME  = 2;
CONVOLUTION_SHAPE_VALID = 3;

switch(convShape)
    case(CONVOLUTION_SHAPE_FULL)
        % Code for the 'full' case
        convShapeString = 'full';
    case(CONVOLUTION_SHAPE_SAME)
        % Code for the 'same' case
        convShapeString = 'same';
    case(CONVOLUTION_SHAPE_VALID)
        % Code for the 'valid' case
        convShapeString = 'valid';
end

mImpulse = zeros(numRows, numCols);

for ii = numel(mImpulse):-1:1
    mImpulse(ii)    = 1; %<! Create impulse image corresponding to i-th output matrix column
    mTmp            = sparse(conv2(mImpulse, mH, convShapeString)); %<! The impulse response
    cColumn{ii}     = mTmp(:);
    mImpulse(ii)    = 0;
end

mK = cell2mat(cColumn);


end

我还创建了一个函数来创建一个用于图像过滤的矩阵(类似于 MATLAB 的想法imfilter()):

function [ mK ] = CreateImageFilterMtx( mH, numRows, numCols, operationMode, boundaryMode )
%UNTITLED6 Summary of this function goes here
%   Detailed explanation goes here

OPERATION_MODE_CONVOLUTION = 1;
OPERATION_MODE_CORRELATION = 2;

BOUNDARY_MODE_ZEROS         = 1;
BOUNDARY_MODE_SYMMETRIC     = 2;
BOUNDARY_MODE_REPLICATE     = 3;
BOUNDARY_MODE_CIRCULAR      = 4;

switch(operationMode)
    case(OPERATION_MODE_CONVOLUTION)
        mH = mH(end:-1:1, end:-1:1);
    case(OPERATION_MODE_CORRELATION)
        % mH = mH; %<! Default Code is correlation
end

switch(boundaryMode)
    case(BOUNDARY_MODE_ZEROS)
        mK = CreateConvMtxZeros(mH, numRows, numCols);
    case(BOUNDARY_MODE_SYMMETRIC)
        mK = CreateConvMtxSymmetric(mH, numRows, numCols);
    case(BOUNDARY_MODE_REPLICATE)
        mK = CreateConvMtxReplicate(mH, numRows, numCols);
    case(BOUNDARY_MODE_CIRCULAR)
        mK = CreateConvMtxCircular(mH, numRows, numCols);
end


end


function [ mK ] = CreateConvMtxZeros( mH, numRows, numCols )
%UNTITLED6 Summary of this function goes here
%   Detailed explanation goes here

numElementsImage    = numRows * numCols;
numRowsKernel       = size(mH, 1);
numColsKernel       = size(mH, 2);
numElementsKernel   = numRowsKernel * numColsKernel;

vRows = reshape(repmat(1:numElementsImage, numElementsKernel, 1), numElementsImage * numElementsKernel, 1);
vCols = zeros(numElementsImage * numElementsKernel, 1);
vVals = zeros(numElementsImage * numElementsKernel, 1);

kernelRadiusV = floor(numRowsKernel / 2);
kernelRadiusH = floor(numColsKernel / 2);

pxIdx       = 0;
elmntIdx    = 0;

for jj = 1:numCols
    for ii = 1:numRows
        pxIdx = pxIdx + 1;
        for ll = -kernelRadiusH:kernelRadiusH
            for kk = -kernelRadiusV:kernelRadiusV
                elmntIdx = elmntIdx + 1;

                pxShift = (ll * numCols) + kk;

                if((ii + kk <= numRows) && (ii + kk >= 1) && (jj + ll <= numCols) && (jj + ll >= 1))
                    vCols(elmntIdx) = pxIdx + pxShift;
                    vVals(elmntIdx) = mH(kk + kernelRadiusV + 1, ll + kernelRadiusH + 1);
                else
                    vCols(elmntIdx) = pxIdx;
                    vVals(elmntIdx) = 0; % See the accumulation property of 'sparse()'.
                end
            end
        end
    end
end

mK = sparse(vRows, vCols, vVals, numElementsImage, numElementsImage);


end


function [ mK ] = CreateConvMtxSymmetric( mH, numRows, numCols )
%UNTITLED6 Summary of this function goes here
%   Detailed explanation goes here

numElementsImage    = numRows * numCols;
numRowsKernel       = size(mH, 1);
numColsKernel       = size(mH, 2);
numElementsKernel   = numRowsKernel * numColsKernel;

vRows = reshape(repmat(1:numElementsImage, numElementsKernel, 1), numElementsImage * numElementsKernel, 1);
vCols = zeros(numElementsImage * numElementsKernel, 1);
vVals = zeros(numElementsImage * numElementsKernel, 1);

kernelRadiusV = floor(numRowsKernel / 2);
kernelRadiusH = floor(numColsKernel / 2);

pxIdx       = 0;
elmntIdx    = 0;

for jj = 1:numCols
    for ii = 1:numRows
        pxIdx = pxIdx + 1;
        for ll = -kernelRadiusH:kernelRadiusH
            for kk = -kernelRadiusV:kernelRadiusV
                elmntIdx = elmntIdx + 1;

                pxShift = (ll * numCols) + kk;

                if(ii + kk > numRows)
                    pxShift = pxShift - (2 * (ii + kk - numRows) - 1);
                end

                if(ii + kk < 1)
                    pxShift = pxShift + (2 * (1 -(ii + kk)) - 1);
                end

                if(jj + ll > numCols)
                    pxShift = pxShift - ((2 * (jj + ll - numCols) - 1) * numCols);
                end

                if(jj + ll < 1)
                    pxShift = pxShift + ((2 * (1 - (jj + ll)) - 1) * numCols);
                end

                vCols(elmntIdx) = pxIdx + pxShift;
                vVals(elmntIdx) = mH(kk + kernelRadiusV + 1, ll + kernelRadiusH + 1);

            end
        end
    end
end

mK = sparse(vRows, vCols, vVals, numElementsImage, numElementsImage);


end


function [ mK ] = CreateConvMtxReplicate( mH, numRows, numCols )
%UNTITLED6 Summary of this function goes here
%   Detailed explanation goes here

numElementsImage    = numRows * numCols;
numRowsKernel       = size(mH, 1);
numColsKernel       = size(mH, 2);
numElementsKernel   = numRowsKernel * numColsKernel;

vRows = reshape(repmat(1:numElementsImage, numElementsKernel, 1), numElementsImage * numElementsKernel, 1);
vCols = zeros(numElementsImage * numElementsKernel, 1);
vVals = zeros(numElementsImage * numElementsKernel, 1);

kernelRadiusV = floor(numRowsKernel / 2);
kernelRadiusH = floor(numColsKernel / 2);

pxIdx       = 0;
elmntIdx    = 0;

for jj = 1:numCols
    for ii = 1:numRows
        pxIdx = pxIdx + 1;
        for ll = -kernelRadiusH:kernelRadiusH
            for kk = -kernelRadiusV:kernelRadiusV
                elmntIdx = elmntIdx + 1;

                pxShift = (ll * numCols) + kk;

                if(ii + kk > numRows)
                    pxShift = pxShift - (ii + kk - numRows);
                end

                if(ii + kk < 1)
                    pxShift = pxShift + (1 -(ii + kk));
                end

                if(jj + ll > numCols)
                    pxShift = pxShift - ((jj + ll - numCols) * numCols);
                end

                if(jj + ll < 1)
                    pxShift = pxShift + ((1 - (jj + ll)) * numCols);
                end

                vCols(elmntIdx) = pxIdx + pxShift;
                vVals(elmntIdx) = mH(kk + kernelRadiusV + 1, ll + kernelRadiusH + 1);

            end
        end
    end
end

mK = sparse(vRows, vCols, vVals, numElementsImage, numElementsImage);


end


function [ mK ] = CreateConvMtxCircular( mH, numRows, numCols )
%UNTITLED6 Summary of this function goes here
%   Detailed explanation goes here

numElementsImage    = numRows * numCols;
numRowsKernel       = size(mH, 1);
numColsKernel       = size(mH, 2);
numElementsKernel   = numRowsKernel * numColsKernel;

vRows = reshape(repmat(1:numElementsImage, numElementsKernel, 1), numElementsImage * numElementsKernel, 1);
vCols = zeros(numElementsImage * numElementsKernel, 1);
vVals = zeros(numElementsImage * numElementsKernel, 1);

kernelRadiusV = floor(numRowsKernel / 2);
kernelRadiusH = floor(numColsKernel / 2);

pxIdx       = 0;
elmntIdx    = 0;

for jj = 1:numCols
    for ii = 1:numRows
        pxIdx = pxIdx + 1;
        for ll = -kernelRadiusH:kernelRadiusH
            for kk = -kernelRadiusV:kernelRadiusV
                elmntIdx = elmntIdx + 1;

                pxShift = (ll * numCols) + kk;

                if(ii + kk > numRows)
                    pxShift = pxShift - numRows;
                end

                if(ii + kk < 1)
                    pxShift = pxShift + numRows;
                end

                if(jj + ll > numCols)
                    pxShift = pxShift - (numCols * numCols);
                end

                if(jj + ll < 1)
                    pxShift = pxShift + (numCols * numCols);
                end

                vCols(elmntIdx) = pxIdx + pxShift;
                vVals(elmntIdx) = mH(kk + kernelRadiusV + 1, ll + kernelRadiusH + 1);

            end
        end
    end
end

mK = sparse(vRows, vCols, vVals, numElementsImage, numElementsImage);


end

该代码已针对 MATLAB 进行了验证imfilter()

我的StackOverflow Q2080835 GitHub 存储库中提供了完整代码

时域中的卷积等于频域中的矩阵乘法,反之亦然。

滤波相当于时域中的卷积,因此相当于频域中的矩阵乘法。

至于 5x5 地图或蒙版,它们来自于对 canny/sobel 算子进行离散化。

对于图像或卷积网络的应用,为了更有效地使用现代 GPU 中的矩阵乘法器,输入通常被重新整形为激活矩阵的列,然后可以一次与多个过滤器/内核相乘。

查看斯坦福 CS231n 中的此链接,并向下滚动到“作为矩阵乘法的实现”部分了解详细信息。

该过程通过获取输入图像或激活图上的所有局部补丁来工作,这些补丁将与内核相乘,并通过通常称为 im2col 的操作将它们拉伸到新矩阵 X 的列中。内核也被拉伸以填充权重矩阵 W 的行,以便在执行矩阵运算 W*X 时,得到的矩阵 Y 具有卷积的所有结果。最后,必须通过通常称为 cal2im 的操作将列转换回图像来再次重塑 Y 矩阵。