DQN_tensorflow 源码解读
来源:互联网 发布:姚明nba前几场比赛数据 编辑:程序博客网 时间:2024/06/05 17:42
最近根据课题需要在研究《Playing Atari with Deep Reinforcement Learning》这篇论文,也就是DeepMind的最原始的算法,该论文对应的开源代码很多,这里以github上的一个开源代码[https://github.com/gliese581gg/DQN_tensorflow]为例,理解深度强化学习的具体训练学习过程,代码是基于tensorflow,opencv的,本人对代码进行了详细的注释,希望对大家有所帮助
Main函数,里面定义了deep_atari类,提供接口进行训练测试,其中params字典为网络进行相应的 具体看代码配置
from database import *from emulator import *import tensorflow as tfimport numpy as npimport timefrom ale_python_interface import ALEInterfaceimport cv2from scipy import miscimport gc #garbage colloectorimport threadgc.enable()#给网络定义参数params = { 'visualize' : True, 'network_type':'nips', 'ckpt_file':None, 'steps_per_epoch': 50000, 'num_epochs': 100, 'eval_freq':50000, 'steps_per_eval':10000, 'copy_freq' : 10000, 'disp_freq':10000, 'save_interval':10000, 'db_size': 1000000, 'batch': 32, 'num_act': 0, 'input_dims' : [210, 160, 3], 'input_dims_proc' : [84, 84, 4], 'learning_interval': 1, 'eps': 1.0, 'eps_step':1000000, 'eps_min' : 0.1, 'eps_eval' : 0.05, 'discount': 0.95, 'lr': 0.0002, 'rms_decay':0.99, 'rms_eps':1e-6, 'train_start':100, 'img_scale':255.0, 'clip_delta' : 0, #nature : 1 'gpu_fraction' : 0.25, 'batch_accumulator':'mean', 'record_eval' : True, 'only_eval' : 'n'}class deep_atari: def __init__(self,params): print 'Initializing Module...' self.params = params self.gpu_config = tf.ConfigProto(gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.params['gpu_fraction'])) self.sess = tf.Session(config=self.gpu_config) self.DB = database(self.params)#初始化replay memory self.engine = emulator(rom_name='breakout.bin', vis=self.params['visualize'],windowname=self.params['network_type']+'_preview') self.params['num_act'] = len(self.engine.legal_actions)#确认该游戏,action的数量 self.build_net()#创建网络 qnet网络 和target网络 self.training = True def build_net(self): print 'Building QNet and targetnet...' '''qnet表示训练网络,target表示测试的网络,整个模型可以理解为一个游戏的初学者, 在玩游戏的过程中,通过已有的经验的学会怎么玩游戏,在这里已有的经验为DB(replay memory)''' self.qnet = DQN(self.params,'qnet')#定义qnet self.targetnet = DQN(self.params,'targetnet')#定义targetnet self.sess.run(tf.initialize_all_variables()) saver_dict = {'qw1':self.qnet.w1,'qb1':self.qnet.b1, 'qw2':self.qnet.w2,'qb2':self.qnet.b2, 'qw3':self.qnet.w3,'qb3':self.qnet.b3, 'qw4':self.qnet.w4,'qb4':self.qnet.b4, 'qw5':self.qnet.w5,'qb5':self.qnet.b5, 'tw1':self.targetnet.w1,'tb1':self.targetnet.b1, 'tw2':self.targetnet.w2,'tb2':self.targetnet.b2, 'tw3':self.targetnet.w3,'tb3':self.targetnet.b3, 'tw4':self.targetnet.w4,'tb4':self.targetnet.b4, 'tw5':self.targetnet.w5,'tb5':self.targetnet.b5, 'step':self.qnet.global_step}#需要保存的save_dict的权值和偏置值 self.saver = tf.train.Saver(saver_dict) #self.saver = tf.train.Saver() #复制qnet网络的权值和偏置值 self.cp_ops = [ self.targetnet.w1.assign(self.qnet.w1),self.targetnet.b1.assign(self.qnet.b1), self.targetnet.w2.assign(self.qnet.w2),self.targetnet.b2.assign(self.qnet.b2), self.targetnet.w3.assign(self.qnet.w3),self.targetnet.b3.assign(self.qnet.b3), self.targetnet.w4.assign(self.qnet.w4),self.targetnet.b4.assign(self.qnet.b4), self.targetnet.w5.assign(self.qnet.w5),self.targetnet.b5.assign(self.qnet.b5)] self.sess.run(self.cp_ops) if self.params['ckpt_file'] is not None:#恢复上一次训练的状态 print 'loading checkpoint : ' + self.params['ckpt_file'] self.saver.restore(self.sess,self.params['ckpt_file']) temp_train_cnt = self.sess.run(self.qnet.global_step) temp_step = temp_train_cnt * self.params['learning_interval'] print 'Continue from' print ' -> Steps : ' + str(temp_step) print ' -> Minibatch update : ' + str(temp_train_cnt) def start(self):#网络开始学习和训练 self.reset_game()#开始一个新的游戏 self.step = 0#当前迭代次数 self.reset_statistics('all')#重置网络所有参数 self.train_cnt = self.sess.run(self.qnet.global_step) #如果是恢复上一层训练的状态则读取相应的文件 if self.train_cnt > 0 : self.step = self.train_cnt * self.params['learning_interval'] try: self.log_train = open('log_training_'+self.params['network_type']+'.csv','a') except: self.log_train = open('log_training_'+self.params['network_type']+'.csv','w') self.log_train.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') try: self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','a') except: self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','w') self.log_eval.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') else: self.log_train = open('log_training_'+self.params['network_type']+'.csv','w') self.log_train.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') self.log_eval = open('log_eval_'+self.params['network_type']+'.csv','w') self.log_eval.write('step,epoch,train_cnt,avg_reward,avg_q,epsilon,time\n') self.s = time.time() #输出网络状态 print self.params print 'Start training!' print 'Collecting replay memory for ' + str(self.params['train_start']) + ' steps' #开始进行迭代,训练,其中,params['train_start']表示用随机的权值跑游戏,获取最初始的replay memory while self.step < (self.params['steps_per_epoch'] * self.params['num_epochs'] * self.params['learning_interval'] + self.params['train_start']): if self.training : if self.DB.get_size() >= self.params['train_start'] : self.step += 1 ; self.steps_train += 1 else : self.step_eval += 1 #将上一次的状态,归一化之后的奖励值,以及所采取的action的索引,布尔类型的terminal 保存到DB里面(replay memory) if self.state_gray_old is not None and self.training: self.DB.insert(self.state_gray_old[26:110,:],self.reward_scaled,self.action_idx,self.terminal) #每隔params['copy_freq']的迭代次数,将训练的qnet网络超参数,复制到target网络 if self.training and self.params['copy_freq'] > 0 and self.step % self.params['copy_freq'] == 0 and self.DB.get_size() > self.params['train_start']: print '&&& Copying Qnet to targetnet\n' self.sess.run(self.cp_ops)#??? #每隔params['learning_interval']的迭代次数权值更新一次,注意的是params['learning_interval']=1表明每次action之后都得训练一次 if self.training and self.step % self.params['learning_interval'] == 0 and self.DB.get_size() > self.params['train_start'] : '''从DB(replay memory中)随机选取batch个状态序列,供网络进行学习训练, 具体的数据有状态s,动作a对应的索引,采取动作a之后的下一个状态,以及奖励值''' bat_s,bat_a,bat_t,bat_n,bat_r = self.DB.get_batches() bat_a = self.get_onehot(bat_a)#将action的索引值转换成一个稀疏矩阵,矩阵的行的大小表示batch_size,列的大小表示num_action,每行中对应的bat_a为1,其余为0 #将游戏的当前状态,通过targetnet,将输出的q_t当作当前状态下的最大未来奖励 if self.params['copy_freq'] > 0 : feed_dict={self.targetnet.x: bat_n} q_t = self.sess.run(self.targetnet.y,feed_dict=feed_dict) else: feed_dict={self.qnet.x: bat_n} q_t = self.sess.run(self.qnet.y,feed_dict=feed_dict) q_t = np.amax(q_t,axis=1) #这里将随机取出来的状态序列(可以理解成为经验),喂入qnet网络 feed_dict={self.qnet.x: bat_s, self.qnet.q_t: q_t, self.qnet.actions: bat_a, self.qnet.terminals:bat_t, self.qnet.rewards: bat_r} #通过之前定义的qnet计算损失函数 _,self.train_cnt,self.cost = self.sess.run([self.qnet.rmsprop,self.qnet.global_step,self.qnet.cost],feed_dict=feed_dict) #累计损失函数的计算 self.total_cost_train += np.sqrt(self.cost) self.train_cnt_for_disp += 1 if self.training : self.params['eps'] = max(self.params['eps_min'],1.0 - float(self.train_cnt * self.params['learning_interval'])/float(self.params['eps_step'])) else: self.params['eps'] = 0.05 #每隔self.params['save_interval']迭代,进行权值保存 相当于caffe里面的snapshot ,一下到228行都是一些辅助代码,显示输出啊,测试啊,等等, if self.DB.get_size() > self.params['train_start'] and self.step % self.params['save_interval'] == 0 and self.training: save_idx = self.train_cnt self.saver.save(self.sess,'ckpt/model_'+self.params['network_type']+'_'+str(save_idx)) sys.stdout.write('$$$ Model saved : %s\n\n' % ('ckpt/model_'+self.params['network_type']+'_'+str(save_idx))) sys.stdout.flush() #输出显示 if self.training and self.step > 0 and self.step % self.params['disp_freq'] == 0 and self.DB.get_size() > self.params['train_start'] : self.write_log_train() #进行测试,这个时候相当于,只是用targetnet玩游戏,测试嘛。。qnet就被搁置了 if self.training and self.step > 0 and self.step % self.params['eval_freq'] == 0 and self.DB.get_size() > self.params['train_start'] : self.reset_game() if self.step % self.params['steps_per_epoch'] == 0 : self.reset_statistics('all') else: self.reset_statistics('eval') self.training = False #TODO : add video recording continue #训练时,每self.params['steps_per_epoch']步,重新开始游戏,因为往后对权值更新的作用不大 #这里为解释一下,为什么会有这玩意,因为深度强化学习采用的是贝尔曼迭代法,有一个咖马参数,当当前游戏玩的步数比较多时,对当前的最大未来奖励就不是很大了,具体见论文 if self.training and self.step > 0 and self.step % self.params['steps_per_epoch'] == 0 and self.DB.get_size() > self.params['train_start']: self.reset_game() self.reset_statistics('all') #self.training = False continue if not self.training and self.step_eval >= self.params['steps_per_eval'] : self.write_log_eval() self.reset_game() self.reset_statistics('eval') self.training = True continue #判断游戏是否over if self.terminal : self.reset_game() if self.training : self.num_epi_train += 1 self.total_reward_train += self.epi_reward_train self.epi_reward_train = 0 else : self.num_epi_eval += 1 self.total_reward_eval += self.epi_reward_eval self.epi_reward_eval = 0 continue '''这个函数表示选择下一步执行的action,select_action()这个函数采用了模拟退火算法的思想''' self.action_idx,self.action, self.maxQ = self.select_action(self.state_proc) #根据select_action函数获得的动作 执行,并且返回执行该动作后,所产生的状态,奖励,是否中断游戏等参数(构成马尔可夫链的一个新节点) self.state, self.reward, self.terminal = self.engine.next(self.action) self.reward_scaled = self.reward // max(1,abs(self.reward))#归一化奖励 if self.training : self.epi_reward_train += self.reward ; self.total_Q_train += self.maxQ#累计奖励 else : self.epi_reward_eval += self.reward ; self.total_Q_eval += self.maxQ #以下程序段对新产生的状态进行处理,便于归入DB(replay memory,) self.state_gray_old = np.copy(self.state_gray) self.state_proc[:,:,0:3] = self.state_proc[:,:,1:4] self.state_resized = cv2.resize(self.state,(84,110)) self.state_gray = cv2.cvtColor(self.state_resized, cv2.COLOR_BGR2GRAY) self.state_proc[:,:,3] = self.state_gray[26:110,:]/self.params['img_scale'] #TODO : add video recording def reset_game(self): self.state_proc = np.zeros((84,84,4)); self.action = -1; self.terminal = False; self.reward = 0 self.state = self.engine.newGame() self.state_resized = cv2.resize(self.state,(84,110)) self.state_gray = cv2.cvtColor(self.state_resized, cv2.COLOR_BGR2GRAY) self.state_gray_old = None self.state_proc[:,:,3] = self.state_gray[26:110,:]/self.params['img_scale'] def reset_statistics(self,mode): if mode == 'all': self.epi_reward_train = 0 self.epi_Q_train = 0 self.num_epi_train = 0 self.total_reward_train = 0 self.total_Q_train = 0 self.total_cost_train = 0 self.steps_train = 0 self.train_cnt_for_disp = 0 self.step_eval = 0 self.epi_reward_eval = 0 self.epi_Q_eval = 0 self.num_epi_eval = 0 self.total_reward_eval = 0 self.total_Q_eval = 0 def write_log_train(self): sys.stdout.write('### Training (Step : %d , Minibatch update : %d , Epoch %d)\n' % (self.step,self.train_cnt,self.step//self.params['steps_per_epoch'] )) sys.stdout.write(' Num.Episodes : %d , Avg.reward : %.3f , Avg.Q : %.3f, Avg.loss : %.3f\n' % (self.num_epi_train,float(self.total_reward_train)/max(1,self.num_epi_train),float(self.total_Q_train)/max(1,self.steps_train),self.total_cost_train/max(1,self.train_cnt_for_disp))) sys.stdout.write(' Epsilon : %.3f , Elapsed time : %.1f\n\n' % (self.params['eps'],time.time()-self.s)) sys.stdout.flush() self.log_train.write(str(self.step) + ',' + str(self.step//self.params['steps_per_epoch']) + ',' + str(self.train_cnt) + ',') self.log_train.write(str(float(self.total_reward_train)/max(1,self.num_epi_train)) +','+ str(float(self.total_Q_train)/max(1,self.steps_train)) +',') self.log_train.write(str(self.params['eps']) +','+ str(time.time()-self.s) + '\n') self.log_train.flush() def write_log_eval(self): sys.stdout.write('@@@ Evaluation (Step : %d , Minibatch update : %d , Epoch %d)\n' % (self.step,self.train_cnt,self.step//self.params['steps_per_epoch'] )) sys.stdout.write(' Num.Episodes : %d , Avg.reward : %.3f , Avg.Q : %.3f\n' % (self.num_epi_eval,float(self.total_reward_eval)/max(1,self.num_epi_eval),float(self.total_Q_eval)/max(1,self.params['steps_per_eval']))) sys.stdout.write(' Epsilon : %.3f , Elapsed time : %.1f\n\n' % (self.params['eps'],time.time()-self.s)) sys.stdout.flush() self.log_eval.write(str(self.step) + ',' + str(self.step//self.params['steps_per_epoch']) + ',' + str(self.train_cnt) + ',') self.log_eval.write(str(float(self.total_reward_eval)/max(1,self.num_epi_eval)) +','+ str(float(self.total_Q_eval)/max(1,self.params['steps_per_eval'])) +',') self.log_eval.write(str(self.params['eps']) +','+ str(time.time()-self.s) + '\n') self.log_eval.flush() def select_action(self,st): if np.random.rand() > self.params['eps']:#产生随机数,若self.params['eps']小于该随机数,则用qnet决定下一步执行的动作,注意这里采用的是qnet #greedy with random tie-breaking Q_pred = self.sess.run(self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0] a_winner = np.argwhere(Q_pred == np.amax(Q_pred)) if len(a_winner) > 1: act_idx = a_winner[np.random.randint(0, len(a_winner))][0] return act_idx,self.engine.legal_actions[act_idx], np.amax(Q_pred) else: act_idx = a_winner[0][0] return act_idx,self.engine.legal_actions[act_idx], np.amax(Q_pred) #若小于该随机数则随机产生一个动作进行执行 else: #random act_idx = np.random.randint(0,len(self.engine.legal_actions)) Q_pred = self.sess.run(self.qnet.y, feed_dict = {self.qnet.x: np.reshape(st, (1,84,84,4))})[0] return act_idx,self.engine.legal_actions[act_idx], Q_pred[act_idx] def get_onehot(self,actions): actions_onehot = np.zeros((self.params['batch'], self.params['num_act'])) for i in range(self.params['batch']): actions_onehot[i,actions[i]] = 1 return actions_onehotif __name__ == "__main__": dict_items = params.items() for i in range(1,len(sys.argv),2): if sys.argv[i] == '-weight' :params['ckpt_file'] = sys.argv[i+1] elif sys.argv[i] == '-network_type' :params['network_type'] = sys.argv[i+1] elif sys.argv[i] == '-visualize' : if sys.argv[i+1] == 'y' : params['visualize'] = True elif sys.argv[i+1] == 'n' : params['visualize'] = False else: print 'Invalid visualization argument!!! Available arguments are' print ' y or n' raise ValueError() elif sys.argv[i] == '-gpu_fraction' : params['gpu_fraction'] = float(sys.argv[i+1]) elif sys.argv[i] == '-db_size' : params['db_size'] = int(sys.argv[i+1]) elif sys.argv[i] == '-only_eval' : params['only_eval'] = sys.argv[i+1] else : print 'Invalid arguments!!! Available arguments are' print ' -weight (filename)' print ' -network_type (nips or nature)' print ' -visualize (y or n)' print ' -gpu_fraction (0.1~0.9)' print ' -db_size (integer)' raise ValueError() if params['network_type'] == 'nips': from DQN_nips import * elif params['network_type'] == 'nature': from DQN_nature import * params['steps_per_epoch']= 200000 params['eval_freq'] = 100000 params['steps_per_eval'] = 10000 params['copy_freq'] = 10000 params['disp_freq'] = 20000 params['save_interval'] = 20000 params['learning_interval'] = 1 params['discount'] = 0.99 params['lr'] = 0.00025 params['rms_decay'] = 0.95 params['rms_eps']=0.01 params['clip_delta'] = 1.0 params['train_start']=50000 params['batch_accumulator'] = 'sum' params['eps_step'] = 1000000 params['num_epochs'] = 250 params['batch'] = 32 else : print 'Invalid network type! Available network types are' print ' nips or nature' raise ValueError() if params['only_eval'] == 'y' : only_eval = True elif params['only_eval'] == 'n' : only_eval = False else : print 'Invalid only_eval option! Available options are' print ' y or n' raise ValueError() if only_eval: params['eval_freq'] = 1 params['train_start'] = 100 da = deep_atari(params) da.start()
database类 该类实现了论文里面的replay memory
import numpy as npimport gcimport timeimport cv2class database: def __init__(self, params): self.size = params['db_size'] self.img_scale = params['img_scale'] self.states = np.zeros([self.size,84,84],dtype='uint8') #image dimensions self.actions = np.zeros(self.size,dtype='float32') self.terminals = np.zeros(self.size,dtype='float32') self.rewards = np.zeros(self.size,dtype='float32') self.bat_size = params['batch'] self.bat_s = np.zeros([self.bat_size,84,84,4]) self.bat_a = np.zeros([self.bat_size]) self.bat_t = np.zeros([self.bat_size]) self.bat_n = np.zeros([self.bat_size,84,84,4]) self.bat_r = np.zeros([self.bat_size]) self.counter = 0 #keep track of next empty state self.flag = False return def get_batches(self):#get random replay memory for i in range(self.bat_size):#从replay memory提取,batch_size=32的序列数据 idx = 0 while idx < 3 or (idx > self.counter-2 and idx < self.counter+3): idx = np.random.randint(3,self.get_size()-1)#get_size()返回的是当前replay memory的状态个数,随机选取 #以下是提取相应idx对应的值,并返回 self.bat_s[i] = np.transpose(self.states[idx-3:idx+1,:,:],(1,2,0))/self.img_scale self.bat_n[i] = np.transpose(self.states[idx-2:idx+2,:,:],(1,2,0))/self.img_scale self.bat_a[i] = self.actions[idx] self.bat_t[i] = self.terminals[idx] self.bat_r[i] = self.rewards[idx] #self.bat_s[0] = np.transpose(self.states[10:14,:,:],(1,2,0))/self.img_scale #self.bat_n[0] = np.transpose(self.states[11:15,:,:],(1,2,0))/self.img_scale #self.bat_a[0] = self.actions[13] #self.bat_t[0] = self.terminals[13] #self.bat_r[0] = self.rewards[13] return self.bat_s,self.bat_a,self.bat_t,self.bat_n,self.bat_r def insert(self, prevstate_proc,reward,action,terminal):#更新马尔可夫链 self.states[self.counter] = prevstate_proc self.rewards[self.counter] = reward self.actions[self.counter] = action self.terminals[self.counter] = terminal #update counter self.counter += 1 if self.counter >= self.size: self.flag = True self.counter = 0 return def get_size(self):#返回当前replay 马尔可夫链的大小 if self.flag == False: return self.counter else: return self.size
DQN网络,该代码的核心部分,定义了网络结构,贝尔曼函数以及损失函数
import numpy as npimport tensorflow as tfimport cv2class DQN: def __init__(self,params,name): # 用tensorflow为马尔可夫节点的各个元素分配内存空间,输入[32,84,84,4](注:params['batch']=32) self.network_type = 'nature' self.params = params self.network_name = name self.x = tf.placeholder('float32',[None,84,84,4],name=self.network_name + '_x') self.q_t = tf.placeholder('float32',[None],name=self.network_name + '_q_t') self.actions = tf.placeholder("float32", [None, params['num_act']],name=self.network_name + '_actions') self.rewards = tf.placeholder("float32", [None],name=self.network_name + '_rewards') self.terminals = tf.placeholder("float32", [None],name=self.network_name + '_terminals') #conv1,[32,84,84,4]-->[32,w1,h1,32](卷积后的w,h,根据公式:h=(h+2*padding-stride)/stride+1,自行计算) layer_name = 'conv1' ; size = 8 ; channels = 4 ; filters = 32 ; stride = 4 self.w1 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') self.b1 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases') self.c1 = tf.nn.conv2d(self.x, self.w1, strides=[1, stride, stride, 1], padding='VALID',name=self.network_name + '_'+layer_name+'_convs') self.o1 = tf.nn.relu(tf.add(self.c1,self.b1),name=self.network_name + '_'+layer_name+'_activations') #self.n1 = tf.nn.lrn(self.o1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) #conv2,[32,w1,h1,32]-->[32,w2,h2,64] layer_name = 'conv2' ; size = 4 ; channels = 32 ; filters = 64 ; stride = 2 self.w2 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') self.b2 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases') self.c2 = tf.nn.conv2d(self.o1, self.w2, strides=[1, stride, stride, 1], padding='VALID',name=self.network_name + '_'+layer_name+'_convs') self.o2 = tf.nn.relu(tf.add(self.c2,self.b2),name=self.network_name + '_'+layer_name+'_activations') #self.n2 = tf.nn.lrn(self.o2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) #conv3,[32,w2,h2,64]-->[32,w3,h3,64] layer_name = 'conv3' ; size = 3 ; channels = 64 ; filters = 64 ; stride = 1 self.w3 = tf.Variable(tf.random_normal([size,size,channels,filters], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') self.b3 = tf.Variable(tf.constant(0.1, shape=[filters]),name=self.network_name + '_'+layer_name+'_biases') self.c3 = tf.nn.conv2d(self.o2, self.w3, strides=[1, stride, stride, 1], padding='VALID',name=self.network_name + '_'+layer_name+'_convs') self.o3 = tf.nn.relu(tf.add(self.c3,self.b3),name=self.network_name + '_'+layer_name+'_activations') #self.n2 = tf.nn.lrn(self.o2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75) #flat,将矩阵压缩成向量 o3_shape = self.o3.get_shape().as_list() #fc3,[32,w3*h3*64]-->[32,512] layer_name = 'fc4' ; hiddens = 512 ; dim = o3_shape[1]*o3_shape[2]*o3_shape[3] self.o3_flat = tf.reshape(self.o3, [-1,dim],name=self.network_name + '_'+layer_name+'_input_flat') self.w4 = tf.Variable(tf.random_normal([dim,hiddens], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') self.b4 = tf.Variable(tf.constant(0.1, shape=[hiddens]),name=self.network_name + '_'+layer_name+'_biases') self.ip4 = tf.add(tf.matmul(self.o3_flat,self.w4),self.b4,name=self.network_name + '_'+layer_name+'_ips') self.o4 = tf.nn.relu(self.ip4,name=self.network_name + '_'+layer_name+'_activations') #fc4,[32,512]-->[32,number_act] layer_name = 'fc5' ; hiddens = params['num_act'] ; dim = 512 self.w5 = tf.Variable(tf.random_normal([dim,hiddens], stddev=0.01),name=self.network_name + '_'+layer_name+'_weights') self.b5 = tf.Variable(tf.constant(0.1, shape=[hiddens]),name=self.network_name + '_'+layer_name+'_biases') ''' 在这里说一说,deep-Q-Learning的核心思想之一,就是在玩游戏时,使下一步的最大未来奖励最大,但是在实际过程, 不可能从未来向现在求累加,于是就用了贝尔曼迭代公式,就是假设下一个状态的最大未来奖励已经求解出来了, 这样就可以通过target神经网络求出该状态下的最大未来奖励(这个值做为qnet神经网络训练样本的期望值,但是,把它叫做label), 这种思想,类似于算法里面的递归,数学里面的归纳法,好了,现在再来说说深度神经网络的作用,它的输入就是当前状态(是一个连续的4张图像), 输出是下一状态下,不同action对应的最大未来收益,选取最大的输出值,就是该状态下的最大未来收益值,也就是神经网络的预测值,这里有意思的是期望值和预测值都是通过神经网络 求出来的(两个神经网络),训练的过程就和正常卷积神经网络的也就一样了,选择损失函数,误差反向传递,更新权值''' self.y = tf.add(tf.matmul(self.o4,self.w5),self.b5,name=self.network_name + '_'+layer_name+'_outputs') #Q,Cost,Optimizer self.discount = tf.constant(self.params['discount'])#贝尔曼迭代公式的咖马 #贝尔曼迭代公式,计算出来的结果叫做打折后的的最大未来奖励 self.yj = tf.add(self.rewards, tf.mul(1.0-self.terminals, tf.mul(self.discount, self.q_t))) self.Qxa = tf.mul(self.y,self.actions) self.Q_pred = tf.reduce_max(self.Qxa, reduction_indices=1) #self.yjr = tf.reshape(self.yj,(-1,1)) #self.yjtile = tf.concat(1,[self.yjr,self.yjr,self.yjr,self.yjr]) #self.yjax = tf.mul(self.yjtile,self.actions) #half = tf.constant(0.5) self.diff = tf.sub(self.yj, self.Q_pred) if self.params['clip_delta'] > 0 : self.quadratic_part = tf.minimum(tf.abs(self.diff), tf.constant(self.params['clip_delta']))###????? self.linear_part = tf.sub(tf.abs(self.diff),self.quadratic_part) self.diff_square = 0.5 * tf.pow(self.quadratic_part,2) + self.params['clip_delta']*self.linear_part else: self.diff_square = tf.mul(tf.constant(0.5),tf.pow(self.diff, 2)) if self.params['batch_accumulator'] == 'sum': self.cost = tf.reduce_sum(self.diff_square) else: self.cost = tf.reduce_mean(self.diff_square) self.global_step = tf.Variable(0, name='global_step', trainable=False) self.rmsprop = tf.train.RMSPropOptimizer(self.params['lr'],self.params['rms_decay'],0.0,self.params['rms_eps']).minimize(self.cost,global_step=self.global_step)
emulator类 定义了atari的游戏接口,可以获取当前的状态(图像),reward,重置游戏函数,新建游戏函数
import numpy as npimport copyimport sysfrom ale_python_interface import ALEInterfaceimport cv2import time#import scipy.miscclass emulator: def __init__(self, rom_name, vis,windowname='preview'): self.ale = ALEInterface() self.max_frames_per_episode = self.ale.getInt("max_num_frames_per_episode"); self.ale.setInt("random_seed",123) self.ale.setInt("frame_skip",4) self.ale.loadROM('roms/' + rom_name ) self.legal_actions = self.ale.getMinimalActionSet() self.action_map = dict() self.windowname = windowname for i in range(len(self.legal_actions)): self.action_map[self.legal_actions[i]] = i # print(self.legal_actions) self.screen_width,self.screen_height = self.ale.getScreenDims() print("width/height: " +str(self.screen_width) + "/" + str(self.screen_height)) self.vis = vis if vis: cv2.startWindowThread() cv2.namedWindow(self.windowname) def get_image(self):#读取画面的图像 numpy_surface = np.zeros(self.screen_height*self.screen_width*3, dtype=np.uint8) self.ale.getScreenRGB(numpy_surface) image = np.reshape(numpy_surface, (self.screen_height, self.screen_width, 3)) return image def newGame(self): self.ale.reset_game()#开始新的游戏 return self.get_image()#得到初始的游戏画面 def next(self, action_indx): reward = self.ale.act(action_indx) nextstate = self.get_image() # scipy.misc.imsave('test.png',nextstate) if self.vis: cv2.imshow(self.windowname,nextstate) return nextstate, reward, self.ale.game_over()if __name__ == "__main__": engine = emulator('breakout.bin',True) engine.next(0) time.sleep(5)
论文链接:https://arxiv.org/pdf/1312.5602.pdf
代码链接:https://github.com/gliese581gg/DQN_tensorflow
晚点附上,这篇论文的论文笔记
若存在不足之处,请批评指正
阅读全文
1 0
- DQN_tensorflow 源码解读
- 源码解读
- 源码解读之Intent解读
- [源码解读] FastClick.js源码解读
- CppUnit源码解读(1)
- CppUnit源码解读(2)
- CppUnit源码解读(3)
- CppUnit源码解读(4)
- CppUnit源码解读(5)
- CppUnit源码解读(6)
- shared_ptr源码解读
- strutsr源码解读
- shared_ptr源码解读
- CppUnit源码解读
- CppUnit源码解读
- Ajax::prototype 源码解读
- CppUnit源码解读
- CppUnit源码解读(3)
- SAP HANA Odata的POST报错403 forbidden error
- 第三方支付“躺着赚钱”的时代开始终结
- Java二维数组计算集合(上下左右左斜右斜)
- 1436 [CA1009]The root of the equation
- RC积分(如果叫做低通也行) 俗语
- DQN_tensorflow 源码解读
- K选择问题
- Huge Matrix 扫描
- 硬件访问服务4之Android硬件访问服务框架及系统函数全详细实现
- 搜素-Q
- spring入门(注解实现Bean的定义)
- ubuntu opencv多版本控制
- 两个队列实现一个栈
- [一天几个linux命令] 改变文件所属用户组,所有者 chgrp chown