This commit is contained in:
loujie0822 2020-08-28 00:09:57 +08:00
parent 63b1a3eff5
commit 247cd6d673
3 changed files with 116 additions and 32 deletions

View File

@ -14,6 +14,7 @@ class EntExtractNet(BertPreTrainedModel):
def __init__(self, config, classes_num):
super(EntExtractNet, self).__init__(config, classes_num)
print('ent_po_net.py')
self.bert = BertModel(config)
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,

View File

@ -43,8 +43,8 @@ def get_args():
parser.add_argument('--device_id', type=int, default=0)
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument("--debug",
action='store_true', )
parser.add_argument("--debug",action='store_true', )
parser.add_argument("--diff_lr", action='store_true', )
# bert parameters
parser.add_argument("--do_lower_case",
action='store_true',

View File

@ -1,6 +1,5 @@
# _*_ coding:utf-8 _*_
import codecs
import json
import logging
import sys
import time
@ -8,12 +7,10 @@ from warnings import simplefilter
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from deepIE.chip_ent.ent_extract_pointer import ent_po_net as ent_net
from deepIE.chip_ent.ent_extract_pointer import ent_po_net_lstm as ent_net_lstm
from layers.encoders.transformers.bert.bert_optimization import BertAdam
simplefilter(action='ignore', category=FutureWarning)
@ -26,16 +23,17 @@ class Trainer(object):
self.args = args
self.tokenizer = tokenizer
self.max_len = args.max_len-2
self.max_len = args.max_len - 2
self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu")
self.n_gpu = torch.cuda.device_count()
self.load_ent_dict()
self.id2rel = {item: key for key, item in spo_conf.items()}
self.rel2id = spo_conf
if self.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
if args.encoder_type=='lstm':
if args.encoder_type == 'lstm':
self.model = ent_net_lstm.EntExtractNet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
else:
self.model = ent_net.EntExtractNet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
@ -43,9 +41,10 @@ class Trainer(object):
self.model.to(self.device)
if args.train_mode != "train":
self.resume(args)
logging.info('total gpu num is {}'.format(self.n_gpu))
if self.n_gpu > 1:
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
# if self.n_gpu > 1:
# logging.info('total gpu num is {}'.format(self.n_gpu))
# self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
train_dataloader, dev_dataloader, test_dataloader = data_loaders
train_eval, dev_eval, test_eval = examples
@ -68,11 +67,39 @@ class Trainer(object):
param_optimizer = list(model.named_parameters())
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
# TODO:设置不同学习率
if args.diff_lr:
logging.info('设置不同学习率')
for n, p in param_optimizer:
if not n.startswith('bert') and not any(nd in n for nd in no_decay):
print(n)
print('+' * 10)
for n, p in param_optimizer:
if not n.startswith('bert') and any(nd in n for nd in no_decay):
print(n)
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if
not any(nd in n for nd in no_decay) and n.startswith('bert')],
'weight_decay': 0.01, 'lr': args.learning_rate},
{'params': [p for n, p in param_optimizer if
not any(nd in n for nd in no_decay) and not n.startswith('bert')],
'weight_decay': 0.01, 'lr': args.learning_rate * 10},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and n.startswith('bert')],
'weight_decay': 0.0, 'lr': args.learning_rate},
{'params': [p for n, p in param_optimizer if
any(nd in n for nd in no_decay) and not n.startswith('bert')],
'weight_decay': 0.0, 'lr': args.learning_rate * 10}
]
else:
logging.info('原始设置学习率设置')
# TODO:原始设置
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
@ -155,7 +182,7 @@ class Trainer(object):
data_loader = self.data_loader_choice[chosen]
eval_file = self.eval_file_choice[chosen]
answer_dict = {i: [[], []] for i in range(len(eval_file))}
answer_dict = {i: [[], [], []] for i in range(len(eval_file))}
last_time = time.time()
with torch.no_grad():
@ -176,7 +203,7 @@ class Trainer(object):
data_loader = self.data_loader_choice[chosen]
eval_file = self.eval_file_choice[chosen]
answer_dict = {i: [[], []] for i in range(len(eval_file))}
answer_dict = {i: [[], [], []] for i in range(len(eval_file))}
last_time = time.time()
with torch.no_grad():
@ -188,31 +215,86 @@ class Trainer(object):
# self.convert2result(eval_file, answer_dict)
with codecs.open(self.args.res_path, 'w', 'utf-8') as f:
for key, ans_list in answer_dict.items():
out_put = {}
out_put['text'] = eval_file[int(key)].raw_text
spo_tuple_lst = ans_list[1]
spo_lst = []
for (s, p, o) in spo_tuple_lst:
spo_lst.append({"predicate": p, "subject": s, "object": {"@value": o}})
out_put['spo_list'] = spo_lst
for key in answer_dict.keys():
json_str = json.dumps(out_put, ensure_ascii=False)
f.write(json_str)
f.write('\n')
raw_text = answer_dict[key][2]
if raw_text == []:
continue
pred = answer_dict[key][1]
# pred = self.clean_result_with_dct(raw_text, pred)
pred_text = []
for (s, e, ent_name, ent_type) in pred:
pred_text.append(' '.join([str(s), str(e), ent_type]))
if len(pred_text) == 0:
f.write(raw_text + '\n')
else:
f.write(raw_text + '|||' + '|||'.join(pred_text) + '|||' + '\n')
def clean_result(self, text, po_lst):
"""
清洗结果
:return:
"""
po_lst = list(set(po_lst))
po_lst.sort(key=lambda x: x[0])
po_lst.sort(key=lambda x: x[1] - x[0], reverse=True)
area_mask = [0] * len(text)
area_type = [False] * len(text)
new_po_list = []
for (s, e, ent_name, ent_type) in po_lst:
if (area_mask[s] == 1 or area_mask[e] == 1) and (not area_type[s] or not area_type[e]):
continue
else:
area_mask[s:e + 1] = [1] * (e - s + 1)
if ent_type == 'sym':
area_type[s:e + 1] = [True] * (e - s + 1)
else:
area_type[s:e + 1] = [False] * (e - s + 1)
new_po_list.append((s, e, ent_name, ent_type))
new_po_list.sort(key=lambda x: x[0])
return new_po_list
def clean_result_with_dct(self, text, po_lst):
"""
清洗结果 利用词典来纠正实体类型
:return:
"""
logging.info('清洗结果 利用词典来纠正实体类型')
new_po_list = []
for (s, e, ent_name, ent_type) in po_lst:
ent_type_ = self.ent_dct.get(ent_name, None)
if ent_type_ is not None:
ent_type = ent_type_
new_po_list.append((s, e, ent_name, ent_type))
return new_po_list
def load_ent_dict(self):
ent_dct = {}
logging.info('loading ent dict in {}'.format('deepIE/chip_ent/data/' + 'ent_dict.txt'))
with open('deepIE/chip_ent/data/' + 'ent_dict.txt', 'r') as fr:
for line in fr.readlines():
ent_name, ent_type = line.strip().split()
ent_dct[ent_name] = ent_type
self.ent_dct = ent_dct
def evaluate(self, eval_file, answer_dict, chosen):
spo_em, spo_pred_num, spo_gold_num = 0.0, 0.0, 0.0
for key in answer_dict.keys():
context = eval_file[key].context
raw_text = answer_dict[key][2]
triple_gold = answer_dict[key][0]
triple_pred = answer_dict[key][1]
# triple_pred = self.clean_result_with_dct(raw_text, triple_pred)
# if set(triple_pred) != set(triple_gold):
# print()
# print(context)
# print(raw_text)
# triple_pred.sort(key=lambda x: x[0])
# triple_gold.sort(key=lambda x: x[0])
# print(triple_pred)
# print(triple_gold)
@ -280,7 +362,7 @@ class Trainer(object):
start, end, p = po
ent_name = context[start - 1:end]
predicate = self.id2rel[p]
po_lst.append((start-1, end-1, ent_name, predicate))
po_lst.append((start - 1, end - 1, ent_name, predicate))
if text_id not in answer_dict:
raise ValueError('text_id error in answer_dict ')
@ -297,4 +379,5 @@ class Trainer(object):
answer_dict[text_id][1].extend(po_lst)
if len(answer_dict[text_id][0]) > 1:
continue
answer_dict[text_id][0]=example.g_gold_ent
answer_dict[text_id][0] = example.g_gold_ent
answer_dict[text_id][2] = example.g_raw_text