给定一对Variable
i
andj
和 a Tensor
A
,如何得到一个新Tensor
的A
与i
第 th 元素和j
第 th 元素交换的?N
成对的i
和j
我们需要全部交换吗?如何在批量训练中做到这一点?
我发现交换一对很容易:
if i < j:
return tf.concat((A[:i], A[j], A[i+1:j], A[i], A[j+1:]), 0)
else:
return tf.concat((A[:j], A[i], A[j+1:i], A[j], A[i+1:]), 0)
交换多对需要显式循环/减少:
for i, j in ijpairs:
A = swap_one_pair(A, i, j)
return A
批量交换多个样本需要显式映射:
return tf.map_fn(swap_one_pair, (As, is, js))
我发现如果我将交换分为“读取”和“写入”,则可以通过以下方式轻松矢量化读取gather_nd
:
indexes = tf.stack((tf.tile(tf.range(batch_size)[:,None], (1,read_size)), all_i), 2)
index_results = tf.gather_nd(A, indexes)
有“写”的版本gather_nd
吗?或者有没有更优雅的方式来实现利用矢量化/广播的交换?
编辑
我找到了一个“写”版本scatter_nd_update
。不幸的是,它只适用于Variable
. 我尝试Tensor
使用但不工作Variable
。tf.Variable(myTensor)
我目前使用:
def scatter_nd_update(ref, indices, updates):
original_part = ref * (1 - tf.scatter_nd(indices, tf.ones(updates.shape), ref.shape))
update_part = tf.scatter_nd(index, update, ref.shape)
return original_part + update_part
这种方法使用了大量无意义的计算。有更好的想法吗?