快速且数值稳定的成对距离算法

计算科学 线性代数 表现 稳定 浮点 距离测量
2021-12-06 22:14:05

我正在寻找有关快速、数值稳定的成对欧几里德距离算法的资源。特别是,假设ARM×DBRN×D是两组行向量。我想计算矩阵,

XRM×N,Xi,j=AiBj22.

到目前为止,我找到了两种方法:

方法 1 - 简单的方法是遍历每个向量Ai和每个向量Bj并用相应的平方和填充X

方法 2 - 利用以下事实,

AiBj22=AiBj,AiBj=Ai22+Bj222Ai,Bj.

利用高效的代码来计算这三个项中的每一项,方法 2 几乎快了一个数量级。但是,它在数值上不如方法 1 稳定。例如,方法 2 可以输出负距离。

是否有替代方法来计算比方法 1 更快但保证(至少)所有非负距离的成对距离?有更好的数值稳定性?

代码演示- 下面是一个用 Python 编写的示例比较。Scipy 的cdist()功能实际上是方法 1 的实现,而cdist_fast()下面是方法 2 的实现:

# experiment.py
import numpy as np
import time
from scipy.spatial.distance import cdist


def cdist_fast(XA, XB):
    XA_norm = np.sum(XA**2, axis=1)
    XB_norm = np.sum(XB**2, axis=1)
    XA_XB_T = np.dot(XA, XB.T)
    distances = XA_norm.reshape(-1,1) + XB_norm - 2*XA_XB_T
    return distances


def main():
    M,N = 5000, 128
    XA = np.random.randn(M,N)

    t = time.time()
    distances_cdist = cdist(XA, XA, metric='sqeuclidean')
    time_cdist = time.time() - t

    t = time.time()
    distances_cdist_fast = cdist_fast(XA, XA)
    time_cdist_fast = time.time() - t

    print(f'time_cdist = {time_cdist:.3f} s')
    print(f'time_cdist_fast = {time_cdist_fast:.3f} s')

    # check validity of results
    assert np.allclose(distances_cdist, distances_cdist_fast)

    # check that the results are non-negative
    try:
        assert (distances_cdist >= 0.0).all()
    except AssertionError:
        print('Numerical instability in cdist()')

    try:
        assert (distances_cdist_fast >= 0.0).all()
    except AssertionError:
        print('Numerical instability in cdist_fast()')


if __name__ == '__main__':
    main()

3.1 GHz Intel Core i7 上的脚本输出:

$ python experiment.py
time_cdist = 3.457 s
time_cdist_fast = 0.625 s
Numerical instability in cdist_fast()
2个回答

我强烈建议不要使用方法 2: 无论您是否使用绝对值。因为您将获得积极(但错误)的距离;但是,您仍然缺乏任何数值稳定性。

AiBj22=AiBj,AiBj=Ai22+Bj222Ai,Bj.

现在,对于中的 5000 个向量,您的 (scipy) 稳定版本比快速版本慢约 6 倍。我认为您找不到比. 我会确保它实际上是您代码中的瓶颈。特别是由于方法 1自然只对正条目求和,这是一个不处理稳定性的理想且廉价的方案!R128scipy.cdist

您绝对可以尝试使用KahanNeumaier求和,但我会感到惊讶的是,如果使用它们(请注意,Neumaier 必须单独应用于每个条目),上述方法 2 将比scipy.cdist,最有可能至少几十时间较慢。

但是,如果距离的正数是您唯一关心的事情,那么您当然可以在函数之上使用绝对值cdist_fast(.,.)

时,计算值为负,其中是机器精度。负值是舍入误差。Ai=Bj+O(ε)ε

如果返回计算的绝对值, ,你得到加速和正数。|||Ai||22+||Bj||222<Ai,Bj>|