浅谈深度学习训练中数据规范化(Normalization)的重要性

栏目: Python · 发布时间: 5年前

内容简介:在pytorch附带的模型中我们可以选择预训练模型即模型中的权重参数都被训练好了,在构造模型后读取模型权重即可。

前言

数据规范-Normalization 是深度学习中我们很容易忽视,也很容易出错的问题。我们训练的所有数据在输入到模型中的时候都要进行一些规范化。例如在pytorch中,有些模型是通过规范化后的数据进行训练的,所以我们在使用这些预训练好的模型的时候,要注意在将自己的数据投入模型中之前要首先对数据进行规范化。

在pytorch附带的模型中我们可以选择 预训练模型

  1. import torchvision.models as models
    resnet18 = models.resnet18(pretrained=True)
    alexnet = models.alexnet(pretrained=True)
    squeezenet = models.squeezenet1_0(pretrained=True)
    vgg16 = models.vgg16(pretrained=True)
    densenet = models.densenet161(pretrained=True)
    inception = models.inception_v3(pretrained=True)复制代码

预训练模型即模型中的权重参数都被训练好了,在构造模型后读取模型权重即可。

但是有些东西需要注意:

  • 模型的权重参数是训练好的,但是要确定你输入的数据和预训练时使用的数据格式一致。
  • 要注意什么时候需要格式化什么时候不需要。

也就是说,模型设计的正确只是第一步,我们输入的图像数据的格式的正确性也是特别重要的,我们平常输入的图像大部分都是三通道RGB彩色图像,数据范围大部分都是[0-255],也就是通常意义上的24-bit图(RGB三通道各8位)。在pytorch中有专门的一些模块: transforms 模块来对图像进行一些预处理操作:

transform = transforms.Compose([
        transforms.RandomResizedCrop(100),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])复制代码

正文

而且在pytorch的官方介绍中也提到了,pytorch使用的预训练模型搭配的数据必须是:

浅谈深度学习训练中数据规范化(Normalization)的重要性

也就是3通道RGB图像(3 x H x W),而且高和宽最好不低于224(因为拿来做预训练的模型大小就是224 x 224),并且图像数据大小的范围为[0-1],使用mean和std去Normalize。

为什么这样说官方也说了,因为所有的预训练模型是使用经过normal后的数据得到的,所以我们输入的数据也必须经过格式化,否则很容易出现损失爆炸。

为什么我们要进行格式化呢?

我们选取一组人脸图片来举个例子,这组人脸图像的格式是这样的:

  • 取100组人脸图像
  • 图像的高和宽都是100
  • 3通道,像素点的范围是[0-255]

这里从 Labeled faces in the Wild 数据集中取出100个人脸图像,这个数据集中每张图像对应着一个名字,而且每张图像的脸都差不多被定位到了中间。

浅谈深度学习训练中数据规范化(Normalization)的重要性

我们有了这一组数据后,接下来要做的一般是这几个步骤:

统一形状和大小

在图像输入到神经网络之前要注意,每张图都要保证一样的尺寸和大小。大部分的模型要求输入的图像的形状是正方形,一般都是256 x 256、128 x 128 、 64 x 64或者其他的形状,这种方形是最好进行训练的。当然其他形状也是可以的,比如长方形,但如果是长方形的话就要注意设计卷积层通道的时候要稍微注意一下。总之,我们都是先对图像极性crop,crop成正方形,一般取图像的中心位置。

浅谈深度学习训练中数据规范化(Normalization)的重要性

比如下面这张人脸图(256 x 256)就很舒服,呃,因为不用修剪了。

浅谈深度学习训练中数据规范化(Normalization)的重要性

图像比例

比例也是比较重要的,图像形状确定了,但是有些时候我们在训练时随着卷积层越来越深,特征图越来越小,为了实现一些功能,我们所需要的图像的比例也要稍微改变一下。不论是放大还是缩小,假如缩小到100像素,我们就让上面的图像乘以0.39(100/256)。但是放大和缩小时都要考虑四舍五入,是floor还是ceil就各有见地了。

均值,方差

一组图像集的均值和方差可以很好地概括这组图像的信息和特征。均值就是一组数据的平均水平,而方差代表的是数据的离散程度。下图是之前展示的100张人脸图的均值图和方差图,可以看到左面的均值图中,明显看到一个模糊的人脸。并且可以看出100张人脸图中,人的脸是分布在中心的,而右边的方差图可以看到中心颜色偏暗(小于100),四周偏亮(大于100),也就是说明100张图中,图像四周的分布明显变化比较剧烈。

在这样Normalize之后,神经网络在训练的过程中,梯度对每一张图片的作用都是平均的,也就是不存在比例不匹配的情况。而在normalize之前每张图片的特征分布都是不一样的,有的陡峭有的平缓,如果不进行预处理,那么在梯度下降的时候有些图片的特征值比较大而有些则比较小,这样梯度运算无法顾及到不同特征不同维度不同层次的下降趋势,这样很难进行训练,loss会不停的震荡。

浅谈深度学习训练中数据规范化(Normalization)的重要性

格式化(Normalization)

说到重点了,我们在文章最开始说的格式化,其实即使在一组图中,每个图像的像素点首先减去所有图像均值的像素点,然后再除以方差。这样可以保证所有的图像分布都相似,也就是在训练的时候更容易收敛,也就是训练的更快更好了。另外,不同图像像素点范围的mean和std是不一样的,一般我们输入的都是[0-1]或者[0-255]的图像数据,在pytorch的模型中,输入的是[0-1],而在caffe的模型中,我们输入的是[0-255]。

下面这个图就是在格式化后的100张人脸图。

浅谈深度学习训练中数据规范化(Normalization)的重要性

显然,格式化就是使数据中心对齐,如cs231n中的示例图,左边是原始数据,中间是减去mean的数据分布,右边是除以std方差的数据分布,当然cs231n中说除以std其实可以不去执行,因为只要数据都遵循一定范围的时候(比如图像都是[0-255])就没有必要这样做了。

浅谈深度学习训练中数据规范化(Normalization)的重要性

维数变化

有时候需要输入不是彩色图,这时候可能需要对数据进行降维操作,也就是RGB->GRAY,当然还有颜色通道和色彩通道的改变,例如RGB->BGR,或者RGB->YUV。颜色通道的改变是为了实现不同的任务和功能,这就要视情况来决定。

浅谈深度学习训练中数据规范化(Normalization)的重要性

其他变化:数据增强

在pytorch的transforms模块中有很多的变化,都可以用来做数据增强,比如图像翻转,旋转,极坐标变换,都可以得到不同的“原始图”从而加大训练变量达到很好的训练效果。这里不多说,这个需要单独说明。


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

MySQL性能调优与架构设计

MySQL性能调优与架构设计

简朝阳 / 2009-6 / 59.80元

《MySQL性能调优与架构设计》以 MySQL 数据库的基础及维护为切入点,重点介绍了 MySQL 数据库应用系统的性能调优,以及高可用可扩展的架构设计。 全书共分3篇,基础篇介绍了MySQL软件的基础知识、架构组成、存储引擎、安全管理及基本的备份恢复知识。性能优化篇从影响 MySQL 数据库应用系统性能的因素开始,针对性地对各个影响因素进行调优分析。如 MySQL Schema 设计的技巧......一起来看看 《MySQL性能调优与架构设计》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

在线进制转换器
在线进制转换器

各进制数互转换器

XML 在线格式化
XML 在线格式化

在线 XML 格式化压缩工具