commit
c81a00937e
3
.gitignore
vendored
3
.gitignore
vendored
@ -6,7 +6,8 @@ settings.py
|
||||
|
||||
instance/
|
||||
data/
|
||||
|
||||
cpt/
|
||||
pyscripts/
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
|
||||
|
51
config/baidu_spo_config.py
Normal file
51
config/baidu_spo_config.py
Normal file
@ -0,0 +1,51 @@
|
||||
BAIDU_RELATION = {
|
||||
"朝代": 0,
|
||||
"人口数量": 1,
|
||||
"出生地": 2,
|
||||
"连载网站": 3,
|
||||
"身高": 4,
|
||||
"占地面积": 5,
|
||||
"作者": 6,
|
||||
"目": 7,
|
||||
"母亲": 8,
|
||||
"海拔": 9,
|
||||
"作词": 10,
|
||||
"嘉宾": 11,
|
||||
"总部地点": 12,
|
||||
"出版社": 13,
|
||||
"主持人": 14,
|
||||
"出生日期": 15,
|
||||
"所在城市": 16,
|
||||
"修业年限": 17,
|
||||
"祖籍": 18,
|
||||
"邮政编码": 19,
|
||||
"毕业院校": 20,
|
||||
"气候": 21,
|
||||
"号": 22,
|
||||
"注册资本": 23,
|
||||
"丈夫": 24,
|
||||
"国籍": 25,
|
||||
"主角": 26,
|
||||
"主演": 27,
|
||||
"民族": 28,
|
||||
"董事长": 29,
|
||||
"所属专辑": 30,
|
||||
"专业代码": 31,
|
||||
"改编自": 32,
|
||||
"歌手": 33,
|
||||
"编剧": 34,
|
||||
"妻子": 35,
|
||||
"面积": 36,
|
||||
"作曲": 37,
|
||||
"官方语言": 38,
|
||||
"出品公司": 39,
|
||||
"成立日期": 40,
|
||||
"简称": 41,
|
||||
"首都": 42,
|
||||
"父亲": 43,
|
||||
"字": 44,
|
||||
"制片人": 45,
|
||||
"上映时间": 46,
|
||||
"创始人": 47,
|
||||
"导演": 48
|
||||
}
|
232
models/ere_net/bert_mpn.py
Normal file
232
models/ere_net/bert_mpn.py
Normal file
@ -0,0 +1,232 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from layers.encoders.transformers.bert.bert_model import BertModel
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
|
||||
class EntityNET(nn.Module):
|
||||
"""
|
||||
ERENet : entity relation extraction
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
super(EntityNET, self).__init__()
|
||||
|
||||
self.sb1 = nn.Linear(args.bert_hidden_size, 1)
|
||||
self.sb2 = nn.Linear(args.bert_hidden_size, 1)
|
||||
|
||||
def forward(self, sent_encoder, q_ids=None, eval_file=None, passages=None, s1=None, s2=None, is_eval=False):
|
||||
|
||||
sequence_mask = passages != 0
|
||||
sb1 = self.sb1(sent_encoder).squeeze()
|
||||
sb2 = self.sb2(sent_encoder).squeeze()
|
||||
|
||||
if not is_eval:
|
||||
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
sb1_loss = loss_fct(sb1, s1)
|
||||
s1_loss = torch.sum(sb1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float())
|
||||
|
||||
s2_loss = loss_fct(sb2, s2)
|
||||
s2_loss = torch.sum(s2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float())
|
||||
|
||||
ent_loss = s1_loss + s2_loss
|
||||
return ent_loss
|
||||
else:
|
||||
answer_list = self.predict(eval_file, q_ids, sb1, sb2)
|
||||
return answer_list
|
||||
|
||||
def predict(self, eval_file, q_ids=None, sb1=None, sb2=None):
|
||||
answer_list = list()
|
||||
for qid, p1, p2 in zip(q_ids.cpu().numpy(),
|
||||
sb1.cpu().numpy(),
|
||||
sb2.cpu().numpy()):
|
||||
|
||||
context = eval_file[qid].context
|
||||
start = None
|
||||
end = None
|
||||
threshold = 0.0
|
||||
positions = list()
|
||||
for idx in range(0, len(context)):
|
||||
if idx == 0:
|
||||
continue
|
||||
if p1[idx] > threshold and start is None:
|
||||
start = idx
|
||||
if p2[idx] > threshold and end is None:
|
||||
end = idx
|
||||
if start is not None and end is not None and start <= end:
|
||||
positions.append((start, end + 1))
|
||||
start = None
|
||||
end = None
|
||||
answer_list.append(positions)
|
||||
|
||||
return answer_list
|
||||
|
||||
|
||||
class RelNET(nn.Module):
|
||||
"""
|
||||
ERENet : entity relation extraction
|
||||
"""
|
||||
|
||||
def __init__(self, args, spo_conf):
|
||||
super(RelNET, self).__init__()
|
||||
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.bert_hidden_size,
|
||||
padding_idx=0)
|
||||
self.encoder_layer = TransformerEncoderLayer(args.bert_hidden_size, args.nhead)
|
||||
self.transformer_encoder = TransformerEncoder(self.encoder_layer, args.transformer_layers)
|
||||
|
||||
self.classes_num = len(spo_conf)
|
||||
self.ob1 = nn.Linear(args.bert_hidden_size, self.classes_num)
|
||||
self.ob2 = nn.Linear(args.bert_hidden_size, self.classes_num)
|
||||
|
||||
def forward(self, passages=None, sent_encoder=None, posit_ids=None, o1=None, o2=None, is_eval=False):
|
||||
mask = passages.eq(0)
|
||||
|
||||
subject_encoder = sent_encoder + self.token_entity_emb(posit_ids)
|
||||
|
||||
subject_encoder = torch.transpose(subject_encoder, 1, 0)
|
||||
transformer_encoder = self.transformer_encoder(subject_encoder, src_key_padding_mask=mask)
|
||||
transformer_encoder = torch.transpose(transformer_encoder, 0, 1)
|
||||
|
||||
po1 = self.ob1(transformer_encoder)
|
||||
po2 = self.ob2(transformer_encoder)
|
||||
|
||||
if not is_eval:
|
||||
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
sequence_mask = passages != 0
|
||||
|
||||
s1_loss = loss_fct(po1, o1)
|
||||
s1_loss = torch.sum(s1_loss, 2)
|
||||
s1_loss = torch.sum(s1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num
|
||||
|
||||
s2_loss = loss_fct(po2, o2)
|
||||
s2_loss = torch.sum(s2_loss, 2)
|
||||
s2_loss = torch.sum(s2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num
|
||||
|
||||
rel_loss = s1_loss + s2_loss
|
||||
|
||||
return rel_loss
|
||||
|
||||
else:
|
||||
po1 = nn.Sigmoid()(po1)
|
||||
po2 = nn.Sigmoid()(po2)
|
||||
return po1, po2
|
||||
|
||||
|
||||
class ERENet(nn.Module):
|
||||
"""
|
||||
ERENet : entity relation extraction
|
||||
"""
|
||||
|
||||
def __init__(self, args, spo_conf):
|
||||
super(ERENet, self).__init__()
|
||||
print('joint entity relation extraction')
|
||||
self.bert_encoder = BertModel.from_pretrained(args.bert_model)
|
||||
self.entity_extraction = EntityNET(args)
|
||||
self.rel_extraction = RelNET(args, spo_conf)
|
||||
|
||||
def forward(self, q_ids=None, eval_file=None, passages=None, token_type_ids=None, segment_ids=None, s1=None,
|
||||
s2=None, po1=None, po2=None, is_eval=False):
|
||||
|
||||
sequence_mask = passages != 0
|
||||
sent_encoder, _ = self.bert_encoder(passages, token_type_ids=segment_ids, attention_mask=sequence_mask,
|
||||
output_all_encoded_layers=False)
|
||||
|
||||
if not is_eval:
|
||||
# entity_extraction
|
||||
ent_loss = self.entity_extraction(sent_encoder, passages=passages, s1=s1, s2=s2,
|
||||
is_eval=is_eval)
|
||||
|
||||
# rel_extraction
|
||||
rel_loss = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, posit_ids=token_type_ids,
|
||||
o1=po1,
|
||||
o2=po2, is_eval=False)
|
||||
|
||||
# add total loss
|
||||
total_loss = ent_loss + rel_loss
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
else:
|
||||
|
||||
answer_list = self.entity_extraction(sent_encoder, q_ids=q_ids, eval_file=eval_file,
|
||||
passages=passages, is_eval=is_eval)
|
||||
start_list, end_list = list(), list()
|
||||
qid_list, pass_list, posit_list, sent_list = list(), list(), list(), list()
|
||||
for i, ans_list in enumerate(answer_list):
|
||||
seq_len = passages.size(1)
|
||||
posit_ids = []
|
||||
for ans_tuple in ans_list:
|
||||
posit_array = np.zeros(seq_len, dtype=np.int)
|
||||
start, end = ans_tuple[0], ans_tuple[1]
|
||||
start_list.append(start)
|
||||
end_list.append(end)
|
||||
posit_array[start:end] = 1
|
||||
posit_ids.append(posit_array)
|
||||
|
||||
if len(posit_ids) == 0:
|
||||
continue
|
||||
qid_ = q_ids[i].unsqueeze(0).expand(len(posit_ids))
|
||||
sent_tensor = sent_encoder[i, :, :].unsqueeze(0).expand(len(posit_ids), sent_encoder.size(1),
|
||||
sent_encoder.size(2))
|
||||
pass_tensor = passages[i, :].unsqueeze(0).expand(len(posit_ids), passages.size(1))
|
||||
posit_tensor = torch.tensor(posit_ids, dtype=torch.long).to(sent_encoder.device)
|
||||
|
||||
qid_list.append(qid_)
|
||||
pass_list.append(pass_tensor)
|
||||
posit_list.append(posit_tensor)
|
||||
sent_list.append(sent_tensor)
|
||||
|
||||
if len(qid_list) == 0:
|
||||
qid_tensor = torch.tensor([-1, -1], dtype=torch.long).to(sent_encoder.device)
|
||||
return qid_tensor, qid_tensor, qid_tensor, qid_tensor, qid_tensor
|
||||
|
||||
qid_tensor = torch.cat(qid_list).to(sent_encoder.device)
|
||||
sent_tensor = torch.cat(sent_list).to(sent_encoder.device)
|
||||
pass_tensor = torch.cat(pass_list).to(sent_encoder.device)
|
||||
posi_tensor = torch.cat(posit_list).to(sent_encoder.device)
|
||||
|
||||
flag = False
|
||||
split_heads = 1024
|
||||
|
||||
inputs = torch.split(pass_tensor, split_heads, dim=0)
|
||||
posits = torch.split(posi_tensor, split_heads, dim=0)
|
||||
sents = torch.split(sent_tensor, split_heads, dim=0)
|
||||
|
||||
po1_list, po2_list = list(), list()
|
||||
for i in range(len(inputs)):
|
||||
passages = inputs[i]
|
||||
sent_encoder = sents[i]
|
||||
posit_ids = posits[i]
|
||||
|
||||
if passages.size(0) == 1:
|
||||
flag = True
|
||||
passages = passages.expand(2, passages.size(1))
|
||||
sent_encoder = sent_encoder.expand(2, sent_encoder.size(1), sent_encoder.size(2))
|
||||
posit_ids = posit_ids.expand(2, posit_ids.size(1))
|
||||
|
||||
po1, po2 = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, posit_ids=posit_ids,
|
||||
is_eval=is_eval)
|
||||
if flag:
|
||||
po1 = po1[1, :, :].unsqueeze(0)
|
||||
po2 = po2[1, :, :].unsqueeze(0)
|
||||
|
||||
po1_list.append(po1)
|
||||
po2_list.append(po2)
|
||||
|
||||
po1_tensor = torch.cat(po1_list).to(sent_encoder.device)
|
||||
po2_tensor = torch.cat(po2_list).to(sent_encoder.device)
|
||||
|
||||
s_tensor = torch.tensor(start_list, dtype=torch.long).to(sent_encoder.device)
|
||||
e_tensor = torch.tensor(end_list, dtype=torch.long).to(sent_encoder.device)
|
||||
|
||||
return qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor
|
248
models/ere_net/mpn.py
Normal file
248
models/ere_net/mpn.py
Normal file
@ -0,0 +1,248 @@
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
from layers.encoders.rnns.stacked_rnn import StackedBRNN
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
class SentenceEncoder(nn.Module):
|
||||
def __init__(self, args, input_size):
|
||||
super(SentenceEncoder, self).__init__()
|
||||
rnn_type = nn.LSTM if args.rnn_encoder == 'lstm' else nn.GRU
|
||||
self.encoder = StackedBRNN(
|
||||
input_size=input_size,
|
||||
hidden_size=args.hidden_size,
|
||||
num_layers=args.num_layers,
|
||||
dropout_rate=args.dropout,
|
||||
dropout_output=True,
|
||||
concat_layers=False,
|
||||
rnn_type=rnn_type,
|
||||
padding=True
|
||||
)
|
||||
|
||||
def forward(self, input, mask):
|
||||
return self.encoder(input, mask)
|
||||
|
||||
|
||||
class EntityNET(nn.Module):
|
||||
"""
|
||||
EntityNET : entity extraction using pointer network
|
||||
"""
|
||||
|
||||
def __init__(self, args, char_emb):
|
||||
super(EntityNET, self).__init__()
|
||||
|
||||
if char_emb is not None:
|
||||
self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
|
||||
padding_idx=0)
|
||||
else:
|
||||
self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
|
||||
padding_idx=0)
|
||||
|
||||
self.sentence_encoder = SentenceEncoder(args, args.word_emb_size)
|
||||
self.s1 = nn.Linear(args.hidden_size * 2, 1)
|
||||
self.s2 = nn.Linear(args.hidden_size * 2, 1)
|
||||
|
||||
def forward(self, q_ids=None, eval_file=None, passages=None, s1=None, s2=None, is_eval=False):
|
||||
mask = passages.eq(0)
|
||||
sequence_mask = passages != 0
|
||||
|
||||
char_emb = self.char_emb(passages)
|
||||
|
||||
sent_encoder = self.sentence_encoder(char_emb, mask)
|
||||
|
||||
s1_ = self.s1(sent_encoder).squeeze()
|
||||
s2_ = self.s2(sent_encoder).squeeze()
|
||||
|
||||
if not is_eval:
|
||||
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
sb1_loss = loss_fct(s1_, s1)
|
||||
s1_loss = torch.sum(sb1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float())
|
||||
|
||||
s2_loss = loss_fct(s2_, s2)
|
||||
s2_loss = torch.sum(s2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float())
|
||||
|
||||
ent_loss = s1_loss + s2_loss
|
||||
return sent_encoder, ent_loss
|
||||
else:
|
||||
answer_list = self.predict(eval_file, q_ids, s1_, s2_)
|
||||
return sent_encoder, answer_list
|
||||
|
||||
def predict(self, eval_file, q_ids=None, s1=None, s2=None):
|
||||
sub_ans_list = list()
|
||||
for qid, p1, p2 in zip(q_ids.cpu().numpy(),
|
||||
s1.cpu().numpy(),
|
||||
s2.cpu().numpy()):
|
||||
start = None
|
||||
end = None
|
||||
threshold = 0.0
|
||||
positions = list()
|
||||
for idx in range(0, len(eval_file[qid].context)):
|
||||
if p1[idx] > threshold and start is None:
|
||||
start = idx
|
||||
if p2[idx] > threshold and end is None:
|
||||
end = idx
|
||||
if start is not None and end is not None and start <= end:
|
||||
positions.append((start, end + 1))
|
||||
start = None
|
||||
end = None
|
||||
sub_ans_list.append(positions)
|
||||
|
||||
return sub_ans_list
|
||||
|
||||
|
||||
class RelNET(nn.Module):
|
||||
"""
|
||||
ERENet : entity relation extraction
|
||||
"""
|
||||
|
||||
def __init__(self, args, spo_conf):
|
||||
super(RelNET, self).__init__()
|
||||
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
|
||||
padding_idx=0)
|
||||
self.sentence_encoder = SentenceEncoder(args, args.word_emb_size)
|
||||
self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead)
|
||||
|
||||
self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
|
||||
|
||||
self.classes_num = len(spo_conf)
|
||||
self.po1 = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
self.po2 = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
|
||||
def forward(self, passages=None, sent_encoder=None, token_type_id=None, po1=None, po2=None, is_eval=False):
|
||||
mask = passages.eq(0)
|
||||
sequence_mask = passages != 0
|
||||
|
||||
subject_encoder = sent_encoder + self.token_entity_emb(token_type_id)
|
||||
sent_sub_aware_encoder = self.sentence_encoder(subject_encoder, mask).transpose(1, 0)
|
||||
|
||||
transformer_encoder = self.transformer_encoder(sent_sub_aware_encoder, src_key_padding_mask=mask).transpose(0,
|
||||
1)
|
||||
|
||||
po1_ = self.po1(transformer_encoder)
|
||||
po2_ = self.po2(transformer_encoder)
|
||||
|
||||
if not is_eval:
|
||||
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
po1_loss = loss_fct(po1_, po1)
|
||||
po1_loss = torch.sum(po1_loss, 2)
|
||||
po1_loss = torch.sum(po1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num
|
||||
|
||||
po2_loss = loss_fct(po2_, po2)
|
||||
po2_loss = torch.sum(po2_loss, 2)
|
||||
po2_loss = torch.sum(po2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num
|
||||
|
||||
rel_loss = po1_loss + po2_loss
|
||||
|
||||
return rel_loss
|
||||
|
||||
else:
|
||||
po1 = nn.Sigmoid()(po1_)
|
||||
po2 = nn.Sigmoid()(po2_)
|
||||
return po1, po2
|
||||
|
||||
|
||||
class ERENet(nn.Module):
|
||||
"""
|
||||
ERENet : entity relation jointed extraction with Multi-label Pointer Network(MPN) based Entity-aware
|
||||
"""
|
||||
|
||||
def __init__(self, args, char_emb, spo_conf):
|
||||
super(ERENet, self).__init__()
|
||||
print('joint entity relation extraction')
|
||||
self.entity_extraction = EntityNET(args, char_emb)
|
||||
self.rel_extraction = RelNET(args, spo_conf)
|
||||
|
||||
def forward(self, q_ids=None, eval_file=None, passages=None, token_type_ids=None, segment_ids=None, s1=None,
|
||||
s2=None, po1=None, po2=None, is_eval=False):
|
||||
|
||||
if not is_eval:
|
||||
sent_encoder, ent_loss = self.entity_extraction(passages=passages, s1=s1, s2=s2, is_eval=is_eval)
|
||||
rel_loss = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, token_type_id=token_type_ids,
|
||||
po1=po1, po2=po2, is_eval=False)
|
||||
total_loss = ent_loss + rel_loss
|
||||
|
||||
return total_loss
|
||||
else:
|
||||
|
||||
sent_encoder, answer_list = self.entity_extraction(q_ids=q_ids, eval_file=eval_file,
|
||||
passages=passages, is_eval=is_eval)
|
||||
start_list, end_list = list(), list()
|
||||
qid_list, pass_list, posit_list, sent_list = list(), list(), list(), list()
|
||||
for i, ans_list in enumerate(answer_list):
|
||||
seq_len = passages.size(1)
|
||||
posit_ids = []
|
||||
for ans_tuple in ans_list:
|
||||
posit_array = np.zeros(seq_len, dtype=np.int)
|
||||
start, end = ans_tuple[0], ans_tuple[1]
|
||||
start_list.append(start)
|
||||
end_list.append(end)
|
||||
posit_array[start:end] = 1
|
||||
posit_ids.append(posit_array)
|
||||
|
||||
if len(posit_ids) == 0:
|
||||
continue
|
||||
qid_ = q_ids[i].unsqueeze(0).expand(len(posit_ids))
|
||||
sent_tensor = sent_encoder[i, :, :].unsqueeze(0).expand(len(posit_ids), sent_encoder.size(1),
|
||||
sent_encoder.size(2))
|
||||
pass_tensor = passages[i, :].unsqueeze(0).expand(len(posit_ids), passages.size(1))
|
||||
posit_tensor = torch.tensor(posit_ids, dtype=torch.long).to(sent_encoder.device)
|
||||
|
||||
qid_list.append(qid_)
|
||||
pass_list.append(pass_tensor)
|
||||
posit_list.append(posit_tensor)
|
||||
sent_list.append(sent_tensor)
|
||||
|
||||
if len(qid_list) == 0:
|
||||
# print('len(qid_list)==0:')
|
||||
qid_tensor = torch.tensor([-1, -1], dtype=torch.long).to(sent_encoder.device)
|
||||
return qid_tensor, qid_tensor, qid_tensor, qid_tensor, qid_tensor
|
||||
|
||||
qid_tensor = torch.cat(qid_list).to(sent_encoder.device)
|
||||
sent_tensor = torch.cat(sent_list).to(sent_encoder.device)
|
||||
pass_tensor = torch.cat(pass_list).to(sent_encoder.device)
|
||||
posi_tensor = torch.cat(posit_list).to(sent_encoder.device)
|
||||
|
||||
flag = False
|
||||
split_heads = 1024
|
||||
|
||||
inputs = torch.split(pass_tensor, split_heads, dim=0)
|
||||
posits = torch.split(posi_tensor, split_heads, dim=0)
|
||||
sents = torch.split(sent_tensor, split_heads, dim=0)
|
||||
|
||||
po1_list, po2_list = list(), list()
|
||||
for i in range(len(inputs)):
|
||||
passages = inputs[i]
|
||||
sent_encoder = sents[i]
|
||||
posit_ids = posits[i]
|
||||
|
||||
if passages.size(0) == 1:
|
||||
flag = True
|
||||
# print('flag = True**********')
|
||||
passages = passages.expand(2, passages.size(1))
|
||||
sent_encoder = sent_encoder.expand(2, sent_encoder.size(1), sent_encoder.size(2))
|
||||
posit_ids = posit_ids.expand(2, posit_ids.size(1))
|
||||
|
||||
po1, po2 = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, token_type_id=posit_ids,
|
||||
is_eval=is_eval)
|
||||
if flag:
|
||||
po1 = po1[1, :, :].unsqueeze(0)
|
||||
po2 = po2[1, :, :].unsqueeze(0)
|
||||
|
||||
po1_list.append(po1)
|
||||
po2_list.append(po2)
|
||||
|
||||
po1_tensor = torch.cat(po1_list).to(sent_encoder.device)
|
||||
po2_tensor = torch.cat(po2_list).to(sent_encoder.device)
|
||||
|
||||
s_tensor = torch.tensor(start_list, dtype=torch.long).to(sent_encoder.device)
|
||||
e_tensor = torch.tensor(end_list, dtype=torch.long).to(sent_encoder.device)
|
||||
|
||||
return qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor
|
0
pyscripts/__init__.py
Normal file
0
pyscripts/__init__.py
Normal file
@ -3,7 +3,7 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from run.attribute_extract.mpn.data import Reader, Vocabulary, config, Feature
|
||||
from run.attribute_extract.mpn.data_loader import Reader, Vocabulary, config, Feature
|
||||
from run.attribute_extract.mpn.train import Trainer
|
||||
from utils.file_util import save, load
|
||||
|
||||
|
0
run/entity_relation_jointed_extraction/__init__.py
Normal file
0
run/entity_relation_jointed_extraction/__init__.py
Normal file
475
run/entity_relation_jointed_extraction/mpn/data_loader.py
Normal file
475
run/entity_relation_jointed_extraction/mpn/data_loader.py
Normal file
@ -0,0 +1,475 @@
|
||||
import codecs
|
||||
import json
|
||||
import logging
|
||||
from collections import Counter
|
||||
from functools import partial
|
||||
|
||||
import jieba
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from config.baidu_spo_config import BAIDU_RELATION
|
||||
from layers.encoders.transformers.bert.bert_tokenization import BertTokenizer
|
||||
from utils.data_util import padding, _handle_pos_limit, find_position, spo_padding
|
||||
|
||||
|
||||
class PredictObject(object):
|
||||
def __init__(self,
|
||||
object_name,
|
||||
object_start,
|
||||
object_end,
|
||||
predict_type,
|
||||
predict_type_id
|
||||
):
|
||||
self.object_name = object_name
|
||||
self.object_start = object_start
|
||||
self.object_end = object_end
|
||||
self.predict_type = predict_type
|
||||
self.predict_type_id = predict_type_id
|
||||
|
||||
|
||||
class Example(object):
|
||||
def __init__(self,
|
||||
p_id=None,
|
||||
context=None,
|
||||
bert_tokens=None,
|
||||
sub_pos=None,
|
||||
sub_entity_list=None,
|
||||
relative_pos_start=None,
|
||||
relative_pos_end=None,
|
||||
po_list=None,
|
||||
gold_answer=None):
|
||||
self.p_id = p_id
|
||||
self.context = context
|
||||
self.bert_tokens = bert_tokens
|
||||
self.sub_pos = sub_pos
|
||||
self.sub_entity_list = sub_entity_list
|
||||
self.relative_pos_start = relative_pos_start
|
||||
self.relative_pos_end = relative_pos_end
|
||||
self.po_list = po_list
|
||||
self.gold_answer = gold_answer
|
||||
|
||||
|
||||
class InputFeature(object):
|
||||
|
||||
def __init__(self,
|
||||
p_id=None,
|
||||
passage_id=None,
|
||||
token_type_id=None,
|
||||
pos_start_id=None,
|
||||
pos_end_id=None,
|
||||
segment_id=None,
|
||||
po_label=None,
|
||||
s1=None,
|
||||
s2=None):
|
||||
self.p_id = p_id
|
||||
self.passage_id = passage_id
|
||||
self.token_type_id = token_type_id
|
||||
self.pos_start_id = pos_start_id
|
||||
self.pos_end_id = pos_end_id
|
||||
self.segment_id = segment_id
|
||||
self.po_label = po_label
|
||||
self.s1 = s1
|
||||
self.s2 = s2
|
||||
|
||||
|
||||
class Reader(object):
|
||||
def __init__(self, do_lowercase=False, seg_char=False, max_len=600):
|
||||
|
||||
self.do_lowercase = do_lowercase
|
||||
self.seg_char = seg_char
|
||||
self.max_len = max_len
|
||||
self.relation_config = BAIDU_RELATION
|
||||
|
||||
if self.seg_char:
|
||||
logging.info("seg_char...")
|
||||
else:
|
||||
logging.info("seg_word using jieba ...")
|
||||
|
||||
def read_examples(self, filename, data_type):
|
||||
logging.info("Generating {} examples...".format(data_type))
|
||||
return self._read(filename, data_type)
|
||||
|
||||
def _data_process(self, filename, data_type='train'):
|
||||
output_data = list()
|
||||
with codecs.open(filename, 'r') as f:
|
||||
gold_num = 0
|
||||
for line in tqdm(f):
|
||||
data_json = json.loads(line.strip())
|
||||
text = data_json['text'].lower()
|
||||
sub_po_dict, sub_ent_list, spo_list = dict(), list(), list()
|
||||
|
||||
for spo in data_json['spo_list']:
|
||||
# TODO .strip('《》').strip()
|
||||
subject_name = spo['subject'].lower().strip('《》').strip()
|
||||
object_name = spo['object'].lower().strip('《》').strip()
|
||||
s_start, s_end = find_position(subject_name, text)
|
||||
o_start, o_end = find_position(object_name, text)
|
||||
|
||||
if text[s_start:s_end] != subject_name:
|
||||
# print(subject_name)
|
||||
subject_name = spo['subject'].lower().replace('》', '').replace('《', '')
|
||||
s_start, s_end = find_position(subject_name, text)
|
||||
if s_start != -1 and o_start != -1:
|
||||
sub_ent_list.append((subject_name, s_start, s_end))
|
||||
spo_list.append((subject_name, spo['predicate'], object_name))
|
||||
|
||||
if subject_name not in sub_po_dict:
|
||||
sub_po_dict[subject_name] = {}
|
||||
sub_po_dict[subject_name]['sub_pos'] = [s_start, s_end]
|
||||
sub_po_dict[subject_name]['po_list'] = [
|
||||
{'predict': spo['predicate'], 'object': (object_name, o_start, o_end)}]
|
||||
else:
|
||||
sub_po_dict[subject_name]['po_list'].append(
|
||||
{'predict': spo['predicate'], 'object': (object_name, o_start, o_end)})
|
||||
text_spo = dict()
|
||||
text_spo['context'] = text
|
||||
text_spo['sub_po_dict'] = sub_po_dict
|
||||
text_spo['spo_list'] = list(set(spo_list))
|
||||
text_spo['sub_ent_list'] = list(set(sub_ent_list))
|
||||
gold_num += len(set(spo_list))
|
||||
output_data.append(text_spo)
|
||||
|
||||
if data_type == 'train':
|
||||
return self._convert_train_data(output_data)
|
||||
# print(f'total gold num is {gold_num}')
|
||||
return output_data
|
||||
|
||||
@staticmethod
|
||||
def _convert_train_data(src_data):
|
||||
"""
|
||||
将train_data转化为满足训练要求的形式,即:
|
||||
1条数据为:一个subject对应响应的(predict,object)-->sub_po_dict
|
||||
:param data:
|
||||
:return:
|
||||
"""
|
||||
spo_data = []
|
||||
for data in src_data:
|
||||
for sub_ent, po_dict in data['sub_po_dict'].items():
|
||||
data['sub_name'] = sub_ent
|
||||
data['sub_pos'] = po_dict['sub_pos']
|
||||
data['po_list'] = po_dict['po_list']
|
||||
|
||||
spo_data.append(data)
|
||||
return spo_data
|
||||
|
||||
def _read(self, filename, data_type):
|
||||
|
||||
data_set = self._data_process(filename, data_type)
|
||||
logging.info("{} data_set total size is {} ".format(data_type, len(data_set)))
|
||||
examples = []
|
||||
for p_id in tqdm(range(len(data_set))):
|
||||
data = data_set[p_id]
|
||||
para = data['context']
|
||||
context = para if self.seg_char else ''.join(jieba.lcut(para))
|
||||
if len(context) > self.max_len:
|
||||
context = context[:self.max_len]
|
||||
|
||||
if data_type == 'train':
|
||||
|
||||
start, end = data['sub_pos']
|
||||
if start >= self.max_len or end >= self.max_len:
|
||||
continue
|
||||
assert data['sub_name'] == context[start:end]
|
||||
|
||||
# pos_start&pos_end: 指句子中词语相对subject_entity的position(相对距离)
|
||||
# 如:[-30, 30],embed 时整体+31,变成[1, 61]
|
||||
# 则一共62个pos token,0 留给 pad
|
||||
pos_start = list(map(lambda i: i - start, list(range(len(context)))))
|
||||
pos_end = list(map(lambda i: i - end, list(range(len(context)))))
|
||||
relative_pos_start = _handle_pos_limit(pos_start)
|
||||
relative_pos_end = _handle_pos_limit(pos_end)
|
||||
|
||||
po_list = []
|
||||
for predict_object in data['po_list']:
|
||||
predict_type = predict_object['predict']
|
||||
object_ = predict_object['object']
|
||||
object_name, object_start, object_end = object_[0], object_[1], object_[2]
|
||||
|
||||
if object_start >= self.max_len or object_end >= self.max_len:
|
||||
continue
|
||||
assert object_name == context[object_start:object_end]
|
||||
|
||||
po_list.append(PredictObject(
|
||||
object_name=object_name,
|
||||
object_start=object_start,
|
||||
object_end=object_end,
|
||||
predict_type=predict_type,
|
||||
predict_type_id=self.relation_config[predict_type]
|
||||
))
|
||||
|
||||
examples.append(
|
||||
Example(
|
||||
p_id=p_id,
|
||||
context=context,
|
||||
sub_pos=data['sub_pos'],
|
||||
sub_entity_list=data['sub_ent_list'],
|
||||
relative_pos_start=relative_pos_start,
|
||||
relative_pos_end=relative_pos_end,
|
||||
po_list=po_list,
|
||||
gold_answer=data['spo_list']
|
||||
)
|
||||
)
|
||||
else:
|
||||
examples.append(
|
||||
Example(
|
||||
p_id=p_id,
|
||||
context=context,
|
||||
sub_pos=None,
|
||||
sub_entity_list=data['sub_ent_list'],
|
||||
relative_pos_start=None,
|
||||
relative_pos_end=None,
|
||||
po_list=None,
|
||||
gold_answer=data['spo_list']
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("{} total size is {} ".format(data_type, len(examples)))
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
class Vocabulary(object):
|
||||
|
||||
def __init__(self, special_tokens=["<OOV>", "<MASK>"]):
|
||||
|
||||
self.char_vocab = None
|
||||
self.emb_mat = None
|
||||
self.char2idx = None
|
||||
self.char_counter = Counter()
|
||||
self.special_tokens = special_tokens
|
||||
|
||||
def build_vocab_only_with_char(self, examples, min_char_count=-1):
|
||||
|
||||
logging.info("Building vocabulary only with character...")
|
||||
|
||||
self.char_vocab = ["<PAD>"]
|
||||
|
||||
if self.special_tokens is not None and isinstance(self.special_tokens, list):
|
||||
self.char_vocab.extend(self.special_tokens)
|
||||
|
||||
for example in tqdm(examples):
|
||||
for char in example.context:
|
||||
self.char_counter[char] += 1
|
||||
|
||||
for w, v in self.char_counter.most_common():
|
||||
if v >= min_char_count:
|
||||
self.char_vocab.append(w)
|
||||
|
||||
self.char2idx = {token: idx for idx, token in enumerate(self.char_vocab)}
|
||||
|
||||
logging.info("total char counter size is {} ".format(len(self.char_counter)))
|
||||
logging.info("total char vocabulary size is {} ".format(len(self.char_vocab)))
|
||||
|
||||
def _load_embedding(self, embedding_file, embedding_dict):
|
||||
|
||||
with open(embedding_file) as f:
|
||||
for line in f:
|
||||
if len(line.rstrip().split(" ")) <= 2: continue
|
||||
token, vector = line.rstrip().split(" ", 1)
|
||||
embedding_dict[token] = np.fromstring(vector, dtype=np.float, sep=" ")
|
||||
return embedding_dict
|
||||
|
||||
def make_embedding(self, vocab, embedding_file, emb_size):
|
||||
|
||||
embedding_dict = dict()
|
||||
embedding_dict["<PAD>"] = np.array([0. for _ in range(emb_size)])
|
||||
|
||||
self._load_embedding(embedding_file, embedding_dict)
|
||||
|
||||
count = 0
|
||||
for token in tqdm(vocab):
|
||||
if token not in embedding_dict:
|
||||
count += 1
|
||||
embedding_dict[token] = np.array([np.random.normal(scale=0.1) for _ in range(emb_size)])
|
||||
logging.info(
|
||||
"{} / {} tokens have corresponding in embedding vector".format(len(vocab) - count, len(vocab)))
|
||||
|
||||
emb_mat = [embedding_dict[token] for idx, token in enumerate(vocab)]
|
||||
|
||||
return emb_mat
|
||||
|
||||
|
||||
class Feature(object):
|
||||
def __init__(self, args, token2idx_dict):
|
||||
self.bert = args.use_bert
|
||||
self.token2idx_dict = token2idx_dict
|
||||
if self.bert:
|
||||
self.tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
|
||||
def token2wid(self, token):
|
||||
if token in self.token2idx_dict:
|
||||
return self.token2idx_dict[token]
|
||||
return self.token2idx_dict["<OOV>"]
|
||||
|
||||
def __call__(self, examples, data_type):
|
||||
|
||||
if self.bert:
|
||||
return self.convert_examples_to_bert_features(examples, data_type)
|
||||
else:
|
||||
return self.convert_examples_to_features(examples, data_type)
|
||||
|
||||
def convert_examples_to_features(self, examples, data_type):
|
||||
|
||||
logging.info("convert {} examples to features .".format(data_type))
|
||||
|
||||
examples2features = list()
|
||||
for index, example in enumerate(examples):
|
||||
|
||||
passage_id = np.zeros(len(example.context), dtype=np.int)
|
||||
segment_id = np.zeros(len(example.context), dtype=np.int)
|
||||
token_type_id = np.zeros(len(example.context), dtype=np.int)
|
||||
pos_start_id = np.zeros(len(example.context), dtype=np.int)
|
||||
pos_end_id = np.zeros(len(example.context), dtype=np.int)
|
||||
s1 = np.zeros(len(example.context), dtype=np.float)
|
||||
s2 = np.zeros(len(example.context), dtype=np.float)
|
||||
|
||||
for (_, start, end) in example.sub_entity_list:
|
||||
if start >= len(example.context) or end >= len(example.context):
|
||||
continue
|
||||
s1[start] = 1.0
|
||||
s2[end - 1] = 1.0
|
||||
|
||||
for i, token in enumerate(example.context):
|
||||
passage_id[i] = self.token2wid(token)
|
||||
|
||||
if data_type == 'train':
|
||||
sub_start, sub_end = example.sub_pos[0], example.sub_pos[1]
|
||||
for i, token in enumerate(example.context):
|
||||
if sub_start <= i < sub_end:
|
||||
# token = "<MASK>"
|
||||
token_type_id[i] = 1
|
||||
pos_start_id[i] = example.relative_pos_start[i]
|
||||
pos_end_id[i] = example.relative_pos_end[i]
|
||||
|
||||
examples2features.append(
|
||||
InputFeature(
|
||||
p_id=index,
|
||||
passage_id=passage_id,
|
||||
token_type_id=token_type_id,
|
||||
pos_start_id=pos_start_id,
|
||||
pos_end_id=pos_end_id,
|
||||
segment_id=segment_id,
|
||||
po_label=example.po_list,
|
||||
s1=s1,
|
||||
s2=s2
|
||||
|
||||
))
|
||||
|
||||
logging.info("Built instances is Completed")
|
||||
return SPODataset(examples2features, predict_num=len(BAIDU_RELATION), data_type=data_type)
|
||||
|
||||
def convert_examples_to_bert_features(self, examples, data_type):
|
||||
|
||||
logging.info("Processing {} examples...".format(data_type))
|
||||
|
||||
examples2features = list()
|
||||
for index, example in enumerate(examples):
|
||||
|
||||
segment_id = np.zeros(len(example.context) + 2, dtype=np.int)
|
||||
token_type_id = np.zeros(len(example.context) + 2, dtype=np.int)
|
||||
pos_start_id = np.zeros(len(example.context) + 2, dtype=np.int)
|
||||
pos_end_id = np.zeros(len(example.context) + 2, dtype=np.int)
|
||||
s1 = np.zeros(len(example.context) + 2, dtype=np.float)
|
||||
s2 = np.zeros(len(example.context) + 2, dtype=np.float)
|
||||
|
||||
for (_, start, end) in example.sub_entity_list:
|
||||
if start >= len(example.context) or end >= len(example.context):
|
||||
continue
|
||||
s1[start + 1] = 1.0
|
||||
s2[end] = 1.0
|
||||
|
||||
tokens = ["[CLS]"]
|
||||
for i, token in enumerate(example.context):
|
||||
tokens.append(token)
|
||||
tokens.append("[SEP]")
|
||||
passage_id = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
example.bert_tokens = tokens
|
||||
|
||||
if data_type == 'train':
|
||||
sub_start, sub_end = example.sub_pos[0], example.sub_pos[1]
|
||||
for i, token in enumerate(example.context):
|
||||
if sub_start <= i < sub_end:
|
||||
token_type_id[i + 1] = 1
|
||||
pos_start_id[i + 1] = example.relative_pos_start[i]
|
||||
pos_end_id[i + 1] = example.relative_pos_end[i]
|
||||
|
||||
examples2features.append(
|
||||
InputFeature(
|
||||
p_id=index,
|
||||
passage_id=passage_id,
|
||||
token_type_id=token_type_id,
|
||||
pos_start_id=pos_start_id,
|
||||
pos_end_id=pos_end_id,
|
||||
segment_id=segment_id,
|
||||
po_label=example.po_list,
|
||||
s1=s1,
|
||||
s2=s2
|
||||
))
|
||||
|
||||
logging.info("Built instances is Completed")
|
||||
return SPODataset(examples2features, predict_num=len(BAIDU_RELATION), use_bert=True,data_type=data_type)
|
||||
|
||||
|
||||
class SPODataset(Dataset):
|
||||
def __init__(self, features, predict_num, data_type, use_bert=False):
|
||||
super(SPODataset, self).__init__()
|
||||
self.use_bert = use_bert
|
||||
self.is_train = True if data_type == 'train' else False
|
||||
self.q_ids = [f.p_id for f in features]
|
||||
self.passages = [f.passage_id for f in features]
|
||||
self.segment_ids = [f.segment_id for f in features]
|
||||
self.predict_num = predict_num
|
||||
|
||||
if self.is_train:
|
||||
self.token_type = [f.token_type_id for f in features]
|
||||
self.pos_start_ids = [f.pos_start_id for f in features]
|
||||
self.pos_end_ids = [f.pos_end_id for f in features]
|
||||
self.s1 = [f.s1 for f in features]
|
||||
self.s2 = [f.s2 for f in features]
|
||||
self.po_label = [f.po_label for f in features]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.passages)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_train:
|
||||
return self.q_ids[index], self.passages[index], self.segment_ids[index], self.token_type[index], \
|
||||
self.pos_start_ids[index], self.pos_end_ids[index], self.s1[index], self.s2[index], self.po_label[
|
||||
index]
|
||||
else:
|
||||
return self.q_ids[index], self.passages[index], self.segment_ids[index]
|
||||
|
||||
def _create_collate_fn(self, batch_first=False):
|
||||
def collate(examples):
|
||||
if self.is_train:
|
||||
p_ids, passages, segment_ids, token_type, pos_start_ids, pos_end_ids, s1, s2, label = zip(*examples)
|
||||
|
||||
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
|
||||
passages_tensor, _ = padding(passages, is_float=False, batch_first=batch_first)
|
||||
segment_tensor, _ = padding(segment_ids, is_float=False, batch_first=batch_first)
|
||||
|
||||
token_type_tensor, _ = padding(token_type, is_float=False, batch_first=batch_first)
|
||||
pos_start_tensor, _ = padding(pos_start_ids, is_float=False, batch_first=batch_first)
|
||||
pos_end_tensor, _ = padding(pos_end_ids, is_float=False, batch_first=batch_first)
|
||||
s1_tensor, _ = padding(s1, is_float=True, batch_first=batch_first)
|
||||
s2_tensor, _ = padding(s2, is_float=True, batch_first=batch_first)
|
||||
po1_tensor, po2_tensor = spo_padding(passages, label, class_num=self.predict_num, is_float=True,
|
||||
use_bert=self.use_bert)
|
||||
return p_ids, passages_tensor, segment_tensor, token_type_tensor, s1_tensor, s2_tensor, po1_tensor, \
|
||||
po2_tensor
|
||||
else:
|
||||
p_ids, passages, segment_ids = zip(*examples)
|
||||
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
|
||||
passages_tensor, _ = padding(passages, is_float=False, batch_first=batch_first)
|
||||
segment_tensor, _ = padding(segment_ids, is_float=False, batch_first=batch_first)
|
||||
return p_ids, passages_tensor, segment_tensor
|
||||
|
||||
return partial(collate)
|
||||
|
||||
def get_dataloader(self, batch_size, num_workers=0, shuffle=False, batch_first=True, pin_memory=False,
|
||||
drop_last=False):
|
||||
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._create_collate_fn(batch_first),
|
||||
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)
|
171
run/entity_relation_jointed_extraction/mpn/main.py
Normal file
171
run/entity_relation_jointed_extraction/mpn/main.py
Normal file
@ -0,0 +1,171 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from config.baidu_spo_config import BAIDU_RELATION
|
||||
from run.entity_relation_jointed_extraction.mpn.data_loader import Reader, Vocabulary, Feature
|
||||
from run.entity_relation_jointed_extraction.mpn.train import Trainer
|
||||
from utils.file_util import save, load
|
||||
|
||||
logger = logging.getLogger()
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# file parameters
|
||||
parser.add_argument("--input", default=None, type=str, required=True)
|
||||
parser.add_argument("--output"
|
||||
, default=None, type=str, required=False,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
# "cpt/baidu_w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5"
|
||||
# 'cpt/baidu_w2v/w2v.txt'
|
||||
parser.add_argument('--embedding_file', type=str,
|
||||
default='cpt/baidu_w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5')
|
||||
|
||||
# choice parameters
|
||||
parser.add_argument('--entity_type', type=str, default='disease')
|
||||
parser.add_argument('--use_word2vec', type=bool, default=False)
|
||||
parser.add_argument('--use_bert', type=bool, default=False)
|
||||
parser.add_argument('--seg_char', type=bool, default=True)
|
||||
|
||||
# train parameters
|
||||
parser.add_argument('--train_mode', type=str, default="train")
|
||||
parser.add_argument("--train_batch_size", default=4, type=int, help="Total batch size for training.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--epoch_num", default=3, type=int,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument('--patience_stop', type=int, default=10, help='Patience for learning early stop')
|
||||
parser.add_argument('--device_id', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
||||
|
||||
# bert parameters
|
||||
parser.add_argument("--do_lower_case",
|
||||
action='store_true',
|
||||
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
|
||||
"of training.")
|
||||
parser.add_argument("--bert_model", default=None, type=str,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||
"bert-base-multilingual-cased, bert-base-chinese.")
|
||||
|
||||
# model parameters
|
||||
parser.add_argument("--max_len", default=1000, type=int)
|
||||
parser.add_argument('--word_emb_size', type=int, default=300)
|
||||
parser.add_argument('--char_emb_size', type=int, default=300)
|
||||
parser.add_argument('--entity_emb_size', type=int, default=300)
|
||||
parser.add_argument('--pos_limit', type=int, default=30)
|
||||
parser.add_argument('--pos_dim', type=int, default=300)
|
||||
parser.add_argument('--pos_size', type=int, default=62)
|
||||
|
||||
parser.add_argument('--hidden_size', type=int, default=150)
|
||||
parser.add_argument('--bert_hidden_size', type=int, default=768)
|
||||
parser.add_argument('--num_layers', type=int, default=2)
|
||||
parser.add_argument('--dropout', type=int, default=0.5)
|
||||
parser.add_argument('--rnn_encoder', type=str, default='lstm', help="must choose in blow: lstm or gru")
|
||||
parser.add_argument('--bidirectional', type=bool, default=True)
|
||||
parser.add_argument('--pin_memory', type=bool, default=False)
|
||||
parser.add_argument('--transformer_layers', type=int, default=1)
|
||||
parser.add_argument('--nhead', type=int, default=4)
|
||||
parser.add_argument('--dim_feedforward', type=int, default=2048)
|
||||
args = parser.parse_args()
|
||||
if args.use_word2vec:
|
||||
args.cache_data = args.input + '/char2v_cache_data/'
|
||||
elif args.use_bert:
|
||||
args.cache_data = args.input + '/char_bert_cache_data/'
|
||||
else:
|
||||
args.cache_data = args.input + '/char_cache_data/'
|
||||
return args
|
||||
|
||||
|
||||
def bulid_dataset(args, reader, vocab, debug=False):
|
||||
char2idx, char_emb = None, None
|
||||
train_src = args.input + "/train_data.json"
|
||||
dev_src = args.input + "/dev_data.json"
|
||||
|
||||
train_examples_file = args.cache_data + "/train-examples.pkl"
|
||||
dev_examples_file = args.cache_data + "/dev-examples.pkl"
|
||||
|
||||
char_emb_file = args.cache_data + "/char_emb.pkl"
|
||||
char_dictionary = args.cache_data + "/char_dict.pkl"
|
||||
|
||||
if not os.path.exists(train_examples_file):
|
||||
|
||||
train_examples = reader.read_examples(train_src, data_type='train')
|
||||
dev_examples = reader.read_examples(dev_src, data_type='dev')
|
||||
|
||||
if not args.use_bert:
|
||||
# todo : min_word_count=3 ?
|
||||
vocab.build_vocab_only_with_char(train_examples, min_char_count=1)
|
||||
if args.use_word2vec and args.embedding_file:
|
||||
char_emb = vocab.make_embedding(vocab=vocab.char_vocab,
|
||||
embedding_file=args.embedding_file,
|
||||
emb_size=args.word_emb_size)
|
||||
save(char_emb_file, char_emb, message="char embedding")
|
||||
save(char_dictionary, vocab.char2idx, message="char dictionary")
|
||||
char2idx = vocab.char2idx
|
||||
save(train_examples_file, train_examples, message="train examples")
|
||||
save(dev_examples_file, dev_examples, message="dev examples")
|
||||
else:
|
||||
if not args.use_bert:
|
||||
if args.use_word2vec and args.embedding_file:
|
||||
char_emb = load(char_emb_file)
|
||||
char2idx = load(char_dictionary)
|
||||
logging.info("total char vocabulary size is {} ".format(len(char2idx)))
|
||||
train_examples, dev_examples = load(train_examples_file), load(dev_examples_file)
|
||||
|
||||
logging.info('train examples size is {}'.format(len(train_examples)))
|
||||
logging.info('dev examples size is {}'.format(len(dev_examples)))
|
||||
|
||||
if not args.use_bert:
|
||||
args.vocab_size = len(char2idx)
|
||||
convert_examples_features = Feature(args, token2idx_dict=char2idx)
|
||||
|
||||
train_examples = train_examples[:10] if debug else train_examples
|
||||
dev_examples = dev_examples[:10] if debug else dev_examples
|
||||
|
||||
train_data_set = convert_examples_features(train_examples, data_type='train')
|
||||
dev_data_set = convert_examples_features(dev_examples, data_type='dev')
|
||||
train_data_loader = train_data_set.get_dataloader(args.train_batch_size, shuffle=True, pin_memory=args.pin_memory)
|
||||
dev_data_loader = dev_data_set.get_dataloader(args.train_batch_size)
|
||||
|
||||
data_loaders = train_data_loader, dev_data_loader
|
||||
eval_examples = train_examples, dev_examples
|
||||
|
||||
return eval_examples, data_loaders, char_emb
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
if not os.path.exists(args.output):
|
||||
print('mkdir {}'.format(args.output))
|
||||
os.makedirs(args.output)
|
||||
if not os.path.exists(args.cache_data):
|
||||
print('mkdir {}'.format(args.cache_data))
|
||||
os.makedirs(args.cache_data)
|
||||
|
||||
logger.info("** ** * bulid dataset ** ** * ")
|
||||
reader = Reader(seg_char=args.seg_char, max_len=args.max_len)
|
||||
vocab = Vocabulary()
|
||||
|
||||
eval_examples, data_loaders, char_emb = bulid_dataset(args, reader, vocab, debug=False)
|
||||
|
||||
trainer = Trainer(args, data_loaders, eval_examples, char_emb, spo_conf=BAIDU_RELATION)
|
||||
|
||||
if args.train_mode == "train":
|
||||
trainer.train(args)
|
||||
elif args.train_mode == "eval":
|
||||
# trainer.resume(args)
|
||||
# trainer.eval_data_set("train")
|
||||
trainer.eval_data_set("dev")
|
||||
elif args.train_mode == "resume":
|
||||
# trainer.resume(args)
|
||||
trainer.show("dev") # bad case analysis
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
265
run/entity_relation_jointed_extraction/mpn/train.py
Normal file
265
run/entity_relation_jointed_extraction/mpn/train.py
Normal file
@ -0,0 +1,265 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
import models.ere_net.bert_mpn as bert_mpn
|
||||
import models.ere_net.mpn as mpn
|
||||
from utils.optimizer_util import set_optimizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, args, data_loaders, examples, char_emb, spo_conf):
|
||||
if args.use_bert:
|
||||
self.model = bert_mpn.ERENet(args, spo_conf)
|
||||
else:
|
||||
self.model = mpn.ERENet(args, char_emb, spo_conf)
|
||||
|
||||
self.args = args
|
||||
|
||||
self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu")
|
||||
self.n_gpu = torch.cuda.device_count()
|
||||
|
||||
self.id2rel = {item: key for key, item in spo_conf.items()}
|
||||
self.rel2id = spo_conf
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if self.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
self.model.to(self.device)
|
||||
# self.resume(args)
|
||||
logging.info('total gpu num is {}'.format(self.n_gpu))
|
||||
if self.n_gpu > 1:
|
||||
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
|
||||
|
||||
train_dataloader, dev_dataloader = data_loaders
|
||||
train_eval, dev_eval = examples
|
||||
self.eval_file_choice = {
|
||||
"train": train_eval,
|
||||
"dev": dev_eval,
|
||||
}
|
||||
self.data_loader_choice = {
|
||||
"train": train_dataloader,
|
||||
"dev": dev_dataloader,
|
||||
}
|
||||
self.optimizer = set_optimizer(args, self.model,
|
||||
train_steps=(int(len(train_eval) / args.train_batch_size) + 1) * args.epoch_num)
|
||||
|
||||
def train(self, args):
|
||||
|
||||
best_f1 = 0.0
|
||||
patience_stop = 0
|
||||
self.model.train()
|
||||
step_gap = 20
|
||||
for epoch in range(int(args.epoch_num)):
|
||||
|
||||
global_loss = 0.0
|
||||
|
||||
for step, batch in tqdm(enumerate(self.data_loader_choice[u"train"]), mininterval=5,
|
||||
desc=u'training at epoch : %d ' % epoch, leave=False, file=sys.stdout):
|
||||
|
||||
loss = self.forward(batch)
|
||||
|
||||
if step % step_gap == 0:
|
||||
global_loss += loss
|
||||
current_loss = global_loss / step_gap
|
||||
print(
|
||||
u"step {} / {} of epoch {}, train/loss: {}".format(step, len(self.data_loader_choice["train"]),
|
||||
epoch, current_loss))
|
||||
global_loss = 0.0
|
||||
|
||||
res_dev = self.eval_data_set("dev")
|
||||
if res_dev['f1'] >= best_f1:
|
||||
best_f1 = res_dev['f1']
|
||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||
model_to_save = self.model.module if hasattr(self.model,
|
||||
'module') else self.model # Only save the model it-self
|
||||
output_model_file = args.output + "/pytorch_model.bin"
|
||||
torch.save(model_to_save.state_dict(), str(output_model_file))
|
||||
patience_stop = 0
|
||||
else:
|
||||
patience_stop += 1
|
||||
if patience_stop >= args.patience_stop:
|
||||
return
|
||||
|
||||
def resume(self, args):
|
||||
resume_model_file = args.output + "/pytorch_model.bin"
|
||||
logging.info("=> loading checkpoint '{}'".format(resume_model_file))
|
||||
checkpoint = torch.load(resume_model_file, map_location='cpu')
|
||||
self.model.load_state_dict(checkpoint)
|
||||
|
||||
def forward(self, batch, chosen=u'train', eval=False, answer_dict=None):
|
||||
|
||||
batch = tuple(t.to(self.device) for t in batch)
|
||||
if not eval:
|
||||
|
||||
p_ids, input_ids, segment_ids, token_type_ids, s1, s2, po1, po2 = batch
|
||||
loss = self.model(passages=input_ids, token_type_ids=token_type_ids, segment_ids=segment_ids, s1=s1, s2=s2,
|
||||
po1=po1, po2=po2,
|
||||
is_eval=eval)
|
||||
if self.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
|
||||
loss.backward()
|
||||
loss = loss.item()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
return loss
|
||||
else:
|
||||
p_ids, input_ids, segment_ids = batch
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor = self.model(q_ids=p_ids, eval_file=eval_file,
|
||||
passages=input_ids, is_eval=eval)
|
||||
ans_dict = self.convert_spo_contour(qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor, eval_file,
|
||||
answer_dict, use_bert=self.args.use_bert)
|
||||
return ans_dict
|
||||
|
||||
def eval_data_set(self, chosen="dev"):
|
||||
|
||||
self.model.eval()
|
||||
|
||||
data_loader = self.data_loader_choice[chosen]
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
answer_dict = {i: [[], []] for i in range(len(eval_file))}
|
||||
|
||||
last_time = time.time()
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
|
||||
self.forward(batch, chosen, eval=True, answer_dict=answer_dict)
|
||||
used_time = time.time() - last_time
|
||||
logging.info('chosen {} took : {} sec'.format(chosen, used_time))
|
||||
res = self.evaluate(eval_file, answer_dict, chosen)
|
||||
# self.detail_evaluate(eval_file, answer_dict, chosen)
|
||||
self.model.train()
|
||||
return res
|
||||
|
||||
def show(self, chosen="dev"):
|
||||
|
||||
self.model.eval()
|
||||
answer_dict = {}
|
||||
|
||||
data_loader = self.data_loader_choice[chosen]
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
|
||||
loss, answer_dict_ = self.forward(batch, chosen, eval=True)
|
||||
answer_dict.update(answer_dict_)
|
||||
self.badcase_analysis(eval_file, answer_dict, chosen)
|
||||
|
||||
@staticmethod
|
||||
def evaluate(eval_file, answer_dict, chosen):
|
||||
|
||||
entity_em = 0
|
||||
entity_pred_num = 0
|
||||
entity_gold_num = 0
|
||||
|
||||
triple_em = 0
|
||||
triple_pred_num = 0
|
||||
triple_gold_num = 0
|
||||
for key, value in answer_dict.items():
|
||||
triple_gold = eval_file[key].gold_answer
|
||||
entity_gold = eval_file[key].sub_entity_list
|
||||
|
||||
entity_pred, triple_pred = value
|
||||
|
||||
entity_em += len(set(entity_pred) & set(entity_gold))
|
||||
entity_pred_num += len(set(entity_pred))
|
||||
entity_gold_num += len(set(entity_gold))
|
||||
|
||||
triple_em += len(set(triple_pred) & set(triple_gold))
|
||||
triple_pred_num += len(set(triple_pred))
|
||||
triple_gold_num += len(set(triple_gold))
|
||||
|
||||
entity_precision = 100.0 * entity_em / entity_pred_num if entity_pred_num > 0 else 0.
|
||||
entity_recall = 100.0 * entity_em / entity_gold_num if entity_gold_num > 0 else 0.
|
||||
entity_f1 = 2 * entity_recall * entity_precision / (entity_recall + entity_precision) if (
|
||||
entity_recall + entity_precision) != 0 else 0.0
|
||||
|
||||
precision = 100.0 * triple_em / triple_pred_num if triple_pred_num > 0 else 0.
|
||||
recall = 100.0 * triple_em / triple_gold_num if triple_gold_num > 0 else 0.
|
||||
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) != 0 else 0.0
|
||||
print('============================================')
|
||||
print("{}/entity_em: {},\tentity_pred_num&entity_gold_num: {}\t{} ".format(chosen, entity_em, entity_pred_num,
|
||||
entity_gold_num))
|
||||
print(
|
||||
"{}/entity_f1: {}, \tentity_precision: {},\tentity_recall: {} ".format(chosen, entity_f1, entity_precision,
|
||||
entity_recall))
|
||||
print('============================================')
|
||||
print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, triple_em, triple_pred_num, triple_gold_num))
|
||||
print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(chosen, f1, precision,
|
||||
recall))
|
||||
return {'f1': f1, "recall": recall, "precision": precision}
|
||||
|
||||
@staticmethod
|
||||
def badcase_analysis(eval_file, answer_dict, chosen):
|
||||
em = 0
|
||||
pre = 0
|
||||
gold = 0
|
||||
content = ''
|
||||
for key, value in answer_dict.items():
|
||||
entity_name = eval_file[int(key)].entity_name
|
||||
context = eval_file[int(key)].context
|
||||
ground_truths = eval_file[int(key)].gold_answer
|
||||
value, l1, l2 = value
|
||||
prediction = list(value) if len(value) else ['']
|
||||
assert type(prediction) == type(ground_truths)
|
||||
|
||||
intersection = set(prediction) & set(ground_truths)
|
||||
|
||||
if prediction == ground_truths == ['']:
|
||||
continue
|
||||
if set(prediction) != set(ground_truths):
|
||||
ground_truths = list(sorted(set(ground_truths)))
|
||||
prediction = list(sorted(set(prediction)))
|
||||
print('raw context is:\t' + context)
|
||||
print('subject_name is:\t' + entity_name)
|
||||
print('pred_text is:\t' + '\t'.join(prediction))
|
||||
print('gold_text is:\t' + '\t'.join(ground_truths))
|
||||
content += 'raw context is:\t' + context + '\n'
|
||||
content += 'subject_name is:\t' + entity_name + '\n'
|
||||
content += 'pred_text is:\t' + '\t'.join(prediction) + '\n'
|
||||
content += 'gold_text is:\t' + '\t'.join(ground_truths) + '\n'
|
||||
content += '==============================='
|
||||
em += len(intersection)
|
||||
pre += len(set(prediction))
|
||||
gold += len(set(ground_truths))
|
||||
with open('badcase_{}.txt'.format(chosen), 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
def convert_spo_contour(self, qid_tensor, po1, po2, s_tensor, e_tensor, eval_file, answer_dict, use_bert=False):
|
||||
for qid, s, e, o1, o2 in zip(qid_tensor.data.cpu().numpy(), s_tensor.data.cpu().numpy(),
|
||||
e_tensor.data.cpu().numpy(), po1.data.cpu().numpy(), po2.data.cpu().numpy()):
|
||||
if qid == -1:
|
||||
continue
|
||||
context = eval_file[qid.item()].context if not use_bert else eval_file[qid.item()].bert_tokens
|
||||
gold_answer = eval_file[qid].gold_answer
|
||||
|
||||
_subject = ''.join(context[s:e]) if use_bert else context[s:e]
|
||||
answers = list()
|
||||
start, end = np.where(o1 > 0.5), np.where(o2 > 0.5)
|
||||
for _start, _predict_id_start in zip(*start):
|
||||
if _start > len(context) or (_start == 0 and use_bert):
|
||||
continue
|
||||
for _end, _predict_id_end in zip(*end):
|
||||
if _start <= _end < len(context) and _predict_id_start == _predict_id_end:
|
||||
_obeject = ''.join(context[_start: _end + 1]) if use_bert else context[_start: _end + 1]
|
||||
_predicate = self.id2rel[_predict_id_start]
|
||||
answers.append((_subject, _predicate, _obeject))
|
||||
break
|
||||
|
||||
if qid not in answer_dict:
|
||||
print('erro in answer_dict ')
|
||||
else:
|
||||
answer_dict[qid][0].append((_subject, s-1, e-1))
|
||||
answer_dict[qid][1].extend(answers)
|
Loading…
Reference in New Issue
Block a user