class AtariA2C(nn.Module):
def __init__(self, input_shape, n_actions):
super(AtariA2C, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
)
conv_output_size = self. _get_conv_out(input_shape)
self.policy = nn.Sequential(
nn.Linear(conv_output_size, 512),
nn.ReLU(),
nn.Linear(512, n_actions),
)
self.value = nn.Sequential(
nn.Linear(conv_output_size, 512),
nn.ReLU(),
nn.Linear(512, 1),
)
def _get_conv_out(self, shape):
o = self.conv(T.zeros(1, *shape))
return int(np.prod(o.shape))
def forward(self, x):
x = x.float() / 256
conv_out = self.conv(x).view(x.size()[0], -1)
return self.policy(conv_out), self.value(conv_out)
在 Maxim Lapan 的书中Deep Reinforcement Learning Hands-on,在实现上述网络模型后,它说
通过网络的前向传递返回两个张量的元组:策略和值。现在我们有一个大而重要的函数,它接受一批环境转换并返回三个张量:一批状态、一批采取的行动和一批使用公式计算的 Q 值
这个 Q_value 将用于两个地方:计算均方误差 (MSE) 损失以改进值近似,与 DQN 相同,以及计算动作的优势。
我对一件事情感到非常困惑。我们如何以及为什么计算均方误差损失以改进Advantage Actor-Critic 算法中的值逼近?