深度有趣 | 21 从FlappyBird到DQN

栏目: 后端 · 发布时间: 7年前

内容简介:介绍强化学习(Reinforcement Learning,RL)的概念,并用DQN训练一个会玩FlappyBird的模型这个游戏很多人都玩过,很虐,以下是一个用pygame重现的FlappyBird,如果没有pygame则安装

介绍强化学习(Reinforcement Learning,RL)的概念,并用DQN训练一个会玩FlappyBird的模型

FlappyBird

这个游戏很多人都玩过,很虐,以下是一个用pygame重现的FlappyBird, github.com/sourabhv/Fl…

如果没有pygame则安装

pip install pygame
复制代码

运行 flappy.py 即可开始游戏,如果出现按键无法控制的情况,用 pythonw 运行代码即可

pythonw flappy.py
复制代码
深度有趣 | 21 从FlappyBird到DQN

原理

无监督学习没有标签,例如聚类;有监督学习有标签,例如分类;而强化学习介于两者之间,标签是通过不断尝试积累的

RL包括几个组成部分:

  • State(S):环境的状态,例如FlappyBird中的当前游戏界面,可以用一张图片来表示
  • Action(A):每个S下可采取的行动集合,例如在FlappyBird中可选择两个A,“跳一下”或者“什么都不做”
  • Reward(R):在某个S下执行某个A之后得到的回报,例如在FlappyBird中,可以是成功跳过一根水管(正回报),撞到水管或者掉到地上(负回报)

这样一来,游戏的进行过程,无非是从一个初始S开始,执行A、得到R、进入下一个S,如此往复,直到进入一个终止S

定义一个函数,用来计算游戏过程中回报的总和

以及从某个时刻开始之后的回报总和

但我们对未来每一步能获取的回报并不是完全肯定的,所以不妨乘上一个0到1之间的衰减系数

这样一来,可以得到相邻两步总回报之间的递推关系

DQN是强化学习中的一种常用算法,主要是引入了Q函数(Quality,价值函数),用于计算在某个S下执行某个A可以得到的最大总回报

有了Q函数之后,对于当前状态S,只需要计算每一个A对应的Q值,然后选择Q值最大的一个A,便是最优的行动策略(策略函数)

当Q函数收敛后,还可以得到Q函数的递推公式

可以使用神经网络实现Q函数并训练:

  • 定义神经网络的结构并随机初始化,输入为S,输出的个数和行动集合的大小一样
  • 每次以一定概率随机选择A,否则使用策略函数选择最优的A,即随机探索和有向策略相结合
  • 维护一个记忆模块,用于积累游戏过程中产生的数据
  • 预热期:不训练,主要是为了让记忆模块先积累一定数据
  • 探索期:逐渐降低随机概率,从随机探索过渡到有向策略,并且每次从记忆模块中取出一些数据训练模型
  • 训练期:固定随机概率,进一步训练模型,使得Q函数进一步收敛

关于强化学习和DQN的原理介绍,可以参考以下文章, ai.intel.com/demystifyin…

实现

基于以下项目进行修改, github.com/yenchenlin/…

game 中的代码对之前的 flappy.py 进行了简化和修改,去掉了背景图并固定角色和水管颜色,游戏会自动开始,挂掉之后也会自动继续,主要是便于模型自动进行和采集数据

加载库

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import random
import cv2
import sys
sys.path.append('game/')
import wrapped_flappy_bird as fb
from collections import deque
复制代码

定义一些参数

ACTIONS = 2
GAMMA = 0.99
OBSERVE = 10000
EXPLORE = 3000000
INITIAL_EPSILON = 0.1
FINAL_EPSILON = 0.0001
REPLAY_MEMORY = 50000
BATCH = 32
IMAGE_SIZE = 80
复制代码

定义一些网络输入和辅助函数,每一个S由连续的四帧游戏截图组成

S = tf.placeholder(dtype=tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 4], name='S')
A = tf.placeholder(dtype=tf.float32, shape=[None, ACTIONS], name='A')
Y = tf.placeholder(dtype=tf.float32, shape=[None], name='Y')
k_initializer = tf.truncated_normal_initializer(0, 0.01)
b_initializer = tf.constant_initializer(0.01)

def conv2d(inputs, kernel_size, filters, strides):
    return tf.layers.conv2d(inputs, kernel_size=kernel_size, filters=filters, strides=strides, padding='same', kernel_initializer=k_initializer, bias_initializer=b_initializer)

def max_pool(inputs):
    return tf.layers.max_pooling2d(inputs, pool_size=2, strides=2, padding='same')

def relu(inputs):
    return tf.nn.relu(inputs)
复制代码

定义网络结构,典型的卷积、池化、全连接层结构

h0 = max_pool(relu(conv2d(S, 8, 32, 4)))
h0 = relu(conv2d(h0, 4, 64, 2))
h0 = relu(conv2d(h0, 3, 64, 1))
h0 = tf.contrib.layers.flatten(h0)
h0 = tf.layers.dense(h0, units=512, activation=tf.nn.relu, bias_initializer=b_initializer)

