用DQN玩flappy bird(TensorFlow学习框架)

来源:互联网 发布:川大生活服务 网络 编辑:程序博客网 时间:2024/05/18 18:18

一.引言

增强学习挺有意思的。今天楼主就用TensorFlow实现最基本的DQN,让神经网络玩flappy bird。

增强学习理论楼主目前也在学习之中,不太了解。知乎上有一个叫“智能单元”的专栏写了一个叫“DQN从入门到放弃”的系列,挺不错的,大家可以参考一下:https://zhuanlan.zhihu.com/intelligentunit?topic=%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0%20(Reinforcement%20Learning)



二.程序介绍

直接进入主题了。


代码:https://github.com/Xiong-Da/flyRL

工程下有3个python源文件,分别是main.py、simulator.py、agent.py。分别是干什么的我想大家看名字就知道了,main负责界面交互,simulator实现flappy bird游戏机制,agent实现DQN。

启动main.py可以看到如下所示的界面:


play菜单是手动玩游戏,train菜单是从头开始训练模型,trainMoudle是已经保存的模型开始训练,save菜单表示保存正在训练的模型(同时结束训练),select表示选择一个模型用来演示DQN玩游戏,show是开始让DQN玩游戏。

游戏界面如下:


由于是手动实现flappy bird,所以做了很大的简化。水管直接退化为线段,鸟也是方形,游戏难度也相对小了很多。

代码目录下有很多楼主已经训练好的模型,大家可以直接载入后看看效果。只能说差强人意,勉强能玩,有些情况还是不能很好应对。


三.代码介绍

代码上写的很清楚,虽然没有注释。就挑agent.py里面的几点说一下。

首先是DQN模型的定义:

#[speed,birdPosY,tube-1X,tube-1height,tube-1gap,tube0...,tube1...]input_state=tf.placeholder(tf.float32,[None,1+1+3*3])input_value=tf.placeholder(tf.float32,[None])weight1 = tf.Variable(tf.truncated_normal([1+1+3*3,128], stddev=0.1))bias1 = tf.Variable(tf.zeros([128]))output1=tf.nn.relu(tf.matmul(input_state,weight1)+bias1)weight2 = tf.Variable(tf.truncated_normal([128,128], stddev=0.1))bias2 = tf.Variable(tf.zeros([128]))output2=tf.nn.relu(tf.matmul(output1,weight2)+bias2)weight3 = tf.Variable(tf.truncated_normal([128,2], stddev=0.1))bias3 = tf.Variable(tf.zeros([2]))output3=tf.matmul(output2,weight3)+bias3value_array=output3

楼主这里用的全连接网络。

输入层输入11维的特征向量,为鸟的垂直速度,垂直方向坐标,以及最近3个水管的数据。中间有两个隐藏层都是128个节点,输出层有2个节点,分别输出在当前状态下飞或不飞对应的价值。

刚开始楼主也没有用这么复杂的模型,都是后期调参数慢慢调复杂的,有兴趣的同学可以尝试降低模型复杂度多训练几次试试。


还有一点就是每一帧采取随机动作的概率,这里楼主采用了动态设置的办法,代码如下:

    greedyFactor=(200-liveCount/playCount)*0.001    if greedyFactor<0.05:        greedyFactor=0.05

这里主要是靠最近一次运行的平均存货帧数来设置,最小概率0.05。这样做的原因是希望在模型玩的不太好是更多的采取随机策略。


最后一点就是训练数据的收集了,看代码:

recordLength=40if len(_actions)>=recordLength and random.uniform(0,1)<=0.5:actions=actions+_actions[-recordLength:]rawNewStates=rawNewStates+_rawNewStates[-recordLength:]rawOldStates=rawOldStates+_rawOldStates[-recordLength:]
else:actions = actions + _actionsrawNewStates = rawNewStates + _rawNewStatesrawOldStates = rawOldStates + _rawOldStates

可以看出,楼主采集数据时有0.5的机会只采用后40帧数据,这是因为楼主希望模型更多的更新不能应对的情况下的策略。改为百分之百信不信?楼主试过,貌似会影响拟合。


以上就是楼主拟合模型时脑洞出来的结果,是不是真的有用不得而知。楼主明显感觉DNQ拟合比较麻烦一些,写代码用了不到一天,调模型拟合模型用了差不多一个星期,实际效果还并不完美,有几个情况楼主训练的模型还不能很好应对,特别是水管缝隙比较小的情况。


四.总结

没什么好总结的。这玩意儿就是要多动手调教调教。

至于理论的学习楼主应该还会再深入一些,也就是下一篇博客应该还是强化学习,但如果还是最基础的DQN就说不过去了。