我正在尝试通过使用状态和动作表(即没有神经网络)来学习表格 Q 学习。我在FrozenLake环境中进行了尝试。这是一个非常简单的环境,其中的任务是G
从源头开始,S
避免漏洞H
,并遵循冻结路径,即F
. 这FrozenLake 网格看起来像这样
SFFF
FHFH
FFFH
HFFG
我正在使用滑滑版本,其中代理,如果它采取一步,有相同的概率沿着它打算的方向前进或垂直于原始方向横向滑动(如果该位置在网格中)。空洞是终态,目标是终态。
现在我首先尝试了值迭代,它收敛到以下状态值集
[0.0688909 0.06141457 0.07440976 0.05580732 0.09185454 0. 0.11220821 0. 0.14543635 0.24749695 0.29961759 0. 0. 0.3799359 0.63902015 0. ]
我还对策略迭代进行了编码,它也给了我相同的结果。所以我很有信心这个价值函数是正确的。
现在,我尝试编写 Q 学习算法,这是我的 Q 学习算法代码
def get_action(Q_table, state, epsilon):
"""
Uses e-greedy to policy to return an action corresponding to state
Args:
Q_table: numpy array containing the q values
state: current state
epsilon: value of epsilon in epsilon greedy strategy
env: OpenAI gym environment
"""
return env.action_space.sample() if np.random.random() < epsilon else np.argmax(Q_table[state])
def tabular_Q_learning(env):
"""
Returns the optimal policy by using tabular Q learning
Args:
env: OpenAI gym environment
Returns:
(policy, Q function, V function)
"""
# initialize the Q table
#
# Implementation detail:
# A numpy array of |x| * |a| values
Q_table = np.zeros((env.nS, env.nA))
# hyperparameters
epsilon = 0.9
episodes = 500000
lr = 0.81
for _ in tqdm_notebook(range(episodes)):
# initialize the state
state = env.reset()
if episodes / 1000 > 21:
epsilon = 0.1
t = 0
while True: # for each step of the episode
# env.render()
# print(observation)
# choose a from s using policy derived from Q
action = get_action(Q_table, state, epsilon)
# take action a, observe r, s_dash
s_dash, r, done, info = env.step(action)
# Q table update
Q_table[state][action] += lr * (r + gamma * np.max(Q_table[s_dash]) - Q_table[state][action])
state = s_dash
t += 1
if done:
# print("Episode finished after {} timesteps".format(t+1))
break
# print(Q_table)
policy = np.argmax(Q_table, axis=1)
V = np.max(Q_table, axis=1)
return policy, Q_table, V
我尝试运行它,它收敛到一组不同的值,如下[0.26426802 0.03656142 0.12557195 0.03075882 0.35018374 0. 0.02584052 0. 0.37657211 0.59209091 0.15439031 0. 0. 0.60367728 0.79768863 0. ]
我不明白,出了什么问题。Q 学习的实现非常简单。我检查了我的代码,似乎是正确的。
任何指针都会有所帮助。