added transformer notebook results
This commit is contained in:
parent
96fe748dac
commit
9ae32cf124
@ -2,9 +2,26 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I1105 23:37:22.746299 140101965489984 file_utils.py:39] PyTorch version 1.3.0 available.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1.3.0\n",
|
||||
"0.4.0\n",
|
||||
"2.1.1\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torchtext\n",
|
||||
@ -17,7 +34,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -45,27 +62,51 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I1105 23:37:23.678925 140101965489984 tokenization_utils.py:374] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/ben/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"512\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(tokenizer.max_model_input_sizes['bert-base-uncased'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['hello', 'world', 'how', 'are', 'you', '?']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokens = tokenizer.tokenize('Hello WORLD how ARE yoU?')\n",
|
||||
"\n",
|
||||
@ -74,9 +115,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[7592, 2088, 2129, 2024, 2017, 1029]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"indexes = tokenizer.convert_tokens_to_ids(tokens)\n",
|
||||
"\n",
|
||||
@ -85,9 +134,20 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"101\n",
|
||||
"102\n",
|
||||
"0\n",
|
||||
"100\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"init_token = tokenizer.convert_tokens_to_ids(tokenizer.cls_token)\n",
|
||||
"eos_token = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)\n",
|
||||
@ -102,7 +162,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -112,7 +172,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -130,7 +190,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -141,9 +201,19 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Number of training examples: 17500\n",
|
||||
"Number of validation examples: 7500\n",
|
||||
"Number of testing examples: 25000\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(f\"Number of training examples: {len(train_data)}\")\n",
|
||||
"print(f\"Number of validation examples: {len(valid_data)}\")\n",
|
||||
@ -152,18 +222,34 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'text': [5949, 1997, 2026, 2166, 1010, 1012, 1012, 1012, 1012, 1996, 2472, 2323, 2022, 10339, 1012, 2339, 2111, 2514, 2027, 2342, 2000, 2191, 22692, 5691, 2097, 2196, 2191, 3168, 2000, 2033, 1012, 2043, 2016, 2351, 2012, 1996, 2203, 1010, 2009, 2081, 2033, 4756, 1012, 1045, 2018, 2000, 2689, 1996, 3149, 2116, 2335, 2802, 1996, 2143, 2138, 1045, 2001, 2893, 10339, 3666, 2107, 3532, 3772, 1012, 11504, 1996, 3124, 2040, 2209, 9895, 2196, 4152, 2147, 2153, 1012, 2006, 2327, 1997, 2008, 1045, 3246, 1996, 2472, 2196, 4152, 2000, 2191, 2178, 2143, 1010, 1998, 2038, 2010, 3477, 5403, 3600, 2579, 2067, 2005, 2023, 10231, 1012, 1063, 1012, 6185, 2041, 1997, 2184, 1065], 'label': 'neg'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(vars(train_data.examples[6]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['waste', 'of', 'my', 'life', ',', '.', '.', '.', '.', 'the', 'director', 'should', 'be', 'embarrassed', '.', 'why', 'people', 'feel', 'they', 'need', 'to', 'make', 'worthless', 'movies', 'will', 'never', 'make', 'sense', 'to', 'me', '.', 'when', 'she', 'died', 'at', 'the', 'end', ',', 'it', 'made', 'me', 'laugh', '.', 'i', 'had', 'to', 'change', 'the', 'channel', 'many', 'times', 'throughout', 'the', 'film', 'because', 'i', 'was', 'getting', 'embarrassed', 'watching', 'such', 'poor', 'acting', '.', 'hopefully', 'the', 'guy', 'who', 'played', 'heath', 'never', 'gets', 'work', 'again', '.', 'on', 'top', 'of', 'that', 'i', 'hope', 'the', 'director', 'never', 'gets', 'to', 'make', 'another', 'film', ',', 'and', 'has', 'his', 'pay', '##che', '##ck', 'taken', 'back', 'for', 'this', 'crap', '.', '{', '.', '02', 'out', 'of', '10', '}']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokens = tokenizer.convert_ids_to_tokens(vars(train_data.examples[6])['text'])\n",
|
||||
"\n",
|
||||
@ -172,7 +258,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -181,16 +267,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"defaultdict(None, {'neg': 0, 'pos': 1})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(LABEL.vocab.stoi)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -206,16 +300,48 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I1105 23:39:17.115942 140101965489984 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/ben/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c\n",
|
||||
"I1105 23:39:17.119767 140101965489984 configuration_utils.py:168] Model config {\n",
|
||||
" \"attention_probs_dropout_prob\": 0.1,\n",
|
||||
" \"finetuning_task\": null,\n",
|
||||
" \"hidden_act\": \"gelu\",\n",
|
||||
" \"hidden_dropout_prob\": 0.1,\n",
|
||||
" \"hidden_size\": 768,\n",
|
||||
" \"initializer_range\": 0.02,\n",
|
||||
" \"intermediate_size\": 3072,\n",
|
||||
" \"layer_norm_eps\": 1e-12,\n",
|
||||
" \"max_position_embeddings\": 512,\n",
|
||||
" \"num_attention_heads\": 12,\n",
|
||||
" \"num_hidden_layers\": 12,\n",
|
||||
" \"num_labels\": 2,\n",
|
||||
" \"output_attentions\": false,\n",
|
||||
" \"output_hidden_states\": false,\n",
|
||||
" \"output_past\": true,\n",
|
||||
" \"pruned_heads\": {},\n",
|
||||
" \"torchscript\": false,\n",
|
||||
" \"type_vocab_size\": 2,\n",
|
||||
" \"use_bfloat16\": false,\n",
|
||||
" \"vocab_size\": 30522\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"I1105 23:39:17.664919 140101965489984 modeling_utils.py:337] loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin from cache at /home/ben/.cache/torch/transformers/aa1ef1aede4482d0dbcd4d52baad8ae300e60902e88fcb0bebdec09afd232066.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"bert = BertModel.from_pretrained('bert-base-uncased')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -274,7 +400,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -294,9 +420,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The model has 112,241,409 trainable parameters\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def count_parameters(model):\n",
|
||||
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
||||
@ -306,7 +440,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -317,9 +451,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The model has 2,759,169 trainable parameters\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def count_parameters(model):\n",
|
||||
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
||||
@ -329,9 +471,34 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"rnn.weight_ih_l0\n",
|
||||
"rnn.weight_hh_l0\n",
|
||||
"rnn.bias_ih_l0\n",
|
||||
"rnn.bias_hh_l0\n",
|
||||
"rnn.weight_ih_l0_reverse\n",
|
||||
"rnn.weight_hh_l0_reverse\n",
|
||||
"rnn.bias_ih_l0_reverse\n",
|
||||
"rnn.bias_hh_l0_reverse\n",
|
||||
"rnn.weight_ih_l1\n",
|
||||
"rnn.weight_hh_l1\n",
|
||||
"rnn.bias_ih_l1\n",
|
||||
"rnn.bias_hh_l1\n",
|
||||
"rnn.weight_ih_l1_reverse\n",
|
||||
"rnn.weight_hh_l1_reverse\n",
|
||||
"rnn.bias_ih_l1_reverse\n",
|
||||
"rnn.bias_hh_l1_reverse\n",
|
||||
"out.weight\n",
|
||||
"out.bias\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for name, param in model.named_parameters(): \n",
|
||||
" if param.requires_grad:\n",
|
||||
@ -340,7 +507,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -349,7 +516,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -358,7 +525,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -368,7 +535,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -386,7 +553,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -419,7 +586,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -448,9 +615,235 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "332c7228ec2a4ae6b5ee03914979c296",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, max=137), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "bdb737cd4a664f0fb8f6d12cf36f8448",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, max=59), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch: 01\n",
|
||||
"\tTrain Loss: 0.485 | Train Acc: 75.87%\n",
|
||||
"\t Val. Loss: 0.346 | Val. Acc: 85.89%\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c9b20cff72684de2961554bcfb8ba8c4",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, max=137), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b08e28ae06174c8b9dbecfd47c449f33",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, max=59), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch: 02\n",
|
||||
"\tTrain Loss: 0.286 | Train Acc: 88.16%\n",
|
||||
"\t Val. Loss: 0.247 | Val. Acc: 90.26%\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "676a29ed0adc4a5596eae764c1c1f212",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, max=137), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "dddf83901c0b4f57bf2ac9f902f3ef41",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, max=59), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch: 03\n",
|
||||
"\tTrain Loss: 0.234 | Train Acc: 90.77%\n",
|
||||
"\t Val. Loss: 0.229 | Val. Acc: 91.00%\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a4ae06ccbcd041a8a83667acd4cd0167",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, max=137), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "21f1f978e9b44f95a7ad4068caeaad70",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, max=59), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch: 04\n",
|
||||
"\tTrain Loss: 0.209 | Train Acc: 91.83%\n",
|
||||
"\t Val. Loss: 0.225 | Val. Acc: 91.10%\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "30bcf72065ae4e9fae8d5f29fab2a29d",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, max=137), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "06f7c8835cc44558bc5f0789039f1e20",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, max=59), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Epoch: 05\n",
|
||||
"\tTrain Loss: 0.182 | Train Acc: 92.97%\n",
|
||||
"\t Val. Loss: 0.217 | Val. Acc: 91.98%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"N_EPOCHS = 5\n",
|
||||
"\n",
|
||||
@ -472,9 +865,32 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "58055c0390544943ac64fa66c8d0f1d8",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"HBox(children=(IntProgress(value=0, max=196), HTML(value='')))"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"Test Loss: 0.198 | Test Acc: 92.31%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.load_state_dict(torch.load('tut6-model.pt'))\n",
|
||||
"\n",
|
||||
@ -485,7 +901,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -502,18 +918,40 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.02264496125280857"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"predict_sentiment(model, tokenizer, \"This film is terrible\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9411056041717529"
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"predict_sentiment(model, tokenizer, \"This film is great\")"
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user