我有两个张量a:[batch_size, dim]
b:[batch_size, dim]
。我想为批次中的每一对做内积,生成c:[batch_size, 1]
,在哪里c[i,0]=a[i,:].T*b[i,:]
。如何?
如何在 Tensorflow 中做批量内积?
数据挖掘
张量流
2021-09-14 03:54:44
3个回答
没有本地.dot_product
方法。但是,两个向量之间的点积只是按元素相乘相加,因此以下示例有效:
import tensorflow as tf
# Arbitrarity, we'll use placeholders and allow batch size to vary,
# but fix vector dimensions.
# You can change this as you see fit
a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))
c = tf.reduce_sum( tf.multiply( a, b ), 1, keep_dims=True )
with tf.Session() as session:
print( c.eval(
feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
) )
输出是:
[[ 20.]
[ 92.]]
另一个值得一试的选项是- 它本质上是Einstein Notation[tf.einsum][1]
的简化版本。
以下是 Neil 和 dumkar 的示例:
import tensorflow as tf
a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))
c = tf.einsum('ij,ij->i', a, b)
with tf.Session() as session:
print( c.eval(
feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
) )
的第一个参数einsum
是一个方程,表示要相乘和相加的轴。方程的基本规则是:
- 输入张量由逗号分隔的维度标签字符串描述
- 重复的标签表示对应的维度会成倍增加
- 输出张量由代表相应输入(或产品)的另一串维度标签描述
- 将输出字符串中缺少的标签相加
在我们的例子中,ij,ij->i
意味着我们的输入将是 2 个形状相等的矩阵(i,j)
,我们的输出将是一个形状向量(i,)
。
一旦你掌握了它的窍门,你会发现它einsum
概括了大量的其他操作:
X = [[1, 2]]
Y = [[3, 4], [5, 6]]
einsum('ab->ba', X) == [[1],[2]] # transpose
einsum('ab->a', X) == [3] # sum over last dimension
einsum('ab->', X) == 3 # sum over both dimensions
einsum('ab,bc->ac', X, Y) == [[13,16]] # matrix multiply
einsum('ab,bc->abc', X, Y) == [[[3,4],[10,12]]] # multiply and broadcast
不幸的是,einsum
与手动乘法+减法相比,它的性能受到了相当大的打击。在性能至关重要的地方,我绝对建议坚持使用 Neil 的解决方案。
如果将轴设置为例如,取tf.tensordot的对角线也可以满足您的要求
[[1], [1]]
我改编了尼尔斯莱特的例子:
import tensorflow as tf
# Arbitrarity, we'll use placeholders and allow batch size to vary,
# but fix vector dimensions.
# You can change this as you see fit
a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))
c = tf.diag_part(tf.tensordot( a, b, axes=[[1],[1]]))
with tf.Session() as session:
print( c.eval(
feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
) )
现在还给出了:
[ 20. 92.]
不过,这对于大型矩阵可能不是最理想的(请参阅此处的讨论)
其它你可能感兴趣的问题