updates
This commit is contained in:
parent
d9520e6aad
commit
2e9c4f7f9d
@ -302,6 +302,82 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class GRU(nn.Module):\n",
|
||||
" def __init__(self, input_dim, emb_dim, hid_dim, output_dim, pad_idx):\n",
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
|
||||
" self.gru = nn.GRUCell(emb_dim, hid_dim)\n",
|
||||
" self.fc = nn.Linear(hid_dim, output_dim)\n",
|
||||
"\n",
|
||||
" def forward(self, text, lengths):\n",
|
||||
"\n",
|
||||
" # text = [seq len, batch size]\n",
|
||||
" # lengths = [batch size]\n",
|
||||
"\n",
|
||||
" embedded = self.embedding(text)\n",
|
||||
"\n",
|
||||
" # embedded = [seq len, batch size, emb dim]\n",
|
||||
"\n",
|
||||
" seq_len, batch_size, _ = embedded.shape\n",
|
||||
" hid_dim = self.gru.hidden_size\n",
|
||||
" \n",
|
||||
" hidden = torch.zeros(batch_size, hid_dim).to(embedded.device)\n",
|
||||
" \n",
|
||||
" for i in range(seq_len):\n",
|
||||
" x = embedded[i]\n",
|
||||
" hidden = self.gru(x, hidden)\n",
|
||||
" \n",
|
||||
" prediction = self.fc(hidden)\n",
|
||||
"\n",
|
||||
" # prediction = [batch size, output dim]\n",
|
||||
"\n",
|
||||
" return prediction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class GRU(nn.Module):\n",
|
||||
" def __init__(self, input_dim, emb_dim, hid_dim, output_dim, pad_idx):\n",
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
|
||||
" self.gru = nn.GRU(emb_dim, hid_dim)\n",
|
||||
" self.fc = nn.Linear(hid_dim, output_dim)\n",
|
||||
"\n",
|
||||
" def forward(self, text, lengths):\n",
|
||||
"\n",
|
||||
" # text = [seq len, batch size]\n",
|
||||
" # lengths = [batch size]\n",
|
||||
"\n",
|
||||
" embedded = self.embedding(text)\n",
|
||||
"\n",
|
||||
" # embedded = [seq len, batch size, emb dim]\n",
|
||||
"\n",
|
||||
" output, hidden = self.gru(embedded)\n",
|
||||
"\n",
|
||||
" # output = [seq_len, batch size, n directions * hid dim]\n",
|
||||
" # hidden = [n layers * n directions, batch size, hid dim]\n",
|
||||
"\n",
|
||||
" prediction = self.fc(hidden.squeeze(0))\n",
|
||||
"\n",
|
||||
" # prediction = [batch size, output dim]\n",
|
||||
"\n",
|
||||
" return prediction "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -314,7 +390,7 @@
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(input_dim, emb_dim)\n",
|
||||
" self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
|
||||
" self.gru = nn.GRU(emb_dim, hid_dim)\n",
|
||||
" self.fc = nn.Linear(hid_dim, output_dim)\n",
|
||||
"\n",
|
||||
@ -333,7 +409,7 @@
|
||||
"\n",
|
||||
" output, _ = nn.utils.rnn.pad_packed_sequence(packed_output)\n",
|
||||
"\n",
|
||||
" # outputs = [seq_len, batch size, n directions * hid dim]\n",
|
||||
" # output = [seq_len, batch size, n directions * hid dim]\n",
|
||||
" # hidden = [n layers * n directions, batch size, hid dim]\n",
|
||||
"\n",
|
||||
" prediction = self.fc(hidden.squeeze(0))\n",
|
||||
@ -345,7 +421,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -363,7 +439,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -377,7 +453,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -402,7 +478,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -416,7 +492,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -447,7 +523,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -462,7 +538,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -485,7 +561,7 @@
|
||||
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -496,7 +572,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -509,7 +585,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -522,7 +598,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -535,7 +611,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -549,7 +625,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 29,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -566,7 +642,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -606,7 +682,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -642,7 +718,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -659,7 +735,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -677,7 +753,7 @@
|
||||
"Epoch: 01 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.634 | Train Acc: 62.44%\n",
|
||||
"\t Val. Loss: 0.474 | Val. Acc: 77.64%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 8s\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.375 | Train Acc: 83.86%\n",
|
||||
"\t Val. Loss: 0.333 | Val. Acc: 86.20%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 8s\n",
|
||||
@ -686,16 +762,16 @@
|
||||
"Epoch: 04 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.170 | Train Acc: 93.78%\n",
|
||||
"\t Val. Loss: 0.316 | Val. Acc: 89.58%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 8s\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.106 | Train Acc: 96.58%\n",
|
||||
"\t Val. Loss: 0.319 | Val. Acc: 89.63%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 8s\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.066 | Train Acc: 98.08%\n",
|
||||
"\t Val. Loss: 0.327 | Val. Acc: 89.52%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.041 | Train Acc: 98.82%\n",
|
||||
"\t Val. Loss: 0.451 | Val. Acc: 88.07%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 8s\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.021 | Train Acc: 99.43%\n",
|
||||
"\t Val. Loss: 0.472 | Val. Acc: 88.16%\n",
|
||||
"Epoch: 09 | Epoch Time: 0m 7s\n",
|
||||
@ -734,7 +810,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -763,7 +839,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -785,7 +861,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -802,7 +878,7 @@
|
||||
"0.07642160356044769"
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -815,7 +891,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -832,7 +908,7 @@
|
||||
"0.8930155634880066"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -845,7 +921,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -862,7 +938,7 @@
|
||||
"0.2206803858280182"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -876,7 +952,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -893,7 +969,7 @@
|
||||
"0.5373267531394958"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -314,7 +314,7 @@
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(input_dim, emb_dim)\n",
|
||||
" self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
|
||||
" self.lstm = nn.LSTM(emb_dim, hid_dim, num_layers = n_layers, bidirectional = True, dropout = dropout)\n",
|
||||
" self.fc = nn.Linear(2 * hid_dim, output_dim)\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
@ -337,8 +337,8 @@
|
||||
" # outputs = [seq_len, batch size, n directions * hid dim]\n",
|
||||
" # hidden = [n layers * n directions, batch size, hid dim]\n",
|
||||
"\n",
|
||||
" hidden_fwd = hidden[-1]\n",
|
||||
" hidden_bck = hidden[-2]\n",
|
||||
" hidden_fwd = hidden[-2]\n",
|
||||
" hidden_bck = hidden[-1]\n",
|
||||
"\n",
|
||||
" # hidden_fwd/bck = [batch size, hid dim]\n",
|
||||
"\n",
|
||||
@ -687,35 +687,35 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.668 | Train Acc: 59.01%\n",
|
||||
"\t Val. Loss: 0.652 | Val. Acc: 62.33%\n",
|
||||
"\tTrain Loss: 0.656 | Train Acc: 60.38%\n",
|
||||
"\t Val. Loss: 0.609 | Val. Acc: 64.97%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.602 | Train Acc: 67.75%\n",
|
||||
"\t Val. Loss: 0.478 | Val. Acc: 77.18%\n",
|
||||
"\tTrain Loss: 0.594 | Train Acc: 69.33%\n",
|
||||
"\t Val. Loss: 0.578 | Val. Acc: 71.96%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.497 | Train Acc: 76.59%\n",
|
||||
"\t Val. Loss: 0.478 | Val. Acc: 80.62%\n",
|
||||
"\tTrain Loss: 0.565 | Train Acc: 70.59%\n",
|
||||
"\t Val. Loss: 0.476 | Val. Acc: 77.87%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.456 | Train Acc: 79.24%\n",
|
||||
"\t Val. Loss: 0.397 | Val. Acc: 83.30%\n",
|
||||
"\tTrain Loss: 0.487 | Train Acc: 76.72%\n",
|
||||
"\t Val. Loss: 0.453 | Val. Acc: 79.55%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.391 | Train Acc: 82.72%\n",
|
||||
"\t Val. Loss: 0.344 | Val. Acc: 85.23%\n",
|
||||
"\tTrain Loss: 0.415 | Train Acc: 81.43%\n",
|
||||
"\t Val. Loss: 0.403 | Val. Acc: 83.60%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.345 | Train Acc: 85.30%\n",
|
||||
"\t Val. Loss: 0.350 | Val. Acc: 86.12%\n",
|
||||
"\tTrain Loss: 0.349 | Train Acc: 85.02%\n",
|
||||
"\t Val. Loss: 0.337 | Val. Acc: 86.41%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.314 | Train Acc: 86.75%\n",
|
||||
"\t Val. Loss: 0.310 | Val. Acc: 87.50%\n",
|
||||
"\tTrain Loss: 0.308 | Train Acc: 86.93%\n",
|
||||
"\t Val. Loss: 0.344 | Val. Acc: 85.36%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.266 | Train Acc: 89.54%\n",
|
||||
"\t Val. Loss: 0.315 | Val. Acc: 88.16%\n",
|
||||
"\tTrain Loss: 0.279 | Train Acc: 88.49%\n",
|
||||
"\t Val. Loss: 0.315 | Val. Acc: 87.62%\n",
|
||||
"Epoch: 09 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.247 | Train Acc: 90.21%\n",
|
||||
"\t Val. Loss: 0.285 | Val. Acc: 89.02%\n",
|
||||
"\tTrain Loss: 0.252 | Train Acc: 89.98%\n",
|
||||
"\t Val. Loss: 0.326 | Val. Acc: 88.16%\n",
|
||||
"Epoch: 10 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.217 | Train Acc: 91.79%\n",
|
||||
"\t Val. Loss: 0.282 | Val. Acc: 88.98%\n"
|
||||
"\tTrain Loss: 0.218 | Train Acc: 91.31%\n",
|
||||
"\t Val. Loss: 0.293 | Val. Acc: 89.16%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -761,7 +761,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.294 | Test Acc: 87.95%\n"
|
||||
"Test Loss: 0.307 | Test Acc: 88.17%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -811,7 +811,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.008071469143033028"
|
||||
"0.015085793100297451"
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
@ -841,7 +841,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9896865487098694"
|
||||
"0.991886556148529"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
@ -871,7 +871,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.029767075553536415"
|
||||
"0.6712993383407593"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
@ -902,7 +902,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.5513127446174622"
|
||||
"0.28944137692451477"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user