This commit is contained in:
loujie0822 2020-09-07 02:24:48 +08:00
parent 98f52ab35b
commit 6b357403e4
6 changed files with 930 additions and 0 deletions

View File

@ -0,0 +1,257 @@
"""
不使用BERT自带的tokenizer只是按照char进行序列标注
不再随机选择subject而是将其全部flatten
"""
import json
import logging
from functools import partial
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from deepIE.chip_rel.utils.data_utils import search_spo_index
from utils import extract_chinese_and_punct
from utils.data_util import sequence_padding
chineseandpunctuationextractor = extract_chinese_and_punct.ChineseAndPunctuationExtractor()
class Example(object):
def __init__(self,
p_id=None,
raw_text=None,
context=None,
choice_sub=None,
tok_to_orig_start_index=None,
tok_to_orig_end_index=None,
bert_tokens=None,
ent_labels=None,
rel_labels=None,
gold_rel=None,
gold_ent=None):
self.p_id = p_id
self.context = context
self.raw_text = raw_text
self.choice_sub = choice_sub
self.tok_to_orig_start_index = tok_to_orig_start_index
self.tok_to_orig_end_index = tok_to_orig_end_index
self.bert_tokens = bert_tokens
self.ent_labels = ent_labels
self.rel_labels = rel_labels
self.gold_rel = gold_rel
self.gold_ent = gold_ent
class InputFeature(object):
def __init__(self,
p_id=None,
passage_id=None,
token_type_id=None,
pos_start_id=None,
pos_end_id=None,
segment_id=None,
po_label=None,
s1=None,
s2=None):
self.p_id = p_id
self.passage_id = passage_id
self.token_type_id = token_type_id
self.pos_start_id = pos_start_id
self.pos_end_id = pos_end_id
self.segment_id = segment_id
self.po_label = po_label
self.s1 = s1
self.s2 = s2
class Reader(object):
def __init__(self, spo_conf, tokenizer=None, max_seq_length=None):
self.spo_conf = spo_conf
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
def read_examples(self, filename, data_type):
logging.info("Generating {} examples...".format(data_type))
return self._read(filename, data_type)
def _read(self, filename, data_type):
examples = []
gold_num = 0
with open(filename, 'r') as fr:
p_id = 0
for line in tqdm(fr.readlines()):
p_id += 1
data_line = json.loads(line.strip())
text_raw = data_line['text']
tokens = [text.lower() for text in text_raw]
tokens = tokens[:self.max_seq_length - 2]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
if 'spo_list' not in data_line:
examples.append(
Example(
p_id=p_id,
context=text_raw,
bert_tokens=tokens,
))
continue
gold_ent_lst, gold_spo_lst = [], []
spo_list = data_line['spo_list']
rel_labels = []
ent_labels = []
for spo in spo_list:
subject = spo['subject']
gold_ent_lst.append(subject)
predicate = spo['predicate']
object = spo['object']['@value']
gold_ent_lst.append(object)
gold_spo_lst.append((subject, predicate, object))
subject_sub_tokens = [text.lower() for text in subject]
object_sub_tokens = [text.lower() for text in object]
subject_start, object_start = search_spo_index(tokens, subject_sub_tokens, object_sub_tokens)
predicate_label = self.spo_conf[predicate]
if subject_start != -1 and object_start != -1:
s = (subject_start, subject_start + len(subject_sub_tokens) - 1)
o = (object_start, object_start + len(object_sub_tokens) - 1, predicate_label)
ent_labels.append((subject_start, subject_start + len(subject_sub_tokens) - 1))
ent_labels.append((object_start, object_start + len(object_sub_tokens) - 1))
rel_labels.append((s[0], o[0], o[2]))
if subject_start == -1 or object_start == -1:
print('error')
print(subject_sub_tokens, object_sub_tokens, text_raw)
examples.append(
Example(
p_id=p_id,
context=text_raw,
bert_tokens=tokens,
gold_ent=gold_ent_lst,
gold_rel=gold_spo_lst,
ent_labels=ent_labels,
rel_labels=rel_labels
))
gold_num += len(gold_spo_lst)
logging.info('total gold spo num in {} is {}'.format(data_type, gold_num))
logging.info("{} total size is {} ".format(data_type, len(examples)))
return examples
class Feature(object):
def __init__(self, max_len, spo_config, tokenizer):
self.max_len = max_len
self.spo_config = spo_config
self.tokenizer = tokenizer
def __call__(self, examples, data_type):
return self.convert_examples_to_bert_features(examples, data_type)
def convert_examples_to_bert_features(self, examples, data_type):
logging.info("convert {} examples to features .".format(data_type))
examples2features = list()
for index, example in enumerate(examples):
examples2features.append((index, example))
logging.info("Built instances is Completed")
return SPODataset(examples2features, spo_config=self.spo_config, data_type=data_type,
tokenizer=self.tokenizer, max_len=self.max_len)
class SPODataset(Dataset):
def __init__(self, data, spo_config, data_type, tokenizer=None, max_len=128):
super(SPODataset, self).__init__()
self.spo_config = spo_config
self.tokenizer = tokenizer
self.max_len = max_len
self.q_ids = [f[0] for f in data]
self.features = [f[1] for f in data]
self.is_train = True if data_type == 'train' else False
def __len__(self):
return len(self.q_ids)
def __getitem__(self, index):
return self.q_ids[index], self.features[index]
def _create_collate_fn(self):
def collate(examples):
p_ids, examples = zip(*examples)
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
batch_token_ids, batch_segment_ids = [], []
batch_ent_labels, batch_rel_labels = [], []
for example in examples:
ent_labels = example.ent_labels
rel_labels = example.rel_labels
bert_tokens = example.bert_tokens
token_ids = self.tokenizer.encode(bert_tokens)[1:-1]
segment_ids = len(token_ids) * [0]
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
ent_label_ids = np.zeros((len(token_ids), 2), dtype=np.float32)
for s in ent_labels:
ent_label_ids[s[0], 0] = 1
ent_label_ids[s[1], 1] = 1
batch_ent_labels.append(ent_label_ids)
batch_rel_labels.append(rel_labels)
batch_token_ids = sequence_padding(batch_token_ids, is_float=False)
batch_segment_ids = sequence_padding(batch_segment_ids, is_float=False)
if not self.is_train:
batch_ent_labels = sequence_padding(batch_ent_labels, padding=np.zeros(2), is_float=True)
batch_rel_labels = select_padding(batch_token_ids, batch_rel_labels, is_float=True,
class_num=len(self.spo_config), use_bert=True)
return p_ids, batch_token_ids, batch_segment_ids, batch_ent_labels, batch_rel_labels
else:
batch_ent_labels = sequence_padding(batch_ent_labels, padding=np.zeros(2), is_float=True)
batch_rel_labels = select_padding(batch_token_ids, batch_rel_labels, is_float=True,
class_num=len(self.spo_config), use_bert=True)
return batch_token_ids, batch_segment_ids, batch_ent_labels, batch_rel_labels
return partial(collate)
def get_dataloader(self, batch_size, num_workers=0, shuffle=False, pin_memory=False,
drop_last=False):
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._create_collate_fn(),
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)
def select_padding(seqs, select, is_float=False, class_num=None, use_bert=False):
lengths = [len(s) for s in seqs]
batch_length = max(lengths)
seq_tensor = torch.FloatTensor(len(seqs), batch_length, class_num, batch_length).fill_(float(0)) if is_float \
else torch.LongTensor(len(seqs), batch_length, class_num, batch_length).fill_(0)
for i, triplet_list in enumerate(select):
for triplet in triplet_list:
subject_pos = triplet[0]
object_pos = triplet[1]
predicate = triplet[2]
seq_tensor[i, subject_pos, predicate, object_pos] = 1
return seq_tensor

