first add model codes

This commit is contained in:
loujie0822 2020-02-14 13:49:44 +08:00
parent cad8b9b8fd
commit 144d94550a
19 changed files with 1494 additions and 0 deletions

0
models/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,120 @@
import torch
from torch import nn
from layers.encoders.transformers.bert.bert_model import BertModel
from layers.encoders.transformers.bert.bert_pretrain import BertPreTrainedModel
# class AttributeExtractNet(BertPreTrainedModel):
# """
# Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware and
# encoded by BERT
# """
#
# def __init__(self, config, args, attribute_conf):
# print('bert mpn baseline')
# super(AttributeExtractNet, self).__init__(config, args, attribute_conf)
#
# self.bert = BertModel(config)
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
# padding_idx=0)
#
# # sentence_encoder using transformer
# self.transformer_encoder_layer = TransformerEncoderLayer(config.hidden_size, args.nhead,
# dim_feedforward=args.dim_feedforward)
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
#
# self.classes_num = len(attribute_conf)
#
# # pointer net work
# self.attr_start = nn.Linear(config.hidden_size, self.classes_num)
# self.attr_end = nn.Linear(config.hidden_size, self.classes_num)
#
# self.apply(self.init_bert_weights)
#
# def forward(self, passage_id=None, token_type_id=None, segment_id=None, pos_start=None, pos_end=None, start_id=None,
# end_id=None, is_eval=False):
# mask = passage_id.eq(0)
# sent_mask = passage_id != 0
#
# context_encoder, _ = self.bert(passage_id, segment_id, attention_mask=sent_mask,
# output_all_encoded_layers=False)
#
# token_entity_emb = self.token_entity_emb(token_type_id)
#
# # sent encoder based entity-aware
# sent_entity_encoder = context_encoder + token_entity_emb
# transformer_encoder = self.transformer_encoder(sent_entity_encoder.transpose(1, 0),
# src_key_padding_mask=mask).transpose(0, 1)
#
# attr_start = self.attr_start(transformer_encoder)
# attr_end = self.attr_end(transformer_encoder)
#
# loss_fct = nn.BCEWithLogitsLoss(reduction='none')
#
# s1_loss = loss_fct(attr_start, start_id)
# s1_loss = torch.sum(s1_loss, 2)
# s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
#
# s2_loss = loss_fct(attr_end, end_id)
# s2_loss = torch.sum(s2_loss, 2)
# s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
#
# total_loss = s1_loss + s2_loss
# po1 = nn.Sigmoid()(attr_start)
# po2 = nn.Sigmoid()(attr_end)
#
# return total_loss, po1, po2
class AttributeExtractNet(BertPreTrainedModel):
"""
Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware and
encoded by BERT
"""
def __init__(self, config, args, attribute_conf):
super(AttributeExtractNet, self).__init__(config, args, attribute_conf)
print('bert debug - 2 ')
self.bert = BertModel(config)
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
padding_idx=0)
# # sentence_encoder using transformer
# self.transformer_encoder_layer = TransformerEncoderLayer(config.hidden_size, args.nhead,
# dim_feedforward=args.dim_feedforward)
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
self.classes_num = len(attribute_conf)
# pointer net work
self.attr_start = nn.Linear(config.hidden_size, self.classes_num)
self.attr_end = nn.Linear(config.hidden_size, self.classes_num)
self.apply(self.init_bert_weights)
def forward(self, passage_id=None, token_type_id=None, segment_id=None, pos_start=None, pos_end=None, start_id=None,
end_id=None, is_eval=False):
sent_mask = passage_id != 0
context_encoder, _ = self.bert(passage_id, segment_id, attention_mask=sent_mask,
output_all_encoded_layers=False)
attr_start = self.attr_start(context_encoder)
attr_end = self.attr_end(context_encoder)
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
s1_loss = loss_fct(attr_start, start_id)
s1_loss = torch.sum(s1_loss, 2)
s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
s2_loss = loss_fct(attr_end, end_id)
s2_loss = torch.sum(s2_loss, 2)
s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
total_loss = s1_loss + s2_loss
po1 = nn.Sigmoid()(attr_start)
po2 = nn.Sigmoid()(attr_end)
return total_loss, po1, po2

View File

