如何使用 Python 从数据集中的预定义集群中找到最远的数据点?

数据挖掘 scikit-学习 聚类
2022-02-19 08:29:22

我有一个数据集,其中某些行被标记为一个类(并被解释为不同的集群 #1),但其他点要么未标记,要么不明确。因此,我想通过按它们与集群 #1 的各自距离(更准确地说,从集群 #1 的最近点到各自的未标记点)对它们进行排序,来确定哪些未标记的数据点距离集群 #1 最远。

我的第一个想法是创建一个相似矩阵并计算每个未标记点的最近距离,但这似乎有点笨拙,有没有更优雅/有效的方法?

(我曾经将 sklearn 用于类似的任务,但据我所知,无监督聚类算法并未明确提供此类特定信息。)

1个回答

您想知道标记集群中未标记数据的最近邻居。使用sklearn,您可以NearestNeighbors()使用给定度量、算法(Ball-tree、KD-tree...)和所有其他参数(请参见此处)拟合一个类。

kneighbors()然后通过使用方法从未标记的数据点及其距离中获取标记的最近邻居。

这是一个示例代码:

import numpy as np
from sklearn.neighbors import NearestNeighbors

# Fake data
labeled_samples = [[0, 1.2], [0, 1.3], [0, 1.4]]
unlabeled_samples = [[0, 1.7], [0.5, 0.5], [1, 1]]

# Create your class with your labeled cluster
neigh = NearestNeighbors(n_neighbors=1)
neigh.fit(labeled_samples)

# get the distance/index to the nearest neighbor of you unlabeled data
distances, indexes = neigh.kneighbors(unlabeled_samples, 1, return_distance=True)

然后你只需要对结果进行排序。

注意:使用这种方法比计算所有标记数据点的所有距离然后对它们进行排序更优化。有关更多信息,请参阅此说明