add some layers
This commit is contained in:
parent
32d82648f1
commit
d14f5e2047
@ -1,309 +1,290 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Jie Yang
|
||||
# @Date: 2017-12-04 23:19:38
|
||||
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
|
||||
# @Last Modified time: 2018-05-27 22:48:17
|
||||
import torch
|
||||
import torch.autograd as autograd
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
START_TAG = -2
|
||||
STOP_TAG = -1
|
||||
|
||||
|
||||
# Compute log sum exp in a numerically stable way for the forward algorithm
|
||||
def log_sum_exp(vec, m_size):
|
||||
"""
|
||||
calculate log of exp sum
|
||||
args:
|
||||
vec (batch_size, vanishing_dim, hidden_dim) : input tensor (b,t,t)
|
||||
m_size : hidden_dim
|
||||
return:
|
||||
batch_size, hidden_dim
|
||||
"""
|
||||
_, idx = torch.max(vec, 1) # B * 1 * M
|
||||
max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M
|
||||
|
||||
return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size)
|
||||
|
||||
class CRF(nn.Module):
|
||||
"""Conditional random field.
|
||||
|
||||
This module implements a conditional random field [LMP01]_. The forward computation
|
||||
of this class computes the log likelihood of the given sequence of tags and
|
||||
emission score tensor. This class also has `~CRF.decode` method which finds
|
||||
the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
|
||||
|
||||
Args:
|
||||
num_tags: Number of tags.
|
||||
batch_first: Whether the first dimension corresponds to the size of a minibatch.
|
||||
|
||||
Attributes:
|
||||
start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
|
||||
``(num_tags,)``.
|
||||
end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
|
||||
``(num_tags,)``.
|
||||
transitions (`~torch.nn.Parameter`): Transition score tensor of size
|
||||
``(num_tags, num_tags)``.
|
||||
"""
|
||||
|
||||
def __init__(self, num_tags, batch_first):
|
||||
if num_tags <= 0:
|
||||
raise ValueError('invalid number of tags: {}'.format(num_tags))
|
||||
def __init__(self, tagset_size, gpu):
|
||||
super(CRF, self).__init__()
|
||||
self.num_tags = num_tags
|
||||
self.batch_first = batch_first
|
||||
self.start_transitions = nn.Parameter(torch.empty(num_tags))
|
||||
self.end_transitions = nn.Parameter(torch.empty(num_tags))
|
||||
self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
|
||||
print ("build batched crf...")
|
||||
self.gpu = gpu
|
||||
# Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j.
|
||||
self.average_batch = False
|
||||
self.tagset_size = tagset_size
|
||||
# # We add 2 here, because of START_TAG and STOP_TAG
|
||||
# # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag
|
||||
init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
|
||||
# init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
|
||||
# init_transitions[:,START_TAG] = -1000.0
|
||||
# init_transitions[STOP_TAG,:] = -1000.0
|
||||
# init_transitions[:,0] = -1000.0
|
||||
# init_transitions[0,:] = -1000.0
|
||||
if self.gpu:
|
||||
init_transitions = init_transitions.cuda()
|
||||
self.transitions = nn.Parameter(init_transitions) #(t+2,t+2)
|
||||
|
||||
self.reset_parameters()
|
||||
# self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2))
|
||||
# self.transitions.data.zero_()
|
||||
|
||||
def reset_parameters(self):
|
||||
"""Initialize the transition parameters.
|
||||
|
||||
The parameters will be initialized randomly from a uniform distribution
|
||||
between -0.1 and 0.1.
|
||||
def _calculate_PZ(self, feats, mask):
|
||||
"""
|
||||
nn.init.uniform_(self.start_transitions, -0.1, 0.1)
|
||||
nn.init.uniform_(self.end_transitions, -0.1, 0.1)
|
||||
nn.init.uniform_(self.transitions, -0.1, 0.1)
|
||||
|
||||
def __repr__(self):
|
||||
return 'num_tags={}'.format(self.num_tags)
|
||||
|
||||
def forward(self, emissions, tags, mask, reduction):
|
||||
"""Compute the conditional log likelihood of a sequence of tags given emission scores.
|
||||
|
||||
Args:
|
||||
emissions (`~torch.Tensor`): Emission score tensor of size
|
||||
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
|
||||
``(batch_size, seq_length, num_tags)`` otherwise.
|
||||
tags (`~torch.LongTensor`): Sequence of tags tensor of size
|
||||
``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
|
||||
``(batch_size, seq_length)`` otherwise.
|
||||
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
|
||||
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
|
||||
reduction: Specifies the reduction to apply to the output:
|
||||
``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
|
||||
``sum``: the output will be summed over batches. ``mean``: the output will be
|
||||
averaged over batches. ``token_mean``: the output will be averaged over tokens.
|
||||
|
||||
Returns:
|
||||
`~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
|
||||
reduction is ``none``, ``()`` otherwise.
|
||||
input:
|
||||
feats: (batch, seq_len, self.tag_size+2) (b,m,t+2)
|
||||
masks: (batch, seq_len) (b,m)
|
||||
"""
|
||||
self._validate(emissions, tags=tags, mask=mask)
|
||||
if reduction not in ('none', 'sum', 'mean', 'token_mean'):
|
||||
raise ValueError('invalid reduction: {}'.format(reduction))
|
||||
if mask is None:
|
||||
mask = torch.ones_like(tags, dtype=torch.uint8)
|
||||
batch_size = feats.size(0)
|
||||
seq_len = feats.size(1)
|
||||
tag_size = feats.size(2)
|
||||
# print feats.view(seq_len, tag_size)
|
||||
assert(tag_size == self.tagset_size+2)
|
||||
mask = mask.transpose(1,0).contiguous() #(m,b)
|
||||
ins_num = seq_len * batch_size
|
||||
## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
|
||||
feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size)
|
||||
## need to consider start
|
||||
scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
|
||||
scores = scores.view(seq_len, batch_size, tag_size, tag_size)
|
||||
# build iter
|
||||
seq_iter = enumerate(scores) # (index,matrix) index is among the first dim: seqlen
|
||||
_, inivalues = seq_iter.__next__()
|
||||
# only need start from start_tag
|
||||
partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1)
|
||||
|
||||
if self.batch_first:
|
||||
emissions = emissions.transpose(0, 1)
|
||||
tags = tags.transpose(0, 1)
|
||||
mask = mask.transpose(0, 1)
|
||||
## add start score (from start to all tag, duplicate to batch_size)
|
||||
# partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1)
|
||||
# iter over last scores
|
||||
for idx, cur_values in seq_iter:
|
||||
# previous to_target is current from_target
|
||||
# partition: previous results log(exp(from_target)), #(batch_size * from_target)
|
||||
# cur_values: bat_size * from_target * to_target
|
||||
|
||||
cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
|
||||
cur_partition = log_sum_exp(cur_values, tag_size) #(b,t)
|
||||
# print cur_partition.data
|
||||
|
||||
# (bat_size * from_target * to_target) -> (bat_size * to_target)
|
||||
# partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1)
|
||||
mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)
|
||||
|
||||
## effective updated partition part, only keep the partition value of mask value = 1
|
||||
masked_cur_partition = cur_partition.masked_select(mask_idx)
|
||||
## let mask_idx broadcastable, to disable warning
|
||||
mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)
|
||||
|
||||
# shape: (batch_size,)
|
||||
numerator = self._compute_score(emissions, tags, mask)
|
||||
# shape: (batch_size,)
|
||||
denominator = self._compute_normalizer(emissions, mask)
|
||||
# shape: (batch_size,)
|
||||
llh = numerator - denominator
|
||||
## replace the partition where the maskvalue=1, other partition value keeps the same
|
||||
partition.masked_scatter_(mask_idx, masked_cur_partition)
|
||||
# until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG
|
||||
cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
|
||||
cur_partition = log_sum_exp(cur_values, tag_size) #(batch_size,hidden_dim)
|
||||
final_partition = cur_partition[:, STOP_TAG] #(batch_size)
|
||||
return final_partition.sum(), scores #scores: (seq_len, batch, tag_size, tag_size)
|
||||
|
||||
if reduction == 'none':
|
||||
return llh
|
||||
if reduction == 'sum':
|
||||
return llh.sum()
|
||||
if reduction == 'mean':
|
||||
return llh.mean()
|
||||
assert reduction == 'token_mean'
|
||||
return llh.sum() / mask.float().sum()
|
||||
|
||||
def decode(self, emissions, mask=None):
|
||||
"""Find the most likely tag sequence using Viterbi algorithm.
|
||||
|
||||
Args:
|
||||
emissions (`~torch.Tensor`): Emission score tensor of size
|
||||
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
|
||||
``(batch_size, seq_length, num_tags)`` otherwise.
|
||||
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
|
||||
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
|
||||
|
||||
Returns:
|
||||
List of list containing the best tag sequence for each batch.
|
||||
def _viterbi_decode(self, feats, mask):
|
||||
"""
|
||||
self._validate(emissions, tags=None, mask=mask)
|
||||
if mask is None:
|
||||
mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
|
||||
input:
|
||||
feats: (batch, seq_len, self.tag_size+2)
|
||||
mask: (batch, seq_len)
|
||||
output:
|
||||
decode_idx: (batch, seq_len) decoded sequence
|
||||
path_score: (batch, 1) corresponding score for each sequence (to be implementated)
|
||||
"""
|
||||
batch_size = feats.size(0)
|
||||
seq_len = feats.size(1)
|
||||
tag_size = feats.size(2)
|
||||
assert(tag_size == self.tagset_size+2)
|
||||
## calculate sentence length for each sentence
|
||||
length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
|
||||
## mask to (seq_len, batch_size)
|
||||
mask = mask.transpose(1,0).contiguous() #(seq_len,b)
|
||||
ins_num = seq_len * batch_size
|
||||
## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
|
||||
feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) #(ins_num, tag_size, tag_size)
|
||||
## need to consider start
|
||||
scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
|
||||
scores = scores.view(seq_len, batch_size, tag_size, tag_size)
|
||||
|
||||
if self.batch_first:
|
||||
emissions = emissions.transpose(0, 1)
|
||||
mask = mask.transpose(0, 1)
|
||||
# build iter
|
||||
seq_iter = enumerate(scores)
|
||||
## record the position of best score
|
||||
back_points = list()
|
||||
partition_history = list()
|
||||
|
||||
|
||||
## reverse mask (bug for mask = 1- mask, use this as alternative choice)
|
||||
# mask = 1 + (-1)*mask
|
||||
mask = (1 - mask.long()).byte()
|
||||
_, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size
|
||||
# only need start from start_tag
|
||||
partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size
|
||||
#print(partition.size())
|
||||
partition_history.append(partition) #(seqlen,batch_size,tag_size,1)
|
||||
# iter over last scores
|
||||
for idx, cur_values in seq_iter:
|
||||
# previous to_target is current from_target
|
||||
# partition: previous results log(exp(from_target)), #(batch_size * from_target)
|
||||
# cur_values: batch_size * from_target * to_target
|
||||
cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
|
||||
## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG
|
||||
partition, cur_bp = torch.max(cur_values,dim=1)
|
||||
#print(partition.size())
|
||||
partition_history.append(partition.unsqueeze(2))
|
||||
## cur_bp: (batch_size, tag_size) max source score position in current tag
|
||||
## set padded label as 0, which will be filtered in post processing
|
||||
cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
|
||||
back_points.append(cur_bp)
|
||||
### add score to final STOP_TAG
|
||||
partition_history = torch.cat(partition_history,dim=0).view(seq_len, batch_size,-1).transpose(1,0).contiguous() ## (batch_size, seq_len, tag_size)
|
||||
### get the last position for each setences, and select the last partitions using gather()
|
||||
last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1
|
||||
last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1)
|
||||
### calculate the score from last partition to end state (and then select the STOP_TAG from it)
|
||||
last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size)
|
||||
_, last_bp = torch.max(last_values, 1) #(batch_size,tag_size)
|
||||
pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long()
|
||||
if self.gpu:
|
||||
pad_zero = pad_zero.cuda()
|
||||
back_points.append(pad_zero)
|
||||
back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)
|
||||
|
||||
## select end ids in STOP_TAG
|
||||
pointer = last_bp[:, STOP_TAG] #(batch_size)
|
||||
insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size)
|
||||
back_points = back_points.transpose(1,0).contiguous() #(batch_size,sq_len,tag_size)
|
||||
## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values
|
||||
# print "lp:",last_position
|
||||
# print "il:",insert_last
|
||||
back_points.scatter_(1, last_position, insert_last) ##(batch_size,sq_len,tag_size)
|
||||
# print "bp:",back_points
|
||||
# exit(0)
|
||||
back_points = back_points.transpose(1,0).contiguous() #(seq_len, batch_size, tag_size)
|
||||
## decode from the end, padded position ids are 0, which will be filtered if following evaluation
|
||||
decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size))
|
||||
if self.gpu:
|
||||
decode_idx = decode_idx.cuda()
|
||||
decode_idx[-1] = pointer.data
|
||||
for idx in range(len(back_points)-2, -1, -1):
|
||||
pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) #pointer's size:(batch_size,1)
|
||||
decode_idx[idx] = pointer.squeeze(1).data
|
||||
path_score = None
|
||||
decode_idx = decode_idx.transpose(1,0) #(batch_size, sent_len)
|
||||
return path_score, decode_idx #
|
||||
|
||||
return self._viterbi_decode(emissions, mask)
|
||||
|
||||
def _validate(self, emissions, tags, mask):
|
||||
if emissions.dim() != 3:
|
||||
raise ValueError('emissions must have dimension of 3, got {}'.format(emissions.dim()))
|
||||
if emissions.size(2) != self.num_tags:
|
||||
raise ValueError(
|
||||
'expected last dimension of emissions is {}, got {}'.format(self.num_tags, emissions.size(2)))
|
||||
|
||||
if tags is not None:
|
||||
if emissions.shape[:2] != tags.shape:
|
||||
raise ValueError(
|
||||
'the first two dimensions of emissions and tags must match,got {} and {}'.format(
|
||||
tuple(emissions.shape[:2]), tuple(tags.shape)))
|
||||
def forward(self, feats):
|
||||
path_score, best_path = self._viterbi_decode(feats)
|
||||
return path_score, best_path
|
||||
|
||||
|
||||
if mask is not None:
|
||||
if emissions.shape[:2] != mask.shape:
|
||||
raise ValueError(
|
||||
'the first two dimensions of emissions and mask must match, got {} and {}'.format(
|
||||
tuple(emissions.shape[:2]), tuple(mask.shape)))
|
||||
no_empty_seq = not self.batch_first and mask[0].all()
|
||||
no_empty_seq_bf = self.batch_first and mask[:, 0].all()
|
||||
if not no_empty_seq and not no_empty_seq_bf:
|
||||
raise ValueError('mask of the first timestep must all be on')
|
||||
def _score_sentence(self, scores, mask, tags):
|
||||
"""
|
||||
input:
|
||||
scores: variable (seq_len, batch, tag_size, tag_size)
|
||||
mask: (batch, seq_len)
|
||||
tags: tensor (batch, seq_len)
|
||||
output:
|
||||
score: sum of score for gold sequences within whole batch
|
||||
"""
|
||||
# Gives the score of a provided tag sequence
|
||||
batch_size = scores.size(1)
|
||||
seq_len = scores.size(0)
|
||||
tag_size = scores.size(2)
|
||||
## convert tag value into a new format, recorded label bigram information to index
|
||||
new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len))
|
||||
if self.gpu:
|
||||
new_tags = new_tags.cuda()
|
||||
for idx in range(seq_len):
|
||||
if idx == 0:
|
||||
## start -> first score
|
||||
new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0]
|
||||
|
||||
def _compute_score(
|
||||
self, emissions, tags, mask):
|
||||
# emissions: (seq_length, batch_size, num_tags)
|
||||
# tags: (seq_length, batch_size)
|
||||
# mask: (seq_length, batch_size)
|
||||
assert emissions.dim() == 3 and tags.dim() == 2
|
||||
assert emissions.shape[:2] == tags.shape
|
||||
assert emissions.size(2) == self.num_tags
|
||||
assert mask.shape == tags.shape
|
||||
assert mask[0].all()
|
||||
else:
|
||||
new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx]
|
||||
|
||||
seq_length, batch_size = tags.shape
|
||||
mask = mask.float()
|
||||
## transition for label to STOP_TAG
|
||||
end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size)
|
||||
## length for batch, last word position = length - 1
|
||||
length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
|
||||
## index the label id of last word
|
||||
end_ids = torch.gather(tags, 1, length_mask - 1)
|
||||
|
||||
# Start transition score and first emission
|
||||
# shape: (batch_size,)
|
||||
score = self.start_transitions[tags[0]]
|
||||
score += emissions[0, torch.arange(batch_size), tags[0]]
|
||||
## index the transition score for end_id to STOP_TAG
|
||||
end_energy = torch.gather(end_transition, 1, end_ids)
|
||||
|
||||
for i in range(1, seq_length):
|
||||
# Transition score to next tag, only added if next timestep is valid (mask == 1)
|
||||
# shape: (batch_size,)
|
||||
score += self.transitions[tags[i - 1], tags[i]] * mask[i]
|
||||
## convert tag as (seq_len, batch_size, 1)
|
||||
new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1)
|
||||
### need convert tags id to search from 400 positions of scores
|
||||
tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size
|
||||
## mask transpose to (seq_len, batch_size)
|
||||
tg_energy = tg_energy.masked_select(mask.transpose(1,0))
|
||||
|
||||
# ## calculate the score from START_TAG to first label
|
||||
# start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size)
|
||||
# start_energy = torch.gather(start_transition, 1, tags[0,:])
|
||||
|
||||
# Emission score for next tag, only added if next timestep is valid (mask == 1)
|
||||
# shape: (batch_size,)
|
||||
score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
|
||||
## add all score together
|
||||
# gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum()
|
||||
gold_score = tg_energy.sum() + end_energy.sum()
|
||||
return gold_score
|
||||
|
||||
# End transition score
|
||||
# shape: (batch_size,)
|
||||
seq_ends = mask.long().sum(dim=0) - 1
|
||||
# shape: (batch_size,)
|
||||
last_tags = tags[seq_ends, torch.arange(batch_size)]
|
||||
# shape: (batch_size,)
|
||||
score += self.end_transitions[last_tags]
|
||||
def neg_log_likelihood_loss(self, feats, mask, tags):
|
||||
# nonegative log likelihood
|
||||
batch_size = feats.size(0)
|
||||
forward_score, scores = self._calculate_PZ(feats, mask) #forward_score:long, scores: (seq_len, batch, tag_size, tag_size)
|
||||
gold_score = self._score_sentence(scores, mask, tags)
|
||||
#print ("batch, f:", forward_score.data, " g:", gold_score.data, " dis:", forward_score.data - gold_score.data)
|
||||
# exit(0)
|
||||
if self.average_batch:
|
||||
return (forward_score - gold_score)/batch_size
|
||||
else:
|
||||
return forward_score - gold_score
|
||||
|
||||
return score
|
||||
|
||||
def _compute_normalizer(self, emissions, mask):
|
||||
# emissions: (seq_length, batch_size, num_tags)
|
||||
# mask: (seq_length, batch_size)
|
||||
assert emissions.dim() == 3 and mask.dim() == 2
|
||||
assert emissions.shape[:2] == mask.shape
|
||||
assert emissions.size(2) == self.num_tags
|
||||
assert mask[0].all()
|
||||
|
||||
seq_length = emissions.size(0)
|
||||
|
||||
# Start transition score and first emission; score has size of
|
||||
# (batch_size, num_tags) where for each batch, the j-th column stores
|
||||
# the score that the first timestep has tag j
|
||||
# shape: (batch_size, num_tags)
|
||||
score = self.start_transitions + emissions[0]
|
||||
|
||||
for i in range(1, seq_length):
|
||||
# Broadcast score for every possible next tag
|
||||
# shape: (batch_size, num_tags, 1)
|
||||
broadcast_score = score.unsqueeze(2)
|
||||
|
||||
# Broadcast emission score for every possible current tag
|
||||
# shape: (batch_size, 1, num_tags)
|
||||
broadcast_emissions = emissions[i].unsqueeze(1)
|
||||
|
||||
# Compute the score tensor of size (batch_size, num_tags, num_tags) where
|
||||
# for each sample, entry at row i and column j stores the sum of scores of all
|
||||
# possible tag sequences so far that end with transitioning from tag i to tag j
|
||||
# and emitting
|
||||
# shape: (batch_size, num_tags, num_tags)
|
||||
next_score = broadcast_score + self.transitions + broadcast_emissions
|
||||
|
||||
# Sum over all possible current tags, but we're in score space, so a sum
|
||||
# becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
|
||||
# all possible tag sequences so far, that end in tag i
|
||||
# shape: (batch_size, num_tags)
|
||||
next_score = torch.logsumexp(next_score, dim=1)
|
||||
|
||||
# Set score to the next score if this timestep is valid (mask == 1)
|
||||
# shape: (batch_size, num_tags)
|
||||
score = torch.where(mask[i].unsqueeze(1), next_score, score)
|
||||
|
||||
# End transition score
|
||||
# shape: (batch_size, num_tags)
|
||||
score += self.end_transitions
|
||||
|
||||
# Sum (log-sum-exp) over all possible tags
|
||||
# shape: (batch_size,)
|
||||
return torch.logsumexp(score, dim=1)
|
||||
|
||||
def _viterbi_decode(self, emissions, mask):
|
||||
# emissions: (seq_length, batch_size, num_tags)
|
||||
# mask: (seq_length, batch_size)
|
||||
assert emissions.dim() == 3 and mask.dim() == 2
|
||||
assert emissions.shape[:2] == mask.shape
|
||||
assert emissions.size(2) == self.num_tags
|
||||
assert mask[0].all()
|
||||
|
||||
seq_length, batch_size = mask.shape
|
||||
|
||||
# Start transition and first emission
|
||||
# shape: (batch_size, num_tags)
|
||||
score = self.start_transitions + emissions[0]
|
||||
history = []
|
||||
|
||||
# score is a tensor of size (batch_size, num_tags) where for every batch,
|
||||
# value at column j stores the score of the best tag sequence so far that ends
|
||||
# with tag j
|
||||
# history saves where the best tags candidate transitioned from; this is used
|
||||
# when we trace back the best tag sequence
|
||||
|
||||
# Viterbi algorithm recursive case: we compute the score of the best tag sequence
|
||||
# for every possible next tag
|
||||
for i in range(1, seq_length):
|
||||
# Broadcast viterbi score for every possible next tag
|
||||
# shape: (batch_size, num_tags, 1)
|
||||
broadcast_score = score.unsqueeze(2)
|
||||
|
||||
# Broadcast emission score for every possible current tag
|
||||
# shape: (batch_size, 1, num_tags)
|
||||
broadcast_emission = emissions[i].unsqueeze(1)
|
||||
|
||||
# Compute the score tensor of size (batch_size, num_tags, num_tags) where
|
||||
# for each sample, entry at row i and column j stores the score of the best
|
||||
# tag sequence so far that ends with transitioning from tag i to tag j and emitting
|
||||
# shape: (batch_size, num_tags, num_tags)
|
||||
next_score = broadcast_score + self.transitions + broadcast_emission
|
||||
|
||||
# Find the maximum score over all possible current tag
|
||||
# shape: (batch_size, num_tags)
|
||||
next_score, indices = next_score.max(dim=1)
|
||||
|
||||
# Set score to the next score if this timestep is valid (mask == 1)
|
||||
# and save the index that produces the next score
|
||||
# shape: (batch_size, num_tags)
|
||||
score = torch.where(mask[i].unsqueeze(1), next_score, score)
|
||||
history.append(indices)
|
||||
|
||||
# End transition score
|
||||
# shape: (batch_size, num_tags)
|
||||
score += self.end_transitions
|
||||
|
||||
# Now, compute the best path for each sample
|
||||
|
||||
# shape: (batch_size,)
|
||||
seq_ends = mask.long().sum(dim=0) - 1
|
||||
best_tags_list = []
|
||||
|
||||
for idx in range(batch_size):
|
||||
# Find the tag which maximizes the score at the last timestep; this is our best tag
|
||||
# for the last timestep
|
||||
_, best_last_tag = score[idx].max(dim=0)
|
||||
best_tags = [best_last_tag.item()]
|
||||
|
||||
# We trace back where the best last tag comes from, append that to our best tag
|
||||
# sequence, and trace it back again, and so on
|
||||
for hist in reversed(history[:seq_ends[idx]]):
|
||||
best_last_tag = hist[idx][best_tags[-1]]
|
||||
best_tags.append(best_last_tag.item())
|
||||
|
||||
# Reverse the order because we start from the last timestep
|
||||
best_tags.reverse()
|
||||
best_tags_list.append(best_tags)
|
||||
|
||||
return best_tags_list
|
||||
|
309
layers/decoders/pytorch_crf.py
Normal file
309
layers/decoders/pytorch_crf.py
Normal file
@ -0,0 +1,309 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CRF(nn.Module):
|
||||
"""Conditional random field.
|
||||
|
||||
This module implements a conditional random field [LMP01]_. The forward computation
|
||||
of this class computes the log likelihood of the given sequence of tags and
|
||||
emission score tensor. This class also has `~CRF.decode` method which finds
|
||||
the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
|
||||
|
||||
Args:
|
||||
num_tags: Number of tags.
|
||||
batch_first: Whether the first dimension corresponds to the size of a minibatch.
|
||||
|
||||
Attributes:
|
||||
start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
|
||||
``(num_tags,)``.
|
||||
end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
|
||||
``(num_tags,)``.
|
||||
transitions (`~torch.nn.Parameter`): Transition score tensor of size
|
||||
``(num_tags, num_tags)``.
|
||||
"""
|
||||
|
||||
def __init__(self, num_tags, batch_first):
|
||||
if num_tags <= 0:
|
||||
raise ValueError('invalid number of tags: {}'.format(num_tags))
|
||||
super(CRF, self).__init__()
|
||||
self.num_tags = num_tags
|
||||
self.batch_first = batch_first
|
||||
self.start_transitions = nn.Parameter(torch.empty(num_tags))
|
||||
self.end_transitions = nn.Parameter(torch.empty(num_tags))
|
||||
self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
"""Initialize the transition parameters.
|
||||
|
||||
The parameters will be initialized randomly from a uniform distribution
|
||||
between -0.1 and 0.1.
|
||||
"""
|
||||
nn.init.uniform_(self.start_transitions, -0.1, 0.1)
|
||||
nn.init.uniform_(self.end_transitions, -0.1, 0.1)
|
||||
nn.init.uniform_(self.transitions, -0.1, 0.1)
|
||||
|
||||
def __repr__(self):
|
||||
return 'num_tags={}'.format(self.num_tags)
|
||||
|
||||
def forward(self, emissions, tags, mask, reduction):
|
||||
"""Compute the conditional log likelihood of a sequence of tags given emission scores.
|
||||
|
||||
Args:
|
||||
emissions (`~torch.Tensor`): Emission score tensor of size
|
||||
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
|
||||
``(batch_size, seq_length, num_tags)`` otherwise.
|
||||
tags (`~torch.LongTensor`): Sequence of tags tensor of size
|
||||
``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
|
||||
``(batch_size, seq_length)`` otherwise.
|
||||
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
|
||||
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
|
||||
reduction: Specifies the reduction to apply to the output:
|
||||
``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
|
||||
``sum``: the output will be summed over batches. ``mean``: the output will be
|
||||
averaged over batches. ``token_mean``: the output will be averaged over tokens.
|
||||
|
||||
Returns:
|
||||
`~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
|
||||
reduction is ``none``, ``()`` otherwise.
|
||||
"""
|
||||
self._validate(emissions, tags=tags, mask=mask)
|
||||
if reduction not in ('none', 'sum', 'mean', 'token_mean'):
|
||||
raise ValueError('invalid reduction: {}'.format(reduction))
|
||||
if mask is None:
|
||||
mask = torch.ones_like(tags, dtype=torch.uint8)
|
||||
|
||||
if self.batch_first:
|
||||
emissions = emissions.transpose(0, 1)
|
||||
tags = tags.transpose(0, 1)
|
||||
mask = mask.transpose(0, 1)
|
||||
|
||||
# shape: (batch_size,)
|
||||
numerator = self._compute_score(emissions, tags, mask)
|
||||
# shape: (batch_size,)
|
||||
denominator = self._compute_normalizer(emissions, mask)
|
||||
# shape: (batch_size,)
|
||||
llh = numerator - denominator
|
||||
|
||||
if reduction == 'none':
|
||||
return llh
|
||||
if reduction == 'sum':
|
||||
return llh.sum()
|
||||
if reduction == 'mean':
|
||||
return llh.mean()
|
||||
assert reduction == 'token_mean'
|
||||
return llh.sum() / mask.float().sum()
|
||||
|
||||
def decode(self, emissions, mask=None):
|
||||
"""Find the most likely tag sequence using Viterbi algorithm.
|
||||
|
||||
Args:
|
||||
emissions (`~torch.Tensor`): Emission score tensor of size
|
||||
``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
|
||||
``(batch_size, seq_length, num_tags)`` otherwise.
|
||||
mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
|
||||
if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
|
||||
|
||||
Returns:
|
||||
List of list containing the best tag sequence for each batch.
|
||||
"""
|
||||
self._validate(emissions, tags=None, mask=mask)
|
||||
if mask is None:
|
||||
mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
|
||||
|
||||
if self.batch_first:
|
||||
emissions = emissions.transpose(0, 1)
|
||||
mask = mask.transpose(0, 1)
|
||||
|
||||
return self._viterbi_decode(emissions, mask)
|
||||
|
||||
def _validate(self, emissions, tags, mask):
|
||||
if emissions.dim() != 3:
|
||||
raise ValueError('emissions must have dimension of 3, got {}'.format(emissions.dim()))
|
||||
if emissions.size(2) != self.num_tags:
|
||||
raise ValueError(
|
||||
'expected last dimension of emissions is {}, got {}'.format(self.num_tags, emissions.size(2)))
|
||||
|
||||
if tags is not None:
|
||||
if emissions.shape[:2] != tags.shape:
|
||||
raise ValueError(
|
||||
'the first two dimensions of emissions and tags must match,got {} and {}'.format(
|
||||
tuple(emissions.shape[:2]), tuple(tags.shape)))
|
||||
|
||||
if mask is not None:
|
||||
if emissions.shape[:2] != mask.shape:
|
||||
raise ValueError(
|
||||
'the first two dimensions of emissions and mask must match, got {} and {}'.format(
|
||||
tuple(emissions.shape[:2]), tuple(mask.shape)))
|
||||
no_empty_seq = not self.batch_first and mask[0].all()
|
||||
no_empty_seq_bf = self.batch_first and mask[:, 0].all()
|
||||
if not no_empty_seq and not no_empty_seq_bf:
|
||||
raise ValueError('mask of the first timestep must all be on')
|
||||
|
||||
def _compute_score(
|
||||
self, emissions, tags, mask):
|
||||
# emissions: (seq_length, batch_size, num_tags)
|
||||
# tags: (seq_length, batch_size)
|
||||
# mask: (seq_length, batch_size)
|
||||
assert emissions.dim() == 3 and tags.dim() == 2
|
||||
assert emissions.shape[:2] == tags.shape
|
||||
assert emissions.size(2) == self.num_tags
|
||||
assert mask.shape == tags.shape
|
||||
assert mask[0].all()
|
||||
|
||||
seq_length, batch_size = tags.shape
|
||||
mask = mask.float()
|
||||
|
||||
# Start transition score and first emission
|
||||
# shape: (batch_size,)
|
||||
score = self.start_transitions[tags[0]]
|
||||
score += emissions[0, torch.arange(batch_size), tags[0]]
|
||||
|
||||
for i in range(1, seq_length):
|
||||
# Transition score to next tag, only added if next timestep is valid (mask == 1)
|
||||
# shape: (batch_size,)
|
||||
score += self.transitions[tags[i - 1], tags[i]] * mask[i]
|
||||
|
||||
# Emission score for next tag, only added if next timestep is valid (mask == 1)
|
||||
# shape: (batch_size,)
|
||||
score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
|
||||
|
||||
# End transition score
|
||||
# shape: (batch_size,)
|
||||
seq_ends = mask.long().sum(dim=0) - 1
|
||||
# shape: (batch_size,)
|
||||
last_tags = tags[seq_ends, torch.arange(batch_size)]
|
||||
# shape: (batch_size,)
|
||||
score += self.end_transitions[last_tags]
|
||||
|
||||
return score
|
||||
|
||||
def _compute_normalizer(self, emissions, mask):
|
||||
# emissions: (seq_length, batch_size, num_tags)
|
||||
# mask: (seq_length, batch_size)
|
||||
assert emissions.dim() == 3 and mask.dim() == 2
|
||||
assert emissions.shape[:2] == mask.shape
|
||||
assert emissions.size(2) == self.num_tags
|
||||
assert mask[0].all()
|
||||
|
||||
seq_length = emissions.size(0)
|
||||
|
||||
# Start transition score and first emission; score has size of
|
||||
# (batch_size, num_tags) where for each batch, the j-th column stores
|
||||
# the score that the first timestep has tag j
|
||||
# shape: (batch_size, num_tags)
|
||||
score = self.start_transitions + emissions[0]
|
||||
|
||||
for i in range(1, seq_length):
|
||||
# Broadcast score for every possible next tag
|
||||
# shape: (batch_size, num_tags, 1)
|
||||
broadcast_score = score.unsqueeze(2)
|
||||
|
||||
# Broadcast emission score for every possible current tag
|
||||
# shape: (batch_size, 1, num_tags)
|
||||
broadcast_emissions = emissions[i].unsqueeze(1)
|
||||
|
||||
# Compute the score tensor of size (batch_size, num_tags, num_tags) where
|
||||
# for each sample, entry at row i and column j stores the sum of scores of all
|
||||
# possible tag sequences so far that end with transitioning from tag i to tag j
|
||||
# and emitting
|
||||
# shape: (batch_size, num_tags, num_tags)
|
||||
next_score = broadcast_score + self.transitions + broadcast_emissions
|
||||
|
||||
# Sum over all possible current tags, but we're in score space, so a sum
|
||||
# becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
|
||||
# all possible tag sequences so far, that end in tag i
|
||||
# shape: (batch_size, num_tags)
|
||||
next_score = torch.logsumexp(next_score, dim=1)
|
||||
|
||||
# Set score to the next score if this timestep is valid (mask == 1)
|
||||
# shape: (batch_size, num_tags)
|
||||
score = torch.where(mask[i].unsqueeze(1), next_score, score)
|
||||
|
||||
# End transition score
|
||||
# shape: (batch_size, num_tags)
|
||||
score += self.end_transitions
|
||||
|
||||
# Sum (log-sum-exp) over all possible tags
|
||||
# shape: (batch_size,)
|
||||
return torch.logsumexp(score, dim=1)
|
||||
|
||||
def _viterbi_decode(self, emissions, mask):
|
||||
# emissions: (seq_length, batch_size, num_tags)
|
||||
# mask: (seq_length, batch_size)
|
||||
assert emissions.dim() == 3 and mask.dim() == 2
|
||||
assert emissions.shape[:2] == mask.shape
|
||||
assert emissions.size(2) == self.num_tags
|
||||
assert mask[0].all()
|
||||
|
||||
seq_length, batch_size = mask.shape
|
||||
|
||||
# Start transition and first emission
|
||||
# shape: (batch_size, num_tags)
|
||||
score = self.start_transitions + emissions[0]
|
||||
history = []
|
||||
|
||||
# score is a tensor of size (batch_size, num_tags) where for every batch,
|
||||
# value at column j stores the score of the best tag sequence so far that ends
|
||||
# with tag j
|
||||
# history saves where the best tags candidate transitioned from; this is used
|
||||
# when we trace back the best tag sequence
|
||||
|
||||
# Viterbi algorithm recursive case: we compute the score of the best tag sequence
|
||||
# for every possible next tag
|
||||
for i in range(1, seq_length):
|
||||
# Broadcast viterbi score for every possible next tag
|
||||
# shape: (batch_size, num_tags, 1)
|
||||
broadcast_score = score.unsqueeze(2)
|
||||
|
||||
# Broadcast emission score for every possible current tag
|
||||
# shape: (batch_size, 1, num_tags)
|
||||
broadcast_emission = emissions[i].unsqueeze(1)
|
||||
|
||||
# Compute the score tensor of size (batch_size, num_tags, num_tags) where
|
||||
# for each sample, entry at row i and column j stores the score of the best
|
||||
# tag sequence so far that ends with transitioning from tag i to tag j and emitting
|
||||
# shape: (batch_size, num_tags, num_tags)
|
||||
next_score = broadcast_score + self.transitions + broadcast_emission
|
||||
|
||||
# Find the maximum score over all possible current tag
|
||||
# shape: (batch_size, num_tags)
|
||||
next_score, indices = next_score.max(dim=1)
|
||||
|
||||
# Set score to the next score if this timestep is valid (mask == 1)
|
||||
# and save the index that produces the next score
|
||||
# shape: (batch_size, num_tags)
|
||||
score = torch.where(mask[i].unsqueeze(1), next_score, score)
|
||||
history.append(indices)
|
||||
|
||||
# End transition score
|
||||
# shape: (batch_size, num_tags)
|
||||
score += self.end_transitions
|
||||
|
||||
# Now, compute the best path for each sample
|
||||
|
||||
# shape: (batch_size,)
|
||||
seq_ends = mask.long().sum(dim=0) - 1
|
||||
best_tags_list = []
|
||||
|
||||
for idx in range(batch_size):
|
||||
# Find the tag which maximizes the score at the last timestep; this is our best tag
|
||||
# for the last timestep
|
||||
_, best_last_tag = score[idx].max(dim=0)
|
||||
best_tags = [best_last_tag.item()]
|
||||
|
||||
# We trace back where the best last tag comes from, append that to our best tag
|
||||
# sequence, and trace it back again, and so on
|
||||
for hist in reversed(history[:seq_ends[idx]]):
|
||||
best_last_tag = hist[idx][best_tags[-1]]
|
||||
best_tags.append(best_last_tag.item())
|
||||
|
||||
# Reverse the order because we start from the last timestep
|
||||
best_tags.reverse()
|
||||
best_tags_list.append(best_tags)
|
||||
|
||||
return best_tags_list
|
241
layers/encoders/ner_layers.py
Normal file
241
layers/encoders/ner_layers.py
Normal file
@ -0,0 +1,241 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import torch.autograd as autograd
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import math, copy, time
|
||||
|
||||
class CNNmodel(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, num_layer, dropout, gpu=True):
|
||||
super(CNNmodel, self).__init__()
|
||||
self.input_dim = input_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_layer = num_layer
|
||||
self.gpu = gpu
|
||||
|
||||
self.cnn_layer0 = nn.Conv1d(self.input_dim, self.hidden_dim, kernel_size=1, padding=0)
|
||||
self.cnn_layers = [nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) for i in range(self.num_layer-1)]
|
||||
self.drop = nn.Dropout(dropout)
|
||||
|
||||
if self.gpu:
|
||||
self.cnn_layer0 = self.cnn_layer0.cuda()
|
||||
for i in range(self.num_layer-1):
|
||||
self.cnn_layers[i] = self.cnn_layers[i].cuda()
|
||||
|
||||
def forward(self, input_feature):
|
||||
|
||||
batch_size = input_feature.size(0)
|
||||
seq_len = input_feature.size(1)
|
||||
|
||||
input_feature = input_feature.transpose(2,1).contiguous()
|
||||
cnn_output = self.cnn_layer0(input_feature) #(b,h,l)
|
||||
cnn_output = self.drop(cnn_output)
|
||||
cnn_output = torch.tanh(cnn_output)
|
||||
|
||||
for layer in range(self.num_layer-1):
|
||||
cnn_output = self.cnn_layers[layer](cnn_output)
|
||||
cnn_output = self.drop(cnn_output)
|
||||
cnn_output = torch.tanh(cnn_output)
|
||||
|
||||
cnn_output = cnn_output.transpose(2,1).contiguous()
|
||||
return cnn_output
|
||||
|
||||
|
||||
|
||||
def clones(module, N):
|
||||
"Produce N identical layers."
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
"Construct a layernorm module (See citation for details)."
|
||||
def __init__(self, features, eps=1e-6):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.a_2 = nn.Parameter(torch.ones(features))
|
||||
self.b_2 = nn.Parameter(torch.zeros(features))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
mean = x.mean(-1, keepdim=True)
|
||||
std = x.std(-1, keepdim=True)
|
||||
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
||||
|
||||
class SublayerConnection(nn.Module):
|
||||
"""
|
||||
A residual connection followed by a layer norm.
|
||||
Note for code simplicity the norm is first as opposed to last.
|
||||
"""
|
||||
def __init__(self, size, dropout):
|
||||
super(SublayerConnection, self).__init__()
|
||||
self.norm = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x, sublayer):
|
||||
"Apply residual connection to any sublayer with the same size."
|
||||
return x + self.dropout(sublayer(self.norm(x)))
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"Encoder is made up of self-attn and feed forward (defined below)"
|
||||
def __init__(self, size, self_attn, feed_forward, dropout):
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.sublayer = clones(SublayerConnection(size, dropout), 2)
|
||||
self.size = size
|
||||
|
||||
def forward(self, x, mask):
|
||||
"Follow Figure 1 (left) for connections."
|
||||
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
||||
return self.sublayer[1](x, self.feed_forward)
|
||||
|
||||
|
||||
def attention(query, key, value, mask=None, dropout=None):
|
||||
"Compute 'Scaled Dot Product Attention'"
|
||||
d_k = query.size(-1)
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) \
|
||||
/ math.sqrt(d_k) ## (b,h,l,d) * (b,h,d,l)
|
||||
if mask is not None:
|
||||
# scores = scores.masked_fill(mask == 0, -1e9)
|
||||
scores = scores.masked_fill(mask, -1e9)
|
||||
p_attn = F.softmax(scores, dim = -1)
|
||||
if dropout is not None:
|
||||
p_attn = dropout(p_attn)
|
||||
return torch.matmul(p_attn, value), p_attn ##(b,h,l,l) * (b,h,l,d) = (b,h,l,d)
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
def __init__(self, h, d_model, dropout=0.1):
|
||||
"Take in model size and number of heads."
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
assert d_model % h == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = d_model // h
|
||||
self.h = h
|
||||
self.linears = clones(nn.Linear(d_model, d_model), 4)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
def forward(self, query, key, value, mask=None):
|
||||
"Implements Figure 2"
|
||||
if mask is not None:
|
||||
# Same mask applied to all h heads.
|
||||
mask = mask.unsqueeze(1)
|
||||
nbatches = query.size(0)
|
||||
|
||||
# 1) Do all the linear projections in batch from d_model => h x d_k
|
||||
query, key, value = \
|
||||
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
||||
for l, x in zip(self.linears, (query, key, value))]
|
||||
|
||||
# 2) Apply attention on all the projected vectors in batch.
|
||||
x, self.attn = attention(query, key, value, mask=mask,
|
||||
dropout=self.dropout)
|
||||
|
||||
# 3) "Concat" using a view and apply a final linear.
|
||||
x = x.transpose(1, 2).contiguous() \
|
||||
.view(nbatches, -1, self.h * self.d_k)
|
||||
return self.linears[-1](x)
|
||||
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
"Implements FFN equation."
|
||||
def __init__(self, d_model, d_ff, dropout=0.1):
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = nn.Linear(d_model, d_ff)
|
||||
self.w_2 = nn.Linear(d_ff, d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"Implement the PE function."
|
||||
def __init__(self, d_model, dropout, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Compute the positional encodings once in log space.
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0., max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0., d_model, 2) *
|
||||
-(math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + autograd.Variable(self.pe[:, :x.size(1)],
|
||||
requires_grad=False)
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class AttentionModel(nn.Module):
|
||||
"Core encoder is a stack of N layers"
|
||||
def __init__(self, d_input, d_model, d_ff, head, num_layer, dropout):
|
||||
super(AttentionModel, self).__init__()
|
||||
c = copy.deepcopy
|
||||
# attn0 = MultiHeadedAttention(head, d_input, d_model)
|
||||
attn = MultiHeadedAttention(head, d_model, dropout)
|
||||
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
|
||||
# position = PositionalEncoding(d_model, dropout)
|
||||
# layer0 = EncoderLayer(d_model, c(attn0), c(ff), dropout)
|
||||
layer = EncoderLayer(d_model, c(attn), c(ff), dropout)
|
||||
self.layers = clones(layer, num_layer)
|
||||
# layerlist = [copy.deepcopy(layer0),]
|
||||
# for _ in range(num_layer-1):
|
||||
# layerlist.append(copy.deepcopy(layer))
|
||||
# self.layers = nn.ModuleList(layerlist)
|
||||
self.norm = LayerNorm(layer.size)
|
||||
self.posi = PositionalEncoding(d_model, dropout)
|
||||
self.input2model = nn.Linear(d_input, d_model)
|
||||
|
||||
def forward(self, x, mask):
|
||||
"Pass the input (and mask) through each layer in turn."
|
||||
# x: embedding (b,l,we)
|
||||
x = self.posi(self.input2model(x))
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask)
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
|
||||
|
||||
class NERmodel(nn.Module):
|
||||
|
||||
def __init__(self, model_type, input_dim, hidden_dim, num_layer, dropout=0.5, gpu=True, biflag=True):
|
||||
super(NERmodel, self).__init__()
|
||||
self.model_type = model_type
|
||||
|
||||
if self.model_type == 'lstm':
|
||||
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layer, batch_first=True, bidirectional=biflag)
|
||||
self.drop = nn.Dropout(dropout)
|
||||
|
||||
if self.model_type == 'cnn':
|
||||
self.cnn = CNNmodel(input_dim, hidden_dim, num_layer, dropout, gpu)
|
||||
|
||||
## attention model
|
||||
if self.model_type == 'transformer':
|
||||
self.attention_model = AttentionModel(d_input=input_dim, d_model=hidden_dim, d_ff=2*hidden_dim, head=4, num_layer=num_layer, dropout=dropout)
|
||||
for p in self.attention_model.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
|
||||
def forward(self, input, mask=None):
|
||||
|
||||
if self.model_type == 'lstm':
|
||||
hidden = None
|
||||
feature_out, hidden = self.lstm(input, hidden)
|
||||
|
||||
feature_out_d = self.drop(feature_out)
|
||||
|
||||
if self.model_type == 'cnn':
|
||||
feature_out_d = self.cnn(input)
|
||||
|
||||
if self.model_type == 'transformer':
|
||||
feature_out_d = self.attention_model(input, mask)
|
||||
|
||||
return feature_out_d
|
||||
|
@ -1,290 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Author: Jie Yang
|
||||
# @Date: 2017-12-04 23:19:38
|
||||
# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com
|
||||
# @Last Modified time: 2018-05-27 22:48:17
|
||||
import torch
|
||||
import torch.autograd as autograd
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
START_TAG = -2
|
||||
STOP_TAG = -1
|
||||
|
||||
|
||||
# Compute log sum exp in a numerically stable way for the forward algorithm
|
||||
def log_sum_exp(vec, m_size):
|
||||
"""
|
||||
calculate log of exp sum
|
||||
args:
|
||||
vec (batch_size, vanishing_dim, hidden_dim) : input tensor (b,t,t)
|
||||
m_size : hidden_dim
|
||||
return:
|
||||
batch_size, hidden_dim
|
||||
"""
|
||||
_, idx = torch.max(vec, 1) # B * 1 * M
|
||||
max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M
|
||||
|
||||
return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size)
|
||||
|
||||
class CRF(nn.Module):
|
||||
|
||||
def __init__(self, tagset_size, gpu):
|
||||
super(CRF, self).__init__()
|
||||
print ("build batched crf...")
|
||||
self.gpu = gpu
|
||||
# Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j.
|
||||
self.average_batch = False
|
||||
self.tagset_size = tagset_size
|
||||
# # We add 2 here, because of START_TAG and STOP_TAG
|
||||
# # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag
|
||||
init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
|
||||
# init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2)
|
||||
# init_transitions[:,START_TAG] = -1000.0
|
||||
# init_transitions[STOP_TAG,:] = -1000.0
|
||||
# init_transitions[:,0] = -1000.0
|
||||
# init_transitions[0,:] = -1000.0
|
||||
if self.gpu:
|
||||
init_transitions = init_transitions.cuda()
|
||||
self.transitions = nn.Parameter(init_transitions) #(t+2,t+2)
|
||||
|
||||
# self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2))
|
||||
# self.transitions.data.zero_()
|
||||
|
||||
def _calculate_PZ(self, feats, mask):
|
||||
"""
|
||||
input:
|
||||
feats: (batch, seq_len, self.tag_size+2) (b,m,t+2)
|
||||
masks: (batch, seq_len) (b,m)
|
||||
"""
|
||||
batch_size = feats.size(0)
|
||||
seq_len = feats.size(1)
|
||||
tag_size = feats.size(2)
|
||||
# print feats.view(seq_len, tag_size)
|
||||
assert(tag_size == self.tagset_size+2)
|
||||
mask = mask.transpose(1,0).contiguous() #(m,b)
|
||||
ins_num = seq_len * batch_size
|
||||
## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
|
||||
feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size)
|
||||
## need to consider start
|
||||
scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
|
||||
scores = scores.view(seq_len, batch_size, tag_size, tag_size)
|
||||
# build iter
|
||||
seq_iter = enumerate(scores) # (index,matrix) index is among the first dim: seqlen
|
||||
_, inivalues = seq_iter.__next__()
|
||||
# only need start from start_tag
|
||||
partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1)
|
||||
|
||||
## add start score (from start to all tag, duplicate to batch_size)
|
||||
# partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1)
|
||||
# iter over last scores
|
||||
for idx, cur_values in seq_iter:
|
||||
# previous to_target is current from_target
|
||||
# partition: previous results log(exp(from_target)), #(batch_size * from_target)
|
||||
# cur_values: bat_size * from_target * to_target
|
||||
|
||||
cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
|
||||
cur_partition = log_sum_exp(cur_values, tag_size) #(b,t)
|
||||
# print cur_partition.data
|
||||
|
||||
# (bat_size * from_target * to_target) -> (bat_size * to_target)
|
||||
# partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1)
|
||||
mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)
|
||||
|
||||
## effective updated partition part, only keep the partition value of mask value = 1
|
||||
masked_cur_partition = cur_partition.masked_select(mask_idx)
|
||||
## let mask_idx broadcastable, to disable warning
|
||||
mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)
|
||||
|
||||
## replace the partition where the maskvalue=1, other partition value keeps the same
|
||||
partition.masked_scatter_(mask_idx, masked_cur_partition)
|
||||
# until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG
|
||||
cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
|
||||
cur_partition = log_sum_exp(cur_values, tag_size) #(batch_size,hidden_dim)
|
||||
final_partition = cur_partition[:, STOP_TAG] #(batch_size)
|
||||
return final_partition.sum(), scores #scores: (seq_len, batch, tag_size, tag_size)
|
||||
|
||||
|
||||
def _viterbi_decode(self, feats, mask):
|
||||
"""
|
||||
input:
|
||||
feats: (batch, seq_len, self.tag_size+2)
|
||||
mask: (batch, seq_len)
|
||||
output:
|
||||
decode_idx: (batch, seq_len) decoded sequence
|
||||
path_score: (batch, 1) corresponding score for each sequence (to be implementated)
|
||||
"""
|
||||
batch_size = feats.size(0)
|
||||
seq_len = feats.size(1)
|
||||
tag_size = feats.size(2)
|
||||
assert(tag_size == self.tagset_size+2)
|
||||
## calculate sentence length for each sentence
|
||||
length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
|
||||
## mask to (seq_len, batch_size)
|
||||
mask = mask.transpose(1,0).contiguous() #(seq_len,b)
|
||||
ins_num = seq_len * batch_size
|
||||
## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
|
||||
feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) #(ins_num, tag_size, tag_size)
|
||||
## need to consider start
|
||||
scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
|
||||
scores = scores.view(seq_len, batch_size, tag_size, tag_size)
|
||||
|
||||
# build iter
|
||||
seq_iter = enumerate(scores)
|
||||
## record the position of best score
|
||||
back_points = list()
|
||||
partition_history = list()
|
||||
|
||||
|
||||
## reverse mask (bug for mask = 1- mask, use this as alternative choice)
|
||||
# mask = 1 + (-1)*mask
|
||||
mask = (1 - mask.long()).byte()
|
||||
_, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size
|
||||
# only need start from start_tag
|
||||
partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size
|
||||
#print(partition.size())
|
||||
partition_history.append(partition) #(seqlen,batch_size,tag_size,1)
|
||||
# iter over last scores
|
||||
for idx, cur_values in seq_iter:
|
||||
# previous to_target is current from_target
|
||||
# partition: previous results log(exp(from_target)), #(batch_size * from_target)
|
||||
# cur_values: batch_size * from_target * to_target
|
||||
cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
|
||||
## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG
|
||||
partition, cur_bp = torch.max(cur_values,dim=1)
|
||||
#print(partition.size())
|
||||
partition_history.append(partition.unsqueeze(2))
|
||||
## cur_bp: (batch_size, tag_size) max source score position in current tag
|
||||
## set padded label as 0, which will be filtered in post processing
|
||||
cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
|
||||
back_points.append(cur_bp)
|
||||
### add score to final STOP_TAG
|
||||
partition_history = torch.cat(partition_history,dim=0).view(seq_len, batch_size,-1).transpose(1,0).contiguous() ## (batch_size, seq_len, tag_size)
|
||||
### get the last position for each setences, and select the last partitions using gather()
|
||||
last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1
|
||||
last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1)
|
||||
### calculate the score from last partition to end state (and then select the STOP_TAG from it)
|
||||
last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size)
|
||||
_, last_bp = torch.max(last_values, 1) #(batch_size,tag_size)
|
||||
pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long()
|
||||
if self.gpu:
|
||||
pad_zero = pad_zero.cuda()
|
||||
back_points.append(pad_zero)
|
||||
back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size)
|
||||
|
||||
## select end ids in STOP_TAG
|
||||
pointer = last_bp[:, STOP_TAG] #(batch_size)
|
||||
insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size)
|
||||
back_points = back_points.transpose(1,0).contiguous() #(batch_size,sq_len,tag_size)
|
||||
## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values
|
||||
# print "lp:",last_position
|
||||
# print "il:",insert_last
|
||||
back_points.scatter_(1, last_position, insert_last) ##(batch_size,sq_len,tag_size)
|
||||
# print "bp:",back_points
|
||||
# exit(0)
|
||||
back_points = back_points.transpose(1,0).contiguous() #(seq_len, batch_size, tag_size)
|
||||
## decode from the end, padded position ids are 0, which will be filtered if following evaluation
|
||||
decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size))
|
||||
if self.gpu:
|
||||
decode_idx = decode_idx.cuda()
|
||||
decode_idx[-1] = pointer.data
|
||||
for idx in range(len(back_points)-2, -1, -1):
|
||||
pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) #pointer's size:(batch_size,1)
|
||||
decode_idx[idx] = pointer.squeeze(1).data
|
||||
path_score = None
|
||||
decode_idx = decode_idx.transpose(1,0) #(batch_size, sent_len)
|
||||
return path_score, decode_idx #
|
||||
|
||||
|
||||
|
||||
def forward(self, feats):
|
||||
path_score, best_path = self._viterbi_decode(feats)
|
||||
return path_score, best_path
|
||||
|
||||
|
||||
def _score_sentence(self, scores, mask, tags):
|
||||
"""
|
||||
input:
|
||||
scores: variable (seq_len, batch, tag_size, tag_size)
|
||||
mask: (batch, seq_len)
|
||||
tags: tensor (batch, seq_len)
|
||||
output:
|
||||
score: sum of score for gold sequences within whole batch
|
||||
"""
|
||||
# Gives the score of a provided tag sequence
|
||||
batch_size = scores.size(1)
|
||||
seq_len = scores.size(0)
|
||||
tag_size = scores.size(2)
|
||||
## convert tag value into a new format, recorded label bigram information to index
|
||||
new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len))
|
||||
if self.gpu:
|
||||
new_tags = new_tags.cuda()
|
||||
for idx in range(seq_len):
|
||||
if idx == 0:
|
||||
## start -> first score
|
||||
new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0]
|
||||
|
||||
else:
|
||||
new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx]
|
||||
|
||||
## transition for label to STOP_TAG
|
||||
end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size)
|
||||
## length for batch, last word position = length - 1
|
||||
length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
|
||||
## index the label id of last word
|
||||
end_ids = torch.gather(tags, 1, length_mask - 1)
|
||||
|
||||
## index the transition score for end_id to STOP_TAG
|
||||
end_energy = torch.gather(end_transition, 1, end_ids)
|
||||
|
||||
## convert tag as (seq_len, batch_size, 1)
|
||||
new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1)
|
||||
### need convert tags id to search from 400 positions of scores
|
||||
tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size
|
||||
## mask transpose to (seq_len, batch_size)
|
||||
tg_energy = tg_energy.masked_select(mask.transpose(1,0))
|
||||
|
||||
# ## calculate the score from START_TAG to first label
|
||||
# start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size)
|
||||
# start_energy = torch.gather(start_transition, 1, tags[0,:])
|
||||
|
||||
## add all score together
|
||||
# gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum()
|
||||
gold_score = tg_energy.sum() + end_energy.sum()
|
||||
return gold_score
|
||||
|
||||
def neg_log_likelihood_loss(self, feats, mask, tags):
|
||||
# nonegative log likelihood
|
||||
batch_size = feats.size(0)
|
||||
forward_score, scores = self._calculate_PZ(feats, mask) #forward_score:long, scores: (seq_len, batch, tag_size, tag_size)
|
||||
gold_score = self._score_sentence(scores, mask, tags)
|
||||
#print ("batch, f:", forward_score.data, " g:", gold_score.data, " dis:", forward_score.data - gold_score.data)
|
||||
# exit(0)
|
||||
if self.average_batch:
|
||||
return (forward_score - gold_score)/batch_size
|
||||
else:
|
||||
return forward_score - gold_score
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1,261 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import torch
|
||||
import torch.autograd as autograd
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import math, copy, time
|
||||
|
||||
# class CNNmodel(nn.Module):
|
||||
# def __init__(self, input_dim, hidden_dim, num_layer, dropout, gpu=True):
|
||||
# super(CNNmodel, self).__init__()
|
||||
# self.input_dim = input_dim
|
||||
# self.hidden_dim = hidden_dim
|
||||
# self.num_layer = num_layer
|
||||
# self.gpu = gpu
|
||||
#
|
||||
# self.cnn_layer0 = nn.Conv1d(self.input_dim, self.hidden_dim, kernel_size=1, padding=0)
|
||||
# self.cnn_layers = [nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) for i in range(self.num_layer-1)]
|
||||
# self.drop = nn.Dropout(dropout)
|
||||
#
|
||||
# if self.gpu:
|
||||
# self.cnn_layer0 = self.cnn_layer0.cuda()
|
||||
# for i in range(self.num_layer-1):
|
||||
# self.cnn_layers[i] = self.cnn_layers[i].cuda()
|
||||
#
|
||||
# def forward(self, input_feature):
|
||||
#
|
||||
# batch_size = input_feature.size(0)
|
||||
# seq_len = input_feature.size(1)
|
||||
#
|
||||
# input_feature = input_feature.transpose(2,1).contiguous()
|
||||
# cnn_output = self.cnn_layer0(input_feature) #(b,h,l)
|
||||
# cnn_output = self.drop(cnn_output)
|
||||
# cnn_output = torch.tanh(cnn_output)
|
||||
#
|
||||
# for layer in range(self.num_layer-1):
|
||||
# cnn_output = self.cnn_layers[layer](cnn_output)
|
||||
# cnn_output = self.drop(cnn_output)
|
||||
# cnn_output = torch.tanh(cnn_output)
|
||||
#
|
||||
# cnn_output = cnn_output.transpose(2,1).contiguous()
|
||||
# return cnn_output
|
||||
#
|
||||
#
|
||||
#
|
||||
# def clones(module, N):
|
||||
# "Produce N identical layers."
|
||||
# return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
||||
#
|
||||
# class LayerNorm(nn.Module):
|
||||
# "Construct a layernorm module (See citation for details)."
|
||||
# def __init__(self, features, eps=1e-6):
|
||||
# super(LayerNorm, self).__init__()
|
||||
# self.a_2 = nn.Parameter(torch.ones(features))
|
||||
# self.b_2 = nn.Parameter(torch.zeros(features))
|
||||
# self.eps = eps
|
||||
#
|
||||
# def forward(self, x):
|
||||
# mean = x.mean(-1, keepdim=True)
|
||||
# std = x.std(-1, keepdim=True)
|
||||
# return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
|
||||
#
|
||||
# class SublayerConnection(nn.Module):
|
||||
# """
|
||||
# A residual connection followed by a layer norm.
|
||||
# Note for code simplicity the norm is first as opposed to last.
|
||||
# """
|
||||
# def __init__(self, size, dropout):
|
||||
# super(SublayerConnection, self).__init__()
|
||||
# self.norm = LayerNorm(size)
|
||||
# self.dropout = nn.Dropout(dropout)
|
||||
#
|
||||
# def forward(self, x, sublayer):
|
||||
# "Apply residual connection to any sublayer with the same size."
|
||||
# return x + self.dropout(sublayer(self.norm(x)))
|
||||
#
|
||||
# class EncoderLayer(nn.Module):
|
||||
# "Encoder is made up of self-attn and feed forward (defined below)"
|
||||
# def __init__(self, size, self_attn, feed_forward, dropout):
|
||||
# super(EncoderLayer, self).__init__()
|
||||
# self.self_attn = self_attn
|
||||
# self.feed_forward = feed_forward
|
||||
# self.sublayer = clones(SublayerConnection(size, dropout), 2)
|
||||
# self.size = size
|
||||
#
|
||||
# def forward(self, x, mask):
|
||||
# "Follow Figure 1 (left) for connections."
|
||||
# x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
|
||||
# return self.sublayer[1](x, self.feed_forward)
|
||||
#
|
||||
#
|
||||
# def attention(query, key, value, mask=None, dropout=None):
|
||||
# "Compute 'Scaled Dot Product Attention'"
|
||||
# d_k = query.size(-1)
|
||||
# scores = torch.matmul(query, key.transpose(-2, -1)) \
|
||||
# / math.sqrt(d_k) ## (b,h,l,d) * (b,h,d,l)
|
||||
# if mask is not None:
|
||||
# # scores = scores.masked_fill(mask == 0, -1e9)
|
||||
# scores = scores.masked_fill(mask, -1e9)
|
||||
# p_attn = F.softmax(scores, dim = -1)
|
||||
# if dropout is not None:
|
||||
# p_attn = dropout(p_attn)
|
||||
# return torch.matmul(p_attn, value), p_attn ##(b,h,l,l) * (b,h,l,d) = (b,h,l,d)
|
||||
#
|
||||
#
|
||||
# class MultiHeadedAttention(nn.Module):
|
||||
# def __init__(self, h, d_model, dropout=0.1):
|
||||
# "Take in model size and number of heads."
|
||||
# super(MultiHeadedAttention, self).__init__()
|
||||
# assert d_model % h == 0
|
||||
# # We assume d_v always equals d_k
|
||||
# self.d_k = d_model // h
|
||||
# self.h = h
|
||||
# self.linears = clones(nn.Linear(d_model, d_model), 4)
|
||||
# self.attn = None
|
||||
# self.dropout = nn.Dropout(p=dropout)
|
||||
#
|
||||
# def forward(self, query, key, value, mask=None):
|
||||
# "Implements Figure 2"
|
||||
# if mask is not None:
|
||||
# # Same mask applied to all h heads.
|
||||
# mask = mask.unsqueeze(1)
|
||||
# nbatches = query.size(0)
|
||||
#
|
||||
# # 1) Do all the linear projections in batch from d_model => h x d_k
|
||||
# query, key, value = \
|
||||
# [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
|
||||
# for l, x in zip(self.linears, (query, key, value))]
|
||||
#
|
||||
# # 2) Apply attention on all the projected vectors in batch.
|
||||
# x, self.attn = attention(query, key, value, mask=mask,
|
||||
# dropout=self.dropout)
|
||||
#
|
||||
# # 3) "Concat" using a view and apply a final linear.
|
||||
# x = x.transpose(1, 2).contiguous() \
|
||||
# .view(nbatches, -1, self.h * self.d_k)
|
||||
# return self.linears[-1](x)
|
||||
#
|
||||
#
|
||||
# class PositionwiseFeedForward(nn.Module):
|
||||
# "Implements FFN equation."
|
||||
# def __init__(self, d_model, d_ff, dropout=0.1):
|
||||
# super(PositionwiseFeedForward, self).__init__()
|
||||
# self.w_1 = nn.Linear(d_model, d_ff)
|
||||
# self.w_2 = nn.Linear(d_ff, d_model)
|
||||
# self.dropout = nn.Dropout(dropout)
|
||||
#
|
||||
# def forward(self, x):
|
||||
# return self.w_2(self.dropout(F.relu(self.w_1(x))))
|
||||
#
|
||||
#
|
||||
# class PositionalEncoding(nn.Module):
|
||||
# "Implement the PE function."
|
||||
# def __init__(self, d_model, dropout, max_len=5000):
|
||||
# super(PositionalEncoding, self).__init__()
|
||||
# self.dropout = nn.Dropout(p=dropout)
|
||||
#
|
||||
# # Compute the positional encodings once in log space.
|
||||
# pe = torch.zeros(max_len, d_model)
|
||||
# position = torch.arange(0., max_len).unsqueeze(1)
|
||||
# div_term = torch.exp(torch.arange(0., d_model, 2) *
|
||||
# -(math.log(10000.0) / d_model))
|
||||
# pe[:, 0::2] = torch.sin(position * div_term)
|
||||
# pe[:, 1::2] = torch.cos(position * div_term)
|
||||
# pe = pe.unsqueeze(0)
|
||||
# self.register_buffer('pe', pe)
|
||||
#
|
||||
# def forward(self, x):
|
||||
# x = x + autograd.Variable(self.pe[:, :x.size(1)],
|
||||
# requires_grad=False)
|
||||
# return self.dropout(x)
|
||||
#
|
||||
#
|
||||
# class AttentionModel(nn.Module):
|
||||
# "Core encoder is a stack of N layers"
|
||||
# def __init__(self, d_input, d_model, d_ff, head, num_layer, dropout):
|
||||
# super(AttentionModel, self).__init__()
|
||||
# c = copy.deepcopy
|
||||
# # attn0 = MultiHeadedAttention(head, d_input, d_model)
|
||||
# attn = MultiHeadedAttention(head, d_model, dropout)
|
||||
# ff = PositionwiseFeedForward(d_model, d_ff, dropout)
|
||||
# # position = PositionalEncoding(d_model, dropout)
|
||||
# # layer0 = EncoderLayer(d_model, c(attn0), c(ff), dropout)
|
||||
# layer = EncoderLayer(d_model, c(attn), c(ff), dropout)
|
||||
# self.layers = clones(layer, num_layer)
|
||||
# # layerlist = [copy.deepcopy(layer0),]
|
||||
# # for _ in range(num_layer-1):
|
||||
# # layerlist.append(copy.deepcopy(layer))
|
||||
# # self.layers = nn.ModuleList(layerlist)
|
||||
# self.norm = LayerNorm(layer.size)
|
||||
# self.posi = PositionalEncoding(d_model, dropout)
|
||||
# self.input2model = nn.Linear(d_input, d_model)
|
||||
#
|
||||
# def forward(self, x, mask):
|
||||
# "Pass the input (and mask) through each layer in turn."
|
||||
# # x: embedding (b,l,we)
|
||||
# x = self.posi(self.input2model(x))
|
||||
# for layer in self.layers:
|
||||
# x = layer(x, mask)
|
||||
# return self.norm(x)
|
||||
from layers.encoders.transformers.transformer import TransformerEncoder
|
||||
|
||||
|
||||
class NERmodel(nn.Module):
|
||||
|
||||
def __init__(self, model_type, input_dim, hidden_dim, num_layer, dropout=0.5, gpu=True, biflag=True):
|
||||
super(NERmodel, self).__init__()
|
||||
self.model_type = model_type
|
||||
|
||||
if self.model_type == 'lstm':
|
||||
self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layer, batch_first=True, bidirectional=biflag)
|
||||
self.drop = nn.Dropout(dropout)
|
||||
|
||||
# if self.model_type == 'cnn':
|
||||
# self.cnn = CNNmodel(input_dim, hidden_dim, num_layer, dropout, gpu)
|
||||
#
|
||||
# ## attention model
|
||||
if self.model_type == 'transformer':
|
||||
n_head = 6
|
||||
head_dims = 80
|
||||
num_layers = 2
|
||||
d_model = n_head * head_dims
|
||||
feedforward_dim = int(2 * d_model)
|
||||
dropout = 0.15
|
||||
fc_dropout = 0.4
|
||||
after_norm = 1
|
||||
attn_type='adatrans'
|
||||
scale = attn_type == 'transformer'
|
||||
|
||||
self.in_fc = nn.Linear(input_dim, d_model)
|
||||
|
||||
self.transformer = TransformerEncoder(num_layers, d_model, n_head, feedforward_dim, dropout,
|
||||
after_norm=after_norm, attn_type=attn_type,
|
||||
scale=False, dropout_attn=scale,
|
||||
pos_embed=None)
|
||||
self.fc_dropout = nn.Dropout(fc_dropout)
|
||||
|
||||
# self.attention_model = AttentionModel(d_input=input_dim, d_model=hidden_dim, d_ff=2*hidden_dim, head=4, num_layer=num_layer, dropout=dropout)
|
||||
# for p in self.attention_model.parameters():
|
||||
# if p.dim() > 1:
|
||||
# nn.init.xavier_uniform_(p)
|
||||
|
||||
|
||||
def forward(self, input, mask=None):
|
||||
|
||||
if self.model_type == 'lstm':
|
||||
hidden = None
|
||||
feature_out, hidden = self.lstm(input, hidden)
|
||||
|
||||
feature_out_d = self.drop(feature_out)
|
||||
|
||||
# if self.model_type == 'cnn':
|
||||
# feature_out_d = self.cnn(input)
|
||||
#
|
||||
if self.model_type == 'transformer':
|
||||
chars = self.in_fc(input)
|
||||
chars = self.transformer(chars, mask)
|
||||
feature_out_d = self.fc_dropout(chars)
|
||||
|
||||
return feature_out_d
|
||||
|
@ -20,7 +20,7 @@ import torch.autograd as autograd
|
||||
import torch.optim as optim
|
||||
|
||||
from models.ner_net.augment_ner import GazLSTM as SeqModel
|
||||
from models.ner_net.bert_ner import BertNER
|
||||
from models.ner_net.transformers_ner import BertNER
|
||||
from run.entity_extraction.augmentedNER.data import Data
|
||||
from run.entity_extraction.augmentedNER.metric import get_ner_fmeasure
|
||||
|
@ -3,8 +3,8 @@ import sys
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from run.entity_extraction.augmentedNER.alphabet import Alphabet
|
||||
from run.entity_extraction.augmentedNER.functions import *
|
||||
from run.entity_extraction.generalNER.alphabet import Alphabet
|
||||
from run.entity_extraction.generalNER.functions import *
|
||||
|
||||
START = "</s>"
|
||||
UNKNOWN = "</unk>"
|
@ -1,4 +1,4 @@
|
||||
from run.entity_extraction.augmentedNER.trie import Trie
|
||||
from run.entity_extraction.generalNER.trie import Trie
|
||||
|
||||
|
||||
class Gazetteer:
|
Loading…
Reference in New Issue
Block a user