pytorch-sentiment-analysis/3_cnn.ipynb
2021-07-08 19:05:18 +01:00

825 lines
109 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "b1797021",
"metadata": {},
"outputs": [],
"source": [
"import functools\n",
"import sys\n",
"\n",
"import datasets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchtext\n",
"import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0d5b5146",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7fb81d11a950>"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seed = 0\n",
"\n",
"torch.manual_seed(seed)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1f9cda19",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset imdb (/home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)\n"
]
}
],
"source": [
"train_data, test_data = datasets.load_dataset('imdb', split=['train', 'test'])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "48f4c1f6",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = torchtext.data.utils.get_tokenizer('basic_english')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ae7e60ad",
"metadata": {},
"outputs": [],
"source": [
"def tokenize_data(example, tokenizer, max_length):\n",
" tokens = tokenizer(example['text'])[:max_length]\n",
" return {'tokens': tokens}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "eca685b6",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-ad1b7a77180a232c.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-01c0069c185da175.arrow\n"
]
}
],
"source": [
"max_length = 256\n",
"\n",
"train_data = train_data.map(tokenize_data, fn_kwargs={'tokenizer': tokenizer, 'max_length': max_length})\n",
"test_data = test_data.map(tokenize_data, fn_kwargs={'tokenizer': tokenizer, 'max_length': max_length})"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "cb53b268",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached split indices for dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-90b2a85f23273ecd.arrow and /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-99371bdf1a536e7c.arrow\n"
]
}
],
"source": [
"test_size = 0.25\n",
"\n",
"train_valid_data = train_data.train_test_split(test_size=test_size)\n",
"train_data = train_valid_data['train']\n",
"valid_data = train_valid_data['test']"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a7f7d1d7",
"metadata": {},
"outputs": [],
"source": [
"min_freq = 5\n",
"special_tokens = ['<unk>', '<pad>']\n",
"\n",
"vocab = torchtext.vocab.build_vocab_from_iterator(train_data['tokens'],\n",
" min_freq=min_freq,\n",
" specials=special_tokens)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d029794a",
"metadata": {},
"outputs": [],
"source": [
"unk_index = vocab['<unk>']\n",
"pad_index = vocab['<pad>']"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "03aa4514",
"metadata": {},
"outputs": [],
"source": [
"vocab.set_default_index(unk_index)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0133bdd3",
"metadata": {},
"outputs": [],
"source": [
"def numericalize_data(example, vocab):\n",
" ids = [vocab[token] for token in example['tokens']]\n",
" return {'ids': ids}"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "a8deac4e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-4fa96f7122a515e2.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-cabd43c688223ded.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-087b09fd94e05553.arrow\n"
]
}
],
"source": [
"train_data = train_data.map(numericalize_data, fn_kwargs={'vocab': vocab})\n",
"valid_data = valid_data.map(numericalize_data, fn_kwargs={'vocab': vocab})\n",
"test_data = test_data.map(numericalize_data, fn_kwargs={'vocab': vocab})"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "29f4bd82",
"metadata": {},
"outputs": [],
"source": [
"train_data = train_data.with_format(type='torch', columns=['ids', 'label'])\n",
"valid_data = valid_data.with_format(type='torch', columns=['ids', 'label'])\n",
"test_data = test_data.with_format(type='torch', columns=['ids', 'label'])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "65cd046f",
"metadata": {},
"outputs": [],
"source": [
"class CNN(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout_rate, \n",
" pad_index):\n",
" super().__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
" self.convs = nn.ModuleList([nn.Conv1d(embedding_dim, \n",
" n_filters, \n",
" filter_size) \n",
" for filter_size in filter_sizes])\n",
" self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
" self.dropout = nn.Dropout(dropout_rate)\n",
" \n",
" def forward(self, ids):\n",
" # ids = [batch size, seq len]\n",
" embedded = self.dropout(self.embedding(ids))\n",
" # embedded = [batch size, seq len, embedding dim]\n",
" embedded = embedded.permute(0,2,1)\n",
" # embedded = [batch size, embedding dim, seq len]\n",
" conved = [torch.relu(conv(embedded)) for conv in self.convs]\n",
" # conved_n = [batch size, n filters, seq len - filter_sizes[n] + 1]\n",
" pooled = [conv.max(dim=-1).values for conv in conved]\n",
" # pooled_n = [batch size, n filters]\n",
" cat = self.dropout(torch.cat(pooled, dim=-1))\n",
" # cat = [batch size, n filters * len(filter_sizes)]\n",
" prediction = self.fc(cat)\n",
" # prediction = [batch size, output dim]\n",
" return prediction"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ad3da9c4",
"metadata": {},
"outputs": [],
"source": [
"vocab_size = len(vocab)\n",
"embedding_dim = 300\n",
"n_filters = 100\n",
"filter_sizes = [3,5,7]\n",
"output_dim = len(train_data.unique('label'))\n",
"dropout_rate = 0.25\n",
"\n",
"model = CNN(vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout_rate, pad_index)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e5b9314c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model has 6,913,802 trainable parameters\n"
]
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "48dd9079",
"metadata": {},
"outputs": [],
"source": [
"def initialize_weights(m):\n",
" if isinstance(m, nn.Linear):\n",
" nn.init.xavier_normal_(m.weight)\n",
" nn.init.zeros_(m.bias)\n",
" elif isinstance(m, nn.Conv1d):\n",
" nn.init.kaiming_normal_(m.weight, nonlinearity='relu')\n",
" nn.init.zeros_(m.bias)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "e455a168",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"CNN(\n",
" (embedding): Embedding(21543, 300, padding_idx=1)\n",
" (convs): ModuleList(\n",
" (0): Conv1d(300, 100, kernel_size=(3,), stride=(1,))\n",
" (1): Conv1d(300, 100, kernel_size=(5,), stride=(1,))\n",
" (2): Conv1d(300, 100, kernel_size=(7,), stride=(1,))\n",
" )\n",
" (fc): Linear(in_features=300, out_features=2, bias=True)\n",
" (dropout): Dropout(p=0.25, inplace=False)\n",
")"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.apply(initialize_weights)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "cca8ce6b",
"metadata": {},
"outputs": [],
"source": [
"vectors = torchtext.vocab.FastText()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "e8f96c10",
"metadata": {},
"outputs": [],
"source": [
"pretrained_embedding = vectors.get_vecs_by_tokens(vocab.get_itos())"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "bb09a2aa",
"metadata": {},
"outputs": [],
"source": [
"model.embedding.weight.data = pretrained_embedding"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "7a5e39e9",
"metadata": {},
"outputs": [],
"source": [
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "e123ae07",
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "825a973d",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "f9512ae1",
"metadata": {},
"outputs": [],
"source": [
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "2216fbd4",
"metadata": {},
"outputs": [],
"source": [
"def collate(batch, pad_index):\n",
" batch_ids = [i['ids'] for i in batch]\n",
" batch_ids = nn.utils.rnn.pad_sequence(batch_ids, padding_value=pad_index, batch_first=True)\n",
" batch_label = [i['label'] for i in batch]\n",
" batch_label = torch.stack(batch_label)\n",
" batch = {'ids': batch_ids,\n",
" 'label': batch_label}\n",
" return batch"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "0513db80",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 512\n",
"\n",
"collate = functools.partial(collate, pad_index=pad_index)\n",
"\n",
"train_dataloader = torch.utils.data.DataLoader(train_data, \n",
" batch_size=batch_size, \n",
" collate_fn=collate, \n",
" shuffle=True)\n",
"\n",
"valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, collate_fn=collate)\n",
"test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, collate_fn=collate)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "c3651ca7",
"metadata": {},
"outputs": [],
"source": [
"def train(dataloader, model, criterion, optimizer, device):\n",
"\n",
" model.train()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
"\n",
" for batch in tqdm.tqdm(dataloader, desc='training...', file=sys.stdout):\n",
" ids = batch['ids'].to(device)\n",
" label = batch['label'].to(device)\n",
" prediction = model(ids)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
"\n",
" return epoch_losses, epoch_accs"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "f2a96019",
"metadata": {},
"outputs": [],
"source": [
"def evaluate(dataloader, model, criterion, device):\n",
" \n",
" model.eval()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
"\n",
" with torch.no_grad():\n",
" for batch in tqdm.tqdm(dataloader, desc='evaluating...', file=sys.stdout):\n",
" ids = batch['ids'].to(device)\n",
" label = batch['label'].to(device)\n",
" prediction = model(ids)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
"\n",
" return epoch_losses, epoch_accs"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "3cf2f1e1",
"metadata": {},
"outputs": [],
"source": [
"def get_accuracy(prediction, label):\n",
" batch_size, _ = prediction.shape\n",
" predicted_classes = prediction.argmax(dim=-1)\n",
" correct_predictions = predicted_classes.eq(label).sum()\n",
" accuracy = correct_predictions / batch_size\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "af6e8a15",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training...: 100%|██████████| 37/37 [00:03<00:00, 9.32it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 13.63it/s]\n",
"epoch: 1\n",
"train_loss: 0.735, train_acc: 0.594\n",
"valid_loss: 0.525, valid_acc: 0.753\n",
"training...: 100%|██████████| 37/37 [00:03<00:00, 9.73it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 13.57it/s]\n",
"epoch: 2\n",
"train_loss: 0.476, train_acc: 0.773\n",
"valid_loss: 0.390, valid_acc: 0.828\n",
"training...: 100%|██████████| 37/37 [00:03<00:00, 9.45it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 13.46it/s]\n",
"epoch: 3\n",
"train_loss: 0.347, train_acc: 0.852\n",
"valid_loss: 0.344, valid_acc: 0.846\n",
"training...: 100%|██████████| 37/37 [00:04<00:00, 8.96it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 12.55it/s]\n",
"epoch: 4\n",
"train_loss: 0.289, train_acc: 0.879\n",
"valid_loss: 0.322, valid_acc: 0.861\n",
"training...: 100%|██████████| 37/37 [00:04<00:00, 9.11it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 13.16it/s]\n",
"epoch: 5\n",
"train_loss: 0.248, train_acc: 0.903\n",
"valid_loss: 0.310, valid_acc: 0.869\n",
"training...: 100%|██████████| 37/37 [00:03<00:00, 9.26it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 13.52it/s]\n",
"epoch: 6\n",
"train_loss: 0.206, train_acc: 0.922\n",
"valid_loss: 0.296, valid_acc: 0.876\n",
"training...: 100%|██████████| 37/37 [00:03<00:00, 9.40it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 13.25it/s]\n",
"epoch: 7\n",
"train_loss: 0.172, train_acc: 0.941\n",
"valid_loss: 0.303, valid_acc: 0.876\n",
"training...: 100%|██████████| 37/37 [00:04<00:00, 9.24it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 13.03it/s]\n",
"epoch: 8\n",
"train_loss: 0.144, train_acc: 0.952\n",
"valid_loss: 0.289, valid_acc: 0.881\n",
"training...: 100%|██████████| 37/37 [00:03<00:00, 9.37it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 13.19it/s]\n",
"epoch: 9\n",
"train_loss: 0.116, train_acc: 0.965\n",
"valid_loss: 0.290, valid_acc: 0.884\n",
"training...: 100%|██████████| 37/37 [00:03<00:00, 9.36it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 12.98it/s]\n",
"epoch: 10\n",
"train_loss: 0.092, train_acc: 0.975\n",
"valid_loss: 0.295, valid_acc: 0.886\n"
]
}
],
"source": [
"n_epochs = 10\n",
"best_valid_loss = float('inf')\n",
"\n",
"train_losses = []\n",
"train_accs = []\n",
"valid_losses = []\n",
"valid_accs = []\n",
"\n",
"for epoch in range(n_epochs):\n",
"\n",
" train_loss, train_acc = train(train_dataloader, model, criterion, optimizer, device)\n",
" valid_loss, valid_acc = evaluate(valid_dataloader, model, criterion, device)\n",
"\n",
" train_losses.extend(train_loss)\n",
" train_accs.extend(train_acc)\n",
" valid_losses.extend(valid_loss)\n",
" valid_accs.extend(valid_acc)\n",
" \n",
" epoch_train_loss = np.mean(train_loss)\n",
" epoch_train_acc = np.mean(train_acc)\n",
" epoch_valid_loss = np.mean(valid_loss)\n",
" epoch_valid_acc = np.mean(valid_acc)\n",
" \n",
" if epoch_valid_loss < best_valid_loss:\n",
" best_valid_loss = epoch_valid_loss\n",
" torch.save(model.state_dict(), 'cnn.pt')\n",
" \n",
" print(f'epoch: {epoch+1}')\n",
" print(f'train_loss: {epoch_train_loss:.3f}, train_acc: {epoch_train_acc:.3f}')\n",
" print(f'valid_loss: {epoch_valid_loss:.3f}, valid_acc: {epoch_valid_acc:.3f}')"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "03860181",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(10,6))\n",
"ax = fig.add_subplot(1,1,1)\n",
"ax.plot(train_losses, label='train loss')\n",
"ax.plot(valid_losses, label='valid loss')\n",
"plt.legend()\n",
"ax.set_xlabel('updates')\n",
"ax.set_ylabel('loss');"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "df5d03f9",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(10,6))\n",
"ax = fig.add_subplot(1,1,1)\n",
"ax.plot(train_accs, label='train accuracy')\n",
"ax.plot(valid_accs, label='valid accuracy')\n",
"plt.legend()\n",
"ax.set_xlabel('updates')\n",
"ax.set_ylabel('accuracy');"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "bb00498a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"evaluating...: 100%|██████████| 49/49 [00:03<00:00, 12.51it/s]\n",
"test_loss: 0.284, test_acc: 0.879\n"
]
}
],
"source": [
"model.load_state_dict(torch.load('cnn.pt'))\n",
"\n",
"test_loss, test_acc = evaluate(test_dataloader, model, criterion, device)\n",
"\n",
"epoch_test_loss = np.mean(test_loss)\n",
"epoch_test_acc = np.mean(test_acc)\n",
"\n",
"print(f'test_loss: {epoch_test_loss:.3f}, test_acc: {epoch_test_acc:.3f}')"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "7c08b412",
"metadata": {},
"outputs": [],
"source": [
"def predict_sentiment(text, model, tokenizer, vocab, device, min_length, pad_index):\n",
" tokens = tokenizer(text)\n",
" ids = [vocab[t] for t in tokens]\n",
" if len(ids) < min_length:\n",
" ids += [pad_index] * (min_length - len(ids))\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" prediction = model(tensor).squeeze(dim=0)\n",
" probability = torch.softmax(prediction, dim=-1)\n",
" predicted_class = prediction.argmax(dim=-1).item()\n",
" predicted_probability = probability[predicted_class].item()\n",
" return predicted_class, predicted_probability"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "4fd0877a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 0.8934109807014465)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is terrible!\"\n",
"min_length = max(filter_sizes)\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device, min_length, pad_index)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "31063352",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 0.9333373308181763)"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device, min_length, pad_index)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "162aea28",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 0.6499875783920288)"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is not terrible, it's great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device, min_length, pad_index)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "83c036aa",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 0.6004905700683594)"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is not great, it's terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device, min_length, pad_index)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}