k折交叉验证

数据挖掘 Python
2022-02-15 10:13:56

我想使用 mode_selection 中的 KFold 而不是 cross_validation ut 它不适用于 pobject Kfold

 from sklearn.model_selection import KFold
    import xgboost as xgb
    # Some useful parameters which will come in handy later on
    ntrain = X_train.shape[0]
    ntest = X_test.shape[0]
    SEED = 123 # for reproducibility
    NFOLDS = 10 # set folds for out-of-fold prediction
    kf = KFold(shuffle=False, n_splits= NFOLDS, random_state=SEED)
    def get_oof(clf, x_train, y_train, x_test):
        oof_train = np.zeros((ntrain,))
        oof_test = np.zeros((ntest,))
        oof_test_skf = np.empty((NFOLDS, ntest))

        for i, (train_index, test_index) in enumerate(kf):
            x_tr = x_train[train_index]
            y_tr = y_train[train_index]
            x_te = x_train[test_index]

            clf.train(x_tr, y_tr)

            oof_train[test_index] = clf.predict(x_te)
            oof_test_skf[i, :] = clf.predict(x_test)

        oof_test[:] = oof_test_skf.mean(axis=0)
        return oof_train.reshape(-1, 1), oof_test.reshape(-1, 1)
    xgb_oof_train, xgb_oof_test = get_oof(xgb,x_train, y_train, x_test)

我收到此错误 TypeError: 'KFold' object is not iterable

1个回答

尝试使用split方法作为enumerate参数,而不是kf(例如:for i, (train_index, test_index) in enumerate(kf.split(X)):

希望能帮助到你!