K-BERT/uer/targets/lm_target.py
2019-12-12 19:37:32 +08:00

53 lines
1.7 KiB
Python

# -*- encoding:utf-8 -*-
import math
import torch
import torch.nn as nn
from uer.layers.layer_norm import LayerNorm
from uer.utils.act_fun import gelu
class LmTarget(nn.Module):
"""
"""
def __init__(self, args, vocab_size):
super(LmTarget, self).__init__()
self.vocab_size = vocab_size
self.hidden_size = args.hidden_size
self.softmax = nn.LogSoftmax(dim=-1)
self.output_layer = nn.Linear(self.hidden_size, self.vocab_size)
def forward(self, memory_bank, tgt):
"""
Args:
memory_bank: [batch_size x seq_length x hidden_size]
tgt: [batch_size x seq_length]
Returns:
loss: Language modeling loss.
correct: Number of words that are predicted correctly.
denominator: Number of predicted words.
"""
# Language modeling (LM) with full softmax prediction.
output = self.output_layer(memory_bank)
output = output.contiguous().view(-1, self.vocab_size)
# Full probability distribution.
output = self.softmax(output)
tgt = tgt.contiguous().view(-1,1)
label_mask = (tgt > 0).float().to(torch.device(output.device))
one_hot = torch.zeros(label_mask.size(0), self.vocab_size). \
to(torch.device(output.device)). \
scatter_(1, tgt, 1.0)
numerator = -torch.sum(output * one_hot, 1)
label_mask = label_mask.contiguous().view(-1)
tgt = tgt.contiguous().view(-1)
numerator = torch.sum(label_mask * numerator)
denominator = torch.sum(label_mask) + 1e-6
loss = numerator / denominator
correct = torch.sum(label_mask * (output.argmax(dim=-1).eq(tgt)).float())
return loss, correct, denominator