如何在验证数据集的每一折中打印准确性?并为数据框中的每一行分配折叠号?

数据挖掘 网格搜索
2022-02-14 08:23:59

如何打印验证数据集每一折的准确性?以及如何将折叠号分配给数据框中的每一行?

classifier = RandomForestRegressor(n_jobs = -1, criterion='mse')
clf = GridSearchCV(classifier, param_grid = tunedParameters,cv=10)
all_accuracies = cross_val_score(classifier, X=X_train, y=y_train, cv=10)
print(all_accuracies) 
1个回答

你在混淆GridSearchCVcross_val_score; 你应该只需要运行其中一个。

GridSearchCV将在您的超参数空间中搜索每个组合,使用交叉验证并产生分数。您可以通过属性访问这些分数cv_results_

cross_val_score没有超参数搜索;它只是使用交叉验证得分。输出是单个折叠分数的列表。

如果您已经使用过GridSearchCV,则可能没有理由使用cross_val_score. (在超参数搜索之后,您已经看到并使用了该集合中的所有数据,因此 中的分数和 中的分数cv_results_都有偏差cross_val_score;如果您需要对性能进行无偏估计,则需要另一个测试集(或首先是嵌套交叉验证)。)

如果您想跟踪哪些样本进入哪个折叠,我认为您需要使用交叉验证生成器或可迭代的cv参数而不是整数。然后您可以使用该生成器/迭代器来告诉您哪些样本在哪个折叠中。