updated experimental
This commit is contained in:
parent
c0a31c6b17
commit
270c1e16c1
File diff suppressed because it is too large
Load Diff
@ -10,30 +10,10 @@
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "lIYdn1woOS1n",
|
||||
"outputId": "8f3dd381-2f03-4eb0-a5bf-b2cfbbe31439"
|
||||
"outputId": "05f43a3e-f111-4f96-ee3e-d95027c041c8"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Requirement already up-to-date: torchtext in /usr/local/lib/python3.6/dist-packages (0.7.0)\n",
|
||||
"Requirement already satisfied, skipping upgrade: tqdm in /usr/local/lib/python3.6/dist-packages (from torchtext) (4.41.1)\n",
|
||||
"Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.6/dist-packages (from torchtext) (2.23.0)\n",
|
||||
"Requirement already satisfied, skipping upgrade: sentencepiece in /usr/local/lib/python3.6/dist-packages (from torchtext) (0.1.91)\n",
|
||||
"Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.18.5)\n",
|
||||
"Requirement already satisfied, skipping upgrade: torch in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.6.0+cu101)\n",
|
||||
"Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (1.24.3)\n",
|
||||
"Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2020.6.20)\n",
|
||||
"Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2.10)\n",
|
||||
"Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (3.0.4)\n",
|
||||
"Requirement already satisfied, skipping upgrade: future in /usr/local/lib/python3.6/dist-packages (from torch->torchtext) (0.16.0)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install torchtext --upgrade\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
@ -91,7 +71,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_train_valid_split(raw_train_data, split_ratio = 0.8):\n",
|
||||
"def get_train_valid_split(raw_train_data, split_ratio = 0.7):\n",
|
||||
"\n",
|
||||
" raw_train_data = list(raw_train_data)\n",
|
||||
" \n",
|
||||
@ -161,7 +141,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"max_length = 250\n",
|
||||
"max_length = 500\n",
|
||||
"\n",
|
||||
"tokenizer = Tokenizer(max_length = max_length)"
|
||||
]
|
||||
@ -199,10 +179,9 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"min_freq = 2\n",
|
||||
"max_size = 25_000\n",
|
||||
"\n",
|
||||
"vocab = build_vocab_from_data(raw_train_data, tokenizer, min_freq = min_freq, max_size = max_size)"
|
||||
"vocab = build_vocab_from_data(raw_train_data, tokenizer, max_size = max_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -244,26 +223,14 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data = process_raw_data(raw_train_data, tokenizer, vocab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "SlBqLei8QXeF"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data = process_raw_data(raw_train_data, tokenizer, vocab)\n",
|
||||
"valid_data = process_raw_data(raw_valid_data, tokenizer, vocab)\n",
|
||||
"test_data = process_raw_data(raw_test_data, tokenizer, vocab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -291,7 +258,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -299,13 +266,14 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pad_idx = vocab['<pad>']\n",
|
||||
"pad_token = '<pad>'\n",
|
||||
"pad_idx = vocab[pad_token]\n",
|
||||
"collator = Collator(pad_idx)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -313,7 +281,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch_size = 128\n",
|
||||
"batch_size = 256\n",
|
||||
"\n",
|
||||
"train_iterator = torch.utils.data.DataLoader(train_data, \n",
|
||||
" batch_size, \n",
|
||||
@ -333,7 +301,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -377,7 +345,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -387,7 +355,7 @@
|
||||
"source": [
|
||||
"input_dim = len(vocab)\n",
|
||||
"emb_dim = 100\n",
|
||||
"hid_dim = 128\n",
|
||||
"hid_dim = 256\n",
|
||||
"output_dim = 2\n",
|
||||
"\n",
|
||||
"model = GRU(input_dim, emb_dim, hid_dim, output_dim, pad_idx)"
|
||||
@ -395,7 +363,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -409,7 +377,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -417,14 +385,14 @@
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "SJdVErKTTogS",
|
||||
"outputId": "6524fe5a-26c8-4fe6-a665-24f8912f77b4"
|
||||
"outputId": "aaf74c2e-2b9f-47df-a672-b809ffffd6e5"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The model has 2,588,778 trainable parameters\n"
|
||||
"The model has 2,775,658 trainable parameters\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -434,7 +402,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -448,7 +416,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -479,7 +447,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -494,7 +462,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -502,7 +470,7 @@
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "LhlnYb2ZTvPr",
|
||||
"outputId": "13a14d89-1c33-4038-f56c-f119f2363816"
|
||||
"outputId": "8d56d0e2-6af1-40fe-ea1e-9ec7a42d8b15"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -512,15 +480,13 @@
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.4769, 0.6460, -0.2009, ..., -0.2221, -0.2449, 0.8116],\n",
|
||||
" [ 0.7019, -0.0129, 0.7528, ..., -0.8730, 0.3202, 0.0773],\n",
|
||||
" [-0.1876, 0.1964, 0.4381, ..., 0.0729, -0.5052, 0.3773]])"
|
||||
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
|
||||
" [-0.7250, 0.7545, 0.1637, ..., -0.0144, -0.1761, 0.3418],\n",
|
||||
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
@ -530,7 +496,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -543,7 +509,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -556,7 +522,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -569,7 +535,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -583,7 +549,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -600,7 +566,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -640,7 +606,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 29,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -676,7 +642,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -693,7 +659,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -701,43 +667,43 @@
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "lG-dJsjFUF8x",
|
||||
"outputId": "947b2f9a-53cd-4159-d422-6336027dc6a9"
|
||||
"outputId": "c434d13f-4efa-4a7c-c346-5e886db0405d"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.565 | Train Acc: 69.29%\n",
|
||||
"\t Val. Loss: 0.408 | Val. Acc: 81.66%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.299 | Train Acc: 87.81%\n",
|
||||
"\t Val. Loss: 0.294 | Val. Acc: 87.70%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.200 | Train Acc: 92.57%\n",
|
||||
"\t Val. Loss: 0.337 | Val. Acc: 86.00%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.120 | Train Acc: 96.10%\n",
|
||||
"\t Val. Loss: 0.320 | Val. Acc: 88.85%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.064 | Train Acc: 98.16%\n",
|
||||
"\t Val. Loss: 0.420 | Val. Acc: 88.30%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.033 | Train Acc: 99.22%\n",
|
||||
"\t Val. Loss: 0.493 | Val. Acc: 87.30%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.017 | Train Acc: 99.60%\n",
|
||||
"\t Val. Loss: 0.560 | Val. Acc: 87.32%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.012 | Train Acc: 99.71%\n",
|
||||
"\t Val. Loss: 0.576 | Val. Acc: 88.03%\n",
|
||||
"Epoch: 09 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.006 | Train Acc: 99.88%\n",
|
||||
"\t Val. Loss: 0.623 | Val. Acc: 86.68%\n",
|
||||
"Epoch: 10 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.003 | Train Acc: 99.94%\n",
|
||||
"\t Val. Loss: 0.802 | Val. Acc: 87.73%\n"
|
||||
"Epoch: 01 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.634 | Train Acc: 62.44%\n",
|
||||
"\t Val. Loss: 0.474 | Val. Acc: 77.64%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.375 | Train Acc: 83.86%\n",
|
||||
"\t Val. Loss: 0.333 | Val. Acc: 86.20%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.251 | Train Acc: 90.32%\n",
|
||||
"\t Val. Loss: 0.286 | Val. Acc: 89.07%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.170 | Train Acc: 93.78%\n",
|
||||
"\t Val. Loss: 0.316 | Val. Acc: 89.58%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.106 | Train Acc: 96.58%\n",
|
||||
"\t Val. Loss: 0.319 | Val. Acc: 89.63%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.066 | Train Acc: 98.08%\n",
|
||||
"\t Val. Loss: 0.327 | Val. Acc: 89.52%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.041 | Train Acc: 98.82%\n",
|
||||
"\t Val. Loss: 0.451 | Val. Acc: 88.07%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.021 | Train Acc: 99.43%\n",
|
||||
"\t Val. Loss: 0.472 | Val. Acc: 88.16%\n",
|
||||
"Epoch: 09 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.014 | Train Acc: 99.71%\n",
|
||||
"\t Val. Loss: 0.520 | Val. Acc: 88.43%\n",
|
||||
"Epoch: 10 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.005 | Train Acc: 99.93%\n",
|
||||
"\t Val. Loss: 0.660 | Val. Acc: 88.43%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -768,7 +734,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -776,14 +742,14 @@
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "PH7-0f6nUKRb",
|
||||
"outputId": "74b95e1d-afbd-4ffa-da42-1cff889de955"
|
||||
"outputId": "faf1e6dd-c99e-4fda-c6f8-435a08ca0073"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.302 | Test Acc: 87.18%\n"
|
||||
"Test Loss: 0.290 | Test Acc: 88.71%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -797,7 +763,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -819,7 +785,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -827,19 +793,17 @@
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "hb7bC-aEeC1q",
|
||||
"outputId": "5f38953a-083e-4e9a-d06c-d56c697acfda"
|
||||
"outputId": "059cccd1-efb4-404c-81f9-606983c23b33"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.08329976350069046"
|
||||
"0.07642160356044769"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
@ -849,6 +813,36 @@
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "APEVZ3D4eEVw",
|
||||
"outputId": "0d188e29-6e4e-4183-c7aa-467ea8f1afe6"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.8930155634880066"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = 'one of the greatest films i have ever seen in my life.'\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
@ -858,25 +852,24 @@
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "APEVZ3D4eEVw",
|
||||
"outputId": "36f07ec8-5fb2-4646-aed0-007d2998e909"
|
||||
"id": "X7GMey_jebjg",
|
||||
"outputId": "04ca4196-51f0-4661-ffe4-8f4dd199baf4"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.869444727897644"
|
||||
"0.2206803858280182"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = 'one of the greatest films i have ever seen in my life.'\n",
|
||||
"sentence = \"i thought it was going to be one of the greatest films i have ever seen in my life, \\\n",
|
||||
"but it was actually the absolute worst movie of all time.\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
@ -890,53 +883,18 @@
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "X7GMey_jebjg",
|
||||
"outputId": "69d7008d-726a-49ac-8d40-e994b3cad83d"
|
||||
"id": "kOoESlQSxYx2",
|
||||
"outputId": "e5826bef-5f9c-41f6-9eb0-795318280045"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.22929930686950684"
|
||||
"0.5373267531394958"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = \"i thought it was going to be one of the greatest films i have ever seen in my life, \\\n",
|
||||
"but it was actually the absolute worst movie of all time.\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "kOoESlQSxYx2",
|
||||
"outputId": "b9644e61-d2b8-4fd0-beba-db3cf8607514"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.42314645648002625"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
@ -970,7 +928,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5"
|
||||
"version": "3.8.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
948
experimental/3_rnn_bilstm.ipynb
Normal file
948
experimental/3_rnn_bilstm.ipynb
Normal file
@ -0,0 +1,948 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 228
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "lIYdn1woOS1n",
|
||||
"outputId": "a30c21d5-b7cc-4ea6-a0d3-f9f1392ee04a"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.optim as optim\n",
|
||||
"\n",
|
||||
"import torchtext\n",
|
||||
"import torchtext.experimental\n",
|
||||
"import torchtext.experimental.vectors\n",
|
||||
"from torchtext.experimental.datasets.raw.text_classification import RawTextIterableDataset\n",
|
||||
"from torchtext.experimental.datasets.text_classification import TextClassificationDataset\n",
|
||||
"from torchtext.experimental.functional import sequential_transforms, vocab_func, totensor\n",
|
||||
"\n",
|
||||
"import collections\n",
|
||||
"import random\n",
|
||||
"import time"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "II-XIfhSkZS-"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"seed = 1234\n",
|
||||
"\n",
|
||||
"torch.manual_seed(seed)\n",
|
||||
"random.seed(seed)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"torch.backends.cudnn.benchmark = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "kIkeEy2mkcT6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"raw_train_data, raw_test_data = torchtext.experimental.datasets.raw.IMDB()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "_a5ucP1ZkeDv"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_train_valid_split(raw_train_data, split_ratio = 0.7):\n",
|
||||
"\n",
|
||||
" raw_train_data = list(raw_train_data)\n",
|
||||
" \n",
|
||||
" random.shuffle(raw_train_data)\n",
|
||||
" \n",
|
||||
" n_train_examples = int(len(raw_train_data) * split_ratio)\n",
|
||||
" \n",
|
||||
" train_data = raw_train_data[:n_train_examples]\n",
|
||||
" valid_data = raw_train_data[n_train_examples:]\n",
|
||||
" \n",
|
||||
" train_data = RawTextIterableDataset(train_data)\n",
|
||||
" valid_data = RawTextIterableDataset(valid_data)\n",
|
||||
" \n",
|
||||
" return train_data, valid_data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "1WP4nz-_kf_0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"raw_train_data, raw_valid_data = get_train_valid_split(raw_train_data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "pPvrMZlWkicJ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Tokenizer:\n",
|
||||
" def __init__(self, tokenize_fn = 'basic_english', lower = True, max_length = None):\n",
|
||||
" \n",
|
||||
" self.tokenize_fn = torchtext.data.utils.get_tokenizer(tokenize_fn)\n",
|
||||
" self.lower = lower\n",
|
||||
" self.max_length = max_length\n",
|
||||
" \n",
|
||||
" def tokenize(self, s):\n",
|
||||
" \n",
|
||||
" tokens = self.tokenize_fn(s)\n",
|
||||
" \n",
|
||||
" if self.lower:\n",
|
||||
" tokens = [token.lower() for token in tokens]\n",
|
||||
" \n",
|
||||
" if self.max_length is not None:\n",
|
||||
" tokens = tokens[:max_length]\n",
|
||||
" \n",
|
||||
" return tokens"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "SMsMQSuSkkt3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"max_length = 500\n",
|
||||
"\n",
|
||||
"tokenizer = Tokenizer(max_length = max_length)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Yie7TKWKkmeK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def build_vocab_from_data(raw_data, tokenizer, **vocab_kwargs):\n",
|
||||
" \n",
|
||||
" token_freqs = collections.Counter()\n",
|
||||
" \n",
|
||||
" for label, text in raw_data:\n",
|
||||
" tokens = tokenizer.tokenize(text)\n",
|
||||
" token_freqs.update(tokens)\n",
|
||||
" \n",
|
||||
" vocab = torchtext.vocab.Vocab(token_freqs, **vocab_kwargs)\n",
|
||||
" \n",
|
||||
" return vocab"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "9jW7Ci7WkoSn"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"max_size = 25_000\n",
|
||||
"\n",
|
||||
"vocab = build_vocab_from_data(raw_train_data, tokenizer, max_size = max_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "cvSZt_iFkqkt"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def process_raw_data(raw_data, tokenizer, vocab):\n",
|
||||
" \n",
|
||||
" raw_data = [(label, text) for (label, text) in raw_data]\n",
|
||||
"\n",
|
||||
" text_transform = sequential_transforms(tokenizer.tokenize,\n",
|
||||
" vocab_func(vocab),\n",
|
||||
" totensor(dtype=torch.long))\n",
|
||||
" \n",
|
||||
" label_transform = sequential_transforms(totensor(dtype=torch.long))\n",
|
||||
"\n",
|
||||
" transforms = (label_transform, text_transform)\n",
|
||||
"\n",
|
||||
" dataset = TextClassificationDataset(raw_data,\n",
|
||||
" vocab,\n",
|
||||
" transforms)\n",
|
||||
" \n",
|
||||
" return dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "bwsSiBdkktRk"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data = process_raw_data(raw_train_data, tokenizer, vocab)\n",
|
||||
"valid_data = process_raw_data(raw_valid_data, tokenizer, vocab)\n",
|
||||
"test_data = process_raw_data(raw_test_data, tokenizer, vocab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "5m3xRusSk8v3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Collator:\n",
|
||||
" def __init__(self, pad_idx):\n",
|
||||
" \n",
|
||||
" self.pad_idx = pad_idx\n",
|
||||
" \n",
|
||||
" def collate(self, batch):\n",
|
||||
" \n",
|
||||
" labels, text = zip(*batch)\n",
|
||||
" \n",
|
||||
" labels = torch.LongTensor(labels)\n",
|
||||
" \n",
|
||||
" lengths = torch.LongTensor([len(x) for x in text])\n",
|
||||
"\n",
|
||||
" text = nn.utils.rnn.pad_sequence(text, padding_value = self.pad_idx)\n",
|
||||
" \n",
|
||||
" return labels, text, lengths"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "1ZMuZqZxk8-p"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pad_token = '<pad>'\n",
|
||||
"pad_idx = vocab[pad_token]\n",
|
||||
"collator = Collator(pad_idx)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "mxG97Si9lAI2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch_size = 256\n",
|
||||
"\n",
|
||||
"train_iterator = torch.utils.data.DataLoader(train_data, \n",
|
||||
" batch_size, \n",
|
||||
" shuffle = True, \n",
|
||||
" collate_fn = collator.collate)\n",
|
||||
"\n",
|
||||
"valid_iterator = torch.utils.data.DataLoader(valid_data, \n",
|
||||
" batch_size, \n",
|
||||
" shuffle = False, \n",
|
||||
" collate_fn = collator.collate)\n",
|
||||
"\n",
|
||||
"test_iterator = torch.utils.data.DataLoader(test_data, \n",
|
||||
" batch_size, \n",
|
||||
" shuffle = False, \n",
|
||||
" collate_fn = collator.collate)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "ty3NbheMlPYs"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class BiLSTM(nn.Module):\n",
|
||||
" def __init__(self, input_dim, emb_dim, hid_dim, output_dim, n_layers, dropout, pad_idx):\n",
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(input_dim, emb_dim)\n",
|
||||
" self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers = n_layers, bidirectional = True, dropout = dropout)\n",
|
||||
" self.fc = nn.Linear(2 * hid_dim, output_dim)\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
"\n",
|
||||
" def forward(self, text, lengths):\n",
|
||||
"\n",
|
||||
" # text = [seq len, batch size]\n",
|
||||
" # lengths = [batch size]\n",
|
||||
"\n",
|
||||
" embedded = self.dropout(self.embedding(text))\n",
|
||||
"\n",
|
||||
" # embedded = [seq len, batch size, emb dim]\n",
|
||||
"\n",
|
||||
" packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths, enforce_sorted = False)\n",
|
||||
"\n",
|
||||
" packed_output, (hidden, cell) = self.lstm(packed_embedded)\n",
|
||||
"\n",
|
||||
" output, _ = nn.utils.rnn.pad_packed_sequence(packed_output)\n",
|
||||
"\n",
|
||||
" # outputs = [seq_len, batch size, n directions * hid dim]\n",
|
||||
" # hidden = [n layers * n directions, batch size, hid dim]\n",
|
||||
"\n",
|
||||
" hidden_fwd = hidden[-1]\n",
|
||||
" hidden_bck = hidden[-2]\n",
|
||||
"\n",
|
||||
" # hidden_fwd/bck = [batch size, hid dim]\n",
|
||||
"\n",
|
||||
" hidden = torch.cat((hidden_fwd, hidden_bck), dim = 1)\n",
|
||||
"\n",
|
||||
" # hidden = [batch size, hid dim * 2]\n",
|
||||
"\n",
|
||||
" prediction = self.fc(self.dropout(hidden))\n",
|
||||
"\n",
|
||||
" # prediction = [batch size, output dim]\n",
|
||||
"\n",
|
||||
" return prediction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "trg6yTjBqOLZ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"input_dim = len(vocab)\n",
|
||||
"emb_dim = 100\n",
|
||||
"hid_dim = 256\n",
|
||||
"output_dim = 2\n",
|
||||
"n_layers = 2\n",
|
||||
"dropout = 0.5\n",
|
||||
"\n",
|
||||
"model = BiLSTM(input_dim, emb_dim, hid_dim, output_dim, n_layers, dropout, pad_idx)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "9dgdCRsqqQoD"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def count_parameters(model):\n",
|
||||
" return sum(p.numel() for p in model.parameters() if p.requires_grad)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "bfiGzjvnqV-s",
|
||||
"outputId": "168a3662-b95a-48de-d722-c76264e8c8ab"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The model has 4,811,370 trainable parameters\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(f'The model has {count_parameters(model):,} trainable parameters')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Sah17A41qW5d"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"glove = torchtext.experimental.vectors.GloVe(name = '6B',\n",
|
||||
" dim = emb_dim)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "S1Dfcn2Nqabo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_pretrained_embedding(vectors, vocab, unk_token):\n",
|
||||
" \n",
|
||||
" unk_vector = vectors[unk_token]\n",
|
||||
" emb_dim = unk_vector.shape[-1]\n",
|
||||
" zero_vector = torch.zeros(emb_dim)\n",
|
||||
"\n",
|
||||
" pretrained_embedding = torch.zeros(len(vocab), emb_dim) \n",
|
||||
" \n",
|
||||
" unk_tokens = []\n",
|
||||
" \n",
|
||||
" for idx, token in enumerate(vocab.itos):\n",
|
||||
" pretrained_vector = vectors[token]\n",
|
||||
" if torch.all(torch.eq(pretrained_vector, unk_vector)):\n",
|
||||
" unk_tokens.append(token)\n",
|
||||
" pretrained_embedding[idx] = zero_vector\n",
|
||||
" else:\n",
|
||||
" pretrained_embedding[idx] = pretrained_vector\n",
|
||||
" \n",
|
||||
" return pretrained_embedding, unk_tokens"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "sGyV94f7qvdr"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"unk_token = '<unk>'\n",
|
||||
"\n",
|
||||
"pretrained_embedding, unk_tokens = get_pretrained_embedding(glove.vectors, vocab, unk_token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 139
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "KYnGxbVisUsk",
|
||||
"outputId": "e1a88c1c-0f3e-48c6-afcf-9d791fd54bb9"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
|
||||
" [-0.7250, 0.7545, 0.1637, ..., -0.0144, -0.1761, 0.3418],\n",
|
||||
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.embedding.weight.data.copy_(pretrained_embedding)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "DTwNU41WseMS"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"optimizer = optim.Adam(model.parameters())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Rxlx7a72s1ze"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"criterion = nn.CrossEntropyLoss()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "1CLimBxus2yX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "108fm55ftBgO"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = model.to(device)\n",
|
||||
"criterion = criterion.to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "IYCxbvXUvE5v"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def calculate_accuracy(predictions, labels):\n",
|
||||
" top_predictions = predictions.argmax(1, keepdim = True)\n",
|
||||
" correct = top_predictions.eq(labels.view_as(top_predictions)).sum()\n",
|
||||
" accuracy = correct.float() / labels.shape[0]\n",
|
||||
" return accuracy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Ik2JQo6TvGml"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train(model, iterator, optimizer, criterion, device):\n",
|
||||
" \n",
|
||||
" epoch_loss = 0\n",
|
||||
" epoch_acc = 0\n",
|
||||
" \n",
|
||||
" model.train()\n",
|
||||
" \n",
|
||||
" for labels, text, lengths in iterator:\n",
|
||||
" \n",
|
||||
" labels = labels.to(device)\n",
|
||||
" text = text.to(device)\n",
|
||||
"\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" \n",
|
||||
" predictions = model(text, lengths)\n",
|
||||
" \n",
|
||||
" loss = criterion(predictions, labels)\n",
|
||||
" \n",
|
||||
" acc = calculate_accuracy(predictions, labels)\n",
|
||||
" \n",
|
||||
" loss.backward()\n",
|
||||
" \n",
|
||||
" optimizer.step()\n",
|
||||
" \n",
|
||||
" epoch_loss += loss.item()\n",
|
||||
" epoch_acc += acc.item()\n",
|
||||
"\n",
|
||||
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "aGy1Zk6jvIf8"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def evaluate(model, iterator, criterion, device):\n",
|
||||
" \n",
|
||||
" epoch_loss = 0\n",
|
||||
" epoch_acc = 0\n",
|
||||
" \n",
|
||||
" model.eval()\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" \n",
|
||||
" for labels, text, lengths in iterator:\n",
|
||||
"\n",
|
||||
" labels = labels.to(device)\n",
|
||||
" text = text.to(device)\n",
|
||||
" \n",
|
||||
" predictions = model(text, lengths)\n",
|
||||
" \n",
|
||||
" loss = criterion(predictions, labels)\n",
|
||||
" \n",
|
||||
" acc = calculate_accuracy(predictions, labels)\n",
|
||||
"\n",
|
||||
" epoch_loss += loss.item()\n",
|
||||
" epoch_acc += acc.item()\n",
|
||||
" \n",
|
||||
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "9MyMRRzbvKPx"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def epoch_time(start_time, end_time):\n",
|
||||
" elapsed_time = end_time - start_time\n",
|
||||
" elapsed_mins = int(elapsed_time / 60)\n",
|
||||
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
|
||||
" return elapsed_mins, elapsed_secs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 537
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "dRKwD51WvMa3",
|
||||
"outputId": "79389e66-c1bf-45c9-a919-63ee787ad660"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.668 | Train Acc: 59.01%\n",
|
||||
"\t Val. Loss: 0.652 | Val. Acc: 62.33%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.602 | Train Acc: 67.75%\n",
|
||||
"\t Val. Loss: 0.478 | Val. Acc: 77.18%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.497 | Train Acc: 76.59%\n",
|
||||
"\t Val. Loss: 0.478 | Val. Acc: 80.62%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.456 | Train Acc: 79.24%\n",
|
||||
"\t Val. Loss: 0.397 | Val. Acc: 83.30%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.391 | Train Acc: 82.72%\n",
|
||||
"\t Val. Loss: 0.344 | Val. Acc: 85.23%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.345 | Train Acc: 85.30%\n",
|
||||
"\t Val. Loss: 0.350 | Val. Acc: 86.12%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.314 | Train Acc: 86.75%\n",
|
||||
"\t Val. Loss: 0.310 | Val. Acc: 87.50%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.266 | Train Acc: 89.54%\n",
|
||||
"\t Val. Loss: 0.315 | Val. Acc: 88.16%\n",
|
||||
"Epoch: 09 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.247 | Train Acc: 90.21%\n",
|
||||
"\t Val. Loss: 0.285 | Val. Acc: 89.02%\n",
|
||||
"Epoch: 10 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.217 | Train Acc: 91.79%\n",
|
||||
"\t Val. Loss: 0.282 | Val. Acc: 88.98%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"n_epochs = 10\n",
|
||||
"\n",
|
||||
"best_valid_loss = float('inf')\n",
|
||||
"\n",
|
||||
"for epoch in range(n_epochs):\n",
|
||||
"\n",
|
||||
" start_time = time.monotonic()\n",
|
||||
" \n",
|
||||
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)\n",
|
||||
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)\n",
|
||||
" \n",
|
||||
" end_time = time.monotonic()\n",
|
||||
"\n",
|
||||
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
|
||||
" \n",
|
||||
" if valid_loss < best_valid_loss:\n",
|
||||
" best_valid_loss = valid_loss\n",
|
||||
" torch.save(model.state_dict(), 'bilstm-model.pt')\n",
|
||||
" \n",
|
||||
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
|
||||
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
|
||||
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "hKOg4oARvPHJ",
|
||||
"outputId": "7cfe4b85-de2f-47f3-8437-45589c32ceca"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.294 | Test Acc: 87.95%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.load_state_dict(torch.load('bilstm-model.pt'))\n",
|
||||
"\n",
|
||||
"test_loss, test_acc = evaluate(model, test_iterator, criterion, device)\n",
|
||||
"\n",
|
||||
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "tQ4Jsf_vvWgB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def predict_sentiment(tokenizer, vocab, model, device, sentence):\n",
|
||||
" model.eval()\n",
|
||||
" tokens = tokenizer.tokenize(sentence)\n",
|
||||
" length = torch.LongTensor([len(tokens)]).to(device)\n",
|
||||
" indexes = [vocab.stoi[token] for token in tokens]\n",
|
||||
" tensor = torch.LongTensor(indexes).unsqueeze(-1).to(device)\n",
|
||||
" prediction = model(tensor, length)\n",
|
||||
" probabilities = nn.functional.softmax(prediction, dim = -1)\n",
|
||||
" pos_probability = probabilities.squeeze()[-1].item()\n",
|
||||
" return pos_probability"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "Yy7_6rhovZTE",
|
||||
"outputId": "78860852-39ea-4a7b-eb33-9a1a077fb9e0"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.008071469143033028"
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = 'the absolute worst movie of all time.'\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "L3LmQxrgvau9",
|
||||
"outputId": "0204aa17-0bc1-45f2-9be1-c014798af120"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9896865487098694"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = 'one of the greatest films i have ever seen in my life.'\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "t7Qoy21Bvb7v",
|
||||
"outputId": "6094a141-4f37-4110-edc7-aa14b9a3c667"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.029767075553536415"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = \"i thought it was going to be one of the greatest films i have ever seen in my life, \\\n",
|
||||
"but it was actually the absolute worst movie of all time.\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "EPGXBr18vdQT",
|
||||
"outputId": "e5b3d210-0254-4d5f-bdbe-609c0b7d6a8a"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.5513127446174622"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = \"i thought it was going to be the absolute worst movie of all time, \\\n",
|
||||
"but it was actually one of the greatest films i have ever seen in my life.\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"machine_shape": "hm",
|
||||
"name": "scratchpad",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
Loading…
Reference in New Issue
Block a user