pytorch-sentiment-analysis/2_lstm.ipynb

884 lines
148 KiB
Plaintext
Raw Permalink Normal View History

2021-07-08 18:12:50 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
2021-07-09 02:05:18 +08:00
"id": "f23a152d",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"import functools\n",
2021-07-09 02:05:18 +08:00
"import sys\n",
2021-07-08 18:12:50 +08:00
"\n",
"import datasets\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
2021-07-08 18:12:50 +08:00
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
2021-07-09 02:05:18 +08:00
"import torchtext\n",
"import tqdm"
2021-07-08 18:12:50 +08:00
]
},
{
"cell_type": "code",
"execution_count": 2,
2021-07-09 02:05:18 +08:00
"id": "c661e3c4",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-09 02:05:18 +08:00
"<torch._C.Generator at 0x7f37d143a9d0>"
2021-07-08 18:12:50 +08:00
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seed = 0\n",
"\n",
"torch.manual_seed(seed)"
]
},
{
"cell_type": "code",
"execution_count": 3,
2021-07-09 02:05:18 +08:00
"id": "638a120e",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset imdb (/home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)\n"
]
}
],
"source": [
"train_data, test_data = datasets.load_dataset('imdb', split=['train', 'test'])"
]
},
{
"cell_type": "code",
"execution_count": 4,
2021-07-09 02:05:18 +08:00
"id": "7b34799a",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"tokenizer = torchtext.data.utils.get_tokenizer('basic_english')"
]
},
{
"cell_type": "code",
"execution_count": 5,
2021-07-09 02:05:18 +08:00
"id": "6fa2f2e3",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"def tokenize_data(example, tokenizer, max_length):\n",
" tokens = tokenizer(example['text'])[:max_length]\n",
" length = len(tokens)\n",
" return {'tokens': tokens, 'length': length}"
]
},
{
"cell_type": "code",
"execution_count": 6,
2021-07-09 02:05:18 +08:00
"id": "1f3a3894",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-98e263656f586667.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-08f9ff60341e8990.arrow\n"
]
}
],
"source": [
"max_length = 256\n",
"\n",
"train_data = train_data.map(tokenize_data, fn_kwargs={'tokenizer': tokenizer, 'max_length': max_length})\n",
"test_data = test_data.map(tokenize_data, fn_kwargs={'tokenizer': tokenizer, 'max_length': max_length})"
]
},
{
"cell_type": "code",
"execution_count": 7,
2021-07-09 02:05:18 +08:00
"id": "7e5bd85d",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached split indices for dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-28b136fb2a4d67fd.arrow and /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-8fa05fb343b1e79a.arrow\n"
]
}
],
"source": [
"test_size = 0.25\n",
"\n",
"train_valid_data = train_data.train_test_split(test_size=test_size)\n",
"train_data = train_valid_data['train']\n",
"valid_data = train_valid_data['test']"
]
},
{
"cell_type": "code",
"execution_count": 8,
2021-07-09 02:05:18 +08:00
"id": "0bf984df",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"min_freq = 5\n",
"special_tokens = ['<unk>', '<pad>']\n",
"\n",
"vocab = torchtext.vocab.build_vocab_from_iterator(train_data['tokens'],\n",
" min_freq=min_freq,\n",
" specials=special_tokens)"
]
},
{
"cell_type": "code",
"execution_count": 9,
2021-07-09 02:05:18 +08:00
"id": "5147a8fd",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"unk_index = vocab['<unk>']\n",
"pad_index = vocab['<pad>']"
]
},
{
"cell_type": "code",
"execution_count": 10,
2021-07-09 02:05:18 +08:00
"id": "8b97bda7",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"vocab.set_default_index(unk_index)"
]
},
{
"cell_type": "code",
"execution_count": 11,
2021-07-09 02:05:18 +08:00
"id": "843282aa",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"def numericalize_data(example, vocab):\n",
2021-07-08 18:26:28 +08:00
" ids = [vocab[token] for token in example['tokens']]\n",
" return {'ids': ids}"
2021-07-08 18:12:50 +08:00
]
},
{
"cell_type": "code",
"execution_count": 12,
2021-07-09 02:05:18 +08:00
"id": "885b504a",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-631f583ffb0d9c68.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-6f5a6c52dcbaf2d0.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-c455c5d5c41c2779.arrow\n"
2021-07-08 18:12:50 +08:00
]
}
],
"source": [
"train_data = train_data.map(numericalize_data, fn_kwargs={'vocab': vocab})\n",
"valid_data = valid_data.map(numericalize_data, fn_kwargs={'vocab': vocab})\n",
"test_data = test_data.map(numericalize_data, fn_kwargs={'vocab': vocab})"
]
},
{
"cell_type": "code",
"execution_count": 13,
2021-07-09 02:05:18 +08:00
"id": "2b956558",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"train_data = train_data.with_format(type='torch', columns=['ids', 'label', 'length'])\n",
"valid_data = valid_data.with_format(type='torch', columns=['ids', 'label', 'length'])\n",
"test_data = test_data.with_format(type='torch', columns=['ids', 'label', 'length'])"
]
},
{
"cell_type": "code",
"execution_count": 14,
2021-07-09 02:05:18 +08:00
"id": "53575424",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'label': tensor(1),\n",
" 'length': tensor(256),\n",
" 'ids': tensor([ 98, 13, 6329, 6, 559, 13, 6491, 4, 2, 763,\n",
" 6, 6300, 17, 34, 7, 2, 195, 116, 98, 40,\n",
" 2302, 102, 3497, 44, 3318, 15422, 21, 261, 3609, 3433,\n",
" 3, 474, 4, 6093, 6, 10888, 54, 396, 1198, 338,\n",
" 4479, 4, 14, 23, 1481, 3596, 19, 5, 13453, 850,\n",
" 23, 3, 2, 639, 7, 14, 23, 10, 1073, 20,\n",
" 2302, 9, 7180, 9372, 7, 1045, 2522, 4, 1706, 2115,\n",
" 4, 212, 8127, 6, 3179, 1485, 3, 2, 386, 7,\n",
" 13210, 860, 233, 10, 5, 12948, 8, 2, 984, 212,\n",
" 628, 346, 13, 1228, 7, 462, 4, 6, 2, 1236,\n",
" 1675, 114, 6, 2, 905, 13, 10802, 59, 71, 35,\n",
" 1132, 19, 2, 3009, 4, 13, 117, 771, 4, 8,\n",
" 3582, 3534, 9, 16, 10802, 3, 446, 4, 11, 10,\n",
" 2302, 9, 3013, 20, 1810, 6389, 15, 4846, 14, 23,\n",
" 13, 100, 6865, 3, 2, 113, 7, 14, 64, 10,\n",
" 406, 443, 527, 5, 525, 4470, 10812, 7, 23, 3324,\n",
" 3, 190, 4, 500, 14, 3049, 4, 2, 64, 2675,\n",
" 20, 356, 389, 40, 2302, 4, 7843, 6, 262, 3111,\n",
" 14039, 25, 2146, 24, 106, 14, 23, 5, 1777, 8,\n",
" 108, 3, 4281, 2302, 2890, 137, 29, 71, 2, 2386,\n",
" 0, 4, 22, 2, 3544, 7847, 19, 2, 1172, 22,\n",
" 1813, 7915, 3, 6041, 7843, 4, 42, 17, 922, 8,\n",
" 2302, 38, 2, 65, 4, 2100, 20, 50, 604, 556,\n",
" 19, 5, 416, 11476, 6, 310, 27, 5, 221, 7,\n",
" 158, 1248, 6, 16505, 3, 454, 4, 3111, 14039, 637,\n",
" 6, 1079, 1074, 49, 8928, 9])}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[0]"
]
},
{
"cell_type": "code",
"execution_count": 15,
2021-07-09 02:05:18 +08:00
"id": "53427b55",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"class LSTM(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional,\n",
" dropout_rate, pad_index):\n",
2021-07-08 18:12:50 +08:00
" super().__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
2021-07-08 18:12:50 +08:00
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, bidirectional=bidirectional,\n",
" dropout=dropout_rate, batch_first=True)\n",
" self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)\n",
" self.dropout = nn.Dropout(dropout_rate)\n",
" \n",
" def forward(self, ids, length):\n",
" # ids = [batch size, seq len]\n",
" # length = [batch size]\n",
" embedded = self.dropout(self.embedding(ids))\n",
" # embedded = [batch size, seq len, embedding dim]\n",
" packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=True, \n",
" enforce_sorted=False)\n",
" packed_output, (hidden, cell) = self.lstm(packed_embedded)\n",
" # hidden = [n layers * n directions, batch size, hidden dim]\n",
" # cell = [n layers * n directions, batch size, hidden dim]\n",
" output, output_length = nn.utils.rnn.pad_packed_sequence(packed_output)\n",
" # output = [batch size, seq len, hidden dim * n directions]\n",
" if self.lstm.bidirectional:\n",
" hidden = self.dropout(torch.cat([hidden[-1], hidden[-2]], dim=-1))\n",
" # hidden = [batch size, hidden dim * 2]\n",
" else:\n",
" hidden = self.dropout(hidden[-1])\n",
" # hidden = [batch size, hidden dim]\n",
" prediction = self.fc(hidden)\n",
" # prediction = [batch size, output dim]\n",
" return prediction"
]
},
{
"cell_type": "code",
"execution_count": 16,
2021-07-09 02:05:18 +08:00
"id": "11206188",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"vocab_size = len(vocab)\n",
"embedding_dim = 300\n",
"hidden_dim = 300\n",
"output_dim = len(train_data.unique('label'))\n",
"n_layers = 2\n",
"bidirectional = True\n",
"dropout_rate = 0.5\n",
2021-07-08 18:12:50 +08:00
"\n",
"model = LSTM(vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout_rate, \n",
" pad_index)"
2021-07-08 18:12:50 +08:00
]
},
{
"cell_type": "code",
"execution_count": 17,
2021-07-09 02:05:18 +08:00
"id": "5feb9512",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model has 10,073,702 trainable parameters\n"
]
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
]
},
{
"cell_type": "code",
"execution_count": 18,
2021-07-09 02:05:18 +08:00
"id": "3edc8e02",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"def initialize_weights(m):\n",
" if isinstance(m, nn.Linear):\n",
" nn.init.xavier_normal_(m.weight)\n",
" nn.init.zeros_(m.bias)\n",
" elif isinstance(m, nn.LSTM):\n",
" for name, param in m.named_parameters():\n",
" if 'bias' in name:\n",
" nn.init.zeros_(param)\n",
" elif 'weight' in name:\n",
" nn.init.orthogonal_(param)"
]
},
{
"cell_type": "code",
"execution_count": 19,
2021-07-09 02:05:18 +08:00
"id": "98490c40",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LSTM(\n",
" (embedding): Embedding(21543, 300, padding_idx=1)\n",
" (lstm): LSTM(300, 300, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)\n",
2021-07-08 18:12:50 +08:00
" (fc): Linear(in_features=600, out_features=2, bias=True)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
2021-07-08 18:12:50 +08:00
")"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.apply(initialize_weights)"
]
},
{
"cell_type": "code",
"execution_count": 20,
2021-07-09 02:05:18 +08:00
"id": "72bfd654",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"vectors = torchtext.vocab.FastText()"
]
},
{
"cell_type": "code",
"execution_count": 21,
2021-07-09 02:05:18 +08:00
"id": "6ec1ed34",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"pretrained_embedding = vectors.get_vecs_by_tokens(vocab.get_itos())"
]
},
{
"cell_type": "code",
"execution_count": 22,
2021-07-09 02:05:18 +08:00
"id": "7489711f",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"model.embedding.weight.data = pretrained_embedding"
]
},
{
"cell_type": "code",
"execution_count": 23,
2021-07-09 02:05:18 +08:00
"id": "e2d0b14e",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"lr = 5e-4\n",
"\n",
"optimizer = optim.Adam(model.parameters(), lr=lr)"
2021-07-08 18:12:50 +08:00
]
},
{
"cell_type": "code",
"execution_count": 24,
2021-07-09 02:05:18 +08:00
"id": "d798a6bd",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 25,
2021-07-09 02:05:18 +08:00
"id": "4a780705",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 26,
2021-07-09 02:05:18 +08:00
"id": "5c8302f0",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 27,
2021-07-09 02:05:18 +08:00
"id": "070a2098",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"def collate(batch, pad_index):\n",
" batch_ids = [i['ids'] for i in batch]\n",
" batch_ids = nn.utils.rnn.pad_sequence(batch_ids, padding_value=pad_index, batch_first=True)\n",
" batch_length = [i['length'] for i in batch]\n",
" batch_length = torch.stack(batch_length)\n",
" batch_label = [i['label'] for i in batch]\n",
" batch_label = torch.stack(batch_label)\n",
" batch = {'ids': batch_ids,\n",
" 'length': batch_length,\n",
" 'label': batch_label}\n",
" return batch"
]
},
{
"cell_type": "code",
"execution_count": 28,
2021-07-09 02:05:18 +08:00
"id": "48efdace",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"batch_size = 512\n",
"\n",
"collate = functools.partial(collate, pad_index=pad_index)\n",
"\n",
"train_dataloader = torch.utils.data.DataLoader(train_data, \n",
" batch_size=batch_size, \n",
" collate_fn=collate, \n",
" shuffle=True)\n",
"\n",
"valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, collate_fn=collate)\n",
"test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, collate_fn=collate)"
]
},
{
"cell_type": "code",
"execution_count": 29,
2021-07-09 02:05:18 +08:00
"id": "8a1e9b07",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"def train(dataloader, model, criterion, optimizer, device):\n",
"\n",
" model.train()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
2021-07-08 18:12:50 +08:00
"\n",
2021-07-09 02:05:18 +08:00
" for batch in tqdm.tqdm(dataloader, desc='training...', file=sys.stdout):\n",
2021-07-08 18:12:50 +08:00
" ids = batch['ids'].to(device)\n",
" length = batch['length']\n",
" label = batch['label'].to(device)\n",
" prediction = model(ids, length)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
2021-07-08 18:12:50 +08:00
"\n",
" return epoch_losses, epoch_accs"
2021-07-08 18:12:50 +08:00
]
},
{
"cell_type": "code",
"execution_count": 30,
2021-07-09 02:05:18 +08:00
"id": "c7988786",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"def evaluate(dataloader, model, criterion, device):\n",
" \n",
" model.eval()\n",
" epoch_losses = []\n",
" epoch_accs = []\n",
2021-07-08 18:12:50 +08:00
"\n",
" with torch.no_grad():\n",
2021-07-09 02:05:18 +08:00
" for batch in tqdm.tqdm(dataloader, desc='evaluating...', file=sys.stdout):\n",
2021-07-08 18:12:50 +08:00
" ids = batch['ids'].to(device)\n",
" length = batch['length']\n",
" label = batch['label'].to(device)\n",
" prediction = model(ids, length)\n",
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
2021-07-08 18:12:50 +08:00
"\n",
" return epoch_losses, epoch_accs"
2021-07-08 18:12:50 +08:00
]
},
{
"cell_type": "code",
"execution_count": 31,
2021-07-09 02:05:18 +08:00
"id": "d66535bd",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"def get_accuracy(prediction, label):\n",
" batch_size, _ = prediction.shape\n",
" predicted_classes = prediction.argmax(dim=-1)\n",
" correct_predictions = predicted_classes.eq(label).sum()\n",
" accuracy = correct_predictions / batch_size\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": 32,
2021-07-09 02:05:18 +08:00
"id": "24c05b57",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:09<00:00, 3.80it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 8.70it/s]\n",
2021-07-08 18:12:50 +08:00
"epoch: 1\n",
"train_loss: 0.632, train_acc: 0.632\n",
"valid_loss: 0.470, valid_acc: 0.779\n",
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:09<00:00, 3.91it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 8.53it/s]\n",
2021-07-08 18:12:50 +08:00
"epoch: 2\n",
"train_loss: 0.522, train_acc: 0.743\n",
"valid_loss: 0.465, valid_acc: 0.775\n",
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:09<00:00, 4.08it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 8.66it/s]\n",
2021-07-08 18:12:50 +08:00
"epoch: 3\n",
2021-07-09 02:05:18 +08:00
"train_loss: 0.443, train_acc: 0.804\n",
"valid_loss: 0.393, valid_acc: 0.830\n",
"training...: 100%|██████████| 37/37 [00:09<00:00, 4.05it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 8.67it/s]\n",
2021-07-08 18:12:50 +08:00
"epoch: 4\n",
2021-07-09 02:05:18 +08:00
"train_loss: 0.369, train_acc: 0.840\n",
"valid_loss: 0.367, valid_acc: 0.848\n",
"training...: 100%|██████████| 37/37 [00:09<00:00, 4.03it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 8.60it/s]\n",
2021-07-08 18:12:50 +08:00
"epoch: 5\n",
2021-07-09 02:05:18 +08:00
"train_loss: 0.336, train_acc: 0.857\n",
"valid_loss: 0.393, valid_acc: 0.842\n",
"training...: 100%|██████████| 37/37 [00:09<00:00, 4.05it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 8.63it/s]\n",
2021-07-08 18:12:50 +08:00
"epoch: 6\n",
2021-07-09 02:05:18 +08:00
"train_loss: 0.375, train_acc: 0.837\n",
"valid_loss: 0.387, valid_acc: 0.842\n",
"training...: 100%|██████████| 37/37 [00:09<00:00, 4.04it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 8.56it/s]\n",
2021-07-08 18:12:50 +08:00
"epoch: 7\n",
2021-07-09 02:05:18 +08:00
"train_loss: 0.368, train_acc: 0.848\n",
"valid_loss: 0.437, valid_acc: 0.791\n",
"training...: 100%|██████████| 37/37 [00:09<00:00, 4.05it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 8.59it/s]\n",
2021-07-08 18:12:50 +08:00
"epoch: 8\n",
2021-07-09 02:05:18 +08:00
"train_loss: 0.324, train_acc: 0.866\n",
"valid_loss: 0.343, valid_acc: 0.854\n",
"training...: 100%|██████████| 37/37 [00:09<00:00, 4.04it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 8.53it/s]\n",
2021-07-08 18:12:50 +08:00
"epoch: 9\n",
2021-07-09 02:05:18 +08:00
"train_loss: 0.276, train_acc: 0.890\n",
"valid_loss: 0.362, valid_acc: 0.866\n",
"training...: 100%|██████████| 37/37 [00:09<00:00, 4.05it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:01<00:00, 8.57it/s]\n",
2021-07-08 18:12:50 +08:00
"epoch: 10\n",
2021-07-09 02:05:18 +08:00
"train_loss: 0.245, train_acc: 0.899\n",
"valid_loss: 0.330, valid_acc: 0.873\n"
2021-07-08 18:12:50 +08:00
]
}
],
"source": [
"n_epochs = 10\n",
"best_valid_loss = float('inf')\n",
"\n",
"train_losses = []\n",
"train_accs = []\n",
"valid_losses = []\n",
"valid_accs = []\n",
"\n",
2021-07-08 18:12:50 +08:00
"for epoch in range(n_epochs):\n",
"\n",
" train_loss, train_acc = train(train_dataloader, model, criterion, optimizer, device)\n",
" valid_loss, valid_acc = evaluate(valid_dataloader, model, criterion, device)\n",
"\n",
" train_losses.extend(train_loss)\n",
" train_accs.extend(train_acc)\n",
" valid_losses.extend(valid_loss)\n",
" valid_accs.extend(valid_acc)\n",
" \n",
" epoch_train_loss = np.mean(train_loss)\n",
" epoch_train_acc = np.mean(train_acc)\n",
" epoch_valid_loss = np.mean(valid_loss)\n",
" epoch_valid_acc = np.mean(valid_acc)\n",
" \n",
" if epoch_valid_loss < best_valid_loss:\n",
" best_valid_loss = epoch_valid_loss\n",
2021-07-08 18:12:50 +08:00
" torch.save(model.state_dict(), 'lstm.pt')\n",
" \n",
" print(f'epoch: {epoch+1}')\n",
" print(f'train_loss: {epoch_train_loss:.3f}, train_acc: {epoch_train_acc:.3f}')\n",
" print(f'valid_loss: {epoch_valid_loss:.3f}, valid_acc: {epoch_valid_acc:.3f}')"
2021-07-08 18:12:50 +08:00
]
},
{
"cell_type": "code",
"execution_count": 33,
2021-07-09 02:05:18 +08:00
"id": "b360d0cd",
"metadata": {},
"outputs": [
{
"data": {
2021-07-09 02:05:18 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFzCAYAAAB2A95GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAACzDklEQVR4nOy9eZxkVXn//zl3qb2qq7dZenpWmIFZGWBYBBFQJLjhggoG3GI0GpcYEyNmMcbEXzAxxpgvajRqFhVEVNwQlAgiyDYsMwyz79Oz9l571V3O749zz12qq7qr967meb9e8+ruWm6dqq7p+6nP8znPwzjnIAiCIAiCIGYWZbYXQBAEQRAE8WKERBhBEARBEMQsQCKMIAiCIAhiFiARRhAEQRAEMQuQCCMIgiAIgpgFSIQRBEEQBEHMAtp0Hpwxdh2AfwOgAvhPzvltVdcvB/BNAJ0ABgDcwjnvGe2YHR0dfMWKFdOzYIIgCIIgiCnk6aef7uOcd9a6btpEGGNMBXA7gFcC6AHwFGPsJ5zznb6bfR7A/3DO/5sx9nIA/wjg7aMdd8WKFdi6det0LZsgCIIgCGLKYIwdqXfddJYjLwawn3N+kHNeAXAngNdX3WYdgF873z9Y43qCIAiCIIh5yXSKsCUAjvl+7nEu87MNwJuc798IIMkYa5/GNREEQRAEQcwJZjuY/+cArmSMPQvgSgDHAVjVN2KMvY8xtpUxtrW3t3em10gQBEEQBDHlTGcw/ziApb6fu53LXDjnJ+A4YYyxBIAbOOdD1QfinH8NwNcAYMuWLTTskiAIgiCmCMMw0NPTg1KpNNtLaWoikQi6u7uh63rD95lOEfYUgNWMsZUQ4usmAL/vvwFjrAPAAOfcBvBJiJ2SBEEQBEHMED09PUgmk1ixYgUYY7O9nKaEc47+/n709PRg5cqVDd9v2sqRnHMTwIcA3A9gF4C7OOcvMMY+wxi73rnZVQD2MMb2AlgI4LPTtR6CIAiCIEZSKpXQ3t5OAmwSMMbQ3t4+bjdxWvuEcc7vBXBv1WWf8n1/N4C7p3MNBEEQBEGMDgmwyTOR13C2g/kEQRAEQbyIGRoawpe//OUJ3ffVr341hoaGGr79pz/9aXz+85+f0GNNByTCCIIgCIKYNUYTYaZpjnrfe++9F+l0ehpWNTOQCCMIgiAIYta49dZbceDAAWzevBkf//jH8dBDD+GKK67A9ddfj3Xr1gEA3vCGN+DCCy/E+vXr8bWvfc2974oVK9DX14fDhw9j7dq1eO9734v169fj2muvRbFYHPVxn3vuOVx66aXYtGkT3vjGN2JwcBAA8KUvfQnr1q3Dpk2bcNNNNwEAfvOb32Dz5s3YvHkzzj//fGSz2Sl57tOaCSMIgiAIonn4u5++gJ0nMlN6zHVdKfzt69bXvf62227Djh078NxzzwEAHnroITzzzDPYsWOHu9Pwm9/8Jtra2lAsFnHRRRfhhhtuQHt7sLf7vn37cMcdd+DrX/863vrWt+IHP/gBbrnllrqP+453vAP//u//jiuvvBKf+tSn8Hd/93f44he/iNtuuw2HDh1COBx2S52f//zncfvtt+Pyyy9HLpdDJBKZ3IviQE4YMa84PlREoTK6fU0QBEHMbS6++OJAq4cvfelLOO+883DppZfi2LFj2Ldv34j7rFy5Eps3bwYAXHjhhTh8+HDd4w8PD2NoaAhXXnklAOCd73wnHn74YQDApk2bcPPNN+Pb3/42NE14VZdffjk+9rGP4Utf+hKGhobcyycLOWHEvOKNtz+Kd7xkOT708tWzvRSCIIimYzTHaiaJx+Pu9w899BAeeOABPPbYY4jFYrjqqqtqtoIIh8Pu96qqjlmOrMfPf/5zPPzww/jpT3+Kz372s3j++edx66234jWveQ3uvfdeXH755bj//vtx7rnnTuj4fsgJI+YVQ0UDw0VjtpdBEARBNEgymRw1YzU8PIzW1lbEYjHs3r0bjz/++KQfs6WlBa2trfjtb38LAPjf//1fXHnllbBtG8eOHcPVV1+Nz33ucxgeHkYul8OBAwewceNGfOITn8BFF12E3bt3T3oNADlhxDyDcw6bBlsRBEE0De3t7bj88suxYcMGvOpVr8JrXvOawPXXXXcdvvrVr2Lt2rU455xzcOmll07J4/73f/833v/+96NQKGDVqlX41re+BcuycMstt2B4eBicc3zkIx9BOp3G3/zN3+DBBx+EoihYv349XvWqV03JGhjnzXXG2rJlC9+6detsL4OYo5z1l/fiHS9ZPmcsdYIgiLnOrl27sHbt2tlexryg1mvJGHuac76l1u2pHEnMK2zO0WSfKwiCIIgXKSTCiHkDdwSYTSqMIAiCaAJIhBHzBqm9SIMRBEEQzQCJMGLeIB0wcsIIgiCIZoBEGDFvkLsiaXckQRAE0QyQCCPmDdIBa7YdvwRBEMSLExJhxLyBMmEEQRAvDhKJBADgxIkTePOb31zzNldddRVqtbSqd/lsQCKMmDdQJowgCOLFRVdXF+6+++7ZXsaEIRFGzBs8ETbLCyEIgiAa5tZbb8Xtt9/u/vzpT38an//855HL5fCKV7wCF1xwATZu3Igf//jHI+57+PBhbNiwAQBQLBZx0003Ye3atXjjG9/Y0OzIO+64Axs3bsSGDRvwiU98AgBgWRbe9a53YcOGDdi4cSP+9V//FYAYIr5u3Tps2rQJN91001Q8dRpbRMwfbLccSSqMIAhiQvziVuDU81N7zEUbgVfdVvfqG2+8ER/96EfxwQ9+EABw11134f7770ckEsGPfvQjpFIp9PX14dJLL8X1118PxljN43zlK19BLBbDrl27sH37dlxwwQWjLuvEiRP4xCc+gaeffhqtra249tprcc8992Dp0qU4fvw4duzYAQAYGhoCANx22204dOgQwuGwe9lkISeMmDdwKkcSBEE0Heeffz7OnDmDEydOYNu2bWhtbcXSpUvBOcdf/uVfYtOmTbjmmmtw/PhxnD59uu5xHn74Ydxyyy0AgE2bNmHTpk2jPu5TTz2Fq666Cp2dndA0DTfffDMefvhhrFq1CgcPHsSHP/xh3HfffUilUu4xb775Znz729+Gpk2Nh0VOGDFvcJ2w2V0GQRBE8zKKYzWdvOUtb8Hdd9+NU6dO4cYbbwQAfOc730Fvby+efvpp6LqOFStWoFQqTftaWltbsW3bNtx///346le/irvuugvf/OY38fOf/xwPP/wwfvrTn+Kzn/0snn/++UmLMXLCiHkDZcIIgiCakxtvvBF33nkn7r77brzlLW8BAAwPD2PBggXQdR0PPvggjhw5MuoxXvayl+G73/0uAGDHjh3Yvn37qLe/+OKL8Zvf/AZ9fX2wLAt33HEHrrzySvT19cG2bdxwww34h3/4BzzzzDOwbRvHjh3D1Vdfjc997nMYHh5GLpeb9PMmJ4yYN9DuSIIgiOZk/fr1yGazWLJkCRYvXgwAuPnmm/G6170OGzduxJYtW3DuueeOeowPfOADePe73421a9di7dq1uPDCC0e9/eLFi3Hbbbfh6quvBuccr3nNa/D6178e27Ztw7vf/W7Ytg0A+Md//EdYloVbbrkFw8PD4JzjIx/5CNLp9KSfN2u2EPOWLVv4XOnvQcwtTmdKuOT/+z+8euMifPnm0f/zEQRBEIJdu3Zh7dq1s72MeUGt15Ix9jTnfEut21M5kpg3eB3zZ3khBEEQBNEAJMKIeYM3O5JUGEEQBDH3IRFGzBtsm4L5BEEQRPNAIoyYN3Bq1koQBDEh6O/m5JnIa0gijJg3UIsKgiCI8ROJRNDf309CbBJwztHf349IJDKu+1GLCmLe4AXz6Q8JQRBEo3R3d6Onpwe9vb2zvZSmJhKJoLu7e1z3IRFGzBu8YP7sroMgCKKZ0HUdK1eunO1lvCihciQxb6DZkQRBEEQzQSKMmDe4syNJgxEEQRBNAIkwYt7gZsJohDdBEATRBJAII+YN7u5Ie5YXQhAEQRANMK0ijDF2HWNsD2NsP2Ps1hrXL2OMPcgYe5Yxtp0x9urpXA8xv+HUMZ8gCIJoIqZNhDHGVAC3A3gVgHU
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(10,6))\n",
"ax = fig.add_subplot(1,1,1)\n",
"ax.plot(train_losses, label='train loss')\n",
"ax.plot(valid_losses, label='valid loss')\n",
"plt.legend()\n",
"ax.set_xlabel('updates')\n",
"ax.set_ylabel('loss');"
]
},
{
"cell_type": "code",
"execution_count": 34,
2021-07-09 02:05:18 +08:00
"id": "742a6855",
"metadata": {},
"outputs": [
{
"data": {
2021-07-09 02:05:18 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFzCAYAAAB2A95GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAC0x0lEQVR4nOydd5gkV3n1T6Wuzj05bM5Zu5JWCWWUEEECASIHYYKNCcYYbNnGIBM+MrYxUdgkEwSIHCRAGeW8q7g5zOzu5Omcqqvq++PWvXWru7qnJ+3M7N7f8+wzO93V3VXdPV2nz/ve80q2bUMgEAgEAoFAcHyR53oHBAKBQCAQCE5GhAgTCAQCgUAgmAOECBMIBAKBQCCYA4QIEwgEAoFAIJgDhAgTCAQCgUAgmAOECBMIBAKBQCCYA9S53oHJ0tHRYa9YsWKud0MgEAgEAoFgQh577LER27Y7/a5bcCJsxYoVePTRR+d6NwQCgUAgEAgmRJKkQ/WuE+VIgUAgEAgEgjlAiDCBQCAQCASCOUCIMIFAIBAIBII5YMH1hPlhGAb6+/tRLBbnelcEExAMBrFkyRJomjbXuyIQCAQCwZxyQoiw/v5+xGIxrFixApIkzfXuCOpg2zZGR0fR39+PlStXzvXuCAQCgUAwp5wQ5chisYj29nYhwOY5kiShvb1dOJYCgUAgEOAEEWEAhABbIIjXSSAQCAQCwgkjwuaSZDKJr33ta1O67Ute8hIkk8mZ3SGBQCAQCATzHiHCZoBGIqxSqTS87R/+8Ae0tLTMwl5ND9u2YVnWXO+GQCAQCAQnLEKEzQDXX3899u3bh1NPPRUf/vCHcdddd+GCCy7A1VdfjU2bNgEAXvGKV2D79u3YvHkzbrzxRnbbFStWYGRkBAcPHsTGjRvxzne+E5s3b8YVV1yBQqFQ81i//e1vcfbZZ+O0007DZZddhsHBQQBANpvF2972NpxyyinYunUrfv7znwMAbr31Vpx++unYtm0bLr30UgDADTfcgC984QvsPrds2YKDBw/i4MGDWL9+Pd7ylrdgy5Yt6Ovrw7vf/W6cccYZ2Lx5Mz72sY+x2zzyyCM499xzsW3bNpx11lnIZDK48MIL8eSTT7Jtzj//fOzYsWPmnmiBQCAQCE4gTojVkTz//ttn8OzR9Ize56ZFcXzsqs11r//MZz6Dp59+mgmQu+66C48//jiefvpptgrw29/+Ntra2lAoFHDmmWfiVa96Fdrb2z33s2fPHvz4xz/Gt771LbzmNa/Bz3/+c7zpTW/ybHP++efjwQcfhCRJ+J//+R987nOfwxe/+EV84hOfQCKRwFNPPQUAGB8fx/DwMN75znfinnvuwcqVKzE2Njbhse7Zswff+973cM455wAAPvWpT6GtrQ2maeLSSy/Fzp07sWHDBrz2ta/FT37yE5x55plIp9MIhUJ4+9vfju9+97v4z//8T+zevRvFYhHbtm1r+nkWCAQCgeBk4oQTYfOFs846yxPD8OUvfxm//OUvAQB9fX3Ys2dPjQhbuXIlTj31VADA9u3bcfDgwZr77e/vx2tf+1ocO3YM5XKZPcZtt92Gm266iW3X2tqK3/72t7jwwgvZNm1tbRPu9/Lly5kAA4Cf/vSnuPHGG1GpVHDs2DE8++yzkCQJvb29OPPMMwEA8XgcAHDttdfiE5/4BD7/+c/j29/+Nq677roJH08gEAgEgslwNFlAPKQhqi98CbPwj6CKRo7V8SQSibD/33XXXbjtttvwwAMPIBwO4+KLL/aNadB1nf1fURTfcuT73vc+fPCDH8TVV1+Nu+66CzfccMOk901VVU+/F78v/H4fOHAAX/jCF/DII4+gtbUV1113XcN4iXA4jMsvvxy//vWv8dOf/hSPPfbYpPdNIBAIBIJGvPbGB/CiTT34yMs2zfWuTBvREzYDxGIxZDKZutenUim0trYiHA7j+eefx4MPPjjlx0qlUli8eDEA4Hvf+x67/PLLL8dXv/pV9vv4+DjOOecc3HPPPThw4AAAsHLkihUr8PjjjwMAHn/8cXZ9Nel0GpFIBIlEAoODg7jlllsAAOvXr8exY8fwyCOPAAAymQxbgPCOd7wD73//+3HmmWeitbV1yscpEAgEAoEfg6kSjiRrTYqFiBBhM0B7ezvOO+88bNmyBR/+8Idrrr/yyitRqVSwceNGXH/99Z5y32S54YYbcO2112L79u3o6Ohgl3/kIx/B+Pg4tmzZgm3btuHOO+9EZ2cnbrzxRrzyla/Etm3b8NrXvhYA8KpXvQpjY2PYvHkzvvKVr2DdunW+j7Vt2zacdtpp2LBhA97whjfgvPPOAwAEAgH85Cc/wfve9z5s27YNl19+OXPItm/fjng8jre97W1TPkaBQCAQCPwoVUyUTQvj+fJc78qMINm2Pdf7MCnOOOMM+9FHH/Vc9txzz2Hjxo1ztEcCnqNHj+Liiy/G888/D1n21/ji9RIIBALBVBjLlXH6J/6MDT0x3PqBC+d6d5pCkqTHbNs+w+864YQJZozvf//7OPvss/GpT32qrgATCAQCgWCqZIuk9SWZN+Z4T2YGcaYUzBhvectb0NfXh2uvvXaud0UgEJwE2LaNV3z1Ptz69LG53pUTmv+8bTdW/fPvccoNf0TfWH7G7verd+7Fu38wuQVcmRIRX9MpR961awiXfPEuFMrmlO9jphAiTCAQCAQLkoJh4sm+JJ4+MrPZkAIvjxwcg6bIyBQreO7Y5J/rimnhM7c8j9FsyXP5k31J3PrMAJKTEFTUCStVrLoi6rZnB3HTw4fr3sdvnjyK/cM5HJ5BQTlVhAgTCAQCwYIk75yEi8bcOxonMgOpIrYuSQAAjqXqxxTVY9dgBt+4ex9uf37Ic3nRMGHbwIP7R5u+r2zJHQXo54ZligY+dPMOfPbW5+HX827bNu7bNwIAOJqa+xWWQoQJBAKBYEFCnZBSRcy5nU0G0yVsXpSApkhTEmHUvRrPeUUTff3u3zdzIux/7z2AZN7AeN7AcJXzBgD7R3IYTJPLB6ZwLDONEGECgUAgWJAUDOGEzTbZUgXZUgU9iSB6EkEcm4J7lHFE2Fi1CHNet/v2jkz6voDa5nzDtPC/fzmAxS0hAMDugWzN7e/nHuvYPMgaEyJsjohGowBIpMOrX/1q320uvvhiVMdxCAQCgYDAypHCCZs1BtPELeqO6+iNh3AsOXn3iDbT1xNh+4ZzGEwX8ejBMXzq988yUd03lsenb3kOluWWFXknbCBVxD/8dAfe8b1H8NsdR5EqGMiUKnjV9iUASBmUx7Zt3PbcEBa3hNATD+KocMIEixYtws033zzXu+ELTcEXCASC+UhB9ITNOq4IC6K3JYhj6ZlzwkqGhdWdZFTezv4UfvxwH771lwN45/cfRdEwcdtzg/jm3fs9vVtZzgm7Y9cQfv54P25/fgg/eugwe5wV7WF0RAPYNeAuIrBtGx/7zTO4e/cwXnPGUvS2BEU58kTh+uuv94wMuuGGG/CFL3wB2WwWl156KU4//XSccsop+PWvf11z24MHD2LLli0AgEKhgNe97nXYuHEjrrnmGt/ZkQDw8Y9/HGeeeSa2bNmCd73rXaz5cO/evbjsssuwbds2nH766di3bx8A4LOf/SxOOeUUbNu2Dddffz0Ar8s2MjKCFStWAAC++93v4uqrr8Yll1yCSy+9tOExfP/738fWrVuxbds2vPnNb0Ymk8HKlSthGORbTzqd9vwuEAgEM0nBICddIcImj23byJUqKFUaP3dUhPXESTlyIFX0OFPNwERYvtYJ27akBQCwayCN3YMZxIMq/rJnBHftGmJOGV+CzJYqCAcUAMBD+8kovs2L4siUDGSK5FwTC2pY1x3DrkG3HHnr0wP4/gOH8M4LVuL9l65BbyI4LxrzT7gB3rjlemDgqZm9z55TgBd/pu7Vr33ta/GBD3wA73nPewAAP/3pT/HHP/4RwWAQv/zlLxGPxzEyMoJzzjkHV199NSRJ8r2fr3/96wiHw3juueewc+dOnH766b7bvfe978VHP/pRAMCb3/xm/O53v8NVV12FN77xjbj++utxzTXXoFg
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(10,6))\n",
"ax = fig.add_subplot(1,1,1)\n",
"ax.plot(train_accs, label='train accuracy')\n",
"ax.plot(valid_accs, label='valid accuracy')\n",
"plt.legend()\n",
"ax.set_xlabel('updates')\n",
"ax.set_ylabel('accuracy');"
]
},
{
"cell_type": "code",
"execution_count": 35,
2021-07-09 02:05:18 +08:00
"id": "8b89f53f",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-07-09 02:05:18 +08:00
"evaluating...: 100%|██████████| 49/49 [00:06<00:00, 8.03it/s]\n",
"test_loss: 0.351, test_acc: 0.859\n"
2021-07-08 18:12:50 +08:00
]
}
],
"source": [
"model.load_state_dict(torch.load('lstm.pt'))\n",
"\n",
"test_loss, test_acc = evaluate(test_dataloader, model, criterion, device)\n",
"\n",
"epoch_test_loss = np.mean(test_loss)\n",
"epoch_test_acc = np.mean(test_acc)\n",
"\n",
"print(f'test_loss: {epoch_test_loss:.3f}, test_acc: {epoch_test_acc:.3f}')"
2021-07-08 18:12:50 +08:00
]
},
{
"cell_type": "code",
"execution_count": 36,
2021-07-09 02:05:18 +08:00
"id": "c07df383",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [],
"source": [
"def predict_sentiment(text, model, tokenizer, vocab, device):\n",
" tokens = tokenizer(text)\n",
" ids = [vocab[t] for t in tokens]\n",
" length = torch.LongTensor([len(ids)])\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" prediction = model(tensor, length).squeeze(dim=0)\n",
" probability = torch.softmax(prediction, dim=-1)\n",
" predicted_class = prediction.argmax(dim=-1).item()\n",
" predicted_probability = probability[predicted_class].item()\n",
" return predicted_class, predicted_probability"
]
},
{
"cell_type": "code",
"execution_count": 37,
2021-07-09 02:05:18 +08:00
"id": "8d9d591d",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-09 02:05:18 +08:00
"(0, 0.8874172568321228)"
2021-07-08 18:12:50 +08:00
]
},
"execution_count": 37,
2021-07-08 18:12:50 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": 38,
2021-07-09 02:05:18 +08:00
"id": "f392b05a",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-09 02:05:18 +08:00
"(1, 0.9508437514305115)"
2021-07-08 18:12:50 +08:00
]
},
"execution_count": 38,
2021-07-08 18:12:50 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": 39,
2021-07-09 02:05:18 +08:00
"id": "3196951d",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-09 02:05:18 +08:00
"(0, 0.5246995091438293)"
2021-07-08 18:12:50 +08:00
]
},
"execution_count": 39,
2021-07-08 18:12:50 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is not terrible, it's great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
"execution_count": 40,
2021-07-09 02:05:18 +08:00
"id": "c35aeb03",
2021-07-08 18:12:50 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-09 02:05:18 +08:00
"(1, 0.5568666458129883)"
2021-07-08 18:12:50 +08:00
]
},
"execution_count": 40,
2021-07-08 18:12:50 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is not great, it's terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
2021-07-09 02:05:18 +08:00
}