@ -0,0 +1,384 @@
import warnings
import torch
import torch.nn as nn
from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer
from layers.encoders.rnns.stacked_rnn import StackedBRNN
warnings.filterwarnings("ignore")
class SentenceEncoder(nn.Module):
def __init__(self, args, input_size):
super(SentenceEncoder, self).__init__()
rnn_type = nn.LSTM if args.rnn_encoder == 'lstm' else nn.GRU
self.encoder = StackedBRNN(
input_size=input_size,
hidden_size=args.hidden_size,
num_layers=args.num_layers,
dropout_rate=args.dropout,
dropout_output=True,
concat_layers=False,
rnn_type=rnn_type,
padding=True
)
def forward(self, input, mask):
return self.encoder(input, mask)
class AttributeExtractNet(nn.Module):
"""
Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware
实体感知方式:char_emb + token_entity_emb+pos_start_emb+pos_end_emb
"""
def __init__(self, args, char_emb, attribute_conf):
super(AttributeExtractNet, self).__init__()
print('实体感知方式:char_emb + token_entity_emb+pos_start_emb+pos_end_emb')
if char_emb is not None:
self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
padding_idx=0)
else:
self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
padding_idx=0)
self.pos_start = nn.Embedding(num_embeddings=args.pos_size, embedding_dim=args.pos_dim,
padding_idx=0)
self.pos_end = nn.Embedding(num_embeddings=args.pos_size, embedding_dim=args.pos_dim,
padding_idx=0)
# token whether belong to a entity, 1 represent a entity token, else 0;
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
padding_idx=0)
# sentence_encoder using lstm
self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
self.second_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
# sentence_encoder using transformer
self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead,
dim_feedforward=args.dim_feedforward)
self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
self.classes_num = len(attribute_conf)
# pointer net work
self.attr_start = nn.Linear(args.hidden_size * 2, self.classes_num)
self.attr_end = nn.Linear(args.hidden_size * 2, self.classes_num)
def forward(self, passage_id=None, token_type_id=None, segment_id=None, pos_start=None, pos_end=None, start_id=None,
end_id=None, is_eval=False):
mask = passage_id.eq(0)
sent_mask = passage_id != 0
char_emb = self.char_emb(passage_id)
pos_start_emb = self.pos_start(pos_start)
pos_end_emb = self.pos_end(pos_end)
token_entity_emb = self.token_entity_emb(token_type_id)
# sent encoder based entity-aware
sent_entity_encoder = char_emb + token_entity_emb+pos_start_emb+pos_end_emb
sent_first_encoder = self.first_sentence_encoder(sent_entity_encoder, mask).transpose(1, 0)
transformer_encoder = self.transformer_encoder(sent_first_encoder, src_key_padding_mask=mask).transpose(0, 1)
attr_start = self.attr_start(transformer_encoder)
attr_end = self.attr_end(transformer_encoder)
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
s1_loss = loss_fct(attr_start, start_id)
s1_loss = torch.sum(s1_loss, 2)
s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
s2_loss = loss_fct(attr_end, end_id)
s2_loss = torch.sum(s2_loss, 2)
s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
total_loss = s1_loss + s2_loss
po1 = nn.Sigmoid()(attr_start)
po2 = nn.Sigmoid()(attr_end)
return total_loss, po1, po2
# class AttributeExtractNet(nn.Module):
# """
# Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware
# 实体感知方式sent-encoder+pos_start+pos_end
# """
#
# def __init__(self, args, char_emb, attribute_conf):
# super(AttributeExtractNet, self).__init__()
# print('实体感知方式: sent-encoder+pos_start')
# if char_emb is not None:
# self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
# padding_idx=0)
# else:
# self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
# padding_idx=0)
#
# self.pos_start = nn.Embedding(num_embeddings=args.pos_size, embedding_dim=args.pos_dim,
# padding_idx=0)
# self.pos_end = nn.Embedding(num_embeddings=args.pos_size, embedding_dim=args.pos_dim,
# padding_idx=0)
#
# # token whether belong to a entity, 1 represent a entity token, else 0;
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
# padding_idx=0)
# # sentence_encoder using lstm
# self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
# self.second_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
#
# # sentence_encoder using transformer
# self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead,
# dim_feedforward=args.dim_feedforward)
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
#
# self.classes_num = len(attribute_conf)
#
# # pointer net work
# self.attr_start = nn.Linear(args.hidden_size * 2, self.classes_num)
# self.attr_end = nn.Linear(args.hidden_size * 2, self.classes_num)
#
# def forward(self, passage_id=None, token_type_id=None, segment_id=None, pos_start=None, pos_end=None, start_id=None,
# end_id=None, is_eval=False):
# mask = passage_id.eq(0)
# sent_mask = passage_id != 0
#
# char_emb = self.char_emb(passage_id)
# pos_start_emb = self.pos_start(pos_start)
# pos_end_emb = self.pos_end(pos_end)
# token_entity_emb = self.token_entity_emb(token_type_id)
#
# sent_first_encoder = self.first_sentence_encoder(char_emb, mask)
#
# # sent encoder based entity-aware
# sent_entity_aware = sent_first_encoder+pos_start_emb
#
# sent_second_encoder = self.second_sentence_encoder(sent_entity_aware, mask).transpose(1, 0)
# transformer_encoder = self.transformer_encoder(sent_second_encoder, src_key_padding_mask=mask).transpose(0, 1)
#
# attr_start = self.attr_start(transformer_encoder)
# attr_end = self.attr_end(transformer_encoder)
#
# loss_fct = nn.BCEWithLogitsLoss(reduction='none')
#
# s1_loss = loss_fct(attr_start, start_id)
# s1_loss = torch.sum(s1_loss, 2)
# s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
#
# s2_loss = loss_fct(attr_end, end_id)
# s2_loss = torch.sum(s2_loss, 2)
# s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
#
# total_loss = s1_loss + s2_loss
# po1 = nn.Sigmoid()(attr_start)
# po2 = nn.Sigmoid()(attr_end)
#
# return total_loss, po1, po2
# class AttributeExtractNet(nn.Module):
# """
# Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware
# 实体感知方式:[sent-encoder, token_ent_type, global_repre]
# """
#
# def __init__(self, args, char_emb, attribute_conf):
# super(AttributeExtractNet, self).__init__()
# print('实体感知方式:[sent-encoder, token_ent_type, global_repre]')
# if char_emb is not None:
# self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
# padding_idx=0)
# else:
# self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
# padding_idx=0)
#
# # token whether belong to a entity, 1 represent a entity token, else 0;
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
# padding_idx=0)
# # sentence_encoder using lstm
# self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
# self.second_sentence_encoder = SentenceEncoder(args, args.char_emb_size * 3)
#
# # sentence_encoder using transformer
# self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead,
# dim_feedforward=args.dim_feedforward)
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
#
# self.classes_num = len(attribute_conf)
#
# # pointer net work
# self.attr_start = nn.Linear(args.hidden_size * 2, self.classes_num)
# self.attr_end = nn.Linear(args.hidden_size * 2, self.classes_num)
#
# def forward(self, passage_id=None, token_type_id=None, segment_id=None, pos_start=None, pos_end=None, start_id=None,
# end_id=None, is_eval=False):
# mask = passage_id.eq(0)
# sent_mask = passage_id != 0
#
# char_emb = self.char_emb(passage_id)
# token_entity_emb = self.token_entity_emb(token_type_id)
#
# sent_first_encoder = self.first_sentence_encoder(char_emb, mask)
# global_encoder_, _ = torch.max(sent_first_encoder, 1)
# global_encoder = global_encoder_.unsqueeze(1).expand_as(sent_first_encoder)
# # sent encoder based entity-aware
# sent_entity_aware = torch.cat([sent_first_encoder, token_entity_emb, global_encoder], -1)
#
# sent_second_encoder = self.second_sentence_encoder(sent_entity_aware, mask).transpose(1, 0)
# transformer_encoder = self.transformer_encoder(sent_second_encoder, src_key_padding_mask=mask).transpose(0, 1)
#
# attr_start = self.attr_start(transformer_encoder)
# attr_end = self.attr_end(transformer_encoder)
#
# loss_fct = nn.BCEWithLogitsLoss(reduction='none')
#
# s1_loss = loss_fct(attr_start, start_id)
# s1_loss = torch.sum(s1_loss, 2)
# s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
#
# s2_loss = loss_fct(attr_end, end_id)
# s2_loss = torch.sum(s2_loss, 2)
# s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
#
# total_loss = s1_loss + s2_loss
# po1 = nn.Sigmoid()(attr_start)
# po2 = nn.Sigmoid()(attr_end)
#
# return total_loss, po1, po2
# class AttributeExtractNet(nn.Module):
# """
# Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware
# 实体感知方式token_entity_embedding
# """
#
# def __init__(self, args, char_emb, attribute_conf):
# print('basline ')
# super(AttributeExtractNet, self).__init__()
# if char_emb is not None:
# self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
# padding_idx=0)
# else:
# self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
# padding_idx=0)
#
# # token whether belong to a entity, 1 represent a entity token, else 0;
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
# padding_idx=0)
# # sentence_encoder using lstm
# self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
# self.second_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
#
# # sentence_encoder using transformer
# self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead,
# dim_feedforward=args.dim_feedforward)
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
#
# self.classes_num = len(attribute_conf)
#
# # pointer net work
# self.attr_start = nn.Linear(args.hidden_size * 2, self.classes_num)
# self.attr_end = nn.Linear(args.hidden_size * 2, self.classes_num)
#
# def forward(self, passage_id=None, token_type_id=None, segment_id=None,start_id=None, end_id=None, is_eval=False):
# mask = passage_id.eq(0)
# sent_mask = passage_id != 0
#
# char_emb = self.char_emb(passage_id)
# token_entity_emb = self.token_entity_emb(token_type_id)
#
# sent_first_encoder = self.first_sentence_encoder(char_emb, mask)
# # sent encoder based entity-aware
# sent_entity_encoder = sent_first_encoder + token_entity_emb
# sent_second_encoder = self.second_sentence_encoder(sent_entity_encoder, mask).transpose(1, 0)
# transformer_encoder = self.transformer_encoder(sent_second_encoder, src_key_padding_mask=mask).transpose(0, 1)
#
# attr_start = self.attr_start(transformer_encoder)
# attr_end = self.attr_end(transformer_encoder)
#
# loss_fct = nn.BCEWithLogitsLoss(reduction='none')
#
# s1_loss = loss_fct(attr_start, start_id)
# s1_loss = torch.sum(s1_loss, 2)
# s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
#
# s2_loss = loss_fct(attr_end, end_id)
# s2_loss = torch.sum(s2_loss, 2)
# s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
#
# total_loss = s1_loss + s2_loss
# po1 = nn.Sigmoid()(attr_start)
# po2 = nn.Sigmoid()(attr_end)
#
# return total_loss, po1, po2
# class AttributeExtractNet(nn.Module):
# """
# Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware
# 实体感知方式:显示编码,将实体替换为<MASK>
# """
#
# def __init__(self, args, char_emb, attribute_conf):
# super(AttributeExtractNet, self).__init__()
# print('实体感知方式:显示编码,将实体替换为<MASK>')
#
#
# if char_emb is not None:
# self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
# padding_idx=0)
# else:
# self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
# padding_idx=0)
#
# # token whether belong to a entity, 1 represent a entity token, else 0;
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
# padding_idx=0)
# # sentence_encoder using lstm
# self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
# self.second_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
#
# # sentence_encoder using transformer
# self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead,
# dim_feedforward=args.dim_feedforward)
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
#
# self.classes_num = len(attribute_conf)
#
# # pointer net work
# self.attr_start = nn.Linear(args.hidden_size * 2, self.classes_num)
# self.attr_end = nn.Linear(args.hidden_size * 2, self.classes_num)
#
# def forward(self, passage_id=None, token_type_id=None, segment_id=None, start_id=None, end_id=None, is_eval=False):
# mask = passage_id.eq(0)
# sent_mask = passage_id != 0
#
# char_emb = self.char_emb(passage_id)
# # token_entity_emb = self.token_entity_emb(token_type_id)
#
# # char_emb = char_emb + token_entity_emb
#
# sent_first_encoder = self.first_sentence_encoder(char_emb, mask).transpose(1, 0)
# # sent encoder based entity-aware
# # sent_entity_encoder = sent_first_encoder + token_entity_emb
# # sent_second_encoder = self.second_sentence_encoder(sent_first_encoder, mask).transpose(1, 0)
# transformer_encoder = self.transformer_encoder(sent_first_encoder, src_key_padding_mask=mask).transpose(0, 1)
#
# attr_start = self.attr_start(transformer_encoder)
# attr_end = self.attr_end(transformer_encoder)
#
# loss_fct = nn.BCEWithLogitsLoss(reduction='none')
#
# s1_loss = loss_fct(attr_start, start_id)
# s1_loss = torch.sum(s1_loss, 2)
# s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
#
# s2_loss = loss_fct(attr_end, end_id)
# s2_loss = torch.sum(s2_loss, 2)
# s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
#
# total_loss = s1_loss + s2_loss
# po1 = nn.Sigmoid()(attr_start)
# po2 = nn.Sigmoid()(attr_end)
#
# return total_loss, po1, po2

