内容简介:很简单的一个小游戏,名字叫"FizzBuzz",游戏规则如下:从1开始数数,当遇到3的倍数的时候,说fizz,当遇到5的倍数的时候,说buzz,当遇到15的倍数的时候,就说fizzbuzz,其他情况则正常数数可以想到,在这个游戏中,总共只有四类,
Game rules
很简单的一个小游戏,名字叫"FizzBuzz",游戏规则如下:
从1开始数数,当遇到3的倍数的时候,说fizz,当遇到5的倍数的时候,说buzz,当遇到15的倍数的时候,就说fizzbuzz,其他情况则正常数数
Game conversion to classification problem
可以想到,在这个游戏中,总共只有四类, fizzbuzz , buzz , fizz , number
所以我们先定义一个函数,这个函数的作用是将输入的数字,离散为这四类中的某一类
def fizz_buzz_encode(i):
if i % 15 == 0:
return 3
elif i % 5 == 0:
return 2
elif i % 3 == 0:
return 1
else:
return 0
有了 encode 函数,还需要一个decode函数,参数是个数字,以及这个数字的类别,返回是这个数字应该喊什么,比方说 decode(15, 3) ,返回的就应该是 fizzbuzz ,再比如 decode(7, 0) ,就应该返回 7
def fizz_buzz_decode(i, label):
return [str(i), 'fizz', 'buzz', 'fizzbuzz'][label]
写个测试函数测试一下
def helper(i):
print(fizz_buzz_decode(i, fizz_buzz_encode(i)))
for i in range(1, 16):
helper(i)
输出: 1 2 fizz 4 buzz fizz 7 8 fizz buzz 11 fizz 13 14 fizzbuzz
Generate training set
import numpy as np import torch from torch import nn
对于一个神经网络,我们的输入是一个数字,我们要他返回的是这个数字属于哪个类别(知道哪个类别之后调用decode函数就行了)
但其实输入如果单纯是个十进制数字特征不够明显,我们可以尝试把十进制转换为二进制,将 01 编码作为输入
NUM_DIGITS = 10
def binary_encode(i, NUM_DIGITS): # 将一个十进制数转换为二进制
return np.array([i >> d & 1 for d in range(NUM_DIGITS)][::-1])
#print(binary_encode(15, NUM_DIGITS))
然后生成训练集 X 和 y ,我把$[101,1024]$之间的所有整数转为二进制作为 X_train ,掉用 encode 函数生成的标签作为 y_train
X_train = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)]) y_train = torch.LongTensor([fizz_buzz_encode(i) for i in range(101, 2 ** NUM_DIGITS)])
Construct neural network
首先设计网络结构
然后利用PyTorch定义模型
NUM_HIDDEN = 100 # 隐藏层100个神经元
model = nn.Sequential( # 网络结构:Input -> Hidden_Layer1 -> OutPut
nn.Linear(NUM_DIGITS, NUM_HIDDEN, bias = False), # z = w1*x, 其中w1.shape=(10, 100), x.shape=(923, 10)
nn.ReLU(), # z = relu(z), 其中z.shape=(923, 100)
nn.Linear(NUM_HIDDEN, 4, bias = False) # y_pred = z*w2, 其中z.shape(923, 100), w2.shape=(100, 4)
# 输出的是个923*4的矩阵
)
定义Loss_Function和梯度下降的方法
loss_fn = nn.CrossEntropyLoss() # 专为分类问题设计的Loss optimizer = torch.optim.SGD(model.parameters(), lr = 0.1) # lr is learning_rate
开始训练模型
BATCH_SIZE = 128
for epoch in range(10000):
for start in range(0, len(X_train), BATCH_SIZE):
end = start + BATCH_SIZE
batchX = X_train[start:end]
batchY = y_train[start:end]
y_pred = model(batchX)
loss = loss_fn(y_pred, batchY)
print('Epoch', epoch, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
如果关于 BATCH_SIZE 和 EPOCH 不清楚作用,可以看这篇 文章
训练最终结果如下图,我们说,如果一个人通过瞎猜玩这个游戏,那他每次的正确率只有$\frac{1}{4}$,但是从训练结果来看,很明显我们的网络的准确度比瞎猜要高很多
训练完以后生成测试数据 X_test
X_test = torch.Tensor([binary_encode(i, NUM_DIGITS) for i in range(1, 100)])
然后用训练好的模型对测试数据进行预测,生成 y_test ,假设测试数据有100个,那 y_test 的大小就是 (100, 4) ,4列分别对应每个类型的概率,我们取出最大概率对应的下标值,带入 decode 函数,就能看到他在测试数据上的表现了
with torch.no_grad():
y_test = model(X_test)
#y_test.max(1)[1]
predicts = zip(range(0, 101), list(y_test.max(1)[1].data.tolist()))
print([fizz_buzz_decode(i, x) for i, x in predicts])
输出: ['0', '1', 'fizz', '3', 'buzz', 'fizz', '6', '7', 'fizz', 'buzz', '10', 'fizz', '12', '13', 'fizzbuzz', '15', '16', 'fizz', '18', 'buzz', '20', '21', '22', 'fizz', 'buzz', '25', 'fizz', '27', '28', 'fizzbuzz', '30', 'fizz', 'fizz', '33', 'buzz', 'fizz', '36', '37', 'fizz', 'buzz', '40', 'fizz', '42', '43', 'fizzbuzz', '45', '46', 'fizz', '48', 'buzz', 'fizz', '51', '52', 'fizz', 'fizz', '55', 'fizz', '57', '58', 'fizzbuzz', '60', '61', 'fizz', '63', 'buzz', 'fizz', '66', '67', 'fizz', 'buzz', '70', 'fizz', '72', '73', '74', '75', '76', 'fizz', '78', 'buzz', 'fizz', '81', '82', 'fizz', 'buzz', '85', 'fizz', '87', '88', 'fizzbuzz', '90', '91', 'fizz', '93', 'buzz', 'fizz', '96', '97', 'fizz']
最终测试的效果并不是特别好,但是从一些数据当中可以看到,我们这个网络实际还是找到了这个游戏的部分规律。单从 fizzbuzz 的结果来看,虽然他并没有准确的达到每次都在15的倍数输出,但是它隐约知道在15的倍数附近要输出
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网
本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们。
Programming Amazon Web Services
James Murty / O'Reilly Media / 2008-3-25 / USD 49.99
Building on the success of its storefront and fulfillment services, Amazon now allows businesses to "rent" computing power, data storage and bandwidth on its vast network platform. This book demonstra......一起来看看 《Programming Amazon Web Services》 这本书的介绍吧!