pytorch-sentiment-analysis/1_nbow.ipynb

2280 lines
155 KiB
Plaintext
Raw Normal View History

2021-07-08 01:04:25 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e322bd29",
"metadata": {},
"outputs": [],
"source": [
"import functools\n",
2021-07-09 02:05:18 +08:00
"import sys\n",
2021-07-08 01:04:25 +08:00
"\n",
"import datasets\n",
2021-07-08 22:11:54 +08:00
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
2021-07-08 01:04:25 +08:00
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
2021-07-09 02:05:18 +08:00
"import torchtext\n",
"import tqdm"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fcc98ce9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2021-07-09 02:05:18 +08:00
"<torch._C.Generator at 0x7f851a9849b0>"
2021-07-08 01:04:25 +08:00
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seed = 0\n",
"\n",
"torch.manual_seed(seed)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "798f5387",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset imdb (/home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)\n"
]
}
],
"source": [
"train_data, test_data = datasets.load_dataset('imdb', split=['train', 'test'])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "42338609",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Dataset({\n",
" features: ['label', 'text'],\n",
" num_rows: 25000\n",
" }),\n",
" Dataset({\n",
" features: ['label', 'text'],\n",
" num_rows: 25000\n",
" }))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data, test_data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "25a6e8cb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'label': 1,\n",
" 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as \"Teachers\". My 35 years in the teaching profession lead me to believe that Bromwell High\\'s satire is much closer to reality than is \"Teachers\". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\\'t!'}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[0]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3017c0ab",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = torchtext.data.utils.get_tokenizer('basic_english')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "876ad3b9",
"metadata": {},
"outputs": [],
"source": [
"def tokenize_data(example, tokenizer, max_length):\n",
2021-07-08 18:26:28 +08:00
" tokens = tokenizer(example['text'])[:max_length]\n",
" return {'tokens': tokens}"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5e295030",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-07-08 22:11:54 +08:00
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-ad1b7a77180a232c.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-01c0069c185da175.arrow\n"
]
2021-07-08 01:04:25 +08:00
}
],
"source": [
"max_length = 256\n",
"\n",
"train_data = train_data.map(tokenize_data, fn_kwargs={'tokenizer': tokenizer, 'max_length': max_length})\n",
"test_data = test_data.map(tokenize_data, fn_kwargs={'tokenizer': tokenizer, 'max_length': max_length})"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f647bdf9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset({\n",
" features: ['label', 'text', 'tokens'],\n",
" num_rows: 25000\n",
"})"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2f3de3b9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'label': 1,\n",
" 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as \"Teachers\". My 35 years in the teaching profession lead me to believe that Bromwell High\\'s satire is much closer to reality than is \"Teachers\". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\\'t!',\n",
" 'tokens': ['bromwell',\n",
" 'high',\n",
" 'is',\n",
" 'a',\n",
" 'cartoon',\n",
" 'comedy',\n",
" '.',\n",
" 'it',\n",
" 'ran',\n",
" 'at',\n",
" 'the',\n",
" 'same',\n",
" 'time',\n",
" 'as',\n",
" 'some',\n",
" 'other',\n",
" 'programs',\n",
" 'about',\n",
" 'school',\n",
" 'life',\n",
" ',',\n",
" 'such',\n",
" 'as',\n",
" 'teachers',\n",
" '.',\n",
" 'my',\n",
" '35',\n",
" 'years',\n",
" 'in',\n",
" 'the',\n",
" 'teaching',\n",
" 'profession',\n",
" 'lead',\n",
" 'me',\n",
" 'to',\n",
" 'believe',\n",
" 'that',\n",
" 'bromwell',\n",
" 'high',\n",
" \"'\",\n",
" 's',\n",
" 'satire',\n",
" 'is',\n",
" 'much',\n",
" 'closer',\n",
" 'to',\n",
" 'reality',\n",
" 'than',\n",
" 'is',\n",
" 'teachers',\n",
" '.',\n",
" 'the',\n",
" 'scramble',\n",
" 'to',\n",
" 'survive',\n",
" 'financially',\n",
" ',',\n",
" 'the',\n",
" 'insightful',\n",
" 'students',\n",
" 'who',\n",
" 'can',\n",
" 'see',\n",
" 'right',\n",
" 'through',\n",
" 'their',\n",
" 'pathetic',\n",
" 'teachers',\n",
" \"'\",\n",
" 'pomp',\n",
" ',',\n",
" 'the',\n",
" 'pettiness',\n",
" 'of',\n",
" 'the',\n",
" 'whole',\n",
" 'situation',\n",
" ',',\n",
" 'all',\n",
" 'remind',\n",
" 'me',\n",
" 'of',\n",
" 'the',\n",
" 'schools',\n",
" 'i',\n",
" 'knew',\n",
" 'and',\n",
" 'their',\n",
" 'students',\n",
" '.',\n",
" 'when',\n",
" 'i',\n",
" 'saw',\n",
" 'the',\n",
" 'episode',\n",
" 'in',\n",
" 'which',\n",
" 'a',\n",
" 'student',\n",
" 'repeatedly',\n",
" 'tried',\n",
" 'to',\n",
" 'burn',\n",
" 'down',\n",
" 'the',\n",
" 'school',\n",
" ',',\n",
" 'i',\n",
" 'immediately',\n",
" 'recalled',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" 'at',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" '.',\n",
" 'high',\n",
" '.',\n",
" 'a',\n",
" 'classic',\n",
" 'line',\n",
" 'inspector',\n",
" 'i',\n",
" \"'\",\n",
" 'm',\n",
" 'here',\n",
" 'to',\n",
" 'sack',\n",
" 'one',\n",
" 'of',\n",
" 'your',\n",
" 'teachers',\n",
" '.',\n",
" 'student',\n",
" 'welcome',\n",
" 'to',\n",
" 'bromwell',\n",
" 'high',\n",
" '.',\n",
" 'i',\n",
" 'expect',\n",
" 'that',\n",
" 'many',\n",
" 'adults',\n",
" 'of',\n",
" 'my',\n",
" 'age',\n",
" 'think',\n",
" 'that',\n",
" 'bromwell',\n",
" 'high',\n",
" 'is',\n",
" 'far',\n",
" 'fetched',\n",
" '.',\n",
" 'what',\n",
" 'a',\n",
" 'pity',\n",
" 'that',\n",
" 'it',\n",
" 'isn',\n",
" \"'\",\n",
" 't',\n",
" '!']}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[0]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "15e48bfb",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-07-08 22:11:54 +08:00
"Loading cached split indices for dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-90b2a85f23273ecd.arrow and /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-99371bdf1a536e7c.arrow\n"
]
}
],
2021-07-08 01:04:25 +08:00
"source": [
"test_size = 0.25\n",
"\n",
"train_valid_data = train_data.train_test_split(test_size=test_size)\n",
"train_data = train_valid_data['train']\n",
"valid_data = train_valid_data['test']"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "881e83b3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'label': 1,\n",
" 'text': \"Made in 1946 and released in 1948, The Lady and Shanghai was one of the big films made by Welles after returning from relative exile for making Citizen Kane. Dark, brooding and expressing some early Cold War paranoia, this film stands tall as a Film-Noir crime film. The cinematography of this film is filled with Welles' characteristic quirks of odd angles, quick cuts, long pans and sinister lighting. The use of ambient street music is a precursor to the incredible long opening shot in Touch of Evil, and the mysterious Chinese characters and the sequences in Chinatown can only be considered as the inspiration, in many ways, to Roman Polanski's Chinatown. Unfortunately, it is Welles' obsession with technical filmmaking that hurts this film in its entirety. The plot of this story is often lost behind a sometimes incomprehensible clutter of film techniques.<br /><br />However, despite this criticism, the story combined with wonderful performances by Welles, Hayworth and especially Glenn Anders (Laughter) make this film a joy to watch. Orson Welles pulls off not only the Irish brogue, but the torn identities as the honest but dangerous sailor. Rita Hayworth, who was married to Welles at the time, breaks with her usual roles as a sex goddess and takes on a role of real depth and contradictions. Finally, Glenn Anders strange and bizarre portrayal or Elsa's husbands' law partner is nothing short of classic!\",\n",
" 'tokens': ['made',\n",
" 'in',\n",
" '1946',\n",
" 'and',\n",
" 'released',\n",
" 'in',\n",
" '1948',\n",
" ',',\n",
" 'the',\n",
" 'lady',\n",
" 'and',\n",
" 'shanghai',\n",
" 'was',\n",
" 'one',\n",
" 'of',\n",
" 'the',\n",
" 'big',\n",
" 'films',\n",
" 'made',\n",
" 'by',\n",
" 'welles',\n",
" 'after',\n",
" 'returning',\n",
" 'from',\n",
" 'relative',\n",
" 'exile',\n",
" 'for',\n",
" 'making',\n",
" 'citizen',\n",
" 'kane',\n",
" '.',\n",
" 'dark',\n",
" ',',\n",
" 'brooding',\n",
" 'and',\n",
" 'expressing',\n",
" 'some',\n",
" 'early',\n",
" 'cold',\n",
" 'war',\n",
" 'paranoia',\n",
" ',',\n",
" 'this',\n",
" 'film',\n",
" 'stands',\n",
" 'tall',\n",
" 'as',\n",
" 'a',\n",
" 'film-noir',\n",
" 'crime',\n",
" 'film',\n",
" '.',\n",
" 'the',\n",
" 'cinematography',\n",
" 'of',\n",
" 'this',\n",
" 'film',\n",
" 'is',\n",
" 'filled',\n",
" 'with',\n",
" 'welles',\n",
" \"'\",\n",
" 'characteristic',\n",
" 'quirks',\n",
" 'of',\n",
" 'odd',\n",
" 'angles',\n",
" ',',\n",
" 'quick',\n",
" 'cuts',\n",
" ',',\n",
" 'long',\n",
" 'pans',\n",
" 'and',\n",
" 'sinister',\n",
" 'lighting',\n",
" '.',\n",
" 'the',\n",
" 'use',\n",
" 'of',\n",
" 'ambient',\n",
" 'street',\n",
" 'music',\n",
" 'is',\n",
" 'a',\n",
" 'precursor',\n",
" 'to',\n",
" 'the',\n",
" 'incredible',\n",
" 'long',\n",
" 'opening',\n",
" 'shot',\n",
" 'in',\n",
" 'touch',\n",
" 'of',\n",
" 'evil',\n",
" ',',\n",
" 'and',\n",
" 'the',\n",
" 'mysterious',\n",
" 'chinese',\n",
" 'characters',\n",
" 'and',\n",
" 'the',\n",
" 'sequences',\n",
" 'in',\n",
" 'chinatown',\n",
" 'can',\n",
" 'only',\n",
" 'be',\n",
" 'considered',\n",
" 'as',\n",
" 'the',\n",
" 'inspiration',\n",
" ',',\n",
" 'in',\n",
" 'many',\n",
" 'ways',\n",
" ',',\n",
" 'to',\n",
" 'roman',\n",
" 'polanski',\n",
" \"'\",\n",
" 's',\n",
" 'chinatown',\n",
" '.',\n",
" 'unfortunately',\n",
" ',',\n",
" 'it',\n",
" 'is',\n",
" 'welles',\n",
" \"'\",\n",
" 'obsession',\n",
" 'with',\n",
" 'technical',\n",
" 'filmmaking',\n",
" 'that',\n",
" 'hurts',\n",
" 'this',\n",
" 'film',\n",
" 'in',\n",
" 'its',\n",
" 'entirety',\n",
" '.',\n",
" 'the',\n",
" 'plot',\n",
" 'of',\n",
" 'this',\n",
" 'story',\n",
" 'is',\n",
" 'often',\n",
" 'lost',\n",
" 'behind',\n",
" 'a',\n",
" 'sometimes',\n",
" 'incomprehensible',\n",
" 'clutter',\n",
" 'of',\n",
" 'film',\n",
" 'techniques',\n",
" '.',\n",
" 'however',\n",
" ',',\n",
" 'despite',\n",
" 'this',\n",
" 'criticism',\n",
" ',',\n",
" 'the',\n",
" 'story',\n",
" 'combined',\n",
" 'with',\n",
" 'wonderful',\n",
" 'performances',\n",
" 'by',\n",
" 'welles',\n",
" ',',\n",
" 'hayworth',\n",
" 'and',\n",
" 'especially',\n",
" 'glenn',\n",
" 'anders',\n",
" '(',\n",
" 'laughter',\n",
" ')',\n",
" 'make',\n",
" 'this',\n",
" 'film',\n",
" 'a',\n",
" 'joy',\n",
" 'to',\n",
" 'watch',\n",
" '.',\n",
" 'orson',\n",
" 'welles',\n",
" 'pulls',\n",
" 'off',\n",
" 'not',\n",
" 'only',\n",
" 'the',\n",
" 'irish',\n",
" 'brogue',\n",
" ',',\n",
" 'but',\n",
" 'the',\n",
" 'torn',\n",
" 'identities',\n",
" 'as',\n",
" 'the',\n",
" 'honest',\n",
" 'but',\n",
" 'dangerous',\n",
" 'sailor',\n",
" '.',\n",
" 'rita',\n",
" 'hayworth',\n",
" ',',\n",
" 'who',\n",
" 'was',\n",
" 'married',\n",
" 'to',\n",
" 'welles',\n",
" 'at',\n",
" 'the',\n",
" 'time',\n",
" ',',\n",
" 'breaks',\n",
" 'with',\n",
" 'her',\n",
" 'usual',\n",
" 'roles',\n",
" 'as',\n",
" 'a',\n",
" 'sex',\n",
" 'goddess',\n",
" 'and',\n",
" 'takes',\n",
" 'on',\n",
" 'a',\n",
" 'role',\n",
" 'of',\n",
" 'real',\n",
" 'depth',\n",
" 'and',\n",
" 'contradictions',\n",
" '.',\n",
" 'finally',\n",
" ',',\n",
" 'glenn',\n",
" 'anders',\n",
" 'strange',\n",
" 'and',\n",
" 'bizarre',\n",
" 'portrayal',\n",
" 'or',\n",
" 'elsa',\n",
" \"'\"]}"
2021-07-08 01:04:25 +08:00
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[0]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "c227e4fc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(18750, 6250, 25000)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_data), len(valid_data), len(test_data)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "4865e94a",
"metadata": {},
"outputs": [],
"source": [
"min_freq = 5\n",
"special_tokens = ['<unk>', '<pad>']\n",
"\n",
"vocab = torchtext.vocab.build_vocab_from_iterator(train_data['tokens'],\n",
" min_freq=min_freq,\n",
" specials=special_tokens)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "123ceb33",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"21543"
2021-07-08 01:04:25 +08:00
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(vocab)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "d4ec89de",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['<unk>', '<pad>', 'the', '.', ',', 'a', 'and', 'of', 'to', \"'\"]"
2021-07-08 01:04:25 +08:00
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab.get_itos()[:10]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "29ac49c8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"unk_index = vocab['<unk>']\n",
"\n",
"unk_index"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "447020e1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pad_index = vocab['<pad>']\n",
"\n",
"pad_index"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "201b5383",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"'some_token' in vocab"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "7a951ea0",
"metadata": {},
"outputs": [],
"source": [
"vocab.set_default_index(unk_index)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "407fe05d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab['some_token']"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "76518d11",
"metadata": {},
"outputs": [],
"source": [
"def numericalize_data(example, vocab):\n",
2021-07-08 18:26:28 +08:00
" ids = [vocab[token] for token in example['tokens']]\n",
" return {'ids': ids}"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "dacaeaef",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
2021-07-08 22:11:54 +08:00
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-4fa96f7122a515e2.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-cabd43c688223ded.arrow\n",
"Loading cached processed dataset at /home/ben/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-087b09fd94e05553.arrow\n"
]
2021-07-08 01:04:25 +08:00
}
],
"source": [
"train_data = train_data.map(numericalize_data, fn_kwargs={'vocab': vocab})\n",
"valid_data = valid_data.map(numericalize_data, fn_kwargs={'vocab': vocab})\n",
"test_data = test_data.map(numericalize_data, fn_kwargs={'vocab': vocab})"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "08751c45",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'label': 1,\n",
" 'text': \"Made in 1946 and released in 1948, The Lady and Shanghai was one of the big films made by Welles after returning from relative exile for making Citizen Kane. Dark, brooding and expressing some early Cold War paranoia, this film stands tall as a Film-Noir crime film. The cinematography of this film is filled with Welles' characteristic quirks of odd angles, quick cuts, long pans and sinister lighting. The use of ambient street music is a precursor to the incredible long opening shot in Touch of Evil, and the mysterious Chinese characters and the sequences in Chinatown can only be considered as the inspiration, in many ways, to Roman Polanski's Chinatown. Unfortunately, it is Welles' obsession with technical filmmaking that hurts this film in its entirety. The plot of this story is often lost behind a sometimes incomprehensible clutter of film techniques.<br /><br />However, despite this criticism, the story combined with wonderful performances by Welles, Hayworth and especially Glenn Anders (Laughter) make this film a joy to watch. Orson Welles pulls off not only the Irish brogue, but the torn identities as the honest but dangerous sailor. Rita Hayworth, who was married to Welles at the time, breaks with her usual roles as a sex goddess and takes on a role of real depth and contradictions. Finally, Glenn Anders strange and bizarre portrayal or Elsa's husbands' law partner is nothing short of classic!\",\n",
" 'tokens': ['made',\n",
" 'in',\n",
" '1946',\n",
" 'and',\n",
" 'released',\n",
" 'in',\n",
" '1948',\n",
" ',',\n",
" 'the',\n",
" 'lady',\n",
" 'and',\n",
" 'shanghai',\n",
" 'was',\n",
" 'one',\n",
" 'of',\n",
" 'the',\n",
" 'big',\n",
" 'films',\n",
" 'made',\n",
" 'by',\n",
" 'welles',\n",
" 'after',\n",
" 'returning',\n",
" 'from',\n",
" 'relative',\n",
" 'exile',\n",
" 'for',\n",
" 'making',\n",
" 'citizen',\n",
" 'kane',\n",
" '.',\n",
" 'dark',\n",
" ',',\n",
" 'brooding',\n",
" 'and',\n",
" 'expressing',\n",
" 'some',\n",
" 'early',\n",
" 'cold',\n",
" 'war',\n",
" 'paranoia',\n",
" ',',\n",
" 'this',\n",
" 'film',\n",
" 'stands',\n",
" 'tall',\n",
" 'as',\n",
" 'a',\n",
" 'film-noir',\n",
" 'crime',\n",
" 'film',\n",
" '.',\n",
" 'the',\n",
" 'cinematography',\n",
" 'of',\n",
" 'this',\n",
" 'film',\n",
" 'is',\n",
" 'filled',\n",
" 'with',\n",
" 'welles',\n",
" \"'\",\n",
" 'characteristic',\n",
" 'quirks',\n",
" 'of',\n",
" 'odd',\n",
" 'angles',\n",
" ',',\n",
" 'quick',\n",
" 'cuts',\n",
" ',',\n",
" 'long',\n",
" 'pans',\n",
" 'and',\n",
" 'sinister',\n",
" 'lighting',\n",
" '.',\n",
" 'the',\n",
" 'use',\n",
" 'of',\n",
" 'ambient',\n",
" 'street',\n",
" 'music',\n",
" 'is',\n",
" 'a',\n",
" 'precursor',\n",
" 'to',\n",
" 'the',\n",
" 'incredible',\n",
" 'long',\n",
" 'opening',\n",
" 'shot',\n",
" 'in',\n",
" 'touch',\n",
" 'of',\n",
" 'evil',\n",
" ',',\n",
" 'and',\n",
" 'the',\n",
" 'mysterious',\n",
" 'chinese',\n",
" 'characters',\n",
" 'and',\n",
" 'the',\n",
" 'sequences',\n",
" 'in',\n",
" 'chinatown',\n",
" 'can',\n",
" 'only',\n",
" 'be',\n",
" 'considered',\n",
" 'as',\n",
" 'the',\n",
" 'inspiration',\n",
" ',',\n",
" 'in',\n",
" 'many',\n",
" 'ways',\n",
" ',',\n",
" 'to',\n",
" 'roman',\n",
" 'polanski',\n",
" \"'\",\n",
" 's',\n",
" 'chinatown',\n",
" '.',\n",
" 'unfortunately',\n",
" ',',\n",
" 'it',\n",
" 'is',\n",
" 'welles',\n",
" \"'\",\n",
" 'obsession',\n",
" 'with',\n",
" 'technical',\n",
" 'filmmaking',\n",
" 'that',\n",
" 'hurts',\n",
" 'this',\n",
" 'film',\n",
" 'in',\n",
" 'its',\n",
" 'entirety',\n",
" '.',\n",
" 'the',\n",
" 'plot',\n",
" 'of',\n",
" 'this',\n",
" 'story',\n",
" 'is',\n",
" 'often',\n",
" 'lost',\n",
" 'behind',\n",
" 'a',\n",
" 'sometimes',\n",
" 'incomprehensible',\n",
" 'clutter',\n",
" 'of',\n",
" 'film',\n",
" 'techniques',\n",
" '.',\n",
" 'however',\n",
" ',',\n",
" 'despite',\n",
" 'this',\n",
" 'criticism',\n",
" ',',\n",
" 'the',\n",
" 'story',\n",
" 'combined',\n",
" 'with',\n",
" 'wonderful',\n",
" 'performances',\n",
" 'by',\n",
" 'welles',\n",
" ',',\n",
" 'hayworth',\n",
" 'and',\n",
" 'especially',\n",
" 'glenn',\n",
" 'anders',\n",
" '(',\n",
" 'laughter',\n",
" ')',\n",
" 'make',\n",
" 'this',\n",
" 'film',\n",
" 'a',\n",
" 'joy',\n",
" 'to',\n",
" 'watch',\n",
" '.',\n",
" 'orson',\n",
" 'welles',\n",
" 'pulls',\n",
" 'off',\n",
" 'not',\n",
" 'only',\n",
" 'the',\n",
" 'irish',\n",
" 'brogue',\n",
" ',',\n",
" 'but',\n",
" 'the',\n",
" 'torn',\n",
" 'identities',\n",
" 'as',\n",
" 'the',\n",
" 'honest',\n",
" 'but',\n",
" 'dangerous',\n",
" 'sailor',\n",
" '.',\n",
" 'rita',\n",
" 'hayworth',\n",
" ',',\n",
" 'who',\n",
" 'was',\n",
" 'married',\n",
" 'to',\n",
" 'welles',\n",
" 'at',\n",
" 'the',\n",
" 'time',\n",
" ',',\n",
" 'breaks',\n",
" 'with',\n",
" 'her',\n",
" 'usual',\n",
" 'roles',\n",
" 'as',\n",
" 'a',\n",
" 'sex',\n",
" 'goddess',\n",
" 'and',\n",
" 'takes',\n",
" 'on',\n",
" 'a',\n",
" 'role',\n",
" 'of',\n",
" 'real',\n",
" 'depth',\n",
" 'and',\n",
" 'contradictions',\n",
" '.',\n",
" 'finally',\n",
" ',',\n",
" 'glenn',\n",
" 'anders',\n",
" 'strange',\n",
" 'and',\n",
" 'bizarre',\n",
" 'portrayal',\n",
" 'or',\n",
" 'elsa',\n",
" \"'\"],\n",
" 'ids': [98,\n",
" 13,\n",
" 6329,\n",
" 6,\n",
" 559,\n",
" 13,\n",
" 6491,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 2,\n",
" 763,\n",
" 6,\n",
" 6300,\n",
2021-07-08 01:04:25 +08:00
" 17,\n",
" 34,\n",
2021-07-08 01:04:25 +08:00
" 7,\n",
" 2,\n",
" 195,\n",
" 116,\n",
" 98,\n",
2021-07-08 01:04:25 +08:00
" 40,\n",
" 2302,\n",
" 102,\n",
" 3497,\n",
" 44,\n",
" 3318,\n",
" 15422,\n",
" 21,\n",
" 261,\n",
" 3609,\n",
" 3433,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
" 474,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 6093,\n",
" 6,\n",
" 10888,\n",
" 54,\n",
" 396,\n",
" 1198,\n",
" 338,\n",
" 4479,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 14,\n",
" 23,\n",
" 1481,\n",
" 3596,\n",
" 19,\n",
" 5,\n",
" 13453,\n",
" 850,\n",
2021-07-08 01:04:25 +08:00
" 23,\n",
" 3,\n",
" 2,\n",
" 639,\n",
2021-07-08 01:04:25 +08:00
" 7,\n",
" 14,\n",
" 23,\n",
" 10,\n",
" 1073,\n",
" 20,\n",
" 2302,\n",
2021-07-08 01:04:25 +08:00
" 9,\n",
" 7180,\n",
" 9372,\n",
2021-07-08 01:04:25 +08:00
" 7,\n",
" 1045,\n",
" 2522,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 1706,\n",
" 2115,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 212,\n",
" 8127,\n",
" 6,\n",
" 3179,\n",
" 1485,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
" 2,\n",
" 386,\n",
2021-07-08 01:04:25 +08:00
" 7,\n",
" 13210,\n",
" 860,\n",
" 233,\n",
2021-07-08 01:04:25 +08:00
" 10,\n",
" 5,\n",
" 12948,\n",
2021-07-08 01:04:25 +08:00
" 8,\n",
" 2,\n",
" 984,\n",
" 212,\n",
" 628,\n",
" 346,\n",
" 13,\n",
" 1228,\n",
2021-07-08 01:04:25 +08:00
" 7,\n",
" 462,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 6,\n",
2021-07-08 01:04:25 +08:00
" 2,\n",
" 1236,\n",
" 1675,\n",
" 114,\n",
" 6,\n",
2021-07-08 01:04:25 +08:00
" 2,\n",
" 905,\n",
" 13,\n",
" 10802,\n",
" 59,\n",
" 71,\n",
" 35,\n",
" 1132,\n",
" 19,\n",
2021-07-08 01:04:25 +08:00
" 2,\n",
" 3009,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 13,\n",
2021-07-08 01:04:25 +08:00
" 117,\n",
" 771,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 8,\n",
" 3582,\n",
" 3534,\n",
2021-07-08 01:04:25 +08:00
" 9,\n",
" 16,\n",
" 10802,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
" 446,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 11,\n",
" 10,\n",
" 2302,\n",
2021-07-08 01:04:25 +08:00
" 9,\n",
" 3013,\n",
" 20,\n",
" 1810,\n",
" 6389,\n",
2021-07-08 01:04:25 +08:00
" 15,\n",
" 4846,\n",
2021-07-08 01:04:25 +08:00
" 14,\n",
" 23,\n",
" 13,\n",
" 100,\n",
" 6865,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
" 2,\n",
" 113,\n",
2021-07-08 01:04:25 +08:00
" 7,\n",
" 14,\n",
" 64,\n",
2021-07-08 01:04:25 +08:00
" 10,\n",
" 406,\n",
" 443,\n",
" 527,\n",
" 5,\n",
" 525,\n",
" 4470,\n",
" 10812,\n",
2021-07-08 01:04:25 +08:00
" 7,\n",
" 23,\n",
" 3324,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
" 190,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 500,\n",
2021-07-08 01:04:25 +08:00
" 14,\n",
" 3049,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 2,\n",
" 64,\n",
" 2675,\n",
" 20,\n",
" 356,\n",
2021-07-08 01:04:25 +08:00
" 389,\n",
" 40,\n",
" 2302,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 7843,\n",
" 6,\n",
" 262,\n",
" 3111,\n",
" 14039,\n",
2021-07-08 01:04:25 +08:00
" 25,\n",
" 2146,\n",
2021-07-08 01:04:25 +08:00
" 24,\n",
" 106,\n",
2021-07-08 01:04:25 +08:00
" 14,\n",
" 23,\n",
" 5,\n",
" 1777,\n",
2021-07-08 01:04:25 +08:00
" 8,\n",
" 108,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
" 4281,\n",
" 2302,\n",
" 2890,\n",
" 137,\n",
2021-07-08 01:04:25 +08:00
" 29,\n",
" 71,\n",
2021-07-08 01:04:25 +08:00
" 2,\n",
" 2386,\n",
2021-07-08 01:04:25 +08:00
" 0,\n",
" 4,\n",
" 22,\n",
" 2,\n",
" 3544,\n",
" 7847,\n",
" 19,\n",
2021-07-08 01:04:25 +08:00
" 2,\n",
" 1172,\n",
2021-07-08 01:04:25 +08:00
" 22,\n",
" 1813,\n",
" 7915,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
" 6041,\n",
" 7843,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 42,\n",
" 17,\n",
" 922,\n",
2021-07-08 01:04:25 +08:00
" 8,\n",
" 2302,\n",
" 38,\n",
2021-07-08 01:04:25 +08:00
" 2,\n",
" 65,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 2100,\n",
" 20,\n",
" 50,\n",
" 604,\n",
" 556,\n",
2021-07-08 01:04:25 +08:00
" 19,\n",
" 5,\n",
" 416,\n",
" 11476,\n",
2021-07-08 01:04:25 +08:00
" 6,\n",
" 310,\n",
" 27,\n",
2021-07-08 01:04:25 +08:00
" 5,\n",
" 221,\n",
" 7,\n",
" 158,\n",
" 1248,\n",
" 6,\n",
" 16505,\n",
2021-07-08 01:04:25 +08:00
" 3,\n",
" 454,\n",
2021-07-08 01:04:25 +08:00
" 4,\n",
" 3111,\n",
" 14039,\n",
" 637,\n",
" 6,\n",
" 1079,\n",
" 1074,\n",
" 49,\n",
" 8928,\n",
" 9]}"
2021-07-08 01:04:25 +08:00
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[0]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "678d0397",
"metadata": {},
"outputs": [],
"source": [
"train_data = train_data.with_format(type='torch', columns=['ids', 'label'])\n",
"valid_data = valid_data.with_format(type='torch', columns=['ids', 'label'])\n",
"test_data = test_data.with_format(type='torch', columns=['ids', 'label'])"
]
},
{
"cell_type": "markdown",
"id": "00a00726",
"metadata": {},
"source": [
"Same thing as `set_format`, but not in-place."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "be56bf90",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'label': tensor(1),\n",
" 'ids': tensor([ 98, 13, 6329, 6, 559, 13, 6491, 4, 2, 763,\n",
" 6, 6300, 17, 34, 7, 2, 195, 116, 98, 40,\n",
" 2302, 102, 3497, 44, 3318, 15422, 21, 261, 3609, 3433,\n",
" 3, 474, 4, 6093, 6, 10888, 54, 396, 1198, 338,\n",
" 4479, 4, 14, 23, 1481, 3596, 19, 5, 13453, 850,\n",
" 23, 3, 2, 639, 7, 14, 23, 10, 1073, 20,\n",
" 2302, 9, 7180, 9372, 7, 1045, 2522, 4, 1706, 2115,\n",
" 4, 212, 8127, 6, 3179, 1485, 3, 2, 386, 7,\n",
" 13210, 860, 233, 10, 5, 12948, 8, 2, 984, 212,\n",
" 628, 346, 13, 1228, 7, 462, 4, 6, 2, 1236,\n",
" 1675, 114, 6, 2, 905, 13, 10802, 59, 71, 35,\n",
" 1132, 19, 2, 3009, 4, 13, 117, 771, 4, 8,\n",
" 3582, 3534, 9, 16, 10802, 3, 446, 4, 11, 10,\n",
" 2302, 9, 3013, 20, 1810, 6389, 15, 4846, 14, 23,\n",
" 13, 100, 6865, 3, 2, 113, 7, 14, 64, 10,\n",
" 406, 443, 527, 5, 525, 4470, 10812, 7, 23, 3324,\n",
" 3, 190, 4, 500, 14, 3049, 4, 2, 64, 2675,\n",
" 20, 356, 389, 40, 2302, 4, 7843, 6, 262, 3111,\n",
" 14039, 25, 2146, 24, 106, 14, 23, 5, 1777, 8,\n",
" 108, 3, 4281, 2302, 2890, 137, 29, 71, 2, 2386,\n",
" 0, 4, 22, 2, 3544, 7847, 19, 2, 1172, 22,\n",
" 1813, 7915, 3, 6041, 7843, 4, 42, 17, 922, 8,\n",
" 2302, 38, 2, 65, 4, 2100, 20, 50, 604, 556,\n",
" 19, 5, 416, 11476, 6, 310, 27, 5, 221, 7,\n",
" 158, 1248, 6, 16505, 3, 454, 4, 3111, 14039, 637,\n",
" 6, 1079, 1074, 49, 8928, 9])}"
2021-07-08 01:04:25 +08:00
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_data[0]"
]
},
{
"cell_type": "markdown",
"id": "d6ba2ac8",
"metadata": {},
"source": [
"Use `output_all_columns=True` to keep non-converted columns."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "081f04a6",
"metadata": {},
"outputs": [],
"source": [
"class NBoW(nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim, output_dim, pad_index):\n",
" super().__init__()\n",
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)\n",
" self.fc = nn.Linear(embedding_dim, output_dim)\n",
"\n",
" def forward(self, ids):\n",
" # ids = [batch size, seq len]\n",
" embedded = self.embedding(ids)\n",
2021-07-08 01:04:25 +08:00
" # embedded = [batch size, seq len, embedding dim]\n",
" pooled = embedded.mean(dim=1)\n",
" # pooled = [batch size, embedding dim]\n",
" prediction = self.fc(pooled)\n",
" # prediction = [batch size, output dim]\n",
" return prediction"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "97897898",
"metadata": {},
"outputs": [],
"source": [
"vocab_size = len(vocab)\n",
"embedding_dim = 300\n",
"output_dim = len(train_data.unique('label'))\n",
"\n",
"model = NBoW(vocab_size, embedding_dim, output_dim, pad_index)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "4acc5118",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The model has 6,463,502 trainable parameters\n"
]
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"print(f'The model has {count_parameters(model):,} trainable parameters')"
]
},
{
"cell_type": "code",
"execution_count": 30,
2021-07-08 01:04:25 +08:00
"id": "866e0b64",
"metadata": {},
"outputs": [],
"source": [
"vectors = torchtext.vocab.FastText()"
]
},
{
"cell_type": "code",
"execution_count": 31,
2021-07-08 01:04:25 +08:00
"id": "ead7be53",
"metadata": {},
"outputs": [],
"source": [
"hello_vector = vectors.get_vecs_by_tokens('hello')"
]
},
{
"cell_type": "code",
"execution_count": 32,
2021-07-08 01:04:25 +08:00
"id": "1a64ead7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([300])"
]
},
"execution_count": 32,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hello_vector.shape"
]
},
{
"cell_type": "code",
"execution_count": 33,
2021-07-08 01:04:25 +08:00
"id": "7ecc5d88",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-1.5945e-01, -1.8259e-01, 3.3443e-02, 1.8813e-01, -6.7903e-02,\n",
" -1.3663e-01, -2.5559e-01, 1.1000e-01, 1.7275e-01, 5.1971e-02,\n",
" -2.3302e-02, 3.8866e-02, -2.4515e-01, -2.1588e-01, 3.5925e-01,\n",
" -8.2526e-02, 1.2176e-01, -2.6775e-01, 1.0072e-01, -1.3639e-01,\n",
" -9.2658e-02, 5.1837e-01, 1.7736e-01, 9.4878e-02, -1.8461e-01,\n",
" -4.2829e-02, 1.4114e-02, 1.6811e-01, -1.8565e-01, 3.4976e-02,\n",
" -1.0293e-01, 1.7954e-01, -5.2766e-02, 7.2047e-02, -4.2704e-01,\n",
" -1.1616e-01, -9.4875e-03, 1.4199e-01, -2.2782e-01, -1.7292e-02,\n",
" 8.2802e-02, -4.4512e-01, -7.5935e-02, -1.4392e-01, -8.2461e-02,\n",
" 2.0123e-01, -9.5344e-02, -1.1042e-01, -4.6817e-01, 2.0362e-01,\n",
" -1.7140e-01, -4.9850e-01, 2.8963e-01, -1.0305e-01, 2.0393e-01,\n",
" 5.2971e-01, -2.5396e-01, -5.1891e-01, 2.9941e-01, 1.7933e-01,\n",
" 3.0683e-01, 2.5828e-01, -1.8168e-01, -1.0225e-01, -1.1435e-01,\n",
" -1.6304e-01, -1.2424e-01, 3.2814e-01, -2.3099e-01, 1.7912e-01,\n",
" 9.9206e-02, 1.8595e-01, 2.7996e-01, 1.8323e-01, -1.7397e-01,\n",
" 2.6633e-01, -1.8151e-02, 2.8386e-01, 1.7328e-01, 2.9131e-01,\n",
" 8.2289e-02, 1.8560e-01, -1.5544e-01, 2.3311e-01, 3.6578e-01,\n",
" -3.0802e-01, -1.5908e-01, 4.0382e-01, 1.5332e-01, -1.1630e-01,\n",
" 1.3978e-01, 6.4237e-02, 2.2087e-01, 8.2723e-02, 1.2785e-01,\n",
" -6.6854e-02, -2.3016e-02, -1.9224e-01, -5.4482e-02, 3.7509e-01,\n",
" 5.1194e-01, -2.3650e-01, -7.1224e-02, 8.1112e-02, -3.2017e-01,\n",
" 5.0264e-02, -3.3223e-01, 2.2167e-02, 9.9936e-02, -2.7215e-01,\n",
" -7.2833e-02, -3.6598e-01, 1.7541e-01, -3.1303e-01, -2.3134e-01,\n",
" -1.5491e-01, 3.2102e-01, 1.2347e-01, 7.3616e-02, 2.0575e-01,\n",
" 6.1732e-01, 7.1909e-02, -3.6930e-01, 4.7641e-01, 1.7456e-01,\n",
" 3.2928e-01, 2.8792e-01, -7.6989e-02, 2.7030e-01, 6.9828e-01,\n",
" 4.6247e-01, 4.1444e-01, -5.3405e-01, 4.4302e-01, 1.1631e-01,\n",
" -2.3425e-01, -1.5030e-01, -6.8092e-02, 3.3537e-01, 2.8618e-01,\n",
" -3.9781e-02, 2.3245e-01, 3.6262e-01, -1.7151e-01, -3.5204e-01,\n",
" 1.9951e-01, 1.1345e-01, -4.5134e-01, -3.9699e-03, -2.0620e-01,\n",
" -4.9251e-02, 1.0825e-01, 1.2571e-01, -2.8134e-01, 1.0355e-01,\n",
" 7.3498e-02, -2.6716e-01, -1.0001e-01, -2.2600e-01, 3.0784e-01,\n",
" 2.5934e-01, -1.8112e-03, -2.0522e-01, -2.5115e-01, -1.5368e-01,\n",
" 5.6060e-02, -6.4802e-02, 9.2786e-03, 2.6150e-01, -9.3972e-02,\n",
" -3.1032e-01, -2.6632e-01, -1.9598e-01, -4.5088e-02, -2.7611e-02,\n",
" -7.7027e-02, 1.5070e-01, 1.7185e-01, -8.5416e-02, -1.4448e-01,\n",
" -2.4800e-03, -3.2881e-01, -1.6913e-01, -1.2778e-01, -2.3352e-01,\n",
" 1.5178e-01, -6.9358e-01, -3.8922e-01, 3.7190e-01, 2.6020e-01,\n",
" -1.0232e-01, -6.0247e-01, -5.4548e-02, 6.6532e-01, -7.3208e-02,\n",
" -2.3644e-01, -2.5550e-01, 1.9755e-02, -4.8908e-01, -7.3706e-02,\n",
" 3.0545e-01, 2.4459e-01, 2.0426e-01, -3.0128e-01, 6.0666e-02,\n",
" 1.8107e-02, -9.6162e-02, -2.0348e-02, -1.9801e-04, 2.9652e-02,\n",
" 5.0787e-01, -2.0225e-01, -6.1565e-02, -2.7330e-01, -3.7789e-01,\n",
" -2.4373e-01, 9.4902e-02, -3.7236e-01, -8.5854e-02, 2.4096e-01,\n",
" -1.7998e-01, 7.3902e-02, -7.8217e-04, -1.8559e-01, -2.6445e-01,\n",
" -2.3306e-02, -1.8644e-01, -1.0638e-01, 8.9330e-02, 4.1039e-01,\n",
" 1.0452e-02, -9.8721e-03, -1.8335e-01, -2.8524e-01, -1.4771e-01,\n",
" -1.9499e-01, -1.0175e-01, 1.2292e-01, 8.3651e-02, -2.1228e-01,\n",
" 3.4773e-02, 6.1831e-02, 2.9237e-01, 1.4371e-01, -9.2354e-02,\n",
" 8.1267e-03, 2.7648e-01, 2.1753e-01, 2.6609e-01, -3.6083e-01,\n",
" 2.8347e-01, -2.9295e-01, -2.6441e-01, 2.1056e-01, 3.2068e-01,\n",
" -1.6156e-01, 1.5298e-01, -1.5577e-01, 2.2035e-01, -1.1888e-01,\n",
" 1.3766e-01, -9.9048e-02, 4.1584e-01, -3.6029e-02, -6.2504e-02,\n",
" 3.3177e-01, -1.3997e-01, 8.7884e-02, -2.1428e-01, -6.2643e-01,\n",
" -3.1293e-01, -3.4895e-01, 5.2294e-01, -1.2635e-01, -1.9371e-01,\n",
" -2.0631e-01, 5.3758e-01, -1.1522e-01, -2.3659e-01, 2.0457e-01,\n",
" 1.9534e-01, 3.3260e-01, -2.2254e-01, 8.1346e-02, -7.2798e-02,\n",
" -8.6357e-04, -1.0199e-01, 3.1601e-01, 2.0040e-01, 1.9014e-01,\n",
" -9.6766e-02, 2.5155e-01, -2.0484e-01, -4.5859e-01, 1.1687e-01,\n",
" -3.3574e-01, -3.3371e-01, 8.6787e-02, 2.4920e-01, 6.5367e-02])"
]
},
"execution_count": 33,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hello_vector"
]
},
{
"cell_type": "code",
"execution_count": 34,
2021-07-08 01:04:25 +08:00
"id": "e8540b4b",
"metadata": {},
"outputs": [],
"source": [
"pretrained_embedding = vectors.get_vecs_by_tokens(vocab.get_itos())"
]
},
{
"cell_type": "code",
"execution_count": 35,
2021-07-08 01:04:25 +08:00
"id": "9d31228e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([21543, 300])"
2021-07-08 01:04:25 +08:00
]
},
"execution_count": 35,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pretrained_embedding.shape"
]
},
{
"cell_type": "code",
"execution_count": 36,
2021-07-08 01:04:25 +08:00
"id": "3a6f4173",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([[-1.1258, -1.1524, -0.2506, ..., 0.8200, -0.6332, 1.2948],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.1483, 2.4187, 1.3279, ..., -1.0328, 1.1305, -0.5703],\n",
" ...,\n",
" [-0.9882, -0.5407, 1.2382, ..., 2.4935, 1.0714, -0.7917],\n",
" [-1.2230, 0.6308, 1.7523, ..., 0.9265, -0.1116, -0.3872],\n",
" [-1.6577, 0.1200, -0.0599, ..., -0.5380, 0.5277, -0.0379]],\n",
2021-07-08 01:04:25 +08:00
" requires_grad=True)"
]
},
"execution_count": 36,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.embedding.weight"
]
},
{
"cell_type": "code",
"execution_count": 37,
2021-07-08 01:04:25 +08:00
"id": "5c1cbd5c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.0653, -0.0930, -0.0176, ..., 0.1664, -0.1308, 0.0354],\n",
" ...,\n",
" [-0.1329, 0.2494, -0.3875, ..., 0.3734, 0.4520, -0.2060],\n",
" [-0.6976, 0.2878, 0.0754, ..., 0.4601, -0.4200, -0.2361],\n",
2021-07-08 01:04:25 +08:00
" [ 0.1161, -0.0390, 0.1120, ..., 0.0925, -0.1058, 0.5641]])"
]
},
"execution_count": 37,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pretrained_embedding"
]
},
{
"cell_type": "code",
"execution_count": 38,
2021-07-08 01:04:25 +08:00
"id": "6ea34c9b",
"metadata": {},
"outputs": [],
"source": [
"model.embedding.weight.data = pretrained_embedding"
]
},
{
"cell_type": "code",
"execution_count": 39,
2021-07-08 01:04:25 +08:00
"id": "1332d9a6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Parameter containing:\n",
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
" [-0.0653, -0.0930, -0.0176, ..., 0.1664, -0.1308, 0.0354],\n",
" ...,\n",
" [-0.1329, 0.2494, -0.3875, ..., 0.3734, 0.4520, -0.2060],\n",
" [-0.6976, 0.2878, 0.0754, ..., 0.4601, -0.4200, -0.2361],\n",
2021-07-08 01:04:25 +08:00
" [ 0.1161, -0.0390, 0.1120, ..., 0.0925, -0.1058, 0.5641]],\n",
" requires_grad=True)"
]
},
"execution_count": 39,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.embedding.weight"
]
},
{
"cell_type": "code",
"execution_count": 40,
2021-07-08 01:04:25 +08:00
"id": "4fcb95e0",
"metadata": {},
"outputs": [],
"source": [
"optimizer = optim.Adam(model.parameters())"
]
},
{
"cell_type": "code",
"execution_count": 41,
2021-07-08 01:04:25 +08:00
"id": "f8829cd4",
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 42,
2021-07-08 01:04:25 +08:00
"id": "7ed273e0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 42,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"device"
]
},
{
"cell_type": "code",
"execution_count": 43,
2021-07-08 01:04:25 +08:00
"id": "3cdaf3b3",
"metadata": {},
"outputs": [],
"source": [
"model = model.to(device)\n",
"criterion = criterion.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 44,
2021-07-08 01:04:25 +08:00
"id": "c721ad5d",
"metadata": {},
"outputs": [],
"source": [
"def collate(batch, pad_index):\n",
" batch_ids = [i['ids'] for i in batch]\n",
" batch_ids = nn.utils.rnn.pad_sequence(batch_ids, padding_value=pad_index, batch_first=True)\n",
" batch_label = [i['label'] for i in batch]\n",
" batch_label = torch.stack(batch_label)\n",
2021-07-08 01:04:25 +08:00
" batch = {'ids': batch_ids,\n",
" 'label': batch_label}\n",
2021-07-08 01:04:25 +08:00
" return batch"
]
},
{
"cell_type": "code",
"execution_count": 45,
2021-07-08 01:04:25 +08:00
"id": "adf5afb1",
"metadata": {},
"outputs": [],
"source": [
"batch_size = 512\n",
"\n",
"collate = functools.partial(collate, pad_index=pad_index)\n",
"\n",
"train_dataloader = torch.utils.data.DataLoader(train_data, \n",
" batch_size=batch_size, \n",
" collate_fn=collate, \n",
" shuffle=True)\n",
"\n",
"valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, collate_fn=collate)\n",
"test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, collate_fn=collate)"
]
},
{
"cell_type": "code",
"execution_count": 46,
2021-07-08 01:04:25 +08:00
"id": "729aa9c8",
"metadata": {},
"outputs": [],
"source": [
"def train(dataloader, model, criterion, optimizer, device):\n",
"\n",
" model.train()\n",
2021-07-08 22:11:54 +08:00
" epoch_losses = []\n",
" epoch_accs = []\n",
2021-07-08 01:04:25 +08:00
"\n",
2021-07-09 02:05:18 +08:00
" for batch in tqdm.tqdm(dataloader, desc='training...', file=sys.stdout):\n",
" ids = batch['ids'].to(device)\n",
2021-07-08 01:04:25 +08:00
" label = batch['label'].to(device)\n",
" prediction = model(ids)\n",
2021-07-08 01:04:25 +08:00
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
2021-07-08 22:11:54 +08:00
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
2021-07-08 01:04:25 +08:00
"\n",
2021-07-08 22:11:54 +08:00
" return epoch_losses, epoch_accs"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
"execution_count": 47,
2021-07-08 01:04:25 +08:00
"id": "e0a80c30",
"metadata": {},
"outputs": [],
"source": [
"def evaluate(dataloader, model, criterion, device):\n",
" \n",
" model.eval()\n",
2021-07-08 22:11:54 +08:00
" epoch_losses = []\n",
" epoch_accs = []\n",
2021-07-08 01:04:25 +08:00
"\n",
" with torch.no_grad():\n",
2021-07-09 02:05:18 +08:00
" for batch in tqdm.tqdm(dataloader, desc='evaluating...', file=sys.stdout):\n",
" ids = batch['ids'].to(device)\n",
2021-07-08 01:04:25 +08:00
" label = batch['label'].to(device)\n",
" prediction = model(ids)\n",
2021-07-08 01:04:25 +08:00
" loss = criterion(prediction, label)\n",
" accuracy = get_accuracy(prediction, label)\n",
2021-07-08 22:11:54 +08:00
" epoch_losses.append(loss.item())\n",
" epoch_accs.append(accuracy.item())\n",
2021-07-08 01:04:25 +08:00
"\n",
2021-07-08 22:11:54 +08:00
" return epoch_losses, epoch_accs"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
"execution_count": 48,
2021-07-08 01:04:25 +08:00
"id": "703aa1e1",
"metadata": {},
"outputs": [],
"source": [
"def get_accuracy(prediction, label):\n",
" batch_size, _ = prediction.shape\n",
2021-07-08 01:04:25 +08:00
" predicted_classes = prediction.argmax(dim=-1)\n",
" correct_predictions = predicted_classes.eq(label).sum()\n",
" accuracy = correct_predictions / batch_size\n",
" return accuracy"
]
},
{
"cell_type": "code",
"execution_count": 49,
2021-07-08 01:04:25 +08:00
"id": "31343f1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.13it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.12it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 1\n",
"train_loss: 0.684, train_acc: 0.604\n",
"valid_loss: 0.671, valid_acc: 0.682\n",
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.06it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.24it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 2\n",
"train_loss: 0.648, train_acc: 0.718\n",
"valid_loss: 0.627, valid_acc: 0.729\n",
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.65it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 15.88it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 3\n",
"train_loss: 0.588, train_acc: 0.764\n",
"valid_loss: 0.567, valid_acc: 0.769\n",
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.81it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 15.66it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 4\n",
"train_loss: 0.516, train_acc: 0.807\n",
"valid_loss: 0.504, valid_acc: 0.803\n",
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.80it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 15.93it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 5\n",
"train_loss: 0.446, train_acc: 0.847\n",
"valid_loss: 0.450, valid_acc: 0.833\n",
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.83it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.03it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 6\n",
"train_loss: 0.388, train_acc: 0.870\n",
"valid_loss: 0.411, valid_acc: 0.844\n",
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.40it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.37it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 7\n",
"train_loss: 0.343, train_acc: 0.886\n",
"valid_loss: 0.384, valid_acc: 0.852\n",
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:02<00:00, 15.13it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.03it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 8\n",
"train_loss: 0.308, train_acc: 0.899\n",
"valid_loss: 0.364, valid_acc: 0.857\n",
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.99it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.12it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 9\n",
"train_loss: 0.280, train_acc: 0.909\n",
"valid_loss: 0.349, valid_acc: 0.862\n",
2021-07-09 02:05:18 +08:00
"training...: 100%|██████████| 37/37 [00:02<00:00, 14.62it/s]\n",
"evaluating...: 100%|██████████| 13/13 [00:00<00:00, 16.37it/s]\n",
2021-07-08 01:04:25 +08:00
"epoch: 10\n",
"train_loss: 0.257, train_acc: 0.917\n",
"valid_loss: 0.336, valid_acc: 0.867\n"
2021-07-08 01:04:25 +08:00
]
}
],
"source": [
"n_epochs = 10\n",
"best_valid_loss = float('inf')\n",
"\n",
2021-07-08 22:11:54 +08:00
"train_losses = []\n",
"train_accs = []\n",
"valid_losses = []\n",
"valid_accs = []\n",
"\n",
2021-07-08 01:04:25 +08:00
"for epoch in range(n_epochs):\n",
"\n",
" train_loss, train_acc = train(train_dataloader, model, criterion, optimizer, device)\n",
" valid_loss, valid_acc = evaluate(valid_dataloader, model, criterion, device)\n",
"\n",
2021-07-08 22:11:54 +08:00
" train_losses.extend(train_loss)\n",
" train_accs.extend(train_acc)\n",
" valid_losses.extend(valid_loss)\n",
" valid_accs.extend(valid_acc)\n",
" \n",
" epoch_train_loss = np.mean(train_loss)\n",
" epoch_train_acc = np.mean(train_acc)\n",
" epoch_valid_loss = np.mean(valid_loss)\n",
" epoch_valid_acc = np.mean(valid_acc)\n",
" \n",
" if epoch_valid_loss < best_valid_loss:\n",
" best_valid_loss = epoch_valid_loss\n",
2021-07-08 01:04:25 +08:00
" torch.save(model.state_dict(), 'nbow.pt')\n",
" \n",
" print(f'epoch: {epoch+1}')\n",
2021-07-08 22:11:54 +08:00
" print(f'train_loss: {epoch_train_loss:.3f}, train_acc: {epoch_train_acc:.3f}')\n",
" print(f'valid_loss: {epoch_valid_loss:.3f}, valid_acc: {epoch_valid_acc:.3f}')"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
"execution_count": 50,
2021-07-08 22:11:54 +08:00
"id": "2d791c70",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFzCAYAAAB2A95GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAB+zElEQVR4nO3dd3xb1dnA8d/RlmVZ3juOs/cgi0DYe5TRAgUK3S0d9O2gAzrelq63dL59uwtdtKVQRqHsUSDsQBLI3ju2472HbI3z/nHvleUZJ7Esj+f7+eQj6d4r6Vgo8cNznvMcpbVGCCGEEEKMLFuyByCEEEIIMRFJECaEEEIIkQQShAkhhBBCJIEEYUIIIYQQSSBBmBBCCCFEEkgQJoQQQgiRBI5kD+BYZWdn69LS0mQPQwghhBDiqNavX1+rtc7p79yYC8JKS0tZt25dsochhBBCCHFUSqmDA52T6UghhBBCiCSQIEwIIYQQIgkkCBNCCCGESIIxVxMmhBBCiOETCoUoKysjGAwmeyhjmsfjobi4GKfTOeTnSBAmhBBCTGBlZWX4/X5KS0tRSiV7OGOS1pq6ujrKysqYMmXKkJ8n05FCCCHEBBYMBsnKypIA7AQopcjKyjrmbKIEYUIIIcQEJwHYiTuezzChQZhS6iKl1E6l1B6l1G39nP9fpdQG888upVRjIscjhBBCiNGlsbGR3/zmN8f13EsuuYTGxsYhX3/77bfzk5/85LjeKxESFoQppezAr4GLgbnA9UqpufHXaK2/oLVerLVeDPwS+FeixiOEEEKI0WewICwcDg/63CeffJL09PQEjGpkJDITtgLYo7Xep7XuAu4Drhjk+uuBexM4HiGEEEKMMrfddht79+5l8eLFfPnLX2b16tWcfvrpXH755cyda+RurrzySpYuXcq8efO48847Y88tLS2ltraWAwcOMGfOHD7+8Y8zb948LrjgAjo6OgZ93w0bNrBy5UoWLlzIu9/9bhoaGgD4xS9+wdy5c1m4cCHXXXcdAC+99BKLFy9m8eLFnHTSSbS0tAzLz57I1ZFFwOG4x2XAyf1dqJSaDEwBXkjgeIQQQggxiG8/tpVtFc3D+ppzC9P41mXzBjx/xx13sGXLFjZs2ADA6tWrefvtt9myZUtspeGf/vQnMjMz6ejoYPny5Vx11VVkZWX1eJ3du3dz7733ctddd/He976Xhx56iBtvvHHA9/3ABz7AL3/5S84880y++c1v8u1vf5uf//zn3HHHHezfvx+32x2b6vzJT37Cr3/9a1atWkVraysej+fEPhTTaCnMvw54UGsd6e+kUuompdQ6pdS6mpqahA4kFIny6MYKtNYJfR8hhBBC9G/FihU9Wj384he/YNGiRaxcuZLDhw+ze/fuPs+ZMmUKixcvBmDp0qUcOHBgwNdvamqisbGRM888E4APfvCDvPzyywAsXLiQG264gb///e84HEauatWqVdxyyy384he/oLGxMXb8RCUyE1YOTIp7XGwe6891wM0DvZDW+k7gToBly5YlNDr694YKvvTARrTWXLG4KJFvJYQQQowqg2WsRpLP54vdX716Nf/5z3944403SElJ4ayzzuq3FYTb7Y7dt9vtR52OHMgTTzzByy+/zGOPPcb3v/99Nm/ezG233call17Kk08+yapVq3jmmWeYPXv2cb1+vERmwtYCM5RSU5RSLoxA69HeFymlZgMZwBsJHMuQvfukIhZNSuc7j21jT3ULDW1dkhUTQgghEsTv9w9aY9XU1ERGRgYpKSns2LGDNWvWnPB7BgIBMjIyeOWVVwD429/+xplnnkk0GuXw4cOcffbZ/PCHP6SpqYnW1lb27t3LggULuPXWW1m+fDk7duw44TFAAjNhWuuwUuozwDOAHfiT1nqrUuo7wDqttRWQXQfcp0dJpGO3Ke54zwIu++WrnPczIzUZ8Dr58oWzaO0M8/SWSr552VyWlGQkeaRCCCHE2JeVlcWqVauYP38+F198MZdeemmP8xdddBG/+93vmDNnDrNmzWLlypXD8r533303n/zkJ2lvb2fq1Kn8+c9/JhKJcOONN9LU1ITWms9+9rOkp6fz3//937z44ovYbDbmzZvHxRdfPCxjUKMk9hmyZcuW6XXr1iX8fTaXNbG9spnmjhAv7Kjm9b11APhcdjrDUX5+3WIuXVDAlvJm6to6WTwpnfQUV8LHJYQQQgyn7du3M2fOnGQPY1zo77NUSq3XWi/r73rZO3IAC4oDLCgOAPDhVVP406v7CXidXDg/n4/8ZS23PriJR96p4D/bqwDI8bv5xqVzmJ2fxsy8VOk+LIQQQohBSRA2BHab4uNnTI09/tX7TuKS/3uF/2yv4gvnzWThpADfe3wbn7tvAwCfOGMqboeNB9aX8aULZvGeJUUSlAkhhBCiBwnCjkNBwMu9N62kqT3EyVONPiWnTM1ia0UTD6wr4/cv7zOv8/DFBzbyyIZyVk3P5p1DDbxrYSGXLCjAbpOgTAghhJjIJAg7TrPz03o89jjtLJ2cyUmTMkhPcRHwOrnpjKnc8+ZBfvjUDl7ZXUumz8UzW6v47uPbuGRBAfkBD/e+dYj5hQF+ef1J2CQwE0IIISYMCcKGmc2muO3i7t4hHzillEsXFNASDDMpM4X/bK/igXWHufetQ3SGo5RkpvDE5iMsKA7wyTOnJXHkQgghhBhJEoSNgKxUN1mpRhO5C+flc+G8fMKRKFUtnRQGPHzmH+/ww6d30N4VYXa+n8b2EA674lBdOxcvyGdeYSDJP4EQQgghhtto2bZownHYbRSle1FK8ZNrFvHuxUX84vndfPqet/naw5v5yoOb+NWLe/ji/RuJRsdWGxEhhBAikVJTUwGoqKjg6quv7veas846i/5aWg10PBkkEzYKeF12fvreRVx/cgkeh52sVBehSJT1Bxu45X6jsH9uYRp3vryPyZk+3n/KZH7/0l4uXlDA4knpyR6+EEIIkRSFhYU8+OCDyR7GcZNM2CihlGJ5aSYLigMUpnuZnOXjysVFzClI45b7N3LRz1/h8Y1H+N//7OLUO57n9y/v49rfv8H9aw8TjkSTPXwhhBDiuNx22238+te/jj2+/fbb+clPfkJrayvnnnsuS5YsYcGCBfz73//u89wDBw4wf/58ADo6OrjuuuuYM2cO7373u4e0d+S9997LggULmD9/PrfeeisAkUiED33oQ8yfP58FCxbwv//7v4CxifjcuXNZuHAh11133XD86JIJG3ZttbD2j1C8DKafe0IvZbMp/vfaRTy1uZJsv5tL5ufzyIYKHt1YwRfOm8GvXtjDVx7axC9f3M0XzpvJlYuLZIWlEEKI4/fUbVC5eXhfM38BXHzHgKevvfZaPv/5z3PzzTcDcP/99/PMM8/g8Xh4+OGHSUtLo7a2lpUrV3L55ZcP2Hfzt7/9LSkpKWzfvp1NmzaxZMmSQYdVUVHBrbfeyvr168nIyOCCCy7gkUceYdKkSZSXl7NlyxYAGhsbAbjjjjvYv38/brc7duxESSasP723ctIaHvs8PPsNiEaNxxvuhXV/6nntrmfh/xbB6v+BN38/LEOZnZ/GF86fyftXTiYr1c1HT5vCv29exVmzcrn/E6fw+/cvJeB1csv9G7n+rjVUNvXdWV4IIYQYrU466SSqq6upqKhg48aNZGRkMGnSJLTWfO1rX2PhwoWcd955lJeXU1VVNeDrvPzyy9x4440ALFy4kIULFw76vmvXruWss84iJycHh8PBDTfcwMsvv8zUqVPZt28f//Vf/8XTTz9NWlpa7DVvuOEG/v73v+NwDE8OSzJhvVVvh3/dBFf8GgrM/4Cb/gnr/2zcr90DXa1w4BXz8W644Huw9WF4+BOQNw/CnRBqT/hQbTbFhfPyOX9OHg+sP8ztj27j8l+9yoOfPJWSrJSEv78QQohxZpCMVSJdc801PPjgg1RWVnLttdcCcM8991BTU8P69etxOp2UlpYSDCY+0ZCRkcHGjRt55pln+N3vfsf999/Pn/70J5544glefvllHnvsMb7//e+zefPmEw7GJBPWW7AZ2mrgD+ca04o1O+H
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-07-08 22:14:05 +08:00
"fig = plt.figure(figsize=(10,6))\n",
2021-07-08 22:11:54 +08:00
"ax = fig.add_subplot(1,1,1)\n",
"ax.plot(train_losses, label='train loss')\n",
"ax.plot(valid_losses, label='valid loss')\n",
"plt.legend()\n",
"ax.set_xlabel('updates')\n",
"ax.set_ylabel('loss');"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "bc422190",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAFzCAYAAAB2A95GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAACXYUlEQVR4nOzdd3xdd33/8de5e+pqD8uy5W3Hduw4dibZCUmghAQICTv5MVoKlJZRAm3TQGmhZZTSAm1KGWGFMEMgkABZZNtJ7MROvC3b2vPq7n1+f3zPOfde6UqWbMny+DwfDz8k3aWjGyd65/P9fD9fTdd1hBBCCCHE8WWb6wsQQgghhDgdSQgTQgghhJgDEsKEEEIIIeaAhDAhhBBCiDkgIUwIIYQQYg5ICBNCCCGEmAOOub6A6aqvr9fb29vn+jKEEEIIIY7oueeeG9R1vaHSfSddCGtvb2fLli1zfRlCCCGEEEekadrBie6T5UghhBBCiDkgIUwIIYQQYg5ICBNCCCGEmAMnXU9YJdlsls7OTlKp1FxfijgCj8fD/PnzcTqdc30pQgghxJw6JUJYZ2cnwWCQ9vZ2NE2b68sRE9B1naGhITo7O1m0aNFcX44QQggxp06J5chUKkVdXZ0EsBOcpmnU1dVJxVIIIYTgFAlhgASwk4T8cxJCCCGUUyaEzaVwOMzXv/71o3rua17zGsLh8MxekBBCCCFOeBLCZsBkISyXy0363Pvvv5/q6upZuKpjo+s6hUJhri9DCCGEOGVJCJsBt912G/v27WP9+vV8/OMf55FHHuGiiy7iuuuu44wzzgDg+uuv5+yzz2b16tXceeed1nPb29sZHByko6ODVatW8d73vpfVq1fz6le/mmQyOe573XfffZx77rmcddZZXHnllfT19QEQi8W49dZbWbt2LWeeeSY/+9nPAPjd737Hhg0bWLduHVdccQUAd9xxB1/84het11yzZg0dHR10dHSwYsUK3vnOd7JmzRoOHz7M+9//fjZu3Mjq1av5x3/8R+s5mzdv5oILLmDdunWcc845RKNRLr74YrZu3Wo95lWvehXbtm2buTdaCCGEOIWcErsjS336vh283B2Z0dc8Y14V//i61RPe//nPf57t27dbAeSRRx7h+eefZ/v27dYuwG9961vU1taSTCbZtGkTb3zjG6mrqyt7nT179vCjH/2I//3f/+XNb34zP/vZz3j7299e9phXvepVPP3002iaxje/+U3+7d/+jS996Uv80z/9E6FQiJdeegmAkZERBgYGeO9738tjjz3GokWLGB4ePuLPumfPHr773e9y3nnnAfDP//zP1NbWks/nueKKK3jxxRdZuXIlN910Ez/+8Y/ZtGkTkUgEr9fLu9/9br7zne/wla98hd27d5NKpVi3bt2U32chhBDidHLKhbATxTnnnFM2huGrX/0qv/jFLwA4fPgwe/bsGRfCFi1axPr16wE4++yz6ejoGPe6nZ2d3HTTTfT09JDJZKzv8Yc//IG7777belxNTQ333XcfF198sfWY2traI173woULrQAGcM8993DnnXeSy+Xo6enh5ZdfRtM0Wlpa2LRpEwBVVVUA3HjjjfzTP/0TX/jCF/jWt77FLbfccsTvJ4QQQsymRCbHUCxDW61vri9lnFMuhE1WsTqe/H6/9fkjjzzCH/7wB5566il8Ph+XXnppxTENbrfb+txut1dcjvzQhz7ERz7yEa677joeeeQR7rjjjmlfm8PhKOv3Kr2W0us+cOAAX/ziF9m8eTM1NTXccsstk46X8Pl8XHXVVdx7773cc889PPfcc9O+NiGEEGIm/c+j+/nuUx288A9XnXA79KUnbAYEg0Gi0eiE94+OjlJTU4PP52Pnzp08/fTTR/29RkdHaW1tBeC73/2udftVV13F1772NevrkZERzjvvPB577DEOHDgAYC1Htre38/zzzwPw/PPPW/ePFYlE8Pv9hEIh+vr6+O1vfwvAihUr6OnpYfPmzQBEo1FrA8J73vMe/uqv/opNmzZRU1Nz1D+nEEIIMRO6w0nCiSyR5OQb5eaChLAZUFdXx4UXXsiaNWv4+Mc/Pu7+a665hlwux6pVq7jtttvKlvum64477uDGG2/k7LPPpr6+3rr97//+7xkZGWHNmjWsW7eOhx9+mIaGBu68807e8IY3sG7dOm666SYA3vjGNzI8PMzq1av5r//6L5YvX17xe61bt46zzjqLlStX8ta3vpULL7wQAJfLxY9//GM+9KEPsW7dOq666iqrQnb22WdTVVXFrbfeetQ/oxBCCDFTRhJZAPqjJ96gcE3X9bm+hmnZuHGjvmXLlrLbXnnlFVatWjVHVyRKdXd3c+mll7Jz505stsoZX/55CSGEOF7e9I0n2XJwhB+851wuXFp/5CfMME3TntN1fWOl+6QSJmbMXXfdxbnnnss///M/TxjAhBBCiONpJJEBTsxK2CnXmC/mzjvf+U7e+c53zvVlCCGEEJawsRzZF0nP8ZWMJ+UKIYQQQpySdF0nnDR6wiJp+qMphuOZOb6qIglhQgghhDglpHP5sq+j6Rz5gup974+meN9dz/G3P31xLi6tIglhQgghhDjpZXIFzv2XP/K1h/dat4XjWevzwyNJtneNsrN3Zk/VORYSwoQQQggxLaOJLNf+x5+O+ZjAjsE4r/73R+mPHHvT/EgiQziR5T/+uIfDwwnrNgC/y85LnWFyBZ2ucHJcxWyuSAibI4FAAFAjHd70pjdVfMyll17K2HEcQgghTn2Fgs4nf/4izx088pm/c+GlrlFe6YnwYmf4mF5ny8ERdvfF2FEhzL3cHeHN//0Ub/rGkzx3cKTsvif3DnLHr3aU3WY24GdyBT7/u51AMYQtawpirEqi63BoKHFM1z1TJITNsXnz5vHTn/50ri+jInMKvhBCiONr70CMHz17mJ893zXXl1LRgcEYAKPJ7KSP29Ub5a6nOia8vzusjuerND7i8b0DPNsxzJaDIzy0s494Osd//GEPmVyB3+3o5TtPdrB/IGY93ryWpio3z3Wo0GYGsxVNwbLX3j8YP8JPeHxICJsBt912W9mRQXfccQdf/OIXicViXHHFFWzYsIG1a9dy7733jntuR0cHa9asASCZTHLzzTezatUqbrjhhopnRwJ85jOfYdOmTaxZs4b3ve99mAN39+7dy5VXXsm6devYsGED+/btA+Bf//VfWbt2LevWreO2224Dyqtsg4ODtLe3A/Cd73yH6667jssvv5wrrrhi0p/hrrvu4swzz2TdunW84x3vIBqNsmjRIrJZ9Zc+EomUfS2EEGJqNneoCtj2rtE5vpLKDgyqStKRQtiPNx/m9nt3sLuv8tF+XSPq91yl8REjiSwOm0ZTlZv+SJpHdw/w73/YzfOHRhgydjg+tLPfenzYqHotbwoyFE+j67pVCVverELYmtYqQC2DnghOvTlhv70Nel+a2ddsXgvXfn7Cu2+66Sb++q//mg984AMA3HPPPTzwwAN4PB5+8YtfUFVVxeDgIOeddx7XXXfdhAeIfuMb38Dn8/HKK6/w4osvsmHDhoqP++AHP8jtt98OwDve8Q5+/etf87rXvY63ve1t3Hbbbdxwww2kUikKhQK//e1vuffee3nmmWfw+XzW+ZGTef7553nxxRepra0ll8tV/BlefvllPvvZz/Lkk09SX1/P8PAwwWCQSy+9lN/85jdcf/313H333bzhDW/A6XQe8XsKIYQo2mJUcnb2RMnmCzjtx6dmsm8gxqI6PzZb+e8pXdfZNxBjaaMKM1OthIWTKgTd/exhbn/dGePu7x6duBIWTmSo9rloqvLQH03TM6oeMxzPMFISwt5z0WLje6lrWdIQ4E97BhlNZq0ji5Y2qhagC5bU0zuaYm9/jHd961neeu4Crl7dfKS3ZdZIJWwGnHXWWfT399Pd3c22bduoqamhra0NXdf51Kc+xZlnnsmVV15JV1cXfX19E77OY489xtvf/nYAzjzzTM4888yKj3v44Yc599xzWbt2LQ899BA7duwgGo3S1dXFDTfcAIDH48Hn8/GHP/yBW2+9FZ/PB0B
"text/plain": [
"<Figure size 720x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2021-07-08 22:14:05 +08:00
"fig = plt.figure(figsize=(10,6))\n",
2021-07-08 22:11:54 +08:00
"ax = fig.add_subplot(1,1,1)\n",
"ax.plot(train_accs, label='train accuracy')\n",
"ax.plot(valid_accs, label='valid accuracy')\n",
"plt.legend()\n",
"ax.set_xlabel('updates')\n",
"ax.set_ylabel('accuracy');"
]
},
{
"cell_type": "code",
"execution_count": 52,
2021-07-08 01:04:25 +08:00
"id": "cac26e8e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2021-07-09 02:05:18 +08:00
"evaluating...: 100%|██████████| 49/49 [00:03<00:00, 15.38it/s]\n",
"test_loss: 0.353, test_acc: 0.857\n"
2021-07-08 01:04:25 +08:00
]
}
],
"source": [
"model.load_state_dict(torch.load('nbow.pt'))\n",
"\n",
"test_loss, test_acc = evaluate(test_dataloader, model, criterion, device)\n",
"\n",
2021-07-08 22:11:54 +08:00
"epoch_test_loss = np.mean(test_loss)\n",
"epoch_test_acc = np.mean(test_acc)\n",
"\n",
"print(f'test_loss: {epoch_test_loss:.3f}, test_acc: {epoch_test_acc:.3f}')"
2021-07-08 01:04:25 +08:00
]
},
{
"cell_type": "code",
2021-07-08 22:11:54 +08:00
"execution_count": 53,
2021-07-08 01:04:25 +08:00
"id": "b22e040a",
"metadata": {},
"outputs": [],
"source": [
"def predict_sentiment(text, model, tokenizer, vocab, device):\n",
" tokens = tokenizer(text)\n",
" ids = [vocab[t] for t in tokens]\n",
" tensor = torch.LongTensor(ids).unsqueeze(dim=0).to(device)\n",
" prediction = model(tensor).squeeze(dim=0)\n",
" probability = torch.softmax(prediction, dim=-1)\n",
" predicted_class = prediction.argmax(dim=-1).item()\n",
" predicted_probability = probability[predicted_class].item()\n",
" return predicted_class, predicted_probability"
]
},
{
"cell_type": "code",
2021-07-08 22:11:54 +08:00
"execution_count": 54,
2021-07-08 01:04:25 +08:00
"id": "9cfa14eb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0, 0.9999740123748779)"
2021-07-08 01:04:25 +08:00
]
},
2021-07-08 22:11:54 +08:00
"execution_count": 54,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
2021-07-08 22:11:54 +08:00
"execution_count": 55,
2021-07-08 01:04:25 +08:00
"id": "1da60d90",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 0.9999997615814209)"
2021-07-08 01:04:25 +08:00
]
},
2021-07-08 22:11:54 +08:00
"execution_count": 55,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
2021-07-08 22:11:54 +08:00
"execution_count": 56,
2021-07-08 01:04:25 +08:00
"id": "4bee6190",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 0.765370786190033)"
2021-07-08 01:04:25 +08:00
]
},
2021-07-08 22:11:54 +08:00
"execution_count": 56,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is not terrible, it's great!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
},
{
"cell_type": "code",
2021-07-08 22:11:54 +08:00
"execution_count": 57,
2021-07-08 01:04:25 +08:00
"id": "e3d55c92",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 0.765370786190033)"
2021-07-08 01:04:25 +08:00
]
},
2021-07-08 22:11:54 +08:00
"execution_count": 57,
2021-07-08 01:04:25 +08:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = \"This film is not great, it's terrible!\"\n",
"\n",
"predict_sentiment(text, model, tokenizer, vocab, device)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}