Sklearn StratifiedKFold代码解释

数据挖掘 scikit-学习 采样 卡格尔
2022-02-07 16:53:30

在浏览以下博客时,我遇到了以下代码片段

from sklearn.cross_validation import StratifiedKFold
eval_size = 0.10
kf = StratifiedKFold(y,round(1./eval_size))
train_indices, valid_indices = next(iter(kf))
X_train, y_train = X[train_indices], y[train_indices]
X_valid, y_valid = X[valid_indices], y[valid_indices]

我无法理解它是如何工作的。谁能帮我解释一下?谢谢

1个回答

KFold 拆分将获取数据并拆分您指定的多次。StratifiedKFold 用于确保您的训练和验证数据集均包含相同百分比的类(有关更多信息,请参阅 sklearn 文档)。函数 StratifiedKFold 有两个参数,标签数组(对于二进制分类,这将是 1 和 0 的数组)和折叠数。他们将折叠数指定为 1./eval_size,其中 eval_size = 0.10。所以这是一个 10 倍的验证。

train_indices, valid_indices = next(iter(kf))

该行导出索引以便将数据拆分为训练/验证数据集。

现在我们有了索引,我们使用这些索引来实际分割数据。

有趣的是,他们编写了 next(iter(kf)),然后在可以使用 sklearn.cross_validation.train_test_split 时将索引输入到数据集中。train_test_split 只是 next(iter(kf)) 的一个包装器,但它更具可读性,并且它已经是 sklearn 中的一个函数。

希望这可以帮助!