HTML5技术

程序员带你一步步分析AI如何玩FlappyBird - yhthu(6)

字号+ 作者:H5之家 来源:H5之家 2017-04-13 09:03 我要评论( )

在该Demo训练时,也采用了Saver进行参数保存。 # saving and loading networkssaver = tf.train.Saver()checkpoint = tf.train.get_checkpoint_state("saved_networks")if checkpoint and checkpoint.model_checkpo

在该Demo训练时,也采用了Saver进行参数保存。

# saving and loading networks saver = tf.train.Saver() checkpoint = tf.train.get_checkpoint_state("saved_networks") if checkpoint and checkpoint.model_checkpoint_path: saver.restore(sess, checkpoint.model_checkpoint_path) print("Successfully loaded:", checkpoint.model_checkpoint_path) else: print("Could not find old network weights")

首先加载CheckPointState文件,然后采用saver.restore对已存在参数进行恢复。
在该Demo中,每隔10000步,就对参数进行保存:

# save progress every 10000 iterations if t % 10000 == 0: saver.save(sess, 'saved_networks/' + GAME + '-dqn', global_step=t) iv. 实验及样本存储

首先,根据ε 概率选择一个Action。

# choose an action epsilon greedily readout_t = readout.eval(feed_dict={s: [s_t]})[0] a_t = np.zeros([ACTIONS]) action_index t % FRAME_PER_ACTION == 0: if random.random() <= epsilon: print("----------Random Action----------") action_index = random.randrange(ACTIONS) a_t[random.randrange(ACTIONS)] : action_index = np.argmax(readout_t) a_t[action_index] : a_t[

这里,readout_t是训练数据为之前提到的四通道图像的模型输出。a_t是根据ε 概率选择的Action。

其次,执行选择的动作,并保存返回的状态、得分。

# run the selected action and observe next state and reward x_t1_colored, r_t, terminal = game_state.frame_step(a_t) x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY) ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY) x_t1 = np.reshape(x_t1, (80, 80, 1)) # s_t1 = np.append(x_t1, s_t[:,:,1:], axis = 2) s_t1 = np.append(x_t1, s_t[:, :, :3], axis=2) # store the transition in D D.append((s_t, a_t, r_t, s_t1, terminal))

经验池D保存的是一个马尔科夫序列。(s_t, a_t, r_t, s_t1, terminal)分别表示t时的状态s_t,执行的动作a_t,得到的反馈r_t,以及得到的下一步的状态s_t1和游戏是否结束的标志terminal。

在下一训练过程中,更新当前状态及步数:

# update the old values s_t = s_t1 t += 1

重复上述过程,实现反复实验及样本存储。

v. 通过梯度下降进行模型训练

在实验一段时间后,经验池D中已经保存了一些样本数据后,就可以从这些样本数据中随机抽样,进行模型训练了。这里设置样本数为OBSERVE = 100000.。随机抽样的样本数为BATCH = 32。

if t > OBSERVE: # sample a minibatch to train on minibatch = random.sample(D, BATCH) # get the batch variables s_j_batch = [d[0] for d in minibatch] a_batch = [d[1] for d in minibatch] r_batch = [d[2] for d in minibatch] s_j1_batch = [d[3] for d in minibatch] y_batch = [] readout_j1_batch = readout.eval(feed_dict={s: s_j1_batch}) for i in range(0, len(minibatch)): terminal = minibatch[i][4] # if terminal, only equals reward if terminal: y_batch.append(r_batch[i]) else: y_batch.append(r_batch[i] + GAMMA * np.max(readout_j1_batch[i])) # perform gradient step train_step.run(feed_dict={ y: y_batch, a: a_batch, s: s_j_batch} )

s_j_batch、a_batch、r_batch、s_j1_batch是从经验池D中提取到的马尔科夫序列(Java童鞋羡慕Python的列表推导式啊),y_batch为标签值,若游戏结束,则不存在下一步中状态对应的Q值(回忆Q值更新过程),直接添加r_batch,若未结束,则用折合因子(0.99)和下一步中状态的最大Q值的乘积,添加至y_batch。
最后,执行梯度下降训练,train_step的入参是s_j_batch、a_batch和y_batch。差不多经过2000000步(在本机上大概10个小时)训练之后,就能达到本文开头动图中的效果啦。

以上。

 

1.本站遵循行业规范,任何转载的稿件都会明确标注作者和来源;2.本站的原创文章,请转载时务必注明文章作者和来源,不尊重原创的行为我们将追究责任;3.作者投稿可能会经我们编辑修改或补充。

相关文章
  • net.sz.framework 框架 登录服务器架构 单服2 万 TPS(QPS) - 失足程序员

    net.sz.framework 框架 登录服务器架构 单服2 万 TPS(QPS) - 失足

    2017-04-13 11:05

  • [.NET] 一步步打造一个简单的 MVC 电商网站 - BooksStore(三) - 反骨仔(二五仔)

    [.NET] 一步步打造一个简单的 MVC 电商网站 - BooksStore(三) - 反

    2017-04-02 11:00

  • net.sz.framework 框架 轻松搭建服务---让你更专注逻辑功能---初探 - 失足程序员

    net.sz.framework 框架 轻松搭建服务---让你更专注逻辑功能---初探 -

    2017-04-02 10:11

  • 微服务--webapi实现,脱离iis,脱离tomcat - 失足程序员

    微服务--webapi实现,脱离iis,脱离tomcat - 失足程序员

    2017-03-30 11:00

网友点评
6