{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# C - Loading, Saving and Freezing Embeddings\n", "\n", "This notebook will cover: how to load custom word embeddings in TorchText, how to save all the embeddings we learn during training and how to freeze/unfreeze embeddings during training. \n", "\n", "## Loading Custom Embeddings\n", "\n", "First, lets look at loading a custom set of embeddings.\n", "\n", "Your embeddings need to be formatted so each line starts with the word followed by the values of the embedding vector, all space separated. All vectors need to have the same number of elements.\n", "\n", "Let's look at the custom embeddings provided by these tutorials. These are 20-dimensional embeddings for 7 words." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "good 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n", "great 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n", "awesome 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0\n", "bad -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0\n", "terrible -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0\n", "awful -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0 -1.0\n", "kwyjibo 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5 0.5 -0.5\n", "\n" ] } ], "source": [ "with open('custom_embeddings/embeddings.txt', 'r') as f:\n", " print(f.read())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's setup the fields." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torchtext.legacy import data\n", "\n", "SEED = 1234\n", "\n", "torch.manual_seed(SEED)\n", "torch.backends.cudnn.deterministic = True\n", "\n", "TEXT = data.Field(tokenize = 'spacy')\n", "LABEL = data.LabelField(dtype = torch.float)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we'll load our dataset and create the validation set." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from torchtext.legacy import datasets\n", "import random\n", "\n", "train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n", "\n", "train_data, valid_data = train_data.split(random_state = random.seed(SEED))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can only load our custom embeddings after they have been turned into a `Vectors` object.\n", "\n", "We create a `Vector` object by passing it the location of the embeddings (`name`), a location for the cached embeddings (`cache`) and a function that will later initialize tokens in our embeddings that aren't within our dataset (`unk_init`). As have done in previous notebooks, we have initialized these to $\\mathcal{N}(0,1)$." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/7 [00:00= FREEZE_FOR:\n", " #unfreeze embeddings\n", " model.embedding.weight.requires_grad = unfrozen = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another option would be to unfreeze the embeddings whenever the validation loss stops increasing using the following code snippet instead of the `FREEZE_FOR` condition:\n", " \n", "```python\n", "if valid_loss < best_valid_loss:\n", " best_valid_loss = valid_loss\n", " torch.save(model.state_dict(), 'tutC-model.pt')\n", "else:\n", " #unfreeze embeddings\n", " model.embedding.weight.requires_grad = unfrozen = True\n", "```" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test Loss: 0.396 | Test Acc: 82.36%\n" ] } ], "source": [ "model.load_state_dict(torch.load('tutC-model.pt'))\n", "\n", "test_loss, test_acc = evaluate(model, test_iterator, criterion)\n", "\n", "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Saving Embeddings\n", "\n", "We might want to re-use the embeddings we have trained here with another model. To do this, we'll write a function that will loop through our vocabulary, getting the word and embedding for each word, writing them to a text file in the same format as our custom embeddings so they can be used with TorchText again.\n", "\n", "Currently, TorchText Vectors seem to have issues with loading certain unicode words, so we skip these by only writing words without unicode symbols. **If you know a better solution to this then let me know**" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "from tqdm import tqdm\n", "\n", "def write_embeddings(path, embeddings, vocab):\n", " \n", " with open(path, 'w') as f:\n", " for i, embedding in enumerate(tqdm(embeddings)):\n", " word = vocab.itos[i]\n", " #skip words with unicode symbols\n", " if len(word) != len(word.encode()):\n", " continue\n", " vector = ' '.join([str(i) for i in embedding.tolist()])\n", " f.write(f'{word} {vector}\\n')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll write our embeddings to `trained_embeddings.txt`." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 25002/25002 [00:00<00:00, 38085.03it/s]\n" ] } ], "source": [ "write_embeddings('custom_embeddings/trained_embeddings.txt', \n", " model.embedding.weight.data, \n", " TEXT.vocab)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To double check they've written correctly, we can load them as `Vectors`." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 70%|███████ | 17550/24946 [00:00<00:00, 87559.48it/s]\n" ] } ], "source": [ "trained_embeddings = vocab.Vectors(name = 'custom_embeddings/trained_embeddings.txt',\n", " cache = 'custom_embeddings',\n", " unk_init = torch.Tensor.normal_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let's print out the first 5 rows of our loaded vectors and the same from our model's embeddings weights, checking they are the same values." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.2573, -0.2088, 0.2413, -0.1549, 0.1940, -0.1466, -0.2195, -0.1011,\n", " -0.1327, 0.1803, 0.2369, -0.2182, 0.1543, -0.2150, -0.0699, -0.0430,\n", " -0.1958, -0.0506, -0.0059, -0.0024],\n", " [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [-0.1427, -0.4414, 0.7181, -0.5751, -0.3183, 0.0552, -1.6764, -0.3177,\n", " 0.6592, 1.6143, -0.1920, -0.1881, -0.4321, -0.8578, 0.5266, 0.5243,\n", " -0.7083, -0.0048, -1.4680, 1.1425],\n", " [-0.4700, -0.0363, 0.0560, -0.7394, -0.2412, -0.4197, -1.7096, 0.9444,\n", " 0.9633, 0.3703, -0.2243, -1.5279, -1.9086, 0.5718, -0.5721, -0.6015,\n", " 0.3579, -0.3834, 0.8079, 1.0553],\n", " [-0.7055, 0.0954, 0.4646, -1.6595, 0.1138, 0.2208, -0.0220, 0.7397,\n", " -0.1153, 0.3586, 0.3040, -0.6414, -0.1579, -0.2738, -0.6942, 0.0083,\n", " 1.4097, 1.5225, 0.6409, 0.0076]])\n" ] } ], "source": [ "print(trained_embeddings.vectors[:5])" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-0.2573, -0.2088, 0.2413, -0.1549, 0.1940, -0.1466, -0.2195, -0.1011,\n", " -0.1327, 0.1803, 0.2369, -0.2182, 0.1543, -0.2150, -0.0699, -0.0430,\n", " -0.1958, -0.0506, -0.0059, -0.0024],\n", " [ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n", " 0.0000, 0.0000, 0.0000, 0.0000],\n", " [-0.1427, -0.4414, 0.7181, -0.5751, -0.3183, 0.0552, -1.6764, -0.3177,\n", " 0.6592, 1.6143, -0.1920, -0.1881, -0.4321, -0.8578, 0.5266, 0.5243,\n", " -0.7083, -0.0048, -1.4680, 1.1425],\n", " [-0.4700, -0.0363, 0.0560, -0.7394, -0.2412, -0.4197, -1.7096, 0.9444,\n", " 0.9633, 0.3703, -0.2243, -1.5279, -1.9086, 0.5718, -0.5721, -0.6015,\n", " 0.3579, -0.3834, 0.8079, 1.0553],\n", " [-0.7055, 0.0954, 0.4646, -1.6595, 0.1138, 0.2208, -0.0220, 0.7397,\n", " -0.1153, 0.3586, 0.3040, -0.6414, -0.1579, -0.2738, -0.6942, 0.0083,\n", " 1.4097, 1.5225, 0.6409, 0.0076]], device='cuda:0')\n" ] } ], "source": [ "print(model.embedding.weight.data[:5])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All looks good! The only difference between the two is the removal of the ~50 words in the vocabulary that contain unicode symbols." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }