Tensorflow上手2: Keras的技巧和弊端

栏目: 数据库 · 发布时间: 5年前

内容简介:就在不久前,TF 2.0的预告发布,大家都在讨论着Tensorflow接口的变化,于是我也开始尝试使用Tensorflow版本的Keras.Keras是一个非常易用的框架,提供了更好的神经网络层级Layer的抽象,但是真正实现大规模模型训练时却遇到了一些坑.本文从我个人使用经历中看看Keras的灵巧和缺点,其中有一些坑可能大家会有更好的解决方案,欢迎多多指教.层(Layer)是Keras当中最重要的概念之一,但是Keras本身提供的Layer实现又不如Tensorflow的计算图那么的多样化与方便,所以Ke
Tensorflow上手2: Keras的技巧和弊端

就在不久前,TF 2.0的预告发布,大家都在讨论着Tensorflow接口的变化,于是我也开始尝试使用Tensorflow版本的Keras.Keras是一个非常易用的框架,提供了更好的神经网络层级Layer的抽象,但是真正实现大规模模型训练时却遇到了一些坑.

本文从我个人使用经历中看看Keras的灵巧和缺点,其中有一些坑可能大家会有更好的解决方案,欢迎多多指教.

Lambda层与自定义层

层(Layer)是Keras当中最重要的概念之一,但是Keras本身提供的Layer实现又不如Tensorflow的计算图那么的多样化与方便,所以Keras就通过Lambda层和自定义层对其灵活性进行拓展.

Lambda层的使用最为简单:

from tensorflow.keras.layers import Lambda
model.add(Lambda(lambda x: x**2))

其次是自定义层,相比Lambda层的好处是,自定义可以给Layer增加新的可以训练的参数,这些参数需要在build函数中进行定义,比如说一个自定义的dot product层(代码来自官网):

class MyLayer(Layer):
 def __init__(self, output_dim, **kwargs):
 super(MyLayer, self).__init__(**kwargs)
 self.output_dim = output_dim
def build(self, input_shape):
 self.kernel = self.add_weight(
 name=’kernel’,
 shape=(input_shape[1], self.output_dim)
 initializer=’uniform’,
 trainable=True)
def call(self, x):
 return K.dot(x, self.kernel)
def compute_output_shape(self, input_shape):
 return (input_shape[0], self.output_dim)

通过Lambda层和自定义层的灵活运用,人们可以用Keras写出一个很好的Mask RCNN代码,并且通过Keras提供的可视化函数plot_model,讲网络结构打印出来.

Lambda层和自定义层虽然很灵活,但是真正使用过程中还是会遇到不少坑.

TypeError: can’t pickle _thread.lock objects

有时候在你的网络中有Lambda层,保存到时候会遇到以上错误.这通常是因为Lambda层在进行序列化的时候无法序列化你使用的某一个 Python 非静态函数.

这时候有三种解决办法,第一种是在保存模型的时候选择model.save_weights而非model.save,毕竟大部分时候你不需要原先的训练结构.第二种办法是采用functools.partial等函数将Python函数包装成静态函数.第三种办法时放弃Lambda层,讲其函数和参数包装成一个自定义层.我个人推荐这一方案,原因会在下一个问题揭晓.

推荐使用自定义层

从我个人的使用来看,使用自定义层能够更好的对模型结构信息进行存储,包括每一个自定义层采用的参数等等.在新版的Tensorflow中,可以通过Keras导出Estimator使用的模型,这时候我们需要每一个自定义层使用的参数,这些参数可以通过自定义层的如下函数导出:

class MyLayer(Layer):
 def __init__(self, output_dim, **kwargs):
 super(MyLayer, self).__init__(**kwargs)
 self.output_dim = output_dim
def get_conf(self):
 config = super(MyLayer, self).get_config()
 config.update({‘output_dim’: self.output_dim})
 return config

双输出时一定要用list

Keras的灵活性还在于一个层可以有多个输出,就比方Mask RCNN里面的fpn_classifier_graph,在我们使用自定义层产生多输出的时候,既可以

return out1, out2

也可以

return [out1, out2]

这时候请记住,一定要选择返回list,不然你试试多GPU训练的情况就知道了.

关于Keras模型中的Loss

Keras中对Loss的基本定义是一个输入为y_true和y_pred的函数,但是在特殊情况下,他也可以结合权重进行复杂的运算.

就我个人写代码,阅读代码,阅读博客的经验来看,Kera的自定义loss有很多种写法,非常灵活,但与此同时也会遇到不同的问题.

model.add_loss

抛开最基本的model.compile(loss=…),我第一个尝试的是模仿Mask RCNN,通过model.add_loss函数另模型同时优化多个损失函数.使用add_loss的主要原因是传入的函数不能通过简单的y_true, y_pred进行计算,比方Mask RCNN要同时计算边界的损失和分类的损失,写成两个张量的表达形式很复杂并且不容易扩展.

然而add_loss也有自己的弊端,比方说如果我调用函数model_to_estimator,那么我加入的loss就没有了.至于我为什么要调用这个函数,下文会给予介绍.

layer.loss

通过阅读让Keras更酷一些,我了解到我们还可以通过自定义层添加loss,如下所述:

class MyLayer(Layer):
 def loss(self, y_true, y_pred):
 # do something

在原文中,作者给出的函数可以访问该层的一些参数,这样以来可以更灵活的给自定义层的参数进行一定的约束,但是由于一方面我本人没有试验过,不清楚这样做的利弊,另一方面看起来函数的借口还是不够灵活,而且我也不确定如何将多个不同层的损失叠加到一起,所以不能做更多的评价.

回到model.compile(loss=…)

因为一些特殊的原因,我最后还是尝试了将所有的损失函数通过一个自定义层合并,然后通过compile添加到模型里的做法.写成程序很简单:

model.compile(
 optimizer=’adam’,
 loss=lambda y_true, y_pred: pred)

结语,关于回调和分布式训练

Keras的回调是一个很厉害的功能,可以通过回调生成Tensorboard的Summary,调整训练速率,提前结束训练等等,但是使用不好的话很可能在训练过程中造成内存溢出.

最后,我暂时的放弃了Keras,其实并不是因为别的什么原因,而是在我研究如何利用Keras进行分布式训练的时候,我看到了这么一个推荐方案:

End to end example for multi worker training in tensorflow/ecosystem using Kuberentes templates. This example starts with a Keras model and converts it to an Estimator using the tf.keras.estimator.model_to_estimator API.

那么,还是让我们直接使用Estimator比较方便吧.


以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Mobilizing Web Sites

Mobilizing Web Sites

Layon, Kristofer / 2011-12 / 266.00元

Everyone has been talking about the mobile web in recent years, and more of us are browsing the web on smartphones and similar devices than ever before. But most of what we are viewing has not yet bee......一起来看看 《Mobilizing Web Sites》 这本书的介绍吧!

JSON 在线解析
JSON 在线解析

在线 JSON 格式化工具

URL 编码/解码
URL 编码/解码

URL 编码/解码

MD5 加密
MD5 加密

MD5 加密工具