在 tf.Tensor 中交换 2 个数字

数据挖掘 Python 张量流
2022-03-14 05:32:47

给定一对Variable iandj和 a Tensor A,如何得到一个新TensorAi第 th 元素和j第 th 元素交换的?N成对的ij我们需要全部交换吗如何在批量训练中做到这一点?

我发现交换一对很容易:

if i < j:
    return tf.concat((A[:i], A[j], A[i+1:j], A[i], A[j+1:]), 0)
else:
    return tf.concat((A[:j], A[i], A[j+1:i], A[j], A[i+1:]), 0)

交换多对需要显式循环/减少:

for i, j in ijpairs:
    A = swap_one_pair(A, i, j)
return A

批量交换多个样本需要显式映射:

return tf.map_fn(swap_one_pair, (As, is, js))

我发现如果我将交换分为“读取”和“写入”,则可以通过以下方式轻松矢量化读取gather_nd

indexes = tf.stack((tf.tile(tf.range(batch_size)[:,None], (1,read_size)), all_i), 2)
index_results = tf.gather_nd(A, indexes)

有“写”的版本gather_nd吗?或者有没有更优雅的方式来实现利用矢量化/广播的交换?

编辑

我找到了一个“写”版本scatter_nd_update不幸的是,它只适用于Variable. 我尝试Tensor使用但不工作Variabletf.Variable(myTensor)我目前使用:

def scatter_nd_update(ref, indices, updates):
    original_part = ref * (1 - tf.scatter_nd(indices, tf.ones(updates.shape), ref.shape))
    update_part   = tf.scatter_nd(index, update, ref.shape)
    return original_part + update_part

这种方法使用了大量无意义的计算。有更好的想法吗?

0个回答
没有发现任何回复~