Keras 中的交叉验证

数据挖掘 机器学习 Python 深度学习 喀拉斯 交叉验证
2021-09-21 08:01:58

假设我想在 Keras 中训练和测试 MNIST 数据集。

所需数据可按如下方式加载:

from keras.datasets import mnist

digits_data = mnist.load_data()

keras 有没有办法将这些数据分成三组,即:training_data、、test_datacross_validation_data

2个回答

从 Keras 文档中,您可以将数据加载到训练和测试集中,如下所示:

(X_train, y_train), (X_test, y_test) = mnist.load_data()

至于交叉验证,您可以从此处遵循此示例。

from sklearn.model_selection import StratifiedKFold

def load_data():
    # load your data using this function

def create model():
    # create your model using this function

def train_and_evaluate__model(model, data_train, labels_train, data_test, labels_test):
    model.fit...
    # fit and evaluate here.

if __name__ == "__main__":
    n_folds = 10
    data, labels, header_info = load_data()
    skf = StratifiedKFold(labels, n_folds=n_folds, shuffle=True)

    for i, (train, test) in enumerate(skf):
        print "Running Fold", i+1, "/", n_folds
        model = None # Clearing the NN.
        model = create_model()
        train_and_evaluate_model(model, data[train], labels[train], data[test], labels[test])

不在喀拉斯。我通常只使用sklearn的train_test_split功能:

from sklearn.model_selection import train_test_split

train, test = train_test_split(data, train_size=0.8)

Keras也有sklearn 包装器,以后可能会有用。