1075 lines
82 KiB
Plaintext
1075 lines
82 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "891b5e34",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import functools\n",
|
||
|
"import sys\n",
|
||
|
"\n",
|
||
|
"import datasets\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import numpy as np\n",
|
||
|
"import torch\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"import torch.optim as optim\n",
|
||
|
"import torchtext\n",
|
||
|
"import tqdm\n",
|
||
|
"import transformers"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "895ef909",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<torch._C.Generator at 0x7fb36a9ab990>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"seed = 0\n",
|
||
|
"\n",
|
||
|
"torch.manual_seed(seed)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "98f4ab7c",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Reusing dataset imdb (/home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"train_data, test_data = datasets.load_dataset('imdb', split=['train', 'test'])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "d0a1e49f",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"transformer_name = 'bert-base-uncased'\n",
|
||
|
"\n",
|
||
|
"tokenizer = transformers.AutoTokenizer.from_pretrained(transformer_name)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "4c814ae1",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"['hello', 'world', '!']"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tokenizer.tokenize('hello world!')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "ab3e4c32",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"[101, 7592, 2088, 999, 102]"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tokenizer.encode('hello world!')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "8f10453a",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"['[CLS]', 'hello', 'world', '[SEP]']"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tokenizer.convert_ids_to_tokens(tokenizer.encode('hello world'))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "824ba733",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"{'input_ids': [101, 7592, 2088, 999, 102], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tokenizer('hello world!')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"id": "9358a7aa",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def tokenize_and_numericalize_data(example, tokenizer):\n",
|
||
|
" ids = tokenizer(example['text'], truncation=True)['input_ids']\n",
|
||
|
" return {'ids': ids}"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "d0259875",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-2d2ddeb0d544c918.arrow\n",
|
||
|
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-1c96c2577c929948.arrow\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"train_data = train_data.map(tokenize_and_numericalize_data, fn_kwargs={'tokenizer': tokenizer})\n",
|
||
|
"test_data = test_data.map(tokenize_and_numericalize_data, fn_kwargs={'tokenizer': tokenizer})"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "e39e64b1",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"{'label': 1,\n",
|
||
|
" 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as \"Teachers\". My 35 years in the teaching profession lead me to believe that Bromwell High\\'s satire is much closer to reality than is \"Teachers\". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\\'t!',\n",
|
||
|
" 'ids': [101,\n",
|
||
|
" 22953,\n",
|
||
|
" 2213,\n",
|
||
|
" 4381,\n",
|
||
|
" 2152,\n",
|
||
|
" 2003,\n",
|
||
|
" 1037,\n",
|
||
|
" 9476,\n",
|
||
|
" 4038,\n",
|
||
|
" 1012,\n",
|
||
|
" 2009,\n",
|
||
|
" 2743,\n",
|
||
|
" 2012,\n",
|
||
|
" 1996,\n",
|
||
|
" 2168,\n",
|
||
|
" 2051,\n",
|
||
|
" 2004,\n",
|
||
|
" 2070,\n",
|
||
|
" 2060,\n",
|
||
|
" 3454,\n",
|
||
|
" 2055,\n",
|
||
|
" 2082,\n",
|
||
|
" 2166,\n",
|
||
|
" 1010,\n",
|
||
|
" 2107,\n",
|
||
|
" 2004,\n",
|
||
|
" 1000,\n",
|
||
|
" 5089,\n",
|
||
|
" 1000,\n",
|
||
|
" 1012,\n",
|
||
|
" 2026,\n",
|
||
|
" 3486,\n",
|
||
|
" 2086,\n",
|
||
|
" 1999,\n",
|
||
|
" 1996,\n",
|
||
|
" 4252,\n",
|
||
|
" 9518,\n",
|
||
|
" 2599,\n",
|
||
|
" 2033,\n",
|
||
|
" 2000,\n",
|
||
|
" 2903,\n",
|
||
|
" 2008,\n",
|
||
|
" 22953,\n",
|
||
|
" 2213,\n",
|
||
|
" 4381,\n",
|
||
|
" 2152,\n",
|
||
|
" 1005,\n",
|
||
|
" 1055,\n",
|
||
|
" 18312,\n",
|
||
|
" 2003,\n",
|
||
|
" 2172,\n",
|
||
|
" 3553,\n",
|
||
|
" 2000,\n",
|
||
|
" 4507,\n",
|
||
|
" 2084,\n",
|
||
|
" 2003,\n",
|
||
|
" 1000,\n",
|
||
|
" 5089,\n",
|
||
|
" 1000,\n",
|
||
|
" 1012,\n",
|
||
|
" 1996,\n",
|
||
|
" 25740,\n",
|
||
|
" 2000,\n",
|
||
|
" 5788,\n",
|
||
|
" 13732,\n",
|
||
|
" 1010,\n",
|
||
|
" 1996,\n",
|
||
|
" 12369,\n",
|
||
|
" 3993,\n",
|
||
|
" 2493,\n",
|
||
|
" 2040,\n",
|
||
|
" 2064,\n",
|
||
|
" 2156,\n",
|
||
|
" 2157,\n",
|
||
|
" 2083,\n",
|
||
|
" 2037,\n",
|
||
|
" 17203,\n",
|
||
|
" 5089,\n",
|
||
|
" 1005,\n",
|
||
|
" 13433,\n",
|
||
|
" 8737,\n",
|
||
|
" 1010,\n",
|
||
|
" 1996,\n",
|
||
|
" 9004,\n",
|
||
|
" 10196,\n",
|
||
|
" 4757,\n",
|
||
|
" 1997,\n",
|
||
|
" 1996,\n",
|
||
|
" 2878,\n",
|
||
|
" 3663,\n",
|
||
|
" 1010,\n",
|
||
|
" 2035,\n",
|
||
|
" 10825,\n",
|
||
|
" 2033,\n",
|
||
|
" 1997,\n",
|
||
|
" 1996,\n",
|
||
|
" 2816,\n",
|
||
|
" 1045,\n",
|
||
|
" 2354,\n",
|
||
|
" 1998,\n",
|
||
|
" 2037,\n",
|
||
|
" 2493,\n",
|
||
|
" 1012,\n",
|
||
|
" 2043,\n",
|
||
|
" 1045,\n",
|
||
|
" 2387,\n",
|
||
|
" 1996,\n",
|
||
|
" 2792,\n",
|
||
|
" 1999,\n",
|
||
|
" 2029,\n",
|
||
|
" 1037,\n",
|
||
|
" 3076,\n",
|
||
|
" 8385,\n",
|
||
|
" 2699,\n",
|
||
|
" 2000,\n",
|
||
|
" 6402,\n",
|
||
|
" 2091,\n",
|
||
|
" 1996,\n",
|
||
|
" 2082,\n",
|
||
|
" 1010,\n",
|
||
|
" 1045,\n",
|
||
|
" 3202,\n",
|
||
|
" 7383,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 2012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 1012,\n",
|
||
|
" 2152,\n",
|
||
|
" 1012,\n",
|
||
|
" 1037,\n",
|
||
|
" 4438,\n",
|
||
|
" 2240,\n",
|
||
|
" 1024,\n",
|
||
|
" 7742,\n",
|
||
|
" 1024,\n",
|
||
|
" 1045,\n",
|
||
|
" 1005,\n",
|
||
|
" 1049,\n",
|
||
|
" 2182,\n",
|
||
|
" 2000,\n",
|
||
|
" 12803,\n",
|
||
|
" 2028,\n",
|
||
|
" 1997,\n",
|
||
|
" 2115,\n",
|
||
|
" 5089,\n",
|
||
|
" 1012,\n",
|
||
|
" 3076,\n",
|
||
|
" 1024,\n",
|
||
|
" 6160,\n",
|
||
|
" 2000,\n",
|
||
|
" 22953,\n",
|
||
|
" 2213,\n",
|
||
|
" 4381,\n",
|
||
|
" 2152,\n",
|
||
|
" 1012,\n",
|
||
|
" 1045,\n",
|
||
|
" 5987,\n",
|
||
|
" 2008,\n",
|
||
|
" 2116,\n",
|
||
|
" 6001,\n",
|
||
|
" 1997,\n",
|
||
|
" 2026,\n",
|
||
|
" 2287,\n",
|
||
|
" 2228,\n",
|
||
|
" 2008,\n",
|
||
|
" 22953,\n",
|
||
|
" 2213,\n",
|
||
|
" 4381,\n",
|
||
|
" 2152,\n",
|
||
|
" 2003,\n",
|
||
|
" 2521,\n",
|
||
|
" 18584,\n",
|
||
|
" 2098,\n",
|
||
|
" 1012,\n",
|
||
|
" 2054,\n",
|
||
|
" 1037,\n",
|
||
|
" 12063,\n",
|
||
|
" 2008,\n",
|
||
|
" 2009,\n",
|
||
|
" 3475,\n",
|
||
|
" 1005,\n",
|
||
|
" 1056,\n",
|
||
|
" 999,\n",
|
||
|
" 102]}"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"train_data[0]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "96d87205",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"999"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tokenizer.vocab['!']"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"id": "614d747b",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"'[PAD]'"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tokenizer.pad_token"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"id": "18f18b66",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tokenizer.pad_token_id"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"id": "2a0d0a6d",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 15,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"tokenizer.vocab[tokenizer.pad_token]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 16,
|
||
|
"id": "3cdaa10f",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"pad_index = tokenizer.pad_token_id"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"id": "8532b705",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Loading cached split indices for dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-0b46d41b3a9a5f87.arrow and /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-83b98d9279b85695.arrow\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"test_size = 0.25\n",
|
||
|
"\n",
|
||
|
"train_valid_data = train_data.train_test_split(test_size=test_size)\n",
|
||
|
"train_data = train_valid_data['train']\n",
|
||
|
"valid_data = train_valid_data['test']"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 18,
|
||
|
"id": "cc54e1eb",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"train_data = train_data.with_format(type='torch', columns=['ids', 'label'])\n",
|
||
|
"valid_data = valid_data.with_format(type='torch', columns=['ids', 'label'])\n",
|
||
|
"test_data = test_data.with_format(type='torch', columns=['ids', 'label'])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"id": "c7bf1626",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']\n",
|
||
|
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
||
|
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"transformer = transformers.AutoModel.from_pretrained(transformer_name)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"id": "269344e7",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"768"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"transformer.config.hidden_size"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 21,
|
||
|
"id": "01a637ac",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"class Transformer(nn.Module):\n",
|
||
|
" def __init__(self, transformer, output_dim, freeze):\n",
|
||
|
" super().__init__()\n",
|
||
|
" self.transformer = transformer\n",
|
||
|
" hidden_dim = transformer.config.hidden_size\n",
|
||
|
" self.fc = nn.Linear(hidden_dim, output_dim)\n",
|
||
|
" \n",
|
||
|
" if freeze:\n",
|
||
|
" for param in self.transformer.parameters():\n",
|
||
|
" param.requires_grad = False\n",
|
||
|
" \n",
|
||
|
" def forward(self, ids):\n",
|
||
|
" # ids = [batch size, seq len]\n",
|
||
|
" output = self.transformer(ids, output_attentions=True)\n",
|
||
|
" hidden = output.last_hidden_state\n",
|
||
|
" # hidden = [batch size, seq len, hidden dim]\n",
|
||
|
" attention = output.attentions[-1]\n",
|
||
|
" # attention = [batch size, n heads, seq len, seq len]\n",
|
||
|
" cls_hidden = hidden[:,0,:]\n",
|
||
|
" prediction = self.fc(torch.tanh(cls_hidden))\n",
|
||
|
" # prediction = [batch size, output dim]\n",
|
||
|
" return prediction"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 22,
|
||
|
"id": "ff995192",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"output_dim = len(train_data['label'].unique())\n",
|
||
|
"freeze = False\n",
|
||
|
"\n",
|
||
|
"model = Transformer(transformer, output_dim, freeze)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 23,
|
||
|
"id": "8e2f95bd",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"The model has 109,483,778 trainable parameters\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"def count_parameters(model):\n",
|
||
|
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
||
|
"\n",
|
||
|
"print(f'The model has {count_parameters(model):,} trainable parameters')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"id": "d2a9f4f7",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"lr = 1e-5\n",
|
||
|
"\n",
|
||
|
"optimizer = optim.Adam(model.parameters(), lr=lr)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 25,
|
||
|
"id": "e82c640f",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"criterion = nn.CrossEntropyLoss()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"id": "0bb2a2a4",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"id": "6fc62ff8",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"model = model.to(device)\n",
|
||
|
"criterion = criterion.to(device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"id": "63ca51df",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def collate(batch, pad_index):\n",
|
||
|
" batch_ids = [i['ids'] for i in batch]\n",
|
||
|
" batch_ids = nn.utils.rnn.pad_sequence(batch_ids, padding_value=pad_index, batch_first=True)\n",
|
||
|
" batch_label = [i['label'] for i in batch]\n",
|
||
|
" batch_label = torch.stack(batch_label)\n",
|
||
|
" batch = {'ids': batch_ids,\n",
|
||
|
" 'label': batch_label}\n",
|
||
|
" return batch"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 29,
|
||
|
"id": "611063db",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"batch_size = 8\n",
|
||
|
"\n",
|
||
|
"collate = functools.partial(collate, pad_index=pad_index)\n",
|
||
|
"\n",
|
||
|
"train_dataloader = torch.utils.data.DataLoader(train_data, \n",
|
||
|
" batch_size=batch_size, \n",
|
||
|
" collate_fn=collate, \n",
|
||
|
" shuffle=True)\n",
|
||
|
"\n",
|
||
|
"valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, collate_fn=collate)\n",
|
||
|
"test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, collate_fn=collate)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 30,
|
||
|
"id": "98f54638",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def train(dataloader, model, criterion, optimizer, device):\n",
|
||
|
"\n",
|
||
|
" model.train()\n",
|
||
|
" epoch_losses = []\n",
|
||
|
" epoch_accs = []\n",
|
||
|
"\n",
|
||
|
" for batch in tqdm.tqdm(dataloader, desc='training...', file=sys.stdout):\n",
|
||
|
" ids = batch['ids'].to(device)\n",
|
||
|
" label = batch['label'].to(device)\n",
|
||
|
" prediction = model(ids)\n",
|
||
|
" loss = criterion(prediction, label)\n",
|
||
|
" accuracy = get_accuracy(prediction, label)\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
" epoch_losses.append(loss.item())\n",
|
||
|
" epoch_accs.append(accuracy.item())\n",
|
||
|
"\n",
|
||
|
" return epoch_losses, epoch_accs"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 31,
|
||
|
"id": "df0424bd",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def evaluate(dataloader, model, criterion, device):\n",
|
||
|
" \n",
|
||
|
" model.eval()\n",
|
||
|
" epoch_losses = []\n",
|
||
|
" epoch_accs = []\n",
|
||
|
"\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for batch in tqdm.tqdm(dataloader, desc='evaluating...', file=sys.stdout):\n",
|
||
|
" ids = batch['ids'].to(device)\n",
|
||
|
" label = batch['label'].to(device)\n",
|
||
|
" prediction = model(ids)\n",
|
||
|
" loss = criterion(prediction, label)\n",
|
||
|
" accuracy = get_accuracy(prediction, label)\n",
|
||
|
" epoch_losses.append(loss.item())\n",
|
||
|
" epoch_accs.append(accuracy.item())\n",
|
||
|
"\n",
|
||
|
" return epoch_losses, epoch_accs"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 32,
|
||
|
"id": "34331854",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def get_accuracy(prediction, label):\n",
|
||
|
" batch_size, _ = prediction.shape\n",
|
||
|
" predicted_classes = prediction.argmax(dim=-1)\n",
|
||
|
" correct_predictions = predicted_classes.eq(label).sum()\n",
|
||
|
" accuracy = correct_predictions / batch_size\n",
|
||
|
" return accuracy"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 33,
|
||
|
"id": "df33ac5d",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"training...: 100%|██████████| 2344/2344 [06:41<00:00, 5.84it/s]\n",
|
||
|
"evaluating...: 100%|██████████| 782/782 [00:40<00:00, 19.24it/s]\n",
|
||
|
"epoch: 1\n",
|
||
|
"train_loss: 0.249, train_acc: 0.894\n",
|
||
|
"valid_loss: 0.196, valid_acc: 0.925\n",
|
||
|
"training...: 100%|██████████| 2344/2344 [06:37<00:00, 5.89it/s]\n",
|
||
|
"evaluating...: 100%|██████████| 782/782 [00:40<00:00, 19.26it/s]\n",
|
||
|
"epoch: 2\n",
|
||
|
"train_loss: 0.135, train_acc: 0.950\n",
|
||
|
"valid_loss: 0.202, valid_acc: 0.927\n",
|
||
|
"training...: 100%|██████████| 2344/2344 [06:36<00:00, 5.91it/s]\n",
|
||
|
"evaluating...: 100%|██████████| 782/782 [00:40<00:00, 19.25it/s]\n",
|
||
|
"epoch: 3\n",
|
||
|
"train_loss: 0.067, train_acc: 0.978\n",
|
||
|
"valid_loss: 0.249, valid_acc: 0.929\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"n_epochs = 3\n",
|
||
|
"best_valid_loss = float('inf')\n",
|
||
|
"\n",
|
||
|
"train_losses = []\n",
|
||
|
"train_accs = []\n",
|
||
|
"valid_losses = []\n",
|
||
|
"valid_accs = []\n",
|
||
|
"\n",
|
||
|
"for epoch in range(n_epochs):\n",
|
||
|
"\n",
|
||
|
" train_loss, train_acc = train(train_dataloader, model, criterion, optimizer, device)\n",
|
||
|
" valid_loss, valid_acc = evaluate(valid_dataloader, model, criterion, device)\n",
|
||
|
"\n",
|
||
|
" train_losses.extend(train_loss)\n",
|
||
|
" train_accs.extend(train_acc)\n",
|
||
|
" valid_losses.extend(valid_loss)\n",
|
||
|
" valid_accs.extend(valid_acc)\n",
|
||
|
" \n",
|
||
|
" epoch_train_loss = np.mean(train_loss)\n",
|
||
|
" epoch_train_acc = np.mean(train_acc)\n",
|
||
|
" epoch_valid_loss = np.mean(valid_loss)\n",
|
||
|
" epoch_valid_acc = np.mean(valid_acc)\n",
|
||
|
" \n",
|
||
|
" if epoch_valid_loss < best_valid_loss:\n",
|
||
|
" best_valid_loss = epoch_valid_loss\n",
|
||
|
" torch.save(model.state_dict(), 'transformer.pt')\n",
|
||
|
" \n",
|
||
|
" print(f'epoch: {epoch+1}')\n",
|
||
|
" print(f'train_loss: {epoch_train_loss:.3f}, train_acc: {epoch_train_acc:.3f}')\n",
|
||
|
" print(f'valid_loss: {epoch_valid_loss:.3f}, valid_acc: {epoch_valid_acc:.3f}')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 34,
|
||
|
"id": "8ac2a935",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmcAAAFzCAYAAAB7Ha4BAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAABcV0lEQVR4nO3dd5hU1f0G8Pe7BZYmIKBRARcSo4AUBRWDikZjw2iMJmKLKcZoTExikp+YGIMdSyzYUbGLBUVBQIp0qUvvsCwL7C5s731nz++PubvMzk65M3PrzPt5HmXn1nPP3PKdc+45R5RSICIiIiJnSLI7AURERER0FIMzIiIiIgdhcEZERETkIAzOiIiIiByEwRkRERGRgzA4IyIiInKQFLsTYKTevXur9PR0u5NBREREFNb69euLlFJ9/KfHVXCWnp6OjIwMu5NBREREFJaIHAg0ndWaRERERA7C4IyIiIjIQRicERERETlIXL1zRkRERMZpbGxETk4O6urq7E6Kq6WlpaFv375ITU3VtTyDMyIiIgooJycH3bp1Q3p6OkTE7uS4klIKxcXFyMnJwYABA3Stw2pNIiIiCqiurg69evViYBYDEUGvXr0iKn1kcEZERERBMTCLXaR5yOCMiIiIHKmsrAyvvPJKVOteeeWVKCsr0738xIkT8cwzz0S1L6MxOCMiIiJHChWcNTU1hVx3zpw56NGjhwmpMh+DMyIiInKkCRMmYN++fRgxYgT++c9/YsmSJTj//PNx9dVXY/DgwQCAn/3sZxg5ciSGDBmCKVOmtK6bnp6OoqIiZGdnY9CgQfj973+PIUOG4NJLL0VtbW3I/W7atAmjR4/GsGHDcO2116K0tBQAMHnyZAwePBjDhg3D+PHjAQBLly7FiBEjMGLECJxxxhmorKyM+bjZWpOIiIjCemjWduzIqzB0m4NPPAb//emQoPMnTZqEbdu2YdOmTQCAJUuWYMOGDdi2bVtry8epU6fi2GOPRW1tLc466yxcd9116NWrV5vt7N27F9OmTcMbb7yBX/7yl/j8889xyy23BN3vr371K7z44osYO3YsHnzwQTz00EN4/vnnMWnSJOzfvx8dO3ZsrTJ95pln8PLLL2PMmDGoqqpCWlpabJkClpwRuY+nESjeZ3cqiIhscfbZZ7fpkmLy5MkYPnw4Ro8ejUOHDmHv3r3t1hkwYABGjBgBABg5ciSys7ODbr+8vBxlZWUYO3YsAOC2227DsmXLAADDhg3DzTffjA8++AApKd7yrTFjxuDee+/F5MmTUVZW1jo9Fiw5I3Kb+Q8Aa14D/rYD6H6S3akhogQRqoTLSl26dGn9e8mSJVi4cCFWrVqFzp0748ILLwzYZUXHjh1b/05OTg5brRnM7NmzsWzZMsyaNQuPPfYYtm7digkTJmDcuHGYM2cOxowZg3nz5uG0006LavstWHJG5DbZK7z/1pbYmw4iIpN169Yt5Dtc5eXl6NmzJzp37oxdu3Zh9erVMe+ze/fu6NmzJ5YvXw4AeP/99zF27Fg0Nzfj0KFDuOiii/Dkk0+ivLwcVVVV2LdvH4YOHYr77rsPZ511Fnbt2hVzGlhyRkRERI7Uq1cvjBkzBqeffjquuOIKjBs3rs38yy+/HK+99hoGDRqEU089FaNHjzZkv++++y7uvPNO1NTUYODAgXj77bfh8Xhwyy23oLy8HEop3HPPPejRowf+85//YPHixUhKSsKQIUNwxRVXxLx/UUoZcBjOMGrUKJWRkWF3MojM9eoYIH8bcOcK4HtD7U4NEcWxnTt3YtCgQXYnIy4EyksRWa+UGuW/LKs1iYiIiByEwRkRERGRgzA4IyIiInIQBmdEREREDsLgjBJTUwPg9sYwbk8/EREFxOCMEo+nEXi0j7czV1cSuxNAREQmYnBGiadJ6z16/Tu2JoOIiIzXtWtXAEBeXh6uv/76gMtceOGFCNT1VrDpVmNwRmSnid2BT39ldyqIiOLOiSeeiOnTp9udjKgwOCOy246v7E4BEZEjTZgwAS+//HLr54kTJ+KZZ55BVVUVLr74Ypx55pkYOnQovvqq/X00Ozsbp59+OgCgtrYW48ePx6BBg3DttdfqGltz2rRpGDp0KE4//XTcd999AACPx4Nf//rXOP300zF06FA899xzALyDrw8ePBjDhg3D+PHjYz5uDt9ERERE4c2dABzZauw2vzcUuGJS0Nk33HAD/vrXv+Luu+8GAHz66aeYN28e0tLSMGPGDBxzzDEoKirC6NGjcfXVV0Mk8Du5r776Kjp37oydO3diy5YtOPPMM0MmKy8vD/fddx/Wr1+Pnj174tJLL8WXX36Jfv36ITc3F9u2bQMAlJWVAQAmTZqE/fv3o2PHjq3TYsGSMyIiInKkM844AwUFBcjLy8PmzZvRs2dP9OvXD0op/Otf/8KwYcNwySWXIDc3F/n5+UG3s2zZMtxyyy0AgGHDhmHYsGEh97tu3TpceOGF6NOnD1JSUnDzzTdj2bJlGDhwILKysvDnP/8Z33zzDY455pjWbd5888344IMPkJISe7kXS86IiIgovBAlXGb6xS9+genTp+PIkSO44YYbAAAffvghCgsLsX79eqSmpiI9PR11dXWmp6Vnz57YvHkz5s2bh9deew2ffvoppk6ditmzZ2PZsmWYNWsWHnvsMWzdujWmII0lZ5S42E8YEZHj3XDDDfj4448xffp0/OIXvwAAlJeX47jjjkNqaioWL16MAwcOhNzGBRdcgI8++ggAsG3bNmzZsiXk8meffTaWLl2KoqIieDweTJs2DWPHjkVRURGam5tx3XXX4dFHH8WGDRvQ3NyMQ4cO4aKLLsKTTz6J8vJyVFVVxXTMLDmjBMR+woiI3GLIkCGorKzESSedhBNOOAEAcPPNN+OnP/0phg4dilGjRuG0004LuY277roLv/nNbzBo0CAMGjQII0eODLn8CSecgEmTJuGiiy6CUgrjxo3DNddcg82bN+M3v/kNmpubAQBPPPEEPB4PbrnlFpSXl0MphXvuuQc9evSI6ZhFmVR6ICJTAVwFoEApdXqA+f8EcLP2MQXAIAB9lFIlIpINoBKAB0CTUmqUnn2OGjVKOaF/EnK4+irgiZOA1C7Av/PsTcvE7tq/5frXefU8IH8r8IflwAmh35sgIorFzp07MWjQILuTERcC5aWIrA8U45hZrfkOgMuDzVRKPa2UGqGUGgHgfgBLlVIlPotcpM3XFZgROdLuucDeBXangoiIXMS0ak2l1DIRSde5+I0AppmVFiLbTNP6u4mkZIyIiBKa7Q0CRKQzvCVsn/tMVgDmi8h6EbnDnpQRERERWc8JDQJ+CuA7vyrN85RSuSJyHIAFIrJLKbUs0Mpa8HYHAPTv39/81BIRESUQpVTQzl1Jn0jf77e95AzAePhVaSqlcrV/CwDMAHB2sJWVUlOUUqOUUqP69OljakKJiIgSSVpaGoqLiyMOLugopRSKi4uRlpamex1bS85EpDuAsQBu8ZnWBUCSUqpS+/tSAA/blEQiB+PNkojM1bdvX+Tk5KCwsNDupLhaWloa+vbtq3t504IzEZkG4EIAvUUkB8B/AaQCgFLqNW2xawHMV0pV+6x6PIAZWhFqCoCPlFLfmJVOItdh7QIRWSQ1NRUDBgywOxkJx8zWmjfqWOYdeLvc8J2WBWC4OakiIiIicjYnvHNGRERERBoGZ0REREQOwuCMiIiIyEEYnFECY2tHIiJyHgZnlHjYmSIRETkYgzMit2KnkEREcYnBGREREZGDMDgjcitWzxIRxSUGZ0REREQOwuCMiIiIyEEYnBERERE5CIMzIiIiIgdhcEZERETkIAzOiNyK/ZwREcUlBmdErsMuNIiI4hmDMyIiIiIHYXBGiYvVgkRE5EAMzigBsVqQiIici8EZERERkYMwOCMiIiJyEAZnRERERA7C4IzItdiggYgoHjE4IyIiInIQBmdErsVWp0RE8YjBGREREZGDMDgjIiIichAGZ0REREQOwuCMiIiIyEEYnBERERE5CIMzItdiP2dERPGIwRklMJcGN8IuNIiI4plpwZmITBWRAhH
|
||
|
"text/plain": [
|
||
|
"<Figure size 720x432 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"fig = plt.figure(figsize=(10,6))\n",
|
||
|
"ax = fig.add_subplot(1,1,1)\n",
|
||
|
"ax.plot(train_losses, label='train loss')\n",
|
||
|
"ax.plot(valid_losses, label='valid loss')\n",
|
||
|
"plt.legend()\n",
|
||
|
"ax.set_xlabel('updates')\n",
|
||
|
"ax.set_ylabel('loss');"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 35,
|
||
|
"id": "8796527d",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFzCAYAAAB2A95GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAABIl0lEQVR4nO3deZwcdZ3/8ddn7vvK5J4kEyD3JEOSCSCXgXAjEQREDlmRwxN0WVyjIsbzxyLub9ddxR+6gLhyiavcoCiIrqIkAWI4k5BAEkIyOWYyyRyZ4/v7o3smPTM9Pd0zXV1dM+8njyHdVdXf+ta3qqs+/f1+61vmnENEREREUivD7wyIiIiIjEYKwkRERER8oCBMRERExAcKwkRERER8oCBMRERExAcKwkRERER8kOV3BhJVWVnpqqur/c6GiIiIyKBWr169yzk3Ntq8wAVh1dXVrFq1yu9siIiIiAzKzN4eaJ6aI0VERER8oCBMRERExAcKwkRERER8oCBMRERExAcKwkRERER8oCBMRERExAcKwkRERER8oCBMRERExAcKwkRERER84FkQZmZ3mNlOM1s3wHwzs++b2QYzW2tmi7zKi4iIiEi68bIm7C7gjBjzzwRmhP+uAW7zMC8iIiIiacWzZ0c6554zs+oYi3wQuNs554DnzazMzCY657Z7lae47N8Jf/lP2LYGxhwBBRVgGdDews4Wo6D+JbIbN/FqyQlktTXQvn83kzP3si5rHmNKi9mzaweu4gi27qwnv3Qcubl5zG9/mY6sYjZQRfbeDWRlZ7OncAYz7R0A9u1rpKCggL0tnRzjXmY/hezJmcyrGYexoOEZ3siZx/RS40BLK0WZHbQUV2MFFWS17OKdA1mUNb1JbmU1VlhB5v4dZLc38U5bEXNtE23lM9hYv5/W3ErGt26io2gS+bvWMi6vi505U2jPr6Qsq51c2mHnK1hOEZRMpqLh77S6TN7Om0NHRh7j27eyP7OUPRSzvyOL7K5W9maUU9b2LhM7tpJTMo49LV3kFRZj7QfYZNPIO7iLaVkN5GV2sa+5jVK3D1dxGMVNm8jLz2cHFXS2NuGyC9lZvpiJrevZ05FHea5jz94GWrNLaWk+QEXFGFxGFrvaMinr2M34jL20t7XS1dJAQWEJnVn5PJ+5mCML6mnfV09OTi45GZ3s3NNIcX4OXWRQ3NnAwYw8XFYBhV1NdBVPJHv3G7iMLPZnllKc1cHuwpkc2LmZzJx8DlgBB7uMvPwCdrkS5re9SMO+fewqnEl2cSWZHc0UuyYqczpocvmU7HuT+typNDQ2clLGGnaPWcxuV8Lu5k4m2m7eZSwFOZkcyBlH1sFGKtrfozUjnyP2r2ZD4ULmNK9mY9FimnInQE4B7sAuig/Wk9HZijlHW1EVb9lkclt2MYY91GdN5hd5H2JSaT45WRk0H+xk1oRiDrR1MLEsnx2Nrbzb2MLOfW2MK8nlvcZWZo4v5rXt+1hQVUp7p2POxGLqm9rYc6CdptZ2Wju6KMzJpMs51u/cz8Ip5XQ5x9a9LRTmZpKXlUlBbiYZZkwuy6cgJxMHbN3bzOvbm5gzsYTXtu+jJD+b7ExjbHEur29v4qz5E8nOymDN23t5c0cTpfnZ7G0+yM59bdRVl1Oan82OfW2UFWSTk5nBjqY29rW009DSzpjCHKZXFvJeYysl+Vl0dDq2NbRw5JQytjW0MLksny4HL29pYMe+Vk6vmUB2ZgYvvrMXgNkTinm3oZUpFQU0trTT3tlFQ/NBZk8oocs5CnOzKM3PZvXbe8nONF7dvo8zayay58BBDrR10NjSzs6mNqrHFLKtoZmq8gI6OruoKMwlK9PIy8rghc17GV+Sy4zxxbzX2Mo7e5pZv6OJeZNLcQ5yszNwzlGQk0VZfjbbGlrY39ZBSV42Y4tz2bW/jcqiXFrbOyktyOb5t/ZQVZ5PUU4WxXlZbG9sZVJZHnnZmax5Zy9VZaFtmVCaR0FOJh1djtb2Thqa26kozGHX/jbGFecxZ2Ixr7/XRGVRLn/f1kB2ZgYHO7rIycrgqOkV/OaVHXR0deEc1Ewu5aUtDeRmZfCBBRNZ/fZemlo7WDi1jHXb9rG9sYUJpflMrcinobmdbQ0tTK8spLmtk6ryfDbW76eyKBcXLvO/btpDU2s7p8+bwN7mdgAqi3JYt62RHftC23vCjEow2LmvlfycLDbVH8DhOHxsES9s3kNbexdFeVkU5GSya/9BDnZ0MWdiMWbGzqZW3m1oxTnHzPHF7GtpZ3tjK1UV+VSV5fPOnmaOmj6GF9/ZS1tHF/vC5bV1b6jsx5fkctKscbz+XhM797Vy+LgiDrR1Mq44l8LcLP68cReHjy3ite37aD7YSadzTCjJwznY23yQKRUFtHV0YhgTS/PY1tBCXlYGW/a2UDO5lM6uLloOdlFRGDq2KwpzaGnvpPlgB/taOpgzsYSKwmw6uhyPrd3OsUdUUpafTcvBTjbvPsCBttAya7c1Ul6QTW5WJm/uaKI4L4vS/JyedRpQXVlIeUEOOVnGn9bvZtOu/Rw/YyxFuZmU5uewsX5/6NLW1sGOxlb2tbZzxLjQueLtPc3MnlBMUW4WM8cX89au/azb1kjN5FLyszPp7HI0tXYwb1IJABvrD7B7fxs5WRn87rWdXHbMVDbvbmbm+CJe3tpIXlYmf9/WwLxJpexv6+BgRxcTy/Lo6nJMHVPI9oYW5k0q5c8bd1EzuZQ3dzTR2NxOYW4Wi6aVsWlXM7lZGbz4TgPzJpXw5o4m8rMzWTStnE27DlCYk8n+tk6K87Jo6+giJ9PIzMggMwN2HzhIY3M7cyeVUJyXxZq3G8jPyaQ0P5vMDGPPgYOUFWSza38bOZmZ7NrfRobBP546k2ljClMdafSwUAzkUeKhIOxR51xNlHmPAjc75/4Ufv874IvOuX5P5zazawjVljF16tTFb7894LMwh2/N3fDwtd6lL5IkX2y/mvs7T/I7GyIigXXGvAn86KOLPV2Hma12ztVFmxeIjvnOududc3XOubqxY8d6u7KDB7xNXyRJlmS84XcWRIastqqUe6462u9syCh34GCHr+v3MwjbBkyJeF8VnuYz8zsDInHxsBJbxHNTKgpYMKXM72yMGFkZunYNhZm/5eZnEPYwcHn4LsljgEbf+4OJBIjTDwYJsAyfL34i4H+1i2cd883sXmApUGlmW4GvAdkAzrkfAY8DZwEbgGbgCq/yIiIi6cXM/wugiN+/Bby8O/LiQeY74DNerX/I/N4jIiKjgM60yaVL19D4XWyB6JgvIv2pOVKCLMNMgUMSqY/o0PjdLK4grB+dFSQYdM6VQNOpVtKA3z8EFISJBJRqwiTITMevpAXVhKUZ1S9IMOhIlSALdcxXICb+Uk1YulHDuoiI5zSslYiCMJEA01VMgstQx3wRBWEiIpJyCsAkHfh9GCoIEwkodcyXIFMQllwqz2BSENaP+oSJiHhPUUMyqTvz0PgdvCoIEwkonXMlyNQxX0RBWH/6OSEBoeZICTIz/2shRPweJkVBmEhAKQgTEQk2BWH9qCZMRMRrFv5PkkNXrqHxuzZWQVhfao6UgNCRKiLdnK5dgaQgTEREUk59wkQUhIkElvqEiYgMj98/BBSEiQSUgjAR6abGyGBSENaPDmURkVTQz4jkUZewofH75hAFYSIiknIKwEQUhPWnnxMiIilhfnfIEVGfMBERGW0UgEk68PsoVBDWj2rCRERSwe8LoIjfFISJiIiI+EBBmIiI+EItkjLaKQjrSx3zRURERgW/+yYqCBMRkZRTLZikA78PQwVh/agmTILBdKxKwPldCyHiNwVhIiKScn6PVC4C/tfIKgjrS33CREREJAUUhPWjIExExGt+10CIpAMFYSIiIiI+UBAmElDqmC8iMjx+V8gqCBMRkZTz++InAv7foasgrC91zBcREZEUUBAmIiIpp475kg78PgwVhPWjmjAJBr9PHiIiMjwKwvrKKfY7ByJxaSPb7yyIDFlxno7fZMnLzqAwJ9PvbARSSb6/x6GCsL6WXAlZ+X7nQmRQf+iqjTq9ZnJJXJ8fU5iTzOw
|
||
|
"text/plain": [
|
||
|
"<Figure size 720x432 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"fig = plt.figure(figsize=(10,6))\n",
|
||
|
"ax = fig.add_subplot(1,1,1)\n",
|
||
|
"ax.plot(train_accs, label='train accuracy')\n",
|
||
|
"ax.plot(valid_accs, label='valid accuracy')\n",
|
||
|
"plt.legend()\n",
|
||
|
"ax.set_xlabel('updates')\n",
|
||
|
"ax.set_ylabel('accuracy');"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 36,
|
||
|
"id": "c9f36818",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"evaluating...: 100%|██████████| 3125/3125 [02:45<00:00, 18.89it/s]\n",
|
||
|
"test_loss: 0.178, test_acc: 0.932\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model.load_state_dict(torch.load('transformer.pt'))\n",
|
||
|
"\n",
|
||
|
"test_loss, test_acc = evaluate(test_dataloader, model, criterion, device)\n",
|
||
|
"\n",
|
||
|
"epoch_test_loss = np.mean(test_loss)\n",
|
||
|
"epoch_test_acc = np.mean(test_acc)\n",
|
||
|
"\n",
|
||
|
"print(f'test_loss: {epoch_test_loss:.3f}, test_acc: {epoch_test_acc:.3f}')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 37,
|
||
|
"id": "fca30739",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def predict_sentiment(text, model, tokenizer, device):\n",
|
||
|
" ids = tokenizer(text)['input_ids']\n",
|
||
|
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
|
||
|
" prediction = model(tensor).squeeze(dim=0)\n",
|
||
|
" probability = torch.softmax(prediction, dim=-1)\n",
|
||
|
" predicted_class = prediction.argmax(dim=-1).item()\n",
|
||
|
" predicted_probability = probability[predicted_class].item()\n",
|
||
|
" return predicted_class, predicted_probability"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 38,
|
||
|
"id": "bd35e378",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(0, 0.9898680448532104)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 38,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"text = \"This film is terrible!\"\n",
|
||
|
"\n",
|
||
|
"predict_sentiment(text, model, tokenizer, device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 39,
|
||
|
"id": "b53900c7",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(1, 0.9934925436973572)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 39,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"text = \"This film is great!\"\n",
|
||
|
"\n",
|
||
|
"predict_sentiment(text, model, tokenizer, device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 40,
|
||
|
"id": "ad677cd1",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(1, 0.99724280834198)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 40,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"text = \"This film is not terrible, it's great!\"\n",
|
||
|
"\n",
|
||
|
"predict_sentiment(text, model, tokenizer, device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 41,
|
||
|
"id": "e7c35353",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(0, 0.9856145977973938)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 41,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"text = \"This film is not great, it's terrible!\"\n",
|
||
|
"\n",
|
||
|
"predict_sentiment(text, model, tokenizer, device)"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.9.5"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|