Sarsa

来源:互联网 发布:java单机游戏免费下载 编辑:程序博客网 时间:2024/06/03 20:20

1、算法:
整个算法还是一直不断更新 Q table 里的值, 然后再根据新的值来判断要在某个 state 采取怎样的 action. 不过于 Qlearning 不同之处:
Sarsa在当前 state 已经想好了 state 对应的 action, 而且想好了 下一个 state_ 和下一个 action_ (Qlearning 还没有想好下一个 action_)
更新 Q(s,a) 的时候基于的是下一个 Q(s_, a_) (Qlearning 是基于 maxQ(s_)),这种不同之处使得 Sarsa 相对于 Qlearning, 更加的胆小. 因为 Qlearning 永远都是想着 maxQ 最大化, 因为这个 maxQ 而变得贪婪, 不考虑其他非 maxQ 的结果. 我们可以理解成 Qlearning 是一种贪婪, 大胆, 勇敢的算法. 而 Sarsa 是一种保守的算法, 他在乎每一步决策, 对于错误和死亡比较铭感. 这一点我们会在可视化的部分看出他们的不同. 两种算法都有他们的好处, 比如在实际中, 你比较在乎机器的损害, 用一种保守的算法, 在训练时就能减少损坏的次数.
这里写图片描述

2、代码实现:
maze_env: 环境模块,;
RL_brain:RL 的大脑部分

from maze_env import Mazefrom RL_brain import SarsaTable

2.1、迭代部分:

def update():    for episode in range(100):        # 初始化环境        observation = env.reset()        # Sarsa 根据 state 观测选择行为        action = RL.choose_action(str(observation))        while True:            # 刷新环境            env.render()            # 在环境中采取行为, 获得下一个 state_ (obervation_), reward, 和是否终止            observation_, reward, done = env.step(action)            # 根据下一个 state (obervation_) 选取下一个 action_            action_ = RL.choose_action(str(observation_))            # 从 (s, a, r, s, a) 中学习, 更新 Q_tabel 的参数 ==> Sarsa            RL.learn(str(observation), action, reward, str(observation_), action_)            # 将下一个当成下一步的 state (observation) and action            observation = observation_            action = action_            # 终止时跳出循环            if done:                break    # 大循环完毕    print('game over')    env.destroy()if __name__ == "__main__":    env = Maze()    RL = SarsaTable(actions=list(range(env.n_actions)))    env.after(100, update)    env.mainloop()

2.2、主结构(1):

class SarsaTable:    # 初始化 (与之前一样)    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):    # 选行为 (与之前一样)    def choose_action(self, observation):    # 学习更新参数 (有改变)    def learn(self, s, a, r, s_):    # 检测 state 是否存在 (与之前一样)    def check_state_exist(self, state):

主结构(2):继承的思想:
2.2.1、父类:

import numpy as npimport pandas as pdclass RL(object):    def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):        ... # 和 QLearningTable 中的代码一样    def check_state_exist(self, state):        ... # 和 QLearningTable 中的代码一样    def choose_action(self, observation):        ... # 和 QLearningTable 中的代码一样    def learn(self, *args):        pass # 每种的都有点不同, 所以用 pass

2.2.2、Q-Learning子类:

class QLearningTable(RL):   # 继承了父类 RL    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)    # 表示继承关系    def learn(self, s, a, r, s_):   # learn 的方法在每种类型中有不一样, 需重新定义        self.check_state_exist(s_)        q_predict = self.q_table.ix[s, a]        if s_ != 'terminal':            q_target = r + self.gamma * self.q_table.ix[s_, :].max()        else:            q_target = r        self.q_table.ix[s, a] += self.lr * (q_target - q_predict)

2.2.3、Sarsa子类:

class SarsaTable(RL):   # 继承 RL class    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)    # 表示继承关系    def learn(self, s, a, r, s_, a_):        self.check_state_exist(s_)        q_predict = self.q_table.ix[s, a]        if s_ != 'terminal':            q_target = r + self.gamma * self.q_table.ix[s_, a_]  # q_target 基于选好的 a_ 而不是 Q(s_) 的最大值        else:            q_target = r  # 如果 s_ 是终止符        self.q_table.ix[s, a] += self.lr * (q_target - q_predict)  # 更新 q_table

2.3、环境:

