pytorch-sentiment-analysis/4 - Convolutional Sentiment Analysis.ipynb
2021-03-12 13:25:47 +00:00

840 lines
35 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4 - Convolutional Sentiment Analysis\n",
"\n",
"In the previous notebooks, we managed to achieve a test accuracy of ~85% using RNNs and an implementation of the [Bag of Tricks for Efficient Text Classification](https://arxiv.org/abs/1607.01759) model. In this notebook, we will be using a *convolutional neural network* (CNN) to conduct sentiment analysis, implementing the model from [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/abs/1408.5882).\n",
"\n",
"**Note**: This tutorial is not aiming to give a comprehensive introduction and explanation of CNNs. For a better and more in-depth explanation check out [here](https://ujjwalkarn.me/2016/08/11/intuitive-explanation-convnets/) and [here](https://cs231n.github.io/convolutional-networks/).\n",
"\n",
"Traditionally, CNNs are used to analyse images and are made up of one or more *convolutional* layers, followed by one or more linear layers. The convolutional layers use filters (also called *kernels* or *receptive fields*) which scan across an image and produce a processed version of the image. This processed version of the image can be fed into another convolutional layer or a linear layer. Each filter has a shape, e.g. a 3x3 filter covers a 3 pixel wide and 3 pixel high area of the image, and each element of the filter has a weight associated with it, the 3x3 filter would have 9 weights. In traditional image processing these weights were specified by hand by engineers, however the main advantage of the convolutional layers in neural networks is that these weights are learned via backpropagation. \n",
"\n",
"The intuitive idea behind learning the weights is that your convolutional layers act like *feature extractors*, extracting parts of the image that are most important for your CNN's goal, e.g. if using a CNN to detect faces in an image, the CNN may be looking for features such as the existance of a nose, mouth or a pair of eyes in the image.\n",
"\n",
"So why use CNNs on text? In the same way that a 3x3 filter can look over a patch of an image, a 1x2 filter can look over a 2 sequential words in a piece of text, i.e. a bi-gram. In the previous tutorial we looked at the FastText model which used bi-grams by explicitly adding them to the end of a text, in this CNN model we will instead use multiple filters of different sizes which will look at the bi-grams (a 1x2 filter), tri-grams (a 1x3 filter) and/or n-grams (a 1x$n$ filter) within the text.\n",
"\n",
"The intuition here is that the appearance of certain bi-grams, tri-grams and n-grams within the review will be a good indication of the final sentiment."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preparing Data\n",
"\n",
"As in the previous notebooks, we'll prepare the data. \n",
"\n",
"Unlike the previous notebook with the FastText model, we no longer explicitly need to create the bi-grams and append them to the end of the sentence.\n",
"\n",
"As convolutional layers expect the batch dimension to be first we can tell TorchText to return the data already permuted using the `batch_first = True` argument on the field."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
]
}
],
"source": [
"import torch\n",
"from torchtext.legacy import data\n",
"from torchtext.legacy import datasets\n",
"import random\n",
"import numpy as np\n",
"\n",
"SEED = 1234\n",
"\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)\n",
"torch.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True\n",
"\n",
"TEXT = data.Field(tokenize = 'spacy', \n",
" tokenizer_language = 'en_core_web_sm',\n",
" batch_first = True)\n",
"LABEL = data.LabelField(dtype = torch.float)\n",
"\n",
"train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
"\n",
"train_data, valid_data = train_data.split(random_state = random.seed(SEED))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Build the vocab and load the pre-trained word embeddings."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"MAX_VOCAB_SIZE = 25_000\n",
"\n",
"TEXT.build_vocab(train_data, \n",
" max_size = MAX_VOCAB_SIZE, \n",
" vectors = \"glove.6B.100d\", \n",
" unk_init = torch.Tensor.normal_)\n",
"\n",
"LABEL.build_vocab(train_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As before, we create the iterators."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
]
}
],
"source": [
"BATCH_SIZE = 64\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
" (train_data, valid_data, test_data), \n",
" batch_size = BATCH_SIZE, \n",
" device = device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build the Model\n",
"\n",
"Now to build our model.\n",
"\n",
"The first major hurdle is visualizing how CNNs are used for text. Images are typically 2 dimensional (we'll ignore the fact that there is a third \"colour\" dimension for now) whereas text is 1 dimensional. However, we know that the first step in almost all of our previous tutorials (and pretty much all NLP pipelines) is converting the words into word embeddings. This is how we can visualize our words in 2 dimensions, each word along one axis and the elements of vectors aross the other dimension. Consider the 2 dimensional representation of the embedded sentence below:\n",
"\n",
"![](assets/sentiment9.png)\n",
"\n",
"We can then use a filter that is **[n x emb_dim]**. This will cover $n$ sequential words entirely, as their width will be `emb_dim` dimensions. Consider the image below, with our word vectors are represented in green. Here we have 4 words with 5 dimensional embeddings, creating a [4x5] \"image\" tensor. A filter that covers two words at a time (i.e. bi-grams) will be **[2x5]** filter, shown in yellow, and each element of the filter with have a _weight_ associated with it. The output of this filter (shown in red) will be a single real number that is the weighted sum of all elements covered by the filter.\n",
"\n",
"![](assets/sentiment12.png)\n",
"\n",
"The filter then moves \"down\" the image (or across the sentence) to cover the next bi-gram and another output (weighted sum) is calculated. \n",
"\n",
"![](assets/sentiment13.png)\n",
"\n",
"Finally, the filter moves down again and the final output for this filter is calculated.\n",
"\n",
"![](assets/sentiment14.png)\n",
"\n",
"In our case (and in the general case where the width of the filter equals the width of the \"image\"), our output will be a vector with number of elements equal to the height of the image (or lenth of the word) minus the height of the filter plus one, $4-2+1=3$ in this case.\n",
"\n",
"This example showed how to calculate the output of one filter. Our model (and pretty much all CNNs) will have lots of these filters. The idea is that each filter will learn a different feature to extract. In the above example, we are hoping each of the **[2 x emb_dim]** filters will be looking for the occurence of different bi-grams. \n",
"\n",
"In our model, we will also have different sizes of filters, heights of 3, 4 and 5, with 100 of each of them. The intuition is that we will be looking for the occurence of different tri-grams, 4-grams and 5-grams that are relevant for analysing sentiment of movie reviews.\n",
"\n",
"The next step in our model is to use *pooling* (specifically *max pooling*) on the output of the convolutional layers. This is similar to the FastText model where we performed the average over each of the word vectors, implemented by the `F.avg_pool2d` function, however instead of taking the average over a dimension, we are taking the maximum value over a dimension. Below an example of taking the maximum value (0.9) from the output of the convolutional layer on the example sentence (not shown is the activation function applied to the output of the convolutions).\n",
"\n",
"![](assets/sentiment15.png)\n",
"\n",
"The idea here is that the maximum value is the \"most important\" feature for determining the sentiment of the review, which corresponds to the \"most important\" n-gram within the review. How do we know what the \"most important\" n-gram is? Luckily, we don't have to! Through backpropagation, the weights of the filters are changed so that whenever certain n-grams that are highly indicative of the sentiment are seen, the output of the filter is a \"high\" value. This \"high\" value then passes through the max pooling layer if it is the maximum value in the output. \n",
"\n",
"As our model has 100 filters of 3 different sizes, that means we have 300 different n-grams the model thinks are important. We concatenate these together into a single vector and pass them through a linear layer to predict the sentiment. We can think of the weights of this linear layer as \"weighting up the evidence\" from each of the 300 n-grams and making a final decision. \n",
"\n",
"### Implementation Details\n",
"\n",
"We implement the convolutional layers with `nn.Conv2d`. The `in_channels` argument is the number of \"channels\" in your image going into the convolutional layer. In actual images this is usually 3 (one channel for each of the red, blue and green channels), however when using text we only have a single channel, the text itself. The `out_channels` is the number of filters and the `kernel_size` is the size of the filters. Each of our `kernel_size`s is going to be **[n x emb_dim]** where $n$ is the size of the n-grams.\n",
"\n",
"In PyTorch, RNNs want the input with the batch dimension second, whereas CNNs want the batch dimension first - we do not have to permute the data here as we have already set `batch_first = True` in our `TEXT` field. We then pass the sentence through an embedding layer to get our embeddings. The second dimension of the input into a `nn.Conv2d` layer must be the channel dimension. As text technically does not have a channel dimension, we `unsqueeze` our tensor to create one. This matches with our `in_channels=1` in the initialization of our convolutional layers. \n",
"\n",
"We then pass the tensors through the convolutional and pooling layers, using the `ReLU` activation function after the convolutional layers. Another nice feature of the pooling layers is that they handle sentences of different lengths. The size of the output of the convolutional layer is dependent on the size of the input to it, and different batches contain sentences of different lengths. Without the max pooling layer the input to our linear layer would depend on the size of the input sentence (not what we want). One option to rectify this would be to trim/pad all sentences to the same length, however with the max pooling layer we always know the input to the linear layer will be the total number of filters. **Note**: there an exception to this if your sentence(s) are shorter than the largest filter used. You will then have to pad your sentences to the length of the largest filter. In the IMDb data there are no reviews only 5 words long so we don't have to worry about that, but you will if you are using your own data.\n",
"\n",
"Finally, we perform dropout on the concatenated filter outputs and then pass them through a linear layer to make our predictions."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class CNN(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n",
" dropout, pad_idx):\n",
" \n",
" super().__init__()\n",
" \n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)\n",
" \n",
" self.conv_0 = nn.Conv2d(in_channels = 1, \n",
" out_channels = n_filters, \n",
" kernel_size = (filter_sizes[0], embedding_dim))\n",
" \n",
" self.conv_1 = nn.Conv2d(in_channels = 1, \n",
" out_channels = n_filters, \n",
" kernel_size = (filter_sizes[1], embedding_dim))\n",
" \n",
" self.conv_2 = nn.Conv2d(in_channels = 1, \n",
" out_channels = n_filters, \n",
" kernel_size = (filter_sizes[2], embedding_dim))\n",
" \n",
" self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, text):\n",
" \n",
" #text = [batch size, sent len]\n",
" \n",
" embedded = self.embedding(text)\n",
" \n",
" #embedded = [batch size, sent len, emb dim]\n",
" \n",
" embedded = embedded.unsqueeze(1)\n",
" \n",
" #embedded = [batch size, 1, sent len, emb dim]\n",
" \n",
" conved_0 = F.relu(self.conv_0(embedded).squeeze(3))\n",
" conved_1 = F.relu(self.conv_1(embedded).squeeze(3))\n",
" conved_2 = F.relu(self.conv_2(embedded).squeeze(3))\n",
" \n",
" #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]\n",
" \n",
" pooled_0 = F.max_pool1d(conved_0, conved_0.shape[2]).squeeze(2)\n",
" pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2)\n",
" pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2)\n",
" \n",
" #pooled_n = [batch size, n_filters]\n",
" \n",
" cat = self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim = 1))\n",
"\n",
" #cat = [batch size, n_filters * len(filter_sizes)]\n",
" \n",
" return self.fc(cat)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Currently the `CNN` model can only use 3 different sized filters, but we can actually improve the code of our model to make it more generic and take any number of filters.\n",
"\n",
"We do this by placing all of our convolutional layers in a `nn.ModuleList`, a function used to hold a list of PyTorch `nn.Module`s. If we simply used a standard Python list, the modules within the list cannot be \"seen\" by any modules outside the list which will cause us some errors.\n",
"\n",
"We can now pass an arbitrary sized list of filter sizes and the list comprehension will create a convolutional layer for each of them. Then, in the `forward` method we iterate through the list applying each convolutional layer to get a list of convolutional outputs, which we also feed through the max pooling in a list comprehension before concatenating together and passing through the dropout and linear layers."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class CNN(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n",
" dropout, pad_idx):\n",
" \n",
" super().__init__()\n",
" \n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)\n",
" \n",
" self.convs = nn.ModuleList([\n",
" nn.Conv2d(in_channels = 1, \n",
" out_channels = n_filters, \n",
" kernel_size = (fs, embedding_dim)) \n",
" for fs in filter_sizes\n",
" ])\n",
" \n",
" self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, text):\n",
" \n",
" #text = [batch size, sent len]\n",
" \n",
" embedded = self.embedding(text)\n",
" \n",
" #embedded = [batch size, sent len, emb dim]\n",
" \n",
" embedded = embedded.unsqueeze(1)\n",
" \n",
" #embedded = [batch size, 1, sent len, emb dim]\n",
" \n",
" conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n",
" \n",
" #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]\n",
" \n",
" pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n",
" \n",
" #pooled_n = [batch size, n_filters]\n",
" \n",
" cat = self.dropout(torch.cat(pooled, dim = 1))\n",
"\n",
" #cat = [batch size, n_filters * len(filter_sizes)]\n",
" \n",
" return self.fc(cat)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also implement the above model using 1-dimensional convolutional layers, where the embedding dimension is the \"depth\" of the filter and the number of tokens in the sentence is the width.\n",
"\n",
"We'll run our tests in this notebook using the 2-dimensional convolutional model, but leave the implementation for the 1-dimensional model below for anyone interested. "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class CNN1d(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, \n",
" dropout, pad_idx):\n",
" \n",
" super().__init__()\n",
" \n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)\n",
" \n",
" self.convs = nn.ModuleList([\n",
" nn.Conv1d(in_channels = embedding_dim, \n",
" out_channels = n_filters, \n",
" kernel_size = fs)\n",
" for fs in filter_sizes\n",
" ])\n",
" \n",
" self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, text):\n",
" \n",
" #text = [batch size, sent len]\n",
" \n",
" embedded = self.embedding(text)\n",
" \n",
" #embedded = [batch size, sent len, emb dim]\n",
" \n",
" embedded = embedded.permute(0, 2, 1)\n",
" \n",
" #embedded = [batch size, emb dim, sent len]\n",
" \n",
" conved = [F.relu(conv(embedded)) for conv in self.convs]\n",
" \n",
" #conved_n = [batch size, n_filters, sent len - filter_sizes[n] + 1]\n",
" \n",
" pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n",
" \n",
" #pooled_n = [batch size, n_filters]\n",
" \n",
" cat = self.dropout(torch.cat(pooled, dim = 1))\n",
" \n",
" #cat = [batch size, n_filters * len(filter_sizes)]\n",
" \n",
" return self.fc(cat)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We create an instance of our `CNN` class. \n",
"\n",
"We can change `CNN` to `CNN1d` if we want to run the 1-dimensional convolutional model, noting that both models give almost identical results."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"INPUT_DIM = len(TEXT.vocab)\n",
"EMBEDDING_DIM = 100\n",
"N_FILTERS = 100\n",
"FILTER_SIZES = [3,4,5]\n",
"OUTPUT_DIM = 1\n",
"DROPOUT = 0.5\n",
"PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
"\n",
"model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT, PAD_IDX)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Checking the number of parameters in our model we can see it has about the same as the FastText model. \n",
"\n",
"Both the `CNN` and the `CNN1d` models have the exact same number of parameters."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model has 2,620,801 trainable parameters\n"
]
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we'll load the pre-trained embeddings"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.1117, -0.4966, 0.1631, ..., 1.2647, -0.2753, -0.1325],\n",
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
" ...,\n",
" [ 0.6783, 0.0488, 0.5860, ..., 0.2680, -0.0086, 0.5758],\n",
" [-0.6208, -0.0480, -0.1046, ..., 0.3718, 0.1225, 0.1061],\n",
" [-0.6553, -0.6292, 0.9967, ..., 0.2278, -0.1975, 0.0857]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pretrained_embeddings = TEXT.vocab.vectors\n",
"\n",
"model.embedding.weight.data.copy_(pretrained_embeddings)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then zero the initial weights of the unknown and padding tokens."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n",
"\n",
"model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n",
"model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training is the same as before. We initialize the optimizer, loss function (criterion) and place the model and criterion on the GPU (if available)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"\n",
"optimizer = optim.Adam(model.parameters())\n",
"\n",
"criterion = nn.BCEWithLogitsLoss()\n",
"\n",
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We implement the function to calculate accuracy..."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def binary_accuracy(preds, y):\n",
" \"\"\"\n",
" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
" \"\"\"\n",
"\n",
" #round predictions to the closest integer\n",
" rounded_preds = torch.round(torch.sigmoid(preds))\n",
" correct = (rounded_preds == y).float() #convert into float for division \n",
" acc = correct.sum() / len(correct)\n",
" return acc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define a function for training our model...\n",
"\n",
"**Note**: as we are using dropout again, we must remember to use `model.train()` to ensure the dropout is \"turned on\" while training."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def train(model, iterator, optimizer, criterion):\n",
" \n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" \n",
" model.train()\n",
" \n",
" for batch in iterator:\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" predictions = model(batch.text).squeeze(1)\n",
" \n",
" loss = criterion(predictions, batch.label)\n",
" \n",
" acc = binary_accuracy(predictions, batch.label)\n",
" \n",
" loss.backward()\n",
" \n",
" optimizer.step()\n",
" \n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
" \n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define a function for testing our model...\n",
"\n",
"**Note**: again, as we are now using dropout, we must remember to use `model.eval()` to ensure the dropout is \"turned off\" while evaluating."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(model, iterator, criterion):\n",
" \n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" \n",
" model.eval()\n",
" \n",
" with torch.no_grad():\n",
" \n",
" for batch in iterator:\n",
"\n",
" predictions = model(batch.text).squeeze(1)\n",
" \n",
" loss = criterion(predictions, batch.label)\n",
" \n",
" acc = binary_accuracy(predictions, batch.label)\n",
"\n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
" \n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define our function to tell us how long epochs take."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"def epoch_time(start_time, end_time):\n",
" elapsed_time = end_time - start_time\n",
" elapsed_mins = int(elapsed_time / 60)\n",
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
" return elapsed_mins, elapsed_secs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we train our model..."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 01 | Epoch Time: 0m 13s\n",
"\tTrain Loss: 0.649 | Train Acc: 61.79%\n",
"\t Val. Loss: 0.507 | Val. Acc: 78.93%\n",
"Epoch: 02 | Epoch Time: 0m 13s\n",
"\tTrain Loss: 0.433 | Train Acc: 79.86%\n",
"\t Val. Loss: 0.357 | Val. Acc: 84.57%\n",
"Epoch: 03 | Epoch Time: 0m 13s\n",
"\tTrain Loss: 0.305 | Train Acc: 87.36%\n",
"\t Val. Loss: 0.312 | Val. Acc: 86.76%\n",
"Epoch: 04 | Epoch Time: 0m 13s\n",
"\tTrain Loss: 0.224 | Train Acc: 91.20%\n",
"\t Val. Loss: 0.303 | Val. Acc: 87.16%\n",
"Epoch: 05 | Epoch Time: 0m 14s\n",
"\tTrain Loss: 0.159 | Train Acc: 94.16%\n",
"\t Val. Loss: 0.317 | Val. Acc: 87.37%\n"
]
}
],
"source": [
"N_EPOCHS = 5\n",
"\n",
"best_valid_loss = float('inf')\n",
"\n",
"for epoch in range(N_EPOCHS):\n",
"\n",
" start_time = time.time()\n",
" \n",
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
" \n",
" end_time = time.time()\n",
"\n",
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
" \n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), 'tut4-model.pt')\n",
" \n",
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We get test results comparable to the previous 2 models!"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Loss: 0.343 | Test Acc: 85.31%\n"
]
}
],
"source": [
"model.load_state_dict(torch.load('tut4-model.pt'))\n",
"\n",
"test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
"\n",
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## User Input\n",
"\n",
"And again, as a sanity check we can check some input sentences\n",
"\n",
"**Note**: As mentioned in the implementation details, the input sentence has to be at least as long as the largest filter height used. We modify our `predict_sentiment` function to also accept a minimum length argument. If the tokenized input sentence is less than `min_len` tokens, we append padding tokens (`<pad>`) to make it `min_len` tokens."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"import spacy\n",
"nlp = spacy.load('en_core_web_sm')\n",
"\n",
"def predict_sentiment(model, sentence, min_len = 5):\n",
" model.eval()\n",
" tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
" if len(tokenized) < min_len:\n",
" tokenized += ['<pad>'] * (min_len - len(tokenized))\n",
" indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
" tensor = torch.LongTensor(indexed).to(device)\n",
" tensor = tensor.unsqueeze(0)\n",
" prediction = torch.sigmoid(model(tensor))\n",
" return prediction.item()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An example negative review..."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.09913548082113266"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict_sentiment(model, \"This film is terrible\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An example positive review..."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9769725799560547"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict_sentiment(model, \"This film is great\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}