0
run/__init__.py Normal file
View File

View File

View File

View File

@ -0,0 +1,392 @@
import json
import logging
from collections import Counter
from functools import partial
import jieba
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from layers.encoders.transformers.bert.bert_tokenization import BertTokenizer
from utils.data_util import padding, mpn_padding, _handle_pos_limit
config = {
'drug': {
'药品-用药频率': 0,
'药品-持续时间': 1,
'药品-用药剂量': 2,
'药品-用药方法': 3,
'药品-不良反应': 4,
},
'disease': {
'疾病-检查方法': 0,
'疾病-临床表现': 1,
'疾病-非药治疗': 2,
'疾病-药品名称': 3,
'疾病-部位': 4,
},
'oncology_drug': {
'药物_剂量': 0,
'药物_给药日': 1,
'药物_给药方式': 2,
},
'yingxiang_bingzao': {
"病灶部位_异常描述因子": 0,
"病灶部位_阴性描述因子": 1,
"病灶部位-诊断因子": 2
}
}
class Attribute(object):
def __init__(self,
value,
value_pos_start,
value_pos_end,
attr_type,
attr_type_id
):
self.value = value
self.value_pos_start = value_pos_start
self.value_pos_end = value_pos_end
self.attr_type = attr_type
self.attr_type_id = attr_type_id
class Example(object):
def __init__(self,
p_id=None,
context=None,
bert_tokens=None,
entity_name=None,
entity_position=None,
pos_start=None,
pos_end=None,
gold_attr_list=None,
gold_answer=None):
self.p_id = p_id
self.context = context
self.bert_tokens = bert_tokens
self.entity_name = entity_name
self.entity_position = entity_position
self.pos_start = pos_start
self.pos_end = pos_end
self.gold_attr_list = gold_attr_list
self.gold_answer = gold_answer
def __str__(self):
return self.__repr__()
def __repr__(self):
gold_repr = []
for attr in self.gold_attr_list:
gold_repr.append(attr.attr_type + '-' + attr.value)
return 'entity {} with attribute:\n{}'.format(self.entity_name, '\t'.join(gold_repr))
class InputFeature(object):
def __init__(self,
p_id=None,
passage_id=None,
token_type_id=None,
pos_start_id=None,
pos_end_id=None,
segment_id=None,
label=None):
self.p_id = p_id
self.passage_id = passage_id
self.token_type_id = token_type_id
self.pos_start_id = pos_start_id
self.pos_end_id = pos_end_id
self.segment_id = segment_id
self.label = label
class Reader(object):
def __init__(self, do_lowercase=False, seg_char=False, max_len=600, entity_type="药品名称"):
self.do_lowercase = do_lowercase
self.seg_char = seg_char
self.max_len = max_len
self.entity_config = config[entity_type]
if self.seg_char:
logging.info("seg_char...")
else:
logging.info("seg_word using jieba ...")
def read_examples(self, filename, data_type):
logging.info("Generating {} examples...".format(data_type))
return self._read(filename, data_type)
def _read(self, filename, data_type):
with open(filename, 'r') as fh:
source_data = json.load(fh)
examples = []
for p_id in tqdm(range(len(source_data))):
data = source_data[p_id]
para = data['text']
context = para if self.seg_char else ''.join(jieba.lcut(para))
if len(context) > self.max_len:
context = context[:self.max_len]
context = context.lower() if self.do_lowercase else context
_data_dict = dict()
_data_dict['id'] = p_id
entity_name, entity_position = data['entity'][0], data['entity'][1]
entity_name = entity_name.lower() if self.do_lowercase else entity_name
start, end = entity_position
assert entity_name == context[start:end]
# pos_start&pos_end: 指句子中词语相对entity的position限制
# 如:[-30, 30]embed 时整体+31变成[1, 61]
# 则一共62个pos token0 留给 pad
pos_start = list(map(lambda i: i - start, list(range(len(context)))))
pos_end = list(map(lambda i: i - end, list(range(len(context)))))
pos_start = _handle_pos_limit(pos_start)
pos_end = _handle_pos_limit(pos_end)
attribute_list = data['attribute_list']
gold_attr_list = []
for attribute in attribute_list:
attr_type = attribute['type']
value = attribute['value']
value, value_pos_start, value_pos_end = value[0], value[1], value[2]
value = value.lower() if self.do_lowercase else value
assert value == context[value_pos_start:value_pos_end]
gold_attr_list.append(Attribute(
value=value,
value_pos_start=value_pos_start,
value_pos_end=value_pos_end,
attr_type=attr_type,
attr_type_id=self.entity_config[attr_type]
))
gold_answer = [attr.attr_type + '@' + attr.value for attr in gold_attr_list]
examples.append(
Example(
p_id=p_id,
context=context,
entity_name=entity_name,
entity_position=entity_position,
pos_start=pos_start,
pos_end=pos_end,
gold_attr_list=gold_attr_list,
gold_answer=gold_answer
)
)
logging.info("{} total size is {} ".format(data_type, len(examples)))
return examples
class Vocabulary(object):
def __init__(self, special_tokens=["<OOV>", "<MASK>"]):
self.char_vocab = None
self.emb_mat = None
self.char2idx = None
self.char_counter = Counter()
self.special_tokens = special_tokens
def build_vocab_only_with_char(self, examples, min_char_count=-1):
logging.info("Building vocabulary only with character...")
self.char_vocab = ["<PAD>"]
if self.special_tokens is not None and isinstance(self.special_tokens, list):
self.char_vocab.extend(self.special_tokens)
for example in tqdm(examples):
for char in example.context:
self.char_counter[char] += 1
for w, v in self.char_counter.most_common():
if v >= min_char_count:
self.char_vocab.append(w)
self.char2idx = {token: idx for idx, token in enumerate(self.char_vocab)}
logging.info("total char counter size is {} ".format(len(self.char_counter)))
logging.info("total char vocabulary size is {} ".format(len(self.char_vocab)))
def _load_embedding(self, embedding_file, embedding_dict):
with open(embedding_file) as f:
for line in f:
if len(line.rstrip().split(" ")) <= 2: continue
token, vector = line.rstrip().split(" ", 1)
embedding_dict[token] = np.fromstring(vector, dtype=np.float, sep=" ")
return embedding_dict
def make_embedding(self, vocab, embedding_file, emb_size):
embedding_dict = dict()
embedding_dict["<PAD>"] = np.array([0. for _ in range(emb_size)])
self._load_embedding(embedding_file, embedding_dict)
count = 0
for token in tqdm(vocab):
if token not in embedding_dict:
count += 1
embedding_dict[token] = np.array([np.random.normal(scale=0.1) for _ in range(emb_size)])
logging.info(
"{} / {} tokens have corresponding in embedding vector".format(len(vocab) - count, len(vocab)))
emb_mat = [embedding_dict[token] for idx, token in enumerate(vocab)]
return emb_mat
class Feature(object):
def __init__(self, args, token2idx_dict):
self.bert = args.use_bert
self.token2idx_dict = token2idx_dict
if self.bert:
self.tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
def token2wid(self, token):
if token in self.token2idx_dict:
return self.token2idx_dict[token]
return self.token2idx_dict["<OOV>"]
def __call__(self, examples, entity_type, data_type):
if self.bert:
return self.convert_examples_to_bert_features(examples, entity_type, data_type)
else:
return self.convert_examples_to_features(examples, entity_type, data_type)
def convert_examples_to_features(self, examples, entity_type, data_type):
logging.info("Processing {} examples...".format(data_type))
examples2features = list()
for index, example in enumerate(examples):
gold_attr_list = example.gold_attr_list
ent_start, ent_end = example.entity_position[0], example.entity_position[1]
passage_id = np.zeros(len(example.context), dtype=np.int)
token_type_id = np.zeros(len(example.context), dtype=np.int)
pos_start_id = np.zeros(len(example.context), dtype=np.int)
pos_end_id = np.zeros(len(example.context), dtype=np.int)
for i, token in enumerate(example.context):
if ent_start <= i < ent_end:
# token = "<MASK>"
token_type_id[i] = 1
passage_id[i] = self.token2wid(token)
pos_start_id[i] = example.pos_start[i]
pos_end_id[i] = example.pos_end[i]
examples2features.append(
InputFeature(
p_id=index,
passage_id=passage_id,
token_type_id=token_type_id,
pos_start_id=pos_start_id,
pos_end_id=pos_end_id,
segment_id=token_type_id,
label=gold_attr_list
))
logging.info("Built instances is Completed")
return AttributeMPNDataset(examples2features, attribute_num=len(config[entity_type]))
def convert_examples_to_bert_features(self, examples, entity_type, data_type):
logging.info("Processing {} examples...".format(data_type))
examples2features = list()
for index, example in enumerate(examples):
gold_attr_list = example.gold_attr_list
ent_start, ent_end = example.entity_position[0], example.entity_position[1]
segment_id = np.zeros(len(example.context) + 2, dtype=np.int)
token_type_id = np.zeros(len(example.context) + 2, dtype=np.int)
pos_start_id = np.zeros(len(example.context) + 2, dtype=np.int)
pos_end_id = np.zeros(len(example.context) + 2, dtype=np.int)
tokens = ["[CLS]"]
raw_tokens = ["[CLS]"]
for i, token in enumerate(example.context):
raw_tokens.append(token)
if ent_start <= i < ent_end:
# token_type_id[i + 1] = 1
# segment_id[i + 1] = 1
token = '[unused1]'
tokens.append(token)
pos_start_id[i + 1] = example.pos_start[i]
pos_end_id[i + 1] = example.pos_end[i]
tokens.append("[SEP]")
raw_tokens.append("[SEP]")
passage_id = self.tokenizer.convert_tokens_to_ids(tokens)
example.bert_tokens = raw_tokens
examples2features.append(
InputFeature(
p_id=index,
passage_id=passage_id,
token_type_id=token_type_id,
pos_start_id=pos_start_id,
pos_end_id=pos_end_id,
segment_id=segment_id,
label=gold_attr_list
))
logging.info("Built instances is Completed")
return AttributeMPNDataset(examples2features, attribute_num=len(config[entity_type]), use_bert=True)
class AttributeMPNDataset(Dataset):
def __init__(self, features, attribute_num, use_bert=False):
super(AttributeMPNDataset, self).__init__()
self.use_bert = use_bert
self.q_ids = [f.p_id for f in features]
self.passages = [f.passage_id for f in features]
self.token_type = [f.token_type_id for f in features]
self.label = [f.label for f in features]
self.attribute_num = attribute_num
self.segment_ids = [f.segment_id for f in features]
self.pos_start_ids = [f.pos_start_id for f in features]
self.pos_end_ids = [f.pos_end_id for f in features]
def __len__(self):
return len(self.passages)
def __getitem__(self, index):
return self.q_ids[index], self.passages[index], self.token_type[index], self.segment_ids[index], self.label[
index], self.pos_start_ids[index], self.pos_end_ids[index]
def _create_collate_fn(self, batch_first=False):
def collate(examples):
p_ids, passages, token_type, segment_ids, label, pos_start_ids, pos_end_ids = zip(*examples)
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
passages_tensor, _ = padding(passages, is_float=False, batch_first=batch_first)
pos_start_tensor, _ = padding(pos_start_ids, is_float=False, batch_first=batch_first)
pos_end_tensor, _ = padding(pos_end_ids, is_float=False, batch_first=batch_first)
token_type_tensor, _ = padding(token_type, is_float=False, batch_first=batch_first)
segment_tensor, _ = padding(segment_ids, is_float=False, batch_first=batch_first)
o1_tensor, o2_tensor = mpn_padding(passages, label, class_num=self.attribute_num, is_float=True,
use_bert=self.use_bert)
return p_ids, passages_tensor, token_type_tensor, segment_tensor, pos_start_tensor,pos_end_tensor,o1_tensor, o2_tensor
return partial(collate)
def get_dataloader(self, batch_size, num_workers=0, shuffle=False, batch_first=True, pin_memory=False,
drop_last=False):
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._create_collate_fn(batch_first),
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)

