Implementing Batch Normalization in Python

栏目: IT技术 · 发布时间: 5年前

内容简介:I’m recently takingThe implementation of forward pass was relatively simple, but backpropagation, which is more challenging to deal with, took me quite some time to complete. After a few hours of work and struggle, I finally got over this challenge. Here I

Why and How You Implement Batch Normalization in Neural Network

Jan 30 ·4min read

I’m recently taking Convolutional Neural Networks for Visual Recognition offered by Stanford university online and just started working on the second assignment of this course. In one part of this assignment we are asked to implement a batch normalization in a fully connected neural network.

The implementation of forward pass was relatively simple, but backpropagation, which is more challenging to deal with, took me quite some time to complete. After a few hours of work and struggle, I finally got over this challenge. Here I would love to share some of my notes and thoughts on batch normalization.

So … what is batch normalization?

Batch normalization deals with the problem of poorly initialization of neural networks. It can be interpreted as doing preprocessing at every layer of the network . It forces the activations in a network to take on a unit gaussian distribution at the beginning of the training. This ensures that all neurons have about the same output distribution in the network and improves the rate of convergence.

To see why distribution of the activations in a network matters, you can refer to pp. 46 — 62 in the lecture slides offered by the course.

Let’s say we have a batch of activations x at a layer, the zero-mean unit-variance version of x is

Implementing Batch Normalization in Python

This is actually a differentiable operation, that’s why we can apply batch normalization in the training.

In the implementation, we insert the batch normalization layer right after a fully connected layer or a convolutional layer, and before nonlinear layers.

Forward pass of batch normalization

Implementing Batch Normalization in Python

algorithm of batch normalizing transform from the original paper

Let’s look at the gist from the original research paper .

As I said earlier, the whole concept of batch normalization is pretty easy to understand. After computing the mean and the variance of a batch of activations x , we can normalize x by the operation in the third line of the gist. Also note that we introduce learnable scale and shift parameters γ and β in case that zero-mean and unit-variance constraint is too hard for our network.

So the code for forward pass looks like:

One thing to pay attention to is that the estimations of mean and variance depend on the mini-batches we send into the network, and we can’t do this at test-time. So the mean μ and variance σ² for normalization at test-time are actually the running average of values we computed during training.

And this is also why batch normalization has a regularizing effect . We add some kind of randomness when training and average out this randomness at test-time to reduce generalization error (just like the effect of dropout ).

So this is my complete implementation for forward pass from the assignment 2 of the course:

Backpropagation

Now we want to derive a way to compute the gradients of batch normalization. What makes it challenging is the fact that μ itself is a function of x and σ² is a function of both μ and x . Thus we need to be extremely careful and clear when we are performing chain rule on this normalization function.

One of the things I found very helpful when taking the course is the concept of computational graph . It breaks down a complex function into several small and simple operations and helps you perform backpropagation in a neat, organized way (by deriving local gradients of each simple operation and multiplying them together to get the result).

Kratzert’s post explains every steps to compute the gradients of batch normalization using computational graph in detail. Check it out to understand more.

In Python we can write code like this:

One downside of this staged computation is that it takes much longer to derive the final gradients since we computed a lot of “ intermediate values” which might be cancelled out when multiplied together. To make everything faster we need to differentiate the function by ourselves to get a simple result.

When I was writing this article I found a post from Kevin’s blog that talks about every step to derive the gradients by chain rule. It has explained the details very clearly so please refer to it if you are interested in the derivation.

And finally here’s my implementation:

Summary

In this article, we learned how batch normalization improves convergence and why batch normalization serves as a kind of regularization . We also implemented forward pass and backpropagation for batch normalization in python.

Although you probably don’t need to worry about the implementation since everything is already there in those popular deep learning frameworks, I always believe that doing things on our own allows us to have a better understanding. Hope you have gained something after reading this article!


以上所述就是小编给大家介绍的《Implementing Batch Normalization in Python》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

数据结构

数据结构

严蔚敏、吴伟民 / 清华大学出版社 / 2007-3-1 / 30.0

《数据结构》(C语言版)是为“数据结构”课程编写的教材,也可作为学习数据结构及其算法的C程序设计的参数教材。 本书的前半部分从抽象数据类型的角度讨论各种基本类型的数据结构及其应用;后半部分主要讨论查找和排序的各种实现方法及其综合分析比较。其内容和章节编排1992年4月出版的《数据结构》(第二版)基本一致,但在本书中更突出了抽象数据类型的概念。全书采用类C语言作为数据结构和算法的描述语言。 ......一起来看看 《数据结构》 这本书的介绍吧!

随机密码生成器
随机密码生成器

多种字符组合密码

Markdown 在线编辑器
Markdown 在线编辑器

Markdown 在线编辑器

UNIX 时间戳转换
UNIX 时间戳转换

UNIX 时间戳转换