53 lines
1.7 KiB
Python
53 lines
1.7 KiB
Python
# -*- encoding:utf-8 -*-
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
from uer.layers.layer_norm import LayerNorm
|
|
from uer.utils.act_fun import gelu
|
|
|
|
|
|
class LmTarget(nn.Module):
|
|
"""
|
|
"""
|
|
def __init__(self, args, vocab_size):
|
|
super(LmTarget, self).__init__()
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = args.hidden_size
|
|
|
|
self.softmax = nn.LogSoftmax(dim=-1)
|
|
self.output_layer = nn.Linear(self.hidden_size, self.vocab_size)
|
|
|
|
def forward(self, memory_bank, tgt):
|
|
"""
|
|
Args:
|
|
memory_bank: [batch_size x seq_length x hidden_size]
|
|
tgt: [batch_size x seq_length]
|
|
|
|
Returns:
|
|
loss: Language modeling loss.
|
|
correct: Number of words that are predicted correctly.
|
|
denominator: Number of predicted words.
|
|
"""
|
|
|
|
# Language modeling (LM) with full softmax prediction.
|
|
output = self.output_layer(memory_bank)
|
|
output = output.contiguous().view(-1, self.vocab_size)
|
|
# Full probability distribution.
|
|
output = self.softmax(output)
|
|
|
|
tgt = tgt.contiguous().view(-1,1)
|
|
label_mask = (tgt > 0).float().to(torch.device(output.device))
|
|
one_hot = torch.zeros(label_mask.size(0), self.vocab_size). \
|
|
to(torch.device(output.device)). \
|
|
scatter_(1, tgt, 1.0)
|
|
|
|
numerator = -torch.sum(output * one_hot, 1)
|
|
label_mask = label_mask.contiguous().view(-1)
|
|
tgt = tgt.contiguous().view(-1)
|
|
numerator = torch.sum(label_mask * numerator)
|
|
denominator = torch.sum(label_mask) + 1e-6
|
|
loss = numerator / denominator
|
|
correct = torch.sum(label_mask * (output.argmax(dim=-1).eq(tgt)).float())
|
|
|
|
return loss, correct, denominator
|