import numpy as npnp.random.seed(1)import tkinter as tkimport timeUNIT = 40   # pixelsMAZE_H = 4  # grid heightMAZE_W = 4  # grid widthclass Maze(tk.Tk):    def __init__(self):        super(Maze, self).__init__()        self.action_space = ['u', 'd', 'l', 'r']        self.n_actions = len(self.action_space)        self.title('maze')        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT))        self._build_maze()    def _build_maze(self):        self.canvas = tk.Canvas(self, bg='white',                           height=MAZE_H * UNIT,                           width=MAZE_W * UNIT)        # create grids        for c in range(0, MAZE_W * UNIT, UNIT):            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT            self.canvas.create_line(x0, y0, x1, y1)        for r in range(0, MAZE_H * UNIT, UNIT):            x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r            self.canvas.create_line(x0, y0, x1, y1)        # create origin        origin = np.array([20, 20])        # hell        hell1_center = origin + np.array([UNIT * 2, UNIT])        self.hell1 = self.canvas.create_rectangle(            hell1_center[0] - 15, hell1_center[1] - 15,            hell1_center[0] + 15, hell1_center[1] + 15,            fill='black')        # hell        hell2_center = origin + np.array([UNIT, UNIT * 2])        self.hell2 = self.canvas.create_rectangle(            hell2_center[0] - 15, hell2_center[1] - 15,            hell2_center[0] + 15, hell2_center[1] + 15,            fill='black')        # create oval        oval_center = origin + UNIT * 2        self.oval = self.canvas.create_oval(            oval_center[0] - 15, oval_center[1] - 15,            oval_center[0] + 15, oval_center[1] + 15,            fill='yellow')        # create red rect        self.rect = self.canvas.create_rectangle(            origin[0] - 15, origin[1] - 15,            origin[0] + 15, origin[1] + 15,            fill='red')        # pack all        self.canvas.pack()    def reset(self):        self.update()        time.sleep(0.5)        self.canvas.delete(self.rect)        origin = np.array([20, 20])        self.rect = self.canvas.create_rectangle(            origin[0] - 15, origin[1] - 15,            origin[0] + 15, origin[1] + 15,            fill='red')        # return observation        return self.canvas.coords(self.rect)    def step(self, action):        s = self.canvas.coords(self.rect)        base_action = np.array([0, 0])        if action == 0:   # up            if s[1] > UNIT:                base_action[1] -= UNIT        elif action == 1:   # down            if s[1] < (MAZE_H - 1) * UNIT:                base_action[1] += UNIT        elif action == 2:   # right            if s[0] < (MAZE_W - 1) * UNIT:                base_action[0] += UNIT        elif action == 3:   # left            if s[0] > UNIT:                base_action[0] -= UNIT        self.canvas.move(self.rect, base_action[0], base_action[1])  # move agent        s_ = self.canvas.coords(self.rect)  # next state        # reward function        if s_ == self.canvas.coords(self.oval):            reward = 1            done = True        elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]:            reward = -1            done = True        else:            reward = 0            done = False        return s_, reward, done    def render(self):        time.sleep(0.1)        self.update()

2.4、具体实现:

import numpy as npimport pandas as pdclass RL(object):    def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):        self.actions = action_space  # a list        self.lr = learning_rate        self.gamma = reward_decay        self.epsilon = e_greedy        self.q_table = pd.DataFrame(columns=self.actions)    def check_state_exist(self, state):        if state not in self.q_table.index:            # append new state to q table            self.q_table = self.q_table.append(                pd.Series(                    [0]*len(self.actions),                    index=self.q_table.columns,                    name=state,                )            )    def choose_action(self, observation):        self.check_state_exist(observation)        # action selection        if np.random.rand() < self.epsilon:            # choose best action            state_action = self.q_table.ix[observation, :]            state_action = state_action.reindex(np.random.permutation(state_action.index))     # some actions have same value            action = state_action.argmax()        else:            # choose random action            action = np.random.choice(self.actions)        return action    def learn(self, *args):        pass# off-policyclass QLearningTable(RL):    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):        super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)    def learn(self, s, a, r, s_):        self.check_state_exist(s_)        q_predict = self.q_table.ix[s, a]        if s_ != 'terminal':            q_target = r + self.gamma * self.q_table.ix[s_, :].max()  # next state is not terminal        else:            q_target = r  # next state is terminal        self.q_table.ix[s, a] += self.lr * (q_target - q_predict)  # update# on-policyclass SarsaTable(RL):    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):        super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy)    def learn(self, s, a, r, s_, a_):        self.check_state_exist(s_)        q_predict = self.q_table.ix[s, a]        if s_ != 'terminal':            q_target = r + self.gamma * self.q_table.ix[s_, a_]  # next state is not terminal        else:            q_target = r  # next state is terminal        self.q_table.ix[s, a] += self.lr * (q_target - q_predict)  # update
1 0