LSTM Text Classification Using Pytorch

栏目: IT技术 · 发布时间: 3年前

LSTM Text Classification Using Pytorch

A step-by-step guide teaching you how to build a bidirectional LSTM in Pytorch!

LSTM Text Classification Using Pytorch

Photo by Christopher Gower on Unsplash

Intro

Welcome to this tutorial! This tutorial will teach you how to build a bidirectional LSTM for text classification in just a few minutes. If you haven’t already checked out my previous article on BERT Text Classification , this tutorial contains similar code with that one but contains some modifications to support LSTM. This article also gives explanations on how I preprocessed the dataset used in both articles, which is the REAL and FAKE News Dataset from Kaggle.

First of all, what is an LSTM and why do we use it? LSTM stands for Long Short-Term Memory Network , which belongs to a larger category of neural networks called Recurrent Neural Network (RNN) . Its main advantage over the vanilla RNN is that it is better capable of handling long term dependencies through its sophisticated architecture that includes three different gates: input gate, output gate, and the forget gate. The three gates operate together to decide what information to remember and what to forget in the LSTM cell over an arbitrary time.

LSTM Text Classification Using Pytorch

LSTM Cell

Now, we have a bit more understanding of LSTM, let’s focus on how to implement it for text classification. The tutorial is divided into the following steps:

  1. Preprocess Dataset
  2. Importing Libraries
  3. Load Dataset
  4. Build Model
  5. Training
  6. Evaluation

Before we dive right into the tutorial, here is where you can access the code in this article:

Step 1: Preprocess Dataset

The raw dataset looks like the following:

LSTM Text Classification Using Pytorch

Dataset Overview

The dataset contains an arbitrary index, title, text, and the corresponding label.

For preprocessing, we import Pandas and Sklearn and define some variables for path, training validation and test ratio, as well as the trim_string function which will be used to cut each sentence to the first first_n_words words. Trimming the samples in a dataset is not necessary but it enables faster training for heavier models and is normally enough to predict the outcome.

Next, we convert REAL to 0 and FAKE to 1, concatenate title and text to form a new column titletext (we use both the title and text to decide the outcome), drop rows with empty text, trim each sample to the first_n_words , and split the dataset according to train_test_ratio and train_valid_ratio . We save the resulting dataframes into .csv files, getting train.csv , valid.csv , and test.csv .

Step 2: Importing Libraries

We import Pytorch for model construction, torchText for loading data, matplotlib for plotting, and sklearn for evaluation.

Step 3: Load Dataset

First, we use torchText to create a label field for the label in our dataset and a text field for the title , text , and titletext . We then build a TabularDataset by pointing it to the path containing the train.csv , valid.csv , and test.csv dataset files. We create the train, valid, and test iterators that load the data, and finally, build the vocabulary using the train iterator (counting only the tokens with a minimum frequency of 3).

Step 4: Build Model

We construct the LSTM class that inherits from the nn.Module . Inside the LSTM, we construct an Embedding layer, followed by a bi-LSTM layer, and ending with a fully connected linear layer. In the forward function, we pass the text IDs through the embedding layer to get the embeddings, pass it through the LSTM accommodating variable-length sequences, learn from both directions, pass it through the fully connected linear layer, and finally sigmoid to get the probability of the sequences belonging to FAKE (being 1).

Step 5: Training

Before training, we build save and load functions for checkpoints and metrics. For checkpoints, the model parameters and optimizer are saved; for metrics, the train loss, valid loss, and global steps are saved so diagrams can be easily reconstructed later.

We train the LSTM with 10 epochs and save the checkpoint and metrics whenever a hyperparameter setting achieves the best (lowest) validation loss. Here is the output during training:

LSTM Text Classification Using Pytorch

The whole training process was fast on Google Colab. It took less than two minutes to train!

Once we finished training, we can load the metrics previously saved and output a diagram showing the training loss and validation loss throughout time.

LSTM Text Classification Using Pytorch

Step 6: Evaluation

Finally for evaluation, we pick the best model previously saved and evaluate it against our test dataset. We use a default threshold of 0.5 to decide when to classify a sample as FAKE. If the model output is greater than 0.5, we classify that news as FAKE; otherwise, REAL. We output the classification report indicating the precision, recall, and F1-score for each class, as well as the overall accuracy. We also output the confusion matrix.

LSTM Text Classification Using Pytorch

We can see that with a one-layer bi-LSTM, we can achieve an accuracy of 77.53% on the fake news detection task.

Conclusion

This tutorial gives a step-by-step explanation of implementing your own LSTM model for text classification using Pytorch. We find out that bi-LSTM achieves an acceptable accuracy for fake news detection but still has room to improve. If you want a more competitive performance, check out my previous article on BERT Text Classification !

If you want to learn more about modern NLP and deep learning, make sure to follow me for updates on upcoming articles :)

References

[1] S. Hochreiter, J. Schmidhuber, Long Short-Term Memory (1997), Neural Computation


以上所述就是小编给大家介绍的《LSTM Text Classification Using Pytorch》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

React 进阶之路

React 进阶之路

徐超 / 清华大学出版社 / 2018-4 / 69.00元

《React进阶之路》详细介绍了React技术栈涉及的主要技术。本书分为基础篇、进阶篇和实战篇三部分。基础篇主要介绍React的基本用法,包括React 16的新特性;进阶篇深入讲解组件state、虚拟DOM、高阶组件等React中的重要概念,同时对初学者容易困惑的知识点做了介绍;实战篇介绍React Router、Redux和MobX 3个React技术栈的重要成员,并通过实战项目讲解这些技术如......一起来看看 《React 进阶之路》 这本书的介绍吧!

URL 编码/解码
URL 编码/解码

URL 编码/解码

Markdown 在线编辑器
Markdown 在线编辑器

Markdown 在线编辑器