View File

@ -0,0 +1,181 @@
# _*_ coding:utf-8 _*_
import argparse
import logging
import os
from run.attribute_extract.mpn.data import Reader, Vocabulary, config, Feature
from run.attribute_extract.mpn.train import Trainer
from utils.file_util import save, load
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
def get_args():
parser = argparse.ArgumentParser()
# file parameters
parser.add_argument("--input", default=None, type=str, required=True)
parser.add_argument("--output"
, default=None, type=str, required=False,
help="The output directory where the model checkpoints and predictions will be written.")
# "cpt/baidu_w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5"
# 'cpt/baidu_w2v/w2v.txt'
parser.add_argument('--embedding_file', type=str,
default='cpt/baidu_w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5')
# choice parameters
parser.add_argument('--entity_type', type=str, default='disease')
parser.add_argument('--use_word2vec', type=bool, default=False)
parser.add_argument('--use_bert', type=bool, default=False)
parser.add_argument('--seg_char', type=bool, default=True)
# train parameters
parser.add_argument('--train_mode', type=str, default="train")
parser.add_argument("--train_batch_size", default=4, type=int, help="Total batch size for training.")
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--epoch_num", default=3, type=int,
help="Total number of training epochs to perform.")
parser.add_argument('--patience_stop', type=int, default=10, help='Patience for learning early stop')
parser.add_argument('--device_id', type=int, default=0)
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
# bert parameters
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
"of training.")
parser.add_argument("--bert_model", default=None, type=str,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
# model parameters
parser.add_argument("--max_len", default=1000, type=int)
parser.add_argument('--word_emb_size', type=int, default=300)
parser.add_argument('--char_emb_size', type=int, default=300)
parser.add_argument('--entity_emb_size', type=int, default=300)
parser.add_argument('--pos_limit', type=int, default=30)
parser.add_argument('--pos_dim', type=int, default=300)
parser.add_argument('--pos_size', type=int, default=62)
parser.add_argument('--hidden_size', type=int, default=150)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--dropout', type=int, default=0.5)
parser.add_argument('--rnn_encoder', type=str, default='lstm', help="must choose in blow: lstm or gru")
parser.add_argument('--bidirectional', type=bool, default=True)
parser.add_argument('--pin_memory', type=bool, default=False)
parser.add_argument('--transformer_layers', type=int, default=1)
parser.add_argument('--nhead', type=int, default=4)
parser.add_argument('--dim_feedforward', type=int, default=2048)
args = parser.parse_args()
if args.use_word2vec:
args.cache_data = args.input + '/char2v_cache_data/'
elif args.use_bert:
args.cache_data = args.input + '/char_bert_cache_data/'
else:
args.cache_data = args.input + '/char_cache_data/'
return args
def bulid_dataset(args, reader, vocab, debug=False):
char2idx, char_emb = None, None
train_src = args.input + "/train.json"
dev_src = args.input + "/dev.json"
train_examples_file = args.cache_data + "/train-examples.pkl"
dev_examples_file = args.cache_data + "/dev-examples.pkl"
char_emb_file = args.cache_data + "/char_emb.pkl"
char_dictionary = args.cache_data + "/char_dict.pkl"
if not os.path.exists(train_examples_file):
train_examples = reader.read_examples(train_src, data_type='train')
dev_examples = reader.read_examples(dev_src, data_type='dev')
if not args.use_bert:
# todo : min_word_count=3 ?
vocab.build_vocab_only_with_char(train_examples, min_char_count=1)
if args.use_word2vec and args.embedding_file:
char_emb = vocab.make_embedding(vocab=vocab.char_vocab,
embedding_file=args.embedding_file,
emb_size=args.word_emb_size)
save(char_emb_file, char_emb, message="char embedding")
save(char_dictionary, vocab.char2idx, message="char dictionary")
char2idx = vocab.char2idx
save(train_examples_file, train_examples, message="train examples")
save(dev_examples_file, dev_examples, message="dev examples")
else:
if not args.use_bert:
if args.use_word2vec and args.embedding_file:
char_emb = load(char_emb_file)
char2idx = load(char_dictionary)
logging.info("total char vocabulary size is {} ".format(len(char2idx)))
train_examples, dev_examples = load(train_examples_file), load(dev_examples_file)
logging.info('train examples size is {}'.format(len(train_examples)))
logging.info('dev examples size is {}'.format(len(dev_examples)))
if not args.use_bert:
args.vocab_size = len(char2idx)
convert_examples_features = Feature(args, token2idx_dict=char2idx)
train_examples = train_examples[:10] if debug else train_examples
dev_examples = dev_examples[:10] if debug else dev_examples
train_data_set = convert_examples_features(train_examples, entity_type=args.entity_type,
data_type='train')
dev_data_set = convert_examples_features(dev_examples, entity_type=args.entity_type,
data_type='dev')
train_data_loader = train_data_set.get_dataloader(args.train_batch_size, shuffle=True, pin_memory=args.pin_memory)
dev_data_loader = dev_data_set.get_dataloader(args.train_batch_size)
data_loaders = train_data_loader, dev_data_loader
eval_examples = train_examples, dev_examples
return eval_examples, data_loaders, char_emb
# TODO
'''
1增加自动构建 词向量的逻辑
2对比不同实体感知构建方式
3增加elmo方式
4增加对抗训练
'''
def main():
args = get_args()
if not os.path.exists(args.output):
print('mkdir {}'.format(args.output))
os.makedirs(args.output)
if not os.path.exists(args.cache_data):
print('mkdir {}'.format(args.cache_data))
os.makedirs(args.cache_data)
logger.info("** ** * bulid dataset ** ** * ")
reader = Reader(seg_char=args.seg_char, max_len=args.max_len, entity_type=args.entity_type)
vocab = Vocabulary()
eval_examples, data_loaders, char_emb = bulid_dataset(args, reader, vocab, debug=False)
trainer = Trainer(args, data_loaders, eval_examples, char_emb, attribute_conf=config[args.entity_type])
if args.train_mode == "train":
trainer.train(args)
elif args.train_mode == "eval":
trainer.resume(args)
trainer.eval_data_set("train")
trainer.eval_data_set("dev")
elif args.train_mode == "resume":
# trainer.resume(args)
trainer.show("dev") # bad case analysis
if __name__ == '__main__':
main()

View File

@ -0,0 +1,295 @@
# _*_ coding:utf-8 _*_
import logging
import random
import sys
import time
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
import models.attribute_extract_net.bert_mpn as bert_mpn
import models.attribute_extract_net.mpn as mpn
from utils.optimizer_util import set_optimizer
logger = logging.getLogger(__name__)
class Trainer(object):
def __init__(self, args, data_loaders, examples, char_emb, attribute_conf):
if args.use_bert:
self.model = bert_mpn.AttributeExtractNet.from_pretrained(args.bert_model, args, attribute_conf)
else:
self.model = mpn.AttributeExtractNet(args, char_emb, attribute_conf)
self.args = args
self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu")
self.n_gpu = torch.cuda.device_count()
self.id2rel = {item: key for key, item in attribute_conf.items()}
self.rel2id =attribute_conf
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if self.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
self.model.to(self.device)
# self.resume(args)
logging.info('total gpu num is {}'.format(self.n_gpu))
if self.n_gpu > 1:
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
train_dataloader, dev_dataloader = data_loaders
train_eval, dev_eval = examples
self.eval_file_choice = {
"train": train_eval,
"dev": dev_eval,
}
self.data_loader_choice = {
"train": train_dataloader,
"dev": dev_dataloader,
}
self.optimizer = set_optimizer(args, self.model,
train_steps=(int(len(train_eval) / args.train_batch_size) + 1) * args.epoch_num)
def train(self, args):
best_f1 = 0.0
patience_stop = 0
self.model.train()
step_gap = 20
for epoch in range(int(args.epoch_num)):
global_loss = 0.0
for step, batch in tqdm(enumerate(self.data_loader_choice[u"train"]), mininterval=5,
desc=u'training at epoch : %d ' % epoch, leave=False, file=sys.stdout):
loss, answer_dict_ = self.forward(batch)
if step % step_gap == 0:
global_loss += loss
current_loss = global_loss / step_gap
print(
u"step {} / {} of epoch {}, train/loss: {}".format(step, len(self.data_loader_choice["train"]),
epoch, current_loss))
global_loss = 0.0
res_dev = self.eval_data_set("dev")
if res_dev['f1'] >= best_f1:
best_f1 = res_dev['f1']
logging.info("** ** * Saving fine-tuned model ** ** * ")
model_to_save = self.model.module if hasattr(self.model,
'module') else self.model # Only save the model it-self
output_model_file = args.output + "/pytorch_model.bin"
torch.save(model_to_save.state_dict(), str(output_model_file))
patience_stop = 0
else:
patience_stop += 1
if patience_stop >= args.patience_stop:
return
def resume(self, args):
resume_model_file = args.output + "/pytorch_model.bin"
logging.info("=> loading checkpoint '{}'".format(resume_model_file))
checkpoint = torch.load(resume_model_file, map_location='cpu')
self.model.load_state_dict(checkpoint)
def forward(self, batch, chosen=u'train', grad=True, eval=False, detail=False):
batch = tuple(t.to(self.device) for t in batch)
p_ids, passage_id, token_type_id, segment_id, pos_start,pos_end,start_id, end_id = batch
loss, po1, po2 = self.model(passage_id=passage_id, token_type_id=token_type_id, segment_id=segment_id,
pos_start=pos_start,pos_end=pos_end,start_id=start_id, end_id=end_id, is_eval=eval)
if self.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
if grad:
loss.backward()
loss = loss.item()
self.optimizer.step()
self.optimizer.zero_grad()
if eval:
eval_file = self.eval_file_choice[chosen]
answer_dict_ = convert_pointer_net_contour(eval_file, p_ids, po1, po2, self.id2rel,
use_bert=self.args.use_bert)
else:
answer_dict_ = None
return loss, answer_dict_
def eval_data_set(self, chosen="dev"):
self.model.eval()
answer_dict = {}
data_loader = self.data_loader_choice[chosen]
eval_file = self.eval_file_choice[chosen]
last_time = time.time()
with torch.no_grad():
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
loss, answer_dict_ = self.forward(batch, chosen, grad=False, eval=True)
answer_dict.update(answer_dict_)
used_time = time.time() - last_time
logging.info('chosen {} took : {} sec'.format(chosen, used_time))
res = self.evaluate(eval_file, answer_dict, chosen)
self.detail_evaluate(eval_file, answer_dict, chosen)
self.model.train()
return res
def show(self, chosen="dev"):
self.model.eval()
answer_dict = {}
data_loader = self.data_loader_choice[chosen]
eval_file = self.eval_file_choice[chosen]
with torch.no_grad():
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
loss, answer_dict_ = self.forward(batch, chosen, grad=False, eval=True, detail=True)
answer_dict.update(answer_dict_)
self.badcase_analysis(eval_file, answer_dict, chosen)
@staticmethod
def evaluate(eval_file, answer_dict, chosen):
em = 0
pre = 0
gold = 0
for key, value in answer_dict.items():
ground_truths = eval_file[int(key)].gold_answer
value, l1, l2 = value
prediction = list(value) if len(value) else []
assert type(prediction) == type(ground_truths)
intersection = set(prediction) & set(ground_truths)
em += len(intersection)
pre += len(set(prediction))
gold += len(set(ground_truths))
precision = 100.0 * em / pre if pre > 0 else 0.
recall = 100.0 * em / gold if gold > 0 else 0.
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) != 0 else 0.0
print('============================================')
print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, em, pre, gold))
print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(chosen, f1, precision,
recall))
return {'f1': f1, "recall": recall, "precision": precision, 'em': em, 'pre': pre, 'gold': gold}
def detail_evaluate(self, eval_file, answer_dict, chosen):
def generate_detail_dict(spo_list):
dict_detail = dict()
for i, tag in enumerate(spo_list):
detail_name = tag.split('@')[0]
if detail_name not in dict_detail:
dict_detail[detail_name] = [tag]
else:
dict_detail[detail_name].append(tag)
return dict_detail
total_detail = {}
for key, value in answer_dict.items():
ground_truths = eval_file[int(key)].gold_answer
value, l1, l2 = value
prediction = list(value) if len(value) else []
gold_detail = generate_detail_dict(ground_truths)
pred_detail = generate_detail_dict(prediction)
for key in self.rel2id.keys():
pred = pred_detail.get(key, [])
gold = gold_detail.get(key, [])
em = len(set(pred) & set(gold))
pred_num = len(set(pred))
gold_num = len(set(gold))
if key not in total_detail:
total_detail[key] = dict()
total_detail[key]['em'] = em
total_detail[key]['pred_num'] = pred_num
total_detail[key]['gold_num'] = gold_num
else:
total_detail[key]['em'] += em
total_detail[key]['pred_num'] += pred_num
total_detail[key]['gold_num'] += gold_num
for key, res_dict_ in total_detail.items():
res_dict_['p'] = 100.0 * res_dict_['em'] / res_dict_['pred_num'] if res_dict_['pred_num'] != 0 else 0.0
res_dict_['r'] = 100.0 * res_dict_['em'] / res_dict_['gold_num'] if res_dict_['gold_num'] != 0 else 0.0
res_dict_['f'] = 2 * res_dict_['p'] * res_dict_['r'] / (res_dict_['p'] + res_dict_['r']) if res_dict_['p'] + \
res_dict_[
'r'] != 0 else 0.0
for gold_key, res_dict_ in total_detail.items():
print('===============================================================')
print("{}/em: {},\tpred_num&gold_num: {}\t{} ".format(gold_key, res_dict_['em'], res_dict_['pred_num'],
res_dict_['gold_num']))
print(
"{}/f1: {},\tprecison&recall: {}\t{}".format(gold_key, res_dict_['f'], res_dict_['p'], res_dict_['r']))
@staticmethod
def badcase_analysis(eval_file, answer_dict, chosen):
em = 0
pre = 0
gold = 0
content = ''
for key, value in answer_dict.items():
entity_name = eval_file[int(key)].entity_name
context = eval_file[int(key)].context
ground_truths = eval_file[int(key)].gold_answer
value, l1, l2 = value
prediction = list(value) if len(value) else ['']
assert type(prediction) == type(ground_truths)
intersection = set(prediction) & set(ground_truths)
if prediction == ground_truths == ['']:
continue
if set(prediction) != set(ground_truths):
ground_truths = list(sorted(set(ground_truths)))
prediction = list(sorted(set(prediction)))
print('raw context is:\t' + context)
print('subject_name is:\t' + entity_name)
print('pred_text is:\t' + '\t'.join(prediction))
print('gold_text is:\t' + '\t'.join(ground_truths))
content += 'raw context is:\t' + context + '\n'
content += 'subject_name is:\t' + entity_name + '\n'
content += 'pred_text is:\t' + '\t'.join(prediction) + '\n'
content += 'gold_text is:\t' + '\t'.join(ground_truths) + '\n'
content += '==============================='
em += len(intersection)
pre += len(set(prediction))
gold += len(set(ground_truths))
with open('badcase_{}.txt'.format(chosen), 'w') as f:
f.write(content)
def convert_pointer_net_contour(eval_file, q_ids, po1, po2, id2rel, use_bert=False):
answer_dict = dict()
for qid, o1, o2 in zip(q_ids, po1.data.cpu().numpy(), po2.data.cpu().numpy()):
context = eval_file[qid.item()].context if not use_bert else eval_file[qid.item()].bert_tokens
gold_attr_list = eval_file[qid.item()].gold_attr_list
gold_answer = [attr.attr_type + '@' + attr.value for attr in gold_attr_list]
answers = list()
start, end = np.where(o1 > 0.5), np.where(o2 > 0.5)
for _start, _attr_type_id_start in zip(*start):
if _start > len(context) or (_start == 0 and use_bert):
continue
for _end, _attr_type_id_end in zip(*end):
if _start <= _end < len(context) and _attr_type_id_start == _attr_type_id_end:
_attr_value = ''.join(context[_start: _end + 1]) if use_bert else context[_start: _end + 1]
_attr_type = id2rel[_attr_type_id_start]
_attr = _attr_type + '@' + _attr_value
answers.append(_attr)
break
answer_dict[str(qid.item())] = [answers, o1, o2]
return answer_dict

