Q-learning算法实现

来源:互联网 发布:身份证合成软件 编辑:程序博客网 时间:2024/05/18 04:40

1、算法: 
这里写图片描述 
整个算法就是一直不断更新 Q table 里的值, 然后再根据新的值来判断要在某个 state 采取怎样的 action. Qlearning 是一个 off-policy 的算法, 因为里面的 max action 让 Q table 的更新可以不基于正在经历的经验(可以是现在学习着很久以前的经验,甚至是学习他人的经验). 不过这一次的例子, 我们没有运用到 off-policy, 而是把 Qlearning 用在了 on-policy 上, 也就是现学现卖, 将现在经历的直接当场学习并运用. On-policy 和 off-policy 的差别我们会在之后的 Deep Q network (off-policy) 学习中见识到. 而之后的教程也会讲到一个 on-policy (Sarsa) 的形式, 我们之后再对比. 
2、代码实现: 
maze_env :环境模块, maze_env 模块我们可以不深入研究, 可以去看看如何使用 python 自带的简单 GUI 模块 tkinter 来编写虚拟环境. 
RL_brain: 这个模块是 Reinforment Learning 的大脑部分。

from maze_env import Mazefrom RL_brain import QLearningTable`
  • 1
  • 2

算法主要部分:

def update():    # 学习 100 回合    for episode in range(100):        # 初始化 state 的观测值        observation = env.reset()        while True:            # 更新可视化环境            env.render()            # RL 大脑根据 state 的观测值挑选 action            action = RL.choose_action(str(observation))            # 探索者在环境中实施这个 action, 并得到环境返回的下一个 state 观测值, reward 和 done (是否是掉下地狱或者升上天堂)            observation_, reward, done = env.step(action)            # RL 从这个序列 (state, action, reward, state_) 中学习            RL.learn(str(observation), action, reward, str(observation_))            # 将下一个 state 的值传到下一次循环            observation = observation_            # 如果掉下地狱或者升上天堂, 这回合就结束了            if done:                break    # 结束游戏并关闭窗口    print('game over')    env.destroy()if __name__ == "__main__":    # 定义环境 env 和 RL 方式    env = Maze()    RL = QLearningTable(actions=list(range(env.n_actions)))    # 开始可视化环境 env    env.after(100, update)    env.mainloop()3、QLearningTable:3.1、主结构
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
class QLearningTable:    # 初始化    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):    # 选行为    def choose_action(self, observation):    # 学习更新参数    def learn(self, s, a, r, s_):    # 检测 state 是否存在    def check_state_exist(self, state):
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

3.2、预设值:

import numpy as npimport pandas as pdclass QLearningTable:    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):        self.actions = actions  # a list        self.lr = learning_rate # 学习率        self.gamma = reward_decay   # 奖励衰减        self.epsilon = e_greedy     # 贪婪度        self.q_table = pd.DataFrame(columns=self.actions)   # 初始 q_table
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

3.3、决定行为: 
这里是定义如何根据所在的 state, 或者是在这个 state 上的 观测值 (observation) 来决策.

def choose_action(self, observation):        self.check_state_exist(observation) # 检测本 state 是否在 q_table 中存在(见后面标题内容)        # 选择 action        if np.random.uniform() < self.epsilon:  # 选择 Q value 最高的 action            state_action = self.q_table.ix[observation, :]            # 同一个 state, 可能会有多个相同的 Q action value, 所以我们乱序一下            state_action = state_action.reindex(np.random.permutation(state_action.index))            action = state_action.argmax()        else:   # 随机选择 action            action = np.random.choice(self.actions)        return action
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

3.4、学习: 
根据是否是 terminal state (回合终止符) 来判断应该如何更行 q_table. 更新的方式是不是很熟悉呢:

update = self.lr * (q_target - q_predict)

这可以理解成神经网络中的更新方式, 学习率 * (真实值 - 预测值). 将判断误差传递回去, 有着和神经网络更新的异曲同工之处.

def learn(self, s, a, r, s_):        self.check_state_exist(s_)  # 检测 q_table 中是否存在 s_ (见后面标题内容)        q_predict = self.q_table.ix[s, a]        if s_ != 'terminal':            q_target = r + self.gamma * self.q_table.ix[s_, :].max()  # 下个 state 不是 终止符        else:            q_target = r  # 下个 state 是终止符        self.q_table.ix[s, a] += self.lr * (q_target - q_predict)  # 更新对应的 state-action 值
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

3.5、检测state是否存在: 
这个功能就是检测 q_table 中有没有当前 state 的步骤了, 如果还没有当前 state, 那我我们就插入一组全 0 数据, 当做这个 state 的所有 action 初始 values.

 def check_state_exist(self, state):        if state not in self.q_table.index:            # append new state to q table            self.q_table = self.q_table.append(                pd.Series(                    [0]*len(self.actions),                    index=self.q_table.columns,                    name=state,                )            )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

4、附加环境env:

"""import numpy as npnp.random.seed(1)import tkinter as tkimport timeUNIT = 40   # pixelsMAZE_H = 4  # grid heightMAZE_W = 4  # grid widthclass Maze(tk.Tk, object):    def __init__(self):        super(Maze, self).__init__()        self.action_space = ['u', 'd', 'l', 'r']        self.n_actions = len(self.action_space)        self.title('maze')        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT))        self._build_maze()    def _build_maze(self):        self.canvas = tk.Canvas(self, bg='white',                           height=MAZE_H * UNIT,                           width=MAZE_W * UNIT)        # create grids        for c in range(0, MAZE_W * UNIT, UNIT):            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT            self.canvas.create_line(x0, y0, x1, y1)        for r in range(0, MAZE_H * UNIT, UNIT):            x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r            self.canvas.create_line(x0, y0, x1, y1)        # create origin        origin = np.array([20, 20])        # hell        hell1_center = origin + np.array([UNIT * 2, UNIT])        self.hell1 = self.canvas.create_rectangle(            hell1_center[0] - 15, hell1_center[1] - 15,            hell1_center[0] + 15, hell1_center[1] + 15,            fill='black')        # hell        hell2_center = origin + np.array([UNIT, UNIT * 2])        self.hell2 = self.canvas.create_rectangle(            hell2_center[0] - 15, hell2_center[1] - 15,            hell2_center[0] + 15, hell2_center[1] + 15,            fill='black')        # create oval        oval_center = origin + UNIT * 2        self.oval = self.canvas.create_oval(            oval_center[0] - 15, oval_center[1] - 15,            oval_center[0] + 15, oval_center[1] + 15,            fill='yellow')        # create red rect        self.rect = self.canvas.create_rectangle(            origin[0] - 15, origin[1] - 15,            origin[0] + 15, origin[1] + 15,            fill='red')        # pack all        self.canvas.pack()    def reset(self):        self.update()        time.sleep(0.5)        self.canvas.delete(self.rect)        origin = np.array([20, 20])        self.rect = self.canvas.create_rectangle(            origin[0] - 15, origin[1] - 15,            origin[0] + 15, origin[1] + 15,            fill='red')        # return observation        return self.canvas.coords(self.rect)    def step(self, action):        s = self.canvas.coords(self.rect)        base_action = np.array([0, 0])        if action == 0:   # up            if s[1] > UNIT:                base_action[1] -= UNIT        elif action == 1:   # down            if s[1] < (MAZE_H - 1) * UNIT:                base_action[1] += UNIT        elif action == 2:   # right            if s[0] < (MAZE_W - 1) * UNIT:                base_action[0] += UNIT        elif action == 3:   # left            if s[0] > UNIT:                base_action[0] -= UNIT        self.canvas.move(self.rect, base_action[0], base_action[1])  # move agent        s_ = self.canvas.coords(self.rect)  # next state        # reward function        if s_ == self.canvas.coords(self.oval):            reward = 1            done = True        elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]:            reward = -1            done = True        else:            reward = 0            done = False        return s_, reward, done    def render(self):        time.sleep(0.1)        self.update()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
原创粉丝点击