计算diag(u)-uu'的平方根?

计算科学 线性代数 数字 统计数据
2021-12-07 20:40:21

我需要一种有效的方法来取矩阵的平方根,该矩阵是对角矩阵和 rank-1 矩阵的总和。

更具体地说,它是以下矩阵

A=Duu=diag(u)uu

其中的条目u是非负的并且idui=1. 这也称为多项式分布与均值的协方差u

我的第一种方法是使用特征值分解,但它对我的应用程序来说太慢了大约 50 倍。

A1/2=US1/2U
第二种方法是忽略uu学期。这足够快,但会产生很大的错误。

A1/2D1/2

我的第三种方法是对误差项进行一些代数重排,并得出以下 rank-1 近似值

A1/2D1/2+1+uD1u1uD1uD14uuD14

这足够快,并且在几个随机示例(笔记本)上有一个小的近似误差。但是,我不确定是否存在这种近似值失效的情况。

在这一点上,我觉得我可能正在重新发明轮子——这个问题有已知的结果吗?具体来说,一种精确计算它的廉价方法,或者在保证小的近似误差的情况下近似计算它。

的价值观u在我的情况下高度偏斜,它们至少衰减得一样快O(1/k)并且可能成倍增长。

1个回答

会分解形式A=XXT够了吗?这已经足够了,例如,如果最终目标是从具有给定协方差的高斯分布中采样。

如果是这样,您可以使用以下公式,该公式与您的近似值非常相似:

X=D1/2+uTD1u+11uTD1uuuTD1/2
这是从白化对角项得出的
D+uuT=D1/2(I+D1/2uuTD1/2)D1/2,
然后找到中间因子的平方根,这是身份的 rank-1 更新。


如果你真的需要平方根,你可以考虑使用有理近似值,

(D+uuT)1/2c0+c1(σ1I+D+uuT)1+c2(σ2I+D+uuT)1+.
然后用 Sherman-Morrison 公式反转这些项。一个非常好的近似值所需的项数随着
O(logκ)
在哪里κ是条件数D+uuT. 因此,这不是对角线加上秩为 1 的矩阵,而是变成少量对角线加上秩为 1 的矩阵的总和。

有关有理逼近的更多详细信息,这是一篇精彩的论文:

Hale、Nicholas、Nicholas J. Higham 和 Lloyd N. Trefethen。“通过等高线积分计算 A^α、\log(A) 和相关矩阵函数。” SIAM 数值分析杂志 46.5 (2008): 2505-2523。http://eprints.maths.manchester.ac.uk/834/1/hale_higham_trefethen.pdf

具体参见该论文中的方程(4.4)和方法 3。


编辑:在这里我写了一些 Python 代码来执行理性的方法:

# Adaptation of Method 3 from Hale, Higham, and Trefethen, Computing f(A)b by contour integrals. SIAM 2008
import numpy as np
import scipy.linalg as sla
from scipy.special import *


def hht_isqrt_weights_and_poles(min_eigenvalue_m, max_eigenvalue_M, number_of_rational_terms_N):
    # 1/sqrt(z) = w0/(z - p0) + w1/(z - p1) + ...
    m = min_eigenvalue_m
    M = max_eigenvalue_M
    N = number_of_rational_terms_N
    k2 = m/M
    Kp = ellipk(1-k2)
    t = 1j * np.arange(0.5, N) * Kp/N
    sn, cn, dn, ph = ellipj(t.imag,1-k2)
    cn = 1./cn
    dn = dn * cn
    sn = 1j * sn * cn
    w = np.sqrt(m) * sn
    dzdt = cn * dn

    poles = (w**2).real
    weights = (2 * Kp * np.sqrt(m) / (np.pi*N)) * dzdt
    rational_function = lambda zz: np.dot(1. / (zz.reshape((-1,1)) - poles), weights)
    return weights, poles, rational_function


def diagonal_smw(dd, u):
    # (diag(dd)+uu^T)^-1 = diag(dd) - vv^T
    v = (u / dd)/np.sqrt(1. + np.dot(u, u / dd))
    return v


def diagonal_plus_rank_one_inverse_sqrt(dd, u, num_terms):
    # inv(sqrtm(diag(dd) + u*u')) = diag(ss) - V*V' + small error
    lambda_min_bound = np.min(dd) # These bounds could be improved
    lambda_max_bound = np.max(dd) + np.linalg.norm(u)**2
    ww, pp, _ = hht_isqrt_weights_and_poles(lambda_min_bound, lambda_max_bound, num_terms)
    dd_shift = [dd - p for p in pp]
    vv0 = [diagonal_smw(d_shift, u) for d_shift in dd_shift]
    ss = np.sum([w*(1./d_shift) for w, d_shift in zip(ww, dd_shift)], axis=0)
    V = np.vstack([np.sqrt(w)*v0 for w, v0 in zip(ww, vv0)]).T
    return ss, V

def diagonal_woodburyish(ss, V):
    # (diag(ss) - VV^T)^-1 = diag(1./ss) + W*M*W^T
    r = V.shape[1]
    W = V / ss.reshape([-1, 1]) # inv(diag(dd))*V
    capacitance_mtx = np.eye(r)-np.dot(V.T, W)
    M = np.linalg.inv(capacitance_mtx) # inverse of very small matrix, e.g., 5x5. Could do Cholesky if desired..
    return M, W

def diagonal_plus_rank_one_sqrt(dd, u, num_terms):
    # inv(sqrtm(diag(dd) + u*u')) = diag(iss) + W*M*W^T + small error
    ss, V = diagonal_plus_rank_one_inverse_sqrt(dd, u, num_terms)
    M, W = diagonal_woodburyish(ss, V)
    iss = 1. / ss
    return iss, M, W

n = 500
kappa_ish = 1e3
dd = kappa_ish * np.random.rand(n)
u = np.random.randn(n)

A = np.diag(dd) + np.outer(u, u)
sqrtA = sla.sqrtm(A)

for num_terms in 1+np.arange(10):
    iss, M, W = diagonal_plus_rank_one_sqrt(dd, u, num_terms)
    sqrtA2 = np.diag(iss) + np.dot(W, np.dot(M, W.T))
    err = np.linalg.norm(sqrtA - sqrtA2) / np.linalg.norm(sqrtA)
    print('num_terms=', num_terms, ', err=', err)

它工作得很好,这是有理近似中 1 到 10 项的输出误差:

num_terms= 1 , err= 0.28856754578036703
num_terms= 2 , err= 0.04084537953169016
num_terms= 3 , err= 0.005770763245002039
num_terms= 4 , err= 0.000624851647505238
num_terms= 5 , err= 7.234944285553648e-05
num_terms= 6 , err= 1.0059955401611735e-05
num_terms= 7 , err= 1.22949322111201e-06
num_terms= 8 , err= 1.3750242705729841e-07
num_terms= 9 , err= 1.7564866678109768e-08
num_terms= 10 , err= 2.289881883503377e-09