1607 lines
42 KiB
Plaintext
1607 lines
42 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "lIYdn1woOS1n",
|
||
"outputId": "cece5524-0d94-4cc4-b260-23e2f0ecc744"
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"id": "Y2upWg_Qvax1"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"import functools\n",
|
||
"\n",
|
||
"import datasets\n",
|
||
"\n",
|
||
"import torchtext\n",
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.optim as optim"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "ZIVeVqVUvdcK",
|
||
"outputId": "db0dbf36-4a75-4d30-bcef-8cb52b7a5b30"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "b5304f46c35d4fe6985cf45389babfda",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"Downloading: 0%| | 0.00/1.92k [00:00<?, ?B/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "afaae0e285b84b6caf5c65779b42f6ac",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"Downloading: 0%| | 0.00/1.05k [00:00<?, ?B/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Downloading and preparing dataset imdb/plain_text (download: 80.23 MiB, generated: 127.02 MiB, post-processed: Unknown size, total: 207.25 MiB) to /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a...\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "7ed6be629b54466f81d4bf603d23afef",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"Downloading: 0%| | 0.00/84.1M [00:00<?, ?B/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"0 examples [00:00, ? examples/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"0 examples [00:00, ? examples/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
"0 examples [00:00, ? examples/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Dataset imdb downloaded and prepared to /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a. Subsequent calls will reuse this data.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"train_data, test_data = datasets.load_dataset('imdb', split=['train', 'test'])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "2f6USIQOvkV_",
|
||
"outputId": "3444d1ad-e2eb-4766-92d8-ea4d85f7011f"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(Dataset({\n",
|
||
" features: ['label', 'text'],\n",
|
||
" num_rows: 25000\n",
|
||
" }),\n",
|
||
" Dataset({\n",
|
||
" features: ['label', 'text'],\n",
|
||
" num_rows: 25000\n",
|
||
" }))"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"train_data, test_data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {
|
||
"id": "3Av1UUtZxoSL"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"tokenizer = torchtext.data.utils.get_tokenizer('basic_english')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {
|
||
"id": "ju4G0eKEx2RO"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def tokenize_data(example, tokenizer):\n",
|
||
" tokens = {'tokens': tokenizer(example['text'])}\n",
|
||
" return tokens"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "T6RecxI4xs7s",
|
||
"outputId": "0c97c5cb-7bca-4664-98fa-5b665ab510c9"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "a1f91bc27c104b679ee954c62e6b33f8",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
" 0%| | 0/25000 [00:00<?, ?ex/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "aaa75c602f5d46b78a8c79a023f66db2",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
" 0%| | 0/25000 [00:00<?, ?ex/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"train_data = train_data.map(tokenize_data, fn_kwargs={'tokenizer': tokenizer})\n",
|
||
"test_data = test_data.map(tokenize_data, fn_kwargs={'tokenizer': tokenizer})"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "-Foj4qesxqiz",
|
||
"outputId": "09fc27eb-0f50-45df-aacb-debf53948cdc"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(Dataset({\n",
|
||
" features: ['label', 'text', 'tokens'],\n",
|
||
" num_rows: 25000\n",
|
||
" }),\n",
|
||
" Dataset({\n",
|
||
" features: ['label', 'text', 'tokens'],\n",
|
||
" num_rows: 25000\n",
|
||
" }))"
|
||
]
|
||
},
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"train_data, test_data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {
|
||
"id": "OfzYwhN2wFcA"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"train_valid_data = train_data.train_test_split(test_size=0.25)\n",
|
||
"train_data = train_valid_data['train']\n",
|
||
"valid_data = train_valid_data['test']"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "ovGkrOJkwZsC",
|
||
"outputId": "b5ca62c6-96e5-4eca-8c0f-558fd087dee1"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(18750, 6250, 25000)"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"len(train_data), len(valid_data), len(test_data)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {
|
||
"id": "9VMeQG_FxUVO"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"min_freq = 3\n",
|
||
"special_tokens = ['<unk>', '<pad>']\n",
|
||
"\n",
|
||
"vocab = torchtext.vocab.build_vocab_from_iterator(train_data['tokens'],\n",
|
||
" min_freq=min_freq,\n",
|
||
" specials=special_tokens)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "rbroBAClxXGB",
|
||
"outputId": "91c5da92-7f97-4ad8-a946-6de1899e64a2"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"35341"
|
||
]
|
||
},
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"len(vocab)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "3bKHqCxPyQSb",
|
||
"outputId": "35b3c437-f0f8-43fb-8f10-b968c597d5b4"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"['<unk>', '<pad>', 'the', '.', ',', 'and', 'a', 'of', 'to', \"'\"]"
|
||
]
|
||
},
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"vocab.get_itos()[:10]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "uStvd2szyUGR",
|
||
"outputId": "d4c9d2a4-86a9-413f-9200-5f1a27ece925"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"0"
|
||
]
|
||
},
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"unk_index = vocab['<unk>']\n",
|
||
"\n",
|
||
"unk_index"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "gd5R8NCJyws4",
|
||
"outputId": "46666a6d-56ff-42e3-ebe2-6f87930270c7"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"1"
|
||
]
|
||
},
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"pad_index = vocab['<pad>']\n",
|
||
"\n",
|
||
"pad_index"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {
|
||
"id": "_syj_YR8yp7B"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"vocab.set_default_index(unk_index)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {
|
||
"id": "ENlE1eAM0lHe"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def numericalize_data(example, vocab):\n",
|
||
" ids = {'ids': [vocab[token] for token in example['tokens']]}\n",
|
||
" return ids"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/",
|
||
"height": 164,
|
||
"referenced_widgets": [
|
||
"ab96a5a035b140aa9958ada5eeaa8edc",
|
||
"e393253753984998b533c18502dd4477",
|
||
"9c6b3d3c5c5a4c0b940c0aec3d833f58",
|
||
"a24b5ac5e18b48d48467217c8ea67463",
|
||
"b699f48ad80442d38481d4f960cf23a3",
|
||
"707426cd59454f53b955a2ed9e46b90e",
|
||
"70a2943fbcd3495aac9fae93886789dc",
|
||
"38e9f642f0714b95ba51da7cafbd1ff2",
|
||
"852f269e78204293afba0e452e43eba2",
|
||
"c87dea49f30b471d9128b2b2082adf89",
|
||
"a26e53975a1142828b8425895847b47b",
|
||
"6c44fd275fa34a2ab02b16519e48aca7",
|
||
"fdc8037ddba74cb7993f6f986e491894",
|
||
"398dcca3803c46beb2854bd52d66e322",
|
||
"f17e77bdd542441caf8d1d427c360601",
|
||
"0cf052dd17f54df19b0e5de1891ec1a4",
|
||
"32aa00f0d71b4655880a19ad08b9f02d",
|
||
"3c20326b78dc4e458bacb85742002a9a",
|
||
"96e6499f54c348628bc33ff48fd1ee6a",
|
||
"d52ef97893144dffa64d1da2b8ea52e6",
|
||
"d0832872ea074689b8c7e3b8b647e68d",
|
||
"b7fc360206ca42f5b51eacb4f1de763d",
|
||
"eeac1f30dad64508ba7f1a9473f390bd",
|
||
"958e2685a018455b9ea02803ca96b464"
|
||
]
|
||
},
|
||
"id": "ux_YLzDA069-",
|
||
"outputId": "11d67399-ace5-49ed-8327-0f85de3386a6"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "416782d88fee4bed84a1a08c59f4e549",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
" 0%| | 0/18750 [00:00<?, ?ex/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "ec6fef000aa447b794f102eda085ce84",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
" 0%| | 0/6250 [00:00<?, ?ex/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
},
|
||
{
|
||
"data": {
|
||
"application/vnd.jupyter.widget-view+json": {
|
||
"model_id": "28feaa784dda446ba6001d82dfe9514f",
|
||
"version_major": 2,
|
||
"version_minor": 0
|
||
},
|
||
"text/plain": [
|
||
" 0%| | 0/25000 [00:00<?, ?ex/s]"
|
||
]
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"source": [
|
||
"train_data = train_data.map(numericalize_data, fn_kwargs={'vocab': vocab})\n",
|
||
"valid_data = valid_data.map(numericalize_data, fn_kwargs={'vocab': vocab})\n",
|
||
"test_data = test_data.map(numericalize_data, fn_kwargs={'vocab': vocab})"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {
|
||
"id": "GAXGqlXT1FD9"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"train_data.set_format(type='torch', columns=['ids', 'label'])\n",
|
||
"valid_data.set_format(type='torch', columns=['ids', 'label'])\n",
|
||
"test_data.set_format(type='torch', columns=['ids', 'label'])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {
|
||
"id": "qCmyoFKAvmnj"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"class NBoW(nn.Module):\n",
|
||
" def __init__(self, vocab_size, embedding_dim, output_dim, pad_index):\n",
|
||
" super().__init__()\n",
|
||
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
|
||
" self.fc = nn.Linear(embedding_dim, output_dim)\n",
|
||
"\n",
|
||
" def forward(self, text):\n",
|
||
" # text = [batch size, seq len]\n",
|
||
" embedded = self.embedding(text)\n",
|
||
" # embedded = [batch size, seq len, embedding dim]\n",
|
||
" pooled = embedded.mean(dim=1)\n",
|
||
" # pooled = [batch size, embedding dim]\n",
|
||
" prediction = self.fc(pooled)\n",
|
||
" # prediction = [batch size, output dim]\n",
|
||
" return prediction"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {
|
||
"id": "lPym0qxrwC7e"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"vocab_size = len(vocab)\n",
|
||
"embedding_dim = 256\n",
|
||
"output_dim = 2\n",
|
||
"\n",
|
||
"model = NBoW(vocab_size, embedding_dim, output_dim, pad_index)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "Prvx6C3TFyI4",
|
||
"outputId": "c9069fbe-ec6e-423d-ff59-820f934c6432"
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
".vector_cache/wiki.en.vec: 7%|▋ | 465M/6.60G [08:18<3:21:18, 508kB/s] "
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"vectors = torchtext.vocab.FastText()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "e3gI343FIETN",
|
||
"outputId": "44ed85d4-3e22-4bf0-e86e-0d7c01e33c35"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"vectors = torchtext.vocab.GloVe()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "xLg8TKFCzAfL"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"optimizer = optim.Adam(model.parameters())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "XJrIlwlfzQqY"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"criterion = nn.CrossEntropyLoss()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "KbEREGWmzR7J",
|
||
"outputId": "44ec5013-ab39-4f35-b425-250f7e4093ef"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||
"\n",
|
||
"device"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "IBtd-0I3zVTo"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"model = model.to(device)\n",
|
||
"criterion = criterion.to(device)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "TNaDUz3M2QDM"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def collate(batch, pad_index):\n",
|
||
" batch_ids = [i['ids'] for i in batch]\n",
|
||
" batch_ids = nn.utils.rnn.pad_sequence(batch_ids, padding_value=pad_index, batch_first=True)\n",
|
||
" batch_labels = [i['label'] for i in batch]\n",
|
||
" batch_labels = torch.stack(batch_labels)\n",
|
||
" batch = {'ids': batch_ids,\n",
|
||
" 'labels': batch_labels}\n",
|
||
" return batch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "LYsAzjrV0AnH"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"batch_size = 512\n",
|
||
"\n",
|
||
"collate = functools.partial(collate, pad_index=pad_index)\n",
|
||
"\n",
|
||
"train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, collate_fn=collate)\n",
|
||
"valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, collate_fn=collate)\n",
|
||
"test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, collate_fn=collate)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "mKeLtjK5zX41"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def train(dataloader, model, criterion, optimizer, device):\n",
|
||
"\n",
|
||
" model.train()\n",
|
||
" epoch_loss = 0\n",
|
||
" epoch_accuracy = 0\n",
|
||
"\n",
|
||
" for batch in dataloader:\n",
|
||
" tokens = batch['ids'].to(device)\n",
|
||
" labels = batch['labels'].to(device)\n",
|
||
" predictions = model(tokens)\n",
|
||
" loss = criterion(predictions, labels)\n",
|
||
" accuracy = get_accuracy(predictions, labels)\n",
|
||
" optimizer.zero_grad()\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
" epoch_loss += loss.item()\n",
|
||
" epoch_accuracy += accuracy.item()\n",
|
||
"\n",
|
||
" return epoch_loss / len(dataloader), epoch_accuracy / len(dataloader)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "3gJHwUZZ0NC6"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def evaluate(dataloader, model, criterion, device):\n",
|
||
" \n",
|
||
" model.eval()\n",
|
||
" epoch_loss = 0\n",
|
||
" epoch_accuracy = 0\n",
|
||
"\n",
|
||
" with torch.no_grad():\n",
|
||
" for batch in dataloader:\n",
|
||
" tokens = batch['ids'].to(device)\n",
|
||
" labels = batch['labels'].to(device)\n",
|
||
" predictions = model(tokens)\n",
|
||
" loss = criterion(predictions, labels)\n",
|
||
" accuracy = get_accuracy(predictions, labels)\n",
|
||
" epoch_loss += loss.item()\n",
|
||
" epoch_accuracy += accuracy.item()\n",
|
||
"\n",
|
||
" return epoch_loss / len(dataloader), epoch_accuracy / len(dataloader)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "DOPRg4Gg5L44"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"def get_accuracy(predictions, labels):\n",
|
||
" batch_size = predictions.shape[0]\n",
|
||
" predicted_classes = predictions.argmax(1, keepdim=True)\n",
|
||
" correct_predictions = predicted_classes.eq(labels.view_as(predicted_classes)).sum()\n",
|
||
" accuracy = correct_predictions.float() / batch_size\n",
|
||
" return accuracy"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"colab": {
|
||
"base_uri": "https://localhost:8080/"
|
||
},
|
||
"id": "y0T-FtNN0PO3",
|
||
"outputId": "4781c655-bd66-4326-df6d-0d9dbedbaac5"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"n_epochs = 10\n",
|
||
"\n",
|
||
"for epoch in range(n_epochs):\n",
|
||
"\n",
|
||
" train_loss, train_acc = train(train_dataloader, model, criterion, optimizer, device)\n",
|
||
" valid_loss, valid_acc = evaluate(valid_dataloader, model, criterion, device)\n",
|
||
"\n",
|
||
" print(f'epoch: {epoch+1}')\n",
|
||
" print(f'train_loss: {train_loss:.3f}, train_acc: {train_acc:.3f}')\n",
|
||
" print(f'valid_loss: {valid_loss:.3f}, valid_acc: {valid_acc:.3f}')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "OD_BKHY_42XL"
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"accelerator": "GPU",
|
||
"colab": {
|
||
"machine_shape": "hm",
|
||
"name": "torchtext_0.10_imdb_nbow",
|
||
"provenance": []
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.5"
|
||
},
|
||
"widgets": {
|
||
"application/vnd.jupyter.widget-state+json": {
|
||
"0cf052dd17f54df19b0e5de1891ec1a4": {
|
||
"model_module": "@jupyter-widgets/base",
|
||
"model_name": "LayoutModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/base",
|
||
"_model_module_version": "1.2.0",
|
||
"_model_name": "LayoutModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "LayoutView",
|
||
"align_content": null,
|
||
"align_items": null,
|
||
"align_self": null,
|
||
"border": null,
|
||
"bottom": null,
|
||
"display": null,
|
||
"flex": null,
|
||
"flex_flow": null,
|
||
"grid_area": null,
|
||
"grid_auto_columns": null,
|
||
"grid_auto_flow": null,
|
||
"grid_auto_rows": null,
|
||
"grid_column": null,
|
||
"grid_gap": null,
|
||
"grid_row": null,
|
||
"grid_template_areas": null,
|
||
"grid_template_columns": null,
|
||
"grid_template_rows": null,
|
||
"height": null,
|
||
"justify_content": null,
|
||
"justify_items": null,
|
||
"left": null,
|
||
"margin": null,
|
||
"max_height": null,
|
||
"max_width": null,
|
||
"min_height": null,
|
||
"min_width": null,
|
||
"object_fit": null,
|
||
"object_position": null,
|
||
"order": null,
|
||
"overflow": null,
|
||
"overflow_x": null,
|
||
"overflow_y": null,
|
||
"padding": null,
|
||
"right": null,
|
||
"top": null,
|
||
"visibility": null,
|
||
"width": null
|
||
}
|
||
},
|
||
"32aa00f0d71b4655880a19ad08b9f02d": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "HBoxModel",
|
||
"state": {
|
||
"_dom_classes": [],
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "HBoxModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/controls",
|
||
"_view_module_version": "1.5.0",
|
||
"_view_name": "HBoxView",
|
||
"box_style": "",
|
||
"children": [
|
||
"IPY_MODEL_96e6499f54c348628bc33ff48fd1ee6a",
|
||
"IPY_MODEL_d52ef97893144dffa64d1da2b8ea52e6"
|
||
],
|
||
"layout": "IPY_MODEL_3c20326b78dc4e458bacb85742002a9a"
|
||
}
|
||
},
|
||
"38e9f642f0714b95ba51da7cafbd1ff2": {
|
||
"model_module": "@jupyter-widgets/base",
|
||
"model_name": "LayoutModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/base",
|
||
"_model_module_version": "1.2.0",
|
||
"_model_name": "LayoutModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "LayoutView",
|
||
"align_content": null,
|
||
"align_items": null,
|
||
"align_self": null,
|
||
"border": null,
|
||
"bottom": null,
|
||
"display": null,
|
||
"flex": null,
|
||
"flex_flow": null,
|
||
"grid_area": null,
|
||
"grid_auto_columns": null,
|
||
"grid_auto_flow": null,
|
||
"grid_auto_rows": null,
|
||
"grid_column": null,
|
||
"grid_gap": null,
|
||
"grid_row": null,
|
||
"grid_template_areas": null,
|
||
"grid_template_columns": null,
|
||
"grid_template_rows": null,
|
||
"height": null,
|
||
"justify_content": null,
|
||
"justify_items": null,
|
||
"left": null,
|
||
"margin": null,
|
||
"max_height": null,
|
||
"max_width": null,
|
||
"min_height": null,
|
||
"min_width": null,
|
||
"object_fit": null,
|
||
"object_position": null,
|
||
"order": null,
|
||
"overflow": null,
|
||
"overflow_x": null,
|
||
"overflow_y": null,
|
||
"padding": null,
|
||
"right": null,
|
||
"top": null,
|
||
"visibility": null,
|
||
"width": null
|
||
}
|
||
},
|
||
"398dcca3803c46beb2854bd52d66e322": {
|
||
"model_module": "@jupyter-widgets/base",
|
||
"model_name": "LayoutModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/base",
|
||
"_model_module_version": "1.2.0",
|
||
"_model_name": "LayoutModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "LayoutView",
|
||
"align_content": null,
|
||
"align_items": null,
|
||
"align_self": null,
|
||
"border": null,
|
||
"bottom": null,
|
||
"display": null,
|
||
"flex": null,
|
||
"flex_flow": null,
|
||
"grid_area": null,
|
||
"grid_auto_columns": null,
|
||
"grid_auto_flow": null,
|
||
"grid_auto_rows": null,
|
||
"grid_column": null,
|
||
"grid_gap": null,
|
||
"grid_row": null,
|
||
"grid_template_areas": null,
|
||
"grid_template_columns": null,
|
||
"grid_template_rows": null,
|
||
"height": null,
|
||
"justify_content": null,
|
||
"justify_items": null,
|
||
"left": null,
|
||
"margin": null,
|
||
"max_height": null,
|
||
"max_width": null,
|
||
"min_height": null,
|
||
"min_width": null,
|
||
"object_fit": null,
|
||
"object_position": null,
|
||
"order": null,
|
||
"overflow": null,
|
||
"overflow_x": null,
|
||
"overflow_y": null,
|
||
"padding": null,
|
||
"right": null,
|
||
"top": null,
|
||
"visibility": null,
|
||
"width": null
|
||
}
|
||
},
|
||
"3c20326b78dc4e458bacb85742002a9a": {
|
||
"model_module": "@jupyter-widgets/base",
|
||
"model_name": "LayoutModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/base",
|
||
"_model_module_version": "1.2.0",
|
||
"_model_name": "LayoutModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "LayoutView",
|
||
"align_content": null,
|
||
"align_items": null,
|
||
"align_self": null,
|
||
"border": null,
|
||
"bottom": null,
|
||
"display": null,
|
||
"flex": null,
|
||
"flex_flow": null,
|
||
"grid_area": null,
|
||
"grid_auto_columns": null,
|
||
"grid_auto_flow": null,
|
||
"grid_auto_rows": null,
|
||
"grid_column": null,
|
||
"grid_gap": null,
|
||
"grid_row": null,
|
||
"grid_template_areas": null,
|
||
"grid_template_columns": null,
|
||
"grid_template_rows": null,
|
||
"height": null,
|
||
"justify_content": null,
|
||
"justify_items": null,
|
||
"left": null,
|
||
"margin": null,
|
||
"max_height": null,
|
||
"max_width": null,
|
||
"min_height": null,
|
||
"min_width": null,
|
||
"object_fit": null,
|
||
"object_position": null,
|
||
"order": null,
|
||
"overflow": null,
|
||
"overflow_x": null,
|
||
"overflow_y": null,
|
||
"padding": null,
|
||
"right": null,
|
||
"top": null,
|
||
"visibility": null,
|
||
"width": null
|
||
}
|
||
},
|
||
"6c44fd275fa34a2ab02b16519e48aca7": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "HTMLModel",
|
||
"state": {
|
||
"_dom_classes": [],
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "HTMLModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/controls",
|
||
"_view_module_version": "1.5.0",
|
||
"_view_name": "HTMLView",
|
||
"description": "",
|
||
"description_tooltip": null,
|
||
"layout": "IPY_MODEL_0cf052dd17f54df19b0e5de1891ec1a4",
|
||
"placeholder": "",
|
||
"style": "IPY_MODEL_f17e77bdd542441caf8d1d427c360601",
|
||
"value": " 6250/6250 [00:14<00:00, 428.58ex/s]"
|
||
}
|
||
},
|
||
"707426cd59454f53b955a2ed9e46b90e": {
|
||
"model_module": "@jupyter-widgets/base",
|
||
"model_name": "LayoutModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/base",
|
||
"_model_module_version": "1.2.0",
|
||
"_model_name": "LayoutModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "LayoutView",
|
||
"align_content": null,
|
||
"align_items": null,
|
||
"align_self": null,
|
||
"border": null,
|
||
"bottom": null,
|
||
"display": null,
|
||
"flex": null,
|
||
"flex_flow": null,
|
||
"grid_area": null,
|
||
"grid_auto_columns": null,
|
||
"grid_auto_flow": null,
|
||
"grid_auto_rows": null,
|
||
"grid_column": null,
|
||
"grid_gap": null,
|
||
"grid_row": null,
|
||
"grid_template_areas": null,
|
||
"grid_template_columns": null,
|
||
"grid_template_rows": null,
|
||
"height": null,
|
||
"justify_content": null,
|
||
"justify_items": null,
|
||
"left": null,
|
||
"margin": null,
|
||
"max_height": null,
|
||
"max_width": null,
|
||
"min_height": null,
|
||
"min_width": null,
|
||
"object_fit": null,
|
||
"object_position": null,
|
||
"order": null,
|
||
"overflow": null,
|
||
"overflow_x": null,
|
||
"overflow_y": null,
|
||
"padding": null,
|
||
"right": null,
|
||
"top": null,
|
||
"visibility": null,
|
||
"width": null
|
||
}
|
||
},
|
||
"70a2943fbcd3495aac9fae93886789dc": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "DescriptionStyleModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "DescriptionStyleModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "StyleView",
|
||
"description_width": ""
|
||
}
|
||
},
|
||
"852f269e78204293afba0e452e43eba2": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "HBoxModel",
|
||
"state": {
|
||
"_dom_classes": [],
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "HBoxModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/controls",
|
||
"_view_module_version": "1.5.0",
|
||
"_view_name": "HBoxView",
|
||
"box_style": "",
|
||
"children": [
|
||
"IPY_MODEL_a26e53975a1142828b8425895847b47b",
|
||
"IPY_MODEL_6c44fd275fa34a2ab02b16519e48aca7"
|
||
],
|
||
"layout": "IPY_MODEL_c87dea49f30b471d9128b2b2082adf89"
|
||
}
|
||
},
|
||
"958e2685a018455b9ea02803ca96b464": {
|
||
"model_module": "@jupyter-widgets/base",
|
||
"model_name": "LayoutModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/base",
|
||
"_model_module_version": "1.2.0",
|
||
"_model_name": "LayoutModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "LayoutView",
|
||
"align_content": null,
|
||
"align_items": null,
|
||
"align_self": null,
|
||
"border": null,
|
||
"bottom": null,
|
||
"display": null,
|
||
"flex": null,
|
||
"flex_flow": null,
|
||
"grid_area": null,
|
||
"grid_auto_columns": null,
|
||
"grid_auto_flow": null,
|
||
"grid_auto_rows": null,
|
||
"grid_column": null,
|
||
"grid_gap": null,
|
||
"grid_row": null,
|
||
"grid_template_areas": null,
|
||
"grid_template_columns": null,
|
||
"grid_template_rows": null,
|
||
"height": null,
|
||
"justify_content": null,
|
||
"justify_items": null,
|
||
"left": null,
|
||
"margin": null,
|
||
"max_height": null,
|
||
"max_width": null,
|
||
"min_height": null,
|
||
"min_width": null,
|
||
"object_fit": null,
|
||
"object_position": null,
|
||
"order": null,
|
||
"overflow": null,
|
||
"overflow_x": null,
|
||
"overflow_y": null,
|
||
"padding": null,
|
||
"right": null,
|
||
"top": null,
|
||
"visibility": null,
|
||
"width": null
|
||
}
|
||
},
|
||
"96e6499f54c348628bc33ff48fd1ee6a": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "FloatProgressModel",
|
||
"state": {
|
||
"_dom_classes": [],
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "FloatProgressModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/controls",
|
||
"_view_module_version": "1.5.0",
|
||
"_view_name": "ProgressView",
|
||
"bar_style": "success",
|
||
"description": "100%",
|
||
"description_tooltip": null,
|
||
"layout": "IPY_MODEL_b7fc360206ca42f5b51eacb4f1de763d",
|
||
"max": 25000,
|
||
"min": 0,
|
||
"orientation": "horizontal",
|
||
"style": "IPY_MODEL_d0832872ea074689b8c7e3b8b647e68d",
|
||
"value": 25000
|
||
}
|
||
},
|
||
"9c6b3d3c5c5a4c0b940c0aec3d833f58": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "FloatProgressModel",
|
||
"state": {
|
||
"_dom_classes": [],
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "FloatProgressModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/controls",
|
||
"_view_module_version": "1.5.0",
|
||
"_view_name": "ProgressView",
|
||
"bar_style": "success",
|
||
"description": "100%",
|
||
"description_tooltip": null,
|
||
"layout": "IPY_MODEL_707426cd59454f53b955a2ed9e46b90e",
|
||
"max": 18750,
|
||
"min": 0,
|
||
"orientation": "horizontal",
|
||
"style": "IPY_MODEL_b699f48ad80442d38481d4f960cf23a3",
|
||
"value": 18750
|
||
}
|
||
},
|
||
"a24b5ac5e18b48d48467217c8ea67463": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "HTMLModel",
|
||
"state": {
|
||
"_dom_classes": [],
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "HTMLModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/controls",
|
||
"_view_module_version": "1.5.0",
|
||
"_view_name": "HTMLView",
|
||
"description": "",
|
||
"description_tooltip": null,
|
||
"layout": "IPY_MODEL_38e9f642f0714b95ba51da7cafbd1ff2",
|
||
"placeholder": "",
|
||
"style": "IPY_MODEL_70a2943fbcd3495aac9fae93886789dc",
|
||
"value": " 18750/18750 [00:28<00:00, 652.90ex/s]"
|
||
}
|
||
},
|
||
"a26e53975a1142828b8425895847b47b": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "FloatProgressModel",
|
||
"state": {
|
||
"_dom_classes": [],
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "FloatProgressModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/controls",
|
||
"_view_module_version": "1.5.0",
|
||
"_view_name": "ProgressView",
|
||
"bar_style": "success",
|
||
"description": "100%",
|
||
"description_tooltip": null,
|
||
"layout": "IPY_MODEL_398dcca3803c46beb2854bd52d66e322",
|
||
"max": 6250,
|
||
"min": 0,
|
||
"orientation": "horizontal",
|
||
"style": "IPY_MODEL_fdc8037ddba74cb7993f6f986e491894",
|
||
"value": 6250
|
||
}
|
||
},
|
||
"ab96a5a035b140aa9958ada5eeaa8edc": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "HBoxModel",
|
||
"state": {
|
||
"_dom_classes": [],
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "HBoxModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/controls",
|
||
"_view_module_version": "1.5.0",
|
||
"_view_name": "HBoxView",
|
||
"box_style": "",
|
||
"children": [
|
||
"IPY_MODEL_9c6b3d3c5c5a4c0b940c0aec3d833f58",
|
||
"IPY_MODEL_a24b5ac5e18b48d48467217c8ea67463"
|
||
],
|
||
"layout": "IPY_MODEL_e393253753984998b533c18502dd4477"
|
||
}
|
||
},
|
||
"b699f48ad80442d38481d4f960cf23a3": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "ProgressStyleModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "ProgressStyleModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "StyleView",
|
||
"bar_color": null,
|
||
"description_width": "initial"
|
||
}
|
||
},
|
||
"b7fc360206ca42f5b51eacb4f1de763d": {
|
||
"model_module": "@jupyter-widgets/base",
|
||
"model_name": "LayoutModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/base",
|
||
"_model_module_version": "1.2.0",
|
||
"_model_name": "LayoutModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "LayoutView",
|
||
"align_content": null,
|
||
"align_items": null,
|
||
"align_self": null,
|
||
"border": null,
|
||
"bottom": null,
|
||
"display": null,
|
||
"flex": null,
|
||
"flex_flow": null,
|
||
"grid_area": null,
|
||
"grid_auto_columns": null,
|
||
"grid_auto_flow": null,
|
||
"grid_auto_rows": null,
|
||
"grid_column": null,
|
||
"grid_gap": null,
|
||
"grid_row": null,
|
||
"grid_template_areas": null,
|
||
"grid_template_columns": null,
|
||
"grid_template_rows": null,
|
||
"height": null,
|
||
"justify_content": null,
|
||
"justify_items": null,
|
||
"left": null,
|
||
"margin": null,
|
||
"max_height": null,
|
||
"max_width": null,
|
||
"min_height": null,
|
||
"min_width": null,
|
||
"object_fit": null,
|
||
"object_position": null,
|
||
"order": null,
|
||
"overflow": null,
|
||
"overflow_x": null,
|
||
"overflow_y": null,
|
||
"padding": null,
|
||
"right": null,
|
||
"top": null,
|
||
"visibility": null,
|
||
"width": null
|
||
}
|
||
},
|
||
"c87dea49f30b471d9128b2b2082adf89": {
|
||
"model_module": "@jupyter-widgets/base",
|
||
"model_name": "LayoutModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/base",
|
||
"_model_module_version": "1.2.0",
|
||
"_model_name": "LayoutModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "LayoutView",
|
||
"align_content": null,
|
||
"align_items": null,
|
||
"align_self": null,
|
||
"border": null,
|
||
"bottom": null,
|
||
"display": null,
|
||
"flex": null,
|
||
"flex_flow": null,
|
||
"grid_area": null,
|
||
"grid_auto_columns": null,
|
||
"grid_auto_flow": null,
|
||
"grid_auto_rows": null,
|
||
"grid_column": null,
|
||
"grid_gap": null,
|
||
"grid_row": null,
|
||
"grid_template_areas": null,
|
||
"grid_template_columns": null,
|
||
"grid_template_rows": null,
|
||
"height": null,
|
||
"justify_content": null,
|
||
"justify_items": null,
|
||
"left": null,
|
||
"margin": null,
|
||
"max_height": null,
|
||
"max_width": null,
|
||
"min_height": null,
|
||
"min_width": null,
|
||
"object_fit": null,
|
||
"object_position": null,
|
||
"order": null,
|
||
"overflow": null,
|
||
"overflow_x": null,
|
||
"overflow_y": null,
|
||
"padding": null,
|
||
"right": null,
|
||
"top": null,
|
||
"visibility": null,
|
||
"width": null
|
||
}
|
||
},
|
||
"d0832872ea074689b8c7e3b8b647e68d": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "ProgressStyleModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "ProgressStyleModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "StyleView",
|
||
"bar_color": null,
|
||
"description_width": "initial"
|
||
}
|
||
},
|
||
"d52ef97893144dffa64d1da2b8ea52e6": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "HTMLModel",
|
||
"state": {
|
||
"_dom_classes": [],
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "HTMLModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/controls",
|
||
"_view_module_version": "1.5.0",
|
||
"_view_name": "HTMLView",
|
||
"description": "",
|
||
"description_tooltip": null,
|
||
"layout": "IPY_MODEL_958e2685a018455b9ea02803ca96b464",
|
||
"placeholder": "",
|
||
"style": "IPY_MODEL_eeac1f30dad64508ba7f1a9473f390bd",
|
||
"value": " 25000/25000 [00:16<00:00, 1504.95ex/s]"
|
||
}
|
||
},
|
||
"e393253753984998b533c18502dd4477": {
|
||
"model_module": "@jupyter-widgets/base",
|
||
"model_name": "LayoutModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/base",
|
||
"_model_module_version": "1.2.0",
|
||
"_model_name": "LayoutModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "LayoutView",
|
||
"align_content": null,
|
||
"align_items": null,
|
||
"align_self": null,
|
||
"border": null,
|
||
"bottom": null,
|
||
"display": null,
|
||
"flex": null,
|
||
"flex_flow": null,
|
||
"grid_area": null,
|
||
"grid_auto_columns": null,
|
||
"grid_auto_flow": null,
|
||
"grid_auto_rows": null,
|
||
"grid_column": null,
|
||
"grid_gap": null,
|
||
"grid_row": null,
|
||
"grid_template_areas": null,
|
||
"grid_template_columns": null,
|
||
"grid_template_rows": null,
|
||
"height": null,
|
||
"justify_content": null,
|
||
"justify_items": null,
|
||
"left": null,
|
||
"margin": null,
|
||
"max_height": null,
|
||
"max_width": null,
|
||
"min_height": null,
|
||
"min_width": null,
|
||
"object_fit": null,
|
||
"object_position": null,
|
||
"order": null,
|
||
"overflow": null,
|
||
"overflow_x": null,
|
||
"overflow_y": null,
|
||
"padding": null,
|
||
"right": null,
|
||
"top": null,
|
||
"visibility": null,
|
||
"width": null
|
||
}
|
||
},
|
||
"eeac1f30dad64508ba7f1a9473f390bd": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "DescriptionStyleModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "DescriptionStyleModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "StyleView",
|
||
"description_width": ""
|
||
}
|
||
},
|
||
"f17e77bdd542441caf8d1d427c360601": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "DescriptionStyleModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "DescriptionStyleModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "StyleView",
|
||
"description_width": ""
|
||
}
|
||
},
|
||
"fdc8037ddba74cb7993f6f986e491894": {
|
||
"model_module": "@jupyter-widgets/controls",
|
||
"model_name": "ProgressStyleModel",
|
||
"state": {
|
||
"_model_module": "@jupyter-widgets/controls",
|
||
"_model_module_version": "1.5.0",
|
||
"_model_name": "ProgressStyleModel",
|
||
"_view_count": null,
|
||
"_view_module": "@jupyter-widgets/base",
|
||
"_view_module_version": "1.2.0",
|
||
"_view_name": "StyleView",
|
||
"bar_color": null,
|
||
"description_width": "initial"
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 1
|
||
}
|