fixed
This commit is contained in:
parent
63b1a3eff5
commit
247cd6d673
@ -14,6 +14,7 @@ class EntExtractNet(BertPreTrainedModel):
|
||||
|
||||
def __init__(self, config, classes_num):
|
||||
super(EntExtractNet, self).__init__(config, classes_num)
|
||||
print('ent_po_net.py')
|
||||
|
||||
self.bert = BertModel(config)
|
||||
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
|
||||
|
@ -43,8 +43,8 @@ def get_args():
|
||||
parser.add_argument('--device_id', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument("--debug",
|
||||
action='store_true', )
|
||||
parser.add_argument("--debug",action='store_true', )
|
||||
parser.add_argument("--diff_lr", action='store_true', )
|
||||
# bert parameters
|
||||
parser.add_argument("--do_lower_case",
|
||||
action='store_true',
|
||||
|
@ -1,6 +1,5 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import codecs
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
@ -8,12 +7,10 @@ from warnings import simplefilter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from deepIE.chip_ent.ent_extract_pointer import ent_po_net as ent_net
|
||||
from deepIE.chip_ent.ent_extract_pointer import ent_po_net_lstm as ent_net_lstm
|
||||
|
||||
from layers.encoders.transformers.bert.bert_optimization import BertAdam
|
||||
|
||||
simplefilter(action='ignore', category=FutureWarning)
|
||||
@ -29,6 +26,7 @@ class Trainer(object):
|
||||
self.max_len = args.max_len - 2
|
||||
self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu")
|
||||
self.n_gpu = torch.cuda.device_count()
|
||||
self.load_ent_dict()
|
||||
|
||||
self.id2rel = {item: key for key, item in spo_conf.items()}
|
||||
self.rel2id = spo_conf
|
||||
@ -43,9 +41,10 @@ class Trainer(object):
|
||||
self.model.to(self.device)
|
||||
if args.train_mode != "train":
|
||||
self.resume(args)
|
||||
logging.info('total gpu num is {}'.format(self.n_gpu))
|
||||
if self.n_gpu > 1:
|
||||
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
|
||||
|
||||
# if self.n_gpu > 1:
|
||||
# logging.info('total gpu num is {}'.format(self.n_gpu))
|
||||
# self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
|
||||
|
||||
train_dataloader, dev_dataloader, test_dataloader = data_loaders
|
||||
train_eval, dev_eval, test_eval = examples
|
||||
@ -68,6 +67,34 @@ class Trainer(object):
|
||||
param_optimizer = list(model.named_parameters())
|
||||
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
|
||||
# TODO:设置不同学习率
|
||||
if args.diff_lr:
|
||||
logging.info('设置不同学习率')
|
||||
for n, p in param_optimizer:
|
||||
if not n.startswith('bert') and not any(nd in n for nd in no_decay):
|
||||
print(n)
|
||||
print('+' * 10)
|
||||
for n, p in param_optimizer:
|
||||
if not n.startswith('bert') and any(nd in n for nd in no_decay):
|
||||
print(n)
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if
|
||||
not any(nd in n for nd in no_decay) and n.startswith('bert')],
|
||||
'weight_decay': 0.01, 'lr': args.learning_rate},
|
||||
{'params': [p for n, p in param_optimizer if
|
||||
not any(nd in n for nd in no_decay) and not n.startswith('bert')],
|
||||
'weight_decay': 0.01, 'lr': args.learning_rate * 10},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and n.startswith('bert')],
|
||||
'weight_decay': 0.0, 'lr': args.learning_rate},
|
||||
{'params': [p for n, p in param_optimizer if
|
||||
any(nd in n for nd in no_decay) and not n.startswith('bert')],
|
||||
'weight_decay': 0.0, 'lr': args.learning_rate * 10}
|
||||
]
|
||||
else:
|
||||
logging.info('原始设置学习率设置')
|
||||
|
||||
# TODO:原始设置
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
||||
'weight_decay': 0.01},
|
||||
@ -155,7 +182,7 @@ class Trainer(object):
|
||||
|
||||
data_loader = self.data_loader_choice[chosen]
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
answer_dict = {i: [[], []] for i in range(len(eval_file))}
|
||||
answer_dict = {i: [[], [], []] for i in range(len(eval_file))}
|
||||
|
||||
last_time = time.time()
|
||||
with torch.no_grad():
|
||||
@ -176,7 +203,7 @@ class Trainer(object):
|
||||
|
||||
data_loader = self.data_loader_choice[chosen]
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
answer_dict = {i: [[], []] for i in range(len(eval_file))}
|
||||
answer_dict = {i: [[], [], []] for i in range(len(eval_file))}
|
||||
|
||||
last_time = time.time()
|
||||
with torch.no_grad():
|
||||
@ -188,31 +215,86 @@ class Trainer(object):
|
||||
# self.convert2result(eval_file, answer_dict)
|
||||
|
||||
with codecs.open(self.args.res_path, 'w', 'utf-8') as f:
|
||||
for key, ans_list in answer_dict.items():
|
||||
out_put = {}
|
||||
out_put['text'] = eval_file[int(key)].raw_text
|
||||
spo_tuple_lst = ans_list[1]
|
||||
spo_lst = []
|
||||
for (s, p, o) in spo_tuple_lst:
|
||||
spo_lst.append({"predicate": p, "subject": s, "object": {"@value": o}})
|
||||
out_put['spo_list'] = spo_lst
|
||||
for key in answer_dict.keys():
|
||||
|
||||
json_str = json.dumps(out_put, ensure_ascii=False)
|
||||
f.write(json_str)
|
||||
f.write('\n')
|
||||
raw_text = answer_dict[key][2]
|
||||
if raw_text == []:
|
||||
continue
|
||||
pred = answer_dict[key][1]
|
||||
# pred = self.clean_result_with_dct(raw_text, pred)
|
||||
pred_text = []
|
||||
for (s, e, ent_name, ent_type) in pred:
|
||||
pred_text.append(' '.join([str(s), str(e), ent_type]))
|
||||
if len(pred_text) == 0:
|
||||
f.write(raw_text + '\n')
|
||||
else:
|
||||
f.write(raw_text + '|||' + '|||'.join(pred_text) + '|||' + '\n')
|
||||
|
||||
def clean_result(self, text, po_lst):
|
||||
"""
|
||||
清洗结果
|
||||
:return:
|
||||
"""
|
||||
|
||||
po_lst = list(set(po_lst))
|
||||
po_lst.sort(key=lambda x: x[0])
|
||||
po_lst.sort(key=lambda x: x[1] - x[0], reverse=True)
|
||||
|
||||
area_mask = [0] * len(text)
|
||||
area_type = [False] * len(text)
|
||||
new_po_list = []
|
||||
for (s, e, ent_name, ent_type) in po_lst:
|
||||
if (area_mask[s] == 1 or area_mask[e] == 1) and (not area_type[s] or not area_type[e]):
|
||||
continue
|
||||
else:
|
||||
area_mask[s:e + 1] = [1] * (e - s + 1)
|
||||
if ent_type == 'sym':
|
||||
area_type[s:e + 1] = [True] * (e - s + 1)
|
||||
else:
|
||||
area_type[s:e + 1] = [False] * (e - s + 1)
|
||||
|
||||
new_po_list.append((s, e, ent_name, ent_type))
|
||||
new_po_list.sort(key=lambda x: x[0])
|
||||
return new_po_list
|
||||
|
||||
def clean_result_with_dct(self, text, po_lst):
|
||||
"""
|
||||
清洗结果 利用词典来纠正实体类型
|
||||
:return:
|
||||
"""
|
||||
logging.info('清洗结果 利用词典来纠正实体类型')
|
||||
new_po_list = []
|
||||
for (s, e, ent_name, ent_type) in po_lst:
|
||||
ent_type_ = self.ent_dct.get(ent_name, None)
|
||||
if ent_type_ is not None:
|
||||
ent_type = ent_type_
|
||||
new_po_list.append((s, e, ent_name, ent_type))
|
||||
return new_po_list
|
||||
|
||||
def load_ent_dict(self):
|
||||
ent_dct = {}
|
||||
logging.info('loading ent dict in {}'.format('deepIE/chip_ent/data/' + 'ent_dict.txt'))
|
||||
with open('deepIE/chip_ent/data/' + 'ent_dict.txt', 'r') as fr:
|
||||
for line in fr.readlines():
|
||||
ent_name, ent_type = line.strip().split()
|
||||
ent_dct[ent_name] = ent_type
|
||||
self.ent_dct = ent_dct
|
||||
|
||||
def evaluate(self, eval_file, answer_dict, chosen):
|
||||
|
||||
spo_em, spo_pred_num, spo_gold_num = 0.0, 0.0, 0.0
|
||||
|
||||
for key in answer_dict.keys():
|
||||
context = eval_file[key].context
|
||||
raw_text = answer_dict[key][2]
|
||||
triple_gold = answer_dict[key][0]
|
||||
triple_pred = answer_dict[key][1]
|
||||
# triple_pred = self.clean_result_with_dct(raw_text, triple_pred)
|
||||
|
||||
# if set(triple_pred) != set(triple_gold):
|
||||
# print()
|
||||
# print(context)
|
||||
# print(raw_text)
|
||||
# triple_pred.sort(key=lambda x: x[0])
|
||||
# triple_gold.sort(key=lambda x: x[0])
|
||||
# print(triple_pred)
|
||||
# print(triple_gold)
|
||||
|
||||
@ -298,3 +380,4 @@ class Trainer(object):
|
||||
if len(answer_dict[text_id][0]) > 1:
|
||||
continue
|
||||
answer_dict[text_id][0] = example.g_gold_ent
|
||||
answer_dict[text_id][2] = example.g_raw_text
|
||||
|
Loading…
Reference in New Issue
Block a user