在 Python 中计算 KL 散度

数据挖掘 Python 聚类 scikit-学习
2021-10-11 21:03:47

我对此很陌生,不能说我完全理解这背后的理论概念。我正在尝试计算 Python 中几个点列表之间的 KL Divergence。我正在使用来尝试执行此操作。我遇到的问题是返回的值对于任何 2 个数字列表(其 1.3862943611198906)都是相同的。我有一种感觉,我在这里犯了某种理论上的错误,但无法发现它。

values1 = [1.346112,1.337432,1.246655]
values2 = [1.033836,1.082015,1.117323]
metrics.mutual_info_score(values1,values2)

这是我正在运行的一个示例 - 只是我为任何 2 个输入获得了相同的输出。任何建议/帮助将不胜感激!

4个回答

首先,sklearn.metrics.mutual_info_score实现互信息评估聚类结果,而不是纯粹的 Kullback-Leibler 散度!

这等于联合分布与边际乘积分布的 Kullback-Leibler 散度。

KL 散度(以及任何其他此类度量)期望输入数据的总和为 1否则,它们不是适当的概率分布如果您的数据总和不为 1,则很可能通常不适合使用 KL 散度!(在某些情况下,总和小于 1 是可以接受的,例如在缺失数据的情况下。)

另请注意,通常使用以 2 为底的对数。这只会在差异中产生一个恒定的比例因子,但以 2 为底的对数更容易解释并且具有更直观的比例(0 到 1 而不是 0 到 log2=0.69314...,以位而不是 nat 测量信息)。

> sklearn.metrics.mutual_info_score([0,1],[1,0])
0.69314718055994529

我们可以清楚地看到,sklearn 的 MI 结果使用自然对数而不是 log2 进行缩放。如上所述,这是一个不幸的选择。

不幸的是,Kullback-Leibler 分歧是脆弱的。在上面的示例中,它没有明确定义:KL([0,1],[1,0])导致除以零,并趋于无穷大。它也是不对称的。

如果输入两个向量 p 和 q, Scipy 的熵函数将计算 KL 散度,每个向量代表一个概率分布。如果这两个向量不是pdf,它将首先标准化。

互信息与KL Divergence相关,但并不相同。

“这种加权互信息是加权 KL-Divergence 的一种形式,已知它对某些输入取负值,并且有些例子中加权互信息也取负值”

我不确定 scikit-learn 的实现,但这里是 Python 中 KL 散度的快速实现:

import numpy as np

def KL(a, b):
    a = np.asarray(a, dtype=np.float)
    b = np.asarray(b, dtype=np.float)

    return np.sum(np.where(a != 0, a * np.log(a / b), 0))


values1 = [1.346112,1.337432,1.246655]
values2 = [1.033836,1.082015,1.117323]

print KL(values1, values2)

输出: 0.775279624079

某些库中可能存在实现冲突,因此请确保在使用前阅读他们的文档。

这个技巧避免了条件代码,因此可以提供更好的性能。

import numpy as np

def KL(P,Q):
""" Epsilon is used here to avoid conditional code for
checking that neither P nor Q is equal to 0. """
     epsilon = 0.00001

     # You may want to instead make copies to avoid changing the np arrays.
     P = P+epsilon
     Q = Q+epsilon

     divergence = np.sum(P*np.log(P/Q))
     return divergence

# Should be normalized though
values1 = np.asarray([1.346112,1.337432,1.246655])
values2 = np.asarray([1.033836,1.082015,1.117323])

# Note slight difference in the final result compared to Dawny33
print KL(values1, values2) # 0.775278939433