View File

@ -0,0 +1,177 @@
# _*_ coding:utf-8 _*_
"""
mhs_bert:
"""
import argparse
import logging
import os
import random
from warnings import simplefilter
import numpy as np
import torch
from transformers import BertTokenizer
from deepIE.chip_rel.spo_mhs_pointer.data_loader_char_total_sub import Reader, Feature
from deepIE.chip_rel.spo_mhs_pointer.train import Trainer
from deepIE.config.config import CMeIE_CONFIG
from utils.file_util import save, load
simplefilter(action='ignore', category=FutureWarning)
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
def get_args():
parser = argparse.ArgumentParser()
# file parameters
parser.add_argument("--input", default=None, type=str, required=True)
parser.add_argument("--res_path", default=None, type=str, required=False)
parser.add_argument("--output"
, default=None, type=str, required=False,
help="The output directory where the model checkpoints and predictions will be written.")
# choice parameters
parser.add_argument('--spo_version', type=str, default="v1")
# train parameters
parser.add_argument('--train_mode', type=str, default="train")
parser.add_argument("--train_batch_size", default=4, type=int, help="Total batch size for training.")
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--epoch_num", default=3, type=int,
help="Total number of training epochs to perform.")
parser.add_argument('--patience_stop', type=int, default=10, help='Patience for learning early stop')
parser.add_argument('--device_id', type=int, default=0)
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument("--debug",
action='store_true', )
parser.add_argument("--use_bert",
action='store_true', )
parser.add_argument("--diff_lr",
action='store_true', )
# bert parameters
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
"of training.")
parser.add_argument("--bert_model", default=None, type=str,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
# parser.add_argument("--tokenizer_path", default='bert-base-chinese', type=str)
# model parameters
parser.add_argument("--max_len", default=1000, type=int)
parser.add_argument('--entity_emb_size', type=int, default=300)
parser.add_argument('--pos_limit', type=int, default=30)
parser.add_argument('--pos_dim', type=int, default=300)
parser.add_argument('--pos_size', type=int, default=62)
parser.add_argument('--hidden_size', type=int, default=150)
parser.add_argument('--bert_hidden_size', type=int, default=768)
parser.add_argument('--dropout', type=int, default=0.5)
parser.add_argument('--bidirectional', type=bool, default=True)
parser.add_argument('--pin_memory', type=bool, default=False)
parser.add_argument('--activation', type=str, default='tanh')
parser.add_argument('--rel_emb_size', type=int, default=100)
parser.add_argument('--ent_emb_size', type=int, default=100)
args = parser.parse_args()
args.cache_data = args.input + '/{}_mhs_cache_data_{}/'.format(str(args.bert_model).split('/')[1],
str(args.max_len))
return args
def bulid_dataset(args, spo_config, reader, tokenizer, debug=False):
train_src = args.input + "/train_data.json"
dev_src = args.input + "/val_data.json"
test_src = args.input + "/test1.json"
train_examples_file = args.cache_data + "/train-examples.pkl"
dev_examples_file = args.cache_data + "/dev-examples.pkl"
test_examples_file = args.cache_data + "/test-examples.pkl"
if not os.path.exists(train_examples_file):
train_examples = reader.read_examples(train_src, data_type='train')
dev_examples = reader.read_examples(dev_src, data_type='dev')
test_examples = reader.read_examples(test_src, data_type='test')
save(train_examples_file, train_examples, message="train examples")
save(dev_examples_file, dev_examples, message="dev examples")
save(test_examples_file, test_examples, message="test examples")
else:
logging.info('loading train cache_data {}'.format(train_examples_file))
logging.info('loading dev cache_data {}'.format(dev_examples_file))
logging.info('loading test cache_data {}'.format(test_examples_file))
train_examples, dev_examples, test_examples = load(train_examples_file), load(dev_examples_file), load(
test_examples_file)
logging.info('train examples size is {}'.format(len(train_examples)))
logging.info('dev examples size is {}'.format(len(dev_examples)))
logging.info('test examples size is {}'.format(len(test_examples)))
convert_examples_features = Feature(max_len=args.max_len, spo_config=spo_config, tokenizer=tokenizer)
train_examples = train_examples[:2] if debug else train_examples
# train_examples = train_examples[:1000]
dev_examples = dev_examples[:2] if debug else dev_examples
test_examples = test_examples[:2] if debug else test_examples
train_data_set = convert_examples_features(train_examples, data_type='train')
dev_data_set = convert_examples_features(dev_examples, data_type='dev')
test_data_set = convert_examples_features(test_examples, data_type='test')
train_data_loader = train_data_set.get_dataloader(args.train_batch_size, shuffle=True, pin_memory=args.pin_memory)
dev_data_loader = dev_data_set.get_dataloader(args.train_batch_size)
test_data_loader = test_data_set.get_dataloader(args.train_batch_size)
data_loaders = train_data_loader, dev_data_loader, test_data_loader
eval_examples = train_examples, dev_examples, test_examples
return eval_examples, data_loaders, tokenizer
def main():
args = get_args()
if not os.path.exists(args.output):
print('mkdir {}'.format(args.output))
os.makedirs(args.output)
if not os.path.exists(args.cache_data):
print('mkdir {}'.format(args.cache_data))
os.makedirs(args.cache_data)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
logger.info("** ** * bulid dataset ** ** * ")
spo_conf = CMeIE_CONFIG
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
reader = Reader(spo_conf, tokenizer, max_seq_length=args.max_len)
eval_examples, data_loaders, tokenizer = bulid_dataset(args, spo_conf, reader, tokenizer, debug=args.debug)
trainer = Trainer(args, data_loaders, eval_examples, spo_conf=spo_conf, tokenizer=tokenizer)
if args.train_mode == "train":
trainer.train(args)
elif args.train_mode == "eval":
# trainer.resume(args)
# trainer.eval_data_set("train")
trainer.eval_data_set("dev")
elif args.train_mode == "predict":
trainer.predict_data_set("test")
elif args.train_mode == "resume":
# trainer.resume(args)
trainer.show("dev") # bad case analysis
if __name__ == '__main__':
main()

View File

@ -0,0 +1,91 @@
# _*_ coding:utf-8 _*_
import warnings
import torch
import torch.nn.functional as F
from torch import nn
from transformers import BertModel
from deepIE.config.config import CMeIE_CONFIG
warnings.filterwarnings("ignore")
class MHSNet(nn.Module):
"""
MHSNet : entity relation extraction
"""
def __init__(self, args):
super(MHSNet, self).__init__()
if args.activation.lower() == 'relu':
self.activation = nn.ReLU()
elif args.activation.lower() == 'tanh':
self.activation = nn.Tanh()
self.rel_emb = nn.Embedding(num_embeddings=len(CMeIE_CONFIG), embedding_dim=args.rel_emb_size)
self.ent_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.ent_emb_size)
self.bert = BertModel.from_pretrained(args.bert_model)
self.ent_dense = nn.Linear(self.bert.config.hidden_size, 2)
self.selection_u = nn.Linear(self.bert.config.hidden_size + args.ent_emb_size,
args.rel_emb_size)
self.selection_v = nn.Linear(self.bert.config.hidden_size + args.ent_emb_size,
args.rel_emb_size)
self.selection_uv = nn.Linear(2 * args.rel_emb_size,
args.rel_emb_size)
self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')
def forward(self, passage_ids=None, segment_ids=None, ent_ids=None, rel_ids=None,
is_eval=False):
bert_encoder = self.bert(passage_ids, token_type_ids=segment_ids, attention_mask=(passage_ids != 0).float())
bert_encoder = bert_encoder[0]
ent_pre = self.ent_dense(bert_encoder)
mask = passage_ids != 0
if is_eval:
ent_label_ids = (nn.Sigmoid()(ent_pre) > .5)[:, :, 0].long()
else:
ent_label_ids = torch.tensor(ent_ids[:, :, 0], dtype=torch.long)
print(ent_label_ids.device)
ent_encoder = self.ent_emb(ent_label_ids)
rel_encoder = torch.cat((bert_encoder, ent_encoder), dim=2)
B, L, H = rel_encoder.size()
u = self.activation(self.selection_u(rel_encoder)).unsqueeze(1).expand(B, L, L, -1)
v = self.activation(self.selection_v(rel_encoder)).unsqueeze(2).expand(B, L, L, -1)
uv = self.activation(self.selection_uv(torch.cat((u, v), dim=-1)))
selection_logits = torch.einsum('bijh,rh->birj', uv, self.rel_emb.weight)
if is_eval:
# return ent_pre, selection_logits
return ent_ids, rel_ids
else:
ent_loss = self.loss_fct(ent_pre, ent_ids)
ent_loss = ent_loss.mean(2)
ent_loss = torch.sum(ent_loss * mask.float()) / torch.sum(mask.float())
selection_loss = self.masked_BCEloss(mask, selection_logits, rel_ids)
loss = ent_loss + 100 * selection_loss
return loss, ent_loss, selection_loss
def masked_BCEloss(self, mask, selection_logits, selection_gold):
# batch x seq x rel x seq
selection_mask = (mask.unsqueeze(2) *
mask.unsqueeze(1)).unsqueeze(2).expand(-1, -1, len(CMeIE_CONFIG), -1)
selection_loss = F.binary_cross_entropy_with_logits(selection_logits,
selection_gold,
reduction='none')
selection_loss = selection_loss.masked_select(selection_mask).sum()
selection_loss /= mask.sum()
return selection_loss

