计算t r (一种吨乙丙)tr(ATBC)在 Python 中

计算科学 Python 矩阵
2021-12-08 22:44:07

我必须在 python 中计算三个矩阵之间的乘积轨迹,即我必须计算,我想知道在 Python 中做它的好方法是什么(这里是 ) 的转置。如果我只有 2 个矩阵并且我想计算我知道我可以做一些事情A,B,Ctr(ATBC)ATAA,Btr(ATB)

numpy.tensordot(A,B,axis=2)

但是有 3 个矩阵的情况呢?

2个回答

我还不太明白你的矩阵的维度是什么(见评论)。但下面的代码可能会给你一个开始。效率不高或特别优雅,但您可以从那里改进:

   import numpy as np
   A = np.matrix('1 2; 3 4')
   B = np.matrix('1 2; 3 4')
   C = np.matrix('1 2; 3 4')
   result = (np.matmul(A.transpose(),np.matmul(B,C)))
   print(result)
   print(np.trace(result))

一般来说,如果我们为您的问题获得所有信息和一些背景信息,例如小样本问题等,这里的人们将能够提供更好的帮助。

做这种事情的另一个好方法是numpy.einsum,它允许你用索引表示法表达乘法:

>>> A=np.array([[1,2,3],[4,5,6]])
>>> B=np.array([[1,2],[3,4]])
>>> C=np.array([[1,2,3],[4,5,6]])
>>> A.T@B@C
array([[ 85, 116, 147],
       [113, 154, 195],
       [141, 192, 243]])
>>> np.einsum('ij,jk,kl->il',A.T,B,C)
array([[ 85, 116, 147],
       [113, 154, 195],
       [141, 192, 243]])
>>> np.einsum('ij,jk,ki->',A.T,B,C)
482

这与我们在数学上如何写这个方程有很好的对应关系: Einsum 也有一个关键字,它可以搜索最快的收缩顺序。根据矩阵的尺寸,这可以使乘法更快。

Tr(ATBC)=ijk(AT)ijBjkCki
optimize