pytorch-sentiment-analysis/A - Using TorchText with Your Own Datasets.ipynb
2020-01-28 09:53:16 +00:00

453 lines
18 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A - Using TorchText with Your Own Datasets\n",
"\n",
"In this series we have used the IMDb dataset included as a dataset in TorchText. TorchText has many canonical datasets included for classification, language modelling, sequence tagging, etc. However, frequently you'll be wanting to use your own datasets. Luckily, TorchText has functions to help you to this.\n",
"\n",
"Recall in the series, we:\n",
"- defined the `Field`s\n",
"- loaded the dataset\n",
"- created the splits\n",
"\n",
"As a reminder, the code is shown below:\n",
"\n",
"```python\n",
"TEXT = data.Field()\n",
"LABEL = data.LabelField()\n",
"\n",
"train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
"\n",
"train_data, valid_data = train_data.split()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are three data formats TorchText can read: `json`, `tsv` (tab separated values) and`csv` (comma separated values).\n",
"\n",
"**In my opinion, the best formatting for TorchText is `json`, which I'll explain later on.**\n",
"\n",
"## Reading JSON\n",
"\n",
"Starting with `json`, your data must be in the `json lines` format, i.e. it must be something like:\n",
"\n",
"```\n",
"{\"name\": \"John\", \"location\": \"United Kingdom\", \"age\": 42, \"quote\": [\"i\", \"love\", \"the\", \"united kingdom\"]}\n",
"{\"name\": \"Mary\", \"location\": \"United States\", \"age\": 36, \"quote\": [\"i\", \"want\", \"more\", \"telescopes\"]}\n",
"```\n",
"\n",
"That is, each line is a `json` object. See `data/train.json` for an example.\n",
"\n",
"We then define the fields:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from torchtext import data\n",
"from torchtext import datasets\n",
"\n",
"NAME = data.Field()\n",
"SAYING = data.Field()\n",
"PLACE = data.Field()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we must tell TorchText which fields apply to which elements of the `json` object. \n",
"\n",
"For `json` data, we must create a dictionary where:\n",
"- the key matches the key of the `json` object\n",
"- the value is a tuple where:\n",
" - the first element becomes the batch object's attribute name\n",
" - the second element is the name of the `Field`\n",
" \n",
"What do we mean when we say \"becomes the batch object's attribute name\"? Recall in the previous exercises where we accessed the `TEXT` and `LABEL` fields in the train/evaluation loop by using `batch.text` and `batch.label`, this is because TorchText sets the batch object to have a `text` and `label` attribute, each being a tensor containing either the text or the label.\n",
"\n",
"A few notes:\n",
"\n",
"* The order of the keys in the `fields` dictionary does not matter, as long as its keys match the `json` data keys.\n",
"\n",
"- The `Field` name does not have to match the key in the `json` object, e.g. we use `PLACE` for the `\"location\"` field.\n",
"\n",
"- When dealing with `json` data, not all of the keys have to be used, e.g. we did not use the `\"age\"` field.\n",
"\n",
"- Also, if the values of `json` field are a string then the `Fields` tokenization is applied (default is to split the string on spaces), however if the values are a list then no tokenization is applied. Usually it is a good idea for the data to already be tokenized into a list, this saves time as you don't have to wait for TorchText to do it.\n",
"\n",
"- The value of the `json` fields do not have to be the same type. Some examples can have their `\"quote\"` as a string, and some as a list. The tokenization will only get applied to the ones with their `\"quote\"` as a string.\n",
"\n",
"- If you are using a `json` field, every single example must have an instance of that field, e.g. in this example all examples must have a name, location and quote. However, as we are not using the age field, it does not matter if an example does not have it."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"fields = {'name': ('n', NAME), 'location': ('p', PLACE), 'quote': ('s', SAYING)}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, in a training loop we can iterate over the data iterator and access the name via `batch.n`, the location via `batch.p`, and the quote via `batch.s`.\n",
"\n",
"We then create our datasets (`train_data` and `test_data`) with the `TabularDataset.splits` function. \n",
"\n",
"The `path` argument specifices the top level folder common among both datasets, and the `train` and `test` arguments specify the filename of each dataset, e.g. here the train dataset is located at `data/train.json`.\n",
"\n",
"We tell the function we are using `json` data, and pass in our `fields` dictionary defined previously."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"train_data, test_data = data.TabularDataset.splits(\n",
" path = 'data',\n",
" train = 'train.json',\n",
" test = 'test.json',\n",
" format = 'json',\n",
" fields = fields\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you already had a validation dataset, the location of this can be passed as the `validation` argument."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train_data, valid_data, test_data = data.TabularDataset.splits(\n",
" path = 'data',\n",
" train = 'train.json',\n",
" validation = 'valid.json',\n",
" test = 'test.json',\n",
" format = 'json',\n",
" fields = fields\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can then view an example to make sure it has worked correctly.\n",
"\n",
"Notice how the field names (`n`, `p` and `s`) match up with what was defined in the `fields` dictionary.\n",
"\n",
"Also notice how the word `\"United Kingdom\"` in `p` has been split by the tokenization, whereas the `\"united kingdom\"` in `s` has not. This is due to what was mentioned previously, where TorchText assumes that any `json` fields that are lists are already tokenized and no further tokenization is applied. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'n': ['John'], 'p': ['United', 'Kingdom'], 's': ['i', 'love', 'the', 'united kingdom']}\n"
]
}
],
"source": [
"print(vars(train_data[0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now use `train_data`, `test_data` and `valid_data` to build a vocabulary and create iterators, as in the other notebooks. We can access all attributes by using `batch.n`, `batch.p` and `batch.s` for the names, places and sayings, respectively.\n",
"\n",
"## Reading CSV/TSV\n",
"\n",
"`csv` and `tsv` are very similar, except csv has elements separated by commas and tsv by tabs.\n",
"\n",
"Using the same example above, our `tsv` data will be in the form of:\n",
"\n",
"```\n",
"name\tlocation\tage\tquote\n",
"John\tUnited Kingdom\t42\ti love the united kingdom\n",
"Mary\tUnited States\t36\ti want more telescopes\n",
"```\n",
"\n",
"That is, on each row the elements are separated by tabs and we have one example per row. The first row is usually a header (i.e. the name of each of the columns), but your data could have no header.\n",
"\n",
"You cannot have lists within `tsv` or `csv` data.\n",
"\n",
"The way the fields are defined is a bit different to `json`. We now use a list of tuples, where each element is also a tuple. The first element of these inner tuples will become the batch object's attribute name, second element is the `Field` name.\n",
"\n",
"Unlike the `json` data, the tuples have to be in the same order that they are within the `tsv` data. Due to this, when skipping a column of data a tuple of `None`s needs to be used, if not then our `SAYING` field will be applied to the `age` column of the `tsv` data and the `quote` column will not be used. \n",
"\n",
"However, if you only wanted to use the `name` and `age` column, you could just use two tuples as they are the first two columns.\n",
"\n",
"We change our `TabularDataset` to read the correct `.tsv` files, and change the `format` argument to `'tsv'`.\n",
"\n",
"If your data has a header, which ours does, it must be skipped by passing `skip_header = True`. If not, TorchText will think the header is an example. By default, `skip_header` will be `False`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"fields = [('n', NAME), ('p', PLACE), (None, None), ('s', SAYING)]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"train_data, valid_data, test_data = data.TabularDataset.splits(\n",
" path = 'data',\n",
" train = 'train.tsv',\n",
" validation = 'valid.tsv',\n",
" test = 'test.tsv',\n",
" format = 'tsv',\n",
" fields = fields,\n",
" skip_header = True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'n': ['John'], 'p': ['United', 'Kingdom'], 's': ['i', 'love', 'the', 'united', 'kingdom']}\n"
]
}
],
"source": [
"print(vars(train_data[0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, we'll cover `csv` files. \n",
"\n",
"This is pretty much the exact same as the `tsv` files, expect with the `format` argument set to `'csv'`."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"fields = [('n', NAME), ('p', PLACE), (None, None), ('s', SAYING)]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"train_data, valid_data, test_data = data.TabularDataset.splits(\n",
" path = 'data',\n",
" train = 'train.csv',\n",
" validation = 'valid.csv',\n",
" test = 'test.csv',\n",
" format = 'csv',\n",
" fields = fields,\n",
" skip_header = True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'n': ['John'], 'p': ['United', 'Kingdom'], 's': ['i', 'love', 'the', 'united', 'kingdom']}\n"
]
}
],
"source": [
"print(vars(train_data[0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Why JSON over CSV/TSV?\n",
"\n",
"1. Your `csv` or `tsv` data cannot be stored lists. This means data cannot be already be tokenized, thus everytime you run your Python script that reads this data via TorchText, it has to be tokenized. Using advanced tokenizers, such as the `spaCy` tokenizer, takes a non-negligible amount of time. Thus, it is better to tokenize your datasets and store them in the `json lines` format.\n",
"\n",
"2. If tabs appear in your `tsv` data, or commas appear in your `csv` data, TorchText will think they are delimiters between columns. This will cause your data to be parsed incorrectly. Worst of all TorchText will not alert you to this as it cannot tell the difference between a tab/comma in a field and a tab/comma as a delimiter. As `json` data is essentially a dictionary, you access the data within the fields via its key, so do not have to worry about \"surprise\" delimiters."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Iterators \n",
"\n",
"Using any of the above datasets, we can then build the vocab and create the iterators."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"NAME.build_vocab(train_data)\n",
"SAYING.build_vocab(train_data)\n",
"PLACE.build_vocab(train_data)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we can create the iterators after defining our batch size and device.\n",
"\n",
"By default, the train data is shuffled each epoch, but the validation/test data is sorted. However, TorchText doesn't know what to use to sort our data and it would throw an error if we don't tell it. \n",
"\n",
"There are two ways to handle this, you can either tell the iterator not to sort the validation/test data by passing `sort = False`, or you can tell it how to sort the data by passing a `sort_key`. A sort key is a function that returns a key on which to sort the data on. For example, `lambda x: x.s` will sort the examples by their `s` attribute, i.e their quote. Ideally, you want to use a sort key as the `BucketIterator` will then be able to sort your examples and then minimize the amount of padding within each batch.\n",
"\n",
"We can then iterate over our iterator to get batches of data. Note how by default TorchText has the batch dimension second."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train:\n",
"\n",
"[torchtext.data.batch.Batch of size 1]\n",
"\t[.n]:[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
"\t[.p]:[torch.cuda.LongTensor of size 2x1 (GPU 0)]\n",
"\t[.s]:[torch.cuda.LongTensor of size 5x1 (GPU 0)]\n",
"\n",
"[torchtext.data.batch.Batch of size 1]\n",
"\t[.n]:[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
"\t[.p]:[torch.cuda.LongTensor of size 2x1 (GPU 0)]\n",
"\t[.s]:[torch.cuda.LongTensor of size 4x1 (GPU 0)]\n",
"Valid:\n",
"\n",
"[torchtext.data.batch.Batch of size 1]\n",
"\t[.n]:[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
"\t[.p]:[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
"\t[.s]:[torch.cuda.LongTensor of size 2x1 (GPU 0)]\n",
"\n",
"[torchtext.data.batch.Batch of size 1]\n",
"\t[.n]:[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
"\t[.p]:[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
"\t[.s]:[torch.cuda.LongTensor of size 4x1 (GPU 0)]\n",
"Test:\n",
"\n",
"[torchtext.data.batch.Batch of size 1]\n",
"\t[.n]:[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
"\t[.p]:[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
"\t[.s]:[torch.cuda.LongTensor of size 3x1 (GPU 0)]\n",
"\n",
"[torchtext.data.batch.Batch of size 1]\n",
"\t[.n]:[torch.cuda.LongTensor of size 1x1 (GPU 0)]\n",
"\t[.p]:[torch.cuda.LongTensor of size 2x1 (GPU 0)]\n",
"\t[.s]:[torch.cuda.LongTensor of size 3x1 (GPU 0)]\n"
]
}
],
"source": [
"import torch\n",
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"BATCH_SIZE = 1\n",
"\n",
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
" (train_data, valid_data, test_data),\n",
" sort = False, #don't sort test/validation data\n",
" batch_size=BATCH_SIZE,\n",
" device=device)\n",
"\n",
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
" (train_data, valid_data, test_data),\n",
" sort_key = lambda x: x.s, #sort by s attribute (quote)\n",
" batch_size=BATCH_SIZE,\n",
" device=device)\n",
"\n",
"print('Train:')\n",
"for batch in train_iterator:\n",
" print(batch)\n",
" \n",
"print('Valid:')\n",
"for batch in valid_iterator:\n",
" print(batch)\n",
" \n",
"print('Test:')\n",
"for batch in test_iterator:\n",
" print(batch)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}