updated tutorials + readme with latest versions of libs.
This commit is contained in:
parent
3cea8e83b8
commit
ba82b9231e
@ -37,7 +37,7 @@
|
||||
"\n",
|
||||
"We use the `TEXT` field to define how the review should be processed, and the `LABEL` field to process the sentiment. \n",
|
||||
"\n",
|
||||
"Our `TEXT` field has `tokenize='spacy'` as an argument. This defines that the \"tokenization\" (the act of splitting the string into discrete \"tokens\") should be done using the [spaCy](https://spacy.io) tokenizer. If no `tokenize` argument is passed, the default is simply splitting the string on spaces.\n",
|
||||
"Our `TEXT` field has `tokenize='spacy'` as an argument. This defines that the \"tokenization\" (the act of splitting the string into discrete \"tokens\") should be done using the [spaCy](https://spacy.io) tokenizer. If no `tokenize` argument is passed, the default is simply splitting the string on spaces. We also need to specify a `tokenizer_language` which tells torchtext which spaCy model to use. We use the `en_core_web_sm` model which has to be downloaded with `python -m spacy download en_core_web_sm` before you run this notebook!\n",
|
||||
"\n",
|
||||
"`LABEL` is defined by a `LabelField`, a special subset of the `Field` class specifically used for handling labels. We will explain the `dtype` argument later.\n",
|
||||
"\n",
|
||||
@ -60,7 +60,8 @@
|
||||
"torch.manual_seed(SEED)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy')\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy',\n",
|
||||
" tokenizer_language = 'en_core_web_sm')\n",
|
||||
"LABEL = data.LabelField(dtype = torch.float)"
|
||||
]
|
||||
},
|
||||
@ -817,7 +818,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -35,7 +35,18 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torchtext import data\n",
|
||||
@ -46,7 +57,10 @@
|
||||
"torch.manual_seed(SEED)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy', include_lengths = True)\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy',\n",
|
||||
" tokenizer_language = 'en_core_web_sm',\n",
|
||||
" include_lengths = True)\n",
|
||||
"\n",
|
||||
"LABEL = data.LabelField(dtype = torch.float)"
|
||||
]
|
||||
},
|
||||
@ -61,7 +75,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from torchtext import datasets\n",
|
||||
"\n",
|
||||
@ -133,7 +156,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"BATCH_SIZE = 64\n",
|
||||
"\n",
|
||||
@ -204,7 +236,7 @@
|
||||
"\n",
|
||||
"As we are passing the lengths of our sentences to be able to use packed padded sequences, we have to add a second argument, `text_lengths`, to `forward`. \n",
|
||||
"\n",
|
||||
"Before we pass our embeddings to the RNN, we need to pack them, which we do with `nn.utils.rnn.packed_padded_sequence`. This will cause our RNN to only process the non-padded elements of our sequence. The RNN will then return `packed_output` (a packed sequence) as well as the `hidden` and `cell` states (both of which are tensors). Without packed padded sequences, `hidden` and `cell` are tensors from the last element in the sequence, which will most probably be a pad token, however when using packed padded sequences they are both from the last non-padded element in the sequence. \n",
|
||||
"Before we pass our embeddings to the RNN, we need to pack them, which we do with `nn.utils.rnn.packed_padded_sequence`. This will cause our RNN to only process the non-padded elements of our sequence. The RNN will then return `packed_output` (a packed sequence) as well as the `hidden` and `cell` states (both of which are tensors). Without packed padded sequences, `hidden` and `cell` are tensors from the last element in the sequence, which will most probably be a pad token, however when using packed padded sequences they are both from the last non-padded element in the sequence. Note that the `lengths` argument of `packed_padded_sequence` must be a CPU tensor so we explicitly make it one by using `.to('cpu')`.\n",
|
||||
"\n",
|
||||
"We then unpack the output sequence, with `nn.utils.rnn.pad_packed_sequence`, to transform it from a packed sequence to a tensor. The elements of `output` from padding tokens will be zero tensors (tensors where every element is zero). Usually, we only have to unpack output if we are going to use it later on in the model. Although we aren't in this case, we still unpack the sequence just to show how it is done.\n",
|
||||
"\n",
|
||||
@ -246,7 +278,8 @@
|
||||
" #embedded = [sent len, batch size, emb dim]\n",
|
||||
" \n",
|
||||
" #pack sequence\n",
|
||||
" packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths)\n",
|
||||
" # lengths need to be on CPU!\n",
|
||||
" packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths.to('cpu'))\n",
|
||||
" \n",
|
||||
" packed_output, (hidden, cell) = self.rnn(packed_embedded)\n",
|
||||
" \n",
|
||||
@ -383,9 +416,9 @@
|
||||
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.0614, -0.0516, -0.6159, ..., -0.0354, 0.0379, -0.1809],\n",
|
||||
" [ 0.1885, -0.1690, 0.1530, ..., -0.2077, 0.5473, -0.4517],\n",
|
||||
" [-0.1182, -0.4701, -0.0600, ..., 0.7991, -0.0194, 0.4785]])"
|
||||
" [ 0.6783, 0.0488, 0.5860, ..., 0.2680, -0.0086, 0.5758],\n",
|
||||
" [-0.6208, -0.0480, -0.1046, ..., 0.3718, 0.1225, 0.1061],\n",
|
||||
" [-0.6553, -0.6292, 0.9967, ..., 0.2278, -0.1975, 0.0857]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
@ -421,9 +454,9 @@
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.0614, -0.0516, -0.6159, ..., -0.0354, 0.0379, -0.1809],\n",
|
||||
" [ 0.1885, -0.1690, 0.1530, ..., -0.2077, 0.5473, -0.4517],\n",
|
||||
" [-0.1182, -0.4701, -0.0600, ..., 0.7991, -0.0194, 0.4785]])\n"
|
||||
" [ 0.6783, 0.0488, 0.5860, ..., 0.2680, -0.0086, 0.5758],\n",
|
||||
" [-0.6208, -0.0480, -0.1046, ..., 0.3718, 0.1225, 0.1061],\n",
|
||||
" [-0.6553, -0.6292, 0.9967, ..., 0.2278, -0.1975, 0.0857]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -638,25 +671,33 @@
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 28s\n",
|
||||
"\tTrain Loss: 0.648 | Train Acc: 62.05%\n",
|
||||
"\t Val. Loss: 0.620 | Val. Acc: 66.72%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 27s\n",
|
||||
"\tTrain Loss: 0.622 | Train Acc: 66.51%\n",
|
||||
"\t Val. Loss: 0.669 | Val. Acc: 62.83%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 27s\n",
|
||||
"\tTrain Loss: 0.586 | Train Acc: 69.01%\n",
|
||||
"\t Val. Loss: 0.522 | Val. Acc: 75.52%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 27s\n",
|
||||
"\tTrain Loss: 0.415 | Train Acc: 82.02%\n",
|
||||
"\t Val. Loss: 0.457 | Val. Acc: 77.10%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 27s\n",
|
||||
"\tTrain Loss: 0.335 | Train Acc: 86.15%\n",
|
||||
"\t Val. Loss: 0.305 | Val. Acc: 87.15%\n"
|
||||
"Epoch: 01 | Epoch Time: 0m 36s\n",
|
||||
"\tTrain Loss: 0.673 | Train Acc: 58.05%\n",
|
||||
"\t Val. Loss: 0.619 | Val. Acc: 64.97%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 36s\n",
|
||||
"\tTrain Loss: 0.611 | Train Acc: 66.33%\n",
|
||||
"\t Val. Loss: 0.510 | Val. Acc: 74.32%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 37s\n",
|
||||
"\tTrain Loss: 0.484 | Train Acc: 77.04%\n",
|
||||
"\t Val. Loss: 0.397 | Val. Acc: 82.95%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 37s\n",
|
||||
"\tTrain Loss: 0.384 | Train Acc: 83.57%\n",
|
||||
"\t Val. Loss: 0.407 | Val. Acc: 83.23%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 37s\n",
|
||||
"\tTrain Loss: 0.314 | Train Acc: 86.98%\n",
|
||||
"\t Val. Loss: 0.314 | Val. Acc: 86.36%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -701,7 +742,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.308 | Test Acc: 87.07%\n"
|
||||
"Test Loss: 0.334 | Test Acc: 85.28%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -744,7 +785,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import spacy\n",
|
||||
"nlp = spacy.load('en')\n",
|
||||
"nlp = spacy.load('en_core_web_sm')\n",
|
||||
"\n",
|
||||
"def predict_sentiment(model, sentence):\n",
|
||||
" model.eval()\n",
|
||||
@ -773,7 +814,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.005683214403688908"
|
||||
"0.05380420759320259"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
@ -800,7 +841,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9926869869232178"
|
||||
"0.94941645860672"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
@ -838,7 +879,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -50,7 +50,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['This', 'film', 'is', 'terrible', 'This film', 'film is', 'is terrible']"
|
||||
"['This', 'film', 'is', 'terrible', 'film is', 'This film', 'is terrible']"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
@ -75,7 +75,18 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torchtext import data\n",
|
||||
@ -86,7 +97,10 @@
|
||||
"torch.manual_seed(SEED)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy', preprocessing = generate_bigrams)\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy',\n",
|
||||
" tokenizer_language = 'en_core_web_sm',\n",
|
||||
" preprocessing = generate_bigrams)\n",
|
||||
"\n",
|
||||
"LABEL = data.LabelField(dtype = torch.float)"
|
||||
]
|
||||
},
|
||||
@ -101,7 +115,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"\n",
|
||||
@ -144,7 +167,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"BATCH_SIZE = 64\n",
|
||||
"\n",
|
||||
@ -287,9 +319,9 @@
|
||||
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [ 0.3199, 0.0746, 0.0231, ..., -0.3609, 1.1303, 0.5668],\n",
|
||||
" [-1.0530, -1.0757, 0.3903, ..., 0.0792, -0.3059, 1.9734],\n",
|
||||
" [-0.1734, -0.3195, 0.3694, ..., -0.2435, 0.4767, 0.1151]])"
|
||||
" [-0.1606, -0.7357, 0.5809, ..., 0.8704, -1.5637, -1.5724],\n",
|
||||
" [-1.3126, -1.6717, 0.4203, ..., 0.2348, -0.9110, 1.0914],\n",
|
||||
" [-1.5268, 1.5639, -1.0541, ..., 1.0045, -0.6813, -0.8846]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
@ -507,25 +539,33 @@
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 6s\n",
|
||||
"\tTrain Loss: 0.688 | Train Acc: 57.23%\n",
|
||||
"\t Val. Loss: 0.642 | Val. Acc: 71.23%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 5s\n",
|
||||
"\tTrain Loss: 0.653 | Train Acc: 71.09%\n",
|
||||
"\t Val. Loss: 0.521 | Val. Acc: 75.28%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 5s\n",
|
||||
"\tTrain Loss: 0.582 | Train Acc: 78.88%\n",
|
||||
"\t Val. Loss: 0.449 | Val. Acc: 79.64%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 5s\n",
|
||||
"\tTrain Loss: 0.505 | Train Acc: 83.15%\n",
|
||||
"\t Val. Loss: 0.426 | Val. Acc: 82.12%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 5s\n",
|
||||
"\tTrain Loss: 0.439 | Train Acc: 85.99%\n",
|
||||
"\t Val. Loss: 0.397 | Val. Acc: 85.02%\n"
|
||||
"Epoch: 01 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.688 | Train Acc: 61.31%\n",
|
||||
"\t Val. Loss: 0.637 | Val. Acc: 72.46%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 6s\n",
|
||||
"\tTrain Loss: 0.651 | Train Acc: 75.04%\n",
|
||||
"\t Val. Loss: 0.507 | Val. Acc: 76.92%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 6s\n",
|
||||
"\tTrain Loss: 0.578 | Train Acc: 79.91%\n",
|
||||
"\t Val. Loss: 0.424 | Val. Acc: 80.97%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 6s\n",
|
||||
"\tTrain Loss: 0.501 | Train Acc: 83.97%\n",
|
||||
"\t Val. Loss: 0.377 | Val. Acc: 84.34%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 6s\n",
|
||||
"\tTrain Loss: 0.435 | Train Acc: 86.96%\n",
|
||||
"\t Val. Loss: 0.363 | Val. Acc: 86.18%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -572,7 +612,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.391 | Test Acc: 85.11%\n"
|
||||
"Test Loss: 0.381 | Test Acc: 85.42%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -600,7 +640,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import spacy\n",
|
||||
"nlp = spacy.load('en')\n",
|
||||
"nlp = spacy.load('en_core_web_sm')\n",
|
||||
"\n",
|
||||
"def predict_sentiment(model, sentence):\n",
|
||||
" model.eval()\n",
|
||||
@ -627,7 +667,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1.621993561684576e-07"
|
||||
"2.1313092350011553e-12"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
@ -692,7 +732,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -36,7 +36,20 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torchtext import data\n",
|
||||
@ -51,7 +64,9 @@
|
||||
"torch.manual_seed(SEED)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy', batch_first = True)\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy', \n",
|
||||
" tokenizer_language = 'en_core_web_sm',\n",
|
||||
" batch_first = True)\n",
|
||||
"LABEL = data.LabelField(dtype = torch.float)\n",
|
||||
"\n",
|
||||
"train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
|
||||
@ -93,7 +108,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"BATCH_SIZE = 64\n",
|
||||
"\n",
|
||||
@ -416,9 +440,9 @@
|
||||
" [-0.8555, -0.7208, 1.3755, ..., 0.0825, -1.1314, 0.3997],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.0614, -0.0516, -0.6159, ..., -0.0354, 0.0379, -0.1809],\n",
|
||||
" [ 0.1885, -0.1690, 0.1530, ..., -0.2077, 0.5473, -0.4517],\n",
|
||||
" [-0.1182, -0.4701, -0.0600, ..., 0.7991, -0.0194, 0.4785]])"
|
||||
" [ 0.6783, 0.0488, 0.5860, ..., 0.2680, -0.0086, 0.5758],\n",
|
||||
" [-0.6208, -0.0480, -0.1046, ..., 0.3718, 0.1225, 0.1061],\n",
|
||||
" [-0.6553, -0.6292, 0.9967, ..., 0.2278, -0.1975, 0.0857]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
@ -622,25 +646,33 @@
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 13s\n",
|
||||
"\tTrain Loss: 0.645 | Train Acc: 62.08%\n",
|
||||
"\t Val. Loss: 0.488 | Val. Acc: 78.64%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 11s\n",
|
||||
"\tTrain Loss: 0.418 | Train Acc: 81.14%\n",
|
||||
"\t Val. Loss: 0.361 | Val. Acc: 84.59%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 11s\n",
|
||||
"\tTrain Loss: 0.300 | Train Acc: 87.33%\n",
|
||||
"\t Val. Loss: 0.348 | Val. Acc: 85.06%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 11s\n",
|
||||
"\tTrain Loss: 0.217 | Train Acc: 91.49%\n",
|
||||
"\t Val. Loss: 0.320 | Val. Acc: 86.71%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 11s\n",
|
||||
"\tTrain Loss: 0.156 | Train Acc: 94.22%\n",
|
||||
"\t Val. Loss: 0.334 | Val. Acc: 87.06%\n"
|
||||
"\tTrain Loss: 0.649 | Train Acc: 61.79%\n",
|
||||
"\t Val. Loss: 0.507 | Val. Acc: 78.93%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 13s\n",
|
||||
"\tTrain Loss: 0.433 | Train Acc: 79.86%\n",
|
||||
"\t Val. Loss: 0.357 | Val. Acc: 84.57%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 13s\n",
|
||||
"\tTrain Loss: 0.305 | Train Acc: 87.36%\n",
|
||||
"\t Val. Loss: 0.312 | Val. Acc: 86.76%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 13s\n",
|
||||
"\tTrain Loss: 0.224 | Train Acc: 91.20%\n",
|
||||
"\t Val. Loss: 0.303 | Val. Acc: 87.16%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 14s\n",
|
||||
"\tTrain Loss: 0.159 | Train Acc: 94.16%\n",
|
||||
"\t Val. Loss: 0.317 | Val. Acc: 87.37%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -685,7 +717,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.339 | Test Acc: 85.39%\n"
|
||||
"Test Loss: 0.343 | Test Acc: 85.31%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -710,12 +742,12 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import spacy\n",
|
||||
"nlp = spacy.load('en')\n",
|
||||
"nlp = spacy.load('en_core_web_sm')\n",
|
||||
"\n",
|
||||
"def predict_sentiment(model, sentence, min_len = 5):\n",
|
||||
" model.eval()\n",
|
||||
@ -738,16 +770,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.11022213101387024"
|
||||
"0.09913548082113266"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -765,16 +797,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9785954356193542"
|
||||
"0.9769725799560547"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -800,7 +832,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -21,7 +21,20 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torchtext import data\n",
|
||||
@ -33,7 +46,9 @@
|
||||
"torch.manual_seed(SEED)\n",
|
||||
"torch.backends.cudnn.deterministic = True\n",
|
||||
"\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy')\n",
|
||||
"TEXT = data.Field(tokenize = 'spacy',\n",
|
||||
" tokenizer_language = 'en_core_web_sm')\n",
|
||||
"\n",
|
||||
"LABEL = data.LabelField()\n",
|
||||
"\n",
|
||||
"train_data, test_data = datasets.TREC.splits(TEXT, LABEL, fine_grained=False)\n",
|
||||
@ -115,7 +130,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"defaultdict(<function _default_unk_index at 0x7f0a50190d08>, {'HUM': 0, 'ENTY': 1, 'DESC': 2, 'NUM': 3, 'LOC': 4, 'ABBR': 5})\n"
|
||||
"defaultdict(None, {'HUM': 0, 'ENTY': 1, 'DESC': 2, 'NUM': 3, 'LOC': 4, 'ABBR': 5})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -134,7 +149,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"BATCH_SIZE = 64\n",
|
||||
"\n",
|
||||
@ -254,7 +278,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The model has 842,406 trainable parameters\n"
|
||||
"The model has 841,806 trainable parameters\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -286,7 +310,7 @@
|
||||
" ...,\n",
|
||||
" [-0.3110, -0.3398, 1.0308, ..., 0.5317, 0.2836, -0.0640],\n",
|
||||
" [ 0.0091, 0.2810, 0.7356, ..., -0.7508, 0.8967, -0.7631],\n",
|
||||
" [ 0.4306, 1.2011, 0.0873, ..., 0.8817, 0.3722, 0.3458]])"
|
||||
" [ 0.5831, -0.2514, 0.4156, ..., -0.2735, -0.8659, -1.4063]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
@ -367,9 +391,10 @@
|
||||
" \"\"\"\n",
|
||||
" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
|
||||
" \"\"\"\n",
|
||||
" max_preds = preds.argmax(dim = 1, keepdim = True) # get the index of the max probability\n",
|
||||
" correct = max_preds.squeeze(1).eq(y)\n",
|
||||
" return correct.sum() / torch.FloatTensor([y.shape[0]])"
|
||||
" top_pred = preds.argmax(1, keepdim = True)\n",
|
||||
" correct = top_pred.eq(y.view_as(top_pred)).sum()\n",
|
||||
" acc = correct.float() / y.shape[0]\n",
|
||||
" return acc"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -479,25 +504,33 @@
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 0s\n",
|
||||
"\tTrain Loss: 1.310 | Train Acc: 47.99%\n",
|
||||
"\t Val. Loss: 0.947 | Val. Acc: 66.81%\n",
|
||||
"\tTrain Loss: 1.312 | Train Acc: 47.11%\n",
|
||||
"\t Val. Loss: 0.947 | Val. Acc: 66.41%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 0s\n",
|
||||
"\tTrain Loss: 0.869 | Train Acc: 69.09%\n",
|
||||
"\t Val. Loss: 0.746 | Val. Acc: 74.18%\n",
|
||||
"\tTrain Loss: 0.870 | Train Acc: 69.18%\n",
|
||||
"\t Val. Loss: 0.741 | Val. Acc: 74.14%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 0s\n",
|
||||
"\tTrain Loss: 0.665 | Train Acc: 76.94%\n",
|
||||
"\t Val. Loss: 0.627 | Val. Acc: 78.03%\n",
|
||||
"\tTrain Loss: 0.675 | Train Acc: 76.32%\n",
|
||||
"\t Val. Loss: 0.621 | Val. Acc: 78.49%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 0s\n",
|
||||
"\tTrain Loss: 0.503 | Train Acc: 83.42%\n",
|
||||
"\t Val. Loss: 0.548 | Val. Acc: 79.73%\n",
|
||||
"\tTrain Loss: 0.506 | Train Acc: 83.97%\n",
|
||||
"\t Val. Loss: 0.547 | Val. Acc: 80.32%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 0s\n",
|
||||
"\tTrain Loss: 0.376 | Train Acc: 87.88%\n",
|
||||
"\t Val. Loss: 0.506 | Val. Acc: 81.40%\n"
|
||||
"\tTrain Loss: 0.373 | Train Acc: 88.23%\n",
|
||||
"\t Val. Loss: 0.487 | Val. Acc: 82.92%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -542,7 +575,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.411 | Test Acc: 87.15%\n"
|
||||
"Test Loss: 0.415 | Test Acc: 86.07%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -570,7 +603,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import spacy\n",
|
||||
"nlp = spacy.load('en')\n",
|
||||
"nlp = spacy.load('en_core_web_sm')\n",
|
||||
"\n",
|
||||
"def predict_class(model, sentence, min_len = 4):\n",
|
||||
" model.eval()\n",
|
||||
@ -681,7 +714,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -52,16 +52,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I1106 14:55:11.110527 139759243081536 file_utils.py:39] PyTorch version 1.3.0 available.\n",
|
||||
"I1106 14:55:11.917650 139759243081536 tokenization_utils.py:374] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/ben/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import BertTokenizer\n",
|
||||
"\n",
|
||||
@ -294,7 +285,18 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from torchtext import data\n",
|
||||
"\n",
|
||||
@ -321,7 +323,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from torchtext import datasets\n",
|
||||
"\n",
|
||||
@ -367,7 +378,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'text': [5949, 1997, 2026, 2166, 1010, 1012, 1012, 1012, 1012, 1996, 2472, 2323, 2022, 10339, 1012, 2339, 2111, 2514, 2027, 2342, 2000, 2191, 22692, 5691, 2097, 2196, 2191, 3168, 2000, 2033, 1012, 2043, 2016, 2351, 2012, 1996, 2203, 1010, 2009, 2081, 2033, 4756, 1012, 1045, 2018, 2000, 2689, 1996, 3149, 2116, 2335, 2802, 1996, 2143, 2138, 1045, 2001, 2893, 10339, 3666, 2107, 3532, 3772, 1012, 11504, 1996, 3124, 2040, 2209, 9895, 2196, 4152, 2147, 2153, 1012, 2006, 2327, 1997, 2008, 1045, 3246, 1996, 2472, 2196, 4152, 2000, 2191, 2178, 2143, 1010, 1998, 2038, 2010, 3477, 5403, 3600, 2579, 2067, 2005, 2023, 10231, 1012, 1063, 1012, 6185, 2041, 1997, 2184, 1065], 'label': 'neg'}\n"
|
||||
"{'text': [1042, 4140, 1996, 2087, 2112, 1010, 2023, 3185, 5683, 2066, 1037, 1000, 2081, 1011, 2005, 1011, 2694, 1000, 3947, 1012, 1996, 3257, 2003, 10654, 1011, 28273, 1010, 1996, 3772, 1006, 2007, 1996, 6453, 1997, 5965, 1043, 11761, 2638, 1007, 2003, 2058, 13088, 10593, 2102, 1998, 7815, 2100, 1012, 15339, 14282, 1010, 3391, 1010, 18058, 2014, 3210, 2066, 2016, 1005, 1055, 3147, 3752, 2068, 2125, 1037, 16091, 4003, 1012, 2069, 2028, 2518, 3084, 2023, 2143, 4276, 3666, 1010, 1998, 2008, 2003, 2320, 10012, 3310, 2067, 2013, 1996, 1000, 7367, 11368, 5649, 1012, 1000, 2045, 2003, 2242, 14888, 2055, 3666, 1037, 2235, 2775, 4028, 2619, 1010, 1998, 2023, 3185, 2453, 2022, 2062, 2084, 2070, 2064, 5047, 2074, 2005, 2008, 3114, 1012, 2009, 2003, 7078, 5923, 1011, 27017, 1012, 2023, 2143, 2069, 2515, 2028, 2518, 2157, 1010, 2021, 2009, 21145, 2008, 2028, 2518, 2157, 2041, 1997, 1996, 2380, 1012, 4276, 3773, 2074, 2005, 1996, 2197, 2184, 2781, 2030, 2061, 1012], 'label': 'neg'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -391,7 +402,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['waste', 'of', 'my', 'life', ',', '.', '.', '.', '.', 'the', 'director', 'should', 'be', 'embarrassed', '.', 'why', 'people', 'feel', 'they', 'need', 'to', 'make', 'worthless', 'movies', 'will', 'never', 'make', 'sense', 'to', 'me', '.', 'when', 'she', 'died', 'at', 'the', 'end', ',', 'it', 'made', 'me', 'laugh', '.', 'i', 'had', 'to', 'change', 'the', 'channel', 'many', 'times', 'throughout', 'the', 'film', 'because', 'i', 'was', 'getting', 'embarrassed', 'watching', 'such', 'poor', 'acting', '.', 'hopefully', 'the', 'guy', 'who', 'played', 'heath', 'never', 'gets', 'work', 'again', '.', 'on', 'top', 'of', 'that', 'i', 'hope', 'the', 'director', 'never', 'gets', 'to', 'make', 'another', 'film', ',', 'and', 'has', 'his', 'pay', '##che', '##ck', 'taken', 'back', 'for', 'this', 'crap', '.', '{', '.', '02', 'out', 'of', '10', '}']\n"
|
||||
"['f', '##ot', 'the', 'most', 'part', ',', 'this', 'movie', 'feels', 'like', 'a', '\"', 'made', '-', 'for', '-', 'tv', '\"', 'effort', '.', 'the', 'direction', 'is', 'ham', '-', 'fisted', ',', 'the', 'acting', '(', 'with', 'the', 'exception', 'of', 'fred', 'g', '##wyn', '##ne', ')', 'is', 'over', '##wr', '##ough', '##t', 'and', 'soap', '##y', '.', 'denise', 'crosby', ',', 'particularly', ',', 'delivers', 'her', 'lines', 'like', 'she', \"'\", 's', 'cold', 'reading', 'them', 'off', 'a', 'cue', 'card', '.', 'only', 'one', 'thing', 'makes', 'this', 'film', 'worth', 'watching', ',', 'and', 'that', 'is', 'once', 'gage', 'comes', 'back', 'from', 'the', '\"', 'se', '##met', '##ary', '.', '\"', 'there', 'is', 'something', 'disturbing', 'about', 'watching', 'a', 'small', 'child', 'murder', 'someone', ',', 'and', 'this', 'movie', 'might', 'be', 'more', 'than', 'some', 'can', 'handle', 'just', 'for', 'that', 'reason', '.', 'it', 'is', 'absolutely', 'bone', '-', 'chilling', '.', 'this', 'film', 'only', 'does', 'one', 'thing', 'right', ',', 'but', 'it', 'knocks', 'that', 'one', 'thing', 'right', 'out', 'of', 'the', 'park', '.', 'worth', 'seeing', 'just', 'for', 'the', 'last', '10', 'minutes', 'or', 'so', '.']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -445,7 +456,16 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"BATCH_SIZE = 128\n",
|
||||
"\n",
|
||||
@ -470,39 +490,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"I1106 14:57:06.877642 139759243081536 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/ben/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c\n",
|
||||
"I1106 14:57:06.878792 139759243081536 configuration_utils.py:168] Model config {\n",
|
||||
" \"attention_probs_dropout_prob\": 0.1,\n",
|
||||
" \"finetuning_task\": null,\n",
|
||||
" \"hidden_act\": \"gelu\",\n",
|
||||
" \"hidden_dropout_prob\": 0.1,\n",
|
||||
" \"hidden_size\": 768,\n",
|
||||
" \"initializer_range\": 0.02,\n",
|
||||
" \"intermediate_size\": 3072,\n",
|
||||
" \"layer_norm_eps\": 1e-12,\n",
|
||||
" \"max_position_embeddings\": 512,\n",
|
||||
" \"num_attention_heads\": 12,\n",
|
||||
" \"num_hidden_layers\": 12,\n",
|
||||
" \"num_labels\": 2,\n",
|
||||
" \"output_attentions\": false,\n",
|
||||
" \"output_hidden_states\": false,\n",
|
||||
" \"output_past\": true,\n",
|
||||
" \"pruned_heads\": {},\n",
|
||||
" \"torchscript\": false,\n",
|
||||
" \"type_vocab_size\": 2,\n",
|
||||
" \"use_bfloat16\": false,\n",
|
||||
" \"vocab_size\": 30522\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"I1106 14:57:07.421291 139759243081536 modeling_utils.py:337] loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin from cache at /home/ben/.cache/torch/transformers/aa1ef1aede4482d0dbcd4d52baad8ae300e60902e88fcb0bebdec09afd232066.36ca03ab34a1a5d5fa7bc3d03d55c4fa650fed07220e2eeebc06ce58d0e9a157\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import BertTokenizer, BertModel\n",
|
||||
"\n",
|
||||
@ -880,28 +868,36 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/ben/miniconda3/envs/pytorch17/lib/python3.8/site-packages/torchtext-0.9.0a0+c38fd42-py3.8-linux-x86_64.egg/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
||||
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 7m 27s\n",
|
||||
"\tTrain Loss: 0.286 | Train Acc: 88.16%\n",
|
||||
"\t Val. Loss: 0.247 | Val. Acc: 90.26%\n",
|
||||
"Epoch: 02 | Epoch Time: 7m 27s\n",
|
||||
"\tTrain Loss: 0.234 | Train Acc: 90.77%\n",
|
||||
"\t Val. Loss: 0.229 | Val. Acc: 91.00%\n",
|
||||
"Epoch: 03 | Epoch Time: 7m 27s\n",
|
||||
"\tTrain Loss: 0.209 | Train Acc: 91.83%\n",
|
||||
"\t Val. Loss: 0.225 | Val. Acc: 91.10%\n",
|
||||
"Epoch: 04 | Epoch Time: 7m 27s\n",
|
||||
"\tTrain Loss: 0.182 | Train Acc: 92.97%\n",
|
||||
"\t Val. Loss: 0.217 | Val. Acc: 91.98%\n",
|
||||
"Epoch: 05 | Epoch Time: 7m 27s\n",
|
||||
"\tTrain Loss: 0.156 | Train Acc: 94.17%\n",
|
||||
"\t Val. Loss: 0.230 | Val. Acc: 91.76%\n"
|
||||
"Epoch: 01 | Epoch Time: 7m 13s\n",
|
||||
"\tTrain Loss: 0.502 | Train Acc: 74.41%\n",
|
||||
"\t Val. Loss: 0.270 | Val. Acc: 89.15%\n",
|
||||
"Epoch: 02 | Epoch Time: 7m 7s\n",
|
||||
"\tTrain Loss: 0.281 | Train Acc: 88.49%\n",
|
||||
"\t Val. Loss: 0.224 | Val. Acc: 91.32%\n",
|
||||
"Epoch: 03 | Epoch Time: 7m 17s\n",
|
||||
"\tTrain Loss: 0.239 | Train Acc: 90.67%\n",
|
||||
"\t Val. Loss: 0.211 | Val. Acc: 91.91%\n",
|
||||
"Epoch: 04 | Epoch Time: 7m 14s\n",
|
||||
"\tTrain Loss: 0.206 | Train Acc: 91.81%\n",
|
||||
"\t Val. Loss: 0.206 | Val. Acc: 92.01%\n",
|
||||
"Epoch: 05 | Epoch Time: 7m 15s\n",
|
||||
"\tTrain Loss: 0.188 | Train Acc: 92.63%\n",
|
||||
"\t Val. Loss: 0.211 | Val. Acc: 91.92%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -939,14 +935,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.198 | Test Acc: 92.31%\n"
|
||||
"Test Loss: 0.209 | Test Acc: 91.58%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -969,7 +965,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -986,16 +982,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.02264496125280857"
|
||||
"0.03391794115304947"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -1006,16 +1002,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9411056041717529"
|
||||
"0.8869886994361877"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -1041,7 +1037,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.5"
|
||||
"version": "3.8.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
12
README.md
12
README.md
@ -6,11 +6,11 @@
|
||||
|
||||
# PyTorch Sentiment Analysis
|
||||
|
||||
This repo contains tutorials covering how to perform sentiment analysis using [PyTorch](https://github.com/pytorch/pytorch) 1.7 and [torchtext](https://github.com/pytorch/text) 0.8 using Python 3.8.
|
||||
This repo contains tutorials covering how to perform sentiment analysis using [PyTorch](https://github.com/pytorch/pytorch) 1.7, [torchtext](https://github.com/pytorch/text) 0.8 and spaCy 3.0, using Python 3.8.
|
||||
|
||||
The first 2 tutorials will cover getting started with the de facto approach to sentiment analysis: recurrent neural networks (RNNs). The third notebook covers the [FastText](https://arxiv.org/abs/1607.01759) model and the final covers a [convolutional neural network](https://arxiv.org/abs/1408.5882) (CNN) model.
|
||||
|
||||
There are also 2 bonus "appendix" notebooks. The first covers loading your own datasets with TorchText, while the second contains a brief look at the pre-trained word embeddings provided by TorchText.
|
||||
There are also 2 bonus "appendix" notebooks. The first covers loading your own datasets with torchtext, while the second contains a brief look at the pre-trained word embeddings provided by torchtext.
|
||||
|
||||
**If you find any mistakes or disagree with any of the explanations, please do not hesitate to [submit an issue](https://github.com/bentrevett/pytorch-sentiment-analysis/issues/new). I welcome any feedback, positive or negative!**
|
||||
|
||||
@ -18,7 +18,7 @@ There are also 2 bonus "appendix" notebooks. The first covers loading your own d
|
||||
|
||||
To install PyTorch, see installation instructions on the [PyTorch website](https://pytorch.org/get-started/locally).
|
||||
|
||||
To install TorchText:
|
||||
To install torchtext:
|
||||
|
||||
``` bash
|
||||
pip install torchtext
|
||||
@ -27,7 +27,7 @@ pip install torchtext
|
||||
We'll also make use of spaCy to tokenize our data. To install spaCy, follow the instructions [here](https://spacy.io/usage/) making sure to install the English models with:
|
||||
|
||||
``` bash
|
||||
python -m spacy download en
|
||||
python -m spacy download en_core_web_sm
|
||||
```
|
||||
|
||||
For tutorial 6, we'll use the transformers library, which can be installed via:
|
||||
@ -35,13 +35,13 @@ For tutorial 6, we'll use the transformers library, which can be installed via:
|
||||
```bash
|
||||
pip install transformers
|
||||
```
|
||||
These tutorials were created using version 1.2 of the transformers library.
|
||||
These tutorials were created using version 4.3 of the transformers library.
|
||||
|
||||
## Tutorials
|
||||
|
||||
* 1 - [Simple Sentiment Analysis](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/1%20-%20Simple%20Sentiment%20Analysis.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-sentiment-analysis/blob/master/1%20-%20Simple%20Sentiment%20Analysis.ipynb)
|
||||
|
||||
This tutorial covers the workflow of a PyTorch with TorchText project. We'll learn how to: load data, create train/test/validation splits, build a vocabulary, create data iterators, define a model and implement the train/evaluate/test loop. The model will be simple and achieve poor performance, but this will be improved in the subsequent tutorials.
|
||||
This tutorial covers the workflow of a PyTorch with torchtext project. We'll learn how to: load data, create train/test/validation splits, build a vocabulary, create data iterators, define a model and implement the train/evaluate/test loop. The model will be simple and achieve poor performance, but this will be improved in the subsequent tutorials.
|
||||
|
||||
* 2 - [Upgraded Sentiment Analysis](https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/2%20-%20Upgraded%20Sentiment%20Analysis.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bentrevett/pytorch-sentiment-analysis/blob/master/2%20-%20Upgraded%20Sentiment%20Analysis.ipynb)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user