2018-06-06 05:16:05 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3 - Faster Sentiment Analysis\n",
"\n",
2019-04-01 23:11:52 +08:00
"In the previous notebook we managed to achieve a decent test accuracy of ~84% using all of the common techniques used for sentiment analysis. In this notebook, we'll implement a model that gets comparable results whilst training significantly faster and using around half of the parameters. More specifically, we'll be implementing the \"FastText\" model from the paper [Bag of Tricks for Efficient Text Classification](https://arxiv.org/abs/1607.01759)."
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preparing Data\n",
"\n",
"One of the key concepts in the FastText paper is that they calculate the n-grams of an input sentence and append them to the end of a sentence. Here, we'll use bi-grams. Briefly, a bi-gram is a pair of words/tokens that appear consecutively within a sentence. \n",
"\n",
"For example, in the sentence \"how are you ?\", the bi-grams are: \"how are\", \"are you\" and \"you ?\".\n",
"\n",
"The `generate_bigrams` function takes a sentence that has already been tokenized, calculates the bi-grams and appends them to the end of the tokenized list."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"def generate_bigrams(x):\n",
" n_grams = set(zip(*[x[i:] for i in range(2)]))\n",
" for n_gram in n_grams:\n",
" x.append(' '.join(n_gram))\n",
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As an example:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-02-17 20:52:57 +08:00
"['This', 'film', 'is', 'terrible', 'film is', 'This film', 'is terrible']"
2018-06-06 05:16:05 +08:00
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"generate_bigrams(['This', 'film', 'is', 'terrible'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-04-01 23:11:52 +08:00
"TorchText `Field`s have a `preprocessing` argument. A function passed here will be applied to a sentence after it has been tokenized (transformed from a string into a list of tokens), but before it has been numericalized (transformed from a list of tokens to a list of indexes). This is where we'll pass our `generate_bigrams` function.\n",
"\n",
"As we aren't using an RNN we can't use packed padded sequences, thus we do not need to set `include_lengths = True`."
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
2021-02-17 20:52:57 +08:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
]
}
],
2018-06-06 05:16:05 +08:00
"source": [
"import torch\n",
2021-03-12 21:25:47 +08:00
"from torchtext.legacy import data\n",
"from torchtext.legacy import datasets\n",
2018-06-06 05:16:05 +08:00
"\n",
2018-07-05 19:47:04 +08:00
"SEED = 1234\n",
"\n",
"torch.manual_seed(SEED)\n",
2018-10-18 00:54:31 +08:00
"torch.backends.cudnn.deterministic = True\n",
2018-06-06 05:16:05 +08:00
"\n",
2021-02-17 20:52:57 +08:00
"TEXT = data.Field(tokenize = 'spacy',\n",
" tokenizer_language = 'en_core_web_sm',\n",
" preprocessing = generate_bigrams)\n",
"\n",
2019-03-30 00:57:00 +08:00
"LABEL = data.LabelField(dtype = torch.float)"
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As before, we load the IMDb dataset and create the splits."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
2021-02-17 20:52:57 +08:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
]
}
],
2018-06-06 05:16:05 +08:00
"source": [
2018-07-05 19:47:04 +08:00
"import random\n",
"\n",
2018-10-18 00:54:31 +08:00
"train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
2018-06-06 05:16:05 +08:00
"\n",
2019-03-30 00:57:00 +08:00
"train_data, valid_data = train_data.split(random_state = random.seed(SEED))"
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2018-06-06 23:06:25 +08:00
"Build the vocab and load the pre-trained word embeddings."
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
2018-07-05 19:47:04 +08:00
"outputs": [],
2018-06-06 05:16:05 +08:00
"source": [
2019-03-30 00:57:00 +08:00
"MAX_VOCAB_SIZE = 25_000\n",
"\n",
"TEXT.build_vocab(train_data, \n",
" max_size = MAX_VOCAB_SIZE, \n",
" vectors = \"glove.6B.100d\", \n",
" unk_init = torch.Tensor.normal_)\n",
"\n",
2018-10-18 00:54:31 +08:00
"LABEL.build_vocab(train_data)"
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And create the iterators."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
2021-02-17 20:52:57 +08:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
]
}
],
2018-06-06 05:16:05 +08:00
"source": [
"BATCH_SIZE = 64\n",
"\n",
2018-10-18 00:54:31 +08:00
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
2018-07-05 19:47:04 +08:00
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
2018-10-18 00:54:31 +08:00
" (train_data, valid_data, test_data), \n",
2019-03-30 00:57:00 +08:00
" batch_size = BATCH_SIZE, \n",
" device = device)"
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build the Model\n",
"\n",
"This model has far fewer parameters than the previous model as it only has 2 layers that have any parameters, the embedding layer and the linear layer. There is no RNN component in sight!\n",
"\n",
2019-03-10 23:45:40 +08:00
"Instead, it first calculates the word embedding for each word using the `Embedding` layer (blue), then calculates the average of all of the word embeddings (pink) and feeds this through the `Linear` layer (silver), and that's it!\n",
2018-06-06 05:16:05 +08:00
"\n",
2019-03-10 23:45:40 +08:00
"![](assets/sentiment8.png)\n",
2018-06-06 05:16:05 +08:00
"\n",
2018-10-18 00:54:31 +08:00
"We implement the averaging with the `avg_pool2d` (average pool 2-dimensions) function. Initially, you may think using a 2-dimensional pooling seems strange, surely our sentences are 1-dimensional, not 2-dimensional? However, you can think of the word embeddings as a 2-dimensional grid, where the words are along one axis and the dimensions of the word embeddings are along the other. The image below is an example sentence after being converted into 5-dimensional word embeddings, with the words along the vertical axis and the embeddings along the horizontal axis. Each element in this [4x5] tensor is represented by a green block.\n",
2018-06-06 05:16:05 +08:00
"\n",
2019-03-10 23:45:40 +08:00
"![](assets/sentiment9.png)\n",
2018-06-06 05:16:05 +08:00
"\n",
2018-10-18 00:54:31 +08:00
"The `avg_pool2d` uses a filter of size `embedded.shape[1]` (i.e. the length of the sentence) by 1. This is shown in pink in the image below.\n",
2018-06-06 05:16:05 +08:00
"\n",
2019-03-10 23:45:40 +08:00
"![](assets/sentiment10.png)\n",
2018-06-06 05:16:05 +08:00
"\n",
2019-03-10 23:45:40 +08:00
"We calculate the average value of all elements covered by the filter, then the filter then slides to the right, calculating the average over the next column of embedding values for each word in the sentence. \n",
"\n",
"![](assets/sentiment11.png)\n",
"\n",
"Each filter position gives us a single value, the average of all covered elements. After the filter has covered all embedding dimensions we get a [1x5] tensor. This tensor is then passed through the linear layer to produce our prediction."
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
2018-10-18 00:54:31 +08:00
"import torch.nn.functional as F\n",
2018-06-06 05:16:05 +08:00
"\n",
"class FastText(nn.Module):\n",
2019-03-23 22:45:52 +08:00
" def __init__(self, vocab_size, embedding_dim, output_dim, pad_idx):\n",
2019-03-30 00:57:00 +08:00
" \n",
2018-06-06 05:16:05 +08:00
" super().__init__()\n",
" \n",
2019-03-23 22:45:52 +08:00
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)\n",
2019-03-30 00:57:00 +08:00
" \n",
2018-06-06 05:16:05 +08:00
" self.fc = nn.Linear(embedding_dim, output_dim)\n",
" \n",
2019-03-10 23:45:40 +08:00
" def forward(self, text):\n",
2018-06-06 05:16:05 +08:00
" \n",
2019-03-10 23:45:40 +08:00
" #text = [sent len, batch size]\n",
2018-06-06 05:16:05 +08:00
" \n",
2019-03-10 23:45:40 +08:00
" embedded = self.embedding(text)\n",
2018-06-06 05:16:05 +08:00
" \n",
" #embedded = [sent len, batch size, emb dim]\n",
" \n",
" embedded = embedded.permute(1, 0, 2)\n",
" \n",
" #embedded = [batch size, sent len, emb dim]\n",
" \n",
" pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1) \n",
" \n",
" #pooled = [batch size, embedding_dim]\n",
" \n",
" return self.fc(pooled)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As previously, we'll create an instance of our `FastText` class."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"INPUT_DIM = len(TEXT.vocab)\n",
"EMBEDDING_DIM = 100\n",
"OUTPUT_DIM = 1\n",
2019-03-23 22:45:52 +08:00
"PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
2018-06-06 05:16:05 +08:00
"\n",
2019-03-23 22:45:52 +08:00
"model = FastText(INPUT_DIM, EMBEDDING_DIM, OUTPUT_DIM, PAD_IDX)"
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-03-22 06:48:21 +08:00
"Looking at the number of parameters in our model, we see we have about the same as the standard RNN from the first notebook and half the parameters of the previous model."
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
2019-03-22 06:48:21 +08:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model has 2,500,301 trainable parameters\n"
]
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And copy the pre-trained vectors to our embedding layer."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
2018-06-06 05:16:05 +08:00
"outputs": [
{
"data": {
"text/plain": [
2019-03-22 01:22:25 +08:00
"tensor([[-0.1117, -0.4966, 0.1631, ..., 1.2647, -0.2753, -0.1325],\n",
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
2018-06-06 05:16:05 +08:00
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
" ...,\n",
2021-02-17 20:52:57 +08:00
" [-0.1606, -0.7357, 0.5809, ..., 0.8704, -1.5637, -1.5724],\n",
" [-1.3126, -1.6717, 0.4203, ..., 0.2348, -0.9110, 1.0914],\n",
" [-1.5268, 1.5639, -1.0541, ..., 1.0045, -0.6813, -0.8846]])"
2018-06-06 05:16:05 +08:00
]
},
2019-03-22 06:48:21 +08:00
"execution_count": 10,
2018-06-06 05:16:05 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pretrained_embeddings = TEXT.vocab.vectors\n",
"\n",
"model.embedding.weight.data.copy_(pretrained_embeddings)"
]
},
2019-03-23 22:45:52 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not forgetting to zero the initial weights of our unknown and padding tokens."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n",
"\n",
"model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n",
"model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)"
]
},
2018-06-06 05:16:05 +08:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Training the model is the exact same as last time.\n",
"\n",
"We initialize our optimizer..."
]
},
{
"cell_type": "code",
2019-03-23 22:45:52 +08:00
"execution_count": 12,
2018-06-06 05:16:05 +08:00
"metadata": {},
"outputs": [],
"source": [
"import torch.optim as optim\n",
"\n",
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define the criterion and place the model and criterion on the GPU (if available)..."
]
},
{
"cell_type": "code",
2019-03-23 22:45:52 +08:00
"execution_count": 13,
2018-06-06 05:16:05 +08:00
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.BCEWithLogitsLoss()\n",
"\n",
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We implement the function to calculate accuracy..."
]
},
{
"cell_type": "code",
2019-03-23 22:45:52 +08:00
"execution_count": 14,
2018-06-06 05:16:05 +08:00
"metadata": {},
"outputs": [],
"source": [
"def binary_accuracy(preds, y):\n",
" \"\"\"\n",
" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
" \"\"\"\n",
"\n",
" #round predictions to the closest integer\n",
2018-10-18 00:54:31 +08:00
" rounded_preds = torch.round(torch.sigmoid(preds))\n",
2018-06-06 05:16:05 +08:00
" correct = (rounded_preds == y).float() #convert into float for division \n",
2019-03-30 00:57:00 +08:00
" acc = correct.sum() / len(correct)\n",
2018-06-06 05:16:05 +08:00
" return acc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define a function for training our model...\n",
"\n",
"**Note**: we are no longer using dropout so we do not need to use `model.train()`, but as mentioned in the 1st notebook, it is good practice to use it."
]
},
{
"cell_type": "code",
2019-03-23 22:45:52 +08:00
"execution_count": 15,
2018-06-06 05:16:05 +08:00
"metadata": {},
"outputs": [],
"source": [
"def train(model, iterator, optimizer, criterion):\n",
" \n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" \n",
" model.train()\n",
" \n",
" for batch in iterator:\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" predictions = model(batch.text).squeeze(1)\n",
" \n",
" loss = criterion(predictions, batch.label)\n",
" \n",
" acc = binary_accuracy(predictions, batch.label)\n",
" \n",
" loss.backward()\n",
" \n",
" optimizer.step()\n",
" \n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
" \n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We define a function for testing our model...\n",
"\n",
"**Note**: again, we leave `model.eval()` even though we do not use dropout."
]
},
{
"cell_type": "code",
2019-03-23 22:45:52 +08:00
"execution_count": 16,
2018-06-06 05:16:05 +08:00
"metadata": {},
"outputs": [],
"source": [
2018-07-05 18:40:41 +08:00
"def evaluate(model, iterator, criterion):\n",
2018-06-06 05:16:05 +08:00
" \n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" \n",
" model.eval()\n",
" \n",
" with torch.no_grad():\n",
" \n",
" for batch in iterator:\n",
"\n",
" predictions = model(batch.text).squeeze(1)\n",
" \n",
" loss = criterion(predictions, batch.label)\n",
" \n",
" acc = binary_accuracy(predictions, batch.label)\n",
"\n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
" \n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-03-22 06:48:21 +08:00
"As before, we'll implement a useful function to tell us how long an epoch takes."
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "code",
2019-03-23 22:45:52 +08:00
"execution_count": 17,
2019-03-22 06:48:21 +08:00
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"def epoch_time(start_time, end_time):\n",
" elapsed_time = end_time - start_time\n",
" elapsed_mins = int(elapsed_time / 60)\n",
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
" return elapsed_mins, elapsed_secs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-04-10 17:27:20 +08:00
"Finally, we train our model."
2019-03-22 06:48:21 +08:00
]
},
{
"cell_type": "code",
2019-03-23 22:45:52 +08:00
"execution_count": 18,
2018-06-06 05:16:05 +08:00
"metadata": {},
"outputs": [
2021-02-17 20:52:57 +08:00
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
]
},
2018-06-06 05:16:05 +08:00
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-02-17 20:52:57 +08:00
"Epoch: 01 | Epoch Time: 0m 7s\n",
"\tTrain Loss: 0.688 | Train Acc: 61.31%\n",
"\t Val. Loss: 0.637 | Val. Acc: 72.46%\n",
"Epoch: 02 | Epoch Time: 0m 6s\n",
"\tTrain Loss: 0.651 | Train Acc: 75.04%\n",
"\t Val. Loss: 0.507 | Val. Acc: 76.92%\n",
"Epoch: 03 | Epoch Time: 0m 6s\n",
"\tTrain Loss: 0.578 | Train Acc: 79.91%\n",
"\t Val. Loss: 0.424 | Val. Acc: 80.97%\n",
"Epoch: 04 | Epoch Time: 0m 6s\n",
"\tTrain Loss: 0.501 | Train Acc: 83.97%\n",
"\t Val. Loss: 0.377 | Val. Acc: 84.34%\n",
"Epoch: 05 | Epoch Time: 0m 6s\n",
"\tTrain Loss: 0.435 | Train Acc: 86.96%\n",
"\t Val. Loss: 0.363 | Val. Acc: 86.18%\n"
2018-06-06 05:16:05 +08:00
]
}
],
"source": [
2019-04-10 17:27:20 +08:00
"N_EPOCHS = 5\n",
2019-03-22 06:48:21 +08:00
"\n",
"best_valid_loss = float('inf')\n",
2018-06-06 05:16:05 +08:00
"\n",
"for epoch in range(N_EPOCHS):\n",
"\n",
2019-03-22 06:48:21 +08:00
" start_time = time.time()\n",
" \n",
2018-07-05 19:47:04 +08:00
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
2018-06-06 05:16:05 +08:00
" \n",
2019-03-22 06:48:21 +08:00
" end_time = time.time()\n",
"\n",
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
" \n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), 'tut3-model.pt')\n",
" \n",
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...and get the test accuracy!\n",
"\n",
2019-03-23 22:45:52 +08:00
"The results are comparable to the results in the last notebook, but training takes considerably less time!"
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "code",
2019-03-23 22:45:52 +08:00
"execution_count": 19,
2018-06-06 05:16:05 +08:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-02-17 20:52:57 +08:00
"Test Loss: 0.381 | Test Acc: 85.42%\n"
2018-06-06 05:16:05 +08:00
]
}
],
"source": [
2019-03-22 06:48:21 +08:00
"model.load_state_dict(torch.load('tut3-model.pt'))\n",
"\n",
2018-07-05 19:47:04 +08:00
"test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
2018-06-06 05:16:05 +08:00
"\n",
2019-03-22 06:48:21 +08:00
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## User Input\n",
"\n",
2019-03-22 01:22:25 +08:00
"And as before, we can test on any input the user provides making sure to generate bigrams from our tokenized sentence."
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "code",
2019-03-23 22:45:52 +08:00
"execution_count": 20,
2018-06-06 05:16:05 +08:00
"metadata": {},
"outputs": [],
"source": [
"import spacy\n",
2021-02-17 20:52:57 +08:00
"nlp = spacy.load('en_core_web_sm')\n",
2018-06-06 05:16:05 +08:00
"\n",
2019-04-10 17:27:20 +08:00
"def predict_sentiment(model, sentence):\n",
" model.eval()\n",
2019-03-22 01:22:25 +08:00
" tokenized = generate_bigrams([tok.text for tok in nlp.tokenizer(sentence)])\n",
2018-06-06 05:16:05 +08:00
" indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
" tensor = torch.LongTensor(indexed).to(device)\n",
" tensor = tensor.unsqueeze(1)\n",
2018-10-18 00:54:31 +08:00
" prediction = torch.sigmoid(model(tensor))\n",
2018-06-06 05:16:05 +08:00
" return prediction.item()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An example negative review..."
]
},
{
"cell_type": "code",
2019-03-23 22:45:52 +08:00
"execution_count": 21,
2018-06-06 05:16:05 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-02-17 20:52:57 +08:00
"2.1313092350011553e-12"
2018-06-06 05:16:05 +08:00
]
},
2019-03-23 22:45:52 +08:00
"execution_count": 21,
2018-06-06 05:16:05 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2019-04-10 17:27:20 +08:00
"predict_sentiment(model, \"This film is terrible\")"
2018-06-06 05:16:05 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An example positive review..."
]
},
{
"cell_type": "code",
2019-03-23 22:45:52 +08:00
"execution_count": 22,
2018-06-06 05:16:05 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2019-03-22 06:48:21 +08:00
"1.0"
2018-06-06 05:16:05 +08:00
]
},
2019-03-23 22:45:52 +08:00
"execution_count": 22,
2018-06-06 05:16:05 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2019-04-10 17:27:20 +08:00
"predict_sentiment(model, \"This film is great\")"
2018-06-06 05:16:05 +08:00
]
2018-07-06 00:12:00 +08:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Next Steps\n",
"\n",
2019-03-22 06:48:21 +08:00
"In the next notebook we'll use convolutional neural networks (CNNs) to perform sentiment analysis."
2018-07-06 00:12:00 +08:00
]
2018-06-06 05:16:05 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-02-17 20:52:57 +08:00
"version": "3.8.5"
2018-06-06 05:16:05 +08:00
}
},
"nbformat": 4,
"nbformat_minor": 2
2021-03-12 21:25:47 +08:00
}