我们有以下 KLD 的实现:
import numpy as np
import pandas as pd
from scipy.stats import entropy
def KL_divergence(a, b):
hist_a = np.histogram(a, bins=100, range=(0,1.0))[0]
hist_b = np.histogram(b, bins=100, range=(0,1.0))[0]
hist_b = np.where(hist_b == 0.0, 1e-6, hist_b)
return entropy(hist_a, hist_b)
它采用两个数据集(范围为 0-1),将它们离散化为 100 个相等的 bin,并在结果数据集上计算 KLD。
在实践中,这根本行不通,因为这个距离随着数据集的大小(较小的数据集 = 较大的距离)而极大地缩放。在这里,我编写了一个简单的脚本,模拟不同大小数据(大小为 100、1000、10000)的许多分布,评估 KLD,并绘制每个直方图。“潜在概率”是这些数据集可能遵循的示例分布。
import numpy as np
import pandas as pd
from scipy.stats import entropy
import matplotlib.pyplot as plt
%matplotlib inline
def KL_divergence(hist_a, hist_b):
return entropy(hist_a, hist_b)
actual_bin_counts = np.array([7805, 436, 396, 456, 559, 809, 1139, 1928, 4618, 60948])
underlying_probability = actual_bin_counts / actual_bin_counts.sum()
def generate_histogram(n_samples, true_probs = underlying_probability):
uniform_random = np.random.uniform(0,1, size=n_samples)
bins_counts = np.digitize(uniform_random, underlying_probability.cumsum())
return np.unique(bins_counts, return_counts=True)[1]
distances_1000 = []
for repeat in range(10_000):
try:
sampled_a = generate_histogram(1000)
sampled_b = generate_histogram(1000)
distances_1000.append(KL_divergence(sampled_a, sampled_b))
except:
# we had a category with 9 bins. I don't care enough to fix it.
pass
distances_10_000 = []
for repeat in range(10_000):
try:
sampled_a = generate_histogram(10_000)
sampled_b = generate_histogram(10_000)
distances_10_000.append(KL_divergence(sampled_a, sampled_b))
except:
# we had a category with 9 bins. I don't care enough to fix it.
pass
distances_100_000 = []
for repeat in range(10_000):
try:
sampled_a = generate_histogram(100_000)
sampled_b = generate_histogram(100_000)
distances_100_000.append(KL_divergence(sampled_a, sampled_b))
except:
# we had a category with 9 bins. I don't care enough to fix it.
pass
plt.xscale('log')
plt.hist(distances_1000, bins=100);
plt.hist(distances_10_000, bins=100);
plt.hist(distances_100_000, bins=100);
如您所见,虽然基础分布相同,但距离是无法比拟的。如何纠正数据集的大小?
