updated experimental

This commit is contained in:
bentrevett 2020-08-25 18:15:24 +01:00
parent c0a31c6b17
commit 270c1e16c1
3 changed files with 2589 additions and 1731 deletions

File diff suppressed because it is too large Load Diff

View File

@ -10,30 +10,10 @@
},
"colab_type": "code",
"id": "lIYdn1woOS1n",
"outputId": "8f3dd381-2f03-4eb0-a5bf-b2cfbbe31439"
"outputId": "05f43a3e-f111-4f96-ee3e-d95027c041c8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already up-to-date: torchtext in /usr/local/lib/python3.6/dist-packages (0.7.0)\n",
"Requirement already satisfied, skipping upgrade: tqdm in /usr/local/lib/python3.6/dist-packages (from torchtext) (4.41.1)\n",
"Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.6/dist-packages (from torchtext) (2.23.0)\n",
"Requirement already satisfied, skipping upgrade: sentencepiece in /usr/local/lib/python3.6/dist-packages (from torchtext) (0.1.91)\n",
"Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.18.5)\n",
"Requirement already satisfied, skipping upgrade: torch in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.6.0+cu101)\n",
"Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (1.24.3)\n",
"Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2020.6.20)\n",
"Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2.10)\n",
"Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (3.0.4)\n",
"Requirement already satisfied, skipping upgrade: future in /usr/local/lib/python3.6/dist-packages (from torch->torchtext) (0.16.0)\n"
]
}
],
"outputs": [],
"source": [
"!pip install torchtext --upgrade\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
@ -91,7 +71,7 @@
},
"outputs": [],
"source": [
"def get_train_valid_split(raw_train_data, split_ratio = 0.8):\n",
"def get_train_valid_split(raw_train_data, split_ratio = 0.7):\n",
"\n",
" raw_train_data = list(raw_train_data)\n",
" \n",
@ -161,7 +141,7 @@
},
"outputs": [],
"source": [
"max_length = 250\n",
"max_length = 500\n",
"\n",
"tokenizer = Tokenizer(max_length = max_length)"
]
@ -199,10 +179,9 @@
},
"outputs": [],
"source": [
"min_freq = 2\n",
"max_size = 25_000\n",
"\n",
"vocab = build_vocab_from_data(raw_train_data, tokenizer, min_freq = min_freq, max_size = max_size)"
"vocab = build_vocab_from_data(raw_train_data, tokenizer, max_size = max_size)"
]
},
{
@ -244,26 +223,14 @@
},
"outputs": [],
"source": [
"train_data = process_raw_data(raw_train_data, tokenizer, vocab)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "SlBqLei8QXeF"
},
"outputs": [],
"source": [
"train_data = process_raw_data(raw_train_data, tokenizer, vocab)\n",
"valid_data = process_raw_data(raw_valid_data, tokenizer, vocab)\n",
"test_data = process_raw_data(raw_test_data, tokenizer, vocab)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"metadata": {
"colab": {},
"colab_type": "code",
@ -291,7 +258,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"metadata": {
"colab": {},
"colab_type": "code",
@ -299,13 +266,14 @@
},
"outputs": [],
"source": [
"pad_idx = vocab['<pad>']\n",
"pad_token = '<pad>'\n",
"pad_idx = vocab[pad_token]\n",
"collator = Collator(pad_idx)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 14,
"metadata": {
"colab": {},
"colab_type": "code",
@ -313,7 +281,7 @@
},
"outputs": [],
"source": [
"batch_size = 128\n",
"batch_size = 256\n",
"\n",
"train_iterator = torch.utils.data.DataLoader(train_data, \n",
" batch_size, \n",
@ -333,7 +301,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 15,
"metadata": {
"colab": {},
"colab_type": "code",
@ -377,7 +345,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 16,
"metadata": {
"colab": {},
"colab_type": "code",
@ -387,7 +355,7 @@
"source": [
"input_dim = len(vocab)\n",
"emb_dim = 100\n",
"hid_dim = 128\n",
"hid_dim = 256\n",
"output_dim = 2\n",
"\n",
"model = GRU(input_dim, emb_dim, hid_dim, output_dim, pad_idx)"
@ -395,7 +363,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 17,
"metadata": {
"colab": {},
"colab_type": "code",
@ -409,7 +377,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@ -417,14 +385,14 @@
},
"colab_type": "code",
"id": "SJdVErKTTogS",
"outputId": "6524fe5a-26c8-4fe6-a665-24f8912f77b4"
"outputId": "aaf74c2e-2b9f-47df-a672-b809ffffd6e5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model has 2,588,778 trainable parameters\n"
"The model has 2,775,658 trainable parameters\n"
]
}
],
@ -434,7 +402,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 19,
"metadata": {
"colab": {},
"colab_type": "code",
@ -448,7 +416,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 20,
"metadata": {
"colab": {},
"colab_type": "code",
@ -479,7 +447,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 21,
"metadata": {
"colab": {},
"colab_type": "code",
@ -494,7 +462,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@ -502,7 +470,7 @@
},
"colab_type": "code",
"id": "LhlnYb2ZTvPr",
"outputId": "13a14d89-1c33-4038-f56c-f119f2363816"
"outputId": "8d56d0e2-6af1-40fe-ea1e-9ec7a42d8b15"
},
"outputs": [
{
@ -512,15 +480,13 @@
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
" ...,\n",
" [-0.4769, 0.6460, -0.2009, ..., -0.2221, -0.2449, 0.8116],\n",
" [ 0.7019, -0.0129, 0.7528, ..., -0.8730, 0.3202, 0.0773],\n",
" [-0.1876, 0.1964, 0.4381, ..., 0.0729, -0.5052, 0.3773]])"
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
" [-0.7250, 0.7545, 0.1637, ..., -0.0144, -0.1761, 0.3418],\n",
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
]
},
"execution_count": 23,
"metadata": {
"tags": []
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
@ -530,7 +496,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 23,
"metadata": {
"colab": {},
"colab_type": "code",
@ -543,7 +509,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 24,
"metadata": {
"colab": {},
"colab_type": "code",
@ -556,7 +522,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 25,
"metadata": {
"colab": {},
"colab_type": "code",
@ -569,7 +535,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 26,
"metadata": {
"colab": {},
"colab_type": "code",
@ -583,7 +549,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 27,
"metadata": {
"colab": {},
"colab_type": "code",
@ -600,7 +566,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 28,
"metadata": {
"colab": {},
"colab_type": "code",
@ -640,7 +606,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 29,
"metadata": {
"colab": {},
"colab_type": "code",
@ -676,7 +642,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 30,
"metadata": {
"colab": {},
"colab_type": "code",
@ -693,7 +659,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@ -701,43 +667,43 @@
},
"colab_type": "code",
"id": "lG-dJsjFUF8x",
"outputId": "947b2f9a-53cd-4159-d422-6336027dc6a9"
"outputId": "c434d13f-4efa-4a7c-c346-5e886db0405d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 01 | Epoch Time: 0m 9s\n",
"\tTrain Loss: 0.565 | Train Acc: 69.29%\n",
"\t Val. Loss: 0.408 | Val. Acc: 81.66%\n",
"Epoch: 02 | Epoch Time: 0m 9s\n",
"\tTrain Loss: 0.299 | Train Acc: 87.81%\n",
"\t Val. Loss: 0.294 | Val. Acc: 87.70%\n",
"Epoch: 03 | Epoch Time: 0m 9s\n",
"\tTrain Loss: 0.200 | Train Acc: 92.57%\n",
"\t Val. Loss: 0.337 | Val. Acc: 86.00%\n",
"Epoch: 04 | Epoch Time: 0m 9s\n",
"\tTrain Loss: 0.120 | Train Acc: 96.10%\n",
"\t Val. Loss: 0.320 | Val. Acc: 88.85%\n",
"Epoch: 05 | Epoch Time: 0m 9s\n",
"\tTrain Loss: 0.064 | Train Acc: 98.16%\n",
"\t Val. Loss: 0.420 | Val. Acc: 88.30%\n",
"Epoch: 06 | Epoch Time: 0m 9s\n",
"\tTrain Loss: 0.033 | Train Acc: 99.22%\n",
"\t Val. Loss: 0.493 | Val. Acc: 87.30%\n",
"Epoch: 07 | Epoch Time: 0m 9s\n",
"\tTrain Loss: 0.017 | Train Acc: 99.60%\n",
"\t Val. Loss: 0.560 | Val. Acc: 87.32%\n",
"Epoch: 08 | Epoch Time: 0m 9s\n",
"\tTrain Loss: 0.012 | Train Acc: 99.71%\n",
"\t Val. Loss: 0.576 | Val. Acc: 88.03%\n",
"Epoch: 09 | Epoch Time: 0m 9s\n",
"\tTrain Loss: 0.006 | Train Acc: 99.88%\n",
"\t Val. Loss: 0.623 | Val. Acc: 86.68%\n",
"Epoch: 10 | Epoch Time: 0m 9s\n",
"\tTrain Loss: 0.003 | Train Acc: 99.94%\n",
"\t Val. Loss: 0.802 | Val. Acc: 87.73%\n"
"Epoch: 01 | Epoch Time: 0m 8s\n",
"\tTrain Loss: 0.634 | Train Acc: 62.44%\n",
"\t Val. Loss: 0.474 | Val. Acc: 77.64%\n",
"Epoch: 02 | Epoch Time: 0m 8s\n",
"\tTrain Loss: 0.375 | Train Acc: 83.86%\n",
"\t Val. Loss: 0.333 | Val. Acc: 86.20%\n",
"Epoch: 03 | Epoch Time: 0m 8s\n",
"\tTrain Loss: 0.251 | Train Acc: 90.32%\n",
"\t Val. Loss: 0.286 | Val. Acc: 89.07%\n",
"Epoch: 04 | Epoch Time: 0m 8s\n",
"\tTrain Loss: 0.170 | Train Acc: 93.78%\n",
"\t Val. Loss: 0.316 | Val. Acc: 89.58%\n",
"Epoch: 05 | Epoch Time: 0m 8s\n",
"\tTrain Loss: 0.106 | Train Acc: 96.58%\n",
"\t Val. Loss: 0.319 | Val. Acc: 89.63%\n",
"Epoch: 06 | Epoch Time: 0m 8s\n",
"\tTrain Loss: 0.066 | Train Acc: 98.08%\n",
"\t Val. Loss: 0.327 | Val. Acc: 89.52%\n",
"Epoch: 07 | Epoch Time: 0m 8s\n",
"\tTrain Loss: 0.041 | Train Acc: 98.82%\n",
"\t Val. Loss: 0.451 | Val. Acc: 88.07%\n",
"Epoch: 08 | Epoch Time: 0m 8s\n",
"\tTrain Loss: 0.021 | Train Acc: 99.43%\n",
"\t Val. Loss: 0.472 | Val. Acc: 88.16%\n",
"Epoch: 09 | Epoch Time: 0m 7s\n",
"\tTrain Loss: 0.014 | Train Acc: 99.71%\n",
"\t Val. Loss: 0.520 | Val. Acc: 88.43%\n",
"Epoch: 10 | Epoch Time: 0m 7s\n",
"\tTrain Loss: 0.005 | Train Acc: 99.93%\n",
"\t Val. Loss: 0.660 | Val. Acc: 88.43%\n"
]
}
],
@ -768,7 +734,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 32,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@ -776,14 +742,14 @@
},
"colab_type": "code",
"id": "PH7-0f6nUKRb",
"outputId": "74b95e1d-afbd-4ffa-da42-1cff889de955"
"outputId": "faf1e6dd-c99e-4fda-c6f8-435a08ca0073"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Loss: 0.302 | Test Acc: 87.18%\n"
"Test Loss: 0.290 | Test Acc: 88.71%\n"
]
}
],
@ -797,7 +763,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 33,
"metadata": {
"colab": {},
"colab_type": "code",
@ -819,7 +785,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 34,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
@ -827,19 +793,17 @@
},
"colab_type": "code",
"id": "hb7bC-aEeC1q",
"outputId": "5f38953a-083e-4e9a-d06c-d56c697acfda"
"outputId": "059cccd1-efb4-404c-81f9-606983c23b33"
},
"outputs": [
{
"data": {
"text/plain": [
"0.08329976350069046"
"0.07642160356044769"
]
},
"execution_count": 35,
"metadata": {
"tags": []
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
@ -849,6 +813,36 @@
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"colab_type": "code",
"id": "APEVZ3D4eEVw",
"outputId": "0d188e29-6e4e-4183-c7aa-467ea8f1afe6"
},
"outputs": [
{
"data": {
"text/plain": [
"0.8930155634880066"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentence = 'one of the greatest films i have ever seen in my life.'\n",
"\n",
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
]
},
{
"cell_type": "code",
"execution_count": 36,
@ -858,25 +852,24 @@
"height": 35
},
"colab_type": "code",
"id": "APEVZ3D4eEVw",
"outputId": "36f07ec8-5fb2-4646-aed0-007d2998e909"
"id": "X7GMey_jebjg",
"outputId": "04ca4196-51f0-4661-ffe4-8f4dd199baf4"
},
"outputs": [
{
"data": {
"text/plain": [
"0.869444727897644"
"0.2206803858280182"
]
},
"execution_count": 36,
"metadata": {
"tags": []
},
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentence = 'one of the greatest films i have ever seen in my life.'\n",
"sentence = \"i thought it was going to be one of the greatest films i have ever seen in my life, \\\n",
"but it was actually the absolute worst movie of all time.\"\n",
"\n",
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
]
@ -890,53 +883,18 @@
"height": 35
},
"colab_type": "code",
"id": "X7GMey_jebjg",
"outputId": "69d7008d-726a-49ac-8d40-e994b3cad83d"
"id": "kOoESlQSxYx2",
"outputId": "e5826bef-5f9c-41f6-9eb0-795318280045"
},
"outputs": [
{
"data": {
"text/plain": [
"0.22929930686950684"
"0.5373267531394958"
]
},
"execution_count": 37,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"sentence = \"i thought it was going to be one of the greatest films i have ever seen in my life, \\\n",
"but it was actually the absolute worst movie of all time.\"\n",
"\n",
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"colab_type": "code",
"id": "kOoESlQSxYx2",
"outputId": "b9644e61-d2b8-4fd0-beba-db3cf8607514"
},
"outputs": [
{
"data": {
"text/plain": [
"0.42314645648002625"
]
},
"execution_count": 38,
"metadata": {
"tags": []
},
"metadata": {},
"output_type": "execute_result"
}
],
@ -970,7 +928,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.8.3"
}
},
"nbformat": 4,

View File

@ -0,0 +1,948 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 228
},
"colab_type": "code",
"id": "lIYdn1woOS1n",
"outputId": "a30c21d5-b7cc-4ea6-a0d3-f9f1392ee04a"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"import torchtext\n",
"import torchtext.experimental\n",
"import torchtext.experimental.vectors\n",
"from torchtext.experimental.datasets.raw.text_classification import RawTextIterableDataset\n",
"from torchtext.experimental.datasets.text_classification import TextClassificationDataset\n",
"from torchtext.experimental.functional import sequential_transforms, vocab_func, totensor\n",
"\n",
"import collections\n",
"import random\n",
"import time"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "II-XIfhSkZS-"
},
"outputs": [],
"source": [
"seed = 1234\n",
"\n",
"torch.manual_seed(seed)\n",
"random.seed(seed)\n",
"torch.backends.cudnn.deterministic = True\n",
"torch.backends.cudnn.benchmark = False"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "kIkeEy2mkcT6"
},
"outputs": [],
"source": [
"raw_train_data, raw_test_data = torchtext.experimental.datasets.raw.IMDB()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "_a5ucP1ZkeDv"
},
"outputs": [],
"source": [
"def get_train_valid_split(raw_train_data, split_ratio = 0.7):\n",
"\n",
" raw_train_data = list(raw_train_data)\n",
" \n",
" random.shuffle(raw_train_data)\n",
" \n",
" n_train_examples = int(len(raw_train_data) * split_ratio)\n",
" \n",
" train_data = raw_train_data[:n_train_examples]\n",
" valid_data = raw_train_data[n_train_examples:]\n",
" \n",
" train_data = RawTextIterableDataset(train_data)\n",
" valid_data = RawTextIterableDataset(valid_data)\n",
" \n",
" return train_data, valid_data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "1WP4nz-_kf_0"
},
"outputs": [],
"source": [
"raw_train_data, raw_valid_data = get_train_valid_split(raw_train_data)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "pPvrMZlWkicJ"
},
"outputs": [],
"source": [
"class Tokenizer:\n",
" def __init__(self, tokenize_fn = 'basic_english', lower = True, max_length = None):\n",
" \n",
" self.tokenize_fn = torchtext.data.utils.get_tokenizer(tokenize_fn)\n",
" self.lower = lower\n",
" self.max_length = max_length\n",
" \n",
" def tokenize(self, s):\n",
" \n",
" tokens = self.tokenize_fn(s)\n",
" \n",
" if self.lower:\n",
" tokens = [token.lower() for token in tokens]\n",
" \n",
" if self.max_length is not None:\n",
" tokens = tokens[:max_length]\n",
" \n",
" return tokens"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "SMsMQSuSkkt3"
},
"outputs": [],
"source": [
"max_length = 500\n",
"\n",
"tokenizer = Tokenizer(max_length = max_length)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Yie7TKWKkmeK"
},
"outputs": [],
"source": [
"def build_vocab_from_data(raw_data, tokenizer, **vocab_kwargs):\n",
" \n",
" token_freqs = collections.Counter()\n",
" \n",
" for label, text in raw_data:\n",
" tokens = tokenizer.tokenize(text)\n",
" token_freqs.update(tokens)\n",
" \n",
" vocab = torchtext.vocab.Vocab(token_freqs, **vocab_kwargs)\n",
" \n",
" return vocab"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9jW7Ci7WkoSn"
},
"outputs": [],
"source": [
"max_size = 25_000\n",
"\n",
"vocab = build_vocab_from_data(raw_train_data, tokenizer, max_size = max_size)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "cvSZt_iFkqkt"
},
"outputs": [],
"source": [
"def process_raw_data(raw_data, tokenizer, vocab):\n",
" \n",
" raw_data = [(label, text) for (label, text) in raw_data]\n",
"\n",
" text_transform = sequential_transforms(tokenizer.tokenize,\n",
" vocab_func(vocab),\n",
" totensor(dtype=torch.long))\n",
" \n",
" label_transform = sequential_transforms(totensor(dtype=torch.long))\n",
"\n",
" transforms = (label_transform, text_transform)\n",
"\n",
" dataset = TextClassificationDataset(raw_data,\n",
" vocab,\n",
" transforms)\n",
" \n",
" return dataset"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "bwsSiBdkktRk"
},
"outputs": [],
"source": [
"train_data = process_raw_data(raw_train_data, tokenizer, vocab)\n",
"valid_data = process_raw_data(raw_valid_data, tokenizer, vocab)\n",
"test_data = process_raw_data(raw_test_data, tokenizer, vocab)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "5m3xRusSk8v3"
},
"outputs": [],
"source": [
"class Collator:\n",
" def __init__(self, pad_idx):\n",
" \n",
" self.pad_idx = pad_idx\n",
" \n",
" def collate(self, batch):\n",
" \n",
" labels, text = zip(*batch)\n",
" \n",
" labels = torch.LongTensor(labels)\n",
" \n",
" lengths = torch.LongTensor([len(x) for x in text])\n",
"\n",
" text = nn.utils.rnn.pad_sequence(text, padding_value = self.pad_idx)\n",
" \n",
" return labels, text, lengths"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "1ZMuZqZxk8-p"
},
"outputs": [],
"source": [
"pad_token = '<pad>'\n",
"pad_idx = vocab[pad_token]\n",
"collator = Collator(pad_idx)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "mxG97Si9lAI2"
},
"outputs": [],
"source": [
"batch_size = 256\n",
"\n",
"train_iterator = torch.utils.data.DataLoader(train_data, \n",
" batch_size, \n",
" shuffle = True, \n",
" collate_fn = collator.collate)\n",
"\n",
"valid_iterator = torch.utils.data.DataLoader(valid_data, \n",
" batch_size, \n",
" shuffle = False, \n",
" collate_fn = collator.collate)\n",
"\n",
"test_iterator = torch.utils.data.DataLoader(test_data, \n",
" batch_size, \n",
" shuffle = False, \n",
" collate_fn = collator.collate)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ty3NbheMlPYs"
},
"outputs": [],
"source": [
"class BiLSTM(nn.Module):\n",
" def __init__(self, input_dim, emb_dim, hid_dim, output_dim, n_layers, dropout, pad_idx):\n",
"\n",
" super().__init__()\n",
"\n",
" self.embedding = nn.Embedding(input_dim, emb_dim)\n",
" self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers = n_layers, bidirectional = True, dropout = dropout)\n",
" self.fc = nn.Linear(2 * hid_dim, output_dim)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, text, lengths):\n",
"\n",
" # text = [seq len, batch size]\n",
" # lengths = [batch size]\n",
"\n",
" embedded = self.dropout(self.embedding(text))\n",
"\n",
" # embedded = [seq len, batch size, emb dim]\n",
"\n",
" packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths, enforce_sorted = False)\n",
"\n",
" packed_output, (hidden, cell) = self.lstm(packed_embedded)\n",
"\n",
" output, _ = nn.utils.rnn.pad_packed_sequence(packed_output)\n",
"\n",
" # outputs = [seq_len, batch size, n directions * hid dim]\n",
" # hidden = [n layers * n directions, batch size, hid dim]\n",
"\n",
" hidden_fwd = hidden[-1]\n",
" hidden_bck = hidden[-2]\n",
"\n",
" # hidden_fwd/bck = [batch size, hid dim]\n",
"\n",
" hidden = torch.cat((hidden_fwd, hidden_bck), dim = 1)\n",
"\n",
" # hidden = [batch size, hid dim * 2]\n",
"\n",
" prediction = self.fc(self.dropout(hidden))\n",
"\n",
" # prediction = [batch size, output dim]\n",
"\n",
" return prediction"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "trg6yTjBqOLZ"
},
"outputs": [],
"source": [
"input_dim = len(vocab)\n",
"emb_dim = 100\n",
"hid_dim = 256\n",
"output_dim = 2\n",
"n_layers = 2\n",
"dropout = 0.5\n",
"\n",
"model = BiLSTM(input_dim, emb_dim, hid_dim, output_dim, n_layers, dropout, pad_idx)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9dgdCRsqqQoD"
},
"outputs": [],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"colab_type": "code",
"id": "bfiGzjvnqV-s",
"outputId": "168a3662-b95a-48de-d722-c76264e8c8ab"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model has 4,811,370 trainable parameters\n"
]
}
],
"source": [
"print(f'The model has {count_parameters(model):,} trainable parameters')"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Sah17A41qW5d"
},
"outputs": [],
"source": [
"glove = torchtext.experimental.vectors.GloVe(name = '6B',\n",
" dim = emb_dim)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "S1Dfcn2Nqabo"
},
"outputs": [],
"source": [
"def get_pretrained_embedding(vectors, vocab, unk_token):\n",
" \n",
" unk_vector = vectors[unk_token]\n",
" emb_dim = unk_vector.shape[-1]\n",
" zero_vector = torch.zeros(emb_dim)\n",
"\n",
" pretrained_embedding = torch.zeros(len(vocab), emb_dim) \n",
" \n",
" unk_tokens = []\n",
" \n",
" for idx, token in enumerate(vocab.itos):\n",
" pretrained_vector = vectors[token]\n",
" if torch.all(torch.eq(pretrained_vector, unk_vector)):\n",
" unk_tokens.append(token)\n",
" pretrained_embedding[idx] = zero_vector\n",
" else:\n",
" pretrained_embedding[idx] = pretrained_vector\n",
" \n",
" return pretrained_embedding, unk_tokens"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "sGyV94f7qvdr"
},
"outputs": [],
"source": [
"unk_token = '<unk>'\n",
"\n",
"pretrained_embedding, unk_tokens = get_pretrained_embedding(glove.vectors, vocab, unk_token)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 139
},
"colab_type": "code",
"id": "KYnGxbVisUsk",
"outputId": "e1a88c1c-0f3e-48c6-afcf-9d791fd54bb9"
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
" ...,\n",
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
" [-0.7250, 0.7545, 0.1637, ..., -0.0144, -0.1761, 0.3418],\n",
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.embedding.weight.data.copy_(pretrained_embedding)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "DTwNU41WseMS"
},
"outputs": [],
"source": [
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Rxlx7a72s1ze"
},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "1CLimBxus2yX"
},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "108fm55ftBgO"
},
"outputs": [],
"source": [
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "IYCxbvXUvE5v"
},
"outputs": [],
"source": [
"def calculate_accuracy(predictions, labels):\n",
" top_predictions = predictions.argmax(1, keepdim = True)\n",
" correct = top_predictions.eq(labels.view_as(top_predictions)).sum()\n",
" accuracy = correct.float() / labels.shape[0]\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Ik2JQo6TvGml"
},
"outputs": [],
"source": [
"def train(model, iterator, optimizer, criterion, device):\n",
" \n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" \n",
" model.train()\n",
" \n",
" for labels, text, lengths in iterator:\n",
" \n",
" labels = labels.to(device)\n",
" text = text.to(device)\n",
"\n",
" optimizer.zero_grad()\n",
" \n",
" predictions = model(text, lengths)\n",
" \n",
" loss = criterion(predictions, labels)\n",
" \n",
" acc = calculate_accuracy(predictions, labels)\n",
" \n",
" loss.backward()\n",
" \n",
" optimizer.step()\n",
" \n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
"\n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "aGy1Zk6jvIf8"
},
"outputs": [],
"source": [
"def evaluate(model, iterator, criterion, device):\n",
" \n",
" epoch_loss = 0\n",
" epoch_acc = 0\n",
" \n",
" model.eval()\n",
" \n",
" with torch.no_grad():\n",
" \n",
" for labels, text, lengths in iterator:\n",
"\n",
" labels = labels.to(device)\n",
" text = text.to(device)\n",
" \n",
" predictions = model(text, lengths)\n",
" \n",
" loss = criterion(predictions, labels)\n",
" \n",
" acc = calculate_accuracy(predictions, labels)\n",
"\n",
" epoch_loss += loss.item()\n",
" epoch_acc += acc.item()\n",
" \n",
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9MyMRRzbvKPx"
},
"outputs": [],
"source": [
"def epoch_time(start_time, end_time):\n",
" elapsed_time = end_time - start_time\n",
" elapsed_mins = int(elapsed_time / 60)\n",
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
" return elapsed_mins, elapsed_secs"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 537
},
"colab_type": "code",
"id": "dRKwD51WvMa3",
"outputId": "79389e66-c1bf-45c9-a919-63ee787ad660"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 01 | Epoch Time: 0m 24s\n",
"\tTrain Loss: 0.668 | Train Acc: 59.01%\n",
"\t Val. Loss: 0.652 | Val. Acc: 62.33%\n",
"Epoch: 02 | Epoch Time: 0m 24s\n",
"\tTrain Loss: 0.602 | Train Acc: 67.75%\n",
"\t Val. Loss: 0.478 | Val. Acc: 77.18%\n",
"Epoch: 03 | Epoch Time: 0m 24s\n",
"\tTrain Loss: 0.497 | Train Acc: 76.59%\n",
"\t Val. Loss: 0.478 | Val. Acc: 80.62%\n",
"Epoch: 04 | Epoch Time: 0m 24s\n",
"\tTrain Loss: 0.456 | Train Acc: 79.24%\n",
"\t Val. Loss: 0.397 | Val. Acc: 83.30%\n",
"Epoch: 05 | Epoch Time: 0m 24s\n",
"\tTrain Loss: 0.391 | Train Acc: 82.72%\n",
"\t Val. Loss: 0.344 | Val. Acc: 85.23%\n",
"Epoch: 06 | Epoch Time: 0m 24s\n",
"\tTrain Loss: 0.345 | Train Acc: 85.30%\n",
"\t Val. Loss: 0.350 | Val. Acc: 86.12%\n",
"Epoch: 07 | Epoch Time: 0m 24s\n",
"\tTrain Loss: 0.314 | Train Acc: 86.75%\n",
"\t Val. Loss: 0.310 | Val. Acc: 87.50%\n",
"Epoch: 08 | Epoch Time: 0m 24s\n",
"\tTrain Loss: 0.266 | Train Acc: 89.54%\n",
"\t Val. Loss: 0.315 | Val. Acc: 88.16%\n",
"Epoch: 09 | Epoch Time: 0m 24s\n",
"\tTrain Loss: 0.247 | Train Acc: 90.21%\n",
"\t Val. Loss: 0.285 | Val. Acc: 89.02%\n",
"Epoch: 10 | Epoch Time: 0m 24s\n",
"\tTrain Loss: 0.217 | Train Acc: 91.79%\n",
"\t Val. Loss: 0.282 | Val. Acc: 88.98%\n"
]
}
],
"source": [
"n_epochs = 10\n",
"\n",
"best_valid_loss = float('inf')\n",
"\n",
"for epoch in range(n_epochs):\n",
"\n",
" start_time = time.monotonic()\n",
" \n",
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)\n",
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)\n",
" \n",
" end_time = time.monotonic()\n",
"\n",
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
" \n",
" if valid_loss < best_valid_loss:\n",
" best_valid_loss = valid_loss\n",
" torch.save(model.state_dict(), 'bilstm-model.pt')\n",
" \n",
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"colab_type": "code",
"id": "hKOg4oARvPHJ",
"outputId": "7cfe4b85-de2f-47f3-8437-45589c32ceca"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Loss: 0.294 | Test Acc: 87.95%\n"
]
}
],
"source": [
"model.load_state_dict(torch.load('bilstm-model.pt'))\n",
"\n",
"test_loss, test_acc = evaluate(model, test_iterator, criterion, device)\n",
"\n",
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "tQ4Jsf_vvWgB"
},
"outputs": [],
"source": [
"def predict_sentiment(tokenizer, vocab, model, device, sentence):\n",
" model.eval()\n",
" tokens = tokenizer.tokenize(sentence)\n",
" length = torch.LongTensor([len(tokens)]).to(device)\n",
" indexes = [vocab.stoi[token] for token in tokens]\n",
" tensor = torch.LongTensor(indexes).unsqueeze(-1).to(device)\n",
" prediction = model(tensor, length)\n",
" probabilities = nn.functional.softmax(prediction, dim = -1)\n",
" pos_probability = probabilities.squeeze()[-1].item()\n",
" return pos_probability"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"colab_type": "code",
"id": "Yy7_6rhovZTE",
"outputId": "78860852-39ea-4a7b-eb33-9a1a077fb9e0"
},
"outputs": [
{
"data": {
"text/plain": [
"0.008071469143033028"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentence = 'the absolute worst movie of all time.'\n",
"\n",
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"colab_type": "code",
"id": "L3LmQxrgvau9",
"outputId": "0204aa17-0bc1-45f2-9be1-c014798af120"
},
"outputs": [
{
"data": {
"text/plain": [
"0.9896865487098694"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentence = 'one of the greatest films i have ever seen in my life.'\n",
"\n",
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"colab_type": "code",
"id": "t7Qoy21Bvb7v",
"outputId": "6094a141-4f37-4110-edc7-aa14b9a3c667"
},
"outputs": [
{
"data": {
"text/plain": [
"0.029767075553536415"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentence = \"i thought it was going to be one of the greatest films i have ever seen in my life, \\\n",
"but it was actually the absolute worst movie of all time.\"\n",
"\n",
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"colab_type": "code",
"id": "EPGXBr18vdQT",
"outputId": "e5b3d210-0254-4d5f-bdbe-609c0b7d6a8a"
},
"outputs": [
{
"data": {
"text/plain": [
"0.5513127446174622"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sentence = \"i thought it was going to be the absolute worst movie of all time, \\\n",
"but it was actually one of the greatest films i have ever seen in my life.\"\n",
"\n",
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"machine_shape": "hm",
"name": "scratchpad",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}