使用 KMeans 聚类查找聚类直径和相关聚类点 (scikit learn)

数据挖掘 Python scikit-学习
2022-03-02 12:03:20

我已经使用 sklearn 使用 Kmeans 完成了聚类。虽然它有一种打印质心的方法,但我发现 scikit-learn 没有一种方法来找出簇直径(或者我到目前为止还没有看到它),这很奇怪。是否有一种巧妙的方法可以为每个集群以及与集群关联的点获取此信息?

我目前有这个相当笨拙的代码来做

import numpy as np
from sklearn.cluster import KMeans
from sklearn import datasets

iris = datasets.load_iris()
X = iris.data
y = iris.target

estimator = KMeans(n_clusters=3)
estimator.fit(X)
print({i: np.where(estimator.labels_ == i)[0] for i in range(estimator.n_clusters)}) #get the indices of points for each cluster
1个回答

目前,我已经准备了一个解决方案:

#iris example
iris = datasets.load_iris()
x = iris.data
y = iris.target


estimator = KMeans(n_clusters=3)
y_kmeans = estimator.fit_predict(x)

要获取集群的半径,您可以使用以下代码片段:

#empty dictionaries

clusters_centroids=dict()
clusters_radii= dict()

'''looping over clusters and calculate Euclidian distance of 
each point within that cluster from its centroid and 
pick the maximum which is the radius of that cluster'''

for cluster in list(set(y)):

    clusters_centroids[cluster]=list(zip(estimator.cluster_centers_[:, 0],estimator.cluster_centers_[:,1]))[cluster]
    clusters_radii[cluster] = max([np.linalg.norm(np.subtract(i,clusters_centroids[cluster])) for i in zip(x[y_kmeans == cluster, 0],x[y_kmeans == cluster, 1])])

它会给你这个: 在此处输入图像描述

请注意 K 均值:

  • 隐式假设所有集群具有相同的半径
  • 将数据分成 Voronoi 单元(也可以从这里看到)。
  • 聚类点(圆)可以重叠(这是它的定义方式)。

如果你想放松集群的形状(不是严格的球形或像 K-means 那样的圆形),你应该执行高斯混合模型

附录(重现上述可视化):

#Visualising the clusters and cluster circles

fig, ax = plt.subplots(1,figsize=(7,5))

plt.scatter(x[y_kmeans == 0, 0], x[y_kmeans == 0, 1], s = 100, c = 'red', label = 'Iris-setosa')
art = mpatches.Circle(clusters_centroids[0],clusters_radii[0], edgecolor='r',fill=False)
ax.add_patch(art)

plt.scatter(x[y_kmeans == 1, 0], x[y_kmeans == 1, 1], s = 100, c = 'blue', label = 'Iris-versicolour')
art = mpatches.Circle(clusters_centroids[1],clusters_radii[1], edgecolor='b',fill=False)
ax.add_patch(art)

plt.scatter(x[y_kmeans == 2, 0], x[y_kmeans == 2, 1], s = 100, c = 'green', label = 'Iris-virginica')
art = mpatches.Circle(clusters_centroids[2],clusters_radii[2], edgecolor='g',fill=False)
ax.add_patch(art)

#Plotting the centroids of the clusters
plt.scatter(estimator.cluster_centers_[:, 0], estimator.cluster_centers_[:,1], s = 100, c = 'yellow', label = 'Centroids')

plt.legend()
plt.tight_layout()
plt.savefig('kmeans.jpg',dpi=300)