added tqdm

This commit is contained in:
bentrevett 2021-07-08 19:05:18 +01:00
parent e8f7d1ef96
commit ef687f9e92
3 changed files with 161 additions and 64 deletions

View File

@ -8,18 +8,16 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import functools\n", "import functools\n",
"import sys\n",
"\n", "\n",
"import datasets\n", "import datasets\n",
"\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"\n",
"import numpy as np\n", "import numpy as np\n",
"\n",
"import torch\n", "import torch\n",
"import torch.nn as nn\n", "import torch.nn as nn\n",
"import torch.optim as optim\n", "import torch.optim as optim\n",
"\n", "import torchtext\n",
"import torchtext" "import tqdm"
] ]
}, },
{ {
@ -31,7 +29,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"<torch._C.Generator at 0x7fd4fab4c930>" "<torch._C.Generator at 0x7f851a9849b0>"
] ]
}, },
"execution_count": 2, "execution_count": 2,
@ -1910,7 +1908,7 @@
" epoch_losses = []\n", " epoch_losses = []\n",
" epoch_accs = []\n", " epoch_accs = []\n",
"\n", "\n",
" for batch in dataloader:\n", " for batch in tqdm.tqdm(dataloader, desc='training...', file=sys.stdout):\n",
" ids = batch['ids'].to(device)\n", " ids = batch['ids'].to(device)\n",
" label = batch['label'].to(device)\n", " label = batch['label'].to(device)\n",
" prediction = model(ids)\n", " prediction = model(ids)\n",
@ -1939,7 +1937,7 @@
" epoch_accs = []\n", " epoch_accs = []\n",
"\n", "\n",
" with torch.no_grad():\n", " with torch.no_grad():\n",
" for batch in dataloader:\n", " for batch in tqdm.tqdm(dataloader, desc='evaluating...', file=sys.stdout):\n",
" ids = batch['ids'].to(device)\n", " ids = batch['ids'].to(device)\n",
" label = batch['label'].to(device)\n", " label = batch['label'].to(device)\n",
" prediction = model(ids)\n", " prediction = model(ids)\n",
@ -1976,33 +1974,53 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.13it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.12it/s]\n",
"epoch: 1\n", "epoch: 1\n",
"train_loss: 0.684, train_acc: 0.604\n", "train_loss: 0.684, train_acc: 0.604\n",
"valid_loss: 0.671, valid_acc: 0.682\n", "valid_loss: 0.671, valid_acc: 0.682\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.06it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.24it/s]\n",
"epoch: 2\n", "epoch: 2\n",
"train_loss: 0.648, train_acc: 0.718\n", "train_loss: 0.648, train_acc: 0.718\n",
"valid_loss: 0.627, valid_acc: 0.729\n", "valid_loss: 0.627, valid_acc: 0.729\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.65it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 15.88it/s]\n",
"epoch: 3\n", "epoch: 3\n",
"train_loss: 0.588, train_acc: 0.764\n", "train_loss: 0.588, train_acc: 0.764\n",
"valid_loss: 0.567, valid_acc: 0.769\n", "valid_loss: 0.567, valid_acc: 0.769\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.81it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 15.66it/s]\n",
"epoch: 4\n", "epoch: 4\n",
"train_loss: 0.516, train_acc: 0.807\n", "train_loss: 0.516, train_acc: 0.807\n",
"valid_loss: 0.504, valid_acc: 0.803\n", "valid_loss: 0.504, valid_acc: 0.803\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.80it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 15.93it/s]\n",
"epoch: 5\n", "epoch: 5\n",
"train_loss: 0.446, train_acc: 0.847\n", "train_loss: 0.446, train_acc: 0.847\n",
"valid_loss: 0.450, valid_acc: 0.833\n", "valid_loss: 0.450, valid_acc: 0.833\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.83it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.03it/s]\n",
"epoch: 6\n", "epoch: 6\n",
"train_loss: 0.388, train_acc: 0.870\n", "train_loss: 0.388, train_acc: 0.870\n",
"valid_loss: 0.411, valid_acc: 0.844\n", "valid_loss: 0.411, valid_acc: 0.844\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.40it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.37it/s]\n",
"epoch: 7\n", "epoch: 7\n",
"train_loss: 0.343, train_acc: 0.886\n", "train_loss: 0.343, train_acc: 0.886\n",
"valid_loss: 0.384, valid_acc: 0.852\n", "valid_loss: 0.384, valid_acc: 0.852\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.13it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.03it/s]\n",
"epoch: 8\n", "epoch: 8\n",
"train_loss: 0.308, train_acc: 0.899\n", "train_loss: 0.308, train_acc: 0.899\n",
"valid_loss: 0.364, valid_acc: 0.857\n", "valid_loss: 0.364, valid_acc: 0.857\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.99it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.12it/s]\n",
"epoch: 9\n", "epoch: 9\n",
"train_loss: 0.280, train_acc: 0.909\n", "train_loss: 0.280, train_acc: 0.909\n",
"valid_loss: 0.349, valid_acc: 0.862\n", "valid_loss: 0.349, valid_acc: 0.862\n",
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.62it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.37it/s]\n",
"epoch: 10\n", "epoch: 10\n",
"train_loss: 0.257, train_acc: 0.917\n", "train_loss: 0.257, train_acc: 0.917\n",
"valid_loss: 0.336, valid_acc: 0.867\n" "valid_loss: 0.336, valid_acc: 0.867\n"
@ -2110,6 +2128,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"evaluating...: 100%|██████████| 49/49 [00:03<00:00, 15.38it/s]\n",
"test_loss: 0.353, test_acc: 0.857\n" "test_loss: 0.353, test_acc: 0.857\n"
] ]
} }

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long