fix lstm notebook
- added missing pad_index to embedding layer - increased dropout
This commit is contained in:
parent
bc5e0f492a
commit
c8d4c1d136
73
2_lstm.ipynb
73
2_lstm.ipynb
@ -27,7 +27,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<torch._C.Generator at 0x7f0ba8bbc9b0>"
|
||||
"<torch._C.Generator at 0x7faeccf1e990>"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
@ -184,9 +184,9 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-0f4f4fd2aa8eec54.arrow\n",
|
||||
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-3f54ccc520f805f3.arrow\n",
|
||||
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-008d75ce96cff98f.arrow\n"
|
||||
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-631f583ffb0d9c68.arrow\n",
|
||||
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-6f5a6c52dcbaf2d0.arrow\n",
|
||||
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-c455c5d5c41c2779.arrow\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -265,9 +265,9 @@
|
||||
"source": [
|
||||
"class LSTM(nn.Module):\n",
|
||||
" def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional,\n",
|
||||
" dropout_rate):\n",
|
||||
" dropout_rate, pad_index):\n",
|
||||
" super().__init__()\n",
|
||||
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
||||
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
|
||||
" self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, bidirectional=bidirectional,\n",
|
||||
" dropout=dropout_rate, batch_first=True)\n",
|
||||
" self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)\n",
|
||||
@ -309,9 +309,10 @@
|
||||
"output_dim = len(train_data.unique('label'))\n",
|
||||
"n_layers = 2\n",
|
||||
"bidirectional = True\n",
|
||||
"dropout_rate = 0.5\n",
|
||||
"dropout_rate = 0.75\n",
|
||||
"\n",
|
||||
"model = LSTM(vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout_rate)"
|
||||
"model = LSTM(vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout_rate, \n",
|
||||
" pad_index)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -364,10 +365,10 @@
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"LSTM(\n",
|
||||
" (embedding): Embedding(21543, 300)\n",
|
||||
" (lstm): LSTM(300, 300, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)\n",
|
||||
" (embedding): Embedding(21543, 300, padding_idx=1)\n",
|
||||
" (lstm): LSTM(300, 300, num_layers=2, batch_first=True, dropout=0.75, bidirectional=True)\n",
|
||||
" (fc): Linear(in_features=600, out_features=2, bias=True)\n",
|
||||
" (dropout): Dropout(p=0.5, inplace=False)\n",
|
||||
" (dropout): Dropout(p=0.75, inplace=False)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@ -573,35 +574,35 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"epoch: 1\n",
|
||||
"train_loss: 0.654, train_acc: 0.619\n",
|
||||
"valid_loss: 0.600, valid_acc: 0.687\n",
|
||||
"train_loss: 0.649, train_acc: 0.605\n",
|
||||
"valid_loss: 0.637, valid_acc: 0.682\n",
|
||||
"epoch: 2\n",
|
||||
"train_loss: 0.588, train_acc: 0.701\n",
|
||||
"valid_loss: 0.605, valid_acc: 0.672\n",
|
||||
"train_loss: 0.563, train_acc: 0.710\n",
|
||||
"valid_loss: 0.611, valid_acc: 0.642\n",
|
||||
"epoch: 3\n",
|
||||
"train_loss: 0.533, train_acc: 0.738\n",
|
||||
"valid_loss: 0.533, valid_acc: 0.737\n",
|
||||
"train_loss: 0.594, train_acc: 0.697\n",
|
||||
"valid_loss: 0.710, valid_acc: 0.688\n",
|
||||
"epoch: 4\n",
|
||||
"train_loss: 0.415, train_acc: 0.813\n",
|
||||
"valid_loss: 0.392, valid_acc: 0.840\n",
|
||||
"train_loss: 0.590, train_acc: 0.692\n",
|
||||
"valid_loss: 0.517, valid_acc: 0.750\n",
|
||||
"epoch: 5\n",
|
||||
"train_loss: 0.359, train_acc: 0.849\n",
|
||||
"valid_loss: 0.408, valid_acc: 0.848\n",
|
||||
"train_loss: 0.535, train_acc: 0.733\n",
|
||||
"valid_loss: 0.687, valid_acc: 0.507\n",
|
||||
"epoch: 6\n",
|
||||
"train_loss: 0.285, train_acc: 0.885\n",
|
||||
"valid_loss: 0.359, valid_acc: 0.858\n",
|
||||
"train_loss: 0.677, train_acc: 0.566\n",
|
||||
"valid_loss: 0.635, valid_acc: 0.679\n",
|
||||
"epoch: 7\n",
|
||||
"train_loss: 0.238, train_acc: 0.909\n",
|
||||
"valid_loss: 0.359, valid_acc: 0.857\n",
|
||||
"train_loss: 0.593, train_acc: 0.704\n",
|
||||
"valid_loss: 0.685, valid_acc: 0.545\n",
|
||||
"epoch: 8\n",
|
||||
"train_loss: 0.195, train_acc: 0.927\n",
|
||||
"valid_loss: 0.382, valid_acc: 0.852\n",
|
||||
"train_loss: 0.653, train_acc: 0.587\n",
|
||||
"valid_loss: 0.648, valid_acc: 0.630\n",
|
||||
"epoch: 9\n",
|
||||
"train_loss: 0.177, train_acc: 0.934\n",
|
||||
"valid_loss: 0.367, valid_acc: 0.863\n",
|
||||
"train_loss: 0.613, train_acc: 0.658\n",
|
||||
"valid_loss: 0.545, valid_acc: 0.737\n",
|
||||
"epoch: 10\n",
|
||||
"train_loss: 0.145, train_acc: 0.948\n",
|
||||
"valid_loss: 0.488, valid_acc: 0.842\n"
|
||||
"train_loss: 0.533, train_acc: 0.754\n",
|
||||
"valid_loss: 0.474, valid_acc: 0.796\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -633,7 +634,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"test_loss: 0.370, test_acc: 0.848\n"
|
||||
"test_loss: 0.469, test_acc: 0.797\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -673,7 +674,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0, 0.9606466889381409)"
|
||||
"(0, 0.8488112688064575)"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
@ -696,7 +697,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(1, 0.9446237087249756)"
|
||||
"(1, 0.7834495902061462)"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
@ -719,7 +720,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0, 0.7930739521980286)"
|
||||
"(1, 0.5829631090164185)"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
@ -742,7 +743,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(0, 0.6138491034507751)"
|
||||
"(0, 0.6761167645454407)"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
|
Loading…
Reference in New Issue
Block a user