View File

@ -0,0 +1,97 @@
# _*_ coding:utf-8 _*_
import numpy as np
import torch
from deepIE.config.config import CMeIE_CONFIG, Ent_BIO
reversed_relation_vocab = {v: k for k, v in CMeIE_CONFIG.items()}
reversed_bio_vocab = {v: k for k, v in Ent_BIO.items()}
def find_tag_position(find_list, seq_len, text):
tag_list = list()
j = 0
while j < seq_len:
end = j
flag = True
if find_list[j] == 1:
start = j
for k in range(start + 1, seq_len):
if find_list[k] != find_list[start] + 1:
end = k - 1
flag = False
break
if flag:
end = seq_len - 1
tag_list.append(text[start:end + 1])
j = end + 1
return tag_list
def find_entity(pos, text, sequence_tags):
entity = []
if pos >= len(text):
return ''
if sequence_tags[pos] == 'B' and (pos == len(text) - 1 or sequence_tags[pos + 1] == 'O'):
entity.append(text[pos])
elif (sequence_tags[pos] == 'I' and pos == len(text) - 1) or (
sequence_tags[pos] == 'I' and sequence_tags[pos + 1] == 'O'):
temp_entity = []
while sequence_tags[pos] == 'I':
temp_entity.append(text[pos])
pos -= 1
if pos < 0:
break
if sequence_tags[pos] == 'B':
temp_entity.append(text[pos])
break
entity = list(reversed(temp_entity))
return ''.join(entity)
def selection_decode(q_ids, eval_file, ent_pre, rel_pre):
ent_list = list()
ent_start_list = list()
for qid, sub_pred in zip(q_ids.cpu().numpy(), ent_pre.cpu().numpy()):
context = eval_file[qid].bert_tokens
raw_text = eval_file[qid].context
start = np.where(sub_pred[:, 0] > 0.5)[0]
end = np.where(sub_pred[:, 1] > 0.5)[0]
ents = []
ent_start = {}
for i in start:
j = end[end >= i]
if i == 0 or i > len(context) - 2:
continue
if len(j) > 0:
j = j[0]
if j > len(context) - 2:
continue
ent_name = raw_text[i - 1:j]
ent_start[i] = ent_name
ents.append(ent_name)
ent_list.append(ents)
ent_start_list.append(ent_start)
batch_num = len(rel_pre)
result = [[] for _ in range(batch_num)]
idx = torch.nonzero(rel_pre.cpu())
answer_dict = dict()
for i in range(idx.size(0)):
b, s, p, o = idx[i].tolist()
predicate = reversed_relation_vocab[p]
object = ent_start_list[b].get(s, '')
if object == '':
continue
subject = ent_start_list[b].get(o, '')
if subject == '':
continue
result[b].append((subject, predicate, object))
for q_id, res, ent_ in zip(q_ids.cpu(), result, ent_list):
answer_dict[q_id.item()] = (ent_, res)
return answer_dict

