算法 源码 A3C
来源:互联网 发布:王诺诺知乎女神扒皮 编辑:程序博客网 时间:2024/05/18 04:39
A3C 源码解析
标签(空格分隔): 增强学习算法 源码
该代码实现连续空间的策略控制
"""Asynchronous Advantage Actor Critic (A3C) with continuous action space, Reinforcement Learning.Using:tensorflow r1.3gym 0.8.0"""import multiprocessingimport threadingimport tensorflow as tfimport numpy as npimport gymimport osimport shutilimport matplotlib.pyplot as pltGAME = 'Pendulum-v0'OUTPUT_GRAPH = TrueLOG_DIR = './log'N_WORKERS = multiprocessing.cpu_count()MAX_EP_STEP = 200MAX_GLOBAL_EP = 2000GLOBAL_NET_SCOPE = 'Global_Net' #全局网络UPDATE_GLOBAL_ITER = 10GAMMA = 0.9ENTROPY_BETA = 0.01LR_A = 0.0001 # learning rate for actorLR_C = 0.001 # learning rate for criticGLOBAL_RUNNING_R = []GLOBAL_EP = 0env = gym.make(GAME)N_S = env.observation_space.shape[0]N_A = env.action_space.shape[0]A_BOUND = [env.action_space.low, env.action_space.high] #连续动作的下上限class ACNet(object): def __init__(self, scope, globalAC=None): if scope == GLOBAL_NET_SCOPE: # get global network with tf.variable_scope(scope): self.s = tf.placeholder(tf.float32, [None, N_S], 'S') self.a_params, self.c_params = self._build_net(scope)[-2:] #创建Actor-Critic网络图 else: # local net, calculate losses with tf.variable_scope(scope): self.s = tf.placeholder(tf.float32, [None, N_S], 'S') self.a_his = tf.placeholder(tf.float32, [None, N_A], 'A') #当前状态动作输入 self.v_target = tf.placeholder(tf.float32, [None, 1], 'Vtarget') #target value的输入 mu, sigma, self.v, self.a_params, self.c_params = self._build_net(scope) #返回mu,sigma 以及 两个网络的学习参数 td = tf.subtract(self.v_target, self.v, name='TD_error') #value做差,TD-error with tf.name_scope('c_loss'): self.c_loss = tf.reduce_mean(tf.square(td)) #critic 网络的损失函数 with tf.name_scope('wrap_a_out'): mu, sigma = mu * A_BOUND[1], sigma + 1e-4 #在连续空间相当于action 注意的是这里是采样!!!!!! normal_dist = tf.distributions.Normal(mu, sigma) with tf.name_scope('a_loss'): log_prob = normal_dist.log_prob(self.a_his) #actor的损失函数 exp_v = log_prob * td entropy = normal_dist.entropy() # encourage exploration self.exp_v = ENTROPY_BETA * entropy + exp_v self.a_loss = tf.reduce_mean(-self.exp_v) with tf.name_scope('choose_a'): # use local params to choose action self.A = tf.clip_by_value(tf.squeeze(normal_dist.sample(1), axis=0), A_BOUND[0], A_BOUND[1]) with tf.name_scope('local_grad'): #分别对 actor 和critic 的参数求导 梯度 self.a_grads = tf.gradients(self.a_loss, self.a_params) self.c_grads = tf.gradients(self.c_loss, self.c_params) with tf.name_scope('sync'): with tf.name_scope('pull'): #将全局网络的参数 送往局部网络 self.pull_a_params_op = [l_p.assign(g_p) for l_p, g_p in zip(self.a_params, globalAC.a_params)] self.pull_c_params_op = [l_p.assign(g_p) for l_p, g_p in zip(self.c_params, globalAC.c_params)] with tf.name_scope('push'): #对全局网络的参数求导 并优化 self.update_a_op = OPT_A.apply_gradients(zip(self.a_grads, globalAC.a_params)) self.update_c_op = OPT_C.apply_gradients(zip(self.c_grads, globalAC.c_params)) def _build_net(self, scope): w_init = tf.random_normal_initializer(0., .1) with tf.variable_scope('actor'): l_a = tf.layers.dense(self.s, 200, tf.nn.relu6, kernel_initializer=w_init, name='la') mu = tf.layers.dense(l_a, N_A, tf.nn.tanh, kernel_initializer=w_init, name='mu') sigma = tf.layers.dense(l_a, N_A, tf.nn.softplus, kernel_initializer=w_init, name='sigma') with tf.variable_scope('critic'): l_c = tf.layers.dense(self.s, 100, tf.nn.relu6, kernel_initializer=w_init, name='lc') v = tf.layers.dense(l_c, 1, kernel_initializer=w_init, name='v') # state value a_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope + '/actor') c_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope + '/critic') return mu, sigma, v, a_params, c_params def update_global(self, feed_dict): # run by a local 训练全局网络 SESS.run([self.update_a_op, self.update_c_op], feed_dict) # local grads applies to global net def pull_global(self): # run by a local 更新局部网络 SESS.run([self.pull_a_params_op, self.pull_c_params_op]) def choose_action(self, s): # run by a local s = s[np.newaxis, :] return SESS.run(self.A, {self.s: s})[0]class Worker(object): #工作类 该类主要协同多线程对 全局网络进行学习 并更新局部网络 def __init__(self, name, globalAC): self.env = gym.make(GAME).unwrapped self.name = name self.AC = ACNet(name, globalAC) # 将全局网络 和 局部网络联系起来 def work(self): global GLOBAL_RUNNING_R, GLOBAL_EP total_step = 1 buffer_s, buffer_a, buffer_r = [], [], [] #存储最近几次的状态,动作, while not COORD.should_stop() and GLOBAL_EP < MAX_GLOBAL_EP: s = self.env.reset() ep_r = 0 for ep_t in range(MAX_EP_STEP): if self.name == 'W_0': self.env.render() a = self.AC.choose_action(s) s_, r, done, info = self.env.step(a) done = True if ep_t == MAX_EP_STEP - 1 else False ep_r += r buffer_s.append(s) buffer_a.append(a) buffer_r.append((r+8)/8) # normalize if total_step % UPDATE_GLOBAL_ITER == 0 or done: # update global and assign to local net if done: v_s_ = 0 # terminal else: v_s_ = SESS.run(self.AC.v, {self.AC.s: s_[np.newaxis, :]})[0, 0] buffer_v_target = [] for r in buffer_r[::-1]: # reverse buffer r v_s_ = r + GAMMA * v_s_ buffer_v_target.append(v_s_) buffer_v_target.reverse() buffer_s, buffer_a, buffer_v_target = np.vstack(buffer_s), np.vstack(buffer_a), np.vstack(buffer_v_target) feed_dict = { self.AC.s: buffer_s, self.AC.a_his: buffer_a, self.AC.v_target: buffer_v_target, } self.AC.update_global(feed_dict) buffer_s, buffer_a, buffer_r = [], [], [] self.AC.pull_global() s = s_ total_step += 1 if done: if len(GLOBAL_RUNNING_R) == 0: # record running episode reward GLOBAL_RUNNING_R.append(ep_r) else: GLOBAL_RUNNING_R.append(0.9 * GLOBAL_RUNNING_R[-1] + 0.1 * ep_r) print( self.name, "Ep:", GLOBAL_EP, "| Ep_r: %i" % GLOBAL_RUNNING_R[-1], ) GLOBAL_EP += 1 breakif __name__ == "__main__": SESS = tf.Session() with tf.device("/cpu:0"): OPT_A = tf.train.RMSPropOptimizer(LR_A, name='RMSPropA') OPT_C = tf.train.RMSPropOptimizer(LR_C, name='RMSPropC') GLOBAL_AC = ACNet(GLOBAL_NET_SCOPE) # we only need its params workers = [] # Create worker for i in range(N_WORKERS): i_name = 'W_%i' % i # worker name workers.append(Worker(i_name, GLOBAL_AC)) COORD = tf.train.Coordinator() SESS.run(tf.global_variables_initializer()) if OUTPUT_GRAPH: if os.path.exists(LOG_DIR): shutil.rmtree(LOG_DIR) tf.summary.FileWriter(LOG_DIR, SESS.graph) worker_threads = [] for worker in workers: job = lambda: worker.work() t = threading.Thread(target=job) t.start() worker_threads.append(t) COORD.join(worker_threads) plt.plot(np.arange(len(GLOBAL_RUNNING_R)), GLOBAL_RUNNING_R) plt.xlabel('step') plt.ylabel('Total moving reward') plt.show()
阅读全文
0 0
- 算法 源码 A3C
- A3C经典源码
- 强化学习A3C与UNREAL算法
- 深度增强学习前沿算法思想【DQN、A3C、UNREAL,简介】
- 深度增强学习前沿算法思想【DQN、A3C、UNREAL,简介】
- 深度增强学习前沿算法思想【DQN、A3C、UNREAL,简介】
- A3C代码详解
- 深度强化学习——A3C
- 强化学习——A3C,GA3C
- 算法源码
- 强化学习系列<8>Asynchronous Advantage Actor-Critic(A3C)
- Asynchronous Advantage Actor-Critic (A3C)实现cart-pole
- 一个算法源码
- CRC8算法DELPHI源码
- CANNY算法源码
- 快速排序算法源码
- 几个排序算法源码
- A*算法 路径源码
- ASP.NET Core缓存静态资源
- [52ABP实战系列] .NET CORE实战入门视频课程出来啦
- 一些常见的UI主题框架
- 欢迎使用CSDN-markdown编辑器
- 集成,就不用造轮子了
- 算法 源码 A3C
- 设计模式之简单工厂、工厂方法、抽象工厂
- Unity利用粒子系统模拟下雪积雪效果
- spring boot集成swagger2
- ETH挖矿软件之挖矿专家
- 微软发行的SQL Server 2017候选版本可支持Linux
- Linux (CentOS)安装VNC+XFCE可视化桌面环境 附安装FireFox浏览器
- EDID 解析
- unity5.6.3发布Android