Merge pull request #5 from loujie0822/jielou

Jielou
This commit is contained in:
loujie0822 2020-02-17 13:16:42 +08:00 committed by GitHub
commit c81a00937e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1445 additions and 2 deletions

3
.gitignore vendored
View File

@ -6,7 +6,8 @@ settings.py
instance/
data/
cpt/
pyscripts/
.pytest_cache/
.coverage

View File

@ -0,0 +1,51 @@
BAIDU_RELATION = {
"朝代": 0,
"人口数量": 1,
"出生地": 2,
"连载网站": 3,
"身高": 4,
"占地面积": 5,
"作者": 6,
"": 7,
"母亲": 8,
"海拔": 9,
"作词": 10,
"嘉宾": 11,
"总部地点": 12,
"出版社": 13,
"主持人": 14,
"出生日期": 15,
"所在城市": 16,
"修业年限": 17,
"祖籍": 18,
"邮政编码": 19,
"毕业院校": 20,
"气候": 21,
"": 22,
"注册资本": 23,
"丈夫": 24,
"国籍": 25,
"主角": 26,
"主演": 27,
"民族": 28,
"董事长": 29,
"所属专辑": 30,
"专业代码": 31,
"改编自": 32,
"歌手": 33,
"编剧": 34,
"妻子": 35,
"面积": 36,
"作曲": 37,
"官方语言": 38,
"出品公司": 39,
"成立日期": 40,
"简称": 41,
"首都": 42,
"父亲": 43,
"": 44,
"制片人": 45,
"上映时间": 46,
"创始人": 47,
"导演": 48
}

232
models/ere_net/bert_mpn.py Normal file
View File

@ -0,0 +1,232 @@
# _*_ coding:utf-8 _*_
import warnings
import numpy as np
import torch
import torch.nn as nn
from layers.encoders.transformers.bert.bert_model import BertModel
warnings.filterwarnings("ignore")
from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer
class EntityNET(nn.Module):
"""
ERENet : entity relation extraction
"""
def __init__(self, args):
super(EntityNET, self).__init__()
self.sb1 = nn.Linear(args.bert_hidden_size, 1)
self.sb2 = nn.Linear(args.bert_hidden_size, 1)
def forward(self, sent_encoder, q_ids=None, eval_file=None, passages=None, s1=None, s2=None, is_eval=False):
sequence_mask = passages != 0
sb1 = self.sb1(sent_encoder).squeeze()
sb2 = self.sb2(sent_encoder).squeeze()
if not is_eval:
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
sb1_loss = loss_fct(sb1, s1)
s1_loss = torch.sum(sb1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float())
s2_loss = loss_fct(sb2, s2)
s2_loss = torch.sum(s2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float())
ent_loss = s1_loss + s2_loss
return ent_loss
else:
answer_list = self.predict(eval_file, q_ids, sb1, sb2)
return answer_list
def predict(self, eval_file, q_ids=None, sb1=None, sb2=None):
answer_list = list()
for qid, p1, p2 in zip(q_ids.cpu().numpy(),
sb1.cpu().numpy(),
sb2.cpu().numpy()):
context = eval_file[qid].context
start = None
end = None
threshold = 0.0
positions = list()
for idx in range(0, len(context)):
if idx == 0:
continue
if p1[idx] > threshold and start is None:
start = idx
if p2[idx] > threshold and end is None:
end = idx
if start is not None and end is not None and start <= end:
positions.append((start, end + 1))
start = None
end = None
answer_list.append(positions)
return answer_list
class RelNET(nn.Module):
"""
ERENet : entity relation extraction
"""
def __init__(self, args, spo_conf):
super(RelNET, self).__init__()
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.bert_hidden_size,
padding_idx=0)
self.encoder_layer = TransformerEncoderLayer(args.bert_hidden_size, args.nhead)
self.transformer_encoder = TransformerEncoder(self.encoder_layer, args.transformer_layers)
self.classes_num = len(spo_conf)
self.ob1 = nn.Linear(args.bert_hidden_size, self.classes_num)
self.ob2 = nn.Linear(args.bert_hidden_size, self.classes_num)
def forward(self, passages=None, sent_encoder=None, posit_ids=None, o1=None, o2=None, is_eval=False):
mask = passages.eq(0)
subject_encoder = sent_encoder + self.token_entity_emb(posit_ids)
subject_encoder = torch.transpose(subject_encoder, 1, 0)
transformer_encoder = self.transformer_encoder(subject_encoder, src_key_padding_mask=mask)
transformer_encoder = torch.transpose(transformer_encoder, 0, 1)
po1 = self.ob1(transformer_encoder)
po2 = self.ob2(transformer_encoder)
if not is_eval:
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
sequence_mask = passages != 0
s1_loss = loss_fct(po1, o1)
s1_loss = torch.sum(s1_loss, 2)
s1_loss = torch.sum(s1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num
s2_loss = loss_fct(po2, o2)
s2_loss = torch.sum(s2_loss, 2)
s2_loss = torch.sum(s2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num
rel_loss = s1_loss + s2_loss
return rel_loss
else:
po1 = nn.Sigmoid()(po1)
po2 = nn.Sigmoid()(po2)
return po1, po2
class ERENet(nn.Module):
"""
ERENet : entity relation extraction
"""
def __init__(self, args, spo_conf):
super(ERENet, self).__init__()
print('joint entity relation extraction')
self.bert_encoder = BertModel.from_pretrained(args.bert_model)
self.entity_extraction = EntityNET(args)
self.rel_extraction = RelNET(args, spo_conf)
def forward(self, q_ids=None, eval_file=None, passages=None, token_type_ids=None, segment_ids=None, s1=None,
s2=None, po1=None, po2=None, is_eval=False):
sequence_mask = passages != 0
sent_encoder, _ = self.bert_encoder(passages, token_type_ids=segment_ids, attention_mask=sequence_mask,
output_all_encoded_layers=False)
if not is_eval:
# entity_extraction
ent_loss = self.entity_extraction(sent_encoder, passages=passages, s1=s1, s2=s2,
is_eval=is_eval)
# rel_extraction
rel_loss = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, posit_ids=token_type_ids,
o1=po1,
o2=po2, is_eval=False)
# add total loss
total_loss = ent_loss + rel_loss
return total_loss
else:
answer_list = self.entity_extraction(sent_encoder, q_ids=q_ids, eval_file=eval_file,
passages=passages, is_eval=is_eval)
start_list, end_list = list(), list()
qid_list, pass_list, posit_list, sent_list = list(), list(), list(), list()
for i, ans_list in enumerate(answer_list):
seq_len = passages.size(1)
posit_ids = []
for ans_tuple in ans_list:
posit_array = np.zeros(seq_len, dtype=np.int)
start, end = ans_tuple[0], ans_tuple[1]
start_list.append(start)
end_list.append(end)
posit_array[start:end] = 1
posit_ids.append(posit_array)
if len(posit_ids) == 0:
continue
qid_ = q_ids[i].unsqueeze(0).expand(len(posit_ids))
sent_tensor = sent_encoder[i, :, :].unsqueeze(0).expand(len(posit_ids), sent_encoder.size(1),
sent_encoder.size(2))
pass_tensor = passages[i, :].unsqueeze(0).expand(len(posit_ids), passages.size(1))
posit_tensor = torch.tensor(posit_ids, dtype=torch.long).to(sent_encoder.device)
qid_list.append(qid_)
pass_list.append(pass_tensor)
posit_list.append(posit_tensor)
sent_list.append(sent_tensor)
if len(qid_list) == 0:
qid_tensor = torch.tensor([-1, -1], dtype=torch.long).to(sent_encoder.device)
return qid_tensor, qid_tensor, qid_tensor, qid_tensor, qid_tensor
qid_tensor = torch.cat(qid_list).to(sent_encoder.device)
sent_tensor = torch.cat(sent_list).to(sent_encoder.device)
pass_tensor = torch.cat(pass_list).to(sent_encoder.device)
posi_tensor = torch.cat(posit_list).to(sent_encoder.device)
flag = False
split_heads = 1024
inputs = torch.split(pass_tensor, split_heads, dim=0)
posits = torch.split(posi_tensor, split_heads, dim=0)
sents = torch.split(sent_tensor, split_heads, dim=0)
po1_list, po2_list = list(), list()
for i in range(len(inputs)):
passages = inputs[i]
sent_encoder = sents[i]
posit_ids = posits[i]
if passages.size(0) == 1:
flag = True
passages = passages.expand(2, passages.size(1))
sent_encoder = sent_encoder.expand(2, sent_encoder.size(1), sent_encoder.size(2))
posit_ids = posit_ids.expand(2, posit_ids.size(1))
po1, po2 = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, posit_ids=posit_ids,
is_eval=is_eval)
if flag:
po1 = po1[1, :, :].unsqueeze(0)
po2 = po2[1, :, :].unsqueeze(0)
po1_list.append(po1)
po2_list.append(po2)
po1_tensor = torch.cat(po1_list).to(sent_encoder.device)
po2_tensor = torch.cat(po2_list).to(sent_encoder.device)
s_tensor = torch.tensor(start_list, dtype=torch.long).to(sent_encoder.device)
e_tensor = torch.tensor(end_list, dtype=torch.long).to(sent_encoder.device)
return qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor

