反射填充作为纯 keras 版本

数据挖掘 喀拉斯
2022-02-12 22:39:06

我正在使用带有 plaidML 后端的 keras,并且需要实现反射填充。使用一个简单的tf.pad
的 tensorflow 后端,设置为 REFLECT。 mode

如何使用 K.functions 或 plaidml tile 函数实现该功能?
或者在我可以使用的地方有一个实现吗?

使用 K. 函数,简单地将值切片、反转和连接在一起可能是可能的,但我所有的尝试都朝那个方向发展,结果一团糟,并没有真正奏效。

1个回答

这是一个将反射填充作为纯 K 函数的版本,它应该(但未经测试)适用于每个后端:

def reflection_padding(inp, paddings):
    paddings = [(x, x) if isinstance(x, int) else x for x in paddings]
    ishape = inp.shape.dims
    ndims = inp.shape.ndims
    if len(ishape) != len(paddings):
        raise ValueError("Padding dims != input dims")
    last = inp
    _all_slice = slice(None, None, None)

    def _get_slices(ndims, axis, slice_):
        ret = [_all_slice for _ in range(ndims)]
        ret[axis] = slice_
        return tuple(ret)

    for axis, pads in ((i, x) for i, x in enumerate(paddings) if x[0]+x[1] != 0):
        pad_data = []
        if pads[0]:
            pre = last[_get_slices(ndims, axis, slice(pads[0], 0, -1))]
            pad_data.append(pre)
        pad_data.append(last)
        if pads[1]:
            post = last[_get_slices(ndims, axis, slice(-2, -pads[1]-2, -1))]
            pad_data.append(post)
        last = K.concatenate(pad_data, axis)
        ishape = last.shape.dims
    return last

# USAGE: reflection_padding(image_batch, [0, [2,2], [2,2], 0])

我将接受我自己的答案。如果有人有更好的答案,我很乐意将其转换为他们的