View File

@ -0,0 +1,308 @@
# _*_ coding:utf-8 _*_
import logging
import random
import sys
import time
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from deepIE.chip_rel.spo_mhs_pointer.mhs_pointer import MHSNet
from deepIE.chip_rel.spo_mhs_pointer.select_pointer_decoder import selection_decode
from deepIE.config.config import CMeIE_CONFIG
from layers.encoders.transformers.bert.bert_optimization import BertAdam
logger = logging.getLogger(__name__)
class Trainer(object):
def __init__(self, args, data_loaders, examples, spo_conf, tokenizer):
self.args = args
self.max_len = args.max_len
self.tokenizer = tokenizer
self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu")
self.n_gpu = torch.cuda.device_count()
self.id2rel = {item: key for key, item in spo_conf.items()}
self.rel2id = spo_conf
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if self.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
if args.use_bert:
self.model = MHSNet(args)
self.model.to(self.device)
if self.n_gpu > 1:
logging.info('total gpu num is {}'.format(self.n_gpu))
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
train_dataloader, dev_dataloader, test_dataloader = data_loaders
train_eval, dev_eval, test_eval = examples
self.eval_file_choice = {
"train": train_eval,
"dev": dev_eval,
"test": test_eval
}
self.data_loader_choice = {
"train": train_dataloader,
"dev": dev_dataloader,
"test": test_dataloader
}
self.optimizer = self.set_optimizer(args, self.model,
train_steps=(int(
len(train_eval) / args.train_batch_size) + 1) * args.epoch_num)
def set_optimizer(self, args, model, train_steps=None):
param_optimizer = list(model.named_parameters())
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
starts_flag = 'module.bert' if self.n_gpu > 1 else 'bert'
starts_crf = 'module.crf' if self.n_gpu > 1 else 'crf'
# TODO:设置不同学习率
if args.diff_lr:
# logging.info('设置不同学习率')
# for n, p in param_optimizer:
# if not n.startswith(starts_flag) and not any(nd in n for nd in no_decay):
# print(n)
# print('+' * 10)
# for n, p in param_optimizer:
# if not n.startswith(starts_flag) and any(nd in n for nd in no_decay):
# print(n)
# optimizer_grouped_parameters = [
# {'params': [p for n, p in param_optimizer if
# not any(nd in n for nd in no_decay) and n.startswith(starts_flag)],
# 'weight_decay': 0.01, 'lr': args.learning_rate},
# {'params': [p for n, p in param_optimizer if
# not any(nd in n for nd in no_decay) and not n.startswith(starts_flag)],
# 'weight_decay': 0.01, 'lr': args.learning_rate * 20},
# {'params': [p for n, p in param_optimizer if
# any(nd in n for nd in no_decay) and n.startswith(starts_flag)],
# 'weight_decay': 0.0, 'lr': args.learning_rate},
# {'params': [p for n, p in param_optimizer if
# any(nd in n for nd in no_decay) and not n.startswith(starts_flag)],
# 'weight_decay': 0.0, 'lr': args.learning_rate * 20}
# ]
logging.info('CRF层设置不同学习率')
for n, p in param_optimizer:
if n.startswith(starts_crf) and not any(nd in n for nd in no_decay):
print(n)
print('+' * 10)
for n, p in param_optimizer:
if n.startswith(starts_crf) and any(nd in n for nd in no_decay):
print(n)
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if
not any(nd in n for nd in no_decay) and n.startswith(starts_crf)],
'weight_decay': 0.01, 'lr': args.learning_rate * 10},
{'params': [p for n, p in param_optimizer if
not any(nd in n for nd in no_decay) and not n.startswith(starts_crf)],
'weight_decay': 0.01, 'lr': args.learning_rate},
{'params': [p for n, p in param_optimizer if
any(nd in n for nd in no_decay) and n.startswith(starts_crf)],
'weight_decay': 0.0, 'lr': args.learning_rate * 10},
{'params': [p for n, p in param_optimizer if
any(nd in n for nd in no_decay) and not n.startswith(starts_crf)],
'weight_decay': 0.0, 'lr': args.learning_rate}
]
else:
logging.info('原始设置学习率设置')
# TODO:原始设置
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=train_steps)
return optimizer
def train(self, args):
best_f1 = 0.0
patience_stop = 0
self.model.train()
step_gap = 20
for epoch in range(int(args.epoch_num)):
global_loss, global_crf_loss, global_selection_loss = 0.0, 0.0, 0.0
for step, batch in tqdm(enumerate(self.data_loader_choice[u"train"]), mininterval=5,
desc=u'training at epoch : %d ' % epoch, leave=False, file=sys.stdout):
loss, crf_loss, selection_loss = self.forward(batch)
if step % step_gap == 0:
global_loss += loss
global_crf_loss += crf_loss
global_selection_loss += selection_loss
current_loss = global_loss / step_gap
current_crf_loss = global_crf_loss / step_gap
current_selection_loss = global_selection_loss / step_gap
print(
u"step {} / {} of epoch {}, train/loss: {}\tcrf:{}\tsel:{}".format(step, len(
self.data_loader_choice["train"]),
epoch,
current_loss,
current_crf_loss,
current_selection_loss))
global_loss, global_crf_loss, global_selection_loss = 0.0, 0.0, 0.0
res_dev = self.eval_data_set("dev")
if res_dev['f1'] >= best_f1:
best_f1 = res_dev['f1']
logging.info("** ** * Saving fine-tuned model ** ** * ")
model_to_save = self.model.module if hasattr(self.model,
'module') else self.model # Only save the model it-self
output_model_file = args.output + "/pytorch_model.bin"
torch.save(model_to_save.state_dict(), str(output_model_file))
patience_stop = 0
else:
patience_stop += 1
if patience_stop >= args.patience_stop:
return
def resume(self, args):
resume_model_file = args.output + "/pytorch_model.bin"
logging.info("=> loading checkpoint '{}'".format(resume_model_file))
checkpoint = torch.load(resume_model_file, map_location='cpu')
self.model.load_state_dict(checkpoint)
def forward(self, batch, chosen=u'train', eval=False, answer_dict=None):
batch = tuple(t.to(self.device) for t in batch)
if not eval:
passage_ids, segment_ids, ent_ids, rel_ids = batch
loss, crf_loss, selection_loss = self.model(passage_ids=passage_ids, segment_ids=segment_ids,
ent_ids=ent_ids, rel_ids=rel_ids)
if self.n_gpu > 1:
loss = loss.mean()
crf_loss = crf_loss.mean()
selection_loss = selection_loss.mean() # mean() to average on multi-gpu.
loss.backward()
loss = loss.item()
crf_loss = crf_loss.item()
selection_loss = selection_loss.item()
self.optimizer.step()
self.optimizer.zero_grad()
return loss, crf_loss, selection_loss
else:
p_ids, passage_ids, segment_ids, ent_ids, rel_ids = batch
eval_file = self.eval_file_choice[chosen]
ent_logits, rel_logits = self.model(passage_ids=passage_ids, segment_ids=segment_ids, ent_ids=ent_ids,
rel_ids=rel_ids, is_eval=eval)
answer_dict = self.convert_select_contour(eval_file, p_ids, passage_ids, ent_logits, rel_logits)
# p_ids, passage_ids, segment_ids , ent_ids, rel_ids = batch
# eval_file = self.eval_file_choice[chosen]
# answer_dict = convert_select_contour(eval_file, p_ids, passage_ids, ent_ids, rel_ids)
return answer_dict
def eval_data_set(self, chosen="dev"):
self.model.eval()
data_loader = self.data_loader_choice[chosen]
eval_file = self.eval_file_choice[chosen]
answer_dict = {}
last_time = time.time()
with torch.no_grad():
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
answer_dict_ = self.forward(batch, chosen, eval=True, answer_dict=answer_dict)
answer_dict.update(answer_dict_)
used_time = time.time() - last_time
logging.info('chosen {} took : {} sec'.format(chosen, used_time))
res = self.evaluate(eval_file, answer_dict, chosen)
self.model.train()
return res
def show(self, chosen="dev"):
self.model.eval()
answer_dict = {}
data_loader = self.data_loader_choice[chosen]
eval_file = self.eval_file_choice[chosen]
with torch.no_grad():
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
loss, answer_dict_ = self.forward(batch, chosen, eval=True)
answer_dict.update(answer_dict_)
self.badcase_analysis(eval_file, answer_dict, chosen)
@staticmethod
def evaluate(eval_file, answer_dict, chosen):
entity_em = 0
entity_pred_num = 0
entity_gold_num = 0
X, Y, Z = 1e-10, 1e-10, 1e-10
for key, value in answer_dict.items():
triple_gold = eval_file[key].gold_rel
entity_gold = eval_file[key].gold_ent
context = eval_file[key].context
entity_pred, triple_pred = value
entity_em += len(set(entity_pred) & set(entity_gold))
entity_pred_num += len(set(entity_pred))
entity_gold_num += len(set(entity_gold))
# if set(entity_pred) != set(entity_gold):
# print(set(entity_pred))
# print(set(entity_gold))
# print(context)
# print()
R = set([spo for spo in triple_pred])
T = set([spo for spo in triple_gold])
# if R != T:
# print(R)
# print(T)
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
entity_precision = 100.0 * entity_em / entity_pred_num if entity_pred_num > 0 else 0.
entity_recall = 100.0 * entity_em / entity_gold_num if entity_gold_num > 0 else 0.
entity_f1 = 2 * entity_recall * entity_precision / (entity_recall + entity_precision) if (
entity_recall + entity_precision) != 0 else 0.0
print('============================================')
print("{}/entity_em: {},\tentity_pred_num&entity_gold_num: {}\t{} ".format(chosen, entity_em, entity_pred_num,
entity_gold_num))
print(
"{}/entity_f1: {}, \tentity_precision: {},\tentity_recall: {} ".format(chosen, entity_f1, entity_precision,
entity_recall))
print('============================================')
print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, X, Y, Z))
print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(chosen, f1 * 100, precision * 100,
recall * 100))
return {'f1': f1, "recall": recall, "precision": precision}
def convert_select_contour(self, eval_file, q_ids, input_ids, ent_logit, rel_logit):
mask = input_ids != 0
selection_mask = (mask.unsqueeze(2) *
mask.unsqueeze(1)).unsqueeze(2).expand(
-1, -1, len(CMeIE_CONFIG), -1) # batch x seq x rel x seq
# rel_pre = torch.sigmoid(rel_logit) * selection_mask.float() > 0.5
rel_pre = rel_logit > 0
ent_pre = ent_logit
answer_dict = selection_decode(q_ids, eval_file, ent_pre, rel_pre)
return answer_dict