机器学习之Grid World的Deep SARSA算法解析
来源:互联网 发布:淘宝用的什么系统 编辑:程序博客网 时间:2024/06/18 08:36
Github上某开源项目的Deep SARSA算法实现代码地址:
https://github.com/rlcode/reinforcement-learning/tree/a497d719e3ecdd254e6620cf4f4b9afb0524b099/1-grid-world/6-deep-sarsa
Deep SARSA
Deep SARSA算法是基于SARSA算法的,不同之处在于SARSA算法用一个q_table记录了Agent在所有状态下采取所有Action的概率,用SARSA这5个变量去更新这些概率,让Agent能够以更大的概率采取正确的动作。具体分析可以看一下我之前的一篇文章:机器学习之Grid World的SARSA算法解析。
而Deep SARSA算法则采用了神经网络算法代替q_table的记录和查询。用神经网络算法代替q_table有什么好处呢?我自己也有问过自己这样的问题,看了这个开源项目的实现之后我自己有一点想法,不知道正不正确:
当环境有非常多种状态的时候,用q_table记录所有状态并不是很现实。就像围棋,大家都知道,电脑不可能像记录象棋所有变化(即状态)那样子记录围棋的所有变化,因为围棋的所有状态加起来比人类目前能观测到的宇宙的所有原子数量还要多(这点大家可以百度)。所以当q_table不能记录足够多的变化的时候,增强学习是无法战胜人类的。所以结合神经网络算法的增强学习算法才由此诞生,目的就是让Agent经过有限的学习去应对无限的状态空间。
使用q_table的时候,如果Agent没有到达过某种状态,那么Agent在这个状态下选择action可能是随机的,就是瞎选,而神经网络算法有时候能避免这种瞎选,这就是为什么神经网络算法在学习过10000张猫的照片后能分辨出第10001张跟之前照片不同的照片是不是猫的原因。
以上就是我自己的一点想法,其实并不完全是自己想出来的,大神的文章看多了,摸到大象鼻子而已。至于神经网络算法怎么会这么神奇,推荐大家去学习一下神经网络算法,大神们都推荐UFLDL,我这里就不讲这个了,下面我们一起看看Deep SARSA算法的实现。
代码实现
入口代码
for e in range(EPISODES): done = False score = 0 state = env.reset() state = np.reshape(state, [1, 15]) while not done: # fresh env global_step += 1 # get action for the current state and go one step in environment action = agent.get_action(state) next_state, reward, done = env.step(action) next_state = np.reshape(next_state, [1, 15]) next_action = agent.get_action(next_state) agent.train_model(state, action, reward, next_state, next_action, done) state = next_state # every time step we do training score += reward state = copy.deepcopy(next_state)
上面的代码与SARSA算法的区别就是SARSA算法学习过程调用的是:
# with sample <s,a,r,s',a'>, agent learns new q function agent.learn(str(state), action, reward, str(next_state), next_action)
而Deep SARSA算法学习调用的是:
agent.train_model(state, action, reward, next_state, next_action, done)
SARSA算法中的agent.learn函数之前看过了,就是将SARSA 5个变量经过一定的计算算出来的结果去更新q_table的内容,以便做出更正确的决策。下面我们看看agent.train_modle函数做了什么:
def train_model(self, state, action, reward, next_state, next_action, done): if self.epsilon > self.epsilon_min: self.epsilon *= self.epsilon_decay state = np.float32(state) next_state = np.float32(next_state) target = self.model.predict(state)[0] # like Q Learning, get maximum Q value at s' # But from target model if done: target[action] = reward else: target[action] = (reward + self.discount_factor * self.model.predict(next_state)[0][next_action]) target = np.reshape(target, [1, 5]) # make minibatch which includes target q value and predicted q value # and do the model fit! self.model.fit(state, target, epochs=1, verbose=0)
train_model函数先是调用了model.predict函数获得一个target数组,然后使用跟SARSA算法一样的计算更新了target数组下标为action的位置的值,然后调用model.fit传入state和target数组。这样就完了吗,这样agent就学到东西了?我们再看看model是什么:
from keras.layers import Dense from keras.optimizers import Adam from keras.models import Sequential # approximate Q function using Neural Network # state is input and Q Value of each action is output of network def build_model(self): model = Sequential() model.add(Dense(30, input_dim=self.state_size, activation='relu')) model.add(Dense(30, activation='relu')) model.add(Dense(self.action_size, activation='linear')) model.summary() model.compile(loss='mse', optimizer=Adam(lr=self.learning_rate)) return model
model就是一个使用keras API搭建的神经网络模型。关于keras大家可以上它的官网看看,我这里就不多说了。实际上model.predict函数就是神经网络模型对输入state的一个预测过程,给出的预测结果是一个action_size的数组,表示state状态下各个action的概率。而model.fit函数是训练神经网络的过程,epochs参数表示训练神经网络的代数,verbose控制神经网络训练过程的日志输出,0表示不要任何输出。
动作策略
# get action from model using epsilon-greedy policy def get_action(self, state): if np.random.rand() <= self.epsilon: # The agent acts randomly return random.randrange(self.action_size) else: # Predict the reward value based on the given state state = np.float32(state) q_values = self.model.predict(state) return np.argmax(q_values[0])
这里的代码实现可以看出,Deep SARSA的动作策略就是一定概率的随机动作或者神经网络模型的一个预测结果。
总结
Deep SARSA算法是我接触的第一个结合神经网络的增强学习算法,让我更深刻地认识了神经网络算法的作用。太神奇太强大,难怪各大科学家都在大肆宣传神经网络,百度甚至all in,虽然多少有炒作的成分,但是强大的神经网络算法还是功不可没。Deep SARSA的实现代码还是容易看懂的,难懂的是model的实现。大家可以看看我前面推荐的网站学习,我也正在学习中,大家一起努力!这里再发一遍:UFLDL
- 机器学习之Grid World的Deep SARSA算法解析
- 机器学习之Grid World的SARSA算法解析
- 机器学习之Grid World的Monte Carlo算法解析
- 机器学习之Grid World的Q-Learning算法解析
- 机器学习之EM算法解析
- 机器学习之Policy Iteration算法解析
- 深度学习中sarsa算法和Q-learning算法的区别
- 机器学习之深度学习(Deep Learning)
- 机器学习算法实现解析——libFM之libFM的训练过程之Adaptive Regularization
- 强化学习简单示例——SARSA算法
- 机器学习之Hello World kNN
- 机器学习算法实现解析——libFM之libFM的模型处理部分
- 机器学习算法实现解析——libFM之libFM的训练过程概述
- 【增强学习】Sarsa
- 基于table的Q learning和Sarsa算法
- Deep Learning(深度学习)之(七)高维数据的机器学习
- 机器学习 -- Deep Learning
- 【机器学习】Deep Learning
- solver.prototxt参数解析
- tar--文件打包命令
- 剑指offer-用两个栈实现队列
- Cocos Creator使用VS Code调试方法
- 虚拟机如何安装CentOS
- 机器学习之Grid World的Deep SARSA算法解析
- linux的dd命令:文件复制与备份、快速生成大文件、大小写转换
- 高精度计算模板
- 持续集成~Jenkins构建dotnetCore的项目
- ArrayList源码解析(基于JDK1.7)
- Java NIO理解与使用
- VLAN的划分方式及其优缺点
- 开平方的七种算法
- Unity3D学习之路