为什么 Pytorch 仅对批处理是不确定的?

数据挖掘 Python 火炬
2022-03-16 19:22:37

我正在 CPU 上训练 LSTM 网络,并且在不使用数据加载器时可以获得确定性结果。但是当我使用 Pytorchs 数据加载器时,我得到了非确定性的训练错误结果,尽管从数据加载器加载的实际批次是确定性的。

我在这里种下了几乎所有我能想到的种子

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False  
torch.backends.cudnn.enabled   = False

这两段代码是:

optimiser = torch.optim.Adam(model_test.parameters(), learning_rate)
set_seed(42)

for t in range(num_epochs):

    for batch_idx, (X_train, y_train) in enumerate(train_loader):

        # Zero out gradients
        optimiser.zero_grad()

        # Forward pass
        y_pred = model_test(X_train)

        # Loss Function
        loss = loss_fn(y_pred, y_train)

        # Backward pass
        loss.backward()

        # Update parameters
        optimiser.step()

    if t % 100 == 0: print("Epoch ", t, "MSE: ", loss.item())

optimiser = torch.optim.Adam(model_test.parameters(), learning_rate)
set_seed(42)

for t in range(num_epochs):

    # Zero out gradients
    optimiser.zero_grad()

    # Forward pass
    y_pred = model_test(X_train)

    # Loss Function
    loss = loss_fn(y_pred, y_train)

    # Backward pass
    loss.backward()

    # Update parameters
    optimiser.step()

    if t % 100 == 0: print("Epoch ", t, "MSE: ", loss.item())

我在 Github 上看到一些帖子谈到 GPU 上存在确定性问题,但这只是在 CPU 上。

2个回答

LSTM 的 cudnn 实现存在确定性问题,这些问题似乎在 7.6.1 版本中得到修复。检查你的 cudnn 版本。

https://github.com/pytorch/pytorch/issues/18110

通过在 optimiser.zero_grad() 之后添加 set_seed(42) 可以使其具有确定性。不确定在 optimiser.zero_grad() 中会发生什么来弄乱每批的播种。