pytorch-sentiment-analysis/1_nbow.ipynb

1893 lines
159 KiB
Plaintext
Raw Normal View History

2021-07-08 01:04:25 +08:00
{
"cells": [
2021-07-16 05:41:50 +08:00
{
"cell_type": "markdown",
"id": "a36e41e8",
"metadata": {},
"source": [
"# 1 - NBoW\n",
"\n",
"In this series we'll be building a machine learning model to perform sentiment analysis -- a subset of text classification where the task is to detect if a given sentence is positive or negative -- using [PyTorch](https://github.com/pytorch/pytorch) and [torchtext](https://github.com/pytorch/text). The dataset used will be movie reviews from the [IMDb dataset](http://ai.stanford.edu/~amaas/data/sentiment/), which we'll obtain using the [datasets](https://github.com/huggingface/datasets) library.\n",
"\n",
"## Introduction\n",
"\n",
"In this first notebook, we'll start very simple with one of the most basic models for *NLP* (natural language processing): a *NBoW* (*neural bag-of-words*) model (also known as *continuous bag-of-words*, *CBoW*). The NBoW model are a strong, commonly used, baseline model for NLP tasks. They should be one of the first models you implement when performing sentiment analysis/text classification.\n",
"\n",
"![](assets/nbow_model.png)\n",
"\n",
"An NBoW model takes in a sequence of $T$ *tokens*, $X=\\{x_1,...,x_T\\} \\in \\mathbb{Z}^T$ and passes each token through an *embedding layer* to obtain a sequence of *embedding vectors*. The sequence of embedding vectors is just known as an *embedding*, $E=\\{e_1,...,e_T\\} \\in \\mathbb{R}^{T \\times D}$, where $D$ is known as the *embedding dimension*. It then *pools* the embeddings across the sequence dimension to get $P \\in \\mathbb{R}^D$ and then finally passes $P$ through a linear layer (also known as a fully connected layer), to get a prediction, $\\hat{Y} \\in \\mathbb{R}^C$, where $C$ is the number of classes. We'll explain what a token is, and what each of the layers -- embedding layer, pooling, and linear layer -- do in due course. \n",
"\n",
"A note on notation, what does something like $E=\\{e_1,...,e_T\\} \\in \\mathbb{R}^{T \\times D}$ mean? $\\mathbb{R}^{T \\times D}$ means a $T \\times D$ sized tensor full of real numbers, i.e. a `torch.FloatTensor`. $X=\\{x_1,...,x_T\\} \\in \\mathbb{Z}^T$ is a $T$ sized tensor full of integers, i.e. a `torch.LongTensor`.\n",
"\n",
"## Preparing Data\n",
"\n",
"Before we can implement our NBoW model, we first have to perform quite a few steps to get our data ready to use. NLP usually requires quite a lot of data wrangling beforehand, though libraries such as `datasets` and `torchtext` handle most of this for us.\n",
"\n",
"The steps to take are:\n",
"- importing modules\n",
"- loading data\n",
"- tokenizing data\n",
"- creating data splits\n",
"- creating a vocabulary\n",
"- numericalizing data\n",
"- creating the dataloaders\n",
"\n",
"### Importing Modules\n",
"\n",
"First, we'll import the required modules. \n",
"\n",
"We use the `datasets` module for handling datasets, `matplotlib` for plotting our results, `numpy` for numerical analysis, `torch` for tensor computations, `torch.nn` for neural networks, `torch.optim` for neural network optimizers, `torchtext` for text processing, and `tqdm` for process bars."
]
},
2021-07-08 01:04:25 +08:00
{
"cell_type": "code",
"execution_count": 1,
"id": "e322bd29",
"metadata": {},
"outputs": [],
"source": [
"import functools\n",
2021-07-09 02:05:18 +08:00
"import sys\n",
2021-07-08 01:04:25 +08:00
"\n",
"import datasets\n",
2021-07-08 22:11:54 +08:00
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
2021-07-08 01:04:25 +08:00
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
2021-07-09 02:05:18 +08:00
"import torchtext\n",
"import tqdm"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
{
"cell_type": "markdown",
"id": "a5478fc3",
"metadata": {},
"source": [
"We'll also make sure to set the random seeds for `torch` and `numpy`. This is to ensure this notebook is reproducable, i.e. we get the same results each time we run it.\n",
"\n",
"It is usually good practice to run your experiments multiple times with different random seeds -- both to measure the variance of your model and also to avoid having results only calculated with either \"good\" or \"bad\" seeds, i.e. being very lucky or unlucky with the randomness in the training process."
]
},
2021-07-08 01:04:25 +08:00
{
"cell_type": "code",
"execution_count": 2,
"id": "fcc98ce9",
"metadata": {},
2021-07-16 05:41:50 +08:00
"outputs": [],
2021-07-08 01:04:25 +08:00
"source": [
"seed = 0\n",
"\n",
2021-07-16 05:41:50 +08:00
"torch.manual_seed(seed)\n",
"np.random.seed(seed)"
]
},
{
"cell_type": "markdown",
"id": "55b1eb74",
"metadata": {},
"source": [
"Next, we'll load our dataset using the `datasets` library. The first argument is the name of the dataset and the `split` argument chooses which *splits* of the data we want. \n",
"\n",
"Datasets usually come in two or more *splits*, non-overlapping examples from the data, most commonly a *train split* -- which we train our model on -- and a *test split* -- which we evaluate our trained model on. There's also a *validation split*, which we'll talk more about later. The train, test and validation split are also commonly called the train, test and validation sets -- we'll use split and set interchangeably\n",
" in these tutorials -- and the dataset usually refers to all three of the sets combined. The IMDb dataset actually comes with a third split, called *unsupervised*, which contains a bunch of examples without labels. We don't want these so we don't include them in our `split` argument. Note that if we didn't pass an argument to `split` then it would load all available splits of the data.\n",
"\n",
"How do we know that we have to use \"imdb\" for the IMDb dataset and that there's an \"unsupervised\" split? The `datasets` library has a great website used to browse the available datasets, see: https://huggingface.co/datasets/. By navigating to the [IMDb dataset page](https://huggingface.co/datasets/imdb) we can see more information specifically about the IMDb dataset.\n",
"\n",
"The output received when loading the dataset tells us that it is using a locally cached version instead of downloading the dataset from online."
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 63,
2021-07-08 01:04:25 +08:00
"id": "798f5387",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset imdb (/home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)\n"
]
}
],
"source": [
2021-07-16 05:41:50 +08:00
"train_data, test_data = datasets.load_dataset(\"imdb\", split=[\"train\", \"test\"])"
]
},
{
"cell_type": "markdown",
"id": "93721296",
"metadata": {},
"source": [
"We can print out the splits which shows us the *features* and *num_rows* of the dataset. num_rows are the number of examples in split, as we can see, there are 25,000 examples in each. Each example in a dataset provided by the `datasets` library is a dictionary, and the features are the keys which appear in every one of those dictionaries/examples. So, each example in the IMDb dataset has a *text* and a *label* key."
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 64,
2021-07-08 01:04:25 +08:00
"id": "42338609",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Dataset({\n",
" features: ['label', 'text'],\n",
" num_rows: 25000\n",
" }),\n",
" Dataset({\n",
" features: ['label', 'text'],\n",
" num_rows: 25000\n",
" }))"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 64,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data, test_data"
]
},
2021-07-16 05:41:50 +08:00
{
"cell_type": "markdown",
"id": "8ec70556",
"metadata": {},
"source": [
"We can check the `features` attribute of a split to get more information about the features. We can see that *text* is a `Value` of `dtype=string` -- in other words, it's a string -- and that *label* is a `ClassLabel`. A `ClassLabel` means the feature is an integer representation of which class the example belongs to. `num_classes=2` means that our labels are one of two values, 0 or 1, and `names=['neg', 'pos']` gives us the human-readable versions of those values. Thus, a label of 0 means the example is a negative review and a label of 1 means the example is a positive review."
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "58f5cc56",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'label': ClassLabel(num_classes=2, names=['neg', 'pos'], names_file=None, id=None),\n",
" 'text': Value(dtype='string', id=None)}"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data.features"
]
},
{
"cell_type": "markdown",
"id": "84271369",
"metadata": {},
"source": [
"We can look at an example by indexing into the train set. As we can see, the text is quite noisy and also rambles on quite a bit."
]
},
2021-07-08 01:04:25 +08:00
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 66,
2021-07-08 01:04:25 +08:00
"id": "25a6e8cb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'label': 1,\n",
" 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as \"Teachers\". My 35 years in the teaching profession lead me to believe that Bromwell High\\'s satire is much closer to reality than is \"Teachers\". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\\'t!'}"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 66,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[0]"
]
},
{
2021-07-16 05:41:50 +08:00
"cell_type": "markdown",
"id": "f8536207",
2021-07-08 01:04:25 +08:00
"metadata": {},
"source": [
2021-07-16 05:41:50 +08:00
"One of the first things we need to do to our data is *tokenize* it. Machine learning models aren't designed to handle strings, they're design to handle numbers. So what we need to do is break down our string into individual *tokens*, and then convert these tokens to numbers. We'll get to the conversion later, but first we'll look at *tokenization*.\n",
"\n",
"Tokenization involves using a *tokenizer* to process the strings in our dataset. A tokenizer is a function that goes from a string to a list of strings. There are many types of tokenizers available, but we're going to use a relatively simple one provided by `torchtext` called the `basic_english` tokenizer. We load our tokenizer as such:"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 67,
"id": "3017c0ab",
2021-07-08 01:04:25 +08:00
"metadata": {},
"outputs": [],
"source": [
2021-07-16 05:41:50 +08:00
"tokenizer = torchtext.data.utils.get_tokenizer(\"basic_english\")"
2021-07-08 01:04:25 +08:00
]
},
{
2021-07-16 05:41:50 +08:00
"cell_type": "markdown",
"id": "4db58859",
2021-07-08 01:04:25 +08:00
"metadata": {},
"source": [
2021-07-16 05:41:50 +08:00
"We can use the tokenizer by calling it on a string.\n",
"\n",
"Notice it creates a token by splitting the word on spaces, puts punctuation as its own token, and also lowercases every single word.\n",
2021-07-08 01:04:25 +08:00
"\n",
2021-07-16 05:41:50 +08:00
"The `get_tokenizer` function also supports other tokenizers, such as ones provided by [spaCy](https://spacy.io/) and [nltk](https://www.nltk.org/). "
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 68,
"id": "2d0de969",
2021-07-08 01:04:25 +08:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-16 05:41:50 +08:00
"['hello',\n",
" 'world',\n",
" '!',\n",
" 'how',\n",
" 'are',\n",
" 'you',\n",
" 'doing',\n",
" 'today',\n",
" '?',\n",
" 'i',\n",
" \"'\",\n",
" 'm',\n",
" 'doing',\n",
" 'fantastic',\n",
" '!']"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 68,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-07-16 05:41:50 +08:00
"tokenizer(\"Hello world! How are you doing today? I'm doing fantastic!\")"
]
},
{
"cell_type": "markdown",
"id": "593711b9",
"metadata": {},
"source": [
"Now we have our tokenizer defined, we want to actually tokenize our data.\n",
"\n",
"Each dataset provided by the `datasets` library is an instance of a `Dataset` class. We can see all the methods in a `Dataset` [here](https://huggingface.co/docs/datasets/package_reference/main_classes.html#dataset), but the main one we are interested in is [`map`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map). By using `map` we can apply a function to every example in the dataset and either update the example or create a new feature.\n",
"\n",
"We define the `tokenize_example` function below which takes in an `example`, a `tokenizer` and a `max_length` argument, tokenizes the text in the example, given by `example['text']`, trims the tokens to a maximum length and then returns a dictionary with the new feature name and feature value for that example. Note that the first argument to a function which we are going to `map` must always be the example dictionary, and it must always return a dictionary where the keys are the feature names and the values are the feature values to be added to this example. \n",
"\n",
"We're trimming the tokens to a maximum length here as some examples are unnecessarily long and we can predict sentiment pretty well just using the first couple of hundred tokens -- though this might not be true for you if you're using a different dataset!"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 69,
"id": "876ad3b9",
2021-07-08 01:04:25 +08:00
"metadata": {},
2021-07-16 05:41:50 +08:00
"outputs": [],
2021-07-08 01:04:25 +08:00
"source": [
2021-07-16 05:41:50 +08:00
"def tokenize_example(example, tokenizer, max_length):\n",
" tokens = tokenizer(example['text'])[:max_length]\n",
" return {'tokens': tokens}"
]
},
{
"cell_type": "markdown",
"id": "35129a1b",
"metadata": {},
"source": [
"We apply the `tokenize_example` function below, on both the train and test sets. Any arguments to the function -- that aren't the example -- need to be passed as the `fn_kwargs` dictionary, with the keys being the argument names and the values the value passed to that argument.\n",
"\n",
"Operations on a `Dataset` are **not** performed in-place. You should always return the result into a new variable.\n",
"\n",
"Note the warnings showing that as I have performed this `map` before, the results are cached and are thus loaded from the cache instead of being calculated again."
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 70,
"id": "5e295030",
2021-07-08 01:04:25 +08:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-07-16 05:41:50 +08:00
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-ad1b7a77180a232c.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-01c0069c185da175.arrow\n"
]
}
],
2021-07-08 01:04:25 +08:00
"source": [
2021-07-16 05:41:50 +08:00
"max_length = 256\n",
2021-07-08 01:04:25 +08:00
"\n",
2021-07-16 05:41:50 +08:00
"train_data = train_data.map(tokenize_example, fn_kwargs={'tokenizer': tokenizer, 'max_length': max_length})\n",
"test_data = test_data.map(tokenize_example, fn_kwargs={'tokenizer': tokenizer, 'max_length': max_length})"
2021-07-08 01:04:25 +08:00
]
},
{
2021-07-16 05:41:50 +08:00
"cell_type": "markdown",
"id": "a61b38c0",
2021-07-08 01:04:25 +08:00
"metadata": {},
2021-07-16 05:41:50 +08:00
"source": [
"We can now see that our `train_data` has a *tokens* feature -- as \"tokens\" was a key in the dictionary returned by the function we used for the `map`."
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "f647bdf9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['label', 'text', 'tokens'],\n",
" num_rows: 25000\n",
"})"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 71,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2021-07-16 05:41:50 +08:00
"train_data"
]
},
{
"cell_type": "markdown",
"id": "db3443a0",
"metadata": {},
"source": [
"By looking at the `features` attribute we can see it has automatically added the information about the tokens feature -- each is a sequence (a list) of strings. A `length=-1` means that all of our token sequences are not the same length."
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "1605d52b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'label': ClassLabel(num_classes=2, names=['neg', 'pos'], names_file=None, id=None),\n",
" 'text': Value(dtype='string', id=None),\n",
" 'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}"
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data.features"
]
},
{
"cell_type": "markdown",
"id": "1735d91a",
"metadata": {},
"source": [
"We can check the first example in our train set to see the result of the tokenization:"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "2f3de3b9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['bromwell',\n",
" 'high',\n",
" 'is',\n",
" 'a',\n",
" 'cartoon',\n",
" 'comedy',\n",
" '.',\n",
" 'it',\n",
" 'ran',\n",
" 'at',\n",
" 'the',\n",
" 'same',\n",
" 'time',\n",
" 'as',\n",
" 'some',\n",
" 'other',\n",
" 'programs',\n",
" 'about',\n",
" 'school',\n",
" 'life',\n",
" ',',\n",
" 'such',\n",
" 'as',\n",
" 'teachers',\n",
" '.']"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[0]['tokens'][:25]"
]
},
{
"cell_type": "markdown",
"id": "04d4ee14",
"metadata": {},
"source": [
"Next up, we'll create a *validation set* from our data. This is similar to our test set in that we do not train our model on it, we only evaluate our model on it. \n",
"\n",
"Why have both a validation set and a test set? Your test set respresents the real world data that you'd see if you actually deployed this model. You won't be able to see what data your model will be fed once deployed, and your test set is supposed to reflect that. Every time we tune our model hyperparameters or training set-up to make it do a bit better on the test set, we are leak information from the test set into the training process. If we do this too often then we begin to overfit on the test set. Hence, we need some data which can act as a \"proxy\" test set which we can look at more frequently in order to evaluate how well our model actually does on unseen data -- this is the validation set.\n",
"\n",
"We can split a `Dataset` using the `train_test_split` method which splits a dataset into two, creating a `DatasetDict` for each split, one called `train` and another called `test` -- a bit confusing because these are our train and validation sets, not the test. We use `test_size` to set the portion of the data used for the validation set -- 0.25 means we use 25% of the training set -- and the examples are chosen randomly."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "15e48bfb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached split indices for dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-09bdb9cf28fcbb3c.arrow and /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-8e0a9e291c417a75.arrow\n"
]
}
],
"source": [
"test_size = 0.25\n",
"\n",
"train_valid_data = train_data.train_test_split(test_size=test_size)\n",
"train_data = train_valid_data['train']\n",
"valid_data = train_valid_data['test']"
]
},
{
"cell_type": "markdown",
"id": "870c829b",
"metadata": {},
"source": [
"By showing the lengths of each split within our dataset, we can see the 25,000 training examples have now been split into 18,750 training examples and 6,250 validation examples, with the original 25,000 test examples remaining untouched."
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 14,
2021-07-08 01:04:25 +08:00
"id": "c227e4fc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(18750, 6250, 25000)"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 14,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_data), len(valid_data), len(test_data)"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 15,
2021-07-08 01:04:25 +08:00
"id": "4865e94a",
"metadata": {},
"outputs": [],
"source": [
"min_freq = 5\n",
"special_tokens = ['<unk>', '<pad>']\n",
"\n",
"vocab = torchtext.vocab.build_vocab_from_iterator(train_data['tokens'],\n",
" min_freq=min_freq,\n",
" specials=special_tokens)"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 16,
2021-07-08 01:04:25 +08:00
"id": "123ceb33",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-16 05:41:50 +08:00
"21526"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 16,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(vocab)"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 17,
2021-07-08 01:04:25 +08:00
"id": "d4ec89de",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['<unk>', '<pad>', 'the', '.', ',', 'a', 'and', 'of', 'to', \"'\"]"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 17,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab.get_itos()[:10]"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 18,
2021-07-08 01:04:25 +08:00
"id": "29ac49c8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 18,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"unk_index = vocab['<unk>']\n",
"\n",
"unk_index"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 19,
2021-07-08 01:04:25 +08:00
"id": "447020e1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 19,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pad_index = vocab['<pad>']\n",
"\n",
"pad_index"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 20,
2021-07-08 01:04:25 +08:00
"id": "201b5383",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 20,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"'some_token' in vocab"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 21,
2021-07-08 01:04:25 +08:00
"id": "7a951ea0",
"metadata": {},
"outputs": [],
"source": [
"vocab.set_default_index(unk_index)"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 22,
2021-07-08 01:04:25 +08:00
"id": "407fe05d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 22,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab['some_token']"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 23,
2021-07-08 01:04:25 +08:00
"id": "76518d11",
"metadata": {},
"outputs": [],
"source": [
"def numericalize_data(example, vocab):\n",
2021-07-08 18:26:28 +08:00
" ids = [vocab[token] for token in example['tokens']]\n",
" return {'ids': ids}"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 24,
2021-07-08 01:04:25 +08:00
"id": "dacaeaef",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-07-16 05:41:50 +08:00
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-d266a0df023fa6e2.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-296fd4058bb43b50.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-6bc06b9661e3abbb.arrow\n"
]
2021-07-08 01:04:25 +08:00
}
],
"source": [
"train_data = train_data.map(numericalize_data, fn_kwargs={'vocab': vocab})\n",
"valid_data = valid_data.map(numericalize_data, fn_kwargs={'vocab': vocab})\n",
"test_data = test_data.map(numericalize_data, fn_kwargs={'vocab': vocab})"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 25,
2021-07-08 01:04:25 +08:00
"id": "08751c45",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-16 05:41:50 +08:00
"{'label': 0,\n",
" 'text': 'This documentary is at its best when it is simply showing the ayurvedic healers\\' offices and treatment preparation. There is no denying the grinding poverty in India and desperation of even their wealthier clients. However, as an argument for ayurvedic medicine in general, this film fails miserably. Although Indian clients mention having seen \"aleopathic\" doctors, those doctors are not interviewed, and we have to take the vague statements of their patients at face value-- \"the doctor said there was no cure,\" \"the doctor said it was cancer\" etc. Well, \"no cure\" doesn\\'t mean \"no treatment,\" and what type of cancer exactly does the patient have? The film is at its most feeble when showing ayurvedic practice in America. There it is reduced, apparently, to the stunning suggestion that having a high powered Wall Street job can make your stomach hurt.',\n",
" 'tokens': ['this',\n",
" 'documentary',\n",
" 'is',\n",
" 'at',\n",
" 'its',\n",
" 'best',\n",
" 'when',\n",
" 'it',\n",
2021-07-08 01:04:25 +08:00
" 'is',\n",
2021-07-16 05:41:50 +08:00
" 'simply',\n",
" 'showing',\n",
" 'the',\n",
" 'ayurvedic',\n",
" 'healers',\n",
2021-07-08 01:04:25 +08:00
" \"'\",\n",
2021-07-16 05:41:50 +08:00
" 'offices',\n",
2021-07-08 01:04:25 +08:00
" 'and',\n",
2021-07-16 05:41:50 +08:00
" 'treatment',\n",
" 'preparation',\n",
2021-07-08 01:04:25 +08:00
" '.',\n",
2021-07-16 05:41:50 +08:00
" 'there',\n",
2021-07-08 01:04:25 +08:00
" 'is',\n",
2021-07-16 05:41:50 +08:00
" 'no',\n",
" 'denying',\n",
2021-07-08 01:04:25 +08:00
" 'the',\n",
2021-07-16 05:41:50 +08:00
" 'grinding',\n",
" 'poverty',\n",
2021-07-08 01:04:25 +08:00
" 'in',\n",
2021-07-16 05:41:50 +08:00
" 'india',\n",
" 'and',\n",
" 'desperation',\n",
2021-07-08 01:04:25 +08:00
" 'of',\n",
2021-07-16 05:41:50 +08:00
" 'even',\n",
" 'their',\n",
" 'wealthier',\n",
" 'clients',\n",
" '.',\n",
" 'however',\n",
2021-07-08 01:04:25 +08:00
" ',',\n",
" 'as',\n",
2021-07-16 05:41:50 +08:00
" 'an',\n",
" 'argument',\n",
" 'for',\n",
" 'ayurvedic',\n",
" 'medicine',\n",
2021-07-08 01:04:25 +08:00
" 'in',\n",
2021-07-16 05:41:50 +08:00
" 'general',\n",
2021-07-08 01:04:25 +08:00
" ',',\n",
" 'this',\n",
" 'film',\n",
2021-07-16 05:41:50 +08:00
" 'fails',\n",
" 'miserably',\n",
2021-07-08 01:04:25 +08:00
" '.',\n",
2021-07-16 05:41:50 +08:00
" 'although',\n",
" 'indian',\n",
" 'clients',\n",
" 'mention',\n",
" 'having',\n",
" 'seen',\n",
" 'aleopathic',\n",
" 'doctors',\n",
2021-07-08 01:04:25 +08:00
" ',',\n",
2021-07-16 05:41:50 +08:00
" 'those',\n",
" 'doctors',\n",
" 'are',\n",
" 'not',\n",
" 'interviewed',\n",
2021-07-08 01:04:25 +08:00
" ',',\n",
" 'and',\n",
2021-07-16 05:41:50 +08:00
" 'we',\n",
" 'have',\n",
2021-07-08 01:04:25 +08:00
" 'to',\n",
2021-07-16 05:41:50 +08:00
" 'take',\n",
2021-07-08 01:04:25 +08:00
" 'the',\n",
2021-07-16 05:41:50 +08:00
" 'vague',\n",
" 'statements',\n",
" 'of',\n",
" 'their',\n",
" 'patients',\n",
" 'at',\n",
" 'face',\n",
" 'value--',\n",
2021-07-08 01:04:25 +08:00
" 'the',\n",
2021-07-16 05:41:50 +08:00
" 'doctor',\n",
" 'said',\n",
" 'there',\n",
" 'was',\n",
" 'no',\n",
" 'cure',\n",
" ',',\n",
2021-07-08 01:04:25 +08:00
" 'the',\n",
2021-07-16 05:41:50 +08:00
" 'doctor',\n",
" 'said',\n",
" 'it',\n",
" 'was',\n",
" 'cancer',\n",
" 'etc',\n",
2021-07-08 01:04:25 +08:00
" '.',\n",
2021-07-16 05:41:50 +08:00
" 'well',\n",
2021-07-08 01:04:25 +08:00
" ',',\n",
2021-07-16 05:41:50 +08:00
" 'no',\n",
" 'cure',\n",
" 'doesn',\n",
" \"'\",\n",
" 't',\n",
" 'mean',\n",
" 'no',\n",
" 'treatment',\n",
2021-07-08 01:04:25 +08:00
" ',',\n",
" 'and',\n",
2021-07-16 05:41:50 +08:00
" 'what',\n",
" 'type',\n",
2021-07-08 01:04:25 +08:00
" 'of',\n",
2021-07-16 05:41:50 +08:00
" 'cancer',\n",
" 'exactly',\n",
" 'does',\n",
" 'the',\n",
" 'patient',\n",
" 'have',\n",
" '?',\n",
" 'the',\n",
" 'film',\n",
" 'is',\n",
" 'at',\n",
" 'its',\n",
" 'most',\n",
" 'feeble',\n",
" 'when',\n",
" 'showing',\n",
" 'ayurvedic',\n",
" 'practice',\n",
" 'in',\n",
" 'america',\n",
" '.',\n",
" 'there',\n",
" 'it',\n",
" 'is',\n",
" 'reduced',\n",
" ',',\n",
" 'apparently',\n",
" ',',\n",
" 'to',\n",
" 'the',\n",
" 'stunning',\n",
" 'suggestion',\n",
" 'that',\n",
" 'having',\n",
" 'a',\n",
" 'high',\n",
" 'powered',\n",
" 'wall',\n",
" 'street',\n",
" 'job',\n",
" 'can',\n",
" 'make',\n",
" 'your',\n",
" 'stomach',\n",
" 'hurt',\n",
" '.'],\n",
" 'ids': [14,\n",
" 627,\n",
" 10,\n",
" 37,\n",
" 100,\n",
" 125,\n",
" 60,\n",
" 11,\n",
2021-07-08 01:04:25 +08:00
" 10,\n",
2021-07-16 05:41:50 +08:00
" 361,\n",
" 834,\n",
" 2,\n",
" 0,\n",
" 0,\n",
2021-07-08 01:04:25 +08:00
" 9,\n",
2021-07-16 05:41:50 +08:00
" 12187,\n",
" 6,\n",
2021-07-16 05:41:50 +08:00
" 2407,\n",
" 9694,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
2021-07-16 05:41:50 +08:00
" 46,\n",
2021-07-08 01:04:25 +08:00
" 10,\n",
2021-07-16 05:41:50 +08:00
" 66,\n",
" 8861,\n",
2021-07-08 01:04:25 +08:00
" 2,\n",
2021-07-16 05:41:50 +08:00
" 16732,\n",
" 3705,\n",
" 13,\n",
2021-07-16 05:41:50 +08:00
" 2360,\n",
" 6,\n",
" 4374,\n",
2021-07-08 01:04:25 +08:00
" 7,\n",
2021-07-16 05:41:50 +08:00
" 69,\n",
" 77,\n",
" 0,\n",
" 13332,\n",
" 3,\n",
" 190,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 19,\n",
2021-07-16 05:41:50 +08:00
" 41,\n",
" 4597,\n",
" 21,\n",
" 0,\n",
" 6574,\n",
" 13,\n",
2021-07-16 05:41:50 +08:00
" 822,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 14,\n",
" 23,\n",
2021-07-16 05:41:50 +08:00
" 962,\n",
" 3426,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
2021-07-16 05:41:50 +08:00
" 265,\n",
" 1267,\n",
" 13332,\n",
" 798,\n",
" 266,\n",
" 111,\n",
" 0,\n",
" 5592,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
2021-07-16 05:41:50 +08:00
" 157,\n",
" 5592,\n",
" 30,\n",
" 29,\n",
" 8351,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 6,\n",
2021-07-16 05:41:50 +08:00
" 78,\n",
" 31,\n",
2021-07-08 01:04:25 +08:00
" 8,\n",
2021-07-16 05:41:50 +08:00
" 203,\n",
2021-07-08 01:04:25 +08:00
" 2,\n",
2021-07-16 05:41:50 +08:00
" 3400,\n",
" 6614,\n",
" 7,\n",
" 77,\n",
" 5229,\n",
" 37,\n",
" 454,\n",
2021-07-08 01:04:25 +08:00
" 0,\n",
" 2,\n",
2021-07-16 05:41:50 +08:00
" 937,\n",
" 307,\n",
" 46,\n",
" 17,\n",
" 66,\n",
" 4845,\n",
" 4,\n",
2021-07-08 01:04:25 +08:00
" 2,\n",
2021-07-16 05:41:50 +08:00
" 937,\n",
" 307,\n",
" 11,\n",
" 17,\n",
" 5362,\n",
" 487,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
2021-07-16 05:41:50 +08:00
" 82,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
2021-07-16 05:41:50 +08:00
" 66,\n",
" 4845,\n",
" 173,\n",
" 9,\n",
" 28,\n",
" 384,\n",
" 66,\n",
" 2407,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 6,\n",
2021-07-16 05:41:50 +08:00
" 55,\n",
" 618,\n",
" 7,\n",
2021-07-16 05:41:50 +08:00
" 5362,\n",
" 615,\n",
" 135,\n",
" 2,\n",
" 3307,\n",
" 31,\n",
" 56,\n",
" 2,\n",
" 23,\n",
" 10,\n",
" 37,\n",
" 100,\n",
" 94,\n",
" 6702,\n",
" 60,\n",
" 834,\n",
" 0,\n",
" 4335,\n",
" 13,\n",
" 865,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
2021-07-16 05:41:50 +08:00
" 46,\n",
" 11,\n",
" 10,\n",
" 4647,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
2021-07-16 05:41:50 +08:00
" 694,\n",
" 4,\n",
" 8,\n",
" 2,\n",
" 1253,\n",
" 5657,\n",
" 15,\n",
" 266,\n",
" 5,\n",
" 325,\n",
" 10526,\n",
" 1698,\n",
" 874,\n",
" 279,\n",
" 59,\n",
" 105,\n",
" 133,\n",
" 3035,\n",
" 1559,\n",
" 3]}"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 25,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[0]"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 26,
2021-07-08 01:04:25 +08:00
"id": "678d0397",
"metadata": {},
"outputs": [],
"source": [
"train_data = train_data.with_format(type='torch', columns=['ids', 'label'])\n",
"valid_data = valid_data.with_format(type='torch', columns=['ids', 'label'])\n",
"test_data = test_data.with_format(type='torch', columns=['ids', 'label'])"
]
},
{
"cell_type": "markdown",
"id": "00a00726",
"metadata": {},
"source": [
"Same thing as `set_format`, but not in-place."
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 27,
2021-07-08 01:04:25 +08:00
"id": "be56bf90",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-16 05:41:50 +08:00
"{'label': tensor(0),\n",
" 'ids': tensor([ 14, 627, 10, 37, 100, 125, 60, 11, 10, 361,\n",
" 834, 2, 0, 0, 9, 12187, 6, 2407, 9694, 3,\n",
" 46, 10, 66, 8861, 2, 16732, 3705, 13, 2360, 6,\n",
" 4374, 7, 69, 77, 0, 13332, 3, 190, 4, 19,\n",
" 41, 4597, 21, 0, 6574, 13, 822, 4, 14, 23,\n",
" 962, 3426, 3, 265, 1267, 13332, 798, 266, 111, 0,\n",
" 5592, 4, 157, 5592, 30, 29, 8351, 4, 6, 78,\n",
" 31, 8, 203, 2, 3400, 6614, 7, 77, 5229, 37,\n",
" 454, 0, 2, 937, 307, 46, 17, 66, 4845, 4,\n",
" 2, 937, 307, 11, 17, 5362, 487, 3, 82, 4,\n",
" 66, 4845, 173, 9, 28, 384, 66, 2407, 4, 6,\n",
" 55, 618, 7, 5362, 615, 135, 2, 3307, 31, 56,\n",
" 2, 23, 10, 37, 100, 94, 6702, 60, 834, 0,\n",
" 4335, 13, 865, 3, 46, 11, 10, 4647, 4, 694,\n",
" 4, 8, 2, 1253, 5657, 15, 266, 5, 325, 10526,\n",
" 1698, 874, 279, 59, 105, 133, 3035, 1559, 3])}"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 27,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[0]"
]
},
2021-07-16 05:41:50 +08:00
{
"cell_type": "code",
"execution_count": 28,
"id": "d97786a1",
"metadata": {},
"outputs": [],
"source": [
"def collate(batch, pad_index):\n",
" batch_ids = [i['ids'] for i in batch]\n",
" batch_ids = nn.utils.rnn.pad_sequence(batch_ids, padding_value=pad_index, batch_first=True)\n",
" batch_label = [i['label'] for i in batch]\n",
" batch_label = torch.stack(batch_label)\n",
" batch = {'ids': batch_ids,\n",
" 'label': batch_label}\n",
" return batch"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "d3098a96",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 512\n",
"\n",
"collate = functools.partial(collate, pad_index=pad_index)\n",
"\n",
"train_dataloader = torch.utils.data.DataLoader(train_data, \n",
" batch_size=batch_size, \n",
" collate_fn=collate, \n",
" shuffle=True)\n",
"\n",
"valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, collate_fn=collate)\n",
"test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, collate_fn=collate)"
]
},
2021-07-08 01:04:25 +08:00
{
"cell_type": "markdown",
"id": "d6ba2ac8",
"metadata": {},
"source": [
"Use `output_all_columns=True` to keep non-converted columns."
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 30,
2021-07-08 01:04:25 +08:00
"id": "081f04a6",
"metadata": {},
"outputs": [],
"source": [
"class NBoW(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, output_dim, pad_index):\n",
" super().__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
" self.fc = nn.Linear(embedding_dim, output_dim)\n",
"\n",
" def forward(self, ids):\n",
" # ids = [batch size, seq len]\n",
" embedded = self.embedding(ids)\n",
2021-07-08 01:04:25 +08:00
" # embedded = [batch size, seq len, embedding dim]\n",
" pooled = embedded.mean(dim=1)\n",
" # pooled = [batch size, embedding dim]\n",
" prediction = self.fc(pooled)\n",
" # prediction = [batch size, output dim]\n",
" return prediction"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 31,
2021-07-08 01:04:25 +08:00
"id": "97897898",
"metadata": {},
"outputs": [],
"source": [
"vocab_size = len(vocab)\n",
"embedding_dim = 300\n",
"output_dim = len(train_data.unique('label'))\n",
"\n",
"model = NBoW(vocab_size, embedding_dim, output_dim, pad_index)"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 32,
"id": "4acc5118",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-07-16 05:41:50 +08:00
"The model has 6,458,402 trainable parameters\n"
]
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 33,
2021-07-08 01:04:25 +08:00
"id": "866e0b64",
"metadata": {},
"outputs": [],
"source": [
"vectors = torchtext.vocab.FastText()"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 34,
2021-07-08 01:04:25 +08:00
"id": "ead7be53",
"metadata": {},
"outputs": [],
"source": [
"hello_vector = vectors.get_vecs_by_tokens('hello')"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 35,
2021-07-08 01:04:25 +08:00
"id": "1a64ead7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([300])"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 35,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hello_vector.shape"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 36,
2021-07-08 01:04:25 +08:00
"id": "7ecc5d88",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-1.5945e-01, -1.8259e-01, 3.3443e-02, 1.8813e-01, -6.7903e-02,\n",
" -1.3663e-01, -2.5559e-01, 1.1000e-01, 1.7275e-01, 5.1971e-02,\n",
" -2.3302e-02, 3.8866e-02, -2.4515e-01, -2.1588e-01, 3.5925e-01,\n",
" -8.2526e-02, 1.2176e-01, -2.6775e-01, 1.0072e-01, -1.3639e-01,\n",
" -9.2658e-02, 5.1837e-01, 1.7736e-01, 9.4878e-02, -1.8461e-01,\n",
" -4.2829e-02, 1.4114e-02, 1.6811e-01, -1.8565e-01, 3.4976e-02,\n",
" -1.0293e-01, 1.7954e-01, -5.2766e-02, 7.2047e-02, -4.2704e-01,\n",
" -1.1616e-01, -9.4875e-03, 1.4199e-01, -2.2782e-01, -1.7292e-02,\n",
" 8.2802e-02, -4.4512e-01, -7.5935e-02, -1.4392e-01, -8.2461e-02,\n",
" 2.0123e-01, -9.5344e-02, -1.1042e-01, -4.6817e-01, 2.0362e-01,\n",
" -1.7140e-01, -4.9850e-01, 2.8963e-01, -1.0305e-01, 2.0393e-01,\n",
" 5.2971e-01, -2.5396e-01, -5.1891e-01, 2.9941e-01, 1.7933e-01,\n",
" 3.0683e-01, 2.5828e-01, -1.8168e-01, -1.0225e-01, -1.1435e-01,\n",
" -1.6304e-01, -1.2424e-01, 3.2814e-01, -2.3099e-01, 1.7912e-01,\n",
" 9.9206e-02, 1.8595e-01, 2.7996e-01, 1.8323e-01, -1.7397e-01,\n",
" 2.6633e-01, -1.8151e-02, 2.8386e-01, 1.7328e-01, 2.9131e-01,\n",
" 8.2289e-02, 1.8560e-01, -1.5544e-01, 2.3311e-01, 3.6578e-01,\n",
" -3.0802e-01, -1.5908e-01, 4.0382e-01, 1.5332e-01, -1.1630e-01,\n",
" 1.3978e-01, 6.4237e-02, 2.2087e-01, 8.2723e-02, 1.2785e-01,\n",
" -6.6854e-02, -2.3016e-02, -1.9224e-01, -5.4482e-02, 3.7509e-01,\n",
" 5.1194e-01, -2.3650e-01, -7.1224e-02, 8.1112e-02, -3.2017e-01,\n",
" 5.0264e-02, -3.3223e-01, 2.2167e-02, 9.9936e-02, -2.7215e-01,\n",
" -7.2833e-02, -3.6598e-01, 1.7541e-01, -3.1303e-01, -2.3134e-01,\n",
" -1.5491e-01, 3.2102e-01, 1.2347e-01, 7.3616e-02, 2.0575e-01,\n",
" 6.1732e-01, 7.1909e-02, -3.6930e-01, 4.7641e-01, 1.7456e-01,\n",
" 3.2928e-01, 2.8792e-01, -7.6989e-02, 2.7030e-01, 6.9828e-01,\n",
" 4.6247e-01, 4.1444e-01, -5.3405e-01, 4.4302e-01, 1.1631e-01,\n",
" -2.3425e-01, -1.5030e-01, -6.8092e-02, 3.3537e-01, 2.8618e-01,\n",
" -3.9781e-02, 2.3245e-01, 3.6262e-01, -1.7151e-01, -3.5204e-01,\n",
" 1.9951e-01, 1.1345e-01, -4.5134e-01, -3.9699e-03, -2.0620e-01,\n",
" -4.9251e-02, 1.0825e-01, 1.2571e-01, -2.8134e-01, 1.0355e-01,\n",
" 7.3498e-02, -2.6716e-01, -1.0001e-01, -2.2600e-01, 3.0784e-01,\n",
" 2.5934e-01, -1.8112e-03, -2.0522e-01, -2.5115e-01, -1.5368e-01,\n",
" 5.6060e-02, -6.4802e-02, 9.2786e-03, 2.6150e-01, -9.3972e-02,\n",
" -3.1032e-01, -2.6632e-01, -1.9598e-01, -4.5088e-02, -2.7611e-02,\n",
" -7.7027e-02, 1.5070e-01, 1.7185e-01, -8.5416e-02, -1.4448e-01,\n",
" -2.4800e-03, -3.2881e-01, -1.6913e-01, -1.2778e-01, -2.3352e-01,\n",
" 1.5178e-01, -6.9358e-01, -3.8922e-01, 3.7190e-01, 2.6020e-01,\n",
" -1.0232e-01, -6.0247e-01, -5.4548e-02, 6.6532e-01, -7.3208e-02,\n",
" -2.3644e-01, -2.5550e-01, 1.9755e-02, -4.8908e-01, -7.3706e-02,\n",
" 3.0545e-01, 2.4459e-01, 2.0426e-01, -3.0128e-01, 6.0666e-02,\n",
" 1.8107e-02, -9.6162e-02, -2.0348e-02, -1.9801e-04, 2.9652e-02,\n",
" 5.0787e-01, -2.0225e-01, -6.1565e-02, -2.7330e-01, -3.7789e-01,\n",
" -2.4373e-01, 9.4902e-02, -3.7236e-01, -8.5854e-02, 2.4096e-01,\n",
" -1.7998e-01, 7.3902e-02, -7.8217e-04, -1.8559e-01, -2.6445e-01,\n",
" -2.3306e-02, -1.8644e-01, -1.0638e-01, 8.9330e-02, 4.1039e-01,\n",
" 1.0452e-02, -9.8721e-03, -1.8335e-01, -2.8524e-01, -1.4771e-01,\n",
" -1.9499e-01, -1.0175e-01, 1.2292e-01, 8.3651e-02, -2.1228e-01,\n",
" 3.4773e-02, 6.1831e-02, 2.9237e-01, 1.4371e-01, -9.2354e-02,\n",
" 8.1267e-03, 2.7648e-01, 2.1753e-01, 2.6609e-01, -3.6083e-01,\n",
" 2.8347e-01, -2.9295e-01, -2.6441e-01, 2.1056e-01, 3.2068e-01,\n",
" -1.6156e-01, 1.5298e-01, -1.5577e-01, 2.2035e-01, -1.1888e-01,\n",
" 1.3766e-01, -9.9048e-02, 4.1584e-01, -3.6029e-02, -6.2504e-02,\n",
" 3.3177e-01, -1.3997e-01, 8.7884e-02, -2.1428e-01, -6.2643e-01,\n",
" -3.1293e-01, -3.4895e-01, 5.2294e-01, -1.2635e-01, -1.9371e-01,\n",
" -2.0631e-01, 5.3758e-01, -1.1522e-01, -2.3659e-01, 2.0457e-01,\n",
" 1.9534e-01, 3.3260e-01, -2.2254e-01, 8.1346e-02, -7.2798e-02,\n",
" -8.6357e-04, -1.0199e-01, 3.1601e-01, 2.0040e-01, 1.9014e-01,\n",
" -9.6766e-02, 2.5155e-01, -2.0484e-01, -4.5859e-01, 1.1687e-01,\n",
" -3.3574e-01, -3.3371e-01, 8.6787e-02, 2.4920e-01, 6.5367e-02])"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 36,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hello_vector"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 37,
2021-07-08 01:04:25 +08:00
"id": "e8540b4b",
"metadata": {},
"outputs": [],
"source": [
"pretrained_embedding = vectors.get_vecs_by_tokens(vocab.get_itos())"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 38,
2021-07-08 01:04:25 +08:00
"id": "9d31228e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-16 05:41:50 +08:00
"torch.Size([21526, 300])"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 38,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pretrained_embedding.shape"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 39,
2021-07-08 01:04:25 +08:00
"id": "3a6f4173",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([[-1.1258, -1.1524, -0.2506, ..., 0.8200, -0.6332, 1.2948],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.1483, 2.4187, 1.3279, ..., -1.0328, 1.1305, -0.5703],\n",
" ...,\n",
2021-07-16 05:41:50 +08:00
" [-0.9497, -1.5705, -0.5629, ..., -0.5853, 0.1596, -1.3159],\n",
" [ 0.6322, -0.5610, 0.4423, ..., -0.5541, -0.5787, -0.6026],\n",
" [ 1.1698, 0.1340, 1.5503, ..., 1.5039, -0.6415, 1.1412]],\n",
2021-07-08 01:04:25 +08:00
" requires_grad=True)"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 39,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.embedding.weight"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 40,
2021-07-08 01:04:25 +08:00
"id": "5c1cbd5c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.0653, -0.0930, -0.0176, ..., 0.1664, -0.1308, 0.0354],\n",
" ...,\n",
" [-0.1329, 0.2494, -0.3875, ..., 0.3734, 0.4520, -0.2060],\n",
2021-07-16 05:41:50 +08:00
" [-0.2301, -0.1799, -0.2485, ..., 0.5203, 0.6245, 0.1723],\n",
2021-07-08 01:04:25 +08:00
" [ 0.1161, -0.0390, 0.1120, ..., 0.0925, -0.1058, 0.5641]])"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 40,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pretrained_embedding"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 41,
2021-07-08 01:04:25 +08:00
"id": "6ea34c9b",
"metadata": {},
"outputs": [],
"source": [
"model.embedding.weight.data = pretrained_embedding"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 42,
2021-07-08 01:04:25 +08:00
"id": "1332d9a6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.0653, -0.0930, -0.0176, ..., 0.1664, -0.1308, 0.0354],\n",
" ...,\n",
" [-0.1329, 0.2494, -0.3875, ..., 0.3734, 0.4520, -0.2060],\n",
2021-07-16 05:41:50 +08:00
" [-0.2301, -0.1799, -0.2485, ..., 0.5203, 0.6245, 0.1723],\n",
2021-07-08 01:04:25 +08:00
" [ 0.1161, -0.0390, 0.1120, ..., 0.0925, -0.1058, 0.5641]],\n",
" requires_grad=True)"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 42,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.embedding.weight"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 43,
2021-07-08 01:04:25 +08:00
"id": "4fcb95e0",
"metadata": {},
"outputs": [],
"source": [
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 44,
2021-07-08 01:04:25 +08:00
"id": "f8829cd4",
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 45,
2021-07-08 01:04:25 +08:00
"id": "7ed273e0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 45,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"device"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 46,
2021-07-08 01:04:25 +08:00
"id": "3cdaf3b3",
"metadata": {},
"outputs": [],
"source": [
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 47,
2021-07-08 01:04:25 +08:00
"id": "729aa9c8",
"metadata": {},
"outputs": [],
"source": [
"def train(dataloader, model, criterion, optimizer, device):\n",
"\n",
" model.train()\n",
2021-07-08 22:11:54 +08:00
" epoch_losses = []\n",
" epoch_accs = []\n",
2021-07-08 01:04:25 +08:00
"\n",
2021-07-09 02:05:18 +08:00
" for batch in tqdm.tqdm(dataloader, desc='training...', file=sys.stdout):\n",
" ids = batch['ids'].to(device)\n",
2021-07-08 01:04:25 +08:00
" label = batch['label'].to(device)\n",
" prediction = model(ids)\n",
2021-07-08 01:04:25 +08:00
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
2021-07-08 22:11:54 +08:00
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
2021-07-08 01:04:25 +08:00
"\n",
2021-07-08 22:11:54 +08:00
" return epoch_losses, epoch_accs"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 48,
2021-07-08 01:04:25 +08:00
"id": "e0a80c30",
"metadata": {},
"outputs": [],
"source": [
"def evaluate(dataloader, model, criterion, device):\n",
" \n",
" model.eval()\n",
2021-07-08 22:11:54 +08:00
" epoch_losses = []\n",
" epoch_accs = []\n",
2021-07-08 01:04:25 +08:00
"\n",
" with torch.no_grad():\n",
2021-07-09 02:05:18 +08:00
" for batch in tqdm.tqdm(dataloader, desc='evaluating...', file=sys.stdout):\n",
" ids = batch['ids'].to(device)\n",
2021-07-08 01:04:25 +08:00
" label = batch['label'].to(device)\n",
" prediction = model(ids)\n",
2021-07-08 01:04:25 +08:00
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
2021-07-08 22:11:54 +08:00
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
2021-07-08 01:04:25 +08:00
"\n",
2021-07-08 22:11:54 +08:00
" return epoch_losses, epoch_accs"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 49,
2021-07-08 01:04:25 +08:00
"id": "703aa1e1",
"metadata": {},
"outputs": [],
"source": [
"def get_accuracy(prediction, label):\n",
" batch_size, _ = prediction.shape\n",
2021-07-08 01:04:25 +08:00
" predicted_classes = prediction.argmax(dim=-1)\n",
" correct_predictions = predicted_classes.eq(label).sum()\n",
" accuracy = correct_predictions / batch_size\n",
" return accuracy"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 50,
2021-07-08 01:04:25 +08:00
"id": "31343f1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-07-16 05:41:50 +08:00
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.43it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.76it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 1\n",
2021-07-16 05:41:50 +08:00
"train_loss: 0.681, train_acc: 0.625\n",
"valid_loss: 0.666, valid_acc: 0.704\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.49it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 17.17it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 2\n",
2021-07-16 05:41:50 +08:00
"train_loss: 0.645, train_acc: 0.721\n",
"valid_loss: 0.619, valid_acc: 0.739\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.24it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.93it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 3\n",
2021-07-16 05:41:50 +08:00
"train_loss: 0.586, train_acc: 0.761\n",
"valid_loss: 0.554, valid_acc: 0.777\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.52it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.60it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 4\n",
2021-07-16 05:41:50 +08:00
"train_loss: 0.514, train_acc: 0.810\n",
"valid_loss: 0.487, valid_acc: 0.819\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.56it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.70it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 5\n",
2021-07-16 05:41:50 +08:00
"train_loss: 0.445, train_acc: 0.846\n",
"valid_loss: 0.433, valid_acc: 0.843\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.24it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.74it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 6\n",
2021-07-16 05:41:50 +08:00
"train_loss: 0.389, train_acc: 0.870\n",
"valid_loss: 0.390, valid_acc: 0.857\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.61it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.95it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 7\n",
2021-07-16 05:41:50 +08:00
"train_loss: 0.346, train_acc: 0.884\n",
"valid_loss: 0.361, valid_acc: 0.863\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.37it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.98it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 8\n",
2021-07-16 05:41:50 +08:00
"train_loss: 0.312, train_acc: 0.896\n",
"valid_loss: 0.340, valid_acc: 0.870\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.57it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.57it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 9\n",
2021-07-16 05:41:50 +08:00
"train_loss: 0.286, train_acc: 0.907\n",
"valid_loss: 0.325, valid_acc: 0.875\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.31it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.36it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 10\n",
2021-07-16 05:41:50 +08:00
"train_loss: 0.262, train_acc: 0.915\n",
"valid_loss: 0.315, valid_acc: 0.877\n"
2021-07-08 01:04:25 +08:00
]
}
],
"source": [
"n_epochs = 10\n",
"best_valid_loss = float('inf')\n",
"\n",
2021-07-08 22:11:54 +08:00
"train_losses = []\n",
"train_accs = []\n",
"valid_losses = []\n",
"valid_accs = []\n",
"\n",
2021-07-08 01:04:25 +08:00
"for epoch in range(n_epochs):\n",
"\n",
" train_loss, train_acc = train(train_dataloader, model, criterion, optimizer, device)\n",
" valid_loss, valid_acc = evaluate(valid_dataloader, model, criterion, device)\n",
"\n",
2021-07-08 22:11:54 +08:00
" train_losses.extend(train_loss)\n",
" train_accs.extend(train_acc)\n",
" valid_losses.extend(valid_loss)\n",
" valid_accs.extend(valid_acc)\n",
" \n",
" epoch_train_loss = np.mean(train_loss)\n",
" epoch_train_acc = np.mean(train_acc)\n",
" epoch_valid_loss = np.mean(valid_loss)\n",
" epoch_valid_acc = np.mean(valid_acc)\n",
" \n",
" if epoch_valid_loss < best_valid_loss:\n",
" best_valid_loss = epoch_valid_loss\n",
2021-07-08 01:04:25 +08:00
" torch.save(model.state_dict(), 'nbow.pt')\n",
" \n",
" print(f'epoch: {epoch+1}')\n",
2021-07-08 22:11:54 +08:00
" print(f'train_loss: {epoch_train_loss:.3f}, train_acc: {epoch_train_acc:.3f}')\n",
" print(f'valid_loss: {epoch_valid_loss:.3f}, valid_acc: {epoch_valid_acc:.3f}')"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 51,
2021-07-08 22:11:54 +08:00
"id": "2d791c70",
"metadata": {},
"outputs": [
{
"data": {
2021-07-16 05:41:50 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFzCAYAAAB2A95GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAACIk0lEQVR4nO3dd5xcdbn48c93et3eS7JppCdAQggdFKkCooIgIHZRuRbUn+j12q5e0Yt6r14bViyAiIigQBTpNSQhvbfNZjfbe5l+fn+cc6bszm52k52dLc/79cprZs+cM/PdyWzy7PN9vs9XaZqGEEIIIYSYWJZsD0AIIYQQYiaSIEwIIYQQIgskCBNCCCGEyAIJwoQQQgghskCCMCGEEEKILJAgTAghhBAiC2zZHsBYFRUVaTU1NdkehhBCCCHEcW3cuLFV07TidI9NuSCspqaGDRs2ZHsYQgghhBDHpZSqHe4xmY4UQgghhMgCCcKEEEIIIbJAgjAhhBBCiCyYcjVhQgghhBg/4XCYo0ePEggEsj2UKc3lclFVVYXdbh/1NRKECSGEEDPY0aNH8fv91NTUoJTK9nCmJE3TaGtr4+jRo8yZM2fU18l0pBBCCDGDBQIBCgsLJQA7CUopCgsLx5xNzGgQppS6TCm1Rym1Xyl1Z5rHv6+U2mz82auU6szkeIQQQggxlARgJ+9E3sOMBWFKKSvwI+ByYAlwo1JqSfI5mqZ9WtO0UzVNOxX4IfBwpsYjhBBCiMmns7OTH//4xyd07RVXXEFnZ+eoz//qV7/K3XfffUKvlQmZzIStAfZrmnZQ07QQ8ABwzQjn3wjcn8HxCCGEEGKSGSkIi0QiI177+OOPk5eXl4FRTYxMBmGVQF3S10eNY0MopWYDc4Cnh3n8w0qpDUqpDS0tLeM+UCGEEEJkx5133smBAwc49dRT+dznPsezzz7Leeedx9VXX82SJfoE2tve9jZWrVrF0qVLueeee+LX1tTU0NrayuHDh1m8eDEf+tCHWLp0KZdccgkDAwMjvu7mzZtZu3YtK1as4Nprr6WjowOAH/zgByxZsoQVK1Zwww03APDcc89x6qmncuqpp3LaaafR09MzLt/7ZFkdeQPwkKZp0XQPapp2D3APwOrVq7WJHJgQQggxU3ztsR3sbOge1+dcUpHDV65aOuzjd911F9u3b2fz5s0APPvss2zatInt27fHVxr+6le/oqCggIGBAc444wze8Y53UFhYmPI8+/bt4/777+fnP/85119/PX/+85+5+eabh33d97znPfzwhz/kggsu4Mtf/jJf+9rX+J//+R/uuusuDh06hNPpjE913n333fzoRz/inHPOobe3F5fLdXJviiGTmbB6oDrp6yrjWDo3MEmmIsPRGI9uaUDTJNYTQgghsmHNmjUprR5+8IMfsHLlStauXUtdXR379u0bcs2cOXM49dRTAVi1ahWHDx8e9vm7urro7OzkggsuAODWW2/l+eefB2DFihXcdNNN/P73v8dm03NV55xzDnfccQc/+MEP6OzsjB8/WZnMhL0OLFBKzUEPvm4A3j34JKXUIiAfeCWDYxm1v25u4LN/2oKmaVxzatrZUyGEEGJaGiljNZG8Xm/8/rPPPstTTz3FK6+8gsfj4cILL0zbCsLpdMbvW63W405HDufvf/87zz//PI899hjf/OY32bZtG3feeSdXXnkljz/+OOeccw7r1q1j0aJFJ/T8yTKWCdM0LQLcDqwDdgEPapq2Qyn1daXU1Umn3gA8oE2S1NO1p1WysiqX//zbTjr7Q9kejhBCCDGt+f3+EWusurq6yM/Px+PxsHv3bl599dWTfs3c3Fzy8/N54YUXAPjd737HBRdcQCwWo66ujosuuohvf/vbdHV10dvby4EDB1i+fDmf//znOeOMM9i9e/dJjwEyXBOmadrjwOODjn150NdfzeQYxspqUfzX25dz9f+9xDt/+gpXLCvjcFs/155WyUWLSgCIxjSsFumpIoQQQpyswsJCzjnnHJYtW8bll1/OlVdemfL4ZZddxk9/+lMWL17MwoULWbt27bi87r333sttt91Gf38/c+fO5de//jXRaJSbb76Zrq4uNE3jE5/4BHl5efzHf/wHzzzzDBaLhaVLl3L55ZePyxjUJElAjdrq1au1DRs2ZPx1/rWriW/+fRcHW/vIddvpGghTle/GbbdyoKWXNXMK+OC5czl3QREuuzXj4xFCCCEyYdeuXSxevDjbw5gW0r2XSqmNmqatTnf+ZFkdOem8eXEpFy4soS8UwWWz8sDrR9hwuIP+UIRzFxTx+LZjfPC3G/A7bbz3nBqOdQU41jXAzWfO5tKlZVgkUyaEEEKIEUgQNgKrRZHj0ndDf89ZNbznrJr4Y3devohXDrRx//oj/PDp/bjtVgp9Dj76h03MLfby/nPmcNXKCnLdQ3dTr+8c4HBrH+fML5qob0UIIYQQk4wEYSfIabNy4cISLlxYQm2bPmXpd9l5YvsxfvLsAb70yHa+9fguvnr1Ut65qiplT6kv/WUbz+9r5clPnseCUn8WvwshhBBCZEtGN/CeKWYXesnzOLBaFG9dUcHf/u1cHr39HJZW5vK5h7Zy3nee4b/X7aYnEKauvZ9n97YQjWnc9cT4rK4QQgghxNQjmbAMUEqxoiqP+z+0lj9vPMqTOxr50TMH+OPrR1lc7kcB7zlrNve+UstTO5u4eElpyvWxmEZPIEKuZ+hUphBCCCGmBwnCMshqUVx/RjXXn1HNlrpOvvrYDl7Y18rFi0v5whWL2VDbwR0PbubOyxfT0hPE7bCw9WgXLx9oo2sgzN/+7VzKc10cae9nRVVetr8dIYQQQowjmY6cICur8/jzbWfz6/eewTevXYbLbuWnN69CKcUX/7KN7z+1l/96fDevH27nwlOKsVoU968/wr/d/wbX/vhlNh3pyPa3IIQQQkwKPp8PgIaGBt75znemPefCCy8kXUur4Y5ng2TCJpDFouINXwGqCzz87d/OpScQYX6Jj4FQlBy3DaUUkZjGH1+vIxiJYbUoPvXAZv7jrUs4e14hXqf8tQkhhBAVFRU89NBD2R7GCZNMWJZVF3hYUpGDw2Yh12OPr6J81xnVBCMx8j12fv3eM2jpCfKh327g5l++RjQ2tRrsCiGEEMO58847+dGPfhT/+qtf/Sp33303vb29vPnNb+b0009n+fLl/PWvfx1y7eHDh1m2bBkAAwMD3HDDDSxevJhrr712VHtH3n///Sxfvpxly5bx+c9/HoBoNMp73/teli1bxvLly/n+978P6JuIL1myhBUrVnDDDTeMx7cumbDJ6qy5hZy3oIhLl5Zx/inFbPjSxfx501G+/Ncd/P7VWt65qordjT0U+RzMLvQSjESxWyzSJFYIIcSJe+JOaNw2vs9Zthwuv2vYh9/1rnfxqU99io9//OMAPPjgg6xbtw6Xy8Vf/vIXcnJyaG1tZe3atVx99dUpLZ+S/eQnP8Hj8bBr1y62bt3K6aefPuKwGhoa+PznP8/GjRvJz8/nkksu4ZFHHqG6upr6+nq2b98OQGdnJwB33XUXhw4dwul0xo+dLMmEZULLXuisO6mnsFgUv/vAmdy8djYAXqeNW9bO5tz5RXzl0R0s/co63vGTl3n3z1+jPxThov9+lv/+x57xGL0QQggxYU477TSam5tpaGhgy5Yt5OfnU11djaZpfPGLX2TFihVcfPHF1NfX09TUNOzzPP/889x8880ArFixghUrVoz4uq+//joXXnghxcXF2Gw2brrpJp5//nnmzp3LwYMH+bd/+zeefPJJcnJy4s9500038fvf/x6bbXxyWJIJG8nm++DQC3DtT0Z/jabBH94JFafB9feO63CUUvzvDafyp41HiWka/cEo//fMfu744xYaugL89uXD3HbBvLRd+oUQQojjGiFjlUnXXXcdDz30EI2NjbzrXe8C4A9/+AMtLS1s3LgRu91OTU0NgUAg42PJz89ny5YtrFu3jp/+9Kc8+OCD/OpXv+Lvf/87zz//PI899hjf/OY32bZt20kHY5IJG07zLnjsU7DlPmg7MPrr2g9
2021-07-08 22:11:54 +08:00
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-07-08 22:14:05 +08:00
"fig = plt.figure(figsize=(10,6))\n",
2021-07-08 22:11:54 +08:00
"ax = fig.add_subplot(1,1,1)\n",
"ax.plot(train_losses, label='train loss')\n",
"ax.plot(valid_losses, label='valid loss')\n",
"plt.legend()\n",
"ax.set_xlabel('updates')\n",
"ax.set_ylabel('loss');"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 52,
2021-07-08 22:11:54 +08:00
"id": "bc422190",
"metadata": {},
"outputs": [
{
"data": {
2021-07-16 05:41:50 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFzCAYAAAB2A95GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAACgr0lEQVR4nOydd3xcV5n+nzO9a9RlNcvdjntLL04joaVBSAihh+xSwu6yC2SBH4S2sPQWWMpSwgIhBFIIISEhMWlO4u64F1m2ehuNps+dcn5/nHvu3GnSqFmy9H4/H39m5s69d+6MZM/j533O+zLOOQiCIAiCIIgzi2G6L4AgCIIgCGIuQiKMIAiCIAhiGiARRhAEQRAEMQ2QCCMIgiAIgpgGSIQRBEEQBEFMAyTCCIIgCIIgpgHTdF/AWKmqquItLS3TfRkEQRAEQRCjsnPnzgHOeXWh5846EdbS0oIdO3ZM92UQBEEQBEGMCmPsVLHnqBxJEARBEAQxDZAIIwiCIAiCmAZIhBEEQRAEQUwDZ10mrBCJRAIdHR2IxWLTfSnEKNhsNjQ2NsJsNk/3pRAEQRDEtDIrRFhHRwfcbjdaWlrAGJvuyyGKwDnH4OAgOjo6sGDBgum+HIIgCIKYVmZFOTIWi6GyspIE2AyHMYbKykpyLAmCIAgCs0SEASABdpZAPyeCIAiCEMwaETad+P1+/PCHPxzXsW94wxvg9/sn94IIgiAIgpjxkAibBEYSYclkcsRjH3/8cXi93im4qonBOUc6nZ7uyyAIgiCIWQuJsEng7rvvxokTJ7Bu3Tp8/OMfx9atW3HJJZfguuuuwznnnAMAuOGGG7Bx40asXLkSP/nJT7RjW1paMDAwgLa2NqxYsQIf+MAHsHLlSrzuda9DNBrNe60///nPOO+887B+/XpcddVV6O3tBQCEQiG8973vxerVq7FmzRr88Y9/BAA88cQT2LBhA9auXYsrr7wSAHDPPffgG9/4hnbOVatWoa2tDW1tbVi2bBne9a53YdWqVWhvb8cHP/hBbNq0CStXrsTnPvc57Zjt27fjwgsvxNq1a3HuueciGAzi0ksvxZ49e7R9Lr74Yuzdu3fyPmiCIAiCmEXMitWRej7/5wM42BWY1HOeU+/B5968sujzX/3qV7F//35NgGzduhW7du3C/v37tVWAP//5z1FRUYFoNIrNmzfjLW95CyorK7POc+zYMfzud7/DT3/6U7ztbW/DH//4R9x+++1Z+1x88cV4+eWXwRjDz372M3zta1/DN7/5TXzxi19EWVkZXnvtNQDA0NAQ+vv78YEPfADPPfccFixYAJ/PN+p7PXbsGH71q1/h/PPPBwB8+ctfRkVFBVKpFK688krs27cPy5cvxy233ILf//732Lx5MwKBAOx2O97//vfjl7/8Jb7zne/g6NGjiMViWLt2bcmfM0EQBEHMJcgJmyLOPffcrDYM3/ve97B27Vqcf/75aG9vx7Fjx/KOWbBgAdatWwcA2LhxI9ra2vL26ejowDXXXIPVq1fj61//Og4cOAAAePrpp/HhD39Y26+8vBwvv/wyLr30Uu06KioqRr3u+fPnawIMAB544AFs2LAB69evx4EDB3Dw4EEcOXIE8+bNw+bNmwEAHo8HJpMJN998Mx577DEkEgn8/Oc/x3ve855RX48gCIIonaiSwunBSMHnhqMJ9AzT6vOziVnnhI3kWJ1JnE6ndn/r1q14+umnsW3bNjgcDmzZsqVgmwar1ardNxqNBcuRd911Fz72sY/huuuuw9atW3HPPfeM+dpMJlNW3kt/LfrrPnnyJL7xjW9g+/btKC8vx3ve854R20s4HA5cffXVeOSRR/DAAw9g586dY742giAIojg/+scJ/OKFk9jzudfBaMhebf5ffzmEXaeH8NTHLpumqyPGCjlhk4Db7UYwGCz6/PDwMMrLy+FwOHD48GG8/PLL436t4eFhNDQ0AAB+9atfaduvvvpq3HvvvdrjoaEhnH/++Xjuuedw8uRJANDKkS0tLdi1axcAYNeuXdrzuQQCATidTpSVlaG3txd//etfAQDLli1Dd3c3tm/fDgAIBoPaAoQ77rgDH/3oR7F582aUl5eP+30SBEEQ+RzsCiAYT6J7OP8/6d2BGI73hxBLpKbhyojxQCJsEqisrMRFF12EVatW4eMf/3je89deey2SySRWrFiBu+++O6vcN1buuece3Hzzzdi4cSOqqqq07Z/5zGcwNDSEVatWYe3atXj22WdRXV2Nn/zkJ7jpppuwdu1a3HLLLQCAt7zlLfD5fFi5ciV+8IMfYOnSpQVfa+3atVi/fj2WL1+O2267DRdddBEAwGKx4Pe//z3uuusurF27FldffbXmkG3cuBEejwfvfe97x/0eCYIgiMK0DoQAoGBJMhBNgHPgVJFy5UjsPj2Ebz91dMLXNxG2t/nwrWm+hjMN45xP9zWMiU2bNvEdO3ZkbTt06BBWrFgxTVdE6Onq6sKWLVtw+PBhGAyFNT79vAiCmChRJYVXTg5iy7Ka6b6UMdPui8AfSWB1Y9mYjkuk0ljx/55AMs3x1ZtW49Zzm7Oev+KbW9HaH8aP3rEBr189b0znvufRA/jlS2048Plr4LROT1Lp1p9swysnfTj+5TfklVpb+0PwRxPY0Dy+CgvnHE8d7MWWZTWwmM6s/8QY28k531ToOXLCiEnjvvvuw3nnnYcvf/nLRQUYQRDEZPD4a914zy+2ozdw9gXRv/m3I/jI73aVtK+STGvlxdO+CJJprt3PJRAVsZDWgfCYr2koogAA2ofG7qIVIhQfuUdmLl3+KF5u9YFzIBhL5D3/H3/Yi9t/9gr6g/FxXc/B7gDu/PVO/HV/97iOnyrom5KYNN71rnehvb0dN99883RfCkEQsxz5RR2Mje3LfiYwEFLQ7Y+hlErU5x49gNt+KnLErf1CXDFWRISpn8mJ/tCYr2koIo4ttvJyLOw+PYS1n/8bjvcVz0rn8ujerrxrkZwaDGPXaT8iSgo/eCa/s0ApHOstXsadTkiEEQRBEGcd8aRY4X02htCHIgqUVBrD0XzHR086zfG3Az041icERKsqrtY0lOWJsFgiBUX9TE70j8MJCwsnrJC4K0RfMIZbfrwNR3ryhdaedj9SaY4DY+jZ+eieLpiNogTpV105ycO7u8AYcNWKGvz21dPjasMhP7uuAgsaphMSYQRBEEQe8WQKP9p6Qvtin2lIERafodc3En7V6ekLxvGXfd040DVccL/DPUEMhhUEY0nEEim09odR5bJgVQERFlAFncVkQGt/qCSXTY8sR5Yqwr779DG8ctKHV04OAhCZq/u2taHdF9GcuHb1XKk0xw+3HsdAqHApMZ3mONYX1PJefp045ZzjkT2dOLelAh/cshiJFMfB7sKfl5JM439fOIkuf77QOqGWaDv9M6t8TSKMIAiCyGPbiUH89xOH8erJ0SdtTAfSAYufhU6YdMD6AnF88o/78L8vFG4T9MLxfu3+YFjBif4QFla5ML/SAX8kkeWkyVLkynoPgrEkBkJK3vlGYixO2MmBMO7f3g4Amit1ciCMzz5yAL94sU0rm8pz7e8cxteeOIKv/vVwwfP5IgoSKY4V8zwAsp0wX1hB60AYV59Ti8ZyOwCgcygjsg52BZBOc8QSKdz56x344mMH8T//OJH3GvKaOicp8zZZkAgjCIIg8hhUv8RHK5lNF1o5Mnl2ibBEKq2F1g/3BBCKJzEcKfwZv3B8ULs/EIyjbTCMBVVONFc4AGScJgAYVkP5G1U36VB36aXAeDKFsJIJ/4/GH3d2AAC8DrMmwl48PgAA2NM+pAke2SrjlHrOP+3qKFi+lOdYXucGkHEKAaBLda6aKxyodllhNjLNzTrSE8Qbvvc8Htnbif97+RS2HunHvDIbXjg2kHX+dJrjpNrao6vELN6ZgkTYNOFyuQCIlg5vfetbC+6zZcsW5LbjIAiCOBP4VGckUGCl2kwgroqvWOLsKkfqBcaOtiEAhYWukkzj1ZODWNfkBQB0+qMYCCloLLejSRVh+gC+/DltWVYDk4HhpRODeefM5VB3AJu//DT
2021-07-08 22:11:54 +08:00
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-07-08 22:14:05 +08:00
"fig = plt.figure(figsize=(10,6))\n",
2021-07-08 22:11:54 +08:00
"ax = fig.add_subplot(1,1,1)\n",
"ax.plot(train_accs, label='train accuracy')\n",
"ax.plot(valid_accs, label='valid accuracy')\n",
"plt.legend()\n",
"ax.set_xlabel('updates')\n",
"ax.set_ylabel('accuracy');"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 53,
2021-07-08 01:04:25 +08:00
"id": "cac26e8e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-07-16 05:41:50 +08:00
"evaluating...: 100%|██████████| 49/49 [00:03<00:00, 15.72it/s]\n",
"test_loss: 0.359, test_acc: 0.854\n"
2021-07-08 01:04:25 +08:00
]
}
],
"source": [
"model.load_state_dict(torch.load('nbow.pt'))\n",
"\n",
"test_loss, test_acc = evaluate(test_dataloader, model, criterion, device)\n",
"\n",
2021-07-08 22:11:54 +08:00
"epoch_test_loss = np.mean(test_loss)\n",
"epoch_test_acc = np.mean(test_acc)\n",
"\n",
"print(f'test_loss: {epoch_test_loss:.3f}, test_acc: {epoch_test_acc:.3f}')"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 54,
2021-07-08 01:04:25 +08:00
"id": "b22e040a",
"metadata": {},
"outputs": [],
"source": [
"def predict_sentiment(text, model, tokenizer, vocab, device):\n",
" tokens = tokenizer(text)\n",
" ids = [vocab[t] for t in tokens]\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" prediction = model(tensor).squeeze(dim=0)\n",
" probability = torch.softmax(prediction, dim=-1)\n",
" predicted_class = prediction.argmax(dim=-1).item()\n",
" predicted_probability = probability[predicted_class].item()\n",
" return predicted_class, predicted_probability"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 55,
2021-07-08 01:04:25 +08:00
"id": "9cfa14eb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-16 05:41:50 +08:00
"(0, 0.9999850988388062)"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 55,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 56,
2021-07-08 01:04:25 +08:00
"id": "1da60d90",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-16 05:41:50 +08:00
"(1, 0.9999994039535522)"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 56,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 57,
2021-07-08 01:04:25 +08:00
"id": "4bee6190",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-16 05:41:50 +08:00
"(1, 0.6572516560554504)"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 57,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is not terrible, it's great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
2021-07-16 05:41:50 +08:00
"execution_count": 58,
2021-07-08 01:04:25 +08:00
"id": "e3d55c92",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-16 05:41:50 +08:00
"(1, 0.6572516560554504)"
2021-07-08 01:04:25 +08:00
]
},
2021-07-16 05:41:50 +08:00
"execution_count": 58,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is not great, it's terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}