pytorch-sentiment-analysis/6 - Transformers for Sentiment Analysis.ipynb
2021-03-12 13:25:47 +00:00

1045 lines
34 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 6 - Transformers for Sentiment Analysis\n",
"\n",
"In this notebook we will be using the transformer model, first introduced in [this](https://arxiv.org/abs/1706.03762) paper. Specifically, we will be using the BERT (Bidirectional Encoder Representations from Transformers) model from [this](https://arxiv.org/abs/1810.04805) paper. \n",
"\n",
"Transformer models are considerably larger than anything else covered in these tutorials. As such we are going to use the [transformers library](https://github.com/huggingface/transformers) to get pre-trained transformers and use them as our embedding layers. We will freeze (not train) the transformer and only train the remainder of the model which learns from the representations produced by the transformer. In this case we will be using a multi-layer bi-directional GRU, however any model can learn from these representations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preparing Data\n",
"\n",
"First, as always, let's set the random seeds for deterministic results."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"import random\n",
"import numpy as np\n",
"\n",
"SEED = 1234\n",
"\n",
"random.seed(SEED)\n",
"np.random.seed(SEED)\n",
"torch.manual_seed(SEED)\n",
"torch.backends.cudnn.deterministic = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The transformer has already been trained with a specific vocabulary, which means we need to train with the exact same vocabulary and also tokenize our data in the same way that the transformer did when it was initially trained.\n",
"\n",
"Luckily, the transformers library has tokenizers for each of the transformer models provided. In this case we are using the BERT model which ignores casing (i.e. will lower case every word). We get this by loading the pre-trained `bert-base-uncased` tokenizer."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from transformers import BertTokenizer\n",
"\n",
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `tokenizer` has a `vocab` attribute which contains the actual vocabulary we will be using. We can check how many tokens are in it by checking its length."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"30522"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(tokenizer.vocab)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using the tokenizer is as simple as calling `tokenizer.tokenize` on a string. This will tokenize and lower case the data in a way that is consistent with the pre-trained transformer model."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['hello', 'world', 'how', 'are', 'you', '?']\n"
]
}
],
"source": [
"tokens = tokenizer.tokenize('Hello WORLD how ARE yoU?')\n",
"\n",
"print(tokens)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can numericalize tokens using our vocabulary using `tokenizer.convert_tokens_to_ids`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[7592, 2088, 2129, 2024, 2017, 1029]\n"
]
}
],
"source": [
"indexes = tokenizer.convert_tokens_to_ids(tokens)\n",
"\n",
"print(indexes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The transformer was also trained with special tokens to mark the beginning and end of the sentence, detailed [here](https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel). As well as a standard padding and unknown token. We can also get these from the tokenizer.\n",
"\n",
"**Note**: the tokenizer does have a beginning of sequence and end of sequence attributes (`bos_token` and `eos_token`) but these are not set and should not be used for this transformer."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[CLS] [SEP] [PAD] [UNK]\n"
]
}
],
"source": [
"init_token = tokenizer.cls_token\n",
"eos_token = tokenizer.sep_token\n",
"pad_token = tokenizer.pad_token\n",
"unk_token = tokenizer.unk_token\n",
"\n",
"print(init_token, eos_token, pad_token, unk_token)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can get the indexes of the special tokens by converting them using the vocabulary..."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"101 102 0 100\n"
]
}
],
"source": [
"init_token_idx = tokenizer.convert_tokens_to_ids(init_token)\n",
"eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)\n",
"pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)\n",
"unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)\n",
"\n",
"print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...or by explicitly getting them from the tokenizer."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"101 102 0 100\n"
]
}
],
"source": [
"init_token_idx = tokenizer.cls_token_id\n",
"eos_token_idx = tokenizer.sep_token_id\n",
"pad_token_idx = tokenizer.pad_token_id\n",
"unk_token_idx = tokenizer.unk_token_id\n",
"\n",
"print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another thing we need to handle is that the model was trained on sequences with a defined maximum length - it does not know how to handle sequences longer than it has been trained on. We can get the maximum length of these input sizes by checking the `max_model_input_sizes` for the version of the transformer we want to use. In this case, it is 512 tokens."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"512\n"
]
}
],
"source": [
"max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']\n",
"\n",
"print(max_input_length)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Previously we have used the `spaCy` tokenizer to tokenize our examples. However we now need to define a function that we will pass to our `TEXT` field that will handle all the tokenization for us. It will also cut down the number of tokens to a maximum length. Note that our maximum length is 2 less than the actual maximum length. This is because we need to append two tokens to each sequence, one to the start and one to the end."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def tokenize_and_cut(sentence):\n",
" tokens = tokenizer.tokenize(sentence) \n",
" tokens = tokens[:max_input_length-2]\n",
" return tokens"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we define our fields. The transformer expects the batch dimension to be first, so we set `batch_first = True`. As we already have the vocabulary for our text, provided by the transformer we set `use_vocab = False` to tell torchtext that we'll be handling the vocabulary side of things. We pass our `tokenize_and_cut` function as the tokenizer. The `preprocessing` argument is a function that takes in the example after it has been tokenized, this is where we will convert the tokens to their indexes. Finally, we define the special tokens - making note that we are defining them to be their index value and not their string value, i.e. `100` instead of `[UNK]` This is because the sequences will already be converted into indexes.\n",
"\n",
"We define the label field as before."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
]
}
],
"source": [
"from torchtext.legacy import data\n",
"\n",
"TEXT = data.Field(batch_first = True,\n",
" use_vocab = False,\n",
" tokenize = tokenize_and_cut,\n",
" preprocessing = tokenizer.convert_tokens_to_ids,\n",
" init_token = init_token_idx,\n",
" eos_token = eos_token_idx,\n",
" pad_token = pad_token_idx,\n",
" unk_token = unk_token_idx)\n",
"\n",
"LABEL = data.LabelField(dtype = torch.float)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We load the data and create the validation splits as before."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
]
}
],
"source": [
"from torchtext.legacy import datasets\n",
"\n",
"train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
"\n",
"train_data, valid_data = train_data.split(random_state = random.seed(SEED))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of training examples: 17500\n",
"Number of validation examples: 7500\n",
"Number of testing examples: 25000\n"
]
}
],
"source": [
"print(f\"Number of training examples: {len(train_data)}\")\n",
"print(f\"Number of validation examples: {len(valid_data)}\")\n",
"print(f\"Number of testing examples: {len(test_data)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can check an example and ensure that the text has already been numericalized."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'text': [1042, 4140, 1996, 2087, 2112, 1010, 2023, 3185, 5683, 2066, 1037, 1000, 2081, 1011, 2005, 1011, 2694, 1000, 3947, 1012, 1996, 3257, 2003, 10654, 1011, 28273, 1010, 1996, 3772, 1006, 2007, 1996, 6453, 1997, 5965, 1043, 11761, 2638, 1007, 2003, 2058, 13088, 10593, 2102, 1998, 7815, 2100, 1012, 15339, 14282, 1010, 3391, 1010, 18058, 2014, 3210, 2066, 2016, 1005, 1055, 3147, 3752, 2068, 2125, 1037, 16091, 4003, 1012, 2069, 2028, 2518, 3084, 2023, 2143, 4276, 3666, 1010, 1998, 2008, 2003, 2320, 10012, 3310, 2067, 2013, 1996, 1000, 7367, 11368, 5649, 1012, 1000, 2045, 2003, 2242, 14888, 2055, 3666, 1037, 2235, 2775, 4028, 2619, 1010, 1998, 2023, 3185, 2453, 2022, 2062, 2084, 2070, 2064, 5047, 2074, 2005, 2008, 3114, 1012, 2009, 2003, 7078, 5923, 1011, 27017, 1012, 2023, 2143, 2069, 2515, 2028, 2518, 2157, 1010, 2021, 2009, 21145, 2008, 2028, 2518, 2157, 2041, 1997, 1996, 2380, 1012, 4276, 3773, 2074, 2005, 1996, 2197, 2184, 2781, 2030, 2061, 1012], 'label': 'neg'}\n"
]
}
],
"source": [
"print(vars(train_data.examples[6]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use the `convert_ids_to_tokens` to transform these indexes back into readable tokens."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['f', '##ot', 'the', 'most', 'part', ',', 'this', 'movie', 'feels', 'like', 'a', '\"', 'made', '-', 'for', '-', 'tv', '\"', 'effort', '.', 'the', 'direction', 'is', 'ham', '-', 'fisted', ',', 'the', 'acting', '(', 'with', 'the', 'exception', 'of', 'fred', 'g', '##wyn', '##ne', ')', 'is', 'over', '##wr', '##ough', '##t', 'and', 'soap', '##y', '.', 'denise', 'crosby', ',', 'particularly', ',', 'delivers', 'her', 'lines', 'like', 'she', \"'\", 's', 'cold', 'reading', 'them', 'off', 'a', 'cue', 'card', '.', 'only', 'one', 'thing', 'makes', 'this', 'film', 'worth', 'watching', ',', 'and', 'that', 'is', 'once', 'gage', 'comes', 'back', 'from', 'the', '\"', 'se', '##met', '##ary', '.', '\"', 'there', 'is', 'something', 'disturbing', 'about', 'watching', 'a', 'small', 'child', 'murder', 'someone', ',', 'and', 'this', 'movie', 'might', 'be', 'more', 'than', 'some', 'can', 'handle', 'just', 'for', 'that', 'reason', '.', 'it', 'is', 'absolutely', 'bone', '-', 'chilling', '.', 'this', 'film', 'only', 'does', 'one', 'thing', 'right', ',', 'but', 'it', 'knocks', 'that', 'one', 'thing', 'right', 'out', 'of', 'the', 'park', '.', 'worth', 'seeing', 'just', 'for', 'the', 'last', '10', 'minutes', 'or', 'so', '.']\n"
]
}
],
"source": [
"tokens = tokenizer.convert_ids_to_tokens(vars(train_data.examples[6])['text'])\n",
"\n",
"print(tokens)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Although we've handled the vocabulary for the text, we still need to build the vocabulary for the labels."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"LABEL.build_vocab(train_data)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"defaultdict(None, {'neg': 0, 'pos': 1})\n"
]
}
],
"source": [
"print(LABEL.vocab.stoi)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As before, we create the iterators. Ideally we want to use the largest batch size that we can as I've found this gives the best results for transformers."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
]
}
],
"source": [
"BATCH_SIZE = 128\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
" (train_data, valid_data, test_data), \n",
" batch_size = BATCH_SIZE, \n",
" device = device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build the Model\n",
"\n",
"Next, we'll load the pre-trained model, making sure to load the same model as we did for the tokenizer."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"from transformers import BertTokenizer, BertModel\n",
"\n",
"bert = BertModel.from_pretrained('bert-base-uncased')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we'll define our actual model. \n",
"\n",
"Instead of using an embedding layer to get embeddings for our text, we'll be using the pre-trained transformer model. These embeddings will then be fed into a GRU to produce a prediction for the sentiment of the input sentence. We get the embedding dimension size (called the `hidden_size`) from the transformer via its config attribute. The rest of the initialization is standard.\n",
"\n",
"Within the forward pass, we wrap the transformer in a `no_grad` to ensure no gradients are calculated over this part of the model. The transformer actually returns the embeddings for the whole sequence as well as a *pooled* output. The [documentation](https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel) states that the pooled output is \"usually not a good summary of the semantic content of the input, youre often better with averaging or pooling the sequence of hidden-states for the whole input sequence\", hence we will not be using it. The rest of the forward pass is the standard implementation of a recurrent model, where we take the hidden state over the final time-step, and pass it through a linear layer to get our predictions."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"class BERTGRUSentiment(nn.Module):\n",
" def __init__(self,\n",
" bert,\n",
" hidden_dim,\n",
" output_dim,\n",
" n_layers,\n",
" bidirectional,\n",
" dropout):\n",
" \n",
" super().__init__()\n",
" \n",
" self.bert = bert\n",
" \n",
" embedding_dim = bert.config.to_dict()['hidden_size']\n",
" \n",
" self.rnn = nn.GRU(embedding_dim,\n",
" hidden_dim,\n",
" num_layers = n_layers,\n",
" bidirectional = bidirectional,\n",
" batch_first = True,\n",
" dropout = 0 if n_layers < 2 else dropout)\n",
" \n",
" self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)\n",
" \n",
" self.dropout = nn.Dropout(dropout)\n",
" \n",
" def forward(self, text):\n",
" \n",
" #text = [batch size, sent len]\n",
" \n",
" with torch.no_grad():\n",
" embedded = self.bert(text)[0]\n",
" \n",
" #embedded = [batch size, sent len, emb dim]\n",
" \n",
" _, hidden = self.rnn(embedded)\n",
" \n",
" #hidden = [n layers * n directions, batch size, emb dim]\n",
" \n",
" if self.rnn.bidirectional:\n",
" hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))\n",
" else:\n",
" hidden = self.dropout(hidden[-1,:,:])\n",
" \n",
" #hidden = [batch size, hid dim]\n",
" \n",
" output = self.out(hidden)\n",
" \n",
" #output = [batch size, out dim]\n",
" \n",
" return output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we create an instance of our model using standard hyperparameters."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"HIDDEN_DIM = 256\n",
"OUTPUT_DIM = 1\n",
"N_LAYERS = 2\n",
"BIDIRECTIONAL = True\n",
"DROPOUT = 0.25\n",
"\n",
"model = BERTGRUSentiment(bert,\n",
" HIDDEN_DIM,\n",
" OUTPUT_DIM,\n",
" N_LAYERS,\n",
" BIDIRECTIONAL,\n",
" DROPOUT)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can check how many parameters the model has. Our standard models have under 5M, but this one has 112M! Luckily, 110M of these parameters are from the transformer and we will not be training those."
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model has 112,241,409 trainable parameters\n"
]
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to freeze paramers (not train them) we need to set their `requires_grad` attribute to `False`. To do this, we simply loop through all of the `named_parameters` in our model and if they're a part of the `bert` transformer model, we set `requires_grad = False`. "
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"for name, param in model.named_parameters(): \n",
" if name.startswith('bert'):\n",
" param.requires_grad = False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now see that our model has under 3M trainable parameters, making it almost comparable to the `FastText` model. However, the text still has to propagate through the transformer which causes training to take considerably longer."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model has 2,759,169 trainable parameters\n"
]
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can double check the names of the trainable parameters, ensuring they make sense. As we can see, they are all the parameters of the GRU (`rnn`) and the linear layer (`out`)."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rnn.weight_ih_l0\n",
"rnn.weight_hh_l0\n",
"rnn.bias_ih_l0\n",
"rnn.bias_hh_l0\n",
"rnn.weight_ih_l0_reverse\n",
"rnn.weight_hh_l0_reverse\n",
"rnn.bias_ih_l0_reverse\n",
"rnn.bias_hh_l0_reverse\n",
"rnn.weight_ih_l1\n",
"rnn.weight_hh_l1\n",
"rnn.bias_ih_l1\n",
"rnn.bias_hh_l1\n",
"rnn.weight_ih_l1_reverse\n",
"rnn.weight_hh_l1_reverse\n",
"rnn.bias_ih_l1_reverse\n",
"rnn.bias_hh_l1_reverse\n",
"out.weight\n",
"out.bias\n"
]
}
],
"source": [
"for name, param in model.named_parameters(): \n",
" if param.requires_grad:\n",
" print(name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the Model\n",
"\n",
"As is standard, we define our optimizer and criterion (loss function)."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"\n",
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.BCEWithLogitsLoss()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Place the model and criterion onto the GPU (if available)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we'll define functions for: calculating accuracy, performing a training epoch, performing an evaluation epoch and calculating how long a training/evaluation epoch takes."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"def binary_accuracy(preds, y):\n",
" \"\"\"\n",
" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
" \"\"\"\n",
"\n",
" #round predictions to the closest integer\n",
" rounded_preds = torch.round(torch.sigmoid(preds))\n",
" correct = (rounded_preds == y).float() #convert into float for division \n",
" acc = correct.sum() / len(correct)\n",
" return acc"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"def train(model, iterator, optimizer, criterion):\n",
" \n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" \n",
" model.train()\n",
" \n",
" for batch in iterator:\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" predictions = model(batch.text).squeeze(1)\n",
" \n",
" loss = criterion(predictions, batch.label)\n",
" \n",
" acc = binary_accuracy(predictions, batch.label)\n",
" \n",
" loss.backward()\n",
" \n",
" optimizer.step()\n",
" \n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
" \n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(model, iterator, criterion):\n",
" \n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" \n",
" model.eval()\n",
" \n",
" with torch.no_grad():\n",
" \n",
" for batch in iterator:\n",
"\n",
" predictions = model(batch.text).squeeze(1)\n",
" \n",
" loss = criterion(predictions, batch.label)\n",
" \n",
" acc = binary_accuracy(predictions, batch.label)\n",
"\n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
" \n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"def epoch_time(start_time, end_time):\n",
" elapsed_time = end_time - start_time\n",
" elapsed_mins = int(elapsed_time / 60)\n",
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
" return elapsed_mins, elapsed_secs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we'll train our model. This takes considerably longer than any of the previous models due to the size of the transformer. Even though we are not training any of the transformer's parameters we still need to pass the data through the model which takes a considerable amount of time on a standard GPU."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 01 | Epoch Time: 7m 13s\n",
"\tTrain Loss: 0.502 | Train Acc: 74.41%\n",
"\t Val. Loss: 0.270 | Val. Acc: 89.15%\n",
"Epoch: 02 | Epoch Time: 7m 7s\n",
"\tTrain Loss: 0.281 | Train Acc: 88.49%\n",
"\t Val. Loss: 0.224 | Val. Acc: 91.32%\n",
"Epoch: 03 | Epoch Time: 7m 17s\n",
"\tTrain Loss: 0.239 | Train Acc: 90.67%\n",
"\t Val. Loss: 0.211 | Val. Acc: 91.91%\n",
"Epoch: 04 | Epoch Time: 7m 14s\n",
"\tTrain Loss: 0.206 | Train Acc: 91.81%\n",
"\t Val. Loss: 0.206 | Val. Acc: 92.01%\n",
"Epoch: 05 | Epoch Time: 7m 15s\n",
"\tTrain Loss: 0.188 | Train Acc: 92.63%\n",
"\t Val. Loss: 0.211 | Val. Acc: 91.92%\n"
]
}
],
"source": [
"N_EPOCHS = 5\n",
"\n",
"best_valid_loss = float('inf')\n",
"\n",
"for epoch in range(N_EPOCHS):\n",
" \n",
" start_time = time.time()\n",
" \n",
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
" \n",
" end_time = time.time()\n",
" \n",
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
" \n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), 'tut6-model.pt')\n",
" \n",
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll load up the parameters that gave us the best validation loss and try these on the test set - which gives us our best results so far!"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Loss: 0.209 | Test Acc: 91.58%\n"
]
}
],
"source": [
"model.load_state_dict(torch.load('tut6-model.pt'))\n",
"\n",
"test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
"\n",
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference\n",
"\n",
"We'll then use the model to test the sentiment of some sequences. We tokenize the input sequence, trim it down to the maximum length, add the special tokens to either side, convert it to a tensor, add a fake batch dimension and then pass it through our model."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"def predict_sentiment(model, tokenizer, sentence):\n",
" model.eval()\n",
" tokens = tokenizer.tokenize(sentence)\n",
" tokens = tokens[:max_input_length-2]\n",
" indexed = [init_token_idx] + tokenizer.convert_tokens_to_ids(tokens) + [eos_token_idx]\n",
" tensor = torch.LongTensor(indexed).to(device)\n",
" tensor = tensor.unsqueeze(0)\n",
" prediction = torch.sigmoid(model(tensor))\n",
" return prediction.item()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.03391794115304947"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict_sentiment(model, tokenizer, \"This film is terrible\")"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8869886994361877"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict_sentiment(model, tokenizer, \"This film is great\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}