fix lstm notebook

- added missing pad_index to embedding layer
- increased dropout
This commit is contained in:
bentrevett 2021-07-08 11:43:28 +01:00
parent bc5e0f492a
commit c8d4c1d136

View File

@ -27,7 +27,7 @@
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f0ba8bbc9b0>"
"<torch._C.Generator at 0x7faeccf1e990>"
]
},
"execution_count": 2,
@ -184,9 +184,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-0f4f4fd2aa8eec54.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-3f54ccc520f805f3.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-008d75ce96cff98f.arrow\n"
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-631f583ffb0d9c68.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-6f5a6c52dcbaf2d0.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-c455c5d5c41c2779.arrow\n"
]
}
],
@ -265,9 +265,9 @@
"source": [
"class LSTM(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional,\n",
" dropout_rate):\n",
" dropout_rate, pad_index):\n",
" super().__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, bidirectional=bidirectional,\n",
" dropout=dropout_rate, batch_first=True)\n",
" self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)\n",
@ -309,9 +309,10 @@
"output_dim = len(train_data.unique('label'))\n",
"n_layers = 2\n",
"bidirectional = True\n",
"dropout_rate = 0.5\n",
"dropout_rate = 0.75\n",
"\n",
"model = LSTM(vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout_rate)"
"model = LSTM(vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout_rate, \n",
" pad_index)"
]
},
{
@ -364,10 +365,10 @@
"data": {
"text/plain": [
"LSTM(\n",
" (embedding): Embedding(21543, 300)\n",
" (lstm): LSTM(300, 300, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)\n",
" (embedding): Embedding(21543, 300, padding_idx=1)\n",
" (lstm): LSTM(300, 300, num_layers=2, batch_first=True, dropout=0.75, bidirectional=True)\n",
" (fc): Linear(in_features=600, out_features=2, bias=True)\n",
" (dropout): Dropout(p=0.5, inplace=False)\n",
" (dropout): Dropout(p=0.75, inplace=False)\n",
")"
]
},
@ -573,35 +574,35 @@
"output_type": "stream",
"text": [
"epoch: 1\n",
"train_loss: 0.654, train_acc: 0.619\n",
"valid_loss: 0.600, valid_acc: 0.687\n",
"train_loss: 0.649, train_acc: 0.605\n",
"valid_loss: 0.637, valid_acc: 0.682\n",
"epoch: 2\n",
"train_loss: 0.588, train_acc: 0.701\n",
"valid_loss: 0.605, valid_acc: 0.672\n",
"train_loss: 0.563, train_acc: 0.710\n",
"valid_loss: 0.611, valid_acc: 0.642\n",
"epoch: 3\n",
"train_loss: 0.533, train_acc: 0.738\n",
"valid_loss: 0.533, valid_acc: 0.737\n",
"train_loss: 0.594, train_acc: 0.697\n",
"valid_loss: 0.710, valid_acc: 0.688\n",
"epoch: 4\n",
"train_loss: 0.415, train_acc: 0.813\n",
"valid_loss: 0.392, valid_acc: 0.840\n",
"train_loss: 0.590, train_acc: 0.692\n",
"valid_loss: 0.517, valid_acc: 0.750\n",
"epoch: 5\n",
"train_loss: 0.359, train_acc: 0.849\n",
"valid_loss: 0.408, valid_acc: 0.848\n",
"train_loss: 0.535, train_acc: 0.733\n",
"valid_loss: 0.687, valid_acc: 0.507\n",
"epoch: 6\n",
"train_loss: 0.285, train_acc: 0.885\n",
"valid_loss: 0.359, valid_acc: 0.858\n",
"train_loss: 0.677, train_acc: 0.566\n",
"valid_loss: 0.635, valid_acc: 0.679\n",
"epoch: 7\n",
"train_loss: 0.238, train_acc: 0.909\n",
"valid_loss: 0.359, valid_acc: 0.857\n",
"train_loss: 0.593, train_acc: 0.704\n",
"valid_loss: 0.685, valid_acc: 0.545\n",
"epoch: 8\n",
"train_loss: 0.195, train_acc: 0.927\n",
"valid_loss: 0.382, valid_acc: 0.852\n",
"train_loss: 0.653, train_acc: 0.587\n",
"valid_loss: 0.648, valid_acc: 0.630\n",
"epoch: 9\n",
"train_loss: 0.177, train_acc: 0.934\n",
"valid_loss: 0.367, valid_acc: 0.863\n",
"train_loss: 0.613, train_acc: 0.658\n",
"valid_loss: 0.545, valid_acc: 0.737\n",
"epoch: 10\n",
"train_loss: 0.145, train_acc: 0.948\n",
"valid_loss: 0.488, valid_acc: 0.842\n"
"train_loss: 0.533, train_acc: 0.754\n",
"valid_loss: 0.474, valid_acc: 0.796\n"
]
}
],
@ -633,7 +634,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"test_loss: 0.370, test_acc: 0.848\n"
"test_loss: 0.469, test_acc: 0.797\n"
]
}
],
@ -673,7 +674,7 @@
{
"data": {
"text/plain": [
"(0, 0.9606466889381409)"
"(0, 0.8488112688064575)"
]
},
"execution_count": 35,
@ -696,7 +697,7 @@
{
"data": {
"text/plain": [
"(1, 0.9446237087249756)"
"(1, 0.7834495902061462)"
]
},
"execution_count": 36,
@ -719,7 +720,7 @@
{
"data": {
"text/plain": [
"(0, 0.7930739521980286)"
"(1, 0.5829631090164185)"
]
},
"execution_count": 37,
@ -742,7 +743,7 @@
{
"data": {
"text/plain": [
"(0, 0.6138491034507751)"
"(0, 0.6761167645454407)"
]
},
"execution_count": 38,