有监督的聚类

数据挖掘 聚类 无监督学习 监督学习 半监督学习
2022-02-28 02:49:54

我正在研究一个聚类问题。我有一个训练集,由已知集群的点集组成,我想在测试数据集上找到好的集群。这是一种有监督的聚类。

我查找了有关监督聚类的文章,但没有找到很多信息。有“半监督聚类”,包括使用关于几个点的信息(必须链接或不链接关系),但在我的任务中,我没有这种信息。还有某种“度量学习监督聚类”,它使用标记的集群来估计一个度量,该度量将使用 k-means 生成给定的集群。这种技术可以帮助我,但没有太多关于它的文章,我想知道我是否没有找到好的关键字或其他东西。

使用标记数据(具有已知集群的训练点)对数据点进行聚类的技术/算法是什么?

4个回答

您要查找的内容称为KNN algorithm,也称为 k 近邻。这是一种有监督的算法,您可以在其中给出点及其集群,然后使用它们来学习测试点的模式。

这是分类,不是吗?

您已标记训练数据。您想相应地标记您的测试集。使用分类器...

KNN algorithm是一个解决方案,但它不能很好地扩展。KNN的成本O(n.log(n))如果您提出请求,请按要求n它变成了O(n2.log(n))并且根本不扩展。如果您的集群不是太大并且您没有太多的申请要求,它适用。您可以想象一些更快但通常不如ANN(近似最近邻)准确的东西,甚至更快但也更容易出错K-手段/模式/原型模型方法。

原理如下,如果你有你的集群,你可以使用每个集群的代表,称为原型,可以通过各种方式找到它。

在真实空间中,您可以使用平均值或中位数作为原型。对于二进制空间,请使用多数票。对于混合空间,例如均值/媒体和多数投票。

您还可以随机选取集群中的一个点,这是中心点方法,这种方法具有在任何度量空间上工作的优势。一旦每个集群都有一个原型,您只需要将未知点与其最近原型的集群相关联。

这些方法的复杂性下降到O(c)每个请求c是集群的数量,如果您没有数百万个集群,这种技术将很容易实现(取决于您的度量空间)并且工作速度很快。

老实说,直到现在我才听说过“半监督聚类”。那里有很多聚类技术。这里有 7 个流行的 tequines 聚类。我为您整理了一些示例代码(如下)。我让它尽可能自动化(只需复制/粘贴)。希望这会让你指出正确的方向。只需将您自己的数据输入X变量(确保它是一个数组)。

所以:X = df['A'].to_numpy()

from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
#%matplotlib inline
from sklearn import datasets#Iris Dataset
iris = datasets.load_iris()
X = iris.data#KMeans
km = KMeans(n_clusters=3)
km.fit(X)
km.predict(X)
labels = km.labels_#Plotting
fig = plt.figure(1, figsize=(7,7))
ax = Axes3D(fig, rect=[0, 0, 0.95, 1], elev=48, azim=134)
ax.scatter(X[:, 3], X[:, 0], X[:, 2],
          c=labels.astype(np.float), edgecolor="k", s=50)
ax.set_xlabel("Petal width")
ax.set_ylabel("Sepal length")
ax.set_zlabel("Petal length")
plt.title("K Means", fontsize=14)

在此处输入图像描述

########################################

from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
%matplotlib inline
from sklearn import datasets#Iris Dataset
iris = datasets.load_iris()
X = iris.data#Gaussian Mixture Model
gmm = GaussianMixture(n_components=3)
gmm.fit(X)
proba_lists = gmm.predict_proba(X)#Plotting
colored_arrays = np.matrix(proba_lists)
colored_tuples = [tuple(i.tolist()[0]) for i in colored_arrays]
fig = plt.figure(1, figsize=(7,7))
ax = Axes3D(fig, rect=[0, 0, 0.95, 1], elev=48, azim=134)
ax.scatter(X[:, 3], X[:, 0], X[:, 2],
          c=colored_tuples, edgecolor="k", s=50)
ax.set_xlabel("Petal width")
ax.set_ylabel("Sepal length")
ax.set_zlabel("Petal length")
plt.title("Gaussian Mixture Model", fontsize=14)

在此处输入图像描述

########################################

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn.cluster as cluster
import time
#%matplotlib inline
sns.set_context('poster')
sns.set_color_codes()
plot_kwds = {'alpha' : 0.25, 's' : 80, 'linewidths':0}

data = X

plt.scatter(data.T[0], data.T[1], c='b', **plot_kwds)
frame = plt.gca()
frame.axes.get_xaxis().set_visible(False)
frame.axes.get_yaxis().set_visible(False)


def plot_clusters(data, algorithm, args, kwds):
    start_time = time.time()
    labels = algorithm(*args, **kwds).fit_predict(data)
    end_time = time.time()
    palette = sns.color_palette('deep', np.unique(labels).max() + 1)
    colors = [palette[x] if x >= 0 else (0.0, 0.0, 0.0) for x in labels]
    plt.scatter(data.T[0], data.T[1], c=colors, **plot_kwds)
    frame = plt.gca()
    frame.axes.get_xaxis().set_visible(False)
    frame.axes.get_yaxis().set_visible(False)
    plt.title('Clusters found by {}'.format(str(algorithm.__name__)), fontsize=24)
    plt.text(-0.5, 0.7, 'Clustering took {:.2f} s'.format(end_time - start_time), fontsize=14)


plot_clusters(data, cluster.KMeans, (), {'n_clusters':5})

在此处输入图像描述

plot_clusters(data, cluster.AffinityPropagation, (), {'preference':-5.0, 'damping':0.95})

在此处输入图像描述

plot_clusters(data, cluster.MeanShift, (0.175,), {'cluster_all':False})

在此处输入图像描述

plot_clusters(data, cluster.SpectralClustering, (), {'n_clusters':6})

在此处输入图像描述

plot_clusters(data, cluster.AgglomerativeClustering, (), {'n_clusters':6, 'linkage':'ward'})

在此处输入图像描述

plot_clusters(data, cluster.DBSCAN, (), {'eps':0.025})

在此处输入图像描述