Merge branch 'master' of https://github.com/bentrevett/pytorch-sentiment-analysis
This commit is contained in:
commit
d925186eb4
@ -37,7 +37,7 @@
|
||||
"\n",
|
||||
"We use the `TEXT` field to define how the review should be processed, and the `LABEL` field to process the sentiment. \n",
|
||||
"\n",
|
||||
"Our `TEXT` field has `tokenize='spacy'` as an argument. This defines that the \"tokenization\" (the act of splitting the string into discrete \"tokens\") should be done using the [spaCy](https://spacy.io) tokenizer. If no `tokenize` argument is passed, the default is simply splitting the string on spaces.\n",
|
||||
"Our `TEXT` field has `tokenize='spacy'` as an argument. This defines that the \"tokenization\" (the act of splitting the string into discrete \"tokens\") should be done using the [spaCy](https://spacy.io) tokenizer. If no `tokenize` argument is passed, the default is simply splitting the string on spaces. We also need to specify a `tokenizer_language` which tells torchtext which spaCy model to use. We use the `en_core_web_sm` model which has to be downloaded with `python -m spacy download en_core_web_sm` before you run this notebook!\n",
|
||||
"\n",
|
||||
"`LABEL` is defined by a `LabelField`, a special subset of the `Field` class specifically used for handling labels. We will explain the `dtype` argument later.\n",
|
||||
"\n",
|
||||
@ -60,7 +60,8 @@
|
||||
"torch.manual_seed(SEED)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy')\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy',\n",
|
||||
" tokenizer_language = 'en_core_web_sm')\n",
|
||||
"LABEL = data.LabelField(dtype = torch.float)"
|
||||
]
|
||||
},
|
||||
@ -817,7 +818,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -35,7 +35,18 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torchtext import data\n",
|
||||
@ -46,7 +57,10 @@
|
||||
"torch.manual_seed(SEED)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy', include_lengths = True)\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy',\n",
|
||||
" tokenizer_language = 'en_core_web_sm',\n",
|
||||
" include_lengths = True)\n",
|
||||
"\n",
|
||||
"LABEL = data.LabelField(dtype = torch.float)"
|
||||
]
|
||||
},
|
||||
@ -61,7 +75,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from torchtext import datasets\n",
|
||||
"\n",
|
||||
@ -133,7 +156,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"BATCH_SIZE = 64\n",
|
||||
"\n",
|
||||
@ -204,7 +236,7 @@
|
||||
"\n",
|
||||
"As we are passing the lengths of our sentences to be able to use packed padded sequences, we have to add a second argument, `text_lengths`, to `forward`. \n",
|
||||
"\n",
|
||||
"Before we pass our embeddings to the RNN, we need to pack them, which we do with `nn.utils.rnn.packed_padded_sequence`. This will cause our RNN to only process the non-padded elements of our sequence. The RNN will then return `packed_output` (a packed sequence) as well as the `hidden` and `cell` states (both of which are tensors). Without packed padded sequences, `hidden` and `cell` are tensors from the last element in the sequence, which will most probably be a pad token, however when using packed padded sequences they are both from the last non-padded element in the sequence. \n",
|
||||
"Before we pass our embeddings to the RNN, we need to pack them, which we do with `nn.utils.rnn.packed_padded_sequence`. This will cause our RNN to only process the non-padded elements of our sequence. The RNN will then return `packed_output` (a packed sequence) as well as the `hidden` and `cell` states (both of which are tensors). Without packed padded sequences, `hidden` and `cell` are tensors from the last element in the sequence, which will most probably be a pad token, however when using packed padded sequences they are both from the last non-padded element in the sequence. Note that the `lengths` argument of `packed_padded_sequence` must be a CPU tensor so we explicitly make it one by using `.to('cpu')`.\n",
|
||||
"\n",
|
||||
"We then unpack the output sequence, with `nn.utils.rnn.pad_packed_sequence`, to transform it from a packed sequence to a tensor. The elements of `output` from padding tokens will be zero tensors (tensors where every element is zero). Usually, we only have to unpack output if we are going to use it later on in the model. Although we aren't in this case, we still unpack the sequence just to show how it is done.\n",
|
||||
"\n",
|
||||
@ -246,7 +278,8 @@
|
||||
" #embedded = [sent len, batch size, emb dim]\n",
|
||||
" \n",
|
||||
" #pack sequence\n",
|
||||
" packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths)\n",
|
||||
" # lengths need to be on CPU!\n",
|
||||
" packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'))\n",
|
||||
" \n",
|
||||
" packed_output, (hidden, cell) = self.rnn(packed_embedded)\n",
|
||||
" \n",
|
||||
@ -383,9 +416,9 @@
|
||||
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.0614, -0.0516, -0.6159, ..., -0.0354, 0.0379, -0.1809],\n",
|
||||
" [ 0.1885, -0.1690, 0.1530, ..., -0.2077, 0.5473, -0.4517],\n",
|
||||
" [-0.1182, -0.4701, -0.0600, ..., 0.7991, -0.0194, 0.4785]])"
|
||||
" [ 0.6783, 0.0488, 0.5860, ..., 0.2680, -0.0086, 0.5758],\n",
|
||||
" [-0.6208, -0.0480, -0.1046, ..., 0.3718, 0.1225, 0.1061],\n",
|
||||
" [-0.6553, -0.6292, 0.9967, ..., 0.2278, -0.1975, 0.0857]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
@ -421,9 +454,9 @@
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.0614, -0.0516, -0.6159, ..., -0.0354, 0.0379, -0.1809],\n",
|
||||
" [ 0.1885, -0.1690, 0.1530, ..., -0.2077, 0.5473, -0.4517],\n",
|
||||
" [-0.1182, -0.4701, -0.0600, ..., 0.7991, -0.0194, 0.4785]])\n"
|
||||
" [ 0.6783, 0.0488, 0.5860, ..., 0.2680, -0.0086, 0.5758],\n",
|
||||
" [-0.6208, -0.0480, -0.1046, ..., 0.3718, 0.1225, 0.1061],\n",
|
||||
" [-0.6553, -0.6292, 0.9967, ..., 0.2278, -0.1975, 0.0857]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -638,25 +671,33 @@
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 28s\n",
|
||||
"\tTrain Loss: 0.648 | Train Acc: 62.05%\n",
|
||||
"\t Val. Loss: 0.620 | Val. Acc: 66.72%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 27s\n",
|
||||
"\tTrain Loss: 0.622 | Train Acc: 66.51%\n",
|
||||
"\t Val. Loss: 0.669 | Val. Acc: 62.83%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 27s\n",
|
||||
"\tTrain Loss: 0.586 | Train Acc: 69.01%\n",
|
||||
"\t Val. Loss: 0.522 | Val. Acc: 75.52%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 27s\n",
|
||||
"\tTrain Loss: 0.415 | Train Acc: 82.02%\n",
|
||||
"\t Val. Loss: 0.457 | Val. Acc: 77.10%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 27s\n",
|
||||
"\tTrain Loss: 0.335 | Train Acc: 86.15%\n",
|
||||
"\t Val. Loss: 0.305 | Val. Acc: 87.15%\n"
|
||||
"Epoch: 01 | Epoch Time: 0m 36s\n",
|
||||
"\tTrain Loss: 0.673 | Train Acc: 58.05%\n",
|
||||
"\t Val. Loss: 0.619 | Val. Acc: 64.97%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 36s\n",
|
||||
"\tTrain Loss: 0.611 | Train Acc: 66.33%\n",
|
||||
"\t Val. Loss: 0.510 | Val. Acc: 74.32%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 37s\n",
|
||||
"\tTrain Loss: 0.484 | Train Acc: 77.04%\n",
|
||||
"\t Val. Loss: 0.397 | Val. Acc: 82.95%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 37s\n",
|
||||
"\tTrain Loss: 0.384 | Train Acc: 83.57%\n",
|
||||
"\t Val. Loss: 0.407 | Val. Acc: 83.23%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 37s\n",
|
||||
"\tTrain Loss: 0.314 | Train Acc: 86.98%\n",
|
||||
"\t Val. Loss: 0.314 | Val. Acc: 86.36%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -701,7 +742,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.308 | Test Acc: 87.07%\n"
|
||||
"Test Loss: 0.334 | Test Acc: 85.28%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -744,7 +785,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import spacy\n",
|
||||
"nlp = spacy.load('en')\n",
|
||||
"nlp = spacy.load('en_core_web_sm')\n",
|
||||
"\n",
|
||||
"def predict_sentiment(model, sentence):\n",
|
||||
" model.eval()\n",
|
||||
@ -773,7 +814,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.005683214403688908"
|
||||
"0.05380420759320259"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
@ -800,7 +841,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9926869869232178"
|
||||
"0.94941645860672"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
@ -838,7 +879,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -50,7 +50,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['This', 'film', 'is', 'terrible', 'This film', 'film is', 'is terrible']"
|
||||
"['This', 'film', 'is', 'terrible', 'film is', 'This film', 'is terrible']"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
@ -75,7 +75,18 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torchtext import data\n",
|
||||
@ -86,7 +97,10 @@
|
||||
"torch.manual_seed(SEED)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy', preprocessing = generate_bigrams)\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy',\n",
|
||||
" tokenizer_language = 'en_core_web_sm',\n",
|
||||
" preprocessing = generate_bigrams)\n",
|
||||
"\n",
|
||||
"LABEL = data.LabelField(dtype = torch.float)"
|
||||
]
|
||||
},
|
||||
@ -101,7 +115,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"\n",
|
||||
@ -144,7 +167,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"BATCH_SIZE = 64\n",
|
||||
"\n",
|
||||
@ -287,9 +319,9 @@
|
||||
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [ 0.3199, 0.0746, 0.0231, ..., -0.3609, 1.1303, 0.5668],\n",
|
||||
" [-1.0530, -1.0757, 0.3903, ..., 0.0792, -0.3059, 1.9734],\n",
|
||||
" [-0.1734, -0.3195, 0.3694, ..., -0.2435, 0.4767, 0.1151]])"
|
||||
" [-0.1606, -0.7357, 0.5809, ..., 0.8704, -1.5637, -1.5724],\n",
|
||||
" [-1.3126, -1.6717, 0.4203, ..., 0.2348, -0.9110, 1.0914],\n",
|
||||
" [-1.5268, 1.5639, -1.0541, ..., 1.0045, -0.6813, -0.8846]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
@ -507,25 +539,33 @@
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 6s\n",
|
||||
"\tTrain Loss: 0.688 | Train Acc: 57.23%\n",
|
||||
"\t Val. Loss: 0.642 | Val. Acc: 71.23%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 5s\n",
|
||||
"\tTrain Loss: 0.653 | Train Acc: 71.09%\n",
|
||||
"\t Val. Loss: 0.521 | Val. Acc: 75.28%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 5s\n",
|
||||
"\tTrain Loss: 0.582 | Train Acc: 78.88%\n",
|
||||
"\t Val. Loss: 0.449 | Val. Acc: 79.64%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 5s\n",
|
||||
"\tTrain Loss: 0.505 | Train Acc: 83.15%\n",
|
||||
"\t Val. Loss: 0.426 | Val. Acc: 82.12%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 5s\n",
|
||||
"\tTrain Loss: 0.439 | Train Acc: 85.99%\n",
|
||||
"\t Val. Loss: 0.397 | Val. Acc: 85.02%\n"
|
||||
"Epoch: 01 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.688 | Train Acc: 61.31%\n",
|
||||
"\t Val. Loss: 0.637 | Val. Acc: 72.46%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 6s\n",
|
||||
"\tTrain Loss: 0.651 | Train Acc: 75.04%\n",
|
||||
"\t Val. Loss: 0.507 | Val. Acc: 76.92%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 6s\n",
|
||||
"\tTrain Loss: 0.578 | Train Acc: 79.91%\n",
|
||||
"\t Val. Loss: 0.424 | Val. Acc: 80.97%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 6s\n",
|
||||
"\tTrain Loss: 0.501 | Train Acc: 83.97%\n",
|
||||
"\t Val. Loss: 0.377 | Val. Acc: 84.34%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 6s\n",
|
||||
"\tTrain Loss: 0.435 | Train Acc: 86.96%\n",
|
||||
"\t Val. Loss: 0.363 | Val. Acc: 86.18%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -572,7 +612,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.391 | Test Acc: 85.11%\n"
|
||||
"Test Loss: 0.381 | Test Acc: 85.42%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -600,7 +640,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import spacy\n",
|
||||
"nlp = spacy.load('en')\n",
|
||||
"nlp = spacy.load('en_core_web_sm')\n",
|
||||
"\n",
|
||||
"def predict_sentiment(model, sentence):\n",
|
||||
" model.eval()\n",
|
||||
@ -627,7 +667,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1.621993561684576e-07"
|
||||
"2.1313092350011553e-12"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
@ -692,7 +732,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -36,7 +36,20 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torchtext import data\n",
|
||||
@ -51,7 +64,9 @@
|
||||
"torch.manual_seed(SEED)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy', batch_first = True)\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy', \n",
|
||||
" tokenizer_language = 'en_core_web_sm',\n",
|
||||
" batch_first = True)\n",
|
||||
"LABEL = data.LabelField(dtype = torch.float)\n",
|
||||
"\n",
|
||||
"train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
|
||||
@ -93,7 +108,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"BATCH_SIZE = 64\n",
|
||||
"\n",
|
||||
@ -416,9 +440,9 @@
|
||||
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.0614, -0.0516, -0.6159, ..., -0.0354, 0.0379, -0.1809],\n",
|
||||
" [ 0.1885, -0.1690, 0.1530, ..., -0.2077, 0.5473, -0.4517],\n",
|
||||
" [-0.1182, -0.4701, -0.0600, ..., 0.7991, -0.0194, 0.4785]])"
|
||||
" [ 0.6783, 0.0488, 0.5860, ..., 0.2680, -0.0086, 0.5758],\n",
|
||||
" [-0.6208, -0.0480, -0.1046, ..., 0.3718, 0.1225, 0.1061],\n",
|
||||
" [-0.6553, -0.6292, 0.9967, ..., 0.2278, -0.1975, 0.0857]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
@ -622,25 +646,33 @@
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 13s\n",
|
||||
"\tTrain Loss: 0.645 | Train Acc: 62.08%\n",
|
||||
"\t Val. Loss: 0.488 | Val. Acc: 78.64%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 11s\n",
|
||||
"\tTrain Loss: 0.418 | Train Acc: 81.14%\n",
|
||||
"\t Val. Loss: 0.361 | Val. Acc: 84.59%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 11s\n",
|
||||
"\tTrain Loss: 0.300 | Train Acc: 87.33%\n",
|
||||
"\t Val. Loss: 0.348 | Val. Acc: 85.06%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 11s\n",
|
||||
"\tTrain Loss: 0.217 | Train Acc: 91.49%\n",
|
||||
"\t Val. Loss: 0.320 | Val. Acc: 86.71%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 11s\n",
|
||||
"\tTrain Loss: 0.156 | Train Acc: 94.22%\n",
|
||||
"\t Val. Loss: 0.334 | Val. Acc: 87.06%\n"
|
||||
"\tTrain Loss: 0.649 | Train Acc: 61.79%\n",
|
||||
"\t Val. Loss: 0.507 | Val. Acc: 78.93%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 13s\n",
|
||||
"\tTrain Loss: 0.433 | Train Acc: 79.86%\n",
|
||||
"\t Val. Loss: 0.357 | Val. Acc: 84.57%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 13s\n",
|
||||
"\tTrain Loss: 0.305 | Train Acc: 87.36%\n",
|
||||
"\t Val. Loss: 0.312 | Val. Acc: 86.76%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 13s\n",
|
||||
"\tTrain Loss: 0.224 | Train Acc: 91.20%\n",
|
||||
"\t Val. Loss: 0.303 | Val. Acc: 87.16%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 14s\n",
|
||||
"\tTrain Loss: 0.159 | Train Acc: 94.16%\n",
|
||||
"\t Val. Loss: 0.317 | Val. Acc: 87.37%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -685,7 +717,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.339 | Test Acc: 85.39%\n"
|
||||
"Test Loss: 0.343 | Test Acc: 85.31%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -710,12 +742,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import spacy\n",
|
||||
"nlp = spacy.load('en')\n",
|
||||
"nlp = spacy.load('en_core_web_sm')\n",
|
||||
"\n",
|
||||
"def predict_sentiment(model, sentence, min_len = 5):\n",
|
||||
" model.eval()\n",
|
||||
@ -738,16 +770,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.11022213101387024"
|
||||
"0.09913548082113266"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -765,16 +797,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9785954356193542"
|
||||
"0.9769725799560547"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -800,7 +832,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -21,7 +21,20 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torchtext import data\n",
|
||||
@ -33,7 +46,9 @@
|
||||
"torch.manual_seed(SEED)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy')\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy',\n",
|
||||
" tokenizer_language = 'en_core_web_sm')\n",
|
||||
"\n",
|
||||
"LABEL = data.LabelField()\n",
|
||||
"\n",
|
||||
"train_data, test_data = datasets.TREC.splits(TEXT, LABEL, fine_grained=False)\n",
|
||||
@ -115,7 +130,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"defaultdict(<function _default_unk_index at 0x7f0a50190d08>, {'HUM': 0, 'ENTY': 1, 'DESC': 2, 'NUM': 3, 'LOC': 4, 'ABBR': 5})\n"
|
||||
"defaultdict(None, {'HUM': 0, 'ENTY': 1, 'DESC': 2, 'NUM': 3, 'LOC': 4, 'ABBR': 5})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -134,7 +149,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"BATCH_SIZE = 64\n",
|
||||
"\n",
|
||||
@ -254,7 +278,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The model has 842,406 trainable parameters\n"
|
||||
"The model has 841,806 trainable parameters\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -286,7 +310,7 @@
|
||||
" ...,\n",
|
||||
" [-0.3110, -0.3398, 1.0308, ..., 0.5317, 0.2836, -0.0640],\n",
|
||||
" [ 0.0091, 0.2810, 0.7356, ..., -0.7508, 0.8967, -0.7631],\n",
|
||||
" [ 0.4306, 1.2011, 0.0873, ..., 0.8817, 0.3722, 0.3458]])"
|
||||
" [ 0.5831, -0.2514, 0.4156, ..., -0.2735, -0.8659, -1.4063]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
@ -367,9 +391,10 @@
|
||||
" \"\"\"\n",
|
||||
" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
|
||||
" \"\"\"\n",
|
||||
" max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability\n",
|
||||
" correct = max_preds.squeeze(1).eq(y)\n",
|
||||
" return correct.sum() / torch.FloatTensor([y.shape[0]])"
|
||||
" top_pred = preds.argmax(1, keepdim = True)\n",
|
||||
" correct = top_pred.eq(y.view_as(top_pred)).sum()\n",
|
||||
" acc = correct.float() / y.shape[0]\n",
|
||||
" return acc"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -479,25 +504,33 @@
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 0s\n",
|
||||
"\tTrain Loss: 1.310 | Train Acc: 47.99%\n",
|
||||
"\t Val. Loss: 0.947 | Val. Acc: 66.81%\n",
|
||||
"\tTrain Loss: 1.312 | Train Acc: 47.11%\n",
|
||||
"\t Val. Loss: 0.947 | Val. Acc: 66.41%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 0s\n",
|
||||
"\tTrain Loss: 0.869 | Train Acc: 69.09%\n",
|
||||
"\t Val. Loss: 0.746 | Val. Acc: 74.18%\n",
|
||||
"\tTrain Loss: 0.870 | Train Acc: 69.18%\n",
|
||||
"\t Val. Loss: 0.741 | Val. Acc: 74.14%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 0s\n",
|
||||
"\tTrain Loss: 0.665 | Train Acc: 76.94%\n",
|
||||
"\t Val. Loss: 0.627 | Val. Acc: 78.03%\n",
|
||||
"\tTrain Loss: 0.675 | Train Acc: 76.32%\n",
|
||||
"\t Val. Loss: 0.621 | Val. Acc: 78.49%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 0s\n",
|
||||
"\tTrain Loss: 0.503 | Train Acc: 83.42%\n",
|
||||
"\t Val. Loss: 0.548 | Val. Acc: 79.73%\n",
|
||||
"\tTrain Loss: 0.506 | Train Acc: 83.97%\n",
|
||||
"\t Val. Loss: 0.547 | Val. Acc: 80.32%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 0s\n",
|
||||
"\tTrain Loss: 0.376 | Train Acc: 87.88%\n",
|
||||
"\t Val. Loss: 0.506 | Val. Acc: 81.40%\n"
|
||||
"\tTrain Loss: 0.373 | Train Acc: 88.23%\n",
|
||||
"\t Val. Loss: 0.487 | Val. Acc: 82.92%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -542,7 +575,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.411 | Test Acc: 87.15%\n"
|
||||
"Test Loss: 0.415 | Test Acc: 86.07%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -570,7 +603,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import spacy\n",
|
||||
"nlp = spacy.load('en')\n",
|
||||
"nlp = spacy.load('en_core_web_sm')\n",
|
||||
"\n",
|
||||
"def predict_class(model, sentence, min_len = 4):\n",
|
||||
" model.eval()\n",
|
||||
@ -681,7 +714,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -52,16 +52,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I1106 14:55:11.110527 139759243081536 file_utils.py:39] PyTorch version 1.3.0 available.\n",
|
||||
"I1106 14:55:11.917650 139759243081536 tokenization_utils.py:374] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/ben/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import BertTokenizer\n",
|
||||
"\n",
|
||||
@ -294,7 +285,18 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from torchtext import data\n",
|
||||
"\n",
|
||||
@ -321,7 +323,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from torchtext import datasets\n",
|
||||
"\n",
|
||||
@ -367,7 +378,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'text': [5949, 1997, 2026, 2166, 1010, 1012, 1012, 1012, 1012, 1996, 2472, 2323, 2022, 10339, 1012, 2339, 2111, 2514, 2027, 2342, 2000, 2191, 22692, 5691, 2097, 2196, 2191, 3168, 2000, 2033, 1012, 2043, 2016, 2351, 2012, 1996, 2203, 1010, 2009, 2081, 2033, 4756, 1012, 1045, 2018, 2000, 2689, 1996, 3149, 2116, 2335, 2802, 1996, 2143, 2138, 1045, 2001, 2893, 10339, 3666, 2107, 3532, 3772, 1012, 11504, 1996, 3124, 2040, 2209, 9895, 2196, 4152, 2147, 2153, 1012, 2006, 2327, 1997, 2008, 1045, 3246, 1996, 2472, 2196, 4152, 2000, 2191, 2178, 2143, 1010, 1998, 2038, 2010, 3477, 5403, 3600, 2579, 2067, 2005, 2023, 10231, 1012, 1063, 1012, 6185, 2041, 1997, 2184, 1065], 'label': 'neg'}\n"
|
||||
"{'text': [1042, 4140, 1996, 2087, 2112, 1010, 2023, 3185, 5683, 2066, 1037, 1000, 2081, 1011, 2005, 1011, 2694, 1000, 3947, 1012, 1996, 3257, 2003, 10654, 1011, 28273, 1010, 1996, 3772, 1006, 2007, 1996, 6453, 1997, 5965, 1043, 11761, 2638, 1007, 2003, 2058, 13088, 10593, 2102, 1998, 7815, 2100, 1012, 15339, 14282, 1010, 3391, 1010, 18058, 2014, 3210, 2066, 2016, 1005, 1055, 3147, 3752, 2068, 2125, 1037, 16091, 4003, 1012, 2069, 2028, 2518, 3084, 2023, 2143, 4276, 3666, 1010, 1998, 2008, 2003, 2320, 10012, 3310, 2067, 2013, 1996, 1000, 7367, 11368, 5649, 1012, 1000, 2045, 2003, 2242, 14888, 2055, 3666, 1037, 2235, 2775, 4028, 2619, 1010, 1998, 2023, 3185, 2453, 2022, 2062, 2084, 2070, 2064, 5047, 2074, 2005, 2008, 3114, 1012, 2009, 2003, 7078, 5923, 1011, 27017, 1012, 2023, 2143, 2069, 2515, 2028, 2518, 2157, 1010, 2021, 2009, 21145, 2008, 2028, 2518, 2157, 2041, 1997, 1996, 2380, 1012, 4276, 3773, 2074, 2005, 1996, 2197, 2184, 2781, 2030, 2061, 1012], 'label': 'neg'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -391,7 +402,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['waste', 'of', 'my', 'life', ',', '.', '.', '.', '.', 'the', 'director', 'should', 'be', 'embarrassed', '.', 'why', 'people', 'feel', 'they', 'need', 'to', 'make', 'worthless', 'movies', 'will', 'never', 'make', 'sense', 'to', 'me', '.', 'when', 'she', 'died', 'at', 'the', 'end', ',', 'it', 'made', 'me', 'laugh', '.', 'i', 'had', 'to', 'change', 'the', 'channel', 'many', 'times', 'throughout', 'the', 'film', 'because', 'i', 'was', 'getting', 'embarrassed', 'watching', 'such', 'poor', 'acting', '.', 'hopefully', 'the', 'guy', 'who', 'played', 'heath', 'never', 'gets', 'work', 'again', '.', 'on', 'top', 'of', 'that', 'i', 'hope', 'the', 'director', 'never', 'gets', 'to', 'make', 'another', 'film', ',', 'and', 'has', 'his', 'pay', '##che', '##ck', 'taken', 'back', 'for', 'this', 'crap', '.', '{', '.', '02', 'out', 'of', '10', '}']\n"
|
||||
"['f', '##ot', 'the', 'most', 'part', ',', 'this', 'movie', 'feels', 'like', 'a', '\"', 'made', '-', 'for', '-', 'tv', '\"', 'effort', '.', 'the', 'direction', 'is', 'ham', '-', 'fisted', ',', 'the', 'acting', '(', 'with', 'the', 'exception', 'of', 'fred', 'g', '##wyn', '##ne', ')', 'is', 'over', '##wr', '##ough', '##t', 'and', 'soap', '##y', '.', 'denise', 'crosby', ',', 'particularly', ',', 'delivers', 'her', 'lines', 'like', 'she', \"'\", 's', 'cold', 'reading', 'them', 'off', 'a', 'cue', 'card', '.', 'only', 'one', 'thing', 'makes', 'this', 'film', 'worth', 'watching', ',', 'and', 'that', 'is', 'once', 'gage', 'comes', 'back', 'from', 'the', '\"', 'se', '##met', '##ary', '.', '\"', 'there', 'is', 'something', 'disturbing', 'about', 'watching', 'a', 'small', 'child', 'murder', 'someone', ',', 'and', 'this', 'movie', 'might', 'be', 'more', 'than', 'some', 'can', 'handle', 'just', 'for', 'that', 'reason', '.', 'it', 'is', 'absolutely', 'bone', '-', 'chilling', '.', 'this', 'film', 'only', 'does', 'one', 'thing', 'right', ',', 'but', 'it', 'knocks', 'that', 'one', 'thing', 'right', 'out', 'of', 'the', 'park', '.', 'worth', 'seeing', 'just', 'for', 'the', 'last', '10', 'minutes', 'or', 'so', '.']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -445,7 +456,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"BATCH_SIZE = 128\n",
|
||||
"\n",
|
||||
@ -470,39 +490,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I1106 14:57:06.877642 139759243081536 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/ben/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c\n",
|
||||
"I1106 14:57:06.878792 139759243081536 configuration_utils.py:168] Model config {\n",
|
||||
" \"attention_probs_dropout_prob\": 0.1,\n",
|
||||
" \"finetuning_task\": null,\n",
|
||||
" \"hidden_act\": \"gelu\",\n",
|
||||
" \"hidden_dropout_prob\": 0.1,\n",
|
||||
" \"hidden_size\": 768,\n",
|
||||
" \"initializer_range\": 0.02,\n",
|
||||
" \"intermediate_size\": 3072,\n",
|
||||
" \"layer_norm_eps\": 1e-12,\n",
|
||||
" \"max_position_embeddings\": 512,\n",
|
||||
" \"num_attention_heads\": 12,\n",
|
||||
" \"num_hidden_layers\": 12,\n",
|
||||
" \"num_labels\": 2,\n",
|
||||
" \"output_attentions\": false,\n",
|
||||
" \"output_hidden_states\": false,\n",
|
||||
" \"output_past\": true,\n",
|
||||
" \"pruned_heads\": {},\n",
|
||||
" \"torchscript\": false,\n",
|
||||
" \"type_vocab_size\": 2,\n",
|
||||
" \"use_bfloat16\": false,\n",
|
||||
" \"vocab_size\": 30522\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"I1106 14:57:07.421291 139759243081536 modeling_utils.py:337] loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin from cache at /home/ben/.cache/torch/transformers/aa1ef1aede4482d0dbcd4d52baad8ae300e60902e88fcb0bebdec09afd232066.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import BertTokenizer, BertModel\n",
|
||||
"\n",
|
||||
@ -880,28 +868,36 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 7m 27s\n",
|
||||
"\tTrain Loss: 0.286 | Train Acc: 88.16%\n",
|
||||
"\t Val. Loss: 0.247 | Val. Acc: 90.26%\n",
|
||||
"Epoch: 02 | Epoch Time: 7m 27s\n",
|
||||
"\tTrain Loss: 0.234 | Train Acc: 90.77%\n",
|
||||
"\t Val. Loss: 0.229 | Val. Acc: 91.00%\n",
|
||||
"Epoch: 03 | Epoch Time: 7m 27s\n",
|
||||
"\tTrain Loss: 0.209 | Train Acc: 91.83%\n",
|
||||
"\t Val. Loss: 0.225 | Val. Acc: 91.10%\n",
|
||||
"Epoch: 04 | Epoch Time: 7m 27s\n",
|
||||
"\tTrain Loss: 0.182 | Train Acc: 92.97%\n",
|
||||
"\t Val. Loss: 0.217 | Val. Acc: 91.98%\n",
|
||||
"Epoch: 05 | Epoch Time: 7m 27s\n",
|
||||
"\tTrain Loss: 0.156 | Train Acc: 94.17%\n",
|
||||
"\t Val. Loss: 0.230 | Val. Acc: 91.76%\n"
|
||||
"Epoch: 01 | Epoch Time: 7m 13s\n",
|
||||
"\tTrain Loss: 0.502 | Train Acc: 74.41%\n",
|
||||
"\t Val. Loss: 0.270 | Val. Acc: 89.15%\n",
|
||||
"Epoch: 02 | Epoch Time: 7m 7s\n",
|
||||
"\tTrain Loss: 0.281 | Train Acc: 88.49%\n",
|
||||
"\t Val. Loss: 0.224 | Val. Acc: 91.32%\n",
|
||||
"Epoch: 03 | Epoch Time: 7m 17s\n",
|
||||
"\tTrain Loss: 0.239 | Train Acc: 90.67%\n",
|
||||
"\t Val. Loss: 0.211 | Val. Acc: 91.91%\n",
|
||||
"Epoch: 04 | Epoch Time: 7m 14s\n",
|
||||
"\tTrain Loss: 0.206 | Train Acc: 91.81%\n",
|
||||
"\t Val. Loss: 0.206 | Val. Acc: 92.01%\n",
|
||||
"Epoch: 05 | Epoch Time: 7m 15s\n",
|
||||
"\tTrain Loss: 0.188 | Train Acc: 92.63%\n",
|
||||
"\t Val. Loss: 0.211 | Val. Acc: 91.92%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -939,14 +935,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.198 | Test Acc: 92.31%\n"
|
||||
"Test Loss: 0.209 | Test Acc: 91.58%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -969,7 +965,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -986,16 +982,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.02264496125280857"
|
||||
"0.03391794115304947"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -1006,16 +1002,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9411056041717529"
|
||||
"0.8869886994361877"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -1041,7 +1037,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.5"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -6,7 +6,7 @@ This repo contains tutorials covering how to do sentiment analysis using [PyTorc
|
||||
|
||||
The first 2 tutorials will cover getting started with the de facto approach to sentiment analysis: recurrent neural networks (RNNs). The third notebook covers the [FastText](https://arxiv.org/abs/1607.01759) model and the final covers a [convolutional neural network](https://arxiv.org/abs/1408.5882) (CNN) model.
|
||||
|
||||
There are also 2 bonus "appendix" notebooks. The first covers loading your own datasets with TorchText, while the second contains a brief look at the pre-trained word embeddings provided by TorchText.
|
||||
There are also 2 bonus "appendix" notebooks. The first covers loading your own datasets with torchtext, while the second contains a brief look at the pre-trained word embeddings provided by torchtext.
|
||||
|
||||
**If you find any mistakes or disagree with any of the explanations, please do not hesitate to [submit an issue](https://github.com/bentrevett/pytorch-sentiment-analysis/issues/new). I welcome any feedback, positive or negative!**
|
||||
|
||||
@ -14,7 +14,7 @@ There are also 2 bonus "appendix" notebooks. The first covers loading your own d
|
||||
|
||||
To install PyTorch, see installation instructions on the [PyTorch website](https://pytorch.org/get-started/locally).
|
||||
|
||||
To install TorchText:
|
||||
To install torchtext:
|
||||
|
||||
``` bash
|
||||
pip install torchtext
|
||||
@ -23,7 +23,7 @@ pip install torchtext
|
||||
We'll also make use of spaCy to tokenize our data. To install spaCy, follow the instructions [here](https://spacy.io/usage/) making sure to install the English models with:
|
||||
|
||||
``` bash
|
||||
python -m spacy download en
|
||||
python -m spacy download en_core_web_sm
|
||||
```
|
||||
|
||||
For tutorial 6, we'll use the transformers library, which can be installed via:
|
||||
@ -38,7 +38,7 @@ These tutorials were created using version 4.3 of the transformers library.
|
||||
|
||||
* 1 - [Simple Sentiment Analysis](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/1%20-%20Simple%20Sentiment%20Analysis.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-sentiment-analysis/blob/master/1%20-%20Simple%20Sentiment%20Analysis.ipynb)
|
||||
|
||||
This tutorial covers the workflow of a PyTorch with TorchText project. We'll learn how to: load data, create train/test/validation splits, build a vocabulary, create data iterators, define a model and implement the train/evaluate/test loop. The model will be simple and achieve poor performance, but this will be improved in the subsequent tutorials.
|
||||
This tutorial covers the workflow of a PyTorch with torchtext project. We'll learn how to: load data, create train/test/validation splits, build a vocabulary, create data iterators, define a model and implement the train/evaluate/test loop. The model will be simple and achieve poor performance, but this will be improved in the subsequent tutorials.
|
||||
|
||||
* 2 - [Upgraded Sentiment Analysis](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/2%20-%20Upgraded%20Sentiment%20Analysis.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-sentiment-analysis/blob/master/2%20-%20Upgraded%20Sentiment%20Analysis.ipynb)
|
||||
|
||||
|
@ -64,94 +64,59 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "y-yPGXY_dFmH",
|
||||
"outputId": "3b6e5a98-f073-4281-8ff7-0ec873c059ae"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<torchtext.experimental.datasets.raw.text_classification.RawTextIterableDataset object at 0x7f56d84f0ac0>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(raw_train_data)"
|
||||
"raw_train_data = list(raw_train_data)\n",
|
||||
"raw_test_data = list(raw_test_data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 55
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "UXWtJbsXdFmO",
|
||||
"outputId": "b6f80eb7-4b91-4188-a800-c328067f339e"
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(0, 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered \"controversial\" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, even then it\\'s not shot like some cheaply made porno. While my countrymen mind find it shocking, in reality sex and nudity are a major staple in Swedish cinema. Even Ingmar Bergman, arguably their answer to good old boy John Ford, had sex scenes in his films.<br /><br />I do commend the filmmakers for the fact that any sex shown in the film is shown for artistic purposes rather than just to shock people and make money to be shown in pornographic theaters in America. I AM CURIOUS-YELLOW is a good film for anyone wanting to study the meat and potatoes (no pun intended) of Swedish cinema. But really, this film doesn\\'t have much of a plot.')\n"
|
||||
]
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"('neg',\n",
|
||||
" 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered \"controversial\" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, even then it\\'s not shot like some cheaply made porno. While my countrymen mind find it shocking, in reality sex and nudity are a major staple in Swedish cinema. Even Ingmar Bergman, arguably their answer to good old boy John Ford, had sex scenes in his films.<br /><br />I do commend the filmmakers for the fact that any sex shown in the film is shown for artistic purposes rather than just to shock people and make money to be shown in pornographic theaters in America. I AM CURIOUS-YELLOW is a good film for anyone wanting to study the meat and potatoes (no pun intended) of Swedish cinema. But really, this film doesn\\'t have much of a plot.')"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"raw_train_data = list(raw_train_data)\n",
|
||||
"\n",
|
||||
"print(raw_train_data[0])"
|
||||
"raw_train_data[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 55
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "B2HoQ4VOdFmS",
|
||||
"outputId": "b63ed1a5-82b8-4bf9-c590-82a371e1ec84"
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(0, 'I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn\\'t match the background, and painfully one-dimensional characters cannot be overcome with a \\'sci-fi\\' setting. (I\\'m sure there are those of you out there who think Babylon 5 is good sci-fi TV. It\\'s not. It\\'s clichéd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It\\'s really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of Earth KNOW it\\'s rubbish as they have to always say \"Gene Roddenberry\\'s Earth...\" otherwise people would not continue watching. Roddenberry\\'s ashes must be turning in their orbit as this dull, cheap, poorly edited (watching it without advert breaks really brings this home) trudging Trabant of a show lumbers into space. Spoiler. So, kill off a main character. And then bring him back as another actor. Jeeez! Dallas all over again.')\n"
|
||||
]
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"('neg',\n",
|
||||
" 'I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn\\'t match the background, and painfully one-dimensional characters cannot be overcome with a \\'sci-fi\\' setting. (I\\'m sure there are those of you out there who think Babylon 5 is good sci-fi TV. It\\'s not. It\\'s clichéd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It\\'s really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of Earth KNOW it\\'s rubbish as they have to always say \"Gene Roddenberry\\'s Earth...\" otherwise people would not continue watching. Roddenberry\\'s ashes must be turning in their orbit as this dull, cheap, poorly edited (watching it without advert breaks really brings this home) trudging Trabant of a show lumbers into space. Spoiler. So, kill off a main character. And then bring him back as another actor. Jeeez! Dallas all over again.')"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"raw_test_data = list(raw_test_data)\n",
|
||||
"\n",
|
||||
"print(raw_test_data[0])"
|
||||
"raw_test_data[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 52
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "Tq8DjZTzdFmU",
|
||||
"outputId": "2aeab986-6922-4a28-ad2f-507282ceb60f"
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@ -178,8 +143,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_train_valid_split(raw_train_data, split_ratio = 0.7):\n",
|
||||
"\n",
|
||||
" raw_train_data = list(raw_train_data)\n",
|
||||
" \n",
|
||||
" random.shuffle(raw_train_data)\n",
|
||||
" \n",
|
||||
@ -188,9 +151,6 @@
|
||||
" train_data = raw_train_data[:n_train_examples]\n",
|
||||
" valid_data = raw_train_data[n_train_examples:]\n",
|
||||
" \n",
|
||||
" train_data = RawTextIterableDataset(train_data)\n",
|
||||
" valid_data = RawTextIterableDataset(valid_data)\n",
|
||||
" \n",
|
||||
" return train_data, valid_data"
|
||||
]
|
||||
},
|
||||
@ -210,29 +170,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "iS9aLR8rdFmc"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"raw_train_data = list(raw_train_data)\n",
|
||||
"raw_valid_data = list(raw_valid_data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 69
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "yzEGgkz5dFmf",
|
||||
"outputId": "8e1cd5a4-76cd-492c-c387-65604dba13c1"
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@ -252,7 +190,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -282,7 +220,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -290,14 +228,14 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"max_length = 500\n",
|
||||
"max_length = 250\n",
|
||||
"\n",
|
||||
"tokenizer = Tokenizer(max_length = max_length)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -324,7 +262,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -333,7 +271,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def build_vocab_from_data(raw_data, tokenizer, **vocab_kwargs):\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" token_freqs = collections.Counter()\n",
|
||||
" \n",
|
||||
" for label, text in raw_data:\n",
|
||||
@ -347,7 +285,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -360,6 +298,23 @@
|
||||
"vocab = build_vocab_from_data(raw_train_data, tokenizer, max_size = max_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Unique tokens in vocab: 25,002\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(f'Unique tokens in vocab: {len(vocab):,}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
@ -376,26 +331,26 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[('the', 211890),\n",
|
||||
" ('.', 208427),\n",
|
||||
" (',', 173201),\n",
|
||||
" ('a', 103447),\n",
|
||||
" ('and', 103052),\n",
|
||||
" ('of', 91695),\n",
|
||||
" ('to', 84931),\n",
|
||||
" (\"'\", 83805),\n",
|
||||
" ('is', 67948),\n",
|
||||
" ('it', 61322),\n",
|
||||
" ('in', 58757),\n",
|
||||
" ('i', 57420),\n",
|
||||
" ('this', 49214),\n",
|
||||
" ('that', 46061),\n",
|
||||
" ('s', 38864),\n",
|
||||
" ('was', 31341),\n",
|
||||
" ('as', 29171),\n",
|
||||
" ('movie', 28776),\n",
|
||||
" ('for', 27903),\n",
|
||||
" ('with', 27680)]"
|
||||
"[('the', 165322),\n",
|
||||
" ('.', 164239),\n",
|
||||
" (',', 133647),\n",
|
||||
" ('a', 81952),\n",
|
||||
" ('and', 80334),\n",
|
||||
" ('of', 71820),\n",
|
||||
" ('to', 65662),\n",
|
||||
" (\"'\", 64249),\n",
|
||||
" ('is', 53598),\n",
|
||||
" ('it', 49589),\n",
|
||||
" ('i', 48810),\n",
|
||||
" ('in', 45611),\n",
|
||||
" ('this', 40868),\n",
|
||||
" ('that', 35609),\n",
|
||||
" ('s', 29273),\n",
|
||||
" ('was', 26159),\n",
|
||||
" ('movie', 24543),\n",
|
||||
" ('as', 22276),\n",
|
||||
" ('with', 21494),\n",
|
||||
" ('for', 21332)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 17,
|
||||
@ -473,15 +428,14 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def process_raw_data(raw_data, tokenizer, vocab):\n",
|
||||
" \n",
|
||||
" raw_data = [(label, text) for (label, text) in raw_data]\n",
|
||||
"\n",
|
||||
"def raw_data_to_dataset(raw_data, tokenizer, vocab):\n",
|
||||
" \n",
|
||||
" text_transform = sequential_transforms(tokenizer.tokenize,\n",
|
||||
" vocab_func(vocab),\n",
|
||||
" totensor(dtype=torch.long))\n",
|
||||
" \n",
|
||||
" label_transform = sequential_transforms(totensor(dtype=torch.long))\n",
|
||||
" label_transform = sequential_transforms(lambda x: 1 if x == 'pos' else 0, \n",
|
||||
" totensor(dtype=torch.long))\n",
|
||||
"\n",
|
||||
" transforms = (label_transform, text_transform)\n",
|
||||
"\n",
|
||||
@ -502,12 +456,35 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_data = process_raw_data(raw_train_data, tokenizer, vocab)"
|
||||
"train_data = raw_data_to_dataset(raw_train_data, tokenizer, vocab)\n",
|
||||
"valid_data = raw_data_to_dataset(raw_valid_data, tokenizer, vocab)\n",
|
||||
"test_data = raw_data_to_dataset(raw_test_data, tokenizer, vocab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Number of training examples: 17,500\n",
|
||||
"Number of validation examples: 7,500\n",
|
||||
"Number of testing examples: 25,000\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(f'Number of training examples: {len(train_data):,}')\n",
|
||||
"print(f'Number of validation examples: {len(valid_data):,}')\n",
|
||||
"print(f'Number of testing examples: {len(test_data):,}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -522,41 +499,43 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([ 610, 8612, 10, 5, 60, 221, 6, 5, 60, 0,\n",
|
||||
" 172, 3, 32, 91, 19, 9, 330, 7, 1487, 9,\n",
|
||||
" 17, 2, 35, 13, 414, 128, 4, 6, 11, 293,\n",
|
||||
" 5, 210, 7, 116, 21, 111, 632, 1851, 7, 125,\n",
|
||||
" 1427, 18, 39, 84, 90, 2552, 1114, 4, 6, 44,\n",
|
||||
" 131, 2, 1851, 7, 2, 490, 3, 44, 10, 9,\n",
|
||||
" 6126, 9, 72, 43, 2, 428, 496, 8, 293, 2,\n",
|
||||
" 348, 7, 2, 490, 8, 2, 584, 831, 481, 3,\n",
|
||||
" 2, 4673, 111, 30, 119, 1783, 4, 35, 3138, 6,\n",
|
||||
" 35, 2349, 42, 854, 21, 5, 17557, 5565, 1586, 3,\n",
|
||||
" 2, 560, 7, 41, 174, 3138, 2038, 1996, 42, 1431,\n",
|
||||
" 8, 34, 41, 2004, 2289, 7, 2, 6375, 5693, 346,\n",
|
||||
" 0, 2, 3138, 2140, 12, 2, 348, 6, 536, 7,\n",
|
||||
" 2, 3138, 1271, 3, 14, 10, 77, 2, 1564, 7,\n",
|
||||
" 2, 23, 30, 2, 59, 592, 3, 2, 1852, 7,\n",
|
||||
" 6854, 10, 2482, 6, 265, 8, 264, 4, 2, 600,\n",
|
||||
" 7, 3138, 0, 10, 10481, 4, 6, 2, 105, 2,\n",
|
||||
" 1271, 851, 2341, 8, 2, 514, 1128, 5216, 10, 1136,\n",
|
||||
" 3, 2, 139, 7, 2, 23, 432, 10, 0, 1374,\n",
|
||||
" 4, 22, 8291, 43, 5, 440, 937, 1851, 3, 869,\n",
|
||||
" 10103, 6, 939, 4340, 203, 564, 349, 4, 22, 2,\n",
|
||||
" 1564, 7, 2, 69, 30, 101, 3643, 8, 34, 703,\n",
|
||||
" 13280, 3])\n"
|
||||
"tensor([ 12, 121, 1013, 6, 219, 1855, 8, 276, 70, 20,\n",
|
||||
" 5, 177, 3, 1013, 0, 30, 541, 0, 4, 15259,\n",
|
||||
" 6, 7022, 3, 12, 751, 8, 45, 14, 4, 12,\n",
|
||||
" 69, 123, 4, 22, 11, 10, 8, 56, 241, 1013,\n",
|
||||
" 19, 12534, 563, 10, 8, 338, 1803, 25, 2, 196,\n",
|
||||
" 24, 3, 717, 0, 4, 745, 3428, 686, 4, 4315,\n",
|
||||
" 3437, 4, 4258, 15, 170, 9, 28, 1209, 2, 951,\n",
|
||||
" 4, 6, 2005, 5083, 113, 544, 35, 2957, 20, 5,\n",
|
||||
" 9, 1013, 9, 925, 3, 25, 12, 9, 145, 255,\n",
|
||||
" 46, 30, 160, 7, 26, 54, 46, 42, 107, 12534,\n",
|
||||
" 563, 10, 56, 1013, 241, 3, 11, 9, 16, 29,\n",
|
||||
" 3, 11, 9, 16, 2966, 6, 8018, 3, 24, 143,\n",
|
||||
" 199, 773, 249, 45, 1364, 6, 120, 893, 4, 1013,\n",
|
||||
" 10, 5, 516, 15, 135, 29, 205, 437, 599, 25,\n",
|
||||
" 24229, 3, 338, 1803, 24, 3, 11, 222, 1655, 734,\n",
|
||||
" 1296, 4, 265, 29, 19, 5, 618, 4793, 3, 11,\n",
|
||||
" 9, 16, 69, 866, 8, 474, 47, 2, 113, 138,\n",
|
||||
" 19, 39, 30, 29, 343, 6136, 4, 48, 984, 5,\n",
|
||||
" 5212, 7, 122, 3, 77, 1894, 6, 3550, 30, 1650,\n",
|
||||
" 6, 634, 4, 403, 1266, 8, 110, 3, 2, 1332,\n",
|
||||
" 7, 649, 130, 11, 9, 16, 1834, 19, 39, 31,\n",
|
||||
" 8, 215, 134, 1965, 13961, 9, 16, 649, 3, 3,\n",
|
||||
" 3, 910, 81, 68, 29, 1677, 142, 3, 13961, 9,\n",
|
||||
" 16, 13264, 208, 35, 1685, 13, 77, 13826, 19, 14,\n",
|
||||
" 696, 4, 745, 4, 793, 2192, 25, 142, 11, 211])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"label, indexes = train_data[0]\n",
|
||||
"label, indexes = test_data[0]\n",
|
||||
"\n",
|
||||
"print(indexes)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -571,7 +550,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['david', 'mamet', 'is', 'a', 'very', 'interesting', 'and', 'a', 'very', '<unk>', 'director', '.', 'his', 'first', 'movie', \"'\", 'house', 'of', 'games', \"'\", 'was', 'the', 'one', 'i', 'liked', 'best', ',', 'and', 'it', 'set', 'a', 'series', 'of', 'films', 'with', 'characters', 'whose', 'perspective', 'of', 'life', 'changes', 'as', 'they', 'get', 'into', 'complicated', 'situations', ',', 'and', 'so', 'does', 'the', 'perspective', 'of', 'the', 'viewer', '.', 'so', 'is', \"'\", 'homicide', \"'\", 'which', 'from', 'the', 'title', 'tries', 'to', 'set', 'the', 'mind', 'of', 'the', 'viewer', 'to', 'the', 'usual', 'crime', 'drama', '.', 'the', 'principal', 'characters', 'are', 'two', 'cops', ',', 'one', 'jewish', 'and', 'one', 'irish', 'who', 'deal', 'with', 'a', 'racially', 'charged', 'area', '.', 'the', 'murder', 'of', 'an', 'old', 'jewish', 'shop', 'owner', 'who', 'proves', 'to', 'be', 'an', 'ancient', 'veteran', 'of', 'the', 'israeli', 'independence', 'war', '<unk>', 'the', 'jewish', 'identity', 'in', 'the', 'mind', 'and', 'heart', 'of', 'the', 'jewish', 'detective', '.', 'this', 'is', 'were', 'the', 'flaws', 'of', 'the', 'film', 'are', 'the', 'more', 'obvious', '.', 'the', 'process', 'of', 'awakening', 'is', 'theatrical', 'and', 'hard', 'to', 'believe', ',', 'the', 'group', 'of', 'jewish', '<unk>', 'is', 'operatic', ',', 'and', 'the', 'way', 'the', 'detective', 'eventually', 'walks', 'to', 'the', 'final', 'violent', 'confrontation', 'is', 'pathetic', '.', 'the', 'end', 'of', 'the', 'film', 'itself', 'is', '<unk>', 'smart', ',', 'but', 'disappoints', 'from', 'a', 'human', 'emotional', 'perspective', '.', 'joe', 'mantegna', 'and', 'william', 'macy', 'give', 'strong', 'performances', ',', 'but', 'the', 'flaws', 'of', 'the', 'story', 'are', 'too', 'evident', 'to', 'be', 'easily', 'compensated', '.']\n"
|
||||
"['i', 'love', 'sci-fi', 'and', 'am', 'willing', 'to', 'put', 'up', 'with', 'a', 'lot', '.', 'sci-fi', '<unk>', 'are', 'usually', '<unk>', ',', 'under-appreciated', 'and', 'misunderstood', '.', 'i', 'tried', 'to', 'like', 'this', ',', 'i', 'really', 'did', ',', 'but', 'it', 'is', 'to', 'good', 'tv', 'sci-fi', 'as', 'babylon', '5', 'is', 'to', 'star', 'trek', '(', 'the', 'original', ')', '.', 'silly', '<unk>', ',', 'cheap', 'cardboard', 'sets', ',', 'stilted', 'dialogues', ',', 'cg', 'that', 'doesn', \"'\", 't', 'match', 'the', 'background', ',', 'and', 'painfully', 'one-dimensional', 'characters', 'cannot', 'be', 'overcome', 'with', 'a', \"'\", 'sci-fi', \"'\", 'setting', '.', '(', 'i', \"'\", 'm', 'sure', 'there', 'are', 'those', 'of', 'you', 'out', 'there', 'who', 'think', 'babylon', '5', 'is', 'good', 'sci-fi', 'tv', '.', 'it', \"'\", 's', 'not', '.', 'it', \"'\", 's', 'clichéd', 'and', 'uninspiring', '.', ')', 'while', 'us', 'viewers', 'might', 'like', 'emotion', 'and', 'character', 'development', ',', 'sci-fi', 'is', 'a', 'genre', 'that', 'does', 'not', 'take', 'itself', 'seriously', '(', 'cf', '.', 'star', 'trek', ')', '.', 'it', 'may', 'treat', 'important', 'issues', ',', 'yet', 'not', 'as', 'a', 'serious', 'philosophy', '.', 'it', \"'\", 's', 'really', 'difficult', 'to', 'care', 'about', 'the', 'characters', 'here', 'as', 'they', 'are', 'not', 'simply', 'foolish', ',', 'just', 'missing', 'a', 'spark', 'of', 'life', '.', 'their', 'actions', 'and', 'reactions', 'are', 'wooden', 'and', 'predictable', ',', 'often', 'painful', 'to', 'watch', '.', 'the', 'makers', 'of', 'earth', 'know', 'it', \"'\", 's', 'rubbish', 'as', 'they', 'have', 'to', 'always', 'say', 'gene', 'roddenberry', \"'\", 's', 'earth', '.', '.', '.', 'otherwise', 'people', 'would', 'not', 'continue', 'watching', '.', 'roddenberry', \"'\", 's', 'ashes', 'must', 'be', 'turning', 'in', 'their', 'orbit', 'as', 'this', 'dull', ',', 'cheap', ',', 'poorly', 'edited', '(', 'watching', 'it', 'without']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -579,20 +558,6 @@
|
||||
"print([vocab.itos[i] for i in indexes])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "4Rec_Wk6dFnD"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"valid_data = process_raw_data(raw_valid_data, tokenizer, vocab)\n",
|
||||
"test_data = process_raw_data(raw_test_data, tokenizer, vocab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
@ -1002,9 +967,9 @@
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
|
||||
" [-0.7250, 0.7545, 0.1637, ..., -0.0144, -0.1761, 0.3418],\n",
|
||||
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
||||
" [ 0.4029, 0.1353, 0.6673, ..., -0.3300, 0.7533, -0.1666],\n",
|
||||
" [ 0.1226, 0.0419, 0.0746, ..., -0.0024, -0.2733, -1.0033],\n",
|
||||
" [-0.1009, -0.1484, 0.3141, ..., -0.3414, -0.3768, 0.5605]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 42,
|
||||
@ -1032,7 +997,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"678"
|
||||
"734"
|
||||
]
|
||||
},
|
||||
"execution_count": 43,
|
||||
@ -1061,7 +1026,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['<unk>', '<pad>', '\\x96', 'hadn', '****', '100%', 'camera-work', '*1/2', '$1', '*****']\n"
|
||||
"['<unk>', '<pad>', '\\x96', '****', 'hadn', 'camera-work', '*1/2', '100%', '*****', '$1']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1089,9 +1054,9 @@
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
|
||||
" [-0.7250, 0.7545, 0.1637, ..., -0.0144, -0.1761, 0.3418],\n",
|
||||
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
||||
" [ 0.4029, 0.1353, 0.6673, ..., -0.3300, 0.7533, -0.1666],\n",
|
||||
" [ 0.1226, 0.0419, 0.0746, ..., -0.0024, -0.2733, -1.0033],\n",
|
||||
" [-0.1009, -0.1484, 0.3141, ..., -0.3414, -0.3768, 0.5605]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 45,
|
||||
@ -1298,35 +1263,35 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.687 | Train Acc: 56.51%\n",
|
||||
"\t Val. Loss: 0.677 | Val. Acc: 62.87%\n",
|
||||
"\tTrain Loss: 0.683 | Train Acc: 60.00%\n",
|
||||
"\t Val. Loss: 0.669 | Val. Acc: 67.02%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.665 | Train Acc: 65.13%\n",
|
||||
"\t Val. Loss: 0.650 | Val. Acc: 69.13%\n",
|
||||
"\tTrain Loss: 0.651 | Train Acc: 68.09%\n",
|
||||
"\t Val. Loss: 0.632 | Val. Acc: 71.31%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.629 | Train Acc: 72.45%\n",
|
||||
"\t Val. Loss: 0.611 | Val. Acc: 73.54%\n",
|
||||
"\tTrain Loss: 0.603 | Train Acc: 74.06%\n",
|
||||
"\t Val. Loss: 0.582 | Val. Acc: 74.86%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.583 | Train Acc: 76.17%\n",
|
||||
"\t Val. Loss: 0.566 | Val. Acc: 77.00%\n",
|
||||
"\tTrain Loss: 0.545 | Train Acc: 78.13%\n",
|
||||
"\t Val. Loss: 0.528 | Val. Acc: 78.88%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.533 | Train Acc: 80.22%\n",
|
||||
"\t Val. Loss: 0.521 | Val. Acc: 80.28%\n",
|
||||
"\tTrain Loss: 0.485 | Train Acc: 82.10%\n",
|
||||
"\t Val. Loss: 0.477 | Val. Acc: 81.64%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.484 | Train Acc: 83.24%\n",
|
||||
"\t Val. Loss: 0.480 | Val. Acc: 82.53%\n",
|
||||
"\tTrain Loss: 0.430 | Train Acc: 85.15%\n",
|
||||
"\t Val. Loss: 0.437 | Val. Acc: 83.25%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.440 | Train Acc: 85.46%\n",
|
||||
"\t Val. Loss: 0.443 | Val. Acc: 84.40%\n",
|
||||
"\tTrain Loss: 0.386 | Train Acc: 86.92%\n",
|
||||
"\t Val. Loss: 0.404 | Val. Acc: 84.59%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.401 | Train Acc: 87.10%\n",
|
||||
"\t Val. Loss: 0.414 | Val. Acc: 85.45%\n",
|
||||
"\tTrain Loss: 0.350 | Train Acc: 88.21%\n",
|
||||
"\t Val. Loss: 0.383 | Val. Acc: 85.19%\n",
|
||||
"Epoch: 09 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.367 | Train Acc: 88.41%\n",
|
||||
"\t Val. Loss: 0.390 | Val. Acc: 86.39%\n",
|
||||
"\tTrain Loss: 0.319 | Train Acc: 89.36%\n",
|
||||
"\t Val. Loss: 0.363 | Val. Acc: 85.86%\n",
|
||||
"Epoch: 10 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.340 | Train Acc: 89.23%\n",
|
||||
"\t Val. Loss: 0.370 | Val. Acc: 86.96%\n"
|
||||
"\tTrain Loss: 0.295 | Train Acc: 90.17%\n",
|
||||
"\t Val. Loss: 0.349 | Val. Acc: 86.27%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1372,7 +1337,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.393 | Test Acc: 85.39%\n"
|
||||
"Test Loss: 0.374 | Test Acc: 84.75%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1421,7 +1386,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"9.809165021579247e-06"
|
||||
"2.818893153744284e-05"
|
||||
]
|
||||
},
|
||||
"execution_count": 57,
|
||||
@ -1451,7 +1416,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9999963045120239"
|
||||
"0.9997795224189758"
|
||||
]
|
||||
},
|
||||
"execution_count": 58,
|
||||
@ -1481,7 +1446,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.7485461235046387"
|
||||
"0.6041761040687561"
|
||||
]
|
||||
},
|
||||
"execution_count": 59,
|
||||
@ -1512,7 +1477,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.7485461235046387"
|
||||
"0.6041760444641113"
|
||||
]
|
||||
},
|
||||
"execution_count": 60,
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user