fix bug with max_length in tokenizer. improved loading vectors. added weight initialization
This commit is contained in:
parent
42e9e4850d
commit
b15b88ae50
1
.gitignore
vendored
1
.gitignore
vendored
@ -107,3 +107,4 @@ saves/*
|
||||
*.pt
|
||||
.vscode/
|
||||
custom_embeddings/trained_embeddings.*
|
||||
experimental/.data/
|
@ -78,7 +78,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"<torchtext.experimental.datasets.raw.text_classification.RawTextIterableDataset object at 0x7fbf244dc250>\n"
|
||||
"<torchtext.experimental.datasets.raw.text_classification.RawTextIterableDataset object at 0x7f63ac6b23a0>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -275,7 +275,7 @@
|
||||
" tokens = [token.lower() for token in tokens]\n",
|
||||
" \n",
|
||||
" if self.max_length is not None:\n",
|
||||
" tokens = tokens[:max_length]\n",
|
||||
" tokens = tokens[:self.max_length]\n",
|
||||
" \n",
|
||||
" return tokens"
|
||||
]
|
||||
@ -631,6 +631,7 @@
|
||||
"source": [
|
||||
"pad_token = '<pad>'\n",
|
||||
"pad_idx = vocab[pad_token]\n",
|
||||
"\n",
|
||||
"collator = Collator(pad_idx)"
|
||||
]
|
||||
},
|
||||
@ -674,7 +675,6 @@
|
||||
"source": [
|
||||
"class NBOW(nn.Module):\n",
|
||||
" def __init__(self, input_dim, emb_dim, output_dim, pad_idx):\n",
|
||||
" \n",
|
||||
" super().__init__()\n",
|
||||
" \n",
|
||||
" self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
|
||||
@ -806,7 +806,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"glove.vectors['the']"
|
||||
"glove['the']"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -838,7 +838,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"glove.vectors['shoggoth']"
|
||||
"glove['shoggoth']"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -870,12 +870,61 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"glove.vectors['The']"
|
||||
"glove['The']"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"glove_vocab = glove.vectors.get_stoi()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"'the' in glove_vocab"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"False"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"'The' in glove_vocab"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -883,30 +932,26 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_pretrained_embedding(vectors, vocab, unk_token):\n",
|
||||
"def get_pretrained_embedding(initial_embedding, pretrained_vectors, vocab, unk_token):\n",
|
||||
" \n",
|
||||
" unk_vector = vectors[unk_token]\n",
|
||||
" emb_dim = unk_vector.shape[-1]\n",
|
||||
" zero_vector = torch.zeros(emb_dim)\n",
|
||||
"\n",
|
||||
" pretrained_embedding = torch.zeros(len(vocab), emb_dim) \n",
|
||||
" pretrained_embedding = torch.FloatTensor(initial_embedding.weight.clone()).detach() \n",
|
||||
" pretrained_vocab = pretrained_vectors.vectors.get_stoi()\n",
|
||||
" \n",
|
||||
" unk_tokens = []\n",
|
||||
" \n",
|
||||
" for idx, token in enumerate(vocab.itos):\n",
|
||||
" pretrained_vector = vectors[token]\n",
|
||||
" if torch.all(torch.eq(pretrained_vector, unk_vector)):\n",
|
||||
" unk_tokens.append(token)\n",
|
||||
" pretrained_embedding[idx] = zero_vector\n",
|
||||
" else:\n",
|
||||
" if token in pretrained_vocab:\n",
|
||||
" pretrained_vector = pretrained_vectors[token]\n",
|
||||
" pretrained_embedding[idx] = pretrained_vector\n",
|
||||
" else:\n",
|
||||
" unk_tokens.append(token)\n",
|
||||
" \n",
|
||||
" return pretrained_embedding, unk_tokens"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 40,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -916,108 +961,13 @@
|
||||
"source": [
|
||||
"unk_token = '<unk>'\n",
|
||||
"\n",
|
||||
"pretrained_embedding, unk_tokens = get_pretrained_embedding(glove.vectors, vocab, unk_token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 139
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "YGO-9DWBdFn1",
|
||||
"outputId": "fb8178fc-0f1c-43ba-b9a8-0b25ec343a54"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
|
||||
" [-0.7250, 0.7545, 0.1637, ..., -0.0144, -0.1761, 0.3418],\n",
|
||||
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pretrained_embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "j36jzQpPdFn3",
|
||||
"outputId": "7ebe041d-b092-498e-ea16-0fce8c20ed33"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"678"
|
||||
]
|
||||
},
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(unk_tokens)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "yzvhgf8tdFn5",
|
||||
"outputId": "8c30dc4a-9a2b-4c11-8c7b-1d2cb3ba0aee"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['<unk>', '<pad>', '\\x96', 'hadn', '****', '100%', 'camera-work', '*1/2', '$1', '*****']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(unk_tokens[:10])"
|
||||
"pretrained_embedding, unk_tokens = get_pretrained_embedding(model.embedding, glove, vocab, unk_token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 139
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "OKxH6f7ddFn8",
|
||||
"outputId": "8ff8c71b-725e-417d-e3fa-7ffeac2e342d"
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@ -1043,20 +993,12 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 139
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "AnE6D4MAdFn_",
|
||||
"outputId": "8b3fea1a-9bcb-4fd9-ba78-72baee94f96a"
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
"tensor([[-0.1117, -0.4966, 0.1631, ..., 1.5903, -0.1947, -0.2415],\n",
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
@ -1071,12 +1013,99 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.embedding.weight.data.copy_(pretrained_embedding)"
|
||||
"pretrained_embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "j36jzQpPdFn3",
|
||||
"outputId": "7ebe041d-b092-498e-ea16-0fce8c20ed33"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"678"
|
||||
]
|
||||
},
|
||||
"execution_count": 43,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"len(unk_tokens)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "yzvhgf8tdFn5",
|
||||
"outputId": "8c30dc4a-9a2b-4c11-8c7b-1d2cb3ba0aee"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['<unk>', '<pad>', '\\x96', 'hadn', '****', '100%', 'camera-work', '*1/2', '$1', '*****']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(unk_tokens[:10])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 139
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "AnE6D4MAdFn_",
|
||||
"outputId": "8b3fea1a-9bcb-4fd9-ba78-72baee94f96a"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[-0.1117, -0.4966, 0.1631, ..., 1.5903, -0.1947, -0.2415],\n",
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
|
||||
" [-0.7250, 0.7545, 0.1637, ..., -0.0144, -0.1761, 0.3418],\n",
|
||||
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.embedding.weight.data.copy_(pretrained_embedding)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -1089,7 +1118,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": 47,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -1102,7 +1131,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"execution_count": 48,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -1129,7 +1158,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"execution_count": 49,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -1143,7 +1172,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"execution_count": 50,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -1160,7 +1189,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"execution_count": 51,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -1200,7 +1229,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"execution_count": 52,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -1236,7 +1265,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"execution_count": 53,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -1253,7 +1282,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"execution_count": 54,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -1269,35 +1298,35 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.687 | Train Acc: 56.68%\n",
|
||||
"\t Val. Loss: 0.677 | Val. Acc: 63.89%\n",
|
||||
"\tTrain Loss: 0.687 | Train Acc: 56.51%\n",
|
||||
"\t Val. Loss: 0.677 | Val. Acc: 62.87%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.664 | Train Acc: 65.74%\n",
|
||||
"\t Val. Loss: 0.650 | Val. Acc: 69.64%\n",
|
||||
"\tTrain Loss: 0.665 | Train Acc: 65.13%\n",
|
||||
"\t Val. Loss: 0.650 | Val. Acc: 69.13%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.628 | Train Acc: 73.00%\n",
|
||||
"\t Val. Loss: 0.610 | Val. Acc: 74.11%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.581 | Train Acc: 76.68%\n",
|
||||
"\t Val. Loss: 0.564 | Val. Acc: 77.74%\n",
|
||||
"\tTrain Loss: 0.629 | Train Acc: 72.45%\n",
|
||||
"\t Val. Loss: 0.611 | Val. Acc: 73.54%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 5s\n",
|
||||
"\tTrain Loss: 0.583 | Train Acc: 76.17%\n",
|
||||
"\t Val. Loss: 0.566 | Val. Acc: 77.00%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.530 | Train Acc: 80.47%\n",
|
||||
"\t Val. Loss: 0.518 | Val. Acc: 80.41%\n",
|
||||
"\tTrain Loss: 0.533 | Train Acc: 80.22%\n",
|
||||
"\t Val. Loss: 0.521 | Val. Acc: 80.28%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.481 | Train Acc: 83.37%\n",
|
||||
"\t Val. Loss: 0.477 | Val. Acc: 82.58%\n",
|
||||
"\tTrain Loss: 0.484 | Train Acc: 83.24%\n",
|
||||
"\t Val. Loss: 0.480 | Val. Acc: 82.53%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.437 | Train Acc: 85.62%\n",
|
||||
"\t Val. Loss: 0.441 | Val. Acc: 84.43%\n",
|
||||
"\tTrain Loss: 0.440 | Train Acc: 85.46%\n",
|
||||
"\t Val. Loss: 0.443 | Val. Acc: 84.40%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.398 | Train Acc: 87.28%\n",
|
||||
"\t Val. Loss: 0.413 | Val. Acc: 85.61%\n",
|
||||
"\tTrain Loss: 0.401 | Train Acc: 87.10%\n",
|
||||
"\t Val. Loss: 0.414 | Val. Acc: 85.45%\n",
|
||||
"Epoch: 09 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.365 | Train Acc: 88.43%\n",
|
||||
"\t Val. Loss: 0.390 | Val. Acc: 86.38%\n",
|
||||
"\tTrain Loss: 0.367 | Train Acc: 88.41%\n",
|
||||
"\t Val. Loss: 0.390 | Val. Acc: 86.39%\n",
|
||||
"Epoch: 10 | Epoch Time: 0m 4s\n",
|
||||
"\tTrain Loss: 0.339 | Train Acc: 89.21%\n",
|
||||
"\t Val. Loss: 0.370 | Val. Acc: 86.92%\n"
|
||||
"\tTrain Loss: 0.340 | Train Acc: 89.23%\n",
|
||||
"\t Val. Loss: 0.370 | Val. Acc: 86.96%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1328,7 +1357,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 52,
|
||||
"execution_count": 55,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -1343,7 +1372,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.393 | Test Acc: 85.31%\n"
|
||||
"Test Loss: 0.393 | Test Acc: 85.39%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1357,7 +1386,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"execution_count": 56,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -1378,7 +1407,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"execution_count": 57,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -1392,10 +1421,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"1.0650165677361656e-05"
|
||||
"9.809165021579247e-06"
|
||||
]
|
||||
},
|
||||
"execution_count": 54,
|
||||
"execution_count": 57,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -1408,7 +1437,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 55,
|
||||
"execution_count": 58,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -1422,10 +1451,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9999966621398926"
|
||||
"0.9999963045120239"
|
||||
]
|
||||
},
|
||||
"execution_count": 55,
|
||||
"execution_count": 58,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -1438,7 +1467,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 56,
|
||||
"execution_count": 59,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -1452,10 +1481,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.7397726774215698"
|
||||
"0.7485461235046387"
|
||||
]
|
||||
},
|
||||
"execution_count": 56,
|
||||
"execution_count": 59,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -1469,7 +1498,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"execution_count": 60,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -1483,10 +1512,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.7397726774215698"
|
||||
"0.7485461235046387"
|
||||
]
|
||||
},
|
||||
"execution_count": 57,
|
||||
"execution_count": 60,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -126,7 +126,7 @@
|
||||
" tokens = [token.lower() for token in tokens]\n",
|
||||
" \n",
|
||||
" if self.max_length is not None:\n",
|
||||
" tokens = tokens[:max_length]\n",
|
||||
" tokens = tokens[:self.max_length]\n",
|
||||
" \n",
|
||||
" return tokens"
|
||||
]
|
||||
@ -268,6 +268,7 @@
|
||||
"source": [
|
||||
"pad_token = '<pad>'\n",
|
||||
"pad_idx = vocab[pad_token]\n",
|
||||
"\n",
|
||||
"collator = Collator(pad_idx)"
|
||||
]
|
||||
},
|
||||
@ -307,7 +308,6 @@
|
||||
"source": [
|
||||
"class GRU(nn.Module):\n",
|
||||
" def __init__(self, input_dim, emb_dim, hid_dim, output_dim, pad_idx):\n",
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
|
||||
@ -347,7 +347,6 @@
|
||||
"source": [
|
||||
"class GRU(nn.Module):\n",
|
||||
" def __init__(self, input_dim, emb_dim, hid_dim, output_dim, pad_idx):\n",
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
|
||||
@ -387,7 +386,6 @@
|
||||
"source": [
|
||||
"class GRU(nn.Module):\n",
|
||||
" def __init__(self, input_dim, emb_dim, hid_dim, output_dim, pad_idx):\n",
|
||||
"\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
|
||||
@ -479,6 +477,85 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"name: embedding.weight, shape: torch.Size([25002, 100])\n",
|
||||
"name: gru.weight_ih_l0, shape: torch.Size([768, 100])\n",
|
||||
"name: gru.weight_hh_l0, shape: torch.Size([768, 256])\n",
|
||||
"name: gru.bias_ih_l0, shape: torch.Size([768])\n",
|
||||
"name: gru.bias_hh_l0, shape: torch.Size([768])\n",
|
||||
"name: fc.weight, shape: torch.Size([2, 256])\n",
|
||||
"name: fc.bias, shape: torch.Size([2])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for n, p in model.named_parameters():\n",
|
||||
" print(f'name: {n}, shape: {p.shape}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def initialize_parameters(m):\n",
|
||||
" if isinstance(m, nn.Embedding):\n",
|
||||
" nn.init.uniform_(m.weight, -0.05, 0.05)\n",
|
||||
" elif isinstance(m, nn.GRU):\n",
|
||||
" for n, p in m.named_parameters():\n",
|
||||
" if 'weight_ih' in n:\n",
|
||||
" r, z, n = p.chunk(3)\n",
|
||||
" nn.init.xavier_uniform_(r)\n",
|
||||
" nn.init.xavier_uniform_(z)\n",
|
||||
" nn.init.xavier_uniform_(n)\n",
|
||||
" elif 'weight_hh' in n:\n",
|
||||
" r, z, n = p.chunk(3)\n",
|
||||
" nn.init.orthogonal_(r)\n",
|
||||
" nn.init.orthogonal_(z)\n",
|
||||
" nn.init.orthogonal_(n)\n",
|
||||
" elif 'bias' in n:\n",
|
||||
" r, z, n = p.chunk(3)\n",
|
||||
" nn.init.zeros_(r)\n",
|
||||
" nn.init.zeros_(z)\n",
|
||||
" nn.init.zeros_(n)\n",
|
||||
" elif isinstance(m, nn.Linear):\n",
|
||||
" nn.init.xavier_uniform_(m.weight)\n",
|
||||
" nn.init.zeros_(m.bias)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"GRU(\n",
|
||||
" (embedding): Embedding(25002, 100, padding_idx=1)\n",
|
||||
" (gru): GRU(100, 256)\n",
|
||||
" (fc): Linear(in_features=256, out_features=2, bias=True)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.apply(initialize_parameters)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -492,7 +569,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -500,30 +577,26 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_pretrained_embedding(vectors, vocab, unk_token):\n",
|
||||
"def get_pretrained_embedding(initial_embedding, pretrained_vectors, vocab, unk_token):\n",
|
||||
" \n",
|
||||
" unk_vector = vectors[unk_token]\n",
|
||||
" emb_dim = unk_vector.shape[-1]\n",
|
||||
" zero_vector = torch.zeros(emb_dim)\n",
|
||||
"\n",
|
||||
" pretrained_embedding = torch.zeros(len(vocab), emb_dim) \n",
|
||||
" pretrained_embedding = torch.FloatTensor(initial_embedding.weight.clone()).detach() \n",
|
||||
" pretrained_vocab = pretrained_vectors.vectors.get_stoi()\n",
|
||||
" \n",
|
||||
" unk_tokens = []\n",
|
||||
" \n",
|
||||
" for idx, token in enumerate(vocab.itos):\n",
|
||||
" pretrained_vector = vectors[token]\n",
|
||||
" if torch.all(torch.eq(pretrained_vector, unk_vector)):\n",
|
||||
" unk_tokens.append(token)\n",
|
||||
" pretrained_embedding[idx] = zero_vector\n",
|
||||
" else:\n",
|
||||
" if token in pretrained_vocab:\n",
|
||||
" pretrained_vector = pretrained_vectors[token]\n",
|
||||
" pretrained_embedding[idx] = pretrained_vector\n",
|
||||
" else:\n",
|
||||
" unk_tokens.append(token)\n",
|
||||
" \n",
|
||||
" return pretrained_embedding, unk_tokens"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -533,12 +606,12 @@
|
||||
"source": [
|
||||
"unk_token = '<unk>'\n",
|
||||
"\n",
|
||||
"pretrained_embedding, unk_tokens = get_pretrained_embedding(glove.vectors, vocab, unk_token)"
|
||||
"pretrained_embedding, unk_tokens = get_pretrained_embedding(model.embedding, glove, vocab, unk_token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -552,8 +625,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
"tensor([[ 0.0098, 0.0150, -0.0099, ..., 0.0211, -0.0092, 0.0027],\n",
|
||||
" [ 0.0347, 0.0276, 0.0468, ..., -0.0315, -0.0472, -0.0326],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
|
||||
@ -561,7 +634,7 @@
|
||||
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 24,
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -572,7 +645,42 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.embedding.weight.data[pad_idx] = torch.zeros(emb_dim)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[ 0.0098, 0.0150, -0.0099, ..., 0.0211, -0.0092, 0.0027],\n",
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
|
||||
" [-0.7250, 0.7545, 0.1637, ..., -0.0144, -0.1761, 0.3418],\n",
|
||||
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.embedding.weight.data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -585,7 +693,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -598,7 +706,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -611,7 +719,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -625,7 +733,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -642,7 +750,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -682,7 +790,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -718,7 +826,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -735,7 +843,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -750,36 +858,36 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.634 | Train Acc: 62.44%\n",
|
||||
"\t Val. Loss: 0.474 | Val. Acc: 77.64%\n",
|
||||
"Epoch: 01 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.654 | Train Acc: 60.73%\n",
|
||||
"\t Val. Loss: 0.584 | Val. Acc: 68.87%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.375 | Train Acc: 83.86%\n",
|
||||
"\t Val. Loss: 0.333 | Val. Acc: 86.20%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.251 | Train Acc: 90.32%\n",
|
||||
"\t Val. Loss: 0.286 | Val. Acc: 89.07%\n",
|
||||
"\tTrain Loss: 0.423 | Train Acc: 80.73%\n",
|
||||
"\t Val. Loss: 0.332 | Val. Acc: 86.04%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.252 | Train Acc: 90.15%\n",
|
||||
"\t Val. Loss: 0.285 | Val. Acc: 88.63%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.170 | Train Acc: 93.78%\n",
|
||||
"\t Val. Loss: 0.316 | Val. Acc: 89.58%\n",
|
||||
"\tTrain Loss: 0.186 | Train Acc: 93.05%\n",
|
||||
"\t Val. Loss: 0.286 | Val. Acc: 89.40%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.106 | Train Acc: 96.58%\n",
|
||||
"\t Val. Loss: 0.319 | Val. Acc: 89.63%\n",
|
||||
"\tTrain Loss: 0.116 | Train Acc: 95.85%\n",
|
||||
"\t Val. Loss: 0.307 | Val. Acc: 89.56%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.066 | Train Acc: 98.08%\n",
|
||||
"\t Val. Loss: 0.327 | Val. Acc: 89.52%\n",
|
||||
"\tTrain Loss: 0.065 | Train Acc: 97.90%\n",
|
||||
"\t Val. Loss: 0.354 | Val. Acc: 89.64%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.041 | Train Acc: 98.82%\n",
|
||||
"\t Val. Loss: 0.451 | Val. Acc: 88.07%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.021 | Train Acc: 99.43%\n",
|
||||
"\t Val. Loss: 0.472 | Val. Acc: 88.16%\n",
|
||||
"\tTrain Loss: 0.042 | Train Acc: 98.74%\n",
|
||||
"\t Val. Loss: 0.403 | Val. Acc: 89.35%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 8s\n",
|
||||
"\tTrain Loss: 0.020 | Train Acc: 99.47%\n",
|
||||
"\t Val. Loss: 0.408 | Val. Acc: 89.35%\n",
|
||||
"Epoch: 09 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.014 | Train Acc: 99.71%\n",
|
||||
"\t Val. Loss: 0.520 | Val. Acc: 88.43%\n",
|
||||
"\tTrain Loss: 0.010 | Train Acc: 99.81%\n",
|
||||
"\t Val. Loss: 0.505 | Val. Acc: 88.53%\n",
|
||||
"Epoch: 10 | Epoch Time: 0m 7s\n",
|
||||
"\tTrain Loss: 0.005 | Train Acc: 99.93%\n",
|
||||
"\t Val. Loss: 0.660 | Val. Acc: 88.43%\n"
|
||||
"\tTrain Loss: 0.007 | Train Acc: 99.85%\n",
|
||||
"\t Val. Loss: 0.657 | Val. Acc: 88.27%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -810,7 +918,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -825,7 +933,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.290 | Test Acc: 88.71%\n"
|
||||
"Test Loss: 0.290 | Test Acc: 87.93%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -839,7 +947,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 40,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -861,7 +969,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 41,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -875,10 +983,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.07642160356044769"
|
||||
"0.06520231813192368"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -891,7 +999,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 42,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -905,10 +1013,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.8930155634880066"
|
||||
"0.8539475798606873"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -921,7 +1029,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 43,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -935,10 +1043,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.2206803858280182"
|
||||
"0.15590433776378632"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
"execution_count": 43,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -952,7 +1060,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 44,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -966,10 +1074,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.5373267531394958"
|
||||
"0.3470574617385864"
|
||||
]
|
||||
},
|
||||
"execution_count": 39,
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -126,7 +126,7 @@
|
||||
" tokens = [token.lower() for token in tokens]\n",
|
||||
" \n",
|
||||
" if self.max_length is not None:\n",
|
||||
" tokens = tokens[:max_length]\n",
|
||||
" tokens = tokens[:self.max_length]\n",
|
||||
" \n",
|
||||
" return tokens"
|
||||
]
|
||||
@ -268,6 +268,7 @@
|
||||
"source": [
|
||||
"pad_token = '<pad>'\n",
|
||||
"pad_idx = vocab[pad_token]\n",
|
||||
"\n",
|
||||
"collator = Collator(pad_idx)"
|
||||
]
|
||||
},
|
||||
@ -415,6 +416,101 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"name: embedding.weight, shape: torch.Size([25002, 100])\n",
|
||||
"name: lstm.weight_ih_l0, shape: torch.Size([1024, 100])\n",
|
||||
"name: lstm.weight_hh_l0, shape: torch.Size([1024, 256])\n",
|
||||
"name: lstm.bias_ih_l0, shape: torch.Size([1024])\n",
|
||||
"name: lstm.bias_hh_l0, shape: torch.Size([1024])\n",
|
||||
"name: lstm.weight_ih_l0_reverse, shape: torch.Size([1024, 100])\n",
|
||||
"name: lstm.weight_hh_l0_reverse, shape: torch.Size([1024, 256])\n",
|
||||
"name: lstm.bias_ih_l0_reverse, shape: torch.Size([1024])\n",
|
||||
"name: lstm.bias_hh_l0_reverse, shape: torch.Size([1024])\n",
|
||||
"name: lstm.weight_ih_l1, shape: torch.Size([1024, 512])\n",
|
||||
"name: lstm.weight_hh_l1, shape: torch.Size([1024, 256])\n",
|
||||
"name: lstm.bias_ih_l1, shape: torch.Size([1024])\n",
|
||||
"name: lstm.bias_hh_l1, shape: torch.Size([1024])\n",
|
||||
"name: lstm.weight_ih_l1_reverse, shape: torch.Size([1024, 512])\n",
|
||||
"name: lstm.weight_hh_l1_reverse, shape: torch.Size([1024, 256])\n",
|
||||
"name: lstm.bias_ih_l1_reverse, shape: torch.Size([1024])\n",
|
||||
"name: lstm.bias_hh_l1_reverse, shape: torch.Size([1024])\n",
|
||||
"name: fc.weight, shape: torch.Size([2, 512])\n",
|
||||
"name: fc.bias, shape: torch.Size([2])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for n, p in model.named_parameters():\n",
|
||||
" print(f'name: {n}, shape: {p.shape}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def initialize_parameters(m):\n",
|
||||
" if isinstance(m, nn.Embedding):\n",
|
||||
" nn.init.uniform_(m.weight, -0.05, 0.05)\n",
|
||||
" elif isinstance(m, nn.LSTM):\n",
|
||||
" for n, p in m.named_parameters():\n",
|
||||
" if 'weight_ih' in n:\n",
|
||||
" i, f, g, o = p.chunk(4)\n",
|
||||
" nn.init.xavier_uniform_(i)\n",
|
||||
" nn.init.xavier_uniform_(f)\n",
|
||||
" nn.init.xavier_uniform_(g)\n",
|
||||
" nn.init.xavier_uniform_(o)\n",
|
||||
" elif 'weight_hh' in n:\n",
|
||||
" i, f, g, o = p.chunk(4)\n",
|
||||
" nn.init.orthogonal_(i)\n",
|
||||
" nn.init.orthogonal_(f)\n",
|
||||
" nn.init.orthogonal_(g)\n",
|
||||
" nn.init.orthogonal_(o)\n",
|
||||
" elif 'bias' in n:\n",
|
||||
" i, f, g, o = p.chunk(4)\n",
|
||||
" nn.init.zeros_(i)\n",
|
||||
" nn.init.ones_(f)\n",
|
||||
" nn.init.zeros_(g)\n",
|
||||
" nn.init.zeros_(o)\n",
|
||||
" elif isinstance(m, nn.Linear):\n",
|
||||
" nn.init.xavier_uniform_(m.weight)\n",
|
||||
" nn.init.zeros_(m.bias)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"BiLSTM(\n",
|
||||
" (embedding): Embedding(25002, 100, padding_idx=1)\n",
|
||||
" (lstm): LSTM(100, 256, num_layers=2, dropout=0.5, bidirectional=True)\n",
|
||||
" (fc): Linear(in_features=512, out_features=2, bias=True)\n",
|
||||
" (dropout): Dropout(p=0.5, inplace=False)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.apply(initialize_parameters)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -428,7 +524,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -436,30 +532,26 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_pretrained_embedding(vectors, vocab, unk_token):\n",
|
||||
"def get_pretrained_embedding(initial_embedding, pretrained_vectors, vocab, unk_token):\n",
|
||||
" \n",
|
||||
" unk_vector = vectors[unk_token]\n",
|
||||
" emb_dim = unk_vector.shape[-1]\n",
|
||||
" zero_vector = torch.zeros(emb_dim)\n",
|
||||
"\n",
|
||||
" pretrained_embedding = torch.zeros(len(vocab), emb_dim) \n",
|
||||
" pretrained_embedding = torch.FloatTensor(initial_embedding.weight.clone()).detach() \n",
|
||||
" pretrained_vocab = pretrained_vectors.vectors.get_stoi()\n",
|
||||
" \n",
|
||||
" unk_tokens = []\n",
|
||||
" \n",
|
||||
" for idx, token in enumerate(vocab.itos):\n",
|
||||
" pretrained_vector = vectors[token]\n",
|
||||
" if torch.all(torch.eq(pretrained_vector, unk_vector)):\n",
|
||||
" unk_tokens.append(token)\n",
|
||||
" pretrained_embedding[idx] = zero_vector\n",
|
||||
" else:\n",
|
||||
" if token in pretrained_vocab:\n",
|
||||
" pretrained_vector = pretrained_vectors[token]\n",
|
||||
" pretrained_embedding[idx] = pretrained_vector\n",
|
||||
" else:\n",
|
||||
" unk_tokens.append(token)\n",
|
||||
" \n",
|
||||
" return pretrained_embedding, unk_tokens"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -469,12 +561,12 @@
|
||||
"source": [
|
||||
"unk_token = '<unk>'\n",
|
||||
"\n",
|
||||
"pretrained_embedding, unk_tokens = get_pretrained_embedding(glove.vectors, vocab, unk_token)"
|
||||
"pretrained_embedding, unk_tokens = get_pretrained_embedding(model.embedding, glove, vocab, unk_token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -488,8 +580,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
"tensor([[-0.0398, 0.0357, -0.0046, ..., -0.0485, -0.0088, 0.0329],\n",
|
||||
" [-0.0330, 0.0428, 0.0304, ..., 0.0236, 0.0487, 0.0101],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
|
||||
@ -497,7 +589,7 @@
|
||||
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -508,7 +600,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.embedding.weight.data[pad_idx] = torch.zeros(emb_dim)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -521,7 +622,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -534,7 +635,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 29,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -547,7 +648,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -561,7 +662,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -578,7 +679,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -618,7 +719,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -654,7 +755,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -671,7 +772,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -686,36 +787,36 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.656 | Train Acc: 60.38%\n",
|
||||
"\t Val. Loss: 0.609 | Val. Acc: 64.97%\n",
|
||||
"Epoch: 01 | Epoch Time: 0m 23s\n",
|
||||
"\tTrain Loss: 0.777 | Train Acc: 52.23%\n",
|
||||
"\t Val. Loss: 0.683 | Val. Acc: 53.70%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.594 | Train Acc: 69.33%\n",
|
||||
"\t Val. Loss: 0.578 | Val. Acc: 71.96%\n",
|
||||
"\tTrain Loss: 0.683 | Train Acc: 57.90%\n",
|
||||
"\t Val. Loss: 0.676 | Val. Acc: 53.47%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.565 | Train Acc: 70.59%\n",
|
||||
"\t Val. Loss: 0.476 | Val. Acc: 77.87%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.487 | Train Acc: 76.72%\n",
|
||||
"\t Val. Loss: 0.453 | Val. Acc: 79.55%\n",
|
||||
"\tTrain Loss: 0.625 | Train Acc: 65.60%\n",
|
||||
"\t Val. Loss: 0.482 | Val. Acc: 78.27%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 25s\n",
|
||||
"\tTrain Loss: 0.483 | Train Acc: 77.15%\n",
|
||||
"\t Val. Loss: 0.410 | Val. Acc: 82.67%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.415 | Train Acc: 81.43%\n",
|
||||
"\t Val. Loss: 0.403 | Val. Acc: 83.60%\n",
|
||||
"\tTrain Loss: 0.350 | Train Acc: 85.31%\n",
|
||||
"\t Val. Loss: 0.315 | Val. Acc: 86.75%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.349 | Train Acc: 85.02%\n",
|
||||
"\t Val. Loss: 0.337 | Val. Acc: 86.41%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.308 | Train Acc: 86.93%\n",
|
||||
"\t Val. Loss: 0.344 | Val. Acc: 85.36%\n",
|
||||
"\tTrain Loss: 0.294 | Train Acc: 88.14%\n",
|
||||
"\t Val. Loss: 0.288 | Val. Acc: 88.41%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 25s\n",
|
||||
"\tTrain Loss: 0.258 | Train Acc: 89.92%\n",
|
||||
"\t Val. Loss: 0.277 | Val. Acc: 89.14%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.279 | Train Acc: 88.49%\n",
|
||||
"\t Val. Loss: 0.315 | Val. Acc: 87.62%\n",
|
||||
"\tTrain Loss: 0.231 | Train Acc: 91.03%\n",
|
||||
"\t Val. Loss: 0.280 | Val. Acc: 88.89%\n",
|
||||
"Epoch: 09 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.252 | Train Acc: 89.98%\n",
|
||||
"\t Val. Loss: 0.326 | Val. Acc: 88.16%\n",
|
||||
"\tTrain Loss: 0.196 | Train Acc: 92.50%\n",
|
||||
"\t Val. Loss: 0.285 | Val. Acc: 89.27%\n",
|
||||
"Epoch: 10 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.218 | Train Acc: 91.31%\n",
|
||||
"\t Val. Loss: 0.293 | Val. Acc: 89.16%\n"
|
||||
"\tTrain Loss: 0.175 | Train Acc: 93.53%\n",
|
||||
"\t Val. Loss: 0.316 | Val. Acc: 89.55%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -746,7 +847,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -761,7 +862,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.307 | Test Acc: 88.17%\n"
|
||||
"Test Loss: 0.291 | Test Acc: 88.06%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -775,7 +876,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -797,7 +898,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -811,10 +912,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.015085793100297451"
|
||||
"0.06933268904685974"
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -827,7 +928,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -841,10 +942,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.991886556148529"
|
||||
"0.9730159640312195"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -857,7 +958,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 40,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -871,10 +972,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.6712993383407593"
|
||||
"0.1614144891500473"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -888,7 +989,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 41,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -902,10 +1003,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.28944137692451477"
|
||||
"0.5040232539176941"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -127,7 +127,7 @@
|
||||
" tokens = [token.lower() for token in tokens]\n",
|
||||
" \n",
|
||||
" if self.max_length is not None:\n",
|
||||
" tokens = tokens[:max_length]\n",
|
||||
" tokens = tokens[:self.max_length]\n",
|
||||
" \n",
|
||||
" return tokens"
|
||||
]
|
||||
@ -419,6 +419,78 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"name: embedding.weight, shape: torch.Size([25002, 100])\n",
|
||||
"name: convs.0.weight, shape: torch.Size([100, 100, 3])\n",
|
||||
"name: convs.0.bias, shape: torch.Size([100])\n",
|
||||
"name: convs.1.weight, shape: torch.Size([100, 100, 4])\n",
|
||||
"name: convs.1.bias, shape: torch.Size([100])\n",
|
||||
"name: convs.2.weight, shape: torch.Size([100, 100, 5])\n",
|
||||
"name: convs.2.bias, shape: torch.Size([100])\n",
|
||||
"name: fc.weight, shape: torch.Size([2, 300])\n",
|
||||
"name: fc.bias, shape: torch.Size([2])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for n, p in model.named_parameters():\n",
|
||||
" print(f'name: {n}, shape: {p.shape}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def initialize_parameters(m):\n",
|
||||
" if isinstance(m, nn.Embedding):\n",
|
||||
" nn.init.uniform_(m.weight, -0.05, 0.05)\n",
|
||||
" elif isinstance(m, nn.Conv1d):\n",
|
||||
" nn.init.xavier_uniform_(m.weight)\n",
|
||||
" nn.init.zeros_(m.bias)\n",
|
||||
" elif isinstance(m, nn.Linear):\n",
|
||||
" nn.init.xavier_uniform_(m.weight)\n",
|
||||
" nn.init.zeros_(m.bias)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"CNN(\n",
|
||||
" (embedding): Embedding(25002, 100, padding_idx=1)\n",
|
||||
" (convs): ModuleList(\n",
|
||||
" (0): Conv1d(100, 100, kernel_size=(3,), stride=(1,))\n",
|
||||
" (1): Conv1d(100, 100, kernel_size=(4,), stride=(1,))\n",
|
||||
" (2): Conv1d(100, 100, kernel_size=(5,), stride=(1,))\n",
|
||||
" )\n",
|
||||
" (fc): Linear(in_features=300, out_features=2, bias=True)\n",
|
||||
" (dropout): Dropout(p=0.5, inplace=False)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.apply(initialize_parameters)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -432,7 +504,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -440,30 +512,26 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_pretrained_embedding(vectors, vocab, unk_token):\n",
|
||||
"def get_pretrained_embedding(initial_embedding, pretrained_vectors, vocab, unk_token):\n",
|
||||
" \n",
|
||||
" unk_vector = vectors[unk_token]\n",
|
||||
" emb_dim = unk_vector.shape[-1]\n",
|
||||
" zero_vector = torch.zeros(emb_dim)\n",
|
||||
"\n",
|
||||
" pretrained_embedding = torch.zeros(len(vocab), emb_dim) \n",
|
||||
" pretrained_embedding = torch.FloatTensor(initial_embedding.weight.clone()).detach() \n",
|
||||
" pretrained_vocab = pretrained_vectors.vectors.get_stoi()\n",
|
||||
" \n",
|
||||
" unk_tokens = []\n",
|
||||
" \n",
|
||||
" for idx, token in enumerate(vocab.itos):\n",
|
||||
" pretrained_vector = vectors[token]\n",
|
||||
" if torch.all(torch.eq(pretrained_vector, unk_vector)):\n",
|
||||
" unk_tokens.append(token)\n",
|
||||
" pretrained_embedding[idx] = zero_vector\n",
|
||||
" else:\n",
|
||||
" if token in pretrained_vocab:\n",
|
||||
" pretrained_vector = pretrained_vectors[token]\n",
|
||||
" pretrained_embedding[idx] = pretrained_vector\n",
|
||||
" else:\n",
|
||||
" unk_tokens.append(token)\n",
|
||||
" \n",
|
||||
" return pretrained_embedding, unk_tokens"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -473,12 +541,12 @@
|
||||
"source": [
|
||||
"unk_token = '<unk>'\n",
|
||||
"\n",
|
||||
"pretrained_embedding, unk_tokens = get_pretrained_embedding(glove.vectors, vocab, unk_token)"
|
||||
"pretrained_embedding, unk_tokens = get_pretrained_embedding(model.embedding, glove, vocab, unk_token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -492,8 +560,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
"tensor([[-0.0220, -0.0288, -0.0422, ..., 0.0103, 0.0218, -0.0141],\n",
|
||||
" [ 0.0326, 0.0222, 0.0044, ..., 0.0249, 0.0163, 0.0052],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
|
||||
@ -501,7 +569,7 @@
|
||||
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -512,7 +580,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.embedding.weight.data[pad_idx] = torch.zeros(emb_dim)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -525,7 +602,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -538,7 +615,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 29,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -551,7 +628,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -565,7 +642,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -582,7 +659,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -622,7 +699,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -658,7 +735,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -675,7 +752,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -691,35 +768,35 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.720 | Train Acc: 55.50%\n",
|
||||
"\t Val. Loss: 0.617 | Val. Acc: 67.59%\n",
|
||||
"\tTrain Loss: 1.370 | Train Acc: 53.26%\n",
|
||||
"\t Val. Loss: 0.588 | Val. Acc: 69.31%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.593 | Train Acc: 68.32%\n",
|
||||
"\t Val. Loss: 0.480 | Val. Acc: 79.81%\n",
|
||||
"\tTrain Loss: 0.796 | Train Acc: 60.77%\n",
|
||||
"\t Val. Loss: 0.562 | Val. Acc: 73.82%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.456 | Train Acc: 79.10%\n",
|
||||
"\t Val. Loss: 0.385 | Val. Acc: 83.43%\n",
|
||||
"\tTrain Loss: 0.620 | Train Acc: 67.86%\n",
|
||||
"\t Val. Loss: 0.523 | Val. Acc: 78.67%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.384 | Train Acc: 82.80%\n",
|
||||
"\t Val. Loss: 0.347 | Val. Acc: 85.50%\n",
|
||||
"\tTrain Loss: 0.523 | Train Acc: 74.40%\n",
|
||||
"\t Val. Loss: 0.459 | Val. Acc: 81.48%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.348 | Train Acc: 84.98%\n",
|
||||
"\t Val. Loss: 0.327 | Val. Acc: 86.35%\n",
|
||||
"\tTrain Loss: 0.459 | Train Acc: 78.51%\n",
|
||||
"\t Val. Loss: 0.416 | Val. Acc: 83.35%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.316 | Train Acc: 86.65%\n",
|
||||
"\t Val. Loss: 0.312 | Val. Acc: 87.05%\n",
|
||||
"\tTrain Loss: 0.412 | Train Acc: 81.52%\n",
|
||||
"\t Val. Loss: 0.381 | Val. Acc: 84.52%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.291 | Train Acc: 87.85%\n",
|
||||
"\t Val. Loss: 0.305 | Val. Acc: 87.50%\n",
|
||||
"\tTrain Loss: 0.374 | Train Acc: 83.71%\n",
|
||||
"\t Val. Loss: 0.369 | Val. Acc: 84.95%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.263 | Train Acc: 89.14%\n",
|
||||
"\t Val. Loss: 0.301 | Val. Acc: 87.73%\n",
|
||||
"\tTrain Loss: 0.356 | Train Acc: 84.29%\n",
|
||||
"\t Val. Loss: 0.356 | Val. Acc: 85.49%\n",
|
||||
"Epoch: 09 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.242 | Train Acc: 90.23%\n",
|
||||
"\t Val. Loss: 0.296 | Val. Acc: 88.02%\n",
|
||||
"\tTrain Loss: 0.339 | Train Acc: 85.20%\n",
|
||||
"\t Val. Loss: 0.344 | Val. Acc: 85.92%\n",
|
||||
"Epoch: 10 | Epoch Time: 0m 9s\n",
|
||||
"\tTrain Loss: 0.226 | Train Acc: 90.82%\n",
|
||||
"\t Val. Loss: 0.290 | Val. Acc: 88.22%\n"
|
||||
"\tTrain Loss: 0.318 | Train Acc: 86.43%\n",
|
||||
"\t Val. Loss: 0.334 | Val. Acc: 86.28%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -750,7 +827,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -765,7 +842,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.295 | Test Acc: 87.70%\n"
|
||||
"Test Loss: 0.338 | Test Acc: 85.99%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -779,7 +856,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -800,7 +877,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -814,10 +891,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.025242310017347336"
|
||||
"0.08827298134565353"
|
||||
]
|
||||
},
|
||||
"execution_count": 34,
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -830,7 +907,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -844,10 +921,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.7093814015388489"
|
||||
"0.6329940557479858"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -860,7 +937,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 40,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -874,10 +951,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.04373614862561226"
|
||||
"0.060872383415699005"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -891,7 +968,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 41,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -905,10 +982,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.049333445727825165"
|
||||
"0.07820437103509903"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -130,7 +130,7 @@
|
||||
" tokens = [self.sos_token] + tokens\n",
|
||||
"\n",
|
||||
" if self.max_length is not None:\n",
|
||||
" tokens = tokens[:max_length]\n",
|
||||
" tokens = tokens[:self.max_length]\n",
|
||||
" \n",
|
||||
" return tokens"
|
||||
]
|
||||
@ -416,27 +416,6 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def init_parameters(model):\n",
|
||||
" for n, p in model.named_parameters():\n",
|
||||
" if p.dim() > 1:\n",
|
||||
" nn.init.xavier_uniform_(p)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"init_parameters(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -450,7 +429,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -473,9 +452,151 @@
|
||||
"print(f'The model has {count_parameters(model):,} trainable parameters')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"name: tok_embedding.weight, shape: torch.Size([25002, 100])\n",
|
||||
"name: pos_embedding.weight, shape: torch.Size([250, 100])\n",
|
||||
"name: transformer.layers.0.self_attn.in_proj_weight, shape: torch.Size([300, 100])\n",
|
||||
"name: transformer.layers.0.self_attn.in_proj_bias, shape: torch.Size([300])\n",
|
||||
"name: transformer.layers.0.self_attn.out_proj.weight, shape: torch.Size([100, 100])\n",
|
||||
"name: transformer.layers.0.self_attn.out_proj.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.0.linear1.weight, shape: torch.Size([1024, 100])\n",
|
||||
"name: transformer.layers.0.linear1.bias, shape: torch.Size([1024])\n",
|
||||
"name: transformer.layers.0.linear2.weight, shape: torch.Size([100, 1024])\n",
|
||||
"name: transformer.layers.0.linear2.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.0.norm1.weight, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.0.norm1.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.0.norm2.weight, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.0.norm2.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.1.self_attn.in_proj_weight, shape: torch.Size([300, 100])\n",
|
||||
"name: transformer.layers.1.self_attn.in_proj_bias, shape: torch.Size([300])\n",
|
||||
"name: transformer.layers.1.self_attn.out_proj.weight, shape: torch.Size([100, 100])\n",
|
||||
"name: transformer.layers.1.self_attn.out_proj.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.1.linear1.weight, shape: torch.Size([1024, 100])\n",
|
||||
"name: transformer.layers.1.linear1.bias, shape: torch.Size([1024])\n",
|
||||
"name: transformer.layers.1.linear2.weight, shape: torch.Size([100, 1024])\n",
|
||||
"name: transformer.layers.1.linear2.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.1.norm1.weight, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.1.norm1.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.1.norm2.weight, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.1.norm2.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.2.self_attn.in_proj_weight, shape: torch.Size([300, 100])\n",
|
||||
"name: transformer.layers.2.self_attn.in_proj_bias, shape: torch.Size([300])\n",
|
||||
"name: transformer.layers.2.self_attn.out_proj.weight, shape: torch.Size([100, 100])\n",
|
||||
"name: transformer.layers.2.self_attn.out_proj.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.2.linear1.weight, shape: torch.Size([1024, 100])\n",
|
||||
"name: transformer.layers.2.linear1.bias, shape: torch.Size([1024])\n",
|
||||
"name: transformer.layers.2.linear2.weight, shape: torch.Size([100, 1024])\n",
|
||||
"name: transformer.layers.2.linear2.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.2.norm1.weight, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.2.norm1.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.2.norm2.weight, shape: torch.Size([100])\n",
|
||||
"name: transformer.layers.2.norm2.bias, shape: torch.Size([100])\n",
|
||||
"name: transformer.norm.weight, shape: torch.Size([100])\n",
|
||||
"name: transformer.norm.bias, shape: torch.Size([100])\n",
|
||||
"name: fc.weight, shape: torch.Size([2, 100])\n",
|
||||
"name: fc.bias, shape: torch.Size([2])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for n, p in model.named_parameters():\n",
|
||||
" print(f'name: {n}, shape: {p.shape}')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def initialize_parameters(m):\n",
|
||||
" if isinstance(m, nn.Embedding):\n",
|
||||
" nn.init.normal_(m.weight, std = 0.02)\n",
|
||||
" elif isinstance(m, nn.Linear):\n",
|
||||
" nn.init.normal_(m.weight, std = 0.02)\n",
|
||||
" nn.init.zeros_(m.bias)\n",
|
||||
" elif isinstance(m, nn.LayerNorm):\n",
|
||||
" nn.init.ones_(m.weight)\n",
|
||||
" nn.init.zeros_(m.bias)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Transformer(\n",
|
||||
" (tok_embedding): Embedding(25002, 100, padding_idx=1)\n",
|
||||
" (pos_embedding): Embedding(250, 100)\n",
|
||||
" (transformer): TransformerEncoder(\n",
|
||||
" (layers): ModuleList(\n",
|
||||
" (0): TransformerEncoderLayer(\n",
|
||||
" (self_attn): MultiheadAttention(\n",
|
||||
" (out_proj): _LinearWithBias(in_features=100, out_features=100, bias=True)\n",
|
||||
" )\n",
|
||||
" (linear1): Linear(in_features=100, out_features=1024, bias=True)\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (linear2): Linear(in_features=1024, out_features=100, bias=True)\n",
|
||||
" (norm1): LayerNorm((100,), eps=1e-05, elementwise_affine=True)\n",
|
||||
" (norm2): LayerNorm((100,), eps=1e-05, elementwise_affine=True)\n",
|
||||
" (dropout1): Dropout(p=0.1, inplace=False)\n",
|
||||
" (dropout2): Dropout(p=0.1, inplace=False)\n",
|
||||
" )\n",
|
||||
" (1): TransformerEncoderLayer(\n",
|
||||
" (self_attn): MultiheadAttention(\n",
|
||||
" (out_proj): _LinearWithBias(in_features=100, out_features=100, bias=True)\n",
|
||||
" )\n",
|
||||
" (linear1): Linear(in_features=100, out_features=1024, bias=True)\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (linear2): Linear(in_features=1024, out_features=100, bias=True)\n",
|
||||
" (norm1): LayerNorm((100,), eps=1e-05, elementwise_affine=True)\n",
|
||||
" (norm2): LayerNorm((100,), eps=1e-05, elementwise_affine=True)\n",
|
||||
" (dropout1): Dropout(p=0.1, inplace=False)\n",
|
||||
" (dropout2): Dropout(p=0.1, inplace=False)\n",
|
||||
" )\n",
|
||||
" (2): TransformerEncoderLayer(\n",
|
||||
" (self_attn): MultiheadAttention(\n",
|
||||
" (out_proj): _LinearWithBias(in_features=100, out_features=100, bias=True)\n",
|
||||
" )\n",
|
||||
" (linear1): Linear(in_features=100, out_features=1024, bias=True)\n",
|
||||
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
||||
" (linear2): Linear(in_features=1024, out_features=100, bias=True)\n",
|
||||
" (norm1): LayerNorm((100,), eps=1e-05, elementwise_affine=True)\n",
|
||||
" (norm2): LayerNorm((100,), eps=1e-05, elementwise_affine=True)\n",
|
||||
" (dropout1): Dropout(p=0.1, inplace=False)\n",
|
||||
" (dropout2): Dropout(p=0.1, inplace=False)\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" (norm): LayerNorm((100,), eps=1e-05, elementwise_affine=True)\n",
|
||||
" )\n",
|
||||
" (fc): Linear(in_features=100, out_features=2, bias=True)\n",
|
||||
" (dropout): Dropout(p=0.5, inplace=False)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model.apply(initialize_parameters)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -489,7 +610,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -497,30 +618,26 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_pretrained_embedding(vectors, vocab, unk_token):\n",
|
||||
"def get_pretrained_embedding(initial_embedding, pretrained_vectors, vocab, unk_token):\n",
|
||||
" \n",
|
||||
" unk_vector = vectors[unk_token]\n",
|
||||
" emb_dim = unk_vector.shape[-1]\n",
|
||||
" zero_vector = torch.zeros(emb_dim)\n",
|
||||
"\n",
|
||||
" pretrained_embedding = torch.zeros(len(vocab), emb_dim) \n",
|
||||
" pretrained_embedding = torch.FloatTensor(initial_embedding.weight.clone()).detach() \n",
|
||||
" pretrained_vocab = pretrained_vectors.vectors.get_stoi()\n",
|
||||
" \n",
|
||||
" unk_tokens = []\n",
|
||||
" \n",
|
||||
" for idx, token in enumerate(vocab.itos):\n",
|
||||
" pretrained_vector = vectors[token]\n",
|
||||
" if torch.all(torch.eq(pretrained_vector, unk_vector)):\n",
|
||||
" unk_tokens.append(token)\n",
|
||||
" pretrained_embedding[idx] = zero_vector\n",
|
||||
" else:\n",
|
||||
" if token in pretrained_vocab:\n",
|
||||
" pretrained_vector = pretrained_vectors[token]\n",
|
||||
" pretrained_embedding[idx] = pretrained_vector\n",
|
||||
" else:\n",
|
||||
" unk_tokens.append(token)\n",
|
||||
" \n",
|
||||
" return pretrained_embedding, unk_tokens"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -530,12 +647,12 @@
|
||||
"source": [
|
||||
"unk_token = '<unk>'\n",
|
||||
"\n",
|
||||
"pretrained_embedding, unk_tokens = get_pretrained_embedding(glove.vectors, vocab, unk_token)"
|
||||
"pretrained_embedding, unk_tokens = get_pretrained_embedding(model.tok_embedding, glove, vocab, unk_token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -549,16 +666,16 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
"tensor([[-0.0118, 0.0220, -0.0321, ..., 0.0011, 0.0252, 0.0027],\n",
|
||||
" [ 0.0154, -0.0052, 0.0104, ..., -0.0116, 0.0198, -0.0480],\n",
|
||||
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
||||
" ...,\n",
|
||||
" [-0.0288, -0.0316, 0.4083, ..., 0.6288, -0.5348, -0.8080],\n",
|
||||
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
||||
" [ 0.0215, 0.0027, -0.0050, ..., -0.0036, -0.0102, 0.0206],\n",
|
||||
" [-0.2612, 0.6821, -0.2295, ..., -0.5306, 0.0863, 0.4852]])"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"execution_count": 26,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@ -569,7 +686,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.tok_embedding.weight.data[pad_idx] = torch.zeros(emb_dim)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -582,7 +708,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 29,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -595,7 +721,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -608,7 +734,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -622,7 +748,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -639,7 +765,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -679,7 +805,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -715,7 +841,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -732,7 +858,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -748,35 +874,35 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 01 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.986 | Train Acc: 49.99%\n",
|
||||
"\t Val. Loss: 0.693 | Val. Acc: 50.61%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 25s\n",
|
||||
"\tTrain Loss: 0.711 | Train Acc: 50.42%\n",
|
||||
"\t Val. Loss: 0.694 | Val. Acc: 49.39%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 25s\n",
|
||||
"\tTrain Loss: 0.701 | Train Acc: 50.51%\n",
|
||||
"\t Val. Loss: 0.690 | Val. Acc: 54.57%\n",
|
||||
"\tTrain Loss: 0.652 | Train Acc: 59.58%\n",
|
||||
"\t Val. Loss: 0.492 | Val. Acc: 77.16%\n",
|
||||
"Epoch: 02 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.433 | Train Acc: 80.47%\n",
|
||||
"\t Val. Loss: 0.380 | Val. Acc: 83.43%\n",
|
||||
"Epoch: 03 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.351 | Train Acc: 84.96%\n",
|
||||
"\t Val. Loss: 0.366 | Val. Acc: 83.79%\n",
|
||||
"Epoch: 04 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.646 | Train Acc: 60.74%\n",
|
||||
"\t Val. Loss: 0.431 | Val. Acc: 80.03%\n",
|
||||
"\tTrain Loss: 0.301 | Train Acc: 87.46%\n",
|
||||
"\t Val. Loss: 0.323 | Val. Acc: 86.52%\n",
|
||||
"Epoch: 05 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.396 | Train Acc: 82.62%\n",
|
||||
"\t Val. Loss: 0.345 | Val. Acc: 84.42%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 25s\n",
|
||||
"\tTrain Loss: 0.303 | Train Acc: 87.41%\n",
|
||||
"\t Val. Loss: 0.329 | Val. Acc: 86.88%\n",
|
||||
"\tTrain Loss: 0.258 | Train Acc: 89.47%\n",
|
||||
"\t Val. Loss: 0.324 | Val. Acc: 87.18%\n",
|
||||
"Epoch: 06 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.232 | Train Acc: 90.77%\n",
|
||||
"\t Val. Loss: 0.320 | Val. Acc: 86.84%\n",
|
||||
"Epoch: 07 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.250 | Train Acc: 89.76%\n",
|
||||
"\t Val. Loss: 0.362 | Val. Acc: 86.24%\n",
|
||||
"\tTrain Loss: 0.200 | Train Acc: 92.21%\n",
|
||||
"\t Val. Loss: 0.439 | Val. Acc: 82.49%\n",
|
||||
"Epoch: 08 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.213 | Train Acc: 91.69%\n",
|
||||
"\t Val. Loss: 0.353 | Val. Acc: 87.64%\n",
|
||||
"\tTrain Loss: 0.188 | Train Acc: 92.84%\n",
|
||||
"\t Val. Loss: 0.381 | Val. Acc: 86.18%\n",
|
||||
"Epoch: 09 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.199 | Train Acc: 92.32%\n",
|
||||
"\t Val. Loss: 0.327 | Val. Acc: 87.93%\n",
|
||||
"\tTrain Loss: 0.159 | Train Acc: 94.02%\n",
|
||||
"\t Val. Loss: 0.358 | Val. Acc: 87.33%\n",
|
||||
"Epoch: 10 | Epoch Time: 0m 24s\n",
|
||||
"\tTrain Loss: 0.171 | Train Acc: 93.47%\n",
|
||||
"\t Val. Loss: 0.362 | Val. Acc: 87.67%\n"
|
||||
"\tTrain Loss: 0.145 | Train Acc: 94.66%\n",
|
||||
"\t Val. Loss: 0.420 | Val. Acc: 86.20%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -807,7 +933,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -822,7 +948,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Loss: 0.354 | Test Acc: 86.14%\n"
|
||||
"Test Loss: 0.349 | Test Acc: 85.45%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -836,7 +962,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -857,7 +983,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -871,67 +997,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.00010680717241484672"
|
||||
]
|
||||
},
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = 'the absolute worst movie of all time.'\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "aLqml9PenBMp",
|
||||
"outputId": "1614cf67-7583-4cb6-ab17-09ea8d1774a6"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9999222755432129"
|
||||
]
|
||||
},
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = 'one of the greatest films i have ever seen in my life.'\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "MyjsYDeJnCui",
|
||||
"outputId": "d87ccbee-9e91-4e64-fb2b-aaaf474f12e6"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.07242873311042786"
|
||||
"0.0066763292998075485"
|
||||
]
|
||||
},
|
||||
"execution_count": 39,
|
||||
@ -940,8 +1006,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = \"i thought it was going to be one of the greatest films i have ever seen in my life, \\\n",
|
||||
"but it was actually the absolute worst movie of all time.\"\n",
|
||||
"sentence = 'the absolute worst movie of all time.'\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
@ -949,6 +1014,67 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "aLqml9PenBMp",
|
||||
"outputId": "1614cf67-7583-4cb6-ab17-09ea8d1774a6"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.9929355978965759"
|
||||
]
|
||||
},
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = 'one of the greatest films i have ever seen in my life.'\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 35
|
||||
},
|
||||
"colab_type": "code",
|
||||
"id": "MyjsYDeJnCui",
|
||||
"outputId": "d87ccbee-9e91-4e64-fb2b-aaaf474f12e6"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.021573422476649284"
|
||||
]
|
||||
},
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sentence = \"i thought it was going to be one of the greatest films i have ever seen in my life, \\\n",
|
||||
"but it was actually the absolute worst movie of all time.\"\n",
|
||||
"\n",
|
||||
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
@ -962,10 +1088,10 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"0.13505800068378448"
|
||||
"0.026321368291974068"
|
||||
]
|
||||
},
|
||||
"execution_count": 40,
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user