如何在 Tensorflow 中做批量内积?

数据挖掘 张量流
2021-09-14 03:54:44

我有两个张量a:[batch_size, dim] b:[batch_size, dim]我想为批次中的每一对做内积,生成c:[batch_size, 1],在哪里c[i,0]=a[i,:].T*b[i,:]如何?

3个回答

没有本地.dot_product方法。但是,两个向量之间的点积只是按元素相乘相加,因此以下示例有效:

import tensorflow as tf

# Arbitrarity, we'll use placeholders and allow batch size to vary,
# but fix vector dimensions.
# You can change this as you see fit
a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))

c = tf.reduce_sum( tf.multiply( a, b ), 1, keep_dims=True )

with tf.Session() as session:
    print( c.eval(
        feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
    ) )

输出是:

[[ 20.]
 [ 92.]]

另一个值得一试的选项是- 它本质上是Einstein Notation[tf.einsum][1]的简化版本

以下是 Neil 和 dumkar 的示例:

import tensorflow as tf

a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))

c = tf.einsum('ij,ij->i', a, b)

with tf.Session() as session:
    print( c.eval(
        feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
    ) )

的第一个参数einsum是一个方程,表示要相乘和相加的轴。方程的基本规则是:

  1. 输入张量由逗号分隔的维度标签字符串描述
  2. 重复的标签表示对应的维度会成倍增加
  3. 输出张量由代表相应输入(或产品)的另一串维度标签描述
  4. 将输出字符串中缺少的标签相加

在我们的例子中,ij,ij->i意味着我们的输入将是 2 个形状相等的矩阵(i,j),我们的输出将是一个形状向量(i,)

一旦你掌握了它的窍门,你会发现它einsum概括了大量的其他操作:

X = [[1, 2]]
Y = [[3, 4], [5, 6]]

einsum('ab->ba', X) == [[1],[2]]   # transpose
einsum('ab->a',  X) ==  [3]        # sum over last dimension
einsum('ab->',   X) ==   3         # sum over both dimensions

einsum('ab,bc->ac',  X, Y) == [[13,16]]          # matrix multiply
einsum('ab,bc->abc', X, Y) == [[[3,4],[10,12]]]  # multiply and broadcast

不幸的是,einsum与手动乘法+减法相比,它的性能受到了相当大的打击。在性能至关重要的地方,我绝对建议坚持使用 Neil 的解决方案。

如果将轴设置为例如,取tf.tensordot的对角线也可以满足您的要求

[[1], [1]]

我改编了尼尔斯莱特的例子:

import tensorflow as tf

# Arbitrarity, we'll use placeholders and allow batch size to vary,
# but fix vector dimensions.
# You can change this as you see fit
a = tf.placeholder(tf.float32, shape=(None, 3))
b = tf.placeholder(tf.float32, shape=(None, 3))

c = tf.diag_part(tf.tensordot( a, b, axes=[[1],[1]]))

with tf.Session() as session:
    print( c.eval(
        feed_dict={ a: [[1,2,3],[4,5,6]], b: [[2,3,4],[5,6,7]] }
    ) )

现在还给出了:

[ 20.  92.]

不过,这对于大型矩阵可能不是最理想的(请参阅此处的讨论)