From d14f5e2047fe22c7dbfba5b8c778c25f9034cafa Mon Sep 17 00:00:00 2001 From: loujie0822 Date: Tue, 12 May 2020 20:49:25 +0800 Subject: [PATCH] add some layers --- layers/decoders/crf.py | 497 +++++++++--------- layers/decoders/pytorch_crf.py | 309 +++++++++++ layers/encoders/ner_layers.py | 241 +++++++++ layers/ner_layers/crf.py | 290 ---------- layers/ner_layers/layers.py | 261 --------- .../{bert_ner.py => bert_finetune_ner.py} | 0 .../{augment_ner.py => general_ner.py} | 0 .../augmentedNER/__init__.py | 0 .../entity_extraction/generalNER}/__init__.py | 0 .../{augmentedNER => generalNER}/main.py | 2 +- .../utils}/alphabet.py | 0 .../utils}/data.py | 4 +- .../utils}/functions.py | 0 .../utils}/gazetteer.py | 2 +- .../utils}/metric.py | 0 .../utils}/trie.py | 0 16 files changed, 793 insertions(+), 813 deletions(-) create mode 100644 layers/decoders/pytorch_crf.py create mode 100644 layers/encoders/ner_layers.py delete mode 100644 layers/ner_layers/crf.py delete mode 100644 layers/ner_layers/layers.py rename models/ner_net/{bert_ner.py => bert_finetune_ner.py} (100%) rename models/ner_net/{augment_ner.py => general_ner.py} (100%) delete mode 100644 run/entity_extraction/augmentedNER/__init__.py rename {layers/ner_layers => run/entity_extraction/generalNER}/__init__.py (100%) rename run/entity_extraction/{augmentedNER => generalNER}/main.py (97%) rename run/entity_extraction/{augmentedNER => generalNER/utils}/alphabet.py (100%) rename run/entity_extraction/{augmentedNER => generalNER/utils}/data.py (98%) rename run/entity_extraction/{augmentedNER => generalNER/utils}/functions.py (100%) rename run/entity_extraction/{augmentedNER => generalNER/utils}/gazetteer.py (96%) rename run/entity_extraction/{augmentedNER => generalNER/utils}/metric.py (100%) rename run/entity_extraction/{augmentedNER => generalNER/utils}/trie.py (100%) diff --git a/layers/decoders/crf.py b/layers/decoders/crf.py index 902f2df..0a5b54b 100644 --- a/layers/decoders/crf.py +++ b/layers/decoders/crf.py @@ -1,309 +1,290 @@ -# _*_ coding:utf-8 _*_ - +# -*- coding: utf-8 -*- +# @Author: Jie Yang +# @Date: 2017-12-04 23:19:38 +# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com +# @Last Modified time: 2018-05-27 22:48:17 import torch +import torch.autograd as autograd import torch.nn as nn +import torch.nn.functional as F +import numpy as np +START_TAG = -2 +STOP_TAG = -1 +# Compute log sum exp in a numerically stable way for the forward algorithm +def log_sum_exp(vec, m_size): + """ + calculate log of exp sum + args: + vec (batch_size, vanishing_dim, hidden_dim) : input tensor (b,t,t) + m_size : hidden_dim + return: + batch_size, hidden_dim + """ + _, idx = torch.max(vec, 1) # B * 1 * M + max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M + + return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) + class CRF(nn.Module): - """Conditional random field. - This module implements a conditional random field [LMP01]_. The forward computation - of this class computes the log likelihood of the given sequence of tags and - emission score tensor. This class also has `~CRF.decode` method which finds - the best tag sequence given an emission score tensor using `Viterbi algorithm`_. - - Args: - num_tags: Number of tags. - batch_first: Whether the first dimension corresponds to the size of a minibatch. - - Attributes: - start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size - ``(num_tags,)``. - end_transitions (`~torch.nn.Parameter`): End transition score tensor of size - ``(num_tags,)``. - transitions (`~torch.nn.Parameter`): Transition score tensor of size - ``(num_tags, num_tags)``. - """ - - def __init__(self, num_tags, batch_first): - if num_tags <= 0: - raise ValueError('invalid number of tags: {}'.format(num_tags)) + def __init__(self, tagset_size, gpu): super(CRF, self).__init__() - self.num_tags = num_tags - self.batch_first = batch_first - self.start_transitions = nn.Parameter(torch.empty(num_tags)) - self.end_transitions = nn.Parameter(torch.empty(num_tags)) - self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) + print ("build batched crf...") + self.gpu = gpu + # Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j. + self.average_batch = False + self.tagset_size = tagset_size + # # We add 2 here, because of START_TAG and STOP_TAG + # # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag + init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2) + # init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2) + # init_transitions[:,START_TAG] = -1000.0 + # init_transitions[STOP_TAG,:] = -1000.0 + # init_transitions[:,0] = -1000.0 + # init_transitions[0,:] = -1000.0 + if self.gpu: + init_transitions = init_transitions.cuda() + self.transitions = nn.Parameter(init_transitions) #(t+2,t+2) - self.reset_parameters() + # self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2)) + # self.transitions.data.zero_() - def reset_parameters(self): - """Initialize the transition parameters. - - The parameters will be initialized randomly from a uniform distribution - between -0.1 and 0.1. + def _calculate_PZ(self, feats, mask): """ - nn.init.uniform_(self.start_transitions, -0.1, 0.1) - nn.init.uniform_(self.end_transitions, -0.1, 0.1) - nn.init.uniform_(self.transitions, -0.1, 0.1) - - def __repr__(self): - return 'num_tags={}'.format(self.num_tags) - - def forward(self, emissions, tags, mask, reduction): - """Compute the conditional log likelihood of a sequence of tags given emission scores. - - Args: - emissions (`~torch.Tensor`): Emission score tensor of size - ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, - ``(batch_size, seq_length, num_tags)`` otherwise. - tags (`~torch.LongTensor`): Sequence of tags tensor of size - ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, - ``(batch_size, seq_length)`` otherwise. - mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` - if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. - reduction: Specifies the reduction to apply to the output: - ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. - ``sum``: the output will be summed over batches. ``mean``: the output will be - averaged over batches. ``token_mean``: the output will be averaged over tokens. - - Returns: - `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if - reduction is ``none``, ``()`` otherwise. + input: + feats: (batch, seq_len, self.tag_size+2) (b,m,t+2) + masks: (batch, seq_len) (b,m) """ - self._validate(emissions, tags=tags, mask=mask) - if reduction not in ('none', 'sum', 'mean', 'token_mean'): - raise ValueError('invalid reduction: {}'.format(reduction)) - if mask is None: - mask = torch.ones_like(tags, dtype=torch.uint8) + batch_size = feats.size(0) + seq_len = feats.size(1) + tag_size = feats.size(2) + # print feats.view(seq_len, tag_size) + assert(tag_size == self.tagset_size+2) + mask = mask.transpose(1,0).contiguous() #(m,b) + ins_num = seq_len * batch_size + ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) + feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size) + ## need to consider start + scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) + scores = scores.view(seq_len, batch_size, tag_size, tag_size) + # build iter + seq_iter = enumerate(scores) # (index,matrix) index is among the first dim: seqlen + _, inivalues = seq_iter.__next__() + # only need start from start_tag + partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) - if self.batch_first: - emissions = emissions.transpose(0, 1) - tags = tags.transpose(0, 1) - mask = mask.transpose(0, 1) + ## add start score (from start to all tag, duplicate to batch_size) + # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1) + # iter over last scores + for idx, cur_values in seq_iter: + # previous to_target is current from_target + # partition: previous results log(exp(from_target)), #(batch_size * from_target) + # cur_values: bat_size * from_target * to_target + + cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) + cur_partition = log_sum_exp(cur_values, tag_size) #(b,t) + # print cur_partition.data + + # (bat_size * from_target * to_target) -> (bat_size * to_target) + # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1) + mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) + + ## effective updated partition part, only keep the partition value of mask value = 1 + masked_cur_partition = cur_partition.masked_select(mask_idx) + ## let mask_idx broadcastable, to disable warning + mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) - # shape: (batch_size,) - numerator = self._compute_score(emissions, tags, mask) - # shape: (batch_size,) - denominator = self._compute_normalizer(emissions, mask) - # shape: (batch_size,) - llh = numerator - denominator + ## replace the partition where the maskvalue=1, other partition value keeps the same + partition.masked_scatter_(mask_idx, masked_cur_partition) + # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG + cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) + cur_partition = log_sum_exp(cur_values, tag_size) #(batch_size,hidden_dim) + final_partition = cur_partition[:, STOP_TAG] #(batch_size) + return final_partition.sum(), scores #scores: (seq_len, batch, tag_size, tag_size) - if reduction == 'none': - return llh - if reduction == 'sum': - return llh.sum() - if reduction == 'mean': - return llh.mean() - assert reduction == 'token_mean' - return llh.sum() / mask.float().sum() - def decode(self, emissions, mask=None): - """Find the most likely tag sequence using Viterbi algorithm. - - Args: - emissions (`~torch.Tensor`): Emission score tensor of size - ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, - ``(batch_size, seq_length, num_tags)`` otherwise. - mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` - if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. - - Returns: - List of list containing the best tag sequence for each batch. + def _viterbi_decode(self, feats, mask): """ - self._validate(emissions, tags=None, mask=mask) - if mask is None: - mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) + input: + feats: (batch, seq_len, self.tag_size+2) + mask: (batch, seq_len) + output: + decode_idx: (batch, seq_len) decoded sequence + path_score: (batch, 1) corresponding score for each sequence (to be implementated) + """ + batch_size = feats.size(0) + seq_len = feats.size(1) + tag_size = feats.size(2) + assert(tag_size == self.tagset_size+2) + ## calculate sentence length for each sentence + length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() + ## mask to (seq_len, batch_size) + mask = mask.transpose(1,0).contiguous() #(seq_len,b) + ins_num = seq_len * batch_size + ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) + feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) #(ins_num, tag_size, tag_size) + ## need to consider start + scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) + scores = scores.view(seq_len, batch_size, tag_size, tag_size) - if self.batch_first: - emissions = emissions.transpose(0, 1) - mask = mask.transpose(0, 1) + # build iter + seq_iter = enumerate(scores) + ## record the position of best score + back_points = list() + partition_history = list() + + + ## reverse mask (bug for mask = 1- mask, use this as alternative choice) + # mask = 1 + (-1)*mask + mask = (1 - mask.long()).byte() + _, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size + # only need start from start_tag + partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size + #print(partition.size()) + partition_history.append(partition) #(seqlen,batch_size,tag_size,1) + # iter over last scores + for idx, cur_values in seq_iter: + # previous to_target is current from_target + # partition: previous results log(exp(from_target)), #(batch_size * from_target) + # cur_values: batch_size * from_target * to_target + cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) + ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG + partition, cur_bp = torch.max(cur_values,dim=1) + #print(partition.size()) + partition_history.append(partition.unsqueeze(2)) + ## cur_bp: (batch_size, tag_size) max source score position in current tag + ## set padded label as 0, which will be filtered in post processing + cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) + back_points.append(cur_bp) + ### add score to final STOP_TAG + partition_history = torch.cat(partition_history,dim=0).view(seq_len, batch_size,-1).transpose(1,0).contiguous() ## (batch_size, seq_len, tag_size) + ### get the last position for each setences, and select the last partitions using gather() + last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1 + last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1) + ### calculate the score from last partition to end state (and then select the STOP_TAG from it) + last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + _, last_bp = torch.max(last_values, 1) #(batch_size,tag_size) + pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long() + if self.gpu: + pad_zero = pad_zero.cuda() + back_points.append(pad_zero) + back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) + + ## select end ids in STOP_TAG + pointer = last_bp[:, STOP_TAG] #(batch_size) + insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size) + back_points = back_points.transpose(1,0).contiguous() #(batch_size,sq_len,tag_size) + ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values + # print "lp:",last_position + # print "il:",insert_last + back_points.scatter_(1, last_position, insert_last) ##(batch_size,sq_len,tag_size) + # print "bp:",back_points + # exit(0) + back_points = back_points.transpose(1,0).contiguous() #(seq_len, batch_size, tag_size) + ## decode from the end, padded position ids are 0, which will be filtered if following evaluation + decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size)) + if self.gpu: + decode_idx = decode_idx.cuda() + decode_idx[-1] = pointer.data + for idx in range(len(back_points)-2, -1, -1): + pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) #pointer's size:(batch_size,1) + decode_idx[idx] = pointer.squeeze(1).data + path_score = None + decode_idx = decode_idx.transpose(1,0) #(batch_size, sent_len) + return path_score, decode_idx # - return self._viterbi_decode(emissions, mask) - def _validate(self, emissions, tags, mask): - if emissions.dim() != 3: - raise ValueError('emissions must have dimension of 3, got {}'.format(emissions.dim())) - if emissions.size(2) != self.num_tags: - raise ValueError( - 'expected last dimension of emissions is {}, got {}'.format(self.num_tags, emissions.size(2))) - if tags is not None: - if emissions.shape[:2] != tags.shape: - raise ValueError( - 'the first two dimensions of emissions and tags must match,got {} and {}'.format( - tuple(emissions.shape[:2]), tuple(tags.shape))) + def forward(self, feats): + path_score, best_path = self._viterbi_decode(feats) + return path_score, best_path + - if mask is not None: - if emissions.shape[:2] != mask.shape: - raise ValueError( - 'the first two dimensions of emissions and mask must match, got {} and {}'.format( - tuple(emissions.shape[:2]), tuple(mask.shape))) - no_empty_seq = not self.batch_first and mask[0].all() - no_empty_seq_bf = self.batch_first and mask[:, 0].all() - if not no_empty_seq and not no_empty_seq_bf: - raise ValueError('mask of the first timestep must all be on') + def _score_sentence(self, scores, mask, tags): + """ + input: + scores: variable (seq_len, batch, tag_size, tag_size) + mask: (batch, seq_len) + tags: tensor (batch, seq_len) + output: + score: sum of score for gold sequences within whole batch + """ + # Gives the score of a provided tag sequence + batch_size = scores.size(1) + seq_len = scores.size(0) + tag_size = scores.size(2) + ## convert tag value into a new format, recorded label bigram information to index + new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len)) + if self.gpu: + new_tags = new_tags.cuda() + for idx in range(seq_len): + if idx == 0: + ## start -> first score + new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0] - def _compute_score( - self, emissions, tags, mask): - # emissions: (seq_length, batch_size, num_tags) - # tags: (seq_length, batch_size) - # mask: (seq_length, batch_size) - assert emissions.dim() == 3 and tags.dim() == 2 - assert emissions.shape[:2] == tags.shape - assert emissions.size(2) == self.num_tags - assert mask.shape == tags.shape - assert mask[0].all() + else: + new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx] - seq_length, batch_size = tags.shape - mask = mask.float() + ## transition for label to STOP_TAG + end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) + ## length for batch, last word position = length - 1 + length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() + ## index the label id of last word + end_ids = torch.gather(tags, 1, length_mask - 1) - # Start transition score and first emission - # shape: (batch_size,) - score = self.start_transitions[tags[0]] - score += emissions[0, torch.arange(batch_size), tags[0]] + ## index the transition score for end_id to STOP_TAG + end_energy = torch.gather(end_transition, 1, end_ids) - for i in range(1, seq_length): - # Transition score to next tag, only added if next timestep is valid (mask == 1) - # shape: (batch_size,) - score += self.transitions[tags[i - 1], tags[i]] * mask[i] + ## convert tag as (seq_len, batch_size, 1) + new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1) + ### need convert tags id to search from 400 positions of scores + tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size + ## mask transpose to (seq_len, batch_size) + tg_energy = tg_energy.masked_select(mask.transpose(1,0)) + + # ## calculate the score from START_TAG to first label + # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) + # start_energy = torch.gather(start_transition, 1, tags[0,:]) - # Emission score for next tag, only added if next timestep is valid (mask == 1) - # shape: (batch_size,) - score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] + ## add all score together + # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() + gold_score = tg_energy.sum() + end_energy.sum() + return gold_score - # End transition score - # shape: (batch_size,) - seq_ends = mask.long().sum(dim=0) - 1 - # shape: (batch_size,) - last_tags = tags[seq_ends, torch.arange(batch_size)] - # shape: (batch_size,) - score += self.end_transitions[last_tags] + def neg_log_likelihood_loss(self, feats, mask, tags): + # nonegative log likelihood + batch_size = feats.size(0) + forward_score, scores = self._calculate_PZ(feats, mask) #forward_score:long, scores: (seq_len, batch, tag_size, tag_size) + gold_score = self._score_sentence(scores, mask, tags) + #print ("batch, f:", forward_score.data, " g:", gold_score.data, " dis:", forward_score.data - gold_score.data) + # exit(0) + if self.average_batch: + return (forward_score - gold_score)/batch_size + else: + return forward_score - gold_score - return score - def _compute_normalizer(self, emissions, mask): - # emissions: (seq_length, batch_size, num_tags) - # mask: (seq_length, batch_size) - assert emissions.dim() == 3 and mask.dim() == 2 - assert emissions.shape[:2] == mask.shape - assert emissions.size(2) == self.num_tags - assert mask[0].all() - seq_length = emissions.size(0) - # Start transition score and first emission; score has size of - # (batch_size, num_tags) where for each batch, the j-th column stores - # the score that the first timestep has tag j - # shape: (batch_size, num_tags) - score = self.start_transitions + emissions[0] - for i in range(1, seq_length): - # Broadcast score for every possible next tag - # shape: (batch_size, num_tags, 1) - broadcast_score = score.unsqueeze(2) - # Broadcast emission score for every possible current tag - # shape: (batch_size, 1, num_tags) - broadcast_emissions = emissions[i].unsqueeze(1) - # Compute the score tensor of size (batch_size, num_tags, num_tags) where - # for each sample, entry at row i and column j stores the sum of scores of all - # possible tag sequences so far that end with transitioning from tag i to tag j - # and emitting - # shape: (batch_size, num_tags, num_tags) - next_score = broadcast_score + self.transitions + broadcast_emissions - # Sum over all possible current tags, but we're in score space, so a sum - # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of - # all possible tag sequences so far, that end in tag i - # shape: (batch_size, num_tags) - next_score = torch.logsumexp(next_score, dim=1) - # Set score to the next score if this timestep is valid (mask == 1) - # shape: (batch_size, num_tags) - score = torch.where(mask[i].unsqueeze(1), next_score, score) - # End transition score - # shape: (batch_size, num_tags) - score += self.end_transitions - # Sum (log-sum-exp) over all possible tags - # shape: (batch_size,) - return torch.logsumexp(score, dim=1) - def _viterbi_decode(self, emissions, mask): - # emissions: (seq_length, batch_size, num_tags) - # mask: (seq_length, batch_size) - assert emissions.dim() == 3 and mask.dim() == 2 - assert emissions.shape[:2] == mask.shape - assert emissions.size(2) == self.num_tags - assert mask[0].all() - seq_length, batch_size = mask.shape - # Start transition and first emission - # shape: (batch_size, num_tags) - score = self.start_transitions + emissions[0] - history = [] - # score is a tensor of size (batch_size, num_tags) where for every batch, - # value at column j stores the score of the best tag sequence so far that ends - # with tag j - # history saves where the best tags candidate transitioned from; this is used - # when we trace back the best tag sequence - # Viterbi algorithm recursive case: we compute the score of the best tag sequence - # for every possible next tag - for i in range(1, seq_length): - # Broadcast viterbi score for every possible next tag - # shape: (batch_size, num_tags, 1) - broadcast_score = score.unsqueeze(2) - # Broadcast emission score for every possible current tag - # shape: (batch_size, 1, num_tags) - broadcast_emission = emissions[i].unsqueeze(1) - # Compute the score tensor of size (batch_size, num_tags, num_tags) where - # for each sample, entry at row i and column j stores the score of the best - # tag sequence so far that ends with transitioning from tag i to tag j and emitting - # shape: (batch_size, num_tags, num_tags) - next_score = broadcast_score + self.transitions + broadcast_emission - # Find the maximum score over all possible current tag - # shape: (batch_size, num_tags) - next_score, indices = next_score.max(dim=1) - # Set score to the next score if this timestep is valid (mask == 1) - # and save the index that produces the next score - # shape: (batch_size, num_tags) - score = torch.where(mask[i].unsqueeze(1), next_score, score) - history.append(indices) - # End transition score - # shape: (batch_size, num_tags) - score += self.end_transitions - # Now, compute the best path for each sample - # shape: (batch_size,) - seq_ends = mask.long().sum(dim=0) - 1 - best_tags_list = [] - for idx in range(batch_size): - # Find the tag which maximizes the score at the last timestep; this is our best tag - # for the last timestep - _, best_last_tag = score[idx].max(dim=0) - best_tags = [best_last_tag.item()] - - # We trace back where the best last tag comes from, append that to our best tag - # sequence, and trace it back again, and so on - for hist in reversed(history[:seq_ends[idx]]): - best_last_tag = hist[idx][best_tags[-1]] - best_tags.append(best_last_tag.item()) - - # Reverse the order because we start from the last timestep - best_tags.reverse() - best_tags_list.append(best_tags) - - return best_tags_list diff --git a/layers/decoders/pytorch_crf.py b/layers/decoders/pytorch_crf.py new file mode 100644 index 0000000..902f2df --- /dev/null +++ b/layers/decoders/pytorch_crf.py @@ -0,0 +1,309 @@ +# _*_ coding:utf-8 _*_ + +import torch +import torch.nn as nn + + +class CRF(nn.Module): + """Conditional random field. + + This module implements a conditional random field [LMP01]_. The forward computation + of this class computes the log likelihood of the given sequence of tags and + emission score tensor. This class also has `~CRF.decode` method which finds + the best tag sequence given an emission score tensor using `Viterbi algorithm`_. + + Args: + num_tags: Number of tags. + batch_first: Whether the first dimension corresponds to the size of a minibatch. + + Attributes: + start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size + ``(num_tags,)``. + end_transitions (`~torch.nn.Parameter`): End transition score tensor of size + ``(num_tags,)``. + transitions (`~torch.nn.Parameter`): Transition score tensor of size + ``(num_tags, num_tags)``. + """ + + def __init__(self, num_tags, batch_first): + if num_tags <= 0: + raise ValueError('invalid number of tags: {}'.format(num_tags)) + super(CRF, self).__init__() + self.num_tags = num_tags + self.batch_first = batch_first + self.start_transitions = nn.Parameter(torch.empty(num_tags)) + self.end_transitions = nn.Parameter(torch.empty(num_tags)) + self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) + + self.reset_parameters() + + def reset_parameters(self): + """Initialize the transition parameters. + + The parameters will be initialized randomly from a uniform distribution + between -0.1 and 0.1. + """ + nn.init.uniform_(self.start_transitions, -0.1, 0.1) + nn.init.uniform_(self.end_transitions, -0.1, 0.1) + nn.init.uniform_(self.transitions, -0.1, 0.1) + + def __repr__(self): + return 'num_tags={}'.format(self.num_tags) + + def forward(self, emissions, tags, mask, reduction): + """Compute the conditional log likelihood of a sequence of tags given emission scores. + + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + tags (`~torch.LongTensor`): Sequence of tags tensor of size + ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + reduction: Specifies the reduction to apply to the output: + ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. + ``sum``: the output will be summed over batches. ``mean``: the output will be + averaged over batches. ``token_mean``: the output will be averaged over tokens. + + Returns: + `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if + reduction is ``none``, ``()`` otherwise. + """ + self._validate(emissions, tags=tags, mask=mask) + if reduction not in ('none', 'sum', 'mean', 'token_mean'): + raise ValueError('invalid reduction: {}'.format(reduction)) + if mask is None: + mask = torch.ones_like(tags, dtype=torch.uint8) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + tags = tags.transpose(0, 1) + mask = mask.transpose(0, 1) + + # shape: (batch_size,) + numerator = self._compute_score(emissions, tags, mask) + # shape: (batch_size,) + denominator = self._compute_normalizer(emissions, mask) + # shape: (batch_size,) + llh = numerator - denominator + + if reduction == 'none': + return llh + if reduction == 'sum': + return llh.sum() + if reduction == 'mean': + return llh.mean() + assert reduction == 'token_mean' + return llh.sum() / mask.float().sum() + + def decode(self, emissions, mask=None): + """Find the most likely tag sequence using Viterbi algorithm. + + Args: + emissions (`~torch.Tensor`): Emission score tensor of size + ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, + ``(batch_size, seq_length, num_tags)`` otherwise. + mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` + if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. + + Returns: + List of list containing the best tag sequence for each batch. + """ + self._validate(emissions, tags=None, mask=mask) + if mask is None: + mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8) + + if self.batch_first: + emissions = emissions.transpose(0, 1) + mask = mask.transpose(0, 1) + + return self._viterbi_decode(emissions, mask) + + def _validate(self, emissions, tags, mask): + if emissions.dim() != 3: + raise ValueError('emissions must have dimension of 3, got {}'.format(emissions.dim())) + if emissions.size(2) != self.num_tags: + raise ValueError( + 'expected last dimension of emissions is {}, got {}'.format(self.num_tags, emissions.size(2))) + + if tags is not None: + if emissions.shape[:2] != tags.shape: + raise ValueError( + 'the first two dimensions of emissions and tags must match,got {} and {}'.format( + tuple(emissions.shape[:2]), tuple(tags.shape))) + + if mask is not None: + if emissions.shape[:2] != mask.shape: + raise ValueError( + 'the first two dimensions of emissions and mask must match, got {} and {}'.format( + tuple(emissions.shape[:2]), tuple(mask.shape))) + no_empty_seq = not self.batch_first and mask[0].all() + no_empty_seq_bf = self.batch_first and mask[:, 0].all() + if not no_empty_seq and not no_empty_seq_bf: + raise ValueError('mask of the first timestep must all be on') + + def _compute_score( + self, emissions, tags, mask): + # emissions: (seq_length, batch_size, num_tags) + # tags: (seq_length, batch_size) + # mask: (seq_length, batch_size) + assert emissions.dim() == 3 and tags.dim() == 2 + assert emissions.shape[:2] == tags.shape + assert emissions.size(2) == self.num_tags + assert mask.shape == tags.shape + assert mask[0].all() + + seq_length, batch_size = tags.shape + mask = mask.float() + + # Start transition score and first emission + # shape: (batch_size,) + score = self.start_transitions[tags[0]] + score += emissions[0, torch.arange(batch_size), tags[0]] + + for i in range(1, seq_length): + # Transition score to next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += self.transitions[tags[i - 1], tags[i]] * mask[i] + + # Emission score for next tag, only added if next timestep is valid (mask == 1) + # shape: (batch_size,) + score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] + + # End transition score + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + # shape: (batch_size,) + last_tags = tags[seq_ends, torch.arange(batch_size)] + # shape: (batch_size,) + score += self.end_transitions[last_tags] + + return score + + def _compute_normalizer(self, emissions, mask): + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + assert emissions.dim() == 3 and mask.dim() == 2 + assert emissions.shape[:2] == mask.shape + assert emissions.size(2) == self.num_tags + assert mask[0].all() + + seq_length = emissions.size(0) + + # Start transition score and first emission; score has size of + # (batch_size, num_tags) where for each batch, the j-th column stores + # the score that the first timestep has tag j + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + + for i in range(1, seq_length): + # Broadcast score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emissions = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the sum of scores of all + # possible tag sequences so far that end with transitioning from tag i to tag j + # and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emissions + + # Sum over all possible current tags, but we're in score space, so a sum + # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of + # all possible tag sequences so far, that end in tag i + # shape: (batch_size, num_tags) + next_score = torch.logsumexp(next_score, dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(1), next_score, score) + + # End transition score + # shape: (batch_size, num_tags) + score += self.end_transitions + + # Sum (log-sum-exp) over all possible tags + # shape: (batch_size,) + return torch.logsumexp(score, dim=1) + + def _viterbi_decode(self, emissions, mask): + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + assert emissions.dim() == 3 and mask.dim() == 2 + assert emissions.shape[:2] == mask.shape + assert emissions.size(2) == self.num_tags + assert mask[0].all() + + seq_length, batch_size = mask.shape + + # Start transition and first emission + # shape: (batch_size, num_tags) + score = self.start_transitions + emissions[0] + history = [] + + # score is a tensor of size (batch_size, num_tags) where for every batch, + # value at column j stores the score of the best tag sequence so far that ends + # with tag j + # history saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + # Broadcast viterbi score for every possible next tag + # shape: (batch_size, num_tags, 1) + broadcast_score = score.unsqueeze(2) + + # Broadcast emission score for every possible current tag + # shape: (batch_size, 1, num_tags) + broadcast_emission = emissions[i].unsqueeze(1) + + # Compute the score tensor of size (batch_size, num_tags, num_tags) where + # for each sample, entry at row i and column j stores the score of the best + # tag sequence so far that ends with transitioning from tag i to tag j and emitting + # shape: (batch_size, num_tags, num_tags) + next_score = broadcast_score + self.transitions + broadcast_emission + + # Find the maximum score over all possible current tag + # shape: (batch_size, num_tags) + next_score, indices = next_score.max(dim=1) + + # Set score to the next score if this timestep is valid (mask == 1) + # and save the index that produces the next score + # shape: (batch_size, num_tags) + score = torch.where(mask[i].unsqueeze(1), next_score, score) + history.append(indices) + + # End transition score + # shape: (batch_size, num_tags) + score += self.end_transitions + + # Now, compute the best path for each sample + + # shape: (batch_size,) + seq_ends = mask.long().sum(dim=0) - 1 + best_tags_list = [] + + for idx in range(batch_size): + # Find the tag which maximizes the score at the last timestep; this is our best tag + # for the last timestep + _, best_last_tag = score[idx].max(dim=0) + best_tags = [best_last_tag.item()] + + # We trace back where the best last tag comes from, append that to our best tag + # sequence, and trace it back again, and so on + for hist in reversed(history[:seq_ends[idx]]): + best_last_tag = hist[idx][best_tags[-1]] + best_tags.append(best_last_tag.item()) + + # Reverse the order because we start from the last timestep + best_tags.reverse() + best_tags_list.append(best_tags) + + return best_tags_list diff --git a/layers/encoders/ner_layers.py b/layers/encoders/ner_layers.py new file mode 100644 index 0000000..987061a --- /dev/null +++ b/layers/encoders/ner_layers.py @@ -0,0 +1,241 @@ +# -*- coding: utf-8 -*- +import torch +import torch.autograd as autograd +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math, copy, time + +class CNNmodel(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layer, dropout, gpu=True): + super(CNNmodel, self).__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_layer = num_layer + self.gpu = gpu + + self.cnn_layer0 = nn.Conv1d(self.input_dim, self.hidden_dim, kernel_size=1, padding=0) + self.cnn_layers = [nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) for i in range(self.num_layer-1)] + self.drop = nn.Dropout(dropout) + + if self.gpu: + self.cnn_layer0 = self.cnn_layer0.cuda() + for i in range(self.num_layer-1): + self.cnn_layers[i] = self.cnn_layers[i].cuda() + + def forward(self, input_feature): + + batch_size = input_feature.size(0) + seq_len = input_feature.size(1) + + input_feature = input_feature.transpose(2,1).contiguous() + cnn_output = self.cnn_layer0(input_feature) #(b,h,l) + cnn_output = self.drop(cnn_output) + cnn_output = torch.tanh(cnn_output) + + for layer in range(self.num_layer-1): + cnn_output = self.cnn_layers[layer](cnn_output) + cnn_output = self.drop(cnn_output) + cnn_output = torch.tanh(cnn_output) + + cnn_output = cnn_output.transpose(2,1).contiguous() + return cnn_output + + + +def clones(module, N): + "Produce N identical layers." + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + +class LayerNorm(nn.Module): + "Construct a layernorm module (See citation for details)." + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.a_2 = nn.Parameter(torch.ones(features)) + self.b_2 = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + +class SublayerConnection(nn.Module): + """ + A residual connection followed by a layer norm. + Note for code simplicity the norm is first as opposed to last. + """ + def __init__(self, size, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + "Apply residual connection to any sublayer with the same size." + return x + self.dropout(sublayer(self.norm(x))) + +class EncoderLayer(nn.Module): + "Encoder is made up of self-attn and feed forward (defined below)" + def __init__(self, size, self_attn, feed_forward, dropout): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 2) + self.size = size + + def forward(self, x, mask): + "Follow Figure 1 (left) for connections." + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[1](x, self.feed_forward) + + +def attention(query, key, value, mask=None, dropout=None): + "Compute 'Scaled Dot Product Attention'" + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) \ + / math.sqrt(d_k) ## (b,h,l,d) * (b,h,d,l) + if mask is not None: + # scores = scores.masked_fill(mask == 0, -1e9) + scores = scores.masked_fill(mask, -1e9) + p_attn = F.softmax(scores, dim = -1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn ##(b,h,l,l) * (b,h,l,d) = (b,h,l,d) + + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1): + "Take in model size and number of heads." + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask=None): + "Implements Figure 2" + if mask is not None: + # Same mask applied to all h heads. + mask = mask.unsqueeze(1) + nbatches = query.size(0) + + # 1) Do all the linear projections in batch from d_model => h x d_k + query, key, value = \ + [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for l, x in zip(self.linears, (query, key, value))] + + # 2) Apply attention on all the projected vectors in batch. + x, self.attn = attention(query, key, value, mask=mask, + dropout=self.dropout) + + # 3) "Concat" using a view and apply a final linear. + x = x.transpose(1, 2).contiguous() \ + .view(nbatches, -1, self.h * self.d_k) + return self.linears[-1](x) + + +class PositionwiseFeedForward(nn.Module): + "Implements FFN equation." + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + + +class PositionalEncoding(nn.Module): + "Implement the PE function." + def __init__(self, d_model, dropout, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0., max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0., d_model, 2) * + -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + autograd.Variable(self.pe[:, :x.size(1)], + requires_grad=False) + return self.dropout(x) + + +class AttentionModel(nn.Module): + "Core encoder is a stack of N layers" + def __init__(self, d_input, d_model, d_ff, head, num_layer, dropout): + super(AttentionModel, self).__init__() + c = copy.deepcopy + # attn0 = MultiHeadedAttention(head, d_input, d_model) + attn = MultiHeadedAttention(head, d_model, dropout) + ff = PositionwiseFeedForward(d_model, d_ff, dropout) + # position = PositionalEncoding(d_model, dropout) + # layer0 = EncoderLayer(d_model, c(attn0), c(ff), dropout) + layer = EncoderLayer(d_model, c(attn), c(ff), dropout) + self.layers = clones(layer, num_layer) + # layerlist = [copy.deepcopy(layer0),] + # for _ in range(num_layer-1): + # layerlist.append(copy.deepcopy(layer)) + # self.layers = nn.ModuleList(layerlist) + self.norm = LayerNorm(layer.size) + self.posi = PositionalEncoding(d_model, dropout) + self.input2model = nn.Linear(d_input, d_model) + + def forward(self, x, mask): + "Pass the input (and mask) through each layer in turn." + # x: embedding (b,l,we) + x = self.posi(self.input2model(x)) + for layer in self.layers: + x = layer(x, mask) + return self.norm(x) + + + + +class NERmodel(nn.Module): + + def __init__(self, model_type, input_dim, hidden_dim, num_layer, dropout=0.5, gpu=True, biflag=True): + super(NERmodel, self).__init__() + self.model_type = model_type + + if self.model_type == 'lstm': + self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layer, batch_first=True, bidirectional=biflag) + self.drop = nn.Dropout(dropout) + + if self.model_type == 'cnn': + self.cnn = CNNmodel(input_dim, hidden_dim, num_layer, dropout, gpu) + + ## attention model + if self.model_type == 'transformer': + self.attention_model = AttentionModel(d_input=input_dim, d_model=hidden_dim, d_ff=2*hidden_dim, head=4, num_layer=num_layer, dropout=dropout) + for p in self.attention_model.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + + def forward(self, input, mask=None): + + if self.model_type == 'lstm': + hidden = None + feature_out, hidden = self.lstm(input, hidden) + + feature_out_d = self.drop(feature_out) + + if self.model_type == 'cnn': + feature_out_d = self.cnn(input) + + if self.model_type == 'transformer': + feature_out_d = self.attention_model(input, mask) + + return feature_out_d + diff --git a/layers/ner_layers/crf.py b/layers/ner_layers/crf.py deleted file mode 100644 index 0a5b54b..0000000 --- a/layers/ner_layers/crf.py +++ /dev/null @@ -1,290 +0,0 @@ -# -*- coding: utf-8 -*- -# @Author: Jie Yang -# @Date: 2017-12-04 23:19:38 -# @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com -# @Last Modified time: 2018-05-27 22:48:17 -import torch -import torch.autograd as autograd -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -START_TAG = -2 -STOP_TAG = -1 - - -# Compute log sum exp in a numerically stable way for the forward algorithm -def log_sum_exp(vec, m_size): - """ - calculate log of exp sum - args: - vec (batch_size, vanishing_dim, hidden_dim) : input tensor (b,t,t) - m_size : hidden_dim - return: - batch_size, hidden_dim - """ - _, idx = torch.max(vec, 1) # B * 1 * M - max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M - - return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) - -class CRF(nn.Module): - - def __init__(self, tagset_size, gpu): - super(CRF, self).__init__() - print ("build batched crf...") - self.gpu = gpu - # Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j. - self.average_batch = False - self.tagset_size = tagset_size - # # We add 2 here, because of START_TAG and STOP_TAG - # # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag - init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2) - # init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2) - # init_transitions[:,START_TAG] = -1000.0 - # init_transitions[STOP_TAG,:] = -1000.0 - # init_transitions[:,0] = -1000.0 - # init_transitions[0,:] = -1000.0 - if self.gpu: - init_transitions = init_transitions.cuda() - self.transitions = nn.Parameter(init_transitions) #(t+2,t+2) - - # self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2)) - # self.transitions.data.zero_() - - def _calculate_PZ(self, feats, mask): - """ - input: - feats: (batch, seq_len, self.tag_size+2) (b,m,t+2) - masks: (batch, seq_len) (b,m) - """ - batch_size = feats.size(0) - seq_len = feats.size(1) - tag_size = feats.size(2) - # print feats.view(seq_len, tag_size) - assert(tag_size == self.tagset_size+2) - mask = mask.transpose(1,0).contiguous() #(m,b) - ins_num = seq_len * batch_size - ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) - feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size) - ## need to consider start - scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) - scores = scores.view(seq_len, batch_size, tag_size, tag_size) - # build iter - seq_iter = enumerate(scores) # (index,matrix) index is among the first dim: seqlen - _, inivalues = seq_iter.__next__() - # only need start from start_tag - partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) - - ## add start score (from start to all tag, duplicate to batch_size) - # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1) - # iter over last scores - for idx, cur_values in seq_iter: - # previous to_target is current from_target - # partition: previous results log(exp(from_target)), #(batch_size * from_target) - # cur_values: bat_size * from_target * to_target - - cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) - cur_partition = log_sum_exp(cur_values, tag_size) #(b,t) - # print cur_partition.data - - # (bat_size * from_target * to_target) -> (bat_size * to_target) - # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1) - mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) - - ## effective updated partition part, only keep the partition value of mask value = 1 - masked_cur_partition = cur_partition.masked_select(mask_idx) - ## let mask_idx broadcastable, to disable warning - mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) - - ## replace the partition where the maskvalue=1, other partition value keeps the same - partition.masked_scatter_(mask_idx, masked_cur_partition) - # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG - cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) - cur_partition = log_sum_exp(cur_values, tag_size) #(batch_size,hidden_dim) - final_partition = cur_partition[:, STOP_TAG] #(batch_size) - return final_partition.sum(), scores #scores: (seq_len, batch, tag_size, tag_size) - - - def _viterbi_decode(self, feats, mask): - """ - input: - feats: (batch, seq_len, self.tag_size+2) - mask: (batch, seq_len) - output: - decode_idx: (batch, seq_len) decoded sequence - path_score: (batch, 1) corresponding score for each sequence (to be implementated) - """ - batch_size = feats.size(0) - seq_len = feats.size(1) - tag_size = feats.size(2) - assert(tag_size == self.tagset_size+2) - ## calculate sentence length for each sentence - length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() - ## mask to (seq_len, batch_size) - mask = mask.transpose(1,0).contiguous() #(seq_len,b) - ins_num = seq_len * batch_size - ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) - feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) #(ins_num, tag_size, tag_size) - ## need to consider start - scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) - scores = scores.view(seq_len, batch_size, tag_size, tag_size) - - # build iter - seq_iter = enumerate(scores) - ## record the position of best score - back_points = list() - partition_history = list() - - - ## reverse mask (bug for mask = 1- mask, use this as alternative choice) - # mask = 1 + (-1)*mask - mask = (1 - mask.long()).byte() - _, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size - # only need start from start_tag - partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size - #print(partition.size()) - partition_history.append(partition) #(seqlen,batch_size,tag_size,1) - # iter over last scores - for idx, cur_values in seq_iter: - # previous to_target is current from_target - # partition: previous results log(exp(from_target)), #(batch_size * from_target) - # cur_values: batch_size * from_target * to_target - cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) - ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG - partition, cur_bp = torch.max(cur_values,dim=1) - #print(partition.size()) - partition_history.append(partition.unsqueeze(2)) - ## cur_bp: (batch_size, tag_size) max source score position in current tag - ## set padded label as 0, which will be filtered in post processing - cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) - back_points.append(cur_bp) - ### add score to final STOP_TAG - partition_history = torch.cat(partition_history,dim=0).view(seq_len, batch_size,-1).transpose(1,0).contiguous() ## (batch_size, seq_len, tag_size) - ### get the last position for each setences, and select the last partitions using gather() - last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1 - last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1) - ### calculate the score from last partition to end state (and then select the STOP_TAG from it) - last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) - _, last_bp = torch.max(last_values, 1) #(batch_size,tag_size) - pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long() - if self.gpu: - pad_zero = pad_zero.cuda() - back_points.append(pad_zero) - back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) - - ## select end ids in STOP_TAG - pointer = last_bp[:, STOP_TAG] #(batch_size) - insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size) - back_points = back_points.transpose(1,0).contiguous() #(batch_size,sq_len,tag_size) - ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values - # print "lp:",last_position - # print "il:",insert_last - back_points.scatter_(1, last_position, insert_last) ##(batch_size,sq_len,tag_size) - # print "bp:",back_points - # exit(0) - back_points = back_points.transpose(1,0).contiguous() #(seq_len, batch_size, tag_size) - ## decode from the end, padded position ids are 0, which will be filtered if following evaluation - decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size)) - if self.gpu: - decode_idx = decode_idx.cuda() - decode_idx[-1] = pointer.data - for idx in range(len(back_points)-2, -1, -1): - pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) #pointer's size:(batch_size,1) - decode_idx[idx] = pointer.squeeze(1).data - path_score = None - decode_idx = decode_idx.transpose(1,0) #(batch_size, sent_len) - return path_score, decode_idx # - - - - def forward(self, feats): - path_score, best_path = self._viterbi_decode(feats) - return path_score, best_path - - - def _score_sentence(self, scores, mask, tags): - """ - input: - scores: variable (seq_len, batch, tag_size, tag_size) - mask: (batch, seq_len) - tags: tensor (batch, seq_len) - output: - score: sum of score for gold sequences within whole batch - """ - # Gives the score of a provided tag sequence - batch_size = scores.size(1) - seq_len = scores.size(0) - tag_size = scores.size(2) - ## convert tag value into a new format, recorded label bigram information to index - new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len)) - if self.gpu: - new_tags = new_tags.cuda() - for idx in range(seq_len): - if idx == 0: - ## start -> first score - new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0] - - else: - new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx] - - ## transition for label to STOP_TAG - end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) - ## length for batch, last word position = length - 1 - length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() - ## index the label id of last word - end_ids = torch.gather(tags, 1, length_mask - 1) - - ## index the transition score for end_id to STOP_TAG - end_energy = torch.gather(end_transition, 1, end_ids) - - ## convert tag as (seq_len, batch_size, 1) - new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1) - ### need convert tags id to search from 400 positions of scores - tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size - ## mask transpose to (seq_len, batch_size) - tg_energy = tg_energy.masked_select(mask.transpose(1,0)) - - # ## calculate the score from START_TAG to first label - # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) - # start_energy = torch.gather(start_transition, 1, tags[0,:]) - - ## add all score together - # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() - gold_score = tg_energy.sum() + end_energy.sum() - return gold_score - - def neg_log_likelihood_loss(self, feats, mask, tags): - # nonegative log likelihood - batch_size = feats.size(0) - forward_score, scores = self._calculate_PZ(feats, mask) #forward_score:long, scores: (seq_len, batch, tag_size, tag_size) - gold_score = self._score_sentence(scores, mask, tags) - #print ("batch, f:", forward_score.data, " g:", gold_score.data, " dis:", forward_score.data - gold_score.data) - # exit(0) - if self.average_batch: - return (forward_score - gold_score)/batch_size - else: - return forward_score - gold_score - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/layers/ner_layers/layers.py b/layers/ner_layers/layers.py deleted file mode 100644 index 4159378..0000000 --- a/layers/ner_layers/layers.py +++ /dev/null @@ -1,261 +0,0 @@ -# -*- coding: utf-8 -*- -import torch -import torch.autograd as autograd -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -import math, copy, time - -# class CNNmodel(nn.Module): -# def __init__(self, input_dim, hidden_dim, num_layer, dropout, gpu=True): -# super(CNNmodel, self).__init__() -# self.input_dim = input_dim -# self.hidden_dim = hidden_dim -# self.num_layer = num_layer -# self.gpu = gpu -# -# self.cnn_layer0 = nn.Conv1d(self.input_dim, self.hidden_dim, kernel_size=1, padding=0) -# self.cnn_layers = [nn.Conv1d(self.hidden_dim, self.hidden_dim, kernel_size=3, padding=1) for i in range(self.num_layer-1)] -# self.drop = nn.Dropout(dropout) -# -# if self.gpu: -# self.cnn_layer0 = self.cnn_layer0.cuda() -# for i in range(self.num_layer-1): -# self.cnn_layers[i] = self.cnn_layers[i].cuda() -# -# def forward(self, input_feature): -# -# batch_size = input_feature.size(0) -# seq_len = input_feature.size(1) -# -# input_feature = input_feature.transpose(2,1).contiguous() -# cnn_output = self.cnn_layer0(input_feature) #(b,h,l) -# cnn_output = self.drop(cnn_output) -# cnn_output = torch.tanh(cnn_output) -# -# for layer in range(self.num_layer-1): -# cnn_output = self.cnn_layers[layer](cnn_output) -# cnn_output = self.drop(cnn_output) -# cnn_output = torch.tanh(cnn_output) -# -# cnn_output = cnn_output.transpose(2,1).contiguous() -# return cnn_output -# -# -# -# def clones(module, N): -# "Produce N identical layers." -# return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) -# -# class LayerNorm(nn.Module): -# "Construct a layernorm module (See citation for details)." -# def __init__(self, features, eps=1e-6): -# super(LayerNorm, self).__init__() -# self.a_2 = nn.Parameter(torch.ones(features)) -# self.b_2 = nn.Parameter(torch.zeros(features)) -# self.eps = eps -# -# def forward(self, x): -# mean = x.mean(-1, keepdim=True) -# std = x.std(-1, keepdim=True) -# return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 -# -# class SublayerConnection(nn.Module): -# """ -# A residual connection followed by a layer norm. -# Note for code simplicity the norm is first as opposed to last. -# """ -# def __init__(self, size, dropout): -# super(SublayerConnection, self).__init__() -# self.norm = LayerNorm(size) -# self.dropout = nn.Dropout(dropout) -# -# def forward(self, x, sublayer): -# "Apply residual connection to any sublayer with the same size." -# return x + self.dropout(sublayer(self.norm(x))) -# -# class EncoderLayer(nn.Module): -# "Encoder is made up of self-attn and feed forward (defined below)" -# def __init__(self, size, self_attn, feed_forward, dropout): -# super(EncoderLayer, self).__init__() -# self.self_attn = self_attn -# self.feed_forward = feed_forward -# self.sublayer = clones(SublayerConnection(size, dropout), 2) -# self.size = size -# -# def forward(self, x, mask): -# "Follow Figure 1 (left) for connections." -# x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) -# return self.sublayer[1](x, self.feed_forward) -# -# -# def attention(query, key, value, mask=None, dropout=None): -# "Compute 'Scaled Dot Product Attention'" -# d_k = query.size(-1) -# scores = torch.matmul(query, key.transpose(-2, -1)) \ -# / math.sqrt(d_k) ## (b,h,l,d) * (b,h,d,l) -# if mask is not None: -# # scores = scores.masked_fill(mask == 0, -1e9) -# scores = scores.masked_fill(mask, -1e9) -# p_attn = F.softmax(scores, dim = -1) -# if dropout is not None: -# p_attn = dropout(p_attn) -# return torch.matmul(p_attn, value), p_attn ##(b,h,l,l) * (b,h,l,d) = (b,h,l,d) -# -# -# class MultiHeadedAttention(nn.Module): -# def __init__(self, h, d_model, dropout=0.1): -# "Take in model size and number of heads." -# super(MultiHeadedAttention, self).__init__() -# assert d_model % h == 0 -# # We assume d_v always equals d_k -# self.d_k = d_model // h -# self.h = h -# self.linears = clones(nn.Linear(d_model, d_model), 4) -# self.attn = None -# self.dropout = nn.Dropout(p=dropout) -# -# def forward(self, query, key, value, mask=None): -# "Implements Figure 2" -# if mask is not None: -# # Same mask applied to all h heads. -# mask = mask.unsqueeze(1) -# nbatches = query.size(0) -# -# # 1) Do all the linear projections in batch from d_model => h x d_k -# query, key, value = \ -# [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) -# for l, x in zip(self.linears, (query, key, value))] -# -# # 2) Apply attention on all the projected vectors in batch. -# x, self.attn = attention(query, key, value, mask=mask, -# dropout=self.dropout) -# -# # 3) "Concat" using a view and apply a final linear. -# x = x.transpose(1, 2).contiguous() \ -# .view(nbatches, -1, self.h * self.d_k) -# return self.linears[-1](x) -# -# -# class PositionwiseFeedForward(nn.Module): -# "Implements FFN equation." -# def __init__(self, d_model, d_ff, dropout=0.1): -# super(PositionwiseFeedForward, self).__init__() -# self.w_1 = nn.Linear(d_model, d_ff) -# self.w_2 = nn.Linear(d_ff, d_model) -# self.dropout = nn.Dropout(dropout) -# -# def forward(self, x): -# return self.w_2(self.dropout(F.relu(self.w_1(x)))) -# -# -# class PositionalEncoding(nn.Module): -# "Implement the PE function." -# def __init__(self, d_model, dropout, max_len=5000): -# super(PositionalEncoding, self).__init__() -# self.dropout = nn.Dropout(p=dropout) -# -# # Compute the positional encodings once in log space. -# pe = torch.zeros(max_len, d_model) -# position = torch.arange(0., max_len).unsqueeze(1) -# div_term = torch.exp(torch.arange(0., d_model, 2) * -# -(math.log(10000.0) / d_model)) -# pe[:, 0::2] = torch.sin(position * div_term) -# pe[:, 1::2] = torch.cos(position * div_term) -# pe = pe.unsqueeze(0) -# self.register_buffer('pe', pe) -# -# def forward(self, x): -# x = x + autograd.Variable(self.pe[:, :x.size(1)], -# requires_grad=False) -# return self.dropout(x) -# -# -# class AttentionModel(nn.Module): -# "Core encoder is a stack of N layers" -# def __init__(self, d_input, d_model, d_ff, head, num_layer, dropout): -# super(AttentionModel, self).__init__() -# c = copy.deepcopy -# # attn0 = MultiHeadedAttention(head, d_input, d_model) -# attn = MultiHeadedAttention(head, d_model, dropout) -# ff = PositionwiseFeedForward(d_model, d_ff, dropout) -# # position = PositionalEncoding(d_model, dropout) -# # layer0 = EncoderLayer(d_model, c(attn0), c(ff), dropout) -# layer = EncoderLayer(d_model, c(attn), c(ff), dropout) -# self.layers = clones(layer, num_layer) -# # layerlist = [copy.deepcopy(layer0),] -# # for _ in range(num_layer-1): -# # layerlist.append(copy.deepcopy(layer)) -# # self.layers = nn.ModuleList(layerlist) -# self.norm = LayerNorm(layer.size) -# self.posi = PositionalEncoding(d_model, dropout) -# self.input2model = nn.Linear(d_input, d_model) -# -# def forward(self, x, mask): -# "Pass the input (and mask) through each layer in turn." -# # x: embedding (b,l,we) -# x = self.posi(self.input2model(x)) -# for layer in self.layers: -# x = layer(x, mask) -# return self.norm(x) -from layers.encoders.transformers.transformer import TransformerEncoder - - -class NERmodel(nn.Module): - - def __init__(self, model_type, input_dim, hidden_dim, num_layer, dropout=0.5, gpu=True, biflag=True): - super(NERmodel, self).__init__() - self.model_type = model_type - - if self.model_type == 'lstm': - self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layer, batch_first=True, bidirectional=biflag) - self.drop = nn.Dropout(dropout) - - # if self.model_type == 'cnn': - # self.cnn = CNNmodel(input_dim, hidden_dim, num_layer, dropout, gpu) - # - # ## attention model - if self.model_type == 'transformer': - n_head = 6 - head_dims = 80 - num_layers = 2 - d_model = n_head * head_dims - feedforward_dim = int(2 * d_model) - dropout = 0.15 - fc_dropout = 0.4 - after_norm = 1 - attn_type='adatrans' - scale = attn_type == 'transformer' - - self.in_fc = nn.Linear(input_dim, d_model) - - self.transformer = TransformerEncoder(num_layers, d_model, n_head, feedforward_dim, dropout, - after_norm=after_norm, attn_type=attn_type, - scale=False, dropout_attn=scale, - pos_embed=None) - self.fc_dropout = nn.Dropout(fc_dropout) - - # self.attention_model = AttentionModel(d_input=input_dim, d_model=hidden_dim, d_ff=2*hidden_dim, head=4, num_layer=num_layer, dropout=dropout) - # for p in self.attention_model.parameters(): - # if p.dim() > 1: - # nn.init.xavier_uniform_(p) - - - def forward(self, input, mask=None): - - if self.model_type == 'lstm': - hidden = None - feature_out, hidden = self.lstm(input, hidden) - - feature_out_d = self.drop(feature_out) - - # if self.model_type == 'cnn': - # feature_out_d = self.cnn(input) - # - if self.model_type == 'transformer': - chars = self.in_fc(input) - chars = self.transformer(chars, mask) - feature_out_d = self.fc_dropout(chars) - - return feature_out_d - diff --git a/models/ner_net/bert_ner.py b/models/ner_net/bert_finetune_ner.py similarity index 100% rename from models/ner_net/bert_ner.py rename to models/ner_net/bert_finetune_ner.py diff --git a/models/ner_net/augment_ner.py b/models/ner_net/general_ner.py similarity index 100% rename from models/ner_net/augment_ner.py rename to models/ner_net/general_ner.py diff --git a/run/entity_extraction/augmentedNER/__init__.py b/run/entity_extraction/augmentedNER/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/layers/ner_layers/__init__.py b/run/entity_extraction/generalNER/__init__.py similarity index 100% rename from layers/ner_layers/__init__.py rename to run/entity_extraction/generalNER/__init__.py diff --git a/run/entity_extraction/augmentedNER/main.py b/run/entity_extraction/generalNER/main.py similarity index 97% rename from run/entity_extraction/augmentedNER/main.py rename to run/entity_extraction/generalNER/main.py index 766d647..129f4f4 100644 --- a/run/entity_extraction/augmentedNER/main.py +++ b/run/entity_extraction/generalNER/main.py @@ -20,7 +20,7 @@ import torch.autograd as autograd import torch.optim as optim from models.ner_net.augment_ner import GazLSTM as SeqModel -from models.ner_net.bert_ner import BertNER +from models.ner_net.transformers_ner import BertNER from run.entity_extraction.augmentedNER.data import Data from run.entity_extraction.augmentedNER.metric import get_ner_fmeasure diff --git a/run/entity_extraction/augmentedNER/alphabet.py b/run/entity_extraction/generalNER/utils/alphabet.py similarity index 100% rename from run/entity_extraction/augmentedNER/alphabet.py rename to run/entity_extraction/generalNER/utils/alphabet.py diff --git a/run/entity_extraction/augmentedNER/data.py b/run/entity_extraction/generalNER/utils/data.py similarity index 98% rename from run/entity_extraction/augmentedNER/data.py rename to run/entity_extraction/generalNER/utils/data.py index c56bcda..7445d83 100644 --- a/run/entity_extraction/augmentedNER/data.py +++ b/run/entity_extraction/generalNER/utils/data.py @@ -3,8 +3,8 @@ import sys from tqdm import tqdm -from run.entity_extraction.augmentedNER.alphabet import Alphabet -from run.entity_extraction.augmentedNER.functions import * +from run.entity_extraction.generalNER.alphabet import Alphabet +from run.entity_extraction.generalNER.functions import * START = "" UNKNOWN = "" diff --git a/run/entity_extraction/augmentedNER/functions.py b/run/entity_extraction/generalNER/utils/functions.py similarity index 100% rename from run/entity_extraction/augmentedNER/functions.py rename to run/entity_extraction/generalNER/utils/functions.py diff --git a/run/entity_extraction/augmentedNER/gazetteer.py b/run/entity_extraction/generalNER/utils/gazetteer.py similarity index 96% rename from run/entity_extraction/augmentedNER/gazetteer.py rename to run/entity_extraction/generalNER/utils/gazetteer.py index 0a474c2..d958b1b 100644 --- a/run/entity_extraction/augmentedNER/gazetteer.py +++ b/run/entity_extraction/generalNER/utils/gazetteer.py @@ -1,4 +1,4 @@ -from run.entity_extraction.augmentedNER.trie import Trie +from run.entity_extraction.generalNER.trie import Trie class Gazetteer: diff --git a/run/entity_extraction/augmentedNER/metric.py b/run/entity_extraction/generalNER/utils/metric.py similarity index 100% rename from run/entity_extraction/augmentedNER/metric.py rename to run/entity_extraction/generalNER/utils/metric.py diff --git a/run/entity_extraction/augmentedNER/trie.py b/run/entity_extraction/generalNER/utils/trie.py similarity index 100% rename from run/entity_extraction/augmentedNER/trie.py rename to run/entity_extraction/generalNER/utils/trie.py