Q = tf.layers.dense(h0, units=ACTIONS, bias_initializer=b_initializer, name='Q')
Q_ = tf.reduce_sum(tf.multiply(Q, A), axis=1)
loss = tf.losses.mean_squared_error(Y, Q_)
optimizer = tf.train.AdamOptimizer(1e-6).minimize(loss)
复制代码

用一个队列实现记忆模块,开始游戏,对于初始状态选择什么都不做

game_state = fb.GameState()
D = deque()

do_nothing = np.zeros(ACTIONS)
do_nothing[0] = 1
img, reward, terminal = game_state.frame_step(do_nothing)
img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY)
_, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY)
S0 = np.stack((img, img, img, img), axis=2)
复制代码

继续进行游戏并训练模型

sess = tf.Session()
sess.run(tf.global_variables_initializer())

t = 0
success = 0
saver = tf.train.Saver()
epsilon = INITIAL_EPSILON
while True:
    if epsilon > FINAL_EPSILON and t > OBSERVE:
        epsilon = INITIAL_EPSILON - (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE * (t - OBSERVE)

    Qv = sess.run(Q, feed_dict={S: [S0]})[0]
    Av = np.zeros(ACTIONS)
    if np.random.random() <= epsilon:
        action_index = np.random.randint(ACTIONS)
    else:
        action_index = np.argmax(Qv) 
    Av[action_index] = 1

    img, reward, terminal = game_state.frame_step(Av)
    if reward == 1:
        success += 1
    img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY)
    _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY)
    img = np.reshape(img, (IMAGE_SIZE, IMAGE_SIZE, 1))
    S1 = np.append(S0[:, :, 1:], img, axis=2)

    D.append((S0, Av, reward, S1, terminal))
    if len(D) > REPLAY_MEMORY:
        D.popleft()

    if t > OBSERVE:
        minibatch = random.sample(D, BATCH)
        S_batch = [d[0] for d in minibatch]
        A_batch = [d[1] for d in minibatch]
        R_batch = [d[2] for d in minibatch]
        S_batch_next = [d[3] for d in minibatch]
        T_batch = [d[4] for d in minibatch]

        Y_batch = []
        Q_batch_next = sess.run(Q, feed_dict={S: S_batch_next})
        for i in range(BATCH):
            if T_batch[i]:
                Y_batch.append(R_batch[i])
            else:
                Y_batch.append(R_batch[i] + GAMMA * np.max(Q_batch_next[i]))

        sess.run(optimizer, feed_dict={S: S_batch, A: A_batch, Y: Y_batch})

    S0 = S1
    t += 1

    if t > OBSERVE and t % 10000 == 0:
        saver.save(sess, './flappy_bird_dqn', global_step=t)

    if t <= OBSERVE:
        state = 'observe'
    elif t <= OBSERVE + EXPLORE:
        state = 'explore'
    else:
        state = 'train'
    print('Current Step %d Success %d State %s Epsilon %.6f Action %d Reward %f Q_MAX %f' % (t, success, state, epsilon, action_index, reward, np.max(Qv)))
复制代码

运行 dqn_flappy.py 即可从零开始训练模型,一开始角色各种乱跳,一根水管都跳不过去,但随着训练的进行,角色会通过学习获得越来越稳定的表现

深度有趣 | 21 从FlappyBird到DQN

也可以直接使用以下代码运行训练好的模型

# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import cv2
import sys
sys.path.append('game/')
import wrapped_flappy_bird as fb

ACTIONS = 2
IMAGE_SIZE = 80

sess = tf.Session()
sess.run(tf.global_variables_initializer())

saver = tf.train.import_meta_graph('./flappy_bird_dqn-8500000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()

S = graph.get_tensor_by_name('S:0')
Q = graph.get_tensor_by_name('Q/BiasAdd:0')

game_state = fb.GameState()

do_nothing = np.zeros(ACTIONS)
do_nothing[0] = 1
img, reward, terminal = game_state.frame_step(do_nothing)
img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY)
_, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY)
S0 = np.stack((img, img, img, img), axis=2)

while True:
    Qv = sess.run(Q, feed_dict={S: [S0]})[0]
    Av = np.zeros(ACTIONS) 
    Av[np.argmax(Qv)] = 1

    img, reward, terminal = game_state.frame_step(Av)
    img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY)
    _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY)
    img = np.reshape(img, (IMAGE_SIZE, IMAGE_SIZE, 1))
    S0 = np.append(S0[:, :, 1:], img, axis=2)
复制代码

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Open Data Structures

Open Data Structures

Pat Morin / AU Press / 2013-6 / USD 29.66

Offered as an introduction to the field of data structures and algorithms, Open Data Structures covers the implementation and analysis of data structures for sequences (lists), queues, priority queues......一起来看看 《Open Data Structures》 这本书的介绍吧!

图片转BASE64编码
图片转BASE64编码

在线图片转Base64编码工具

随机密码生成器
随机密码生成器

多种字符组合密码

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具