【增强学习】Sarsa

来源:互联网 发布:线刷用什么软件 编辑:程序博客网 时间:2024/06/16 12:16

增强学习之Sarsa


代码下载: github https://github.com/gutouyu/ReinforcementLearning


增强学习算法Sarsa(state-action-reward-state_-action_)。和QLearning非常相似。Sarsa是一种On policy的算法(整个循环都是在一个路径上),state_ action_ 在Q_table更新的时候就已经确定了。QLearning是off policy,state_,action_在Q_Table更新的时候,还没有确定。

1. 算法思想

Sarsa和QLearning非常相似,不同的只是Sarsa在更新Q_Table的时候,说到做到,它总是会选择比较大的Q(s2,a2),不但更新s1的时候是这样,选择的时候也是这样。

也就是说,Sarsa同样是在更新Q_table,与QLearning不同之处在于:

  • Sarsa在当前state已经想好了对应的action,而且也想好了下一个state_action_.而QLearning必须要等到进入了下一个状态state_才能想好对应的action_
  • Sarsa更新Q(s,a)的时候是根据下一个Q(s_ ,a_) 而QLearning是依据maxQ(s_ ,a_ )

2. 实现

所有代码,github可以下载

  • 主循环
# 有100条命来让Agent学习,如果还没有学会,game over    for episode in xrange(100):        # Init observation/state        observation = env.reset()        # Sarsa根据observation选取一个action        action = RL.choose_action(str(observation))        while True:            # Sarsa执行action,得到下一个observation observation_            observation_, reward, done = env.step(action)            # 执行了这一步的action之后,还要再选出下一步要执行的action,才能用于上一步的学习            action_ = RL.choose_action(str(observation_))            # 学习            RL.learn(str(observation), action, reward, str(observation_), action_)            # 更新observation action            observation = observation_            action = action_            # 死掉了(掉进黑块)就重新来            if done:                break
  • 状态、表格初始化
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):        self.actions = actions # a list        self.lr = learning_rate        self.gamma = reward_decay        self.epsilon = e_greedy        self.q_table = pd.DataFrame(columns=self.actions)
  • 根据observation选择action,采用的是epsilon-greedy。也就是说会以小概率的情况去进行探索,大部分的情况会选择最好的action
    def choose_action(self, observation):        self.check_state_exist(observation)        # action selection        if np.random.uniform() < self.epsilon:            # choose best action            state_action = self.q_table.ix[observation, :]            # 防止两个值相等的action总是选择其中一个,要随机选择            state_action = state_action.reindex(np.random.permutation(state_action.index))            action = state_action.argmax()        else:            # choose random action            action = np.random.choice(self.actions)        return action
  • 根据当前的observation action和下一步即将采取的observation_ action_ 以及得到的reward来进行学习
    def learn(self, s, a, r, s_, a_):        """        Sarsa是已经知道了下一步要采取的action,而且他也肯定会采取这个action。所以他的学习是直接基于下一次的action的。        """        self.check_state_exist(s_)        q_predict = self.q_table.ix[s, a]        if s_ != 'terminal':            q_target = r + self.gamma * self.q_table.ix[s_,a_] # Sarsa使用下一次采取的action来更新        else:            q_target = r # next state is terminal        self.q_table.ix[s, a] += self.lr * (q_target - q_predict) # update
  • 检查状态state是否已经存在,如果之前不存在就加入到SarsaTable中
    def check_state_exist(self, state):        if state not in self.q_table.index:            # append new state to q table            self.q_table = self.q_table.append(                pd.Series(                    [0] * len(self.actions),                    index=self.q_table.columns,                    name=state                )            )

3. 总结

  • Sarsa相比于QLearning要胆小很多。他每走一步都要特别怕,特别小心,都尽可能的避免掉进坑里。但是他也会以小概率采取不是那么好的action,因为我们是epislon-greedy.属于On policy。
  • QLearning要大胆很多,他为了达到目的,中间不管才多少坑,都不在乎。属于Off policy。