DeepIE/run/relation_extraction/etl_stl/train.py
2020-05-20 19:32:45 +08:00

283 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# _*_ coding:utf-8 _*_
import logging
import random
import sys
import time
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
import models.spo_net.etl_stl as etl
from utils.data_util import Tokenizer
from utils.optimizer_util import set_optimizer
logger = logging.getLogger(__name__)
tokenizer = Tokenizer('cpt/bert-base-chinese/vocab.txt', do_lower_case=True)
class SPO(tuple):
"""用来存三元组的类
表现跟tuple基本一致只是重写了 __hash__ 和 __eq__ 方法,
使得在判断两个三元组是否等价时容错性更好。
"""
def __init__(self, spo):
self.spox = (
tuple(tokenizer.tokenize(spo[0])),
spo[1],
tuple(tokenizer.tokenize(spo[2])),
)
def __hash__(self):
return self.spox.__hash__()
def __eq__(self, spo):
return self.spox == spo.spox
class Trainer(object):
def __init__(self, args, data_loaders, examples, word_emb, ent_conf, spo_conf):
self.args = args
self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu")
self.n_gpu = torch.cuda.device_count()
self.id2rel = {item: key for key, item in spo_conf.items()}
self.rel2id = spo_conf
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if self.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
self.model = etl.ERENet(args, word_emb, ent_conf, spo_conf)
self.model.to(self.device)
# self.resume(args)
logging.info('total gpu num is {}'.format(self.n_gpu))
if self.n_gpu > 1:
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
train_dataloader, dev_dataloader = data_loaders
train_eval, dev_eval = examples
self.eval_file_choice = {
"train": train_eval,
"dev": dev_eval,
}
self.data_loader_choice = {
"train": train_dataloader,
"dev": dev_dataloader,
}
# args.use_bert = False
self.optimizer = set_optimizer(args, self.model,
train_steps=(int(len(train_eval) / args.train_batch_size) + 1) * args.epoch_num)
def train(self, args):
best_f1 = 0.0
patience_stop = 0
self.model.train()
step_gap = 20
for epoch in range(int(args.epoch_num)):
global_loss = 0.0
for step, batch in tqdm(enumerate(self.data_loader_choice[u"train"]), mininterval=5,
desc=u'training at epoch : %d ' % epoch, leave=False, file=sys.stdout):
loss = self.forward(batch)
if step % step_gap == 0:
global_loss += loss
current_loss = global_loss / step_gap
print(
u"step {} / {} of epoch {}, train/loss: {}".format(step, len(self.data_loader_choice["train"]),
epoch, current_loss))
global_loss = 0.0
res_dev = self.eval_data_set("dev")
if res_dev['f1'] >= best_f1:
best_f1 = res_dev['f1']
logging.info("** ** * Saving fine-tuned model ** ** * ")
model_to_save = self.model.module if hasattr(self.model,
'module') else self.model # Only save the model it-self
output_model_file = args.output + "/pytorch_model.bin"
torch.save(model_to_save.state_dict(), str(output_model_file))
patience_stop = 0
else:
patience_stop += 1
if patience_stop >= args.patience_stop:
return
def resume(self, args):
resume_model_file = args.output + "/pytorch_model.bin"
logging.info("=> loading checkpoint '{}'".format(resume_model_file))
checkpoint = torch.load(resume_model_file, map_location='cpu')
self.model.load_state_dict(checkpoint)
def forward(self, batch, chosen=u'train', eval=False, answer_dict=None):
batch = tuple(t.to(self.device) for t in batch)
if not eval:
char_ids, word_ids, token_type_ids, subject_ids, subject_labels, object_labels = batch
loss = self.model(char_ids=char_ids, word_ids=word_ids, token_type_ids=token_type_ids,
subject_ids=subject_ids, subject_labels=subject_labels, object_labels=object_labels)
if self.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
loss.backward()
loss = loss.item()
self.optimizer.step()
self.optimizer.zero_grad()
return loss
else:
p_ids, char_ids, word_ids = batch
eval_file = self.eval_file_choice[chosen]
qids, subject_pred, po_pred = self.model(q_ids=p_ids,
char_ids=char_ids,
word_ids=word_ids,
eval_file=eval_file, is_eval=eval)
ans_dict = self.convert_spo_contour(qids, subject_pred, po_pred, eval_file,
answer_dict, use_bert=self.args.use_bert)
return ans_dict
def eval_data_set(self, chosen="dev"):
self.model.eval()
data_loader = self.data_loader_choice[chosen]
eval_file = self.eval_file_choice[chosen]
answer_dict = {i: [[], []] for i in range(len(eval_file))}
last_time = time.time()
with torch.no_grad():
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
self.forward(batch, chosen, eval=True, answer_dict=answer_dict)
used_time = time.time() - last_time
logging.info('chosen {} took : {} sec'.format(chosen, used_time))
res = self.evaluate(eval_file, answer_dict, chosen)
self.model.train()
return res
def show(self, chosen="dev"):
self.model.eval()
answer_dict = {}
data_loader = self.data_loader_choice[chosen]
eval_file = self.eval_file_choice[chosen]
with torch.no_grad():
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
loss, answer_dict_ = self.forward(batch, chosen, eval=True)
answer_dict.update(answer_dict_)
# self.badcase_analysis(eval_file, answer_dict, chosen)
@staticmethod
def evaluate(eval_file, answer_dict, chosen):
entity_em = 0
entity_pred_num = 0
entity_gold_num = 0
triple_em = 0
triple_pred_num = 0
triple_gold_num = 0
X, Y, Z = 1e-10, 1e-10, 1e-10
for key, value in answer_dict.items():
triple_gold = eval_file[key].gold_answer
entity_gold = eval_file[key].sub_entity_list
entity_pred, triple_pred = value
entity_em += len(set(entity_pred) & set(entity_gold))
entity_pred_num += len(set(entity_pred))
entity_gold_num += len(set(entity_gold))
# triple_em += len(set(triple_pred) & set(triple_gold))
# triple_pred_num += len(set(triple_pred))
# triple_gold_num += len(set(triple_gold))
R = set([SPO(spo) for spo in triple_pred])
T = set([SPO(spo) for spo in triple_gold])
# R = set([spo for spo in triple_pred])
# T = set([spo for spo in triple_gold])
# if R != T:
# print(R)
# print(T)
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
entity_precision = 100.0 * entity_em / entity_pred_num if entity_pred_num > 0 else 0.
entity_recall = 100.0 * entity_em / entity_gold_num if entity_gold_num > 0 else 0.
entity_f1 = 2 * entity_recall * entity_precision / (entity_recall + entity_precision) if (
entity_recall + entity_precision) != 0 else 0.0
print('============================================')
print("{}/entity_em: {},\tentity_pred_num&entity_gold_num: {}\t{} ".format(chosen, entity_em, entity_pred_num,
entity_gold_num))
print(
"{}/entity_f1: {}, \tentity_precision: {},\tentity_recall: {} ".format(chosen, entity_f1, entity_precision,
entity_recall))
print('============================================')
print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, X, Y, Z))
print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(chosen, f1 * 100, precision * 100,
recall * 100))
return {'f1': f1, "recall": recall, "precision": precision}
def convert_spo_contour(self, qids, subject_preds, po_preds, eval_file, answer_dict, use_bert=False):
for qid, subject, po_pred in zip(qids.data.cpu().numpy(), subject_preds.data.cpu().numpy(),
po_preds.data.cpu().numpy()):
if qid == -1:
continue
tokens = eval_file[qid.item()].context
context = eval_file[qid.item()].context
seq_len = min(len(context), self.args.max_len)
pred_entities = find_tag_position(po_pred, seq_len, self.id2rel)
po_predict = []
for o_1, o_2, tag in pred_entities:
po_predict.append((tokens[subject[0]:subject[1] + 1],
tag,
tokens[o_1:o_2 + 1])
)
if qid not in answer_dict:
print('erro in answer_dict ')
else:
answer_dict[qid][0].append(
tokens[subject[0]:subject[1] + 1])
answer_dict[qid][1].extend(po_predict)
def find_tag_position(find_list, seq_len, id2rel):
tag_list = list()
j = 0
while j < seq_len:
end = j
flag = True
if find_list[j] % 2 == 0 and find_list[j] != 0:
start = j
tag = id2rel[find_list[start]].split('-')[1]
for k in range(start + 1, seq_len):
if find_list[k] != find_list[start] + 1:
end = k - 1
flag = False
break
if flag:
end = seq_len - 1
tag_list.append((start, end, tag))
j = end + 1
return tag_list