HTML5技术

程序员带你一步步分析AI如何玩FlappyBird - yhthu(4)

字号+ 作者:H5之家 来源:H5之家 2017-04-13 09:03 我要评论( )

上述表格演示了具有4种状态/4种行为的系统,然而在实际应用中,以本文讲到的Flappy Bird游戏为例,界面为80*80个像素点,每个像素点的色值有256种可能。那么实际的状态总数为256的80*80次方,这是一个很大的数字,

上述表格演示了具有4种状态/4种行为的系统,然而在实际应用中,以本文讲到的Flappy Bird游戏为例,界面为80*80个像素点,每个像素点的色值有256种可能。那么实际的状态总数为256的80*80次方,这是一个很大的数字,直接导致无法通过表格的思路进行计算。

因此,为了实现降维,这里引入了一个价值函数近似的方法,通过一个函数表近似表达价值函数:

价值函数近似

其中,ω 与 b 分别为参数。看到这里,终于可以联系到前面提到的神经网络了,上面的表达式不就是神经元的函数吗?

  • Q-network
  • 下面这张图来自论文《Human-level Control through Deep Reinforcement Learning》,其中详细介绍了上述将Q值神经网络化的过程。(感兴趣的可以点之前的链接了解原文~)

    Q-network

    以本文为例,输入是经过处理的4个连续的80x80图像,然后经过三个卷积层,一个池化层,两个全连接层,最后输出包含每一个动作Q值的向量。

    现在已经将Q-learning神经网络化为Q-network了,接下来的问题是如何训练这个神经网络。神经网络训练的过程其实就是一个最优化方程求解的过程,定义系统的损失函数,然后让损失函数最小化的过程。

    训练过程依赖于上述提到的DQN算法,以目标Q值作为标签,因此,损失函数可以定义为:

    DQN损失函数(来源于论文)

    上面公式是s',a'即下一个状态和动作。确定了损失函数,确定了获取样本的方式,DQN的整个算法也就成型了!

    DQN算法(来源于论文)

    值得注意的是这里的D—Experience Replay,也就是经验池,就是如何存储样本及采样的问题。

    由于玩Flappy Bird游戏,采集的样本是一个时间序列,样本之间具有连续性,如果每次得到样本就更新Q值,受样本分布影响,效果会不好。因此,一个很直接的想法就是把样本先存起来,然后随机采样如何?这就是Experience Replay的思想。

    算法实现上,先反复实验,并且将实验数据存储在D中;存储到一定程度,就从中随机抽取数据,对损失函数进行梯度下降。

    四、代码:TensorFlow实现

    终于到了看代码的时候。首先申明下,当笔者从Deep Mind的论文入手,试图用TensorFlow实现对Flappy Bird游戏进行实现时,发现github已有大神完成demo。思路相同,所以直接以公开代码为例进行分析说明了。

    如有源码需要,请移步github:Using Deep Q-Network to Learn How To Play Flappy Bird。

    代码从结构上来讲,主要分为以下几部分:

    1. GameState游戏类及frame_step方法

    通过Python实现游戏必然要用pygame库,其包含时钟、基本的显示控制、各种游戏控件、触发事件等,对此有兴趣的,可以详细了解pygame。frame_step方法的入参为shape为 (2,) 的ndarray,值域: [1,0]:什么都不做; [0,1]:提升Bird。来看下代码实现:

    if input_actions[1] == 1: if self.playery > -2 * PLAYER_HEIGHT: self.playerVelY = self.playerFlapAcc self.playerFlapped = True # SOUNDS['wing'].play()

    后续操作包括检查得分、设置界面、检查是否碰撞等,这里不再详细展开。
    frame_step方法的返回值是:

    return image_data, reward, terminal

    分别表示界面图像数据,得分以及是否结束游戏。对应前面强化学习模型,界面图像数据表示环境状态 s,得分表示环境给予学习系统的反馈 r。

    2. CNN模型构建

    该Demo中包含三个卷积层,一个池化层,两个全连接层,最后输出包含每一个动作Q值的向量。因此,首先定义权重、偏置、卷积和池化函数:

    # 权重 def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.01) return tf.Variable(initial) # 偏置 def bias_variable(shape): initial = tf.constant(0.01, shape=shape) return tf.Variable(initial) # 卷积 def conv2d(x, W, stride): return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding="SAME") # 池化 def max_pool_2x2(x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")

    然后,通过上述函数构建卷积神经网络模型(对代码中参数不解的,可直接往前翻,看上面那张手画的图)。

    def createNetwork(): # 第一层卷积 W_conv1 = weight_variable([8, 8, 4, 32]) b_conv1 = bias_variable([32]) # 第二层卷积 W_conv2 = weight_variable([4, 4, 32, 64]) b_conv2 = bias_variable([64]) # 第三层卷积 W_conv3 = weight_variable([3, 3, 64, 64]) b_conv3 = bias_variable([64]) # 第一层全连接 W_fc1 = weight_variable([1600, 512]) b_fc1 = bias_variable([512]) # 第二层全连接 W_fc2 = weight_variable([512, ACTIONS]) b_fc2 = bias_variable([ACTIONS]) # 输入层 s = tf.placeholder("float", [None, 80, 80, 4]) # 第一层隐藏层+池化层 h_conv1 = tf.nn.relu(conv2d(s, W_conv1, 4) + b_conv1) h_pool1 = max_pool_2x2(h_conv1) # 第二层隐藏层(这里只用了一层池化层) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2, 2) + b_conv2) # h_pool2 = max_pool_2x2(h_conv2) # 第三层隐藏层 h_conv3 = tf.nn.relu(conv2d(h_conv2, W_conv3, 1) + b_conv3) h_conv3_flat = tf.reshape(h_conv3, [-1, 1600]) # 全连接层 h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, W_fc1) + b_fc1) # 输出层 # readout layer readout = tf.matmul(h_fc1, W_fc2) + b_fc2 return s, readout, h_fc1

    3. OpenCV-Python图像预处理方法

    在Ubuntu中安装opencv的步骤比较麻烦,当时也踩了不少坑,各种Google解决。建议安装opencv3。

    这部分主要对frame_step方法返回的数据进行了灰度化和二值化,也就是最基本的图像预处理方法。

     

    1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,请转载时务必注明文章作者和来源,不尊重原创的行为我们将追究责任;3.作者投稿可能会经我们编辑修改或补充。

    相关文章
    • net.sz.framework 框架 登录服务器架构 单服2 万 TPS(QPS) - 失足程序员

      net.sz.framework 框架 登录服务器架构 单服2 万 TPS(QPS) - 失足

      2017-04-13 11:05

    • [.NET] 一步步打造一个简单的 MVC 电商网站 - BooksStore(三) - 反骨仔(二五仔)

      [.NET] 一步步打造一个简单的 MVC 电商网站 - BooksStore(三) - 反

      2017-04-02 11:00

    • net.sz.framework 框架 轻松搭建服务---让你更专注逻辑功能---初探 - 失足程序员

      net.sz.framework 框架 轻松搭建服务---让你更专注逻辑功能---初探 -

      2017-04-02 10:11

    • 微服务--webapi实现,脱离iis,脱离tomcat - 失足程序员

      微服务--webapi实现,脱离iis,脱离tomcat - 失足程序员

      2017-03-30 11:00

    网友点评