视频源:

DQN 算法更新 (Tensorflow)

作者: UnityTutorial 编辑: UnityTutorial 2017-02-26

学习资料:

要点

Deep Q Network 的简称叫 DQN, 是将 Q learning 的优势 和 Neural networks 结合了. 如果我们使用 tabular Q learning, 对于每一个 state, action 我们都需要存放在一张 q_table 的表中. 如果像显示生活中, 情况可就比那个迷宫的状况复杂多了, 我们有千千万万个 state, 如果将这千万个 state 的值都放在表中, 受限于我们计算机硬件, 这样从表中获取数据, 更新数据是没有效率的. 这就是 DQN 产生的原因了. 我们可以使用神经网络来 估算 这个 state 的值, 这样就不需要一张表了.

这次的教程我们还是基于熟悉的 迷宫 环境, 重点在实现 DQN 算法, 之后我们再拿着做好的 DQN 算法去跑其他更有意思的环境.

算法

DQN 算法更新 (Tensorflow)

整个算法乍看起来很复杂, 不过我们拆分一下, 就变简单了. 也就是个 Q learning 主框架上加了些装饰.

这些装饰包括:

  • 记忆库 (用于重复学习)
  • 神经网络计算 Q 值
  • 暂时冻结 q_target 参数 (切断相关性)

算法的代码形式

接下来我们对应上面的算法, 来实现主循环. 首先 import 所需模块.

from maze_env import Maze
from RL_brain import DeepQNetwork

下面的代码, 就是 DQN 于环境交互最重要的部分.

def run_maze():
    step = 0    # 用来控制什么时候学习
    for episode in range(300):
        # 初始化环境
        observation = env.reset()

        while True:
            # 刷新环境
            env.render()

            # DQN 根据观测值选择行为
            action = RL.choose_action(observation)

            # 环境根据行为给出下一个 state, reward, 是否终止
            observation_, reward, done = env.step(action)

            # DQN 存储记忆
            RL.store_transition(observation, action, reward, observation_)

            # 控制学习起始时间和频率 (先累积一些记忆再开始学习)
            if (step > 200) and (step % 5 == 0):
                RL.learn()

            # 将下一个 state_ 变为 下次循环的 state
            observation = observation_

            # 如果终止, 就跳出循环
            if done:
                break
            step += 1   # 总步数

    # end of game
    print('game over')
    env.destroy()


if __name__ == "__main__":
    env = Maze()
    RL = DeepQNetwork(env.n_actions, env.n_features,
                      learning_rate=0.01,
                      reward_decay=0.9,
                      e_greedy=0.9,
                      replace_target_iter=200,  # 每 200 步替换一次 target_net 的参数
                      memory_size=2000, # 记忆上限
                      # output_graph=True   # 是否输出 tensorboard 文件
                      )
    env.after(100, run_maze)
    env.mainloop()
    RL.plot_cost()  # 观看神经网络的误差曲线

下一节我们会来讲解 DeepQNetwork 这种算法具体要怎么编.

如果想一次性看到全部代码, 请去我的 Github

分享到: Facebook 微博 微信 Twitter
如果你觉得这篇文章或视频对你的学习很有帮助, 请你也分享它, 让它能再次帮助到更多的需要学习的人. UnityTutorial没有正式的经济来源, 如果你也想支持 UnityTutorial 并看到更好的教学内容, 赞助他一点点, 作为鼓励他继续开源的动力.

支持 让教学变得更优秀