使用 Jax 的粗粒分子动力学中的力匹配 - 忽略能量损失时力不匹配

数据挖掘 机器学习 神经网络 收敛 图神经网络
2022-03-10 21:30:30

我目前正在探索分子动力学模拟的力匹配方法。由于我还处于探索状态,所以我尝试过调查

强制匹配神经网络 Colab Notebook

对应于 揭示玻璃态系统中静态结构的预测能力

他们训练一个图神经网络来匹配以估计来自位置的力。

因此,他们计算匹配能量和力的损失。

损失=(energypredictedenergytarget)2+(ForcespredictedForcestarget)2

其中能量定义为U(x,ϕ)并且力定义为dUdx=F.

当忽略能量损失和纯匹配力时,预测似乎收敛到 0。这并不直观,因为带有图神经网络的粗粒度分子动力学似乎只使用力匹配来训练他们的 GNN。问题是:有谁知道神经网络为什么会这样。

重现我的观察:

将1中的损失函数更改为:

@jit
def loss(params, R, targets):
  return force_loss(params, R, targets[1]) 

和培训:

train_epochs = 20

opt_state = opt.init(params)

train_energy_error = []
test_energy_error = []

for iteration in range(train_epochs):
  train_energy_error += [float(np.sqrt(force_loss(params, batch_Rs[0], batch_Fs[0])))]
  test_energy_error += [float(np.sqrt(force_loss(params, test_positions, test_forces)))]
 
  draw_training(params)

  params, opt_state = update_epoch((params, opt_state), 
                                   (batch_Rs, (batch_Es, batch_Fs)))

  onp.random.shuffle(lookup)
  batch_Rs, batch_Es, batch_Fs = make_batches(lookup)

以及培训的可视化:

 def draw_training(params):
  display.clear_output(wait=True)
  display.display(plt.gcf())
  plt.subplot(1, 2, 1)
  plt.semilogy(train_energy_error)
  plt.semilogy(test_energy_error)
  plt.xlim([0, train_epochs])
  format_plot('Epoch', '$L$')
  plt.subplot(1, 2, 2)
  predicted = vmap(force_fn, (None, 0))(params, example_positions).reshape((-1,))
  plt.plot(example_forces.reshape((-1,)), predicted, 'o')
  #plt.plot(np.linspace(-400, -300, 10), np.linspace(-400, -300, 10), '--')
  format_plot('$E_{label}$', '$E_{prediction}$')
  finalize_plot((2, 1))
  plt.show()

重现我的观察。

0个回答
没有发现任何回复~