KNeighborsClassifier 传递给用户定义的距离函数的值是错误的

数据挖掘 scikit-学习 距离 k-nn
2022-03-02 20:33:02

我有一个数据集,其中所有特征都是二进制的,每个数据点的类也是二进制的。我正在尝试将 KNearestClassifier 与用户定义的距离函数一起使用,如下所示:

KNN = KNeighborsClassifier(n_neighbors=3,
                           algorithm='ball_tree',
                           metric='pyfunc',
                           metric_params={"func": lev_metric})
x_train, x_test, y_train, y_test = train_test_split(df_sum,
                                                    y,
                                                    test_size=0.1,
                                                    random_state=0)
KNN.fit(x_train, y_train)

我的自定义指标函数如下:

def lev_metric(a, b):
    print(a)
    print(b)
    return levenshtein(a, b)

度量函数需要两个 ndarrays 的二进制值 0 和 1。knn.fit调用度量函数时,“b”看起来像预期的那样(例如[0 1 1 0 0 1 0 1 ...),但“a”看起来像乱码,并且是一个具有0到1之间实值元素的ndarray,例如:

[0.32222222 0.42222222 0.34444444 0.47777778 0.41111111 0.38888889
 0.4        0.31111111 0.35555556 0.35555556 0.42222222 0.46666667
 0.36666667 0.32222222 0.41111111 0.32222222 0.36666667 0.35555556
 0.41111111 0.33333333 0.4        0.42222222 0.3        0.37777778
 0.38888889 0.48888889 0.41111111 0.43333333 0.34444444 0.35555556
 0.43333333 0.38888889 0.43333333 0.32222222 0.47777778 0.34444444...

我错过了什么?我还检查了“x_train”是否正确。另外,knn 不是基于实例的学习器吗?为什么它无论如何都要调用距离函数?不应该只记住训练示例吗?谢谢。

1个回答

基于邻居的方法被称为非泛化机器学习方法,因为它们只是“记住”其所有训练数据(可能转换为快速索引结构,例如 Ball Tree 或 KD Tree)

(强调我的。来源:https ://scikit-learn.org/stable/modules/neighbors.html#nearest-neighbors )

您正在使用algorithm='ball_tree',粗略地说,它使用欧几里得球对点进行聚类,利用球距离作为实际数据点之间距离的界限(wiki)。所以,我怀疑你看到的“a”值实际上是球的中心之一。

对于 Levenshtein 距离,您可能应该使用该'brute'算法。