pytorch-sentiment-analysis/2 - Upgraded Sentiment Analysis.ipynb
2021-03-12 13:25:47 +00:00

886 lines
36 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2 - Updated Sentiment Analysis\n",
"\n",
"In the previous notebook, we got the fundamentals down for sentiment analysis. In this notebook, we'll actually get decent results.\n",
"\n",
"We will use:\n",
"- packed padded sequences\n",
"- pre-trained word embeddings\n",
"- different RNN architecture\n",
"- bidirectional RNN\n",
"- multi-layer RNN\n",
"- regularization\n",
"- a different optimizer\n",
"\n",
"This will allow us to achieve ~84% test accuracy."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preparing Data\n",
"\n",
"As before, we'll set the seed, define the `Fields` and get the train/valid/test splits.\n",
"\n",
"We'll be using *packed padded sequences*, which will make our RNN only process the non-padded elements of our sequence, and for any padded element the `output` will be a zero tensor. To use packed padded sequences, we have to tell the RNN how long the actual sequences are. We do this by setting `include_lengths = True` for our `TEXT` field. This will cause `batch.text` to now be a tuple with the first element being our sentence (a numericalized tensor that has been padded) and the second element being the actual lengths of our sentences."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
]
}
],
"source": [
"import torch\n",
"from torchtext.legacy import data\n",
"\n",
"SEED = 1234\n",
"\n",
"torch.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True\n",
"\n",
"TEXT = data.Field(tokenize = 'spacy',\n",
" tokenizer_language = 'en_core_web_sm',\n",
" include_lengths = True)\n",
"\n",
"LABEL = data.LabelField(dtype = torch.float)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We then load the IMDb dataset."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
]
}
],
"source": [
"from torchtext.legacy import datasets\n",
"\n",
"train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then create the validation set from our training set."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"train_data, valid_data = train_data.split(random_state = random.seed(SEED))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next is the use of pre-trained word embeddings. Now, instead of having our word embeddings initialized randomly, they are initialized with these pre-trained vectors.\n",
"We get these vectors simply by specifying which vectors we want and passing it as an argument to `build_vocab`. `TorchText` handles downloading the vectors and associating them with the correct words in our vocabulary.\n",
"\n",
"Here, we'll be using the `\"glove.6B.100d\" vectors\"`. `glove` is the algorithm used to calculate the vectors, go [here](https://nlp.stanford.edu/projects/glove/) for more. `6B` indicates these vectors were trained on 6 billion tokens and `100d` indicates these vectors are 100-dimensional.\n",
"\n",
"You can see the other available vectors [here](https://github.com/pytorch/text/blob/master/torchtext/vocab.py#L113).\n",
"\n",
"The theory is that these pre-trained vectors already have words with similar semantic meaning close together in vector space, e.g. \"terrible\", \"awful\", \"dreadful\" are nearby. This gives our embedding layer a good initialization as it does not have to learn these relations from scratch.\n",
"\n",
"**Note**: these vectors are about 862MB, so watch out if you have a limited internet connection.\n",
"\n",
"By default, TorchText will initialize words in your vocabulary but not in your pre-trained embeddings to zero. We don't want this, and instead initialize them randomly by setting `unk_init` to `torch.Tensor.normal_`. This will now initialize those words via a Gaussian distribution."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"MAX_VOCAB_SIZE = 25_000\n",
"\n",
"TEXT.build_vocab(train_data, \n",
" max_size = MAX_VOCAB_SIZE, \n",
" vectors = \"glove.6B.100d\", \n",
" unk_init = torch.Tensor.normal_)\n",
"\n",
"LABEL.build_vocab(train_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As before, we create the iterators, placing the tensors on the GPU if one is available.\n",
"\n",
"Another thing for packed padded sequences all of the tensors within a batch need to be sorted by their lengths. This is handled in the iterator by setting `sort_within_batch = True`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
]
}
],
"source": [
"BATCH_SIZE = 64\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
" (train_data, valid_data, test_data), \n",
" batch_size = BATCH_SIZE,\n",
" sort_within_batch = True,\n",
" device = device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build the Model\n",
"\n",
"The model features the most drastic changes.\n",
"\n",
"### Different RNN Architecture\n",
"\n",
"We'll be using a different RNN architecture called a Long Short-Term Memory (LSTM). Why is an LSTM better than a standard RNN? Standard RNNs suffer from the [vanishing gradient problem](https://en.wikipedia.org/wiki/Vanishing_gradient_problem). LSTMs overcome this by having an extra recurrent state called a _cell_, $c$ - which can be thought of as the \"memory\" of the LSTM - and the use use multiple _gates_ which control the flow of information into and out of the memory. For more information, go [here](https://colah.github.io/posts/2015-08-Understanding-LSTMs/). We can simply think of the LSTM as a function of $x_t$, $h_t$ and $c_t$, instead of just $x_t$ and $h_t$.\n",
"\n",
"$$(h_t, c_t) = \\text{LSTM}(x_t, h_t, c_t)$$\n",
"\n",
"Thus, the model using an LSTM looks something like (with the embedding layers omitted):\n",
"\n",
"![](assets/sentiment2.png)\n",
"\n",
"The initial cell state, $c_0$, like the initial hidden state is initialized to a tensor of all zeros. The sentiment prediction is still, however, only made using the final hidden state, not the final cell state, i.e. $\\hat{y}=f(h_T)$.\n",
"\n",
"### Bidirectional RNN\n",
"\n",
"The concept behind a bidirectional RNN is simple. As well as having an RNN processing the words in the sentence from the first to the last (a forward RNN), we have a second RNN processing the words in the sentence from the **last to the first** (a backward RNN). At time step $t$, the forward RNN is processing word $x_t$, and the backward RNN is processing word $x_{T-t+1}$. \n",
"\n",
"In PyTorch, the hidden state (and cell state) tensors returned by the forward and backward RNNs are stacked on top of each other in a single tensor. \n",
"\n",
"We make our sentiment prediction using a concatenation of the last hidden state from the forward RNN (obtained from final word of the sentence), $h_T^\\rightarrow$, and the last hidden state from the backward RNN (obtained from the first word of the sentence), $h_T^\\leftarrow$, i.e. $\\hat{y}=f(h_T^\\rightarrow, h_T^\\leftarrow)$ \n",
"\n",
"The image below shows a bi-directional RNN, with the forward RNN in orange, the backward RNN in green and the linear layer in silver. \n",
"\n",
"![](assets/sentiment3.png)\n",
"\n",
"### Multi-layer RNN\n",
"\n",
"Multi-layer RNNs (also called *deep RNNs*) are another simple concept. The idea is that we add additional RNNs on top of the initial standard RNN, where each RNN added is another *layer*. The hidden state output by the first (bottom) RNN at time-step $t$ will be the input to the RNN above it at time step $t$. The prediction is then made from the final hidden state of the final (highest) layer.\n",
"\n",
"The image below shows a multi-layer unidirectional RNN, where the layer number is given as a superscript. Also note that each layer needs their own initial hidden state, $h_0^L$.\n",
"\n",
"![](assets/sentiment4.png)\n",
"\n",
"### Regularization\n",
"\n",
"Although we've added improvements to our model, each one adds additional parameters. Without going into overfitting into too much detail, the more parameters you have in in your model, the higher the probability that your model will overfit (memorize the training data, causing a low training error but high validation/testing error, i.e. poor generalization to new, unseen examples). To combat this, we use regularization. More specifically, we use a method of regularization called *dropout*. Dropout works by randomly *dropping out* (setting to 0) neurons in a layer during a forward pass. The probability that each neuron is dropped out is set by a hyperparameter and each neuron with dropout applied is considered indepenently. One theory about why dropout works is that a model with parameters dropped out can be seen as a \"weaker\" (less parameters) model. The predictions from all these \"weaker\" models (one for each forward pass) get averaged together withinin the parameters of the model. Thus, your one model can be thought of as an ensemble of weaker models, none of which are over-parameterized and thus should not overfit.\n",
"\n",
"### Implementation Details\n",
"\n",
"Another addition to this model is that we are not going to learn the embedding for the `<pad>` token. This is because we want to explitictly tell our model that padding tokens are irrelevant to determining the sentiment of a sentence. This means the embedding for the pad token will remain at what it is initialized to (we initialize it to all zeros later). We do this by passing the index of our pad token as the `padding_idx` argument to the `nn.Embedding` layer.\n",
"\n",
"To use an LSTM instead of the standard RNN, we use `nn.LSTM` instead of `nn.RNN`. Also, note that the LSTM returns the `output` and a tuple of the final `hidden` state and the final `cell` state, whereas the standard RNN only returned the `output` and final `hidden` state. \n",
"\n",
"As the final hidden state of our LSTM has both a forward and a backward component, which will be concatenated together, the size of the input to the `nn.Linear` layer is twice that of the hidden dimension size.\n",
"\n",
"Implementing bidirectionality and adding additional layers are done by passing values for the `num_layers` and `bidirectional` arguments for the RNN/LSTM. \n",
"\n",
"Dropout is implemented by initializing an `nn.Dropout` layer (the argument is the probability of dropping out each neuron) and using it within the `forward` method after each layer we want to apply dropout to. **Note**: never use dropout on the input or output layers (`text` or `fc` in this case), you only ever want to use dropout on intermediate layers. The LSTM has a `dropout` argument which adds dropout on the connections between hidden states in one layer to hidden states in the next layer. \n",
"\n",
"As we are passing the lengths of our sentences to be able to use packed padded sequences, we have to add a second argument, `text_lengths`, to `forward`. \n",
"\n",
"Before we pass our embeddings to the RNN, we need to pack them, which we do with `nn.utils.rnn.packed_padded_sequence`. This will cause our RNN to only process the non-padded elements of our sequence. The RNN will then return `packed_output` (a packed sequence) as well as the `hidden` and `cell` states (both of which are tensors). Without packed padded sequences, `hidden` and `cell` are tensors from the last element in the sequence, which will most probably be a pad token, however when using packed padded sequences they are both from the last non-padded element in the sequence. Note that the `lengths` argument of `packed_padded_sequence` must be a CPU tensor so we explicitly make it one by using `.to('cpu')`.\n",
"\n",
"We then unpack the output sequence, with `nn.utils.rnn.pad_packed_sequence`, to transform it from a packed sequence to a tensor. The elements of `output` from padding tokens will be zero tensors (tensors where every element is zero). Usually, we only have to unpack output if we are going to use it later on in the model. Although we aren't in this case, we still unpack the sequence just to show how it is done.\n",
"\n",
"The final hidden state, `hidden`, has a shape of _**[num layers * num directions, batch size, hid dim]**_. These are ordered: **[forward_layer_0, backward_layer_0, forward_layer_1, backward_layer 1, ..., forward_layer_n, backward_layer n]**. As we want the final (top) layer forward and backward hidden states, we get the top two hidden layers from the first dimension, `hidden[-2,:,:]` and `hidden[-1,:,:]`, and concatenate them together before passing them to the linear layer (after applying dropout). "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"class RNN(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, \n",
" bidirectional, dropout, pad_idx):\n",
" \n",
" super().__init__()\n",
" \n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)\n",
" \n",
" self.rnn = nn.LSTM(embedding_dim, \n",
" hidden_dim, \n",
" num_layers=n_layers, \n",
" bidirectional=bidirectional, \n",
" dropout=dropout)\n",
" \n",
" self.fc = nn.Linear(hidden_dim * 2, output_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, text, text_lengths):\n",
" \n",
" #text = [sent len, batch size]\n",
" \n",
" embedded = self.dropout(self.embedding(text))\n",
" \n",
" #embedded = [sent len, batch size, emb dim]\n",
" \n",
" #pack sequence\n",
" # lengths need to be on CPU!\n",
" packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'))\n",
" \n",
" packed_output, (hidden, cell) = self.rnn(packed_embedded)\n",
" \n",
" #unpack sequence\n",
" output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)\n",
"\n",
" #output = [sent len, batch size, hid dim * num directions]\n",
" #output over padding tokens are zero tensors\n",
" \n",
" #hidden = [num layers * num directions, batch size, hid dim]\n",
" #cell = [num layers * num directions, batch size, hid dim]\n",
" \n",
" #concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers\n",
" #and apply dropout\n",
" \n",
" hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))\n",
" \n",
" #hidden = [batch size, hid dim * num directions]\n",
" \n",
" return self.fc(hidden)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Like before, we'll create an instance of our RNN class, with the new parameters and arguments for the number of layers, bidirectionality and dropout probability.\n",
"\n",
"To ensure the pre-trained vectors can be loaded into the model, the `EMBEDDING_DIM` must be equal to that of the pre-trained GloVe vectors loaded earlier.\n",
"\n",
"We get our pad token index from the vocabulary, getting the actual string representing the pad token from the field's `pad_token` attribute, which is `<pad>` by default."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"INPUT_DIM = len(TEXT.vocab)\n",
"EMBEDDING_DIM = 100\n",
"HIDDEN_DIM = 256\n",
"OUTPUT_DIM = 1\n",
"N_LAYERS = 2\n",
"BIDIRECTIONAL = True\n",
"DROPOUT = 0.5\n",
"PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
"\n",
"model = RNN(INPUT_DIM, \n",
" EMBEDDING_DIM, \n",
" HIDDEN_DIM, \n",
" OUTPUT_DIM, \n",
" N_LAYERS, \n",
" BIDIRECTIONAL, \n",
" DROPOUT, \n",
" PAD_IDX)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll print out the number of parameters in our model. \n",
"\n",
"Notice how we have almost twice as many parameters as before!"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model has 4,810,857 trainable parameters\n"
]
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The final addition is copying the pre-trained word embeddings we loaded earlier into the `embedding` layer of our model.\n",
"\n",
"We retrieve the embeddings from the field's vocab, and check they're the correct size, _**[vocab size, embedding dim]**_ "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([25002, 100])\n"
]
}
],
"source": [
"pretrained_embeddings = TEXT.vocab.vectors\n",
"\n",
"print(pretrained_embeddings.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We then replace the initial weights of the `embedding` layer with the pre-trained embeddings.\n",
"\n",
"**Note**: this should always be done on the `weight.data` and not the `weight`!"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.1117, -0.4966, 0.1631, ..., 1.2647, -0.2753, -0.1325],\n",
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
" ...,\n",
" [ 0.6783, 0.0488, 0.5860, ..., 0.2680, -0.0086, 0.5758],\n",
" [-0.6208, -0.0480, -0.1046, ..., 0.3718, 0.1225, 0.1061],\n",
" [-0.6553, -0.6292, 0.9967, ..., 0.2278, -0.1975, 0.0857]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.embedding.weight.data.copy_(pretrained_embeddings)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As our `<unk>` and `<pad>` token aren't in the pre-trained vocabulary they have been initialized using `unk_init` (an $\\mathcal{N}(0,1)$ distribution) when building our vocab. It is preferable to initialize them both to all zeros to explicitly tell our model that, initially, they are irrelevant for determining sentiment. \n",
"\n",
"We do this by manually setting their row in the embedding weights matrix to zeros. We get their row by finding the index of the tokens, which we have already done for the padding index.\n",
"\n",
"**Note**: like initializing the embeddings, this should be done on the `weight.data` and not the `weight`!"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
" ...,\n",
" [ 0.6783, 0.0488, 0.5860, ..., 0.2680, -0.0086, 0.5758],\n",
" [-0.6208, -0.0480, -0.1046, ..., 0.3718, 0.1225, 0.1061],\n",
" [-0.6553, -0.6292, 0.9967, ..., 0.2278, -0.1975, 0.0857]])\n"
]
}
],
"source": [
"UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n",
"\n",
"model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n",
"model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)\n",
"\n",
"print(model.embedding.weight.data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now see the first two rows of the embedding weights matrix have been set to zeros. As we passed the index of the pad token to the `padding_idx` of the embedding layer it will remain zeros throughout training, however the `<unk>` token embedding will be learned."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now to training the model.\n",
"\n",
"The only change we'll make here is changing the optimizer from `SGD` to `Adam`. SGD updates all parameters with the same learning rate and choosing this learning rate can be tricky. `Adam` adapts the learning rate for each parameter, giving parameters that are updated more frequently lower learning rates and parameters that are updated infrequently higher learning rates. More information about `Adam` (and other optimizers) can be found [here](http://ruder.io/optimizing-gradient-descent/index.html).\n",
"\n",
"To change `SGD` to `Adam`, we simply change `optim.SGD` to `optim.Adam`, also note how we do not have to provide an initial learning rate for Adam as PyTorch specifies a sensibile default initial learning rate."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"\n",
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The rest of the steps for training the model are unchanged.\n",
"\n",
"We define the criterion and place the model and criterion on the GPU (if available)..."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.BCEWithLogitsLoss()\n",
"\n",
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We implement the function to calculate accuracy..."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def binary_accuracy(preds, y):\n",
" \"\"\"\n",
" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
" \"\"\"\n",
"\n",
" #round predictions to the closest integer\n",
" rounded_preds = torch.round(torch.sigmoid(preds))\n",
" correct = (rounded_preds == y).float() #convert into float for division \n",
" acc = correct.sum() / len(correct)\n",
" return acc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define a function for training our model. \n",
"\n",
"As we have set `include_lengths = True`, our `batch.text` is now a tuple with the first element being the numericalized tensor and the second element being the actual lengths of each sequence. We separate these into their own variables, `text` and `text_lengths`, before passing them to the model.\n",
"\n",
"**Note**: as we are now using dropout, we must remember to use `model.train()` to ensure the dropout is \"turned on\" while training."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def train(model, iterator, optimizer, criterion):\n",
" \n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" \n",
" model.train()\n",
" \n",
" for batch in iterator:\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" text, text_lengths = batch.text\n",
" \n",
" predictions = model(text, text_lengths).squeeze(1)\n",
" \n",
" loss = criterion(predictions, batch.label)\n",
" \n",
" acc = binary_accuracy(predictions, batch.label)\n",
" \n",
" loss.backward()\n",
" \n",
" optimizer.step()\n",
" \n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
" \n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we define a function for testing our model, again remembering to separate `batch.text`.\n",
"\n",
"**Note**: as we are now using dropout, we must remember to use `model.eval()` to ensure the dropout is \"turned off\" while evaluating."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(model, iterator, criterion):\n",
" \n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" \n",
" model.eval()\n",
" \n",
" with torch.no_grad():\n",
" \n",
" for batch in iterator:\n",
"\n",
" text, text_lengths = batch.text\n",
" \n",
" predictions = model(text, text_lengths).squeeze(1)\n",
" \n",
" loss = criterion(predictions, batch.label)\n",
" \n",
" acc = binary_accuracy(predictions, batch.label)\n",
"\n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
" \n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And also create a nice function to tell us how long our epochs are taking."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"def epoch_time(start_time, end_time):\n",
" elapsed_time = end_time - start_time\n",
" elapsed_mins = int(elapsed_time / 60)\n",
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
" return elapsed_mins, elapsed_secs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we train our model..."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 01 | Epoch Time: 0m 36s\n",
"\tTrain Loss: 0.673 | Train Acc: 58.05%\n",
"\t Val. Loss: 0.619 | Val. Acc: 64.97%\n",
"Epoch: 02 | Epoch Time: 0m 36s\n",
"\tTrain Loss: 0.611 | Train Acc: 66.33%\n",
"\t Val. Loss: 0.510 | Val. Acc: 74.32%\n",
"Epoch: 03 | Epoch Time: 0m 37s\n",
"\tTrain Loss: 0.484 | Train Acc: 77.04%\n",
"\t Val. Loss: 0.397 | Val. Acc: 82.95%\n",
"Epoch: 04 | Epoch Time: 0m 37s\n",
"\tTrain Loss: 0.384 | Train Acc: 83.57%\n",
"\t Val. Loss: 0.407 | Val. Acc: 83.23%\n",
"Epoch: 05 | Epoch Time: 0m 37s\n",
"\tTrain Loss: 0.314 | Train Acc: 86.98%\n",
"\t Val. Loss: 0.314 | Val. Acc: 86.36%\n"
]
}
],
"source": [
"N_EPOCHS = 5\n",
"\n",
"best_valid_loss = float('inf')\n",
"\n",
"for epoch in range(N_EPOCHS):\n",
"\n",
" start_time = time.time()\n",
" \n",
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
" \n",
" end_time = time.time()\n",
"\n",
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
" \n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), 'tut2-model.pt')\n",
" \n",
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...and get our new and vastly improved test accuracy!"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Loss: 0.334 | Test Acc: 85.28%\n"
]
}
],
"source": [
"model.load_state_dict(torch.load('tut2-model.pt'))\n",
"\n",
"test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
"\n",
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## User Input\n",
"\n",
"We can now use our model to predict the sentiment of any sentence we give it. As it has been trained on movie reviews, the sentences provided should also be movie reviews.\n",
"\n",
"When using a model for inference it should always be in evaluation mode. If this tutorial is followed step-by-step then it should already be in evaluation mode (from doing `evaluate` on the test set), however we explicitly set it to avoid any risk.\n",
"\n",
"Our `predict_sentiment` function does a few things:\n",
"- sets the model to evaluation mode\n",
"- tokenizes the sentence, i.e. splits it from a raw string into a list of tokens\n",
"- indexes the tokens by converting them into their integer representation from our vocabulary\n",
"- gets the length of our sequence\n",
"- converts the indexes, which are a Python list into a PyTorch tensor\n",
"- add a batch dimension by `unsqueeze`ing \n",
"- converts the length into a tensor\n",
"- squashes the output prediction from a real number between 0 and 1 with the `sigmoid` function\n",
"- converts the tensor holding a single value into an integer with the `item()` method\n",
"\n",
"We are expecting reviews with a negative sentiment to return a value close to 0 and positive reviews to return a value close to 1."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"import spacy\n",
"nlp = spacy.load('en_core_web_sm')\n",
"\n",
"def predict_sentiment(model, sentence):\n",
" model.eval()\n",
" tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
" indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
" length = [len(indexed)]\n",
" tensor = torch.LongTensor(indexed).to(device)\n",
" tensor = tensor.unsqueeze(1)\n",
" length_tensor = torch.LongTensor(length)\n",
" prediction = torch.sigmoid(model(tensor, length_tensor))\n",
" return prediction.item()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An example negative review..."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.05380420759320259"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict_sentiment(model, \"This film is terrible\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An example positive review..."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.94941645860672"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict_sentiment(model, \"This film is great\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Next Steps\n",
"\n",
"We've now built a decent sentiment analysis model for movie reviews! In the next notebook we'll implement a model that gets comparable accuracy with far fewer parameters and trains much, much faster."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}