1029 lines
27 KiB
Plaintext
1029 lines
27 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 228
|
|
},
|
|
"colab_type": "code",
|
|
"id": "lIYdn1woOS1n",
|
|
"outputId": "f9419fe4-7c0e-4706-a9b9-30fbc836d9a9"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"import torch.optim as optim\n",
|
|
"\n",
|
|
"import torchtext\n",
|
|
"import torchtext.experimental\n",
|
|
"import torchtext.experimental.vectors\n",
|
|
"from torchtext.experimental.datasets.raw.text_classification import RawTextIterableDataset\n",
|
|
"from torchtext.experimental.datasets.text_classification import TextClassificationDataset\n",
|
|
"from torchtext.experimental.functional import sequential_transforms, vocab_func, totensor\n",
|
|
"\n",
|
|
"import collections\n",
|
|
"import random\n",
|
|
"import time"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "II-XIfhSkZS-"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"seed = 1234\n",
|
|
"\n",
|
|
"torch.manual_seed(seed)\n",
|
|
"random.seed(seed)\n",
|
|
"torch.backends.cudnn.deterministic = True\n",
|
|
"torch.backends.cudnn.benchmark = False"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "kIkeEy2mkcT6"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"raw_train_data, raw_test_data = torchtext.experimental.datasets.raw.IMDB()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "_a5ucP1ZkeDv"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_train_valid_split(raw_train_data, split_ratio = 0.7):\n",
|
|
"\n",
|
|
" raw_train_data = list(raw_train_data)\n",
|
|
" \n",
|
|
" random.shuffle(raw_train_data)\n",
|
|
" \n",
|
|
" n_train_examples = int(len(raw_train_data) * split_ratio)\n",
|
|
" \n",
|
|
" train_data = raw_train_data[:n_train_examples]\n",
|
|
" valid_data = raw_train_data[n_train_examples:]\n",
|
|
" \n",
|
|
" train_data = RawTextIterableDataset(train_data)\n",
|
|
" valid_data = RawTextIterableDataset(valid_data)\n",
|
|
" \n",
|
|
" return train_data, valid_data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "1WP4nz-_kf_0"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"raw_train_data, raw_valid_data = get_train_valid_split(raw_train_data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "pPvrMZlWkicJ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Tokenizer:\n",
|
|
" def __init__(self, tokenize_fn = 'basic_english', lower = True, max_length = None):\n",
|
|
" \n",
|
|
" self.tokenize_fn = torchtext.data.utils.get_tokenizer(tokenize_fn)\n",
|
|
" self.lower = lower\n",
|
|
" self.max_length = max_length\n",
|
|
" \n",
|
|
" def tokenize(self, s):\n",
|
|
" \n",
|
|
" tokens = self.tokenize_fn(s)\n",
|
|
" \n",
|
|
" if self.lower:\n",
|
|
" tokens = [token.lower() for token in tokens]\n",
|
|
" \n",
|
|
" if self.max_length is not None:\n",
|
|
" tokens = tokens[:self.max_length]\n",
|
|
" \n",
|
|
" return tokens"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "SMsMQSuSkkt3"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"max_length = 500\n",
|
|
"\n",
|
|
"tokenizer = Tokenizer(max_length = max_length)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "Yie7TKWKkmeK"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def build_vocab_from_data(raw_data, tokenizer, **vocab_kwargs):\n",
|
|
" \n",
|
|
" token_freqs = collections.Counter()\n",
|
|
" \n",
|
|
" for label, text in raw_data:\n",
|
|
" tokens = tokenizer.tokenize(text)\n",
|
|
" token_freqs.update(tokens)\n",
|
|
" \n",
|
|
" vocab = torchtext.vocab.Vocab(token_freqs, **vocab_kwargs)\n",
|
|
" \n",
|
|
" return vocab"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "9jW7Ci7WkoSn"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"max_size = 25_000\n",
|
|
"\n",
|
|
"vocab = build_vocab_from_data(raw_train_data, tokenizer, max_size = max_size)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "cvSZt_iFkqkt"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def process_raw_data(raw_data, tokenizer, vocab):\n",
|
|
" \n",
|
|
" raw_data = [(label, text) for (label, text) in raw_data]\n",
|
|
"\n",
|
|
" text_transform = sequential_transforms(tokenizer.tokenize,\n",
|
|
" vocab_func(vocab),\n",
|
|
" totensor(dtype=torch.long))\n",
|
|
" \n",
|
|
" label_transform = sequential_transforms(totensor(dtype=torch.long))\n",
|
|
"\n",
|
|
" transforms = (label_transform, text_transform)\n",
|
|
"\n",
|
|
" dataset = TextClassificationDataset(raw_data,\n",
|
|
" vocab,\n",
|
|
" transforms)\n",
|
|
" \n",
|
|
" return dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "bwsSiBdkktRk"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_data = process_raw_data(raw_train_data, tokenizer, vocab)\n",
|
|
"valid_data = process_raw_data(raw_valid_data, tokenizer, vocab)\n",
|
|
"test_data = process_raw_data(raw_test_data, tokenizer, vocab)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "5m3xRusSk8v3"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class Collator:\n",
|
|
" def __init__(self, pad_idx, batch_first):\n",
|
|
" \n",
|
|
" self.pad_idx = pad_idx\n",
|
|
" self.batch_first = batch_first\n",
|
|
" \n",
|
|
" def collate(self, batch):\n",
|
|
" \n",
|
|
" labels, text = zip(*batch)\n",
|
|
" \n",
|
|
" labels = torch.LongTensor(labels)\n",
|
|
"\n",
|
|
" text = nn.utils.rnn.pad_sequence(text, \n",
|
|
" padding_value = self.pad_idx,\n",
|
|
" batch_first = self.batch_first)\n",
|
|
" \n",
|
|
" return labels, text"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "1ZMuZqZxk8-p"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"pad_token = '<pad>'\n",
|
|
"pad_idx = vocab[pad_token]\n",
|
|
"batch_first = True\n",
|
|
"\n",
|
|
"collator = Collator(pad_idx, batch_first)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "mxG97Si9lAI2"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 256\n",
|
|
"\n",
|
|
"train_iterator = torch.utils.data.DataLoader(train_data, \n",
|
|
" batch_size, \n",
|
|
" shuffle = True, \n",
|
|
" collate_fn = collator.collate)\n",
|
|
"\n",
|
|
"valid_iterator = torch.utils.data.DataLoader(valid_data, \n",
|
|
" batch_size, \n",
|
|
" shuffle = False, \n",
|
|
" collate_fn = collator.collate)\n",
|
|
"\n",
|
|
"test_iterator = torch.utils.data.DataLoader(test_data, \n",
|
|
" batch_size, \n",
|
|
" shuffle = False, \n",
|
|
" collate_fn = collator.collate)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "ty3NbheMlPYs"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class CNN(nn.Module):\n",
|
|
" def __init__(self, input_dim, emb_dim, n_filters, filter_sizes, output_dim, dropout, pad_idx):\n",
|
|
" super().__init__()\n",
|
|
"\n",
|
|
" self.embedding = nn.Embedding(input_dim, emb_dim, padding_idx = pad_idx)\n",
|
|
" self.convs = nn.ModuleList([nn.Conv1d(in_channels = emb_dim,\n",
|
|
" out_channels = n_filters,\n",
|
|
" kernel_size = filter_size)\n",
|
|
" for filter_size in filter_sizes])\n",
|
|
" self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)\n",
|
|
" self.dropout = nn.Dropout(dropout)\n",
|
|
"\n",
|
|
" def forward(self, text):\n",
|
|
"\n",
|
|
" # text = [batch size, seq len]\n",
|
|
" \n",
|
|
" embedded = self.dropout(self.embedding(text))\n",
|
|
"\n",
|
|
" # embedded = [batch size, seq len, emb dim]\n",
|
|
"\n",
|
|
" embedded = embedded.permute(0, 2, 1)\n",
|
|
"\n",
|
|
" # embedded = [batch size, emb dim, seq len]\n",
|
|
"\n",
|
|
" conved = [F.relu(conv(embedded)) for conv in self.convs]\n",
|
|
"\n",
|
|
" # conved[n] = [batch size, n filters, seq len - filter_sizes[n] + 1]\n",
|
|
"\n",
|
|
" pooled = [F.max_pool1d(conv, conv.shape[-1]).squeeze(-1) for conv in conved]\n",
|
|
"\n",
|
|
" # pooled[n] = [batch size, n filters]\n",
|
|
"\n",
|
|
" cat = torch.cat(pooled, dim = -1)\n",
|
|
"\n",
|
|
" # cat = [batch size, n filters * len(filter_sizes)]\n",
|
|
"\n",
|
|
" prediction = self.fc(self.dropout(cat))\n",
|
|
"\n",
|
|
" # prediction = [batch size, output dim]\n",
|
|
"\n",
|
|
" return prediction"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "trg6yTjBqOLZ"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"input_dim = len(vocab)\n",
|
|
"emb_dim = 100\n",
|
|
"n_filters = 100\n",
|
|
"filter_sizes = [3, 4, 5]\n",
|
|
"output_dim = 2\n",
|
|
"dropout = 0.5\n",
|
|
"pad_idx = pad_idx\n",
|
|
"\n",
|
|
"model = CNN(input_dim, emb_dim, n_filters, filter_sizes, output_dim, dropout, pad_idx)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "9dgdCRsqqQoD"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def count_parameters(model):\n",
|
|
" return sum(p.numel() for p in model.parameters() if p.requires_grad)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 35
|
|
},
|
|
"colab_type": "code",
|
|
"id": "bfiGzjvnqV-s",
|
|
"outputId": "fffbb2a6-0a0a-432f-f182-7697a6903c75"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The model has 2,621,102 trainable parameters\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f'The model has {count_parameters(model):,} trainable parameters')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"name: embedding.weight, shape: torch.Size([25002, 100])\n",
|
|
"name: convs.0.weight, shape: torch.Size([100, 100, 3])\n",
|
|
"name: convs.0.bias, shape: torch.Size([100])\n",
|
|
"name: convs.1.weight, shape: torch.Size([100, 100, 4])\n",
|
|
"name: convs.1.bias, shape: torch.Size([100])\n",
|
|
"name: convs.2.weight, shape: torch.Size([100, 100, 5])\n",
|
|
"name: convs.2.bias, shape: torch.Size([100])\n",
|
|
"name: fc.weight, shape: torch.Size([2, 300])\n",
|
|
"name: fc.bias, shape: torch.Size([2])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for n, p in model.named_parameters():\n",
|
|
" print(f'name: {n}, shape: {p.shape}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def initialize_parameters(m):\n",
|
|
" if isinstance(m, nn.Embedding):\n",
|
|
" nn.init.uniform_(m.weight, -0.05, 0.05)\n",
|
|
" elif isinstance(m, nn.Conv1d):\n",
|
|
" nn.init.xavier_uniform_(m.weight)\n",
|
|
" nn.init.zeros_(m.bias)\n",
|
|
" elif isinstance(m, nn.Linear):\n",
|
|
" nn.init.xavier_uniform_(m.weight)\n",
|
|
" nn.init.zeros_(m.bias)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"CNN(\n",
|
|
" (embedding): Embedding(25002, 100, padding_idx=1)\n",
|
|
" (convs): ModuleList(\n",
|
|
" (0): Conv1d(100, 100, kernel_size=(3,), stride=(1,))\n",
|
|
" (1): Conv1d(100, 100, kernel_size=(4,), stride=(1,))\n",
|
|
" (2): Conv1d(100, 100, kernel_size=(5,), stride=(1,))\n",
|
|
" )\n",
|
|
" (fc): Linear(in_features=300, out_features=2, bias=True)\n",
|
|
" (dropout): Dropout(p=0.5, inplace=False)\n",
|
|
")"
|
|
]
|
|
},
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.apply(initialize_parameters)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "Sah17A41qW5d"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"glove = torchtext.experimental.vectors.GloVe(name = '6B',\n",
|
|
" dim = emb_dim)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "S1Dfcn2Nqabo"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_pretrained_embedding(initial_embedding, pretrained_vectors, vocab, unk_token):\n",
|
|
" \n",
|
|
" pretrained_embedding = torch.FloatTensor(initial_embedding.weight.clone()).detach() \n",
|
|
" pretrained_vocab = pretrained_vectors.vectors.get_stoi()\n",
|
|
" \n",
|
|
" unk_tokens = []\n",
|
|
" \n",
|
|
" for idx, token in enumerate(vocab.itos):\n",
|
|
" if token in pretrained_vocab:\n",
|
|
" pretrained_vector = pretrained_vectors[token]\n",
|
|
" pretrained_embedding[idx] = pretrained_vector\n",
|
|
" else:\n",
|
|
" unk_tokens.append(token)\n",
|
|
" \n",
|
|
" return pretrained_embedding, unk_tokens"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "sGyV94f7qvdr"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"unk_token = '<unk>'\n",
|
|
"\n",
|
|
"pretrained_embedding, unk_tokens = get_pretrained_embedding(model.embedding, glove, vocab, unk_token)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 139
|
|
},
|
|
"colab_type": "code",
|
|
"id": "KYnGxbVisUsk",
|
|
"outputId": "39d1354c-9a3a-4a6e-bf4a-8595d7f4eac9"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor([[-0.0220, -0.0288, -0.0422, ..., 0.0103, 0.0218, -0.0141],\n",
|
|
" [ 0.0326, 0.0222, 0.0044, ..., 0.0249, 0.0163, 0.0052],\n",
|
|
" [-0.0382, -0.2449, 0.7281, ..., -0.1459, 0.8278, 0.2706],\n",
|
|
" ...,\n",
|
|
" [-0.2925, 0.1087, 0.7920, ..., -0.3641, 0.1822, -0.4104],\n",
|
|
" [-0.7250, 0.7545, 0.1637, ..., -0.0144, -0.1761, 0.3418],\n",
|
|
" [ 1.1753, 0.0460, -0.3542, ..., 0.4510, 0.0485, -0.4015]])"
|
|
]
|
|
},
|
|
"execution_count": 25,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.embedding.weight.data.copy_(pretrained_embedding)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.embedding.weight.data[pad_idx] = torch.zeros(emb_dim)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "DTwNU41WseMS"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"optimizer = optim.Adam(model.parameters())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "Rxlx7a72s1ze"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"criterion = nn.CrossEntropyLoss()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "1CLimBxus2yX"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "108fm55ftBgO"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = model.to(device)\n",
|
|
"criterion = criterion.to(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "IYCxbvXUvE5v"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def calculate_accuracy(predictions, labels):\n",
|
|
" top_predictions = predictions.argmax(1, keepdim = True)\n",
|
|
" correct = top_predictions.eq(labels.view_as(top_predictions)).sum()\n",
|
|
" accuracy = correct.float() / labels.shape[0]\n",
|
|
" return accuracy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "Ik2JQo6TvGml"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train(model, iterator, optimizer, criterion, device):\n",
|
|
" \n",
|
|
" epoch_loss = 0\n",
|
|
" epoch_acc = 0\n",
|
|
" \n",
|
|
" model.train()\n",
|
|
" \n",
|
|
" for labels, text in iterator:\n",
|
|
" \n",
|
|
" labels = labels.to(device)\n",
|
|
" text = text.to(device)\n",
|
|
"\n",
|
|
" optimizer.zero_grad()\n",
|
|
" \n",
|
|
" predictions = model(text)\n",
|
|
" \n",
|
|
" loss = criterion(predictions, labels)\n",
|
|
" \n",
|
|
" acc = calculate_accuracy(predictions, labels)\n",
|
|
" \n",
|
|
" loss.backward()\n",
|
|
" \n",
|
|
" optimizer.step()\n",
|
|
" \n",
|
|
" epoch_loss += loss.item()\n",
|
|
" epoch_acc += acc.item()\n",
|
|
"\n",
|
|
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "aGy1Zk6jvIf8"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def evaluate(model, iterator, criterion, device):\n",
|
|
" \n",
|
|
" epoch_loss = 0\n",
|
|
" epoch_acc = 0\n",
|
|
" \n",
|
|
" model.eval()\n",
|
|
" \n",
|
|
" with torch.no_grad():\n",
|
|
" \n",
|
|
" for labels, text in iterator:\n",
|
|
"\n",
|
|
" labels = labels.to(device)\n",
|
|
" text = text.to(device)\n",
|
|
" \n",
|
|
" predictions = model(text)\n",
|
|
" \n",
|
|
" loss = criterion(predictions, labels)\n",
|
|
" \n",
|
|
" acc = calculate_accuracy(predictions, labels)\n",
|
|
"\n",
|
|
" epoch_loss += loss.item()\n",
|
|
" epoch_acc += acc.item()\n",
|
|
" \n",
|
|
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "9MyMRRzbvKPx"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def epoch_time(start_time, end_time):\n",
|
|
" elapsed_time = end_time - start_time\n",
|
|
" elapsed_mins = int(elapsed_time / 60)\n",
|
|
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
|
|
" return elapsed_mins, elapsed_secs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 35,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 537
|
|
},
|
|
"colab_type": "code",
|
|
"id": "dRKwD51WvMa3",
|
|
"outputId": "935b7d4b-c396-42d8-8041-802ec9575cd6"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 01 | Epoch Time: 0m 9s\n",
|
|
"\tTrain Loss: 1.370 | Train Acc: 53.26%\n",
|
|
"\t Val. Loss: 0.588 | Val. Acc: 69.31%\n",
|
|
"Epoch: 02 | Epoch Time: 0m 9s\n",
|
|
"\tTrain Loss: 0.796 | Train Acc: 60.77%\n",
|
|
"\t Val. Loss: 0.562 | Val. Acc: 73.82%\n",
|
|
"Epoch: 03 | Epoch Time: 0m 9s\n",
|
|
"\tTrain Loss: 0.620 | Train Acc: 67.86%\n",
|
|
"\t Val. Loss: 0.523 | Val. Acc: 78.67%\n",
|
|
"Epoch: 04 | Epoch Time: 0m 9s\n",
|
|
"\tTrain Loss: 0.523 | Train Acc: 74.40%\n",
|
|
"\t Val. Loss: 0.459 | Val. Acc: 81.48%\n",
|
|
"Epoch: 05 | Epoch Time: 0m 9s\n",
|
|
"\tTrain Loss: 0.459 | Train Acc: 78.51%\n",
|
|
"\t Val. Loss: 0.416 | Val. Acc: 83.35%\n",
|
|
"Epoch: 06 | Epoch Time: 0m 9s\n",
|
|
"\tTrain Loss: 0.412 | Train Acc: 81.52%\n",
|
|
"\t Val. Loss: 0.381 | Val. Acc: 84.52%\n",
|
|
"Epoch: 07 | Epoch Time: 0m 9s\n",
|
|
"\tTrain Loss: 0.374 | Train Acc: 83.71%\n",
|
|
"\t Val. Loss: 0.369 | Val. Acc: 84.95%\n",
|
|
"Epoch: 08 | Epoch Time: 0m 9s\n",
|
|
"\tTrain Loss: 0.356 | Train Acc: 84.29%\n",
|
|
"\t Val. Loss: 0.356 | Val. Acc: 85.49%\n",
|
|
"Epoch: 09 | Epoch Time: 0m 9s\n",
|
|
"\tTrain Loss: 0.339 | Train Acc: 85.20%\n",
|
|
"\t Val. Loss: 0.344 | Val. Acc: 85.92%\n",
|
|
"Epoch: 10 | Epoch Time: 0m 9s\n",
|
|
"\tTrain Loss: 0.318 | Train Acc: 86.43%\n",
|
|
"\t Val. Loss: 0.334 | Val. Acc: 86.28%\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"n_epochs = 10\n",
|
|
"\n",
|
|
"best_valid_loss = float('inf')\n",
|
|
"\n",
|
|
"for epoch in range(n_epochs):\n",
|
|
"\n",
|
|
" start_time = time.monotonic()\n",
|
|
" \n",
|
|
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)\n",
|
|
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, device)\n",
|
|
" \n",
|
|
" end_time = time.monotonic()\n",
|
|
"\n",
|
|
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
|
|
" \n",
|
|
" if valid_loss < best_valid_loss:\n",
|
|
" best_valid_loss = valid_loss\n",
|
|
" torch.save(model.state_dict(), 'cnn-model.pt')\n",
|
|
" \n",
|
|
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
|
|
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
|
|
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 35
|
|
},
|
|
"colab_type": "code",
|
|
"id": "hKOg4oARvPHJ",
|
|
"outputId": "b5552b10-fcca-4c29-8d4b-4f5688ef53dd"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Test Loss: 0.338 | Test Acc: 85.99%\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.load_state_dict(torch.load('cnn-model.pt'))\n",
|
|
"\n",
|
|
"test_loss, test_acc = evaluate(model, test_iterator, criterion, device)\n",
|
|
"\n",
|
|
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"metadata": {
|
|
"colab": {},
|
|
"colab_type": "code",
|
|
"id": "tQ4Jsf_vvWgB"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def predict_sentiment(tokenizer, vocab, model, device, sentence):\n",
|
|
" model.eval()\n",
|
|
" tokens = tokenizer.tokenize(sentence)\n",
|
|
" indexes = [vocab.stoi[token] for token in tokens]\n",
|
|
" tensor = torch.LongTensor(indexes).unsqueeze(0).to(device)\n",
|
|
" prediction = model(tensor)\n",
|
|
" probabilities = nn.functional.softmax(prediction, dim = -1)\n",
|
|
" pos_probability = probabilities.squeeze()[-1].item()\n",
|
|
" return pos_probability"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 35
|
|
},
|
|
"colab_type": "code",
|
|
"id": "Yy7_6rhovZTE",
|
|
"outputId": "4297c903-8ef3-4c94-8a9e-21fbb98a6be9"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.08827298134565353"
|
|
]
|
|
},
|
|
"execution_count": 38,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"sentence = 'the absolute worst movie of all time.'\n",
|
|
"\n",
|
|
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 35
|
|
},
|
|
"colab_type": "code",
|
|
"id": "L3LmQxrgvau9",
|
|
"outputId": "afee78c4-6c74-4900-dd3b-53ad1c1b7b26"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.6329940557479858"
|
|
]
|
|
},
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"sentence = 'one of the greatest films i have ever seen in my life.'\n",
|
|
"\n",
|
|
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 35
|
|
},
|
|
"colab_type": "code",
|
|
"id": "t7Qoy21Bvb7v",
|
|
"outputId": "d85a8a1b-b4dc-4aea-e58e-2597087b46c2"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.060872383415699005"
|
|
]
|
|
},
|
|
"execution_count": 40,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"sentence = \"i thought it was going to be one of the greatest films i have ever seen in my life, \\\n",
|
|
"but it was actually the absolute worst movie of all time.\"\n",
|
|
"\n",
|
|
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 35
|
|
},
|
|
"colab_type": "code",
|
|
"id": "EPGXBr18vdQT",
|
|
"outputId": "1b28c7d1-9e12-462f-d9ac-2b4876b3b6b4"
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.07820437103509903"
|
|
]
|
|
},
|
|
"execution_count": 41,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"sentence = \"i thought it was going to be the absolute worst movie of all time, \\\n",
|
|
"but it was actually one of the greatest films i have ever seen in my life.\"\n",
|
|
"\n",
|
|
"predict_sentiment(tokenizer, vocab, model, device, sentence)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"machine_shape": "hm",
|
|
"name": "scratchpad",
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 1
|
|
}
|