GAN入门实践(二)--Pytorch实现

栏目: Python · 发布时间: 6年前

内容简介:最近在网上看到一个据说是 Alex Smola 写的关于这个教程的实现是基于 mxnet 的,而我最近比较常用的是 tensorflow 和 pytorch,因此花了点小时间把其中的内容用这两个框架实现了一遍。一方面是加深自己对 GAN 和 tensorflow / pytorch 的理解,同时也给感兴趣的朋友多一些参考。由于教程已经非常通俗易懂了,本文不会再重复其中的讲解,而是记录一些实现过程中个人觉得要注意的地方。

最近在网上看到一个据说是 Alex Smola 写的关于 生成对抗网络 (Generative Adversarial Network, GAN)的入门教程,目的是从实践的角度讲解 GAN 的基本思想和实现过程。这个教程没有采用最普遍的用 GAN 生成图像作为例子,而是用生成一个特定的高斯分布作为切入点来讲解 GAN 的工作原理和效果。对 GAN 感兴趣的朋友不妨去看看这个简短的教程,直观易懂而不失有趣,链接如下:

https://github.com/zackchase/mxnet-the-straight-dope/blob/master/P10-C01-gan-intro.ipynb

这个教程的实现是基于 mxnet 的,而我最近比较常用的是 tensorflow 和 pytorch,因此花了点小时间把其中的内容用这两个框架实现了一遍。一方面是加深自己对 GAN 和 tensorflow / pytorch 的理解,同时也给感兴趣的朋友多一些参考。由于教程已经非常通俗易懂了,本文不会再重复其中的讲解,而是记录一些实现过程中个人觉得要注意的地方。

Pytorch实现的全代码在 这里 。 其中还参考了 CycleGAN 的Pytorch实现和 DCGAN 的Pytorch版本实现。

Tensorflow篇请看这里。

模型定义

class GAN(object):
    def __init__(self):

        # define the network
        self.netG = nn.Linear(2, 2)
        self.netD = nn.Sequential(
                        nn.Linear(2, 5), nn.Tanh(),
                        nn.Linear(5, 3), nn.Tanh(),
                        nn.Linear(3, 2))
        
        # define loss function
        self.criterion = nn.CrossEntropyLoss()

Pytorch 里面的函数调用比 Tensorflow 清晰很多,对初学者来说更容易接受。这里我们直接用 nn.Linear() 定义简单的全连接网络,并且用 nn.Sequential() 将多层网络拼接起来。

定义前向传播

def forward(self, x, z):
        self.z = Variable((torch.from_numpy(z)).float())
        self.fake_x = self.netG(self.z)
        self.real_x = Variable((torch.from_numpy(x)).float())
        
        self.label = Variable(torch.LongTensor(x.shape[0]).fill_(1), requires_grad=False)    # define null label with specific size

Pytorch里对前向和后向的传播都可以方便地进行细致的定义。这里要注意的是 网络的输入输出是一个变量 (Variable),而定义一个变量需要 torch.Tensor。上面的代码提供了一种从 numpy 矩阵转化为 Tensor 的方法,并且将类型转为 FloatTensor(因为 nn.Linear() 需要)。另外,我们还定义了一个 label 变量,为了方便之后计算损失函数,因为损失函数的输入也是需要是变量。由于 label 变量是 不需要计算梯度 的,我们将 requires_grad 设为 False,并且为了符合 CrossEntropyLoss() 的要求使用 LongTensor 类型。

定义后向传播

def backward_D(self):
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        pred_fake = self.netD(self.fake_x.detach())    # stop backprop to the generator by detaching fake_B
        self.loss_D_fake = self.criterion(pred_fake, self.label*0)

        pred_real = self.netD(self.real_x)
        self.loss_D_real = self.criterion(pred_real, self.label*1)
        
        self.loss_D = self.loss_D_fake + self.loss_D_real
        
        self.loss_D.backward()
    
    def backward_G(self):
        # (2) Update G network: maximize log(D(G(z)))
        pred_fake = self.netD(self.fake_x)
        self.loss_G = self.criterion(pred_fake, self.label*1)    # fool the discriminator
        
        self.loss_G.backward()

这两个函数定义了损失函数是如何计算以及梯度是如何反向传播的。注意的是判别器的反向传播函数中用到了 detach(),这是为了避免梯度反向传播到生成器中,因为这里我们只更新判别器的梯度。不然的话,调用 loss_D.backward() 会导致生成器中变量的 grad 发生变化。定义好各种损失函数后,调用 backward() 函数就可以实现自动求各个变量的梯度,十分方便。

定义优化

d_optim = torch.optim.Adam(gan.netD.parameters(), lr=0.05)
g_optim = torch.optim.Adam(gan.netG.parameters(), lr=0.01)

虽然上面定义了梯度是怎么计算的,但具体怎么使用这些梯度进行权值的更新则是通过torch.optim 中的优化器来定义。其中第一个输入参数为该优化器要优化的变量。

模型初始化

# initialize weights in the model
def init_weights(m):
    if type(m) == nn.Linear:
        m.weight.data.normal_(0.0, 0.02)

gan.netD.apply(init_weights)
gan.netG.apply(init_weights)

Pytorch中将 nn 中的 网络层定义与初始化分开 了,通过上面的方法可以对网络的变量进行初始化。

模型训练

......
	
        # forward
        gan.forward(x_batch, z_batch)
            
        # update D network
        d_optim.zero_grad()
        gan.backward_D()
        d_optim.step()
        
        # update G network
        g_optim.zero_grad()
        gan.backward_G()
        g_optim.step()

这里唯一需要注意的是,在计算反向传播前,要先对变量的梯度进行清零,使用 zero_grad() 函数即可。然后使用 step() 执行变量的更新。

模型测试

def test(self, x, z):
        z = Variable((torch.from_numpy(z)).float(), volatile=True)
        fake_x = self.netG(z)
        real_x = Variable((torch.from_numpy(x)).float(), volatile=True)
        
        pred_fake = self.netD(fake_x)
        pred_real = self.netD(real_x)
        
        return fake_x.data.numpy(), pred_real.data.numpy(), pred_fake.data.numpy()

其实前面 GAN 类中还定义了一个用于测试的函数。这里需要注意的是,因为变量都不需要求梯度(不需要训练)了,所以我们把 volatile 设为 True。上面代码还给出了一个例子如何将 Variable 转化为 numpy 矩阵。


以上所述就是小编给大家介绍的《GAN入门实践(二)--Pytorch实现》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Web ReDesign 2.0

Web ReDesign 2.0

Kelly Goto、Emily Cotler / Peachpit Press / 2004-12-10 / USD 45.00

If anything, this volume's premise--that the business of Web design is one of constant change-has only proven truer over time. So much so, in fact, that the 12-month design cycles cited in the last ed......一起来看看 《Web ReDesign 2.0》 这本书的介绍吧!

XML、JSON 在线转换
XML、JSON 在线转换

在线XML、JSON转换工具

UNIX 时间戳转换
UNIX 时间戳转换

UNIX 时间戳转换

RGB CMYK 转换工具
RGB CMYK 转换工具

RGB CMYK 互转工具