为什么 GridSearchCV (sklearn) 会改变 n_samples 的值?

数据挖掘 Python scikit-学习
2022-02-25 12:19:29

我认为n_samples是训练示例的数量。但是当使用 GridSearchCV 时,n_samples 变为 32 而不是 50。

使用 GridSearchCV 时出错:

预期 n_neighbors <= n_samples,但 n_samples = 32,n_neighbors = 50

训练示例:

print(X_train.shape[0])=> 50

print(len(y_train))=> 50

这有效:

neigh = KNeighborsClassifier(n_neighbors=50)
neigh.fit(X_train, y_train) 
result = neigh.predict(X_test)

这失败了:

from sklearn.model_selection import GridSearchCV

grid_params = { 
    "n_neighbors" : [50]
}

g = GridSearchCV(KNeighborsClassifier(), grid_params)
g.fit(X_train, y_train)

我很困惑为什么n_samples在使用 GridSearchCV 时会变成 32。

1个回答

CV 代表 CrossValidation,这意味着它将您的训练集拆分为多个折叠(在本例中为 3),在其中的 n-1个折叠上进行训练并在剩余的折叠上进行测试。这就是为什么您的训练现在是在 32 个而不是 50 个样本上完成的。交叉验证对于估计模型(包括特定的超参数)在未见数据上的表现非常有用。