# -*- coding: utf-8 -*- # @Author: Jie Yang # @Date: 2017-12-04 23:19:38 # @Last Modified by: Jie Yang, Contact: jieynlp@gmail.com # @Last Modified time: 2018-05-27 22:48:17 import torch import torch.autograd as autograd import torch.nn as nn START_TAG = -2 STOP_TAG = -1 # Compute log sum exp in a numerically stable way for the forward algorithm def log_sum_exp(vec, m_size): """ calculate log of exp sum args: vec (batch_size, vanishing_dim, hidden_dim) : input tensor (b,t,t) m_size : hidden_dim return: batch_size, hidden_dim """ _, idx = torch.max(vec, 1) # B * 1 * M max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) class CRF(nn.Module): def __init__(self, tagset_size, gpu, device=0): super(CRF, self).__init__() print("build batched crf...") self.gpu = gpu self.device = device # Matrix of transition parameters. Entry i,j is the score of transitioning *to* i *from* j. self.average_batch = False self.tagset_size = tagset_size # # We add 2 here, because of START_TAG and STOP_TAG # # transitions (f_tag_size, t_tag_size), transition value from f_tag to t_tag init_transitions = torch.zeros(self.tagset_size + 2, self.tagset_size + 2) # init_transitions = torch.zeros(self.tagset_size+2, self.tagset_size+2) # init_transitions[:,START_TAG] = -1000.0 # init_transitions[STOP_TAG,:] = -1000.0 # init_transitions[:,0] = -1000.0 # init_transitions[0,:] = -1000.0 if self.gpu: init_transitions = init_transitions.cuda(device) self.transitions = nn.Parameter(init_transitions) # (t+2,t+2) # self.transitions = nn.Parameter(torch.Tensor(self.tagset_size+2, self.tagset_size+2)) # self.transitions.data.zero_() def _calculate_PZ(self, feats, mask): """ input: feats: (batch, seq_len, self.tag_size+2) (b,m,t+2) masks: (batch, seq_len) (b,m) """ batch_size = feats.size(0) seq_len = feats.size(1) tag_size = feats.size(2) # print feats.view(seq_len, tag_size) assert (tag_size == self.tagset_size + 2) mask = mask.transpose(1, 0).contiguous() # (m,b) ins_num = seq_len * batch_size ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) feats = feats.transpose(1, 0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) ## need to consider start scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) scores = scores.view(seq_len, batch_size, tag_size, tag_size) # build iter seq_iter = enumerate(scores) # (index,matrix) index is among the first dim: seqlen _, inivalues = seq_iter.__next__() # only need start from start_tag partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) ## add start score (from start to all tag, duplicate to batch_size) # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1) # iter over last scores for idx, cur_values in seq_iter: # previous to_target is current from_target # partition: previous results log(exp(from_target)), #(batch_size * from_target) # cur_values: bat_size * from_target * to_target cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) cur_partition = log_sum_exp(cur_values, tag_size) # (b,t) # print cur_partition.data # (bat_size * from_target * to_target) -> (bat_size * to_target) # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1) mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) ## effective updated partition part, only keep the partition value of mask value = 1 masked_cur_partition = cur_partition.masked_select(mask_idx) ## let mask_idx broadcastable, to disable warning mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) ## replace the partition where the maskvalue=1, other partition value keeps the same partition.masked_scatter_(mask_idx, masked_cur_partition) # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG cur_values = self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view( batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) cur_partition = log_sum_exp(cur_values, tag_size) # (batch_size,hidden_dim) final_partition = cur_partition[:, STOP_TAG] # (batch_size) return final_partition.sum(), scores # scores: (seq_len, batch, tag_size, tag_size) def _viterbi_decode(self, feats, mask): """ input: feats: (batch, seq_len, self.tag_size+2) mask: (batch, seq_len) output: decode_idx: (batch, seq_len) decoded sequence path_score: (batch, 1) corresponding score for each sequence (to be implementated) """ batch_size = feats.size(0) seq_len = feats.size(1) tag_size = feats.size(2) assert (tag_size == self.tagset_size + 2) ## calculate sentence length for each sentence length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long() ## mask to (seq_len, batch_size) mask = mask.transpose(1, 0).contiguous() # (seq_len,b) ins_num = seq_len * batch_size ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) feats = feats.transpose(1, 0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) # (ins_num, tag_size, tag_size) ## need to consider start scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) scores = scores.view(seq_len, batch_size, tag_size, tag_size) # build iter seq_iter = enumerate(scores) ## record the position of best score back_points = list() partition_history = list() ## reverse mask (bug for mask = 1- mask, use this as alternative choice) # mask = 1 + (-1)*mask mask = (1 - mask.long()).byte() _, inivalues = seq_iter.__next__() # bat_size * from_target_size * to_target_size # only need start from start_tag partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size # print(partition.size()) partition_history.append(partition) # (seqlen,batch_size,tag_size,1) # iter over last scores for idx, cur_values in seq_iter: # previous to_target is current from_target # partition: previous results log(exp(from_target)), #(batch_size * from_target) # cur_values: batch_size * from_target * to_target cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG partition, cur_bp = torch.max(cur_values, dim=1) # print(partition.size()) partition_history.append(partition.unsqueeze(2)) ## cur_bp: (batch_size, tag_size) max source score position in current tag ## set padded label as 0, which will be filtered in post processing cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) back_points.append(cur_bp) ### add score to final STOP_TAG partition_history = torch.cat(partition_history, dim=0).view(seq_len, batch_size, -1).transpose(1, 0).contiguous() ## (batch_size, seq_len, tag_size) ### get the last position for each setences, and select the last partitions using gather() last_position = length_mask.view(batch_size, 1, 1).expand(batch_size, 1, tag_size) - 1 last_partition = torch.gather(partition_history, 1, last_position).view(batch_size, tag_size, 1) ### calculate the score from last partition to end state (and then select the STOP_TAG from it) last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1, tag_size, tag_size).expand( batch_size, tag_size, tag_size) _, last_bp = torch.max(last_values, 1) # (batch_size,tag_size) pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long() if self.gpu: pad_zero = pad_zero.cuda(self.device) back_points.append(pad_zero) back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) ## select end ids in STOP_TAG pointer = last_bp[:, STOP_TAG] # (batch_size) insert_last = pointer.contiguous().view(batch_size, 1, 1).expand(batch_size, 1, tag_size) back_points = back_points.transpose(1, 0).contiguous() # (batch_size,sq_len,tag_size) ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values # print "lp:",last_position # print "il:",insert_last back_points.scatter_(1, last_position, insert_last) ##(batch_size,sq_len,tag_size) # print "bp:",back_points # exit(0) back_points = back_points.transpose(1, 0).contiguous() # (seq_len, batch_size, tag_size) ## decode from the end, padded position ids are 0, which will be filtered if following evaluation decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size)) if self.gpu: decode_idx = decode_idx.cuda(self.device) decode_idx[-1] = pointer.data for idx in range(len(back_points) - 2, -1, -1): pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) # pointer's size:(batch_size,1) decode_idx[idx] = pointer.squeeze(1).data path_score = None decode_idx = decode_idx.transpose(1, 0) # (batch_size, sent_len) return path_score, decode_idx # def forward(self, feats): path_score, best_path = self._viterbi_decode(feats) return path_score, best_path def _score_sentence(self, scores, mask, tags): """ input: scores: variable (seq_len, batch, tag_size, tag_size) mask: (batch, seq_len) tags: tensor (batch, seq_len) output: score: sum of score for gold sequences within whole batch """ # Gives the score of a provided tag sequence batch_size = scores.size(1) seq_len = scores.size(0) tag_size = scores.size(2) ## convert tag value into a new format, recorded label bigram information to index new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len)) if self.gpu: new_tags = new_tags.cuda(self.device) for idx in range(seq_len): if idx == 0: ## start -> first score new_tags[:, 0] = (tag_size - 2) * tag_size + tags[:, 0] else: new_tags[:, idx] = tags[:, idx - 1] * tag_size + tags[:, idx] ## transition for label to STOP_TAG end_transition = self.transitions[:, STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) ## length for batch, last word position = length - 1 length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long() ## index the label id of last word end_ids = torch.gather(tags, 1, length_mask - 1) ## index the transition score for end_id to STOP_TAG end_energy = torch.gather(end_transition, 1, end_ids) ## convert tag as (seq_len, batch_size, 1) new_tags = new_tags.transpose(1, 0).contiguous().view(seq_len, batch_size, 1) ### need convert tags id to search from 400 positions of scores tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size ## mask transpose to (seq_len, batch_size) tg_energy = tg_energy.masked_select(mask.transpose(1, 0)) # ## calculate the score from START_TAG to first label # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) # start_energy = torch.gather(start_transition, 1, tags[0,:]) ## add all score together # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() gold_score = tg_energy.sum() + end_energy.sum() return gold_score def neg_log_likelihood_loss(self, feats, mask, tags): # nonegative log likelihood batch_size = feats.size(0) forward_score, scores = self._calculate_PZ(feats, mask) # forward_score:long, scores: (seq_len, batch, tag_size, tag_size) gold_score = self._score_sentence(scores, mask, tags) # print ("batch, f:", forward_score.data, " g:", gold_score.data, " dis:", forward_score.data - gold_score.data) # exit(0) if self.average_batch: return (forward_score - gold_score) / batch_size else: return forward_score - gold_score