pytorch-sentiment-analysis/B - A Closer Look at Word Embeddings.ipynb

653 lines
15 KiB
Plaintext
Raw Normal View History

2018-06-22 22:12:04 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are 400000 words in the vocabulary\n"
]
}
],
"source": [
"import torchtext.vocab\n",
"\n",
"glove = torchtext.vocab.GloVe(name='6B', dim=100)\n",
"\n",
"print(f'There are {len(glove.itos)} words in the vocabulary')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([400000, 100])"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"glove.vectors.shape"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['the', ',', '.', 'of', 'to', 'and', 'in', 'a', '\"', \"'s\"]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"glove.itos[:10]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"glove.stoi['the']"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([100])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"glove.vectors[glove.stoi['the']].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"words not in vocab throw an error."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get a vector from a word:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def get_vector(embeddings, word):\n",
" assert word in embeddings.stoi, f'{word} not in vocab!'\n",
" return embeddings.vectors[embeddings.stoi[word]]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([100])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_vector(glove, 'the').shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"def closest_words(embeddings, vector, n=10):\n",
" distances = [(w, torch.dist(vector, get_vector(embeddings, w)).item()) for w in embeddings.itos]\n",
" return sorted(distances, key = lambda w: w[1])[:n]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('japan', 0.0),\n",
" ('japanese', 4.091249465942383),\n",
" ('korea', 4.551243782043457),\n",
" ('tokyo', 4.565995216369629),\n",
" ('china', 4.857661247253418),\n",
" ('thailand', 5.292530536651611),\n",
" ('indonesia', 5.313706874847412),\n",
" ('philippines', 5.3697509765625),\n",
" ('asia', 5.389328479766846),\n",
" ('vietnam', 5.42373514175415)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"closest_words(glove, get_vector(glove, 'japan'))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def print_tuples(tuples):\n",
" for w, d in tuples:\n",
" print(f'({d:02.04f}) {w}') "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.0000) japan\n",
"(4.0912) japanese\n",
"(4.5512) korea\n",
"(4.5660) tokyo\n",
"(4.8577) china\n",
"(5.2925) thailand\n",
"(5.3137) indonesia\n",
"(5.3698) philippines\n",
"(5.3893) asia\n",
"(5.4237) vietnam\n"
]
}
],
"source": [
"print_tuples(closest_words(glove, get_vector(glove, 'japan')))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"def analogy(embeddings, word1, word2, word3, n=5):\n",
" \n",
" candidate_words = closest_words(embeddings, get_vector(embeddings, word2) - get_vector(embeddings, word1) + get_vector(embeddings, word3), n+3)\n",
" \n",
" candidate_words = [x for x in candidate_words if x[0] not in [word1, word2, word3]][:n]\n",
" \n",
" print(f'\\n{word1} is to {word2} as {word3} is to...')\n",
" \n",
" return candidate_words"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"man is to king as woman is to...\n",
"(4.0811) queen\n",
"(4.6429) monarch\n",
"(4.9055) throne\n",
"(4.9216) elizabeth\n",
"(4.9811) prince\n"
]
}
],
"source": [
"print_tuples(analogy(glove, 'man', 'king', 'woman'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can think of vector('King') - vector('Man') as a \"royalty vector\", thus when you add this \"royality vector\" to woman, you get queen. If you add it to \"boy\" you should get \"prince\" and if you add to \"girl\" you should get princess. Let's test:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"man is to king as boy is to...\n",
"(5.3084) queen\n",
"(5.4616) prince\n",
"(5.5430) uncle\n",
"(5.6069) brother\n",
"(5.6418) son\n",
"\n",
"man is to king as girl is to...\n",
"(4.6916) queen\n",
"(5.3437) princess\n",
"(5.4683) prince\n",
"(5.5591) daughter\n",
"(5.5735) sister\n"
]
}
],
"source": [
"print_tuples(analogy(glove, 'man', 'king', 'boy'))\n",
"print_tuples(analogy(glove, 'man', 'king', 'girl'))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"man is to actor as woman is to...\n",
"(2.8133) actress\n",
"(5.0039) comedian\n",
"(5.1399) actresses\n",
"(5.2773) starred\n",
"(5.3085) screenwriter\n",
"\n",
"cat is to kitten as dog is to...\n",
"(3.8146) puppy\n",
"(4.2944) rottweiler\n",
"(4.5888) puppies\n",
"(4.6086) pooch\n",
"(4.6520) pug\n",
"\n",
"dog is to puppy as cat is to...\n",
"(3.8146) kitten\n",
"(4.0255) puppies\n",
"(4.1575) kittens\n",
"(4.1882) pterodactyl\n",
"(4.1945) scaredy\n",
"\n",
"russia is to moscow as france is to...\n",
"(3.2697) paris\n",
"(4.6857) french\n",
"(4.7085) lyon\n",
"(4.9087) strasbourg\n",
"(5.0362) marseille\n",
"\n",
"obama is to president as trump is to...\n",
"(6.4302) executive\n",
"(6.5149) founder\n",
"(6.6997) ceo\n",
"(6.7524) hilton\n",
"(6.7729) walt\n",
"\n",
"rich is to mansion as poor is to...\n",
"(5.8262) residence\n",
"(5.9444) riverside\n",
"(6.0283) hillside\n",
"(6.0328) abandoned\n",
"(6.0681) bungalow\n",
"\n",
"elvis is to rock as eminem is to...\n",
"(5.6597) rap\n",
"(6.2057) rappers\n",
"(6.2161) rapper\n",
"(6.2444) punk\n",
"(6.2690) hop\n",
"\n",
"paper is to newspaper as screen is to...\n",
"(4.7810) tv\n",
"(5.1049) television\n",
"(5.3818) cinema\n",
"(5.5524) feature\n",
"(5.5646) shows\n",
"\n",
"monet is to paint as michelangelo is to...\n",
"(6.0782) plaster\n",
"(6.3768) mold\n",
"(6.3922) tile\n",
"(6.5819) marble\n",
"(6.6524) image\n",
"\n",
"beer is to barley as wine is to...\n",
"(5.6021) grape\n",
"(5.6760) beans\n",
"(5.8174) grapes\n",
"(5.9035) lentils\n",
"(5.9454) figs\n",
"\n",
"earth is to moon as sun is to...\n",
"(6.2294) lee\n",
"(6.4125) kang\n",
"(6.4644) tan\n",
"(6.4757) yang\n",
"(6.4853) lin\n",
"\n",
"house is to roof as castle is to...\n",
"(6.2919) stonework\n",
"(6.3779) masonry\n",
"(6.4773) canopy\n",
"(6.4954) fortress\n",
"(6.5259) battlements\n",
"\n",
"building is to architect as software is to...\n",
"(5.8369) programmer\n",
"(6.8881) entrepreneur\n",
"(6.9240) inventor\n",
"(6.9730) developer\n",
"(6.9949) innovator\n",
"\n",
"boston is to bruins as phoenix is to...\n",
"(3.8546) suns\n",
"(4.1968) mavericks\n",
"(4.6126) coyotes\n",
"(4.6894) mavs\n",
"(4.6971) knicks\n",
"\n",
"good is to heaven as bad is to...\n",
"(4.3959) hell\n",
"(5.2864) ghosts\n",
"(5.2898) hades\n",
"(5.3414) madness\n",
"(5.3520) purgatory\n",
"\n",
"jordan is to basketball as woods is to...\n",
"(5.8607) golf\n",
"(6.4110) golfers\n",
"(6.4418) tournament\n",
"(6.4592) tennis\n",
"(6.6560) collegiate\n"
]
}
],
"source": [
"print_tuples(analogy(glove, 'man', 'actor', 'woman'))\n",
"print_tuples(analogy(glove, 'cat', 'kitten', 'dog'))\n",
"print_tuples(analogy(glove, 'dog', 'puppy', 'cat'))\n",
"print_tuples(analogy(glove, 'russia', 'moscow', 'france'))\n",
"print_tuples(analogy(glove, 'obama', 'president', 'trump'))\n",
"print_tuples(analogy(glove, 'rich', 'mansion', 'poor'))\n",
"print_tuples(analogy(glove, 'elvis', 'rock', 'eminem'))\n",
"print_tuples(analogy(glove, 'paper', 'newspaper', 'screen'))\n",
"print_tuples(analogy(glove, 'monet', 'paint', 'michelangelo'))\n",
"print_tuples(analogy(glove, 'beer', 'barley', 'wine'))\n",
"print_tuples(analogy(glove, 'earth', 'moon', 'sun'))\n",
"print_tuples(analogy(glove, 'house', 'roof', 'castle'))\n",
"print_tuples(analogy(glove, 'building', 'architect', 'software'))\n",
"print_tuples(analogy(glove, 'boston', 'bruins', 'phoenix'))\n",
"print_tuples(analogy(glove, 'good', 'heaven', 'bad'))\n",
"print_tuples(analogy(glove, 'jordan', 'basketball', 'woods'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"http://forums.fast.ai/t/nlp-any-libraries-dictionaries-out-there-for-fixing-common-spelling-errors/16411"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"glove = torchtext.vocab.GloVe(name='840B', dim=300)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0.0000) relieable\n",
"(5.0366) relyable\n",
"(5.2610) realible\n",
"(5.4719) realiable\n",
"(5.5402) relable\n",
"(5.5917) relaible\n",
"(5.6412) reliabe\n",
"(5.8802) relaiable\n",
"(5.9593) stabel\n",
"(5.9981) consitant\n"
]
}
],
"source": [
"print_tuples(closest_words(glove, get_vector(glove, 'relieable')))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"reliable_vector = get_vector(glove, 'reliable')\n",
"\n",
"reliable_misspellings = ['relieable', 'relyable', 'realible', 'realiable', 'relable', 'relaible', 'reliabe', 'relaiable']\n",
"\n",
"diff_reliable = [(reliable_vector - get_vector(glove, s)).unsqueeze(0) for s in reliable_misspellings]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"misspelling_vector = torch.cat(diff_reliable, dim=0).mean(dim=0)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(6.1090) because\n",
"(6.4250) even\n",
"(6.4358) fact\n",
"(6.4914) sure\n",
"(6.5094) though\n",
"(6.5601) obviously\n",
"(6.5682) reason\n",
"(6.5856) if\n",
"(6.6099) but\n",
"(6.6415) why\n"
]
}
],
"source": [
"#misspelling of because\n",
"\n",
"print_tuples(closest_words(glove, get_vector(glove, 'becuase') + misspelling_vector))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(5.4070) definitely\n",
"(5.5643) certainly\n",
"(5.7192) sure\n",
"(5.8152) well\n",
"(5.8588) always\n",
"(5.8812) also\n",
"(5.9557) simply\n",
"(5.9667) consider\n",
"(5.9821) probably\n",
"(5.9948) definately\n"
]
}
],
"source": [
"#misspelling of definitely\n",
"\n",
"print_tuples(closest_words(glove, get_vector(glove, 'defintiely') + misspelling_vector))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(5.9641) consistent\n",
"(6.3674) reliable\n",
"(7.0195) consistant\n",
"(7.0299) consistently\n",
"(7.1605) accurate\n",
"(7.2737) fairly\n",
"(7.3037) good\n",
"(7.3520) reasonable\n",
"(7.3801) dependable\n",
"(7.4027) ensure\n"
]
}
],
"source": [
"#misspelling of consistent\n",
"\n",
"print_tuples(closest_words(glove, get_vector(glove, 'consistant') + misspelling_vector))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(6.6117) package\n",
"(6.9315) packages\n",
"(7.0195) pakage\n",
"(7.0911) comes\n",
"(7.1241) provide\n",
"(7.1469) offer\n",
"(7.1861) reliable\n",
"(7.2431) well\n",
"(7.2434) choice\n",
"(7.2453) offering\n"
]
}
],
"source": [
"#misspelling of package\n",
"\n",
"print_tuples(closest_words(glove, get_vector(glove, 'pakage') + misspelling_vector))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}