View File

View File

View File

View File

0
utils/__init__.py Normal file
View File

51
utils/data_util.py Normal file
View File

@ -0,0 +1,51 @@
import torch
def padding(seqs, is_float=False, batch_first=False):
lengths = [len(s) for s in seqs]
seqs = [torch.Tensor(s) for s in seqs]
batch_length = max(lengths)
seq_tensor = torch.FloatTensor(batch_length, len(seqs)).fill_(float(0)) if is_float \
else torch.LongTensor(batch_length, len(seqs)).fill_(0)
for i, s in enumerate(seqs):
end_seq = lengths[i]
seq_tensor[:end_seq, i].copy_(s[:end_seq])
if batch_first:
seq_tensor = seq_tensor.t()
return seq_tensor, lengths
def mpn_padding(seqs, label, class_num, is_float=False, use_bert=False):
lengths = [len(s) for s in seqs]
seqs = [torch.Tensor(s) for s in seqs]
batch_length = max(lengths)
o1_tensor = torch.FloatTensor(len(seqs), batch_length, class_num).fill_(float(0)) if is_float \
else torch.LongTensor(len(seqs), batch_length, class_num).fill_(0)
o2_tensor = torch.FloatTensor(len(seqs), batch_length, class_num).fill_(float(0)) if is_float \
else torch.LongTensor(len(seqs), batch_length, class_num).fill_(0)
for i, label_ in enumerate(label):
for attr in label_:
if use_bert:
o1_tensor[i, attr.value_pos_start + 1, attr.attr_type_id] = 1
o2_tensor[i, attr.value_pos_end, attr.attr_type_id] = 1
else:
o1_tensor[i, attr.value_pos_start, attr.attr_type_id] = 1
o2_tensor[i, attr.value_pos_end - 1, attr.attr_type_id] = 1
return o1_tensor, o2_tensor
def _handle_pos_limit(pos, limit=30):
for i, p in enumerate(pos):
if p > limit:
pos[i] = limit
if p < -limit:
pos[i] = -limit
return [p + limit + 1 for p in pos]

