Keras 中 vanilla DQN 损失的优化实现

数据挖掘 喀拉斯 强化学习 执行
2021-10-10 21:32:23

我已经为 keras 中的连续/非图像(无 CNN)状态实现了 vanilla DQN。但是,我不确定我的损失计算实现是否是最优的。

提醒一下,损失定义为: loss=(r+γmaxaQ(s,a)Q(s,a))212

这是我对网络+损失函数的实现:

self.network = Sequential()
self.network.add(Dense(256, activation='relu', input_dim=input_dim))
self.network.add(Dense(32, activation='relu'))
self.network.add(Dense(output_dim))

def q_loss(data, y_pred):
    # Extract the concatenated data tensor
    action, reward, next_state, done = K.cast(data[:, 0], 'int32'), data[:, 1], data[:, 2:-1], K.cast(data[:, -1], 'bool')
    # Compute Q(s,a)
    mask = tf.one_hot(action, depth=y_pred.shape[1], dtype=tf.bool, on_value=True, off_value=False)
    q_action = tf.boolean_mask(y_pred, mask)
    # Compute the max of values at next state except if done=True
    max_q_next = K.max(self.network(next_state), axis=1) * K.cast(tf.logical_not(done), 'float32')
    # Compute the TD-error, do not propagate the gradient into the next state value
    td_error = reward + 0.95 * K.stop_gradient(max_q_next) - q_action
    # Compute the MSE
    loss = K.square(td_error) / 2
    return loss

self.network.compile(loss=q_loss, optimizer=RMSprop(lr=self.learning_rate))

这是我的火车功能:

def train(self):
    # Sample a batch (a tuple of narray) from the replay buffer
    # States (B*S), actions (B), rewards (B), next_states(B*S), dones (B)
    # B=batch_size and S=state_size
    states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)
    # Concatenate actions, rewards, next_states, dones together because keras loss only accept one tensor
    data = np.concatenate([np.expand_dims(actions, axis=1), np.expand_dims(rewards, axis=1), next_states, np.expand_dims(dones, axis=1)], axis=1)
    # Train on a batch
    self.network.train_on_batch(states, data)

我发现我计算 DQN td-error 和 loss 的方式很丑陋,而且可能不是最优的。你有更好的解决方法(也许结合 Keras 和 Tensorflow)?

我已经检查了多个现有的 Keras 实现(如link),但遗憾的是,它们主要在完整的 python/numpy 中计算 keras 之外的 td-error,witch 是 IMO 次优的。

0个回答
没有发现任何回复~