248
models/ere_net/mpn.py Normal file
View File

@ -0,0 +1,248 @@
import warnings
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer
from layers.encoders.rnns.stacked_rnn import StackedBRNN
warnings.filterwarnings("ignore")
class SentenceEncoder(nn.Module):
def __init__(self, args, input_size):
super(SentenceEncoder, self).__init__()
rnn_type = nn.LSTM if args.rnn_encoder == 'lstm' else nn.GRU
self.encoder = StackedBRNN(
input_size=input_size,
hidden_size=args.hidden_size,
num_layers=args.num_layers,
dropout_rate=args.dropout,
dropout_output=True,
concat_layers=False,
rnn_type=rnn_type,
padding=True
)
def forward(self, input, mask):
return self.encoder(input, mask)
class EntityNET(nn.Module):
"""
EntityNET : entity extraction using pointer network
"""
def __init__(self, args, char_emb):
super(EntityNET, self).__init__()
if char_emb is not None:
self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
padding_idx=0)
else:
self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
padding_idx=0)
self.sentence_encoder = SentenceEncoder(args, args.word_emb_size)
self.s1 = nn.Linear(args.hidden_size * 2, 1)
self.s2 = nn.Linear(args.hidden_size * 2, 1)
def forward(self, q_ids=None, eval_file=None, passages=None, s1=None, s2=None, is_eval=False):
mask = passages.eq(0)
sequence_mask = passages != 0
char_emb = self.char_emb(passages)
sent_encoder = self.sentence_encoder(char_emb, mask)
s1_ = self.s1(sent_encoder).squeeze()
s2_ = self.s2(sent_encoder).squeeze()
if not is_eval:
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
sb1_loss = loss_fct(s1_, s1)
s1_loss = torch.sum(sb1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float())
s2_loss = loss_fct(s2_, s2)
s2_loss = torch.sum(s2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float())
ent_loss = s1_loss + s2_loss
return sent_encoder, ent_loss
else:
answer_list = self.predict(eval_file, q_ids, s1_, s2_)
return sent_encoder, answer_list
def predict(self, eval_file, q_ids=None, s1=None, s2=None):
sub_ans_list = list()
for qid, p1, p2 in zip(q_ids.cpu().numpy(),
s1.cpu().numpy(),
s2.cpu().numpy()):
start = None
end = None
threshold = 0.0
positions = list()
for idx in range(0, len(eval_file[qid].context)):
if p1[idx] > threshold and start is None:
start = idx
if p2[idx] > threshold and end is None:
end = idx
if start is not None and end is not None and start <= end:
positions.append((start, end + 1))
start = None
end = None
sub_ans_list.append(positions)
return sub_ans_list
class RelNET(nn.Module):
"""
ERENet : entity relation extraction
"""
def __init__(self, args, spo_conf):
super(RelNET, self).__init__()
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
padding_idx=0)
self.sentence_encoder = SentenceEncoder(args, args.word_emb_size)
self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead)
self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
self.classes_num = len(spo_conf)
self.po1 = nn.Linear(args.hidden_size * 2, self.classes_num)
self.po2 = nn.Linear(args.hidden_size * 2, self.classes_num)
def forward(self, passages=None, sent_encoder=None, token_type_id=None, po1=None, po2=None, is_eval=False):
mask = passages.eq(0)
sequence_mask = passages != 0
subject_encoder = sent_encoder + self.token_entity_emb(token_type_id)
sent_sub_aware_encoder = self.sentence_encoder(subject_encoder, mask).transpose(1, 0)
transformer_encoder = self.transformer_encoder(sent_sub_aware_encoder, src_key_padding_mask=mask).transpose(0,
1)
po1_ = self.po1(transformer_encoder)
po2_ = self.po2(transformer_encoder)
if not is_eval:
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
po1_loss = loss_fct(po1_, po1)
po1_loss = torch.sum(po1_loss, 2)
po1_loss = torch.sum(po1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num
po2_loss = loss_fct(po2_, po2)
po2_loss = torch.sum(po2_loss, 2)
po2_loss = torch.sum(po2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num
rel_loss = po1_loss + po2_loss
return rel_loss
else:
po1 = nn.Sigmoid()(po1_)
po2 = nn.Sigmoid()(po2_)
return po1, po2
class ERENet(nn.Module):
"""
ERENet : entity relation jointed extraction with Multi-label Pointer Network(MPN) based Entity-aware
"""
def __init__(self, args, char_emb, spo_conf):
super(ERENet, self).__init__()
print('joint entity relation extraction')
self.entity_extraction = EntityNET(args, char_emb)
self.rel_extraction = RelNET(args, spo_conf)
def forward(self, q_ids=None, eval_file=None, passages=None, token_type_ids=None, segment_ids=None, s1=None,
s2=None, po1=None, po2=None, is_eval=False):
if not is_eval:
sent_encoder, ent_loss = self.entity_extraction(passages=passages, s1=s1, s2=s2, is_eval=is_eval)
rel_loss = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, token_type_id=token_type_ids,
po1=po1, po2=po2, is_eval=False)
total_loss = ent_loss + rel_loss
return total_loss
else:
sent_encoder, answer_list = self.entity_extraction(q_ids=q_ids, eval_file=eval_file,
passages=passages, is_eval=is_eval)
start_list, end_list = list(), list()
qid_list, pass_list, posit_list, sent_list = list(), list(), list(), list()
for i, ans_list in enumerate(answer_list):
seq_len = passages.size(1)
posit_ids = []
for ans_tuple in ans_list:
posit_array = np.zeros(seq_len, dtype=np.int)
start, end = ans_tuple[0], ans_tuple[1]
start_list.append(start)
end_list.append(end)
posit_array[start:end] = 1
posit_ids.append(posit_array)
if len(posit_ids) == 0:
continue
qid_ = q_ids[i].unsqueeze(0).expand(len(posit_ids))
sent_tensor = sent_encoder[i, :, :].unsqueeze(0).expand(len(posit_ids), sent_encoder.size(1),
sent_encoder.size(2))
pass_tensor = passages[i, :].unsqueeze(0).expand(len(posit_ids), passages.size(1))
posit_tensor = torch.tensor(posit_ids, dtype=torch.long).to(sent_encoder.device)
qid_list.append(qid_)
pass_list.append(pass_tensor)
posit_list.append(posit_tensor)
sent_list.append(sent_tensor)
if len(qid_list) == 0:
# print('len(qid_list)==0:')
qid_tensor = torch.tensor([-1, -1], dtype=torch.long).to(sent_encoder.device)
return qid_tensor, qid_tensor, qid_tensor, qid_tensor, qid_tensor
qid_tensor = torch.cat(qid_list).to(sent_encoder.device)
sent_tensor = torch.cat(sent_list).to(sent_encoder.device)
pass_tensor = torch.cat(pass_list).to(sent_encoder.device)
posi_tensor = torch.cat(posit_list).to(sent_encoder.device)
flag = False
split_heads = 1024
inputs = torch.split(pass_tensor, split_heads, dim=0)
posits = torch.split(posi_tensor, split_heads, dim=0)
sents = torch.split(sent_tensor, split_heads, dim=0)
po1_list, po2_list = list(), list()
for i in range(len(inputs)):
passages = inputs[i]
sent_encoder = sents[i]
posit_ids = posits[i]
if passages.size(0) == 1:
flag = True
# print('flag = True**********')
passages = passages.expand(2, passages.size(1))
sent_encoder = sent_encoder.expand(2, sent_encoder.size(1), sent_encoder.size(2))
posit_ids = posit_ids.expand(2, posit_ids.size(1))
po1, po2 = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, token_type_id=posit_ids,
is_eval=is_eval)
if flag:
po1 = po1[1, :, :].unsqueeze(0)
po2 = po2[1, :, :].unsqueeze(0)
po1_list.append(po1)
po2_list.append(po2)
po1_tensor = torch.cat(po1_list).to(sent_encoder.device)
po2_tensor = torch.cat(po2_list).to(sent_encoder.device)
s_tensor = torch.tensor(start_list, dtype=torch.long).to(sent_encoder.device)
e_tensor = torch.tensor(end_list, dtype=torch.long).to(sent_encoder.device)
return qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor

0
pyscripts/__init__.py Normal file
View File

View File

@ -3,7 +3,7 @@ import argparse
import logging
import os
from run.attribute_extract.mpn.data import Reader, Vocabulary, config, Feature
from run.attribute_extract.mpn.data_loader import Reader, Vocabulary, config, Feature
from run.attribute_extract.mpn.train import Trainer
from utils.file_util import save, load

View File

@ -0,0 +1,475 @@
import codecs
import json
import logging
from collections import Counter
from functools import partial
import jieba
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from config.baidu_spo_config import BAIDU_RELATION
from layers.encoders.transformers.bert.bert_tokenization import BertTokenizer
from utils.data_util import padding, _handle_pos_limit, find_position, spo_padding
class PredictObject(object):
def __init__(self,
object_name,
object_start,
object_end,
predict_type,
predict_type_id
):
self.object_name = object_name
self.object_start = object_start
self.object_end = object_end
self.predict_type = predict_type
self.predict_type_id = predict_type_id
class Example(object):
def __init__(self,
p_id=None,
context=None,
bert_tokens=None,
sub_pos=None,
sub_entity_list=None,
relative_pos_start=None,
relative_pos_end=None,
po_list=None,
gold_answer=None):
self.p_id = p_id
self.context = context
self.bert_tokens = bert_tokens
self.sub_pos = sub_pos
self.sub_entity_list = sub_entity_list
self.relative_pos_start = relative_pos_start
self.relative_pos_end = relative_pos_end
self.po_list = po_list
self.gold_answer = gold_answer
class InputFeature(object):
def __init__(self,
p_id=None,
passage_id=None,
token_type_id=None,
pos_start_id=None,
pos_end_id=None,
segment_id=None,
po_label=None,
s1=None,
s2=None):
self.p_id = p_id
self.passage_id = passage_id
self.token_type_id = token_type_id
self.pos_start_id = pos_start_id
self.pos_end_id = pos_end_id
self.segment_id = segment_id
self.po_label = po_label
self.s1 = s1
self.s2 = s2
class Reader(object):
def __init__(self, do_lowercase=False, seg_char=False, max_len=600):
self.do_lowercase = do_lowercase
self.seg_char = seg_char
self.max_len = max_len
self.relation_config = BAIDU_RELATION
if self.seg_char:
logging.info("seg_char...")
else:
logging.info("seg_word using jieba ...")
def read_examples(self, filename, data_type):
logging.info("Generating {} examples...".format(data_type))
return self._read(filename, data_type)
def _data_process(self, filename, data_type='train'):
output_data = list()
with codecs.open(filename, 'r') as f:
gold_num = 0
for line in tqdm(f):
data_json = json.loads(line.strip())
text = data_json['text'].lower()
sub_po_dict, sub_ent_list, spo_list = dict(), list(), list()
for spo in data_json['spo_list']:
# TODO .strip('《》').strip()
subject_name = spo['subject'].lower().strip('《》').strip()
object_name = spo['object'].lower().strip('《》').strip()
s_start, s_end = find_position(subject_name, text)
o_start, o_end = find_position(object_name, text)
if text[s_start:s_end] != subject_name:
# print(subject_name)
subject_name = spo['subject'].lower().replace('', '').replace('', '')
s_start, s_end = find_position(subject_name, text)
if s_start != -1 and o_start != -1:
sub_ent_list.append((subject_name, s_start, s_end))
spo_list.append((subject_name, spo['predicate'], object_name))
if subject_name not in sub_po_dict:
sub_po_dict[subject_name] = {}
sub_po_dict[subject_name]['sub_pos'] = [s_start, s_end]
sub_po_dict[subject_name]['po_list'] = [
{'predict': spo['predicate'], 'object': (object_name, o_start, o_end)}]
else:
sub_po_dict[subject_name]['po_list'].append(
{'predict': spo['predicate'], 'object': (object_name, o_start, o_end)})
text_spo = dict()
text_spo['context'] = text
text_spo['sub_po_dict'] = sub_po_dict
text_spo['spo_list'] = list(set(spo_list))
text_spo['sub_ent_list'] = list(set(sub_ent_list))
gold_num += len(set(spo_list))
output_data.append(text_spo)
if data_type == 'train':
return self._convert_train_data(output_data)
# print(f'total gold num is {gold_num}')
return output_data
@staticmethod
def _convert_train_data(src_data):
"""
将train_data转化为满足训练要求的形式
1条数据为一个subject对应响应的(predict,object)-->sub_po_dict
:param data:
:return:
"""
spo_data = []
for data in src_data:
for sub_ent, po_dict in data['sub_po_dict'].items():
data['sub_name'] = sub_ent
data['sub_pos'] = po_dict['sub_pos']
data['po_list'] = po_dict['po_list']
spo_data.append(data)
return spo_data
def _read(self, filename, data_type):
data_set = self._data_process(filename, data_type)
logging.info("{} data_set total size is {} ".format(data_type, len(data_set)))
examples = []
for p_id in tqdm(range(len(data_set))):
data = data_set[p_id]
para = data['context']
context = para if self.seg_char else ''.join(jieba.lcut(para))
if len(context) > self.max_len:
context = context[:self.max_len]
if data_type == 'train':
start, end = data['sub_pos']
if start >= self.max_len or end >= self.max_len:
continue
assert data['sub_name'] == context[start:end]
# pos_start&pos_end: 指句子中词语相对subject_entity的position(相对距离)
# 如:[-30, 30]embed 时整体+31变成[1, 61]
# 则一共62个pos token0 留给 pad
pos_start = list(map(lambda i: i - start, list(range(len(context)))))
pos_end = list(map(lambda i: i - end, list(range(len(context)))))
relative_pos_start = _handle_pos_limit(pos_start)
relative_pos_end = _handle_pos_limit(pos_end)
po_list = []
for predict_object in data['po_list']:
predict_type = predict_object['predict']
object_ = predict_object['object']
object_name, object_start, object_end = object_[0], object_[1], object_[2]
if object_start >= self.max_len or object_end >= self.max_len:
continue
assert object_name == context[object_start:object_end]
po_list.append(PredictObject(
object_name=object_name,
object_start=object_start,
object_end=object_end,
predict_type=predict_type,
predict_type_id=self.relation_config[predict_type]
))
examples.append(
Example(
p_id=p_id,
context=context,
sub_pos=data['sub_pos'],
sub_entity_list=data['sub_ent_list'],
relative_pos_start=relative_pos_start,
relative_pos_end=relative_pos_end,
po_list=po_list,
gold_answer=data['spo_list']
)
)
else:
examples.append(
Example(
p_id=p_id,
context=context,
sub_pos=None,
sub_entity_list=data['sub_ent_list'],
relative_pos_start=None,
relative_pos_end=None,
po_list=None,
gold_answer=data['spo_list']
)
)
logging.info("{} total size is {} ".format(data_type, len(examples)))
return examples
class Vocabulary(object):
def __init__(self, special_tokens=["<OOV>", "<MASK>"]):
self.char_vocab = None
self.emb_mat = None
self.char2idx = None
self.char_counter = Counter()
self.special_tokens = special_tokens
def build_vocab_only_with_char(self, examples, min_char_count=-1):
logging.info("Building vocabulary only with character...")
self.char_vocab = ["<PAD>"]
if self.special_tokens is not None and isinstance(self.special_tokens, list):
self.char_vocab.extend(self.special_tokens)
for example in tqdm(examples):
for char in example.context:
self.char_counter[char] += 1
for w, v in self.char_counter.most_common():
if v >= min_char_count:
self.char_vocab.append(w)
self.char2idx = {token: idx for idx, token in enumerate(self.char_vocab)}
logging.info("total char counter size is {} ".format(len(self.char_counter)))
logging.info("total char vocabulary size is {} ".format(len(self.char_vocab)))
def _load_embedding(self, embedding_file, embedding_dict):
with open(embedding_file) as f:
for line in f:
if len(line.rstrip().split(" ")) <= 2: continue
token, vector = line.rstrip().split(" ", 1)
embedding_dict[token] = np.fromstring(vector, dtype=np.float, sep=" ")
return embedding_dict
def make_embedding(self, vocab, embedding_file, emb_size):
embedding_dict = dict()
embedding_dict["<PAD>"] = np.array([0. for _ in range(emb_size)])
self._load_embedding(embedding_file, embedding_dict)
count = 0
for token in tqdm(vocab):
if token not in embedding_dict:
count += 1
embedding_dict[token] = np.array([np.random.normal(scale=0.1) for _ in range(emb_size)])
logging.info(
"{} / {} tokens have corresponding in embedding vector".format(len(vocab) - count, len(vocab)))
emb_mat = [embedding_dict[token] for idx, token in enumerate(vocab)]
return emb_mat
class Feature(object):
def __init__(self, args, token2idx_dict):
self.bert = args.use_bert
self.token2idx_dict = token2idx_dict
if self.bert:
self.tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
def token2wid(self, token):
if token in self.token2idx_dict:
return self.token2idx_dict[token]
return self.token2idx_dict["<OOV>"]
def __call__(self, examples, data_type):
if self.bert:
return self.convert_examples_to_bert_features(examples, data_type)
else:
return self.convert_examples_to_features(examples, data_type)
def convert_examples_to_features(self, examples, data_type):
logging.info("convert {} examples to features .".format(data_type))
examples2features = list()
for index, example in enumerate(examples):
passage_id = np.zeros(len(example.context), dtype=np.int)
segment_id = np.zeros(len(example.context), dtype=np.int)
token_type_id = np.zeros(len(example.context), dtype=np.int)
pos_start_id = np.zeros(len(example.context), dtype=np.int)
pos_end_id = np.zeros(len(example.context), dtype=np.int)
s1 = np.zeros(len(example.context), dtype=np.float)
s2 = np.zeros(len(example.context), dtype=np.float)
for (_, start, end) in example.sub_entity_list:
if start >= len(example.context) or end >= len(example.context):
continue
s1[start] = 1.0
s2[end - 1] = 1.0
for i, token in enumerate(example.context):
passage_id[i] = self.token2wid(token)
if data_type == 'train':
sub_start, sub_end = example.sub_pos[0], example.sub_pos[1]
for i, token in enumerate(example.context):
if sub_start <= i < sub_end:
# token = "<MASK>"
token_type_id[i] = 1
pos_start_id[i] = example.relative_pos_start[i]
pos_end_id[i] = example.relative_pos_end[i]
examples2features.append(
InputFeature(
p_id=index,
passage_id=passage_id,
token_type_id=token_type_id,
pos_start_id=pos_start_id,
pos_end_id=pos_end_id,
segment_id=segment_id,
po_label=example.po_list,
s1=s1,
s2=s2
))
logging.info("Built instances is Completed")
return SPODataset(examples2features, predict_num=len(BAIDU_RELATION), data_type=data_type)
def convert_examples_to_bert_features(self, examples, data_type):
logging.info("Processing {} examples...".format(data_type))
examples2features = list()
for index, example in enumerate(examples):
segment_id = np.zeros(len(example.context) + 2, dtype=np.int)
token_type_id = np.zeros(len(example.context) + 2, dtype=np.int)
pos_start_id = np.zeros(len(example.context) + 2, dtype=np.int)
pos_end_id = np.zeros(len(example.context) + 2, dtype=np.int)
s1 = np.zeros(len(example.context) + 2, dtype=np.float)
s2 = np.zeros(len(example.context) + 2, dtype=np.float)
for (_, start, end) in example.sub_entity_list:
if start >= len(example.context) or end >= len(example.context):
continue
s1[start + 1] = 1.0
s2[end] = 1.0
tokens = ["[CLS]"]
for i, token in enumerate(example.context):
tokens.append(token)
tokens.append("[SEP]")
passage_id = self.tokenizer.convert_tokens_to_ids(tokens)
example.bert_tokens = tokens
if data_type == 'train':
sub_start, sub_end = example.sub_pos[0], example.sub_pos[1]
for i, token in enumerate(example.context):
if sub_start <= i < sub_end:
token_type_id[i + 1] = 1
pos_start_id[i + 1] = example.relative_pos_start[i]
pos_end_id[i + 1] = example.relative_pos_end[i]
examples2features.append(
InputFeature(
p_id=index,
passage_id=passage_id,
token_type_id=token_type_id,
pos_start_id=pos_start_id,
pos_end_id=pos_end_id,
segment_id=segment_id,
po_label=example.po_list,
s1=s1,
s2=s2
))
logging.info("Built instances is Completed")
return SPODataset(examples2features, predict_num=len(BAIDU_RELATION), use_bert=True,data_type=data_type)
class SPODataset(Dataset):
def __init__(self, features, predict_num, data_type, use_bert=False):
super(SPODataset, self).__init__()
self.use_bert = use_bert
self.is_train = True if data_type == 'train' else False
self.q_ids = [f.p_id for f in features]
self.passages = [f.passage_id for f in features]
self.segment_ids = [f.segment_id for f in features]
self.predict_num = predict_num
if self.is_train:
self.token_type = [f.token_type_id for f in features]
self.pos_start_ids = [f.pos_start_id for f in features]
self.pos_end_ids = [f.pos_end_id for f in features]
self.s1 = [f.s1 for f in features]
self.s2 = [f.s2 for f in features]
self.po_label = [f.po_label for f in features]
def __len__(self):
return len(self.passages)
def __getitem__(self, index):
if self.is_train:
return self.q_ids[index], self.passages[index], self.segment_ids[index], self.token_type[index], \
self.pos_start_ids[index], self.pos_end_ids[index], self.s1[index], self.s2[index], self.po_label[
index]
else:
return self.q_ids[index], self.passages[index], self.segment_ids[index]
def _create_collate_fn(self, batch_first=False):
def collate(examples):
if self.is_train:
p_ids, passages, segment_ids, token_type, pos_start_ids, pos_end_ids, s1, s2, label = zip(*examples)
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
passages_tensor, _ = padding(passages, is_float=False, batch_first=batch_first)
segment_tensor, _ = padding(segment_ids, is_float=False, batch_first=batch_first)
token_type_tensor, _ = padding(token_type, is_float=False, batch_first=batch_first)
pos_start_tensor, _ = padding(pos_start_ids, is_float=False, batch_first=batch_first)
pos_end_tensor, _ = padding(pos_end_ids, is_float=False, batch_first=batch_first)
s1_tensor, _ = padding(s1, is_float=True, batch_first=batch_first)
s2_tensor, _ = padding(s2, is_float=True, batch_first=batch_first)
po1_tensor, po2_tensor = spo_padding(passages, label, class_num=self.predict_num, is_float=True,
use_bert=self.use_bert)
return p_ids, passages_tensor, segment_tensor, token_type_tensor, s1_tensor, s2_tensor, po1_tensor, \
po2_tensor
else:
p_ids, passages, segment_ids = zip(*examples)
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
passages_tensor, _ = padding(passages, is_float=False, batch_first=batch_first)
segment_tensor, _ = padding(segment_ids, is_float=False, batch_first=batch_first)
return p_ids, passages_tensor, segment_tensor
return partial(collate)
def get_dataloader(self, batch_size, num_workers=0, shuffle=False, batch_first=True, pin_memory=False,
drop_last=False):
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._create_collate_fn(batch_first),
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)

