413 lines
11 KiB
Plaintext
413 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torchtext import data\n",
|
|
"from torchtext import datasets\n",
|
|
"import random\n",
|
|
"\n",
|
|
"SEED = 1234\n",
|
|
"\n",
|
|
"torch.manual_seed(SEED)\n",
|
|
"torch.cuda.manual_seed(SEED)\n",
|
|
"torch.backends.cudnn.deterministic = True\n",
|
|
"\n",
|
|
"TEXT = data.Field(tokenize='spacy')\n",
|
|
"LABEL = data.LabelField()\n",
|
|
"\n",
|
|
"train_data, test_data = datasets.TREC.splits(TEXT, LABEL, fine_grained=False)\n",
|
|
"\n",
|
|
"train_data, valid_data = train_data.split(random_state=random.seed(SEED))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"{'text': ['Who', 'is', 'Peter', 'Weir', '?'], 'label': 'HUM'}"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"vars(train_data[0])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"TEXT.build_vocab(train_data, max_size=25000, vectors=\"glove.6B.100d\")\n",
|
|
"LABEL.build_vocab(train_data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"defaultdict(<function _default_unk_index at 0x7f39f3fd0f28>, {'HUM': 0, 'ENTY': 1, 'DESC': 2, 'NUM': 3, 'LOC': 4, 'ABBR': 5})\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(LABEL.vocab.stoi)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"BATCH_SIZE = 64\n",
|
|
"\n",
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"\n",
|
|
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
|
|
" (train_data, valid_data, test_data), \n",
|
|
" batch_size=BATCH_SIZE, \n",
|
|
" device=device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"\n",
|
|
"class CNN(nn.Module):\n",
|
|
" def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout):\n",
|
|
" super().__init__()\n",
|
|
" \n",
|
|
" self.embedding = nn.Embedding(vocab_size, embedding_dim)\n",
|
|
" self.convs = nn.ModuleList([nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(fs,embedding_dim)) for fs in filter_sizes])\n",
|
|
" self.fc = nn.Linear(len(filter_sizes)*n_filters, output_dim)\n",
|
|
" self.dropout = nn.Dropout(dropout)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" \n",
|
|
" #x = [sent len, batch size]\n",
|
|
" \n",
|
|
" x = x.permute(1, 0)\n",
|
|
" \n",
|
|
" #x = [batch size, sent len]\n",
|
|
" \n",
|
|
" embedded = self.embedding(x)\n",
|
|
" \n",
|
|
" #embedded = [batch size, sent len, emb dim]\n",
|
|
" \n",
|
|
" embedded = embedded.unsqueeze(1)\n",
|
|
" \n",
|
|
" #embedded = [batch size, 1, sent len, emb dim]\n",
|
|
" \n",
|
|
" conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]\n",
|
|
" \n",
|
|
" #conv_n = [batch size, n_filters, sent len - filter_sizes[n]]\n",
|
|
" \n",
|
|
" pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]\n",
|
|
" \n",
|
|
" #pooled_n = [batch size, n_filters]\n",
|
|
" \n",
|
|
" cat = self.dropout(torch.cat(pooled, dim=1))\n",
|
|
"\n",
|
|
" #cat = [batch size, n_filters * len(filter_sizes)]\n",
|
|
" \n",
|
|
" return self.fc(cat)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"INPUT_DIM = len(TEXT.vocab)\n",
|
|
"EMBEDDING_DIM = 100\n",
|
|
"N_FILTERS = 100\n",
|
|
"FILTER_SIZES = [2,3,4]\n",
|
|
"OUTPUT_DIM = len(LABEL.vocab)\n",
|
|
"DROPOUT = 0.5\n",
|
|
"\n",
|
|
"model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM, DROPOUT)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
|
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n",
|
|
" [ 0.1638, 0.6046, 1.0789, ..., -0.3140, 0.1844, 0.3624],\n",
|
|
" ...,\n",
|
|
" [-0.3110, -0.3398, 1.0308, ..., 0.5317, 0.2836, -0.0640],\n",
|
|
" [ 0.0091, 0.2810, 0.7356, ..., -0.7508, 0.8967, -0.7631],\n",
|
|
" [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]])"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"pretrained_embeddings = TEXT.vocab.vectors\n",
|
|
"\n",
|
|
"model.embedding.weight.data.copy_(pretrained_embeddings)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch.optim as optim\n",
|
|
"\n",
|
|
"optimizer = optim.Adam(model.parameters())\n",
|
|
"\n",
|
|
"criterion = nn.CrossEntropyLoss()\n",
|
|
"\n",
|
|
"model = model.to(device)\n",
|
|
"criterion = criterion.to(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def categorical_accuracy(preds, y):\n",
|
|
" \"\"\"\n",
|
|
" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
|
|
" \"\"\"\n",
|
|
" max_preds = preds.argmax(dim=1, keepdim=True) # get the index of the max probability\n",
|
|
" correct = max_preds.squeeze(1).eq(y)\n",
|
|
" return correct.sum()/torch.FloatTensor([y.shape[0]])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train(model, iterator, optimizer, criterion):\n",
|
|
" \n",
|
|
" epoch_loss = 0\n",
|
|
" epoch_acc = 0\n",
|
|
" \n",
|
|
" model.train()\n",
|
|
" \n",
|
|
" for batch in iterator:\n",
|
|
" \n",
|
|
" optimizer.zero_grad()\n",
|
|
" \n",
|
|
" predictions = model(batch.text)\n",
|
|
" \n",
|
|
" loss = criterion(predictions, batch.label)\n",
|
|
" \n",
|
|
" acc = categorical_accuracy(predictions, batch.label)\n",
|
|
" \n",
|
|
" loss.backward()\n",
|
|
" \n",
|
|
" optimizer.step()\n",
|
|
" \n",
|
|
" epoch_loss += loss.item()\n",
|
|
" epoch_acc += acc.item()\n",
|
|
" \n",
|
|
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def evaluate(model, iterator, criterion):\n",
|
|
" \n",
|
|
" epoch_loss = 0\n",
|
|
" epoch_acc = 0\n",
|
|
" \n",
|
|
" model.eval()\n",
|
|
" \n",
|
|
" with torch.no_grad():\n",
|
|
" \n",
|
|
" for batch in iterator:\n",
|
|
"\n",
|
|
" predictions = model(batch.text)\n",
|
|
" \n",
|
|
" loss = criterion(predictions, batch.label)\n",
|
|
" \n",
|
|
" acc = categorical_accuracy(predictions, batch.label)\n",
|
|
"\n",
|
|
" epoch_loss += loss.item()\n",
|
|
" epoch_acc += acc.item()\n",
|
|
" \n",
|
|
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| Epoch: 01 | Train Loss: 1.475 | Train Acc: 41.20% | Val. Loss: 1.211 | Val. Acc: 57.82% |\n",
|
|
"| Epoch: 02 | Train Loss: 1.037 | Train Acc: 63.54% | Val. Loss: 0.865 | Val. Acc: 69.62% |\n",
|
|
"| Epoch: 03 | Train Loss: 0.703 | Train Acc: 77.10% | Val. Loss: 0.608 | Val. Acc: 79.60% |\n",
|
|
"| Epoch: 04 | Train Loss: 0.460 | Train Acc: 85.67% | Val. Loss: 0.484 | Val. Acc: 82.65% |\n",
|
|
"| Epoch: 05 | Train Loss: 0.301 | Train Acc: 91.20% | Val. Loss: 0.429 | Val. Acc: 84.56% |\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"N_EPOCHS = 5\n",
|
|
"\n",
|
|
"for epoch in range(N_EPOCHS):\n",
|
|
"\n",
|
|
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
|
|
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
|
|
" \n",
|
|
" print(f'| Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}% | Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}% |')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"| Test Loss: 0.345 | Test Acc: 90.11% |\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
|
|
"\n",
|
|
"print(f'| Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}% |')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import spacy\n",
|
|
"nlp = spacy.load('en')\n",
|
|
"\n",
|
|
"def predict_sentiment(sentence, min_len=4):\n",
|
|
" tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
|
|
" if len(tokenized) < min_len:\n",
|
|
" tokenized += ['<pad>'] * (min_len - len(tokenized))\n",
|
|
" indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
|
|
" tensor = torch.LongTensor(indexed).to(device)\n",
|
|
" tensor = tensor.unsqueeze(1)\n",
|
|
" preds = model(tensor)\n",
|
|
" max_preds = preds.argmax(dim=1)\n",
|
|
" return max_preds.item()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Predicted class is: 0 = HUM\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"pred_class = predict_sentiment(\"Who is Donald Trump?\")\n",
|
|
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"An example positive review..."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Predicted class is: 5 = ABBR\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"pred_class = predict_sentiment(\"What does is USSR stand for?\")\n",
|
|
"print(f'Predicted class is: {pred_class} = {LABEL.vocab.itos[pred_class]}')"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|