46
utils/file_util.py Normal file
View File

@ -0,0 +1,46 @@
import json
import logging
import os
import pickle
import sys
def pickle_dump_large_file(obj, filepath):
max_bytes = 2 ** 31 - 1
bytes_out = pickle.dumps(obj)
n_bytes = sys.getsizeof(bytes_out)
with open(filepath, 'wb') as f_out:
for idx in range(0, n_bytes, max_bytes):
f_out.write(bytes_out[idx:idx + max_bytes])
def pickle_load_large_file(filepath):
max_bytes = 2 ** 31 - 1
input_size = os.path.getsize(filepath)
bytes_in = bytearray(0)
with open(filepath, 'rb') as f_in:
for _ in range(0, input_size, max_bytes):
bytes_in += f_in.read(max_bytes)
obj = pickle.loads(bytes_in)
return obj
def save(filepath, obj, message=None):
if message is not None:
logging.info("Saving {}...".format(message))
pickle_dump_large_file(obj, filepath)
def load(filepath):
return pickle_load_large_file(filepath)
def read_json(path):
with open(path, 'r') as f:
return json.load(f)
def write_json(obj, path):
with open(path, 'wb') as f:
f.write(json.dumps(obj, indent=2, ensure_ascii=False).
encode('utf-8'))

25
utils/optimizer_util.py Normal file
View File

@ -0,0 +1,25 @@
from torch import optim
from layers.encoders.transformers.bert.bert_optimization import BertAdam
def set_optimizer(args, model, train_steps=None):
if args.use_bert:
param_optimizer = list(model.named_parameters())
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=train_steps)
return optimizer
else:
parameters_trainable = list(filter(lambda p: p.requires_grad, model.parameters()))
optimizer = optim.Adam(parameters_trainable, lr=args.learning_rate)
return optimizer