巧断梯度:单个loss实现GAN模型(附开源代码)

栏目: 数据库 · 发布时间: 5年前

内容简介:我们知道普通的模型都是搭好架构,然后定义好 loss,直接扔给优化器训练就行了。但是 GAN 不一样,一般来说它涉及有两个不同的 loss,这两个 loss 需要交替优化。现在主流的方案是判别器和生成器都按照 1:1 的次数交替训练(各训练一次,必要时可以给两者设置不同的学习率,即 TTUR),交替优化就意味我们需要传入两次数据(从内存传到显存)、执行两次前向传播和反向传播。如果我们能把这两步合并起来,作为一步去优化,那么肯定能节省时间的,

我们知道普通的模型都是搭好架构,然后定义好 loss,直接扔给优化器训练就行了。但是 GAN 不一样,一般来说它涉及有两个不同的 loss,这两个 loss 需要交替优化。

现在主流的方案是判别器和生成器都按照 1:1 的次数交替训练(各训练一次,必要时可以给两者设置不同的学习率,即 TTUR),交替优化就意味我们需要传入两次数据(从内存传到显存)、执行两次前向传播和反向传播。

如果我们能把这两步合并起来,作为一步去优化,那么肯定能节省时间的, 这也就是 GAN 的同步训练。

注:本文不是介绍新的 GAN,而是介绍 GAN 的新写法,这只是一道编程题,不是一道算法题。

如果在TF中

如果是在 TensorFlow 中,实现同步训练并不困难,因为我们定义好了判别器和生成器的训练算子了(假设为 D_solver 和 G_solver ),那么直接执行:

sess.run([D_solver, G_solver], feed_dict={x_in: x_train, z_in: z_train})

就行了。这建立在我们能分别获取判别器和生成器的参数、能直接操作 sess.run 的基础上。

更通用的方法

但是如果是 Keras 呢?Keras 中已经把流程封装好了,一般来说我们没法去操作得如此精细。

所以,下面我们介绍一个通用的技巧, 只需要定义单一一个 loss,然后扔给优化器,就能够实现 GAN 的训练。 同时,从这个技巧中,我们还可以学习到如何更加灵活地操作 loss 来控制梯度。

判别器的优化

我们以 GAN 的 hinge loss 为例子,它的形式是:

巧断梯度:单个loss实现GAN模型(附开源代码)

注意 巧断梯度:单个loss实现GAN模型(附开源代码) 意味着要固定 G,因为 G 本身也是有优化参数的,不固定的话就应该是 巧断梯度:单个loss实现GAN模型(附开源代码)

为了固定G,除了“把 G 的参数从优化器中去掉”这个方法之外,我们也可以利用 stop_gradient去手动固定:

巧断梯度:单个loss实现GAN模型(附开源代码)

这里:

巧断梯度:单个loss实现GAN模型(附开源代码)

这样一来,在式 (2) 中,我们虽然同时放开了 D,G 的权重,但是不断地优化式 (2),会变的只有 D,而 G 是不会变的,因为我们用的是基于梯度下降的优化器,而 G 的梯度已经被停止了,换句话说,我们可以理解为 G 的梯度被强行设置为 0,所以它的更新量一直都是 0。

生成器的优化

现在解决了 D 的优化,那么 G 呢? stop_gradient 可以很方便地放我们固定里边部分的梯度(比如 D(G(z)) 的 G(z)),但 G 的优化是要我们去固定外边的 D,没有函数实现它。但不要灰心,我们可以用一个数学技巧进行转化。

首先,我们要清楚,我们想要 D(G(z)) 里边的 G 的梯度,不想要 D 的梯度,如果直接对 D(G(z)) 求梯度,那么同时会得到 D,G 的梯度。如果直接求 巧断梯度:单个loss实现GAN模型(附开源代码) 的梯度呢?只能得到 D 的梯度,因为 G 已经被停止了。那么,重点来了,将这两个相减,不就得到单纯的 G 的梯度了吗!

巧断梯度:单个loss实现GAN模型(附开源代码)

现在优化式 (4) ,那么 D 是不会变的,改变的是 G。

值得一提的是,直接输出这个式子,结果是恒等于 0,因为两部分都是一样的,直接相减自然是 0,但它的梯度不是 0。也就是说,这是一个恒等于 0 的 loss,但是梯度却不恒等于 0。

合成单一loss 

好了,现在式 (2) 和式 (4) 都同时放开了 D,G,大家都是 arg min,所以可以将两步合成一个 loss:

巧断梯度:单个loss实现GAN模型(附开源代码)

写出这个 loss,就可以同时完成判别器和生成器的优化了,而不需要交替训练,但是效果基本上等效于 1:1 的交替训练。引入 λ 的作用,相当于让判别器和生成器的学习率之比为 1:λ。

参考代码:

https://github.com/bojone/gan/blob/master/gan_one_step_with_hinge_loss.py

文章小结

文章主要介绍了实现 GAN 的一个小技巧,允许我们只写单个模型、用单个 loss 就实现 GAN 的训练。它本质上就是用 stop_gradient 来手动控制梯度的技巧,在其他任务上也可能用得到它。

所以,以后我写 GAN 都用这种写法了,省力省时。当然,理论上这种写法需要多耗些显存,这也算是牺牲空间换时间吧。


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

写给Web开发人员看的HTML5教程

写给Web开发人员看的HTML5教程

2012-3 / 45.00元

《写给Web开发人员看的HTML5教程》通过结合大量实际案例和源代码对HTML5的重要特性进行了详细讲解,内容全面丰富,易于理解。全书共分为12章,从HTML5的历史故事讲起,涉及了文档结构和语义、智能表单、视频与音频、画布、SVG与MathML、地理定位、Web存储与离线Web应用程序、WebSockets套接字、WebWorker多线程、微数据以及以拖曳为代表的一些全局属性,涵盖了HTML5所......一起来看看 《写给Web开发人员看的HTML5教程》 这本书的介绍吧!

在线进制转换器
在线进制转换器

各进制数互转换器

SHA 加密
SHA 加密

SHA 加密工具

RGB HSV 转换
RGB HSV 转换

RGB HSV 互转工具