Sarsa
来源:互联网 发布:java单机游戏免费下载 编辑:程序博客网 时间:2024/06/03 20:20
1、算法:
整个算法还是一直不断更新 Q table 里的值, 然后再根据新的值来判断要在某个 state 采取怎样的 action. 不过于 Qlearning 不同之处:
Sarsa在当前 state 已经想好了 state 对应的 action, 而且想好了 下一个 state_ 和下一个 action_ (Qlearning 还没有想好下一个 action_)
更新 Q(s,a) 的时候基于的是下一个 Q(s_, a_) (Qlearning 是基于 maxQ(s_)),这种不同之处使得 Sarsa 相对于 Qlearning, 更加的胆小. 因为 Qlearning 永远都是想着 maxQ 最大化, 因为这个 maxQ 而变得贪婪, 不考虑其他非 maxQ 的结果. 我们可以理解成 Qlearning 是一种贪婪, 大胆, 勇敢的算法. 而 Sarsa 是一种保守的算法, 他在乎每一步决策, 对于错误和死亡比较铭感. 这一点我们会在可视化的部分看出他们的不同. 两种算法都有他们的好处, 比如在实际中, 你比较在乎机器的损害, 用一种保守的算法, 在训练时就能减少损坏的次数.
2、代码实现:
maze_env: 环境模块,;
RL_brain:RL 的大脑部分
from maze_env import Mazefrom RL_brain import SarsaTable
2.1、迭代部分:
def update(): for episode in range(100): # 初始化环境 observation = env.reset() # Sarsa 根据 state 观测选择行为 action = RL.choose_action(str(observation)) while True: # 刷新环境 env.render() # 在环境中采取行为, 获得下一个 state_ (obervation_), reward, 和是否终止 observation_, reward, done = env.step(action) # 根据下一个 state (obervation_) 选取下一个 action_ action_ = RL.choose_action(str(observation_)) # 从 (s, a, r, s, a) 中学习, 更新 Q_tabel 的参数 ==> Sarsa RL.learn(str(observation), action, reward, str(observation_), action_) # 将下一个当成下一步的 state (observation) and action observation = observation_ action = action_ # 终止时跳出循环 if done: break # 大循环完毕 print('game over') env.destroy()if __name__ == "__main__": env = Maze() RL = SarsaTable(actions=list(range(env.n_actions))) env.after(100, update) env.mainloop()
2.2、主结构(1):
class SarsaTable: # 初始化 (与之前一样) def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): # 选行为 (与之前一样) def choose_action(self, observation): # 学习更新参数 (有改变) def learn(self, s, a, r, s_): # 检测 state 是否存在 (与之前一样) def check_state_exist(self, state):
主结构(2):继承的思想:
2.2.1、父类:
import numpy as npimport pandas as pdclass RL(object): def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): ... # 和 QLearningTable 中的代码一样 def check_state_exist(self, state): ... # 和 QLearningTable 中的代码一样 def choose_action(self, observation): ... # 和 QLearningTable 中的代码一样 def learn(self, *args): pass # 每种的都有点不同, 所以用 pass
2.2.2、Q-Learning子类:
class QLearningTable(RL): # 继承了父类 RL def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) # 表示继承关系 def learn(self, s, a, r, s_): # learn 的方法在每种类型中有不一样, 需重新定义 self.check_state_exist(s_) q_predict = self.q_table.ix[s, a] if s_ != 'terminal': q_target = r + self.gamma * self.q_table.ix[s_, :].max() else: q_target = r self.q_table.ix[s, a] += self.lr * (q_target - q_predict)
2.2.3、Sarsa子类:
class SarsaTable(RL): # 继承 RL class def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) # 表示继承关系 def learn(self, s, a, r, s_, a_): self.check_state_exist(s_) q_predict = self.q_table.ix[s, a] if s_ != 'terminal': q_target = r + self.gamma * self.q_table.ix[s_, a_] # q_target 基于选好的 a_ 而不是 Q(s_) 的最大值 else: q_target = r # 如果 s_ 是终止符 self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # 更新 q_table
2.3、环境:
import numpy as npnp.random.seed(1)import tkinter as tkimport timeUNIT = 40 # pixelsMAZE_H = 4 # grid heightMAZE_W = 4 # grid widthclass Maze(tk.Tk): def __init__(self): super(Maze, self).__init__() self.action_space = ['u', 'd', 'l', 'r'] self.n_actions = len(self.action_space) self.title('maze') self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT)) self._build_maze() def _build_maze(self): self.canvas = tk.Canvas(self, bg='white', height=MAZE_H * UNIT, width=MAZE_W * UNIT) # create grids for c in range(0, MAZE_W * UNIT, UNIT): x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT self.canvas.create_line(x0, y0, x1, y1) for r in range(0, MAZE_H * UNIT, UNIT): x0, y0, x1, y1 = 0, r, MAZE_H * UNIT, r self.canvas.create_line(x0, y0, x1, y1) # create origin origin = np.array([20, 20]) # hell hell1_center = origin + np.array([UNIT * 2, UNIT]) self.hell1 = self.canvas.create_rectangle( hell1_center[0] - 15, hell1_center[1] - 15, hell1_center[0] + 15, hell1_center[1] + 15, fill='black') # hell hell2_center = origin + np.array([UNIT, UNIT * 2]) self.hell2 = self.canvas.create_rectangle( hell2_center[0] - 15, hell2_center[1] - 15, hell2_center[0] + 15, hell2_center[1] + 15, fill='black') # create oval oval_center = origin + UNIT * 2 self.oval = self.canvas.create_oval( oval_center[0] - 15, oval_center[1] - 15, oval_center[0] + 15, oval_center[1] + 15, fill='yellow') # create red rect self.rect = self.canvas.create_rectangle( origin[0] - 15, origin[1] - 15, origin[0] + 15, origin[1] + 15, fill='red') # pack all self.canvas.pack() def reset(self): self.update() time.sleep(0.5) self.canvas.delete(self.rect) origin = np.array([20, 20]) self.rect = self.canvas.create_rectangle( origin[0] - 15, origin[1] - 15, origin[0] + 15, origin[1] + 15, fill='red') # return observation return self.canvas.coords(self.rect) def step(self, action): s = self.canvas.coords(self.rect) base_action = np.array([0, 0]) if action == 0: # up if s[1] > UNIT: base_action[1] -= UNIT elif action == 1: # down if s[1] < (MAZE_H - 1) * UNIT: base_action[1] += UNIT elif action == 2: # right if s[0] < (MAZE_W - 1) * UNIT: base_action[0] += UNIT elif action == 3: # left if s[0] > UNIT: base_action[0] -= UNIT self.canvas.move(self.rect, base_action[0], base_action[1]) # move agent s_ = self.canvas.coords(self.rect) # next state # reward function if s_ == self.canvas.coords(self.oval): reward = 1 done = True elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]: reward = -1 done = True else: reward = 0 done = False return s_, reward, done def render(self): time.sleep(0.1) self.update()
2.4、具体实现:
import numpy as npimport pandas as pdclass RL(object): def __init__(self, action_space, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): self.actions = action_space # a list self.lr = learning_rate self.gamma = reward_decay self.epsilon = e_greedy self.q_table = pd.DataFrame(columns=self.actions) def check_state_exist(self, state): if state not in self.q_table.index: # append new state to q table self.q_table = self.q_table.append( pd.Series( [0]*len(self.actions), index=self.q_table.columns, name=state, ) ) def choose_action(self, observation): self.check_state_exist(observation) # action selection if np.random.rand() < self.epsilon: # choose best action state_action = self.q_table.ix[observation, :] state_action = state_action.reindex(np.random.permutation(state_action.index)) # some actions have same value action = state_action.argmax() else: # choose random action action = np.random.choice(self.actions) return action def learn(self, *args): pass# off-policyclass QLearningTable(RL): def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): super(QLearningTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) def learn(self, s, a, r, s_): self.check_state_exist(s_) q_predict = self.q_table.ix[s, a] if s_ != 'terminal': q_target = r + self.gamma * self.q_table.ix[s_, :].max() # next state is not terminal else: q_target = r # next state is terminal self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # update# on-policyclass SarsaTable(RL): def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): super(SarsaTable, self).__init__(actions, learning_rate, reward_decay, e_greedy) def learn(self, s, a, r, s_, a_): self.check_state_exist(s_) q_predict = self.q_table.ix[s, a] if s_ != 'terminal': q_target = r + self.gamma * self.q_table.ix[s_, a_] # next state is not terminal else: q_target = r # next state is terminal self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # update
1 0
- Sarsa
- Sarsa-Lamda
- 【增强学习】Sarsa
- 强化学习系列<3>、Sarsa
- SARSA和Q-learning算法
- Sarsa 与 Q learning对比
- SARSA与Q-learning的区别
- 强化学习简单示例——SARSA算法
- On-policy Sarsa算法与Off-policy Q learning对比
- Sarsa(λ) and Q(λ) in Tabular Case
- 机器学习之Grid World的SARSA算法解析
- 机器学习之Grid World的Deep SARSA算法解析
- 基于table的Q learning和Sarsa算法
- 对Q-learning和sarsa的进一步理解
- 深度学习中sarsa算法和Q-learning算法的区别
- 深度学习中的sarsa(lambda)和 Q(lambda)算法
- 强化学习(五)----- 时间差分学习(Q learning, Sarsa learning)
- 强化学习入门 : 一文入门强化学习 (Sarsa、Q learning、Monte-carlo learning、Deep-Q-Network等)
- HDU1257
- 网站页面优化的方向及网页的设计
- 图论_最短路问题
- JPA使用指南 javax.persistence的注解配置
- 协同过滤Collaborative Filtering
- Sarsa
- PAT_1036. Boys vs Girls
- Codeforces 443D Andrey and Problem
- Python练习2-基本聊天程序-虚拟茶会话
- jieba中文分词工具
- [Codeforces805] C. Find Amir
- error LNK2005: __vsnwprintf already defined in libcmtd.lib(vsnwprnt.obj)
- 机器学习算法C/C++实现
- 【Hexo】Hexo+Github构建个人博客 (一):环境配置