您可以编写自己的算法。我很快草拟了一些东西。它可以显着优化。
让我们做一些随机数据
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
n = 9
x = np.random.rand(n,2)
y = np.zeros((n,))
y[n//3:2*n//3] = 1
y[2*n//3::] = 2
plt.scatter(x[:,0], x[:,1], c=y)
plt.show()

现在让我们得到每个点的间距。
dists = np.asarray([np.linalg.norm(i-j) for i in x for j in x]).reshape(n,n)
plt.imshow(dists)
plt.show()

我们将制作一个列表,其中包含每个点的每个圆的半径。对于每个点,我们将通过它们的相对接近度来迭代其他点。如果尚未看到关联的标签,请将其添加到临时列表中。否则,我们结束函数并取最后距离和当前非法点的平均值。
radii = []
for row in dists:
labels = []
dis = []
for i in np.argsort(row):
if y[i] not in labels:
labels.append(y[i])
dis = row[i]
else:
dis = (dis + row[i])/2
break
radii.append(dis)
现在我们可以绘制这些圆圈
fig, ax = plt.subplots(figsize=(10,10))
for ix, i in enumerate(radii):
circle = plt.Circle((x[ix, 0], x[ix, 1]), i, color='b', fill=False, alpha = 0.5)
ax.add_artist(circle)
plt.scatter(x[:,0], x[:,1], c=y)
plt.xlim([-0.2,1.4])
plt.ylim([-0.2,1.4])
plt.show()

如果我们增加数据点的数量n我们得到