View File

@ -0,0 +1,171 @@
# _*_ coding:utf-8 _*_
import argparse
import logging
import os
from config.baidu_spo_config import BAIDU_RELATION
from run.entity_relation_jointed_extraction.mpn.data_loader import Reader, Vocabulary, Feature
from run.entity_relation_jointed_extraction.mpn.train import Trainer
from utils.file_util import save, load
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
def get_args():
parser = argparse.ArgumentParser()
# file parameters
parser.add_argument("--input", default=None, type=str, required=True)
parser.add_argument("--output"
, default=None, type=str, required=False,
help="The output directory where the model checkpoints and predictions will be written.")
# "cpt/baidu_w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5"
# 'cpt/baidu_w2v/w2v.txt'
parser.add_argument('--embedding_file', type=str,
default='cpt/baidu_w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5')
# choice parameters
parser.add_argument('--entity_type', type=str, default='disease')
parser.add_argument('--use_word2vec', type=bool, default=False)
parser.add_argument('--use_bert', type=bool, default=False)
parser.add_argument('--seg_char', type=bool, default=True)
# train parameters
parser.add_argument('--train_mode', type=str, default="train")
parser.add_argument("--train_batch_size", default=4, type=int, help="Total batch size for training.")
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--epoch_num", default=3, type=int,
help="Total number of training epochs to perform.")
parser.add_argument('--patience_stop', type=int, default=10, help='Patience for learning early stop')
parser.add_argument('--device_id', type=int, default=0)
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
# bert parameters
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
"of training.")
parser.add_argument("--bert_model", default=None, type=str,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
# model parameters
parser.add_argument("--max_len", default=1000, type=int)
parser.add_argument('--word_emb_size', type=int, default=300)
parser.add_argument('--char_emb_size', type=int, default=300)
parser.add_argument('--entity_emb_size', type=int, default=300)
parser.add_argument('--pos_limit', type=int, default=30)
parser.add_argument('--pos_dim', type=int, default=300)
parser.add_argument('--pos_size', type=int, default=62)
parser.add_argument('--hidden_size', type=int, default=150)
parser.add_argument('--bert_hidden_size', type=int, default=768)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--dropout', type=int, default=0.5)
parser.add_argument('--rnn_encoder', type=str, default='lstm', help="must choose in blow: lstm or gru")
parser.add_argument('--bidirectional', type=bool, default=True)
parser.add_argument('--pin_memory', type=bool, default=False)
parser.add_argument('--transformer_layers', type=int, default=1)
parser.add_argument('--nhead', type=int, default=4)
parser.add_argument('--dim_feedforward', type=int, default=2048)
args = parser.parse_args()
if args.use_word2vec:
args.cache_data = args.input + '/char2v_cache_data/'
elif args.use_bert:
args.cache_data = args.input + '/char_bert_cache_data/'
else:
args.cache_data = args.input + '/char_cache_data/'
return args
def bulid_dataset(args, reader, vocab, debug=False):
char2idx, char_emb = None, None
train_src = args.input + "/train_data.json"
dev_src = args.input + "/dev_data.json"
train_examples_file = args.cache_data + "/train-examples.pkl"
dev_examples_file = args.cache_data + "/dev-examples.pkl"
char_emb_file = args.cache_data + "/char_emb.pkl"
char_dictionary = args.cache_data + "/char_dict.pkl"
if not os.path.exists(train_examples_file):
train_examples = reader.read_examples(train_src, data_type='train')
dev_examples = reader.read_examples(dev_src, data_type='dev')
if not args.use_bert:
# todo : min_word_count=3 ?
vocab.build_vocab_only_with_char(train_examples, min_char_count=1)
if args.use_word2vec and args.embedding_file:
char_emb = vocab.make_embedding(vocab=vocab.char_vocab,
embedding_file=args.embedding_file,
emb_size=args.word_emb_size)
save(char_emb_file, char_emb, message="char embedding")
save(char_dictionary, vocab.char2idx, message="char dictionary")
char2idx = vocab.char2idx
save(train_examples_file, train_examples, message="train examples")
save(dev_examples_file, dev_examples, message="dev examples")
else:
if not args.use_bert:
if args.use_word2vec and args.embedding_file:
char_emb = load(char_emb_file)
char2idx = load(char_dictionary)
logging.info("total char vocabulary size is {} ".format(len(char2idx)))
train_examples, dev_examples = load(train_examples_file), load(dev_examples_file)
logging.info('train examples size is {}'.format(len(train_examples)))
logging.info('dev examples size is {}'.format(len(dev_examples)))
if not args.use_bert:
args.vocab_size = len(char2idx)
convert_examples_features = Feature(args, token2idx_dict=char2idx)
train_examples = train_examples[:10] if debug else train_examples
dev_examples = dev_examples[:10] if debug else dev_examples
train_data_set = convert_examples_features(train_examples, data_type='train')
dev_data_set = convert_examples_features(dev_examples, data_type='dev')
train_data_loader = train_data_set.get_dataloader(args.train_batch_size, shuffle=True, pin_memory=args.pin_memory)
dev_data_loader = dev_data_set.get_dataloader(args.train_batch_size)
data_loaders = train_data_loader, dev_data_loader
eval_examples = train_examples, dev_examples
return eval_examples, data_loaders, char_emb
def main():
args = get_args()
if not os.path.exists(args.output):
print('mkdir {}'.format(args.output))
os.makedirs(args.output)
if not os.path.exists(args.cache_data):
print('mkdir {}'.format(args.cache_data))
os.makedirs(args.cache_data)
logger.info("** ** * bulid dataset ** ** * ")
reader = Reader(seg_char=args.seg_char, max_len=args.max_len)
vocab = Vocabulary()
eval_examples, data_loaders, char_emb = bulid_dataset(args, reader, vocab, debug=False)
trainer = Trainer(args, data_loaders, eval_examples, char_emb, spo_conf=BAIDU_RELATION)
if args.train_mode == "train":
trainer.train(args)
elif args.train_mode == "eval":
# trainer.resume(args)
# trainer.eval_data_set("train")
trainer.eval_data_set("dev")
elif args.train_mode == "resume":
# trainer.resume(args)
trainer.show("dev") # bad case analysis
if __name__ == '__main__':
main()

View File

@ -0,0 +1,265 @@
# _*_ coding:utf-8 _*_
import logging
import random
import sys
import time
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
import models.ere_net.bert_mpn as bert_mpn
import models.ere_net.mpn as mpn
from utils.optimizer_util import set_optimizer
logger = logging.getLogger(__name__)
class Trainer(object):
def __init__(self, args, data_loaders, examples, char_emb, spo_conf):
if args.use_bert:
self.model = bert_mpn.ERENet(args, spo_conf)
else:
self.model = mpn.ERENet(args, char_emb, spo_conf)
self.args = args
self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu")
self.n_gpu = torch.cuda.device_count()
self.id2rel = {item: key for key, item in spo_conf.items()}
self.rel2id = spo_conf
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if self.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
self.model.to(self.device)
# self.resume(args)
logging.info('total gpu num is {}'.format(self.n_gpu))
if self.n_gpu > 1:
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
train_dataloader, dev_dataloader = data_loaders
train_eval, dev_eval = examples
self.eval_file_choice = {
"train": train_eval,
"dev": dev_eval,
}
self.data_loader_choice = {
"train": train_dataloader,
"dev": dev_dataloader,
}
self.optimizer = set_optimizer(args, self.model,
train_steps=(int(len(train_eval) / args.train_batch_size) + 1) * args.epoch_num)
def train(self, args):
best_f1 = 0.0
patience_stop = 0
self.model.train()
step_gap = 20
for epoch in range(int(args.epoch_num)):
global_loss = 0.0
for step, batch in tqdm(enumerate(self.data_loader_choice[u"train"]), mininterval=5,
desc=u'training at epoch : %d ' % epoch, leave=False, file=sys.stdout):
loss = self.forward(batch)
if step % step_gap == 0:
global_loss += loss
current_loss = global_loss / step_gap
print(
u"step {} / {} of epoch {}, train/loss: {}".format(step, len(self.data_loader_choice["train"]),
epoch, current_loss))
global_loss = 0.0
res_dev = self.eval_data_set("dev")
if res_dev['f1'] >= best_f1:
best_f1 = res_dev['f1']
logging.info("** ** * Saving fine-tuned model ** ** * ")
model_to_save = self.model.module if hasattr(self.model,
'module') else self.model # Only save the model it-self
output_model_file = args.output + "/pytorch_model.bin"
torch.save(model_to_save.state_dict(), str(output_model_file))
patience_stop = 0
else:
patience_stop += 1
if patience_stop >= args.patience_stop:
return
def resume(self, args):
resume_model_file = args.output + "/pytorch_model.bin"
logging.info("=> loading checkpoint '{}'".format(resume_model_file))
checkpoint = torch.load(resume_model_file, map_location='cpu')
self.model.load_state_dict(checkpoint)
def forward(self, batch, chosen=u'train', eval=False, answer_dict=None):
batch = tuple(t.to(self.device) for t in batch)
if not eval:
p_ids, input_ids, segment_ids, token_type_ids, s1, s2, po1, po2 = batch
loss = self.model(passages=input_ids, token_type_ids=token_type_ids, segment_ids=segment_ids, s1=s1, s2=s2,
po1=po1, po2=po2,
is_eval=eval)
if self.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
loss.backward()
loss = loss.item()
self.optimizer.step()
self.optimizer.zero_grad()
return loss
else:
p_ids, input_ids, segment_ids = batch
eval_file = self.eval_file_choice[chosen]
qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor = self.model(q_ids=p_ids, eval_file=eval_file,
passages=input_ids, is_eval=eval)
ans_dict = self.convert_spo_contour(qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor, eval_file,
answer_dict, use_bert=self.args.use_bert)
return ans_dict
def eval_data_set(self, chosen="dev"):
self.model.eval()
data_loader = self.data_loader_choice[chosen]
eval_file = self.eval_file_choice[chosen]
answer_dict = {i: [[], []] for i in range(len(eval_file))}
last_time = time.time()
with torch.no_grad():
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
self.forward(batch, chosen, eval=True, answer_dict=answer_dict)
used_time = time.time() - last_time
logging.info('chosen {} took : {} sec'.format(chosen, used_time))
res = self.evaluate(eval_file, answer_dict, chosen)
# self.detail_evaluate(eval_file, answer_dict, chosen)
self.model.train()
return res
def show(self, chosen="dev"):
self.model.eval()
answer_dict = {}
data_loader = self.data_loader_choice[chosen]
eval_file = self.eval_file_choice[chosen]
with torch.no_grad():
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
loss, answer_dict_ = self.forward(batch, chosen, eval=True)
answer_dict.update(answer_dict_)
self.badcase_analysis(eval_file, answer_dict, chosen)
@staticmethod
def evaluate(eval_file, answer_dict, chosen):
entity_em = 0
entity_pred_num = 0
entity_gold_num = 0
triple_em = 0
triple_pred_num = 0
triple_gold_num = 0
for key, value in answer_dict.items():
triple_gold = eval_file[key].gold_answer
entity_gold = eval_file[key].sub_entity_list
entity_pred, triple_pred = value
entity_em += len(set(entity_pred) & set(entity_gold))
entity_pred_num += len(set(entity_pred))
entity_gold_num += len(set(entity_gold))
triple_em += len(set(triple_pred) & set(triple_gold))
triple_pred_num += len(set(triple_pred))
triple_gold_num += len(set(triple_gold))
entity_precision = 100.0 * entity_em / entity_pred_num if entity_pred_num > 0 else 0.
entity_recall = 100.0 * entity_em / entity_gold_num if entity_gold_num > 0 else 0.
entity_f1 = 2 * entity_recall * entity_precision / (entity_recall + entity_precision) if (
entity_recall + entity_precision) != 0 else 0.0
precision = 100.0 * triple_em / triple_pred_num if triple_pred_num > 0 else 0.
recall = 100.0 * triple_em / triple_gold_num if triple_gold_num > 0 else 0.
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) != 0 else 0.0
print('============================================')
print("{}/entity_em: {},\tentity_pred_num&entity_gold_num: {}\t{} ".format(chosen, entity_em, entity_pred_num,
entity_gold_num))
print(
"{}/entity_f1: {}, \tentity_precision: {},\tentity_recall: {} ".format(chosen, entity_f1, entity_precision,
entity_recall))
print('============================================')
print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, triple_em, triple_pred_num, triple_gold_num))
print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(chosen, f1, precision,
recall))
return {'f1': f1, "recall": recall, "precision": precision}
@staticmethod
def badcase_analysis(eval_file, answer_dict, chosen):
em = 0
pre = 0
gold = 0
content = ''
for key, value in answer_dict.items():
entity_name = eval_file[int(key)].entity_name
context = eval_file[int(key)].context
ground_truths = eval_file[int(key)].gold_answer
value, l1, l2 = value
prediction = list(value) if len(value) else ['']
assert type(prediction) == type(ground_truths)
intersection = set(prediction) & set(ground_truths)
if prediction == ground_truths == ['']:
continue
if set(prediction) != set(ground_truths):
ground_truths = list(sorted(set(ground_truths)))
prediction = list(sorted(set(prediction)))
print('raw context is:\t' + context)
print('subject_name is:\t' + entity_name)
print('pred_text is:\t' + '\t'.join(prediction))
print('gold_text is:\t' + '\t'.join(ground_truths))
content += 'raw context is:\t' + context + '\n'
content += 'subject_name is:\t' + entity_name + '\n'
content += 'pred_text is:\t' + '\t'.join(prediction) + '\n'
content += 'gold_text is:\t' + '\t'.join(ground_truths) + '\n'
content += '==============================='
em += len(intersection)
pre += len(set(prediction))
gold += len(set(ground_truths))
with open('badcase_{}.txt'.format(chosen), 'w') as f:
f.write(content)
def convert_spo_contour(self, qid_tensor, po1, po2, s_tensor, e_tensor, eval_file, answer_dict, use_bert=False):
for qid, s, e, o1, o2 in zip(qid_tensor.data.cpu().numpy(), s_tensor.data.cpu().numpy(),
e_tensor.data.cpu().numpy(), po1.data.cpu().numpy(), po2.data.cpu().numpy()):
if qid == -1:
continue
context = eval_file[qid.item()].context if not use_bert else eval_file[qid.item()].bert_tokens
gold_answer = eval_file[qid].gold_answer
_subject = ''.join(context[s:e]) if use_bert else context[s:e]
answers = list()
start, end = np.where(o1 > 0.5), np.where(o2 > 0.5)
for _start, _predict_id_start in zip(*start):
if _start > len(context) or (_start == 0 and use_bert):
continue
for _end, _predict_id_end in zip(*end):
if _start <= _end < len(context) and _predict_id_start == _predict_id_end:
_obeject = ''.join(context[_start: _end + 1]) if use_bert else context[_start: _end + 1]
_predicate = self.id2rel[_predict_id_start]
answers.append((_subject, _predicate, _obeject))
break
if qid not in answer_dict:
print('erro in answer_dict ')
else:
answer_dict[qid][0].append((_subject, s-1, e-1))
answer_dict[qid][1].extend(answers)