Few-shot Learning with Prototypical Networks

栏目: IT技术 · 发布时间: 5年前

内容简介:We, humans, have the ability to recognize a class given only a few examples of that class. For instance, a child only needs two or three images of a rabbit to be able to recognize this animal among other species. This capacity to learn from few examples ov

Learn to code a Few-shot Learning algorithm on the Omniglot dataset

Few-shot Learning with Prototypical Networks

Image credit: https://unsplash.com/photos/kZO9xqmO_TA

Introduction

We, humans, have the ability to recognize a class given only a few examples of that class. For instance, a child only needs two or three images of a rabbit to be able to recognize this animal among other species. This capacity to learn from few examples overtakes any classical Machine Learning algorithm. A lot of people think the Human Kind is being overthrown by AI, but here is the truth: to be able to well differentiate classes, a classifier is often fed with several thousands of images per class… while we only need two or three!

Prototypical Networks is an algorithm introduced by Snell et al. in 2017 (in “Prototypical Networks for Few-shot Learning”) that addresses the Few-shot Learning paradigm. Let’s understand it step by step with an example. In this article, our goal is to classify images. The code provided is in PyTorch, available here.

The Omniglot dataset

In Few-shot Learning, we are given a dataset with few images per class (1 to 10 usually). In this article, we will work on the Omniglot dataset, which contains 1,623 different handwritten characters collected from 50 alphabets. This dataset can be found in this GitHub repository . I used the “images_background.zip” and the “images_evaluation.zip” files.

Few-shot Learning with Prototypical Networks

Examples of characters found in the Omniglot dataset

As suggested in the official paper, data augmentation is performed to increase the number of classes. In practice, all the images are rotated by 90°, 180° and 270°, each rotation resulting in an additional class. Once this data augmentation is performed, we have 1,623 * 4 = 6,492 classes. I split the whole dataset into a training set (images of 4,200 classes), and a testing set (images of 2,292 classes).

Select a sample

To create a sample, Nc classes are randomly picked among all classes. For each class we have two sets of images: the support set of size Ns and the query set of size Nq.

Few-shot Learning with Prototypical Networks

Illustration of a sample of Nc classes, each containing a support set and a query set

Embed the images

“Our approach is based on the idea that there exists an embedding in which points cluster around a single prototype representation for each class.” claim the authors of the original paper.

In other words, there exists a Mathematical representation of the images, in which images of the same class gather in groups called clusters. The main advantage of working in that embedding space is that two images that look the same will be close to each other, and two images that are completely different will be far away.

In our case, with the Omniglot dataset, the embedding block takes (28x28x3) images as inputs and returns column 64-dimensional points. The image2vector function is composed of 4 modules. Each module consists of a convolutional layer, a batch normalization, a ReLu activation function and a 2x2 max pooling layer.

Few-shot Learning with Prototypical Networks

The 4 modules of the image2vector function

Compute the class prototypes

In this step we compute a prototype for each cluster. Once the support images are embedded, vectors are averaged to form a class prototype, a kind of “delegate” for that class.

Few-shot Learning with Prototypical Networks

where v(k) is the prototype of class k, f_phi is the embedding function and xi are the support images.

Few-shot Learning with Prototypical Networks

One prototype is computed per class

Compute distances between queries and prototypes

This step consists in classifying the query images. To do so, we compute the distance between each image and the prototypes. Metric choice is crucial here, and the inventors of Prototypical Networks must be credited to their choice of distance: the Euclidean distance.

Once distances are computed, a softmax is performed over them to get probabilities of belonging to each class.

Compute the loss and backpropagate

Prototypical Networks learning phase proceeds by minimizing the negative log-probability, also called log-softmax loss. The main advantage of using a logarithm is to drastically increase the loss when the model fails to predict the right class.

The backpropagation is performed via Stochastic Gradient Descent.

Launch training

The whole sequence described above forms an episode. And the training phase contains several episodes. I tried to reproduce the results of the original paper. Here are the training settings:

  • Nc: 60 classes
  • Ns: 1 or 5 support points / class
  • Nq: 5 query points / class
  • 5 epochs
  • 2000 episodes / epoch
  • Learning Rate initially at 0.001 and divided by 2 at each epoch

The training took 30 min to run.

Results

Once the ProtoNet is trained, we can test it with new data. We select samples in the testing set in a similar way. The support set is used to compute de prototypes, and then each point of the query set is labelled according to the shorter distance to prototypes.

For the testing I tried 5-way and 20-way scenarios. I took the same number of support and query points than during the training phase. The tests were performed on 1000 episodes.

The results are presented in the table below. “5-way 1-shot” means Nc = 5 and Ns = 1.

Few-shot Learning with Prototypical Networks
Obtained VS paper results

I obtained similar results than the original paper, slightly better in some cases. This may be due to the sampling strategy which is not specified in the paper. I used random sampling at each episode.


以上所述就是小编给大家介绍的《Few-shot Learning with Prototypical Networks》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

可伸缩架构

可伸缩架构

【美】Lee Atchison / 张若飞、张现双 / 电子工业出版社 / 2017-7 / 65

随着互联网的发展越来越成熟,流量和数据量飞速增长,许多公司的关键应用程序都面临着伸缩性的问题,系统变得越来越复杂和脆弱,从而导致风险上升、可用性降低。《可伸缩架构:面向增长应用的高可用》是一本实践指南,让IT、DevOps和系统稳定性管理员能够了解到,如何避免应用程序在发展过程中变得缓慢、数据不一致或者彻底不可用等问题。规模增长并不只意味着处理更多的用户,还包括管理更多的风险和保证系统的可用性。作......一起来看看 《可伸缩架构》 这本书的介绍吧!

HTML 压缩/解压工具
HTML 压缩/解压工具

在线压缩/解压 HTML 代码

Markdown 在线编辑器
Markdown 在线编辑器

Markdown 在线编辑器

html转js在线工具
html转js在线工具

html转js在线工具