first commit
This commit is contained in:
parent
cbfdab5e92
commit
63c1f7f368
3
.gitignore
vendored
3
.gitignore
vendored
@ -31,4 +31,5 @@ Pipfile
|
||||
*.bak
|
||||
|
||||
|
||||
|
||||
transformer_cpt/
|
||||
transformer_cpt_en/
|
@ -1,258 +0,0 @@
|
||||
"""
|
||||
实体抽取,按照字切分
|
||||
"""
|
||||
import codecs
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
from utils import extract_chinese_and_punct
|
||||
from utils.data_util import sequence_padding
|
||||
|
||||
chineseandpunctuationextractor = extract_chinese_and_punct.ChineseAndPunctuationExtractor()
|
||||
|
||||
|
||||
class Example(object):
|
||||
def __init__(self,
|
||||
p_id=None, # 当前文本序号(经过拆分)
|
||||
text_id=None, # 原始文本序号
|
||||
g_raw_text=None, # 全局文本(未拆分)
|
||||
context=None, # 当前文本(经过拆分)
|
||||
tok_to_orig_start_index=None,
|
||||
tok_to_orig_end_index=None,
|
||||
bert_tokens=None,
|
||||
l_gold_ent=None, # 局部答案(经过拆分)
|
||||
g_gold_ent=None, # 全局答案(未拆分)
|
||||
is_split=None,
|
||||
span_index=None,
|
||||
po_list=None,
|
||||
):
|
||||
self.p_id = p_id
|
||||
self.text_id = text_id
|
||||
self.context = context
|
||||
self.g_raw_text = g_raw_text
|
||||
self.tok_to_orig_start_index = tok_to_orig_start_index
|
||||
self.tok_to_orig_end_index = tok_to_orig_end_index
|
||||
self.bert_tokens = bert_tokens
|
||||
self.l_gold_ent = l_gold_ent
|
||||
self.g_gold_ent = g_gold_ent
|
||||
self.is_split = is_split
|
||||
self.span_index = span_index
|
||||
self.po_list = po_list
|
||||
|
||||
|
||||
class InputFeature(object):
|
||||
|
||||
def __init__(self,
|
||||
p_id=None,
|
||||
passage_id=None,
|
||||
token_type_id=None,
|
||||
pos_start_id=None,
|
||||
pos_end_id=None,
|
||||
segment_id=None,
|
||||
po_label=None,
|
||||
):
|
||||
self.p_id = p_id
|
||||
self.passage_id = passage_id
|
||||
self.token_type_id = token_type_id
|
||||
self.pos_start_id = pos_start_id
|
||||
self.pos_end_id = pos_end_id
|
||||
self.segment_id = segment_id
|
||||
self.po_label = po_label
|
||||
|
||||
|
||||
class Reader(object):
|
||||
def __init__(self, spo_conf, tokenizer=None, max_seq_length=None):
|
||||
self.spo_conf = spo_conf
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_length = max_seq_length - 2
|
||||
|
||||
def read_examples(self, filename, data_type):
|
||||
logging.info("Generating {} examples...".format(data_type))
|
||||
return self._read(filename, data_type)
|
||||
|
||||
def split_text(self, text):
|
||||
MAX_LEN = self.max_seq_length
|
||||
text_lst = []
|
||||
split_num = len(text) // MAX_LEN
|
||||
|
||||
for i in range(split_num + 1):
|
||||
text_lst.append(text[i * MAX_LEN:(i + 1) * MAX_LEN])
|
||||
|
||||
return text_lst
|
||||
|
||||
def _read(self, filename, data_type):
|
||||
|
||||
examples = []
|
||||
|
||||
before_text_num = 0
|
||||
after_ent_num = 0
|
||||
before_ent_num = 0
|
||||
with codecs.open(filename, 'r') as fr:
|
||||
text_id = 0
|
||||
p_id = 0
|
||||
seq_len = []
|
||||
for line in fr.readlines():
|
||||
before_text_num += 1
|
||||
|
||||
data_lst = line.strip().split('|||')
|
||||
raw_text = data_lst[0]
|
||||
seq_len.append(len(raw_text))
|
||||
|
||||
ent_lst = []
|
||||
for data in data_lst[1:]:
|
||||
if data == '': continue
|
||||
start, end, ent_type = data.split()
|
||||
ent_name = raw_text[int(start):int(end) + 1]
|
||||
ent_lst.append((int(start), int(end), ent_name, ent_type))
|
||||
ent_lst = list(set(ent_lst))
|
||||
before_ent_num += len(ent_lst)
|
||||
|
||||
text_lst = self.split_text(raw_text)
|
||||
|
||||
for i, text in enumerate(text_lst):
|
||||
|
||||
tokens = [c.lower() for c in text]
|
||||
tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
|
||||
l_gold_ent = []
|
||||
po_list = []
|
||||
for (start, end, ent_name_, ent_type) in ent_lst:
|
||||
if (i * self.max_seq_length) <= start < ((i + 1) * self.max_seq_length) and (
|
||||
i * self.max_seq_length) <= end < ((i + 1) * self.max_seq_length):
|
||||
ent_name = text[start - i * self.max_seq_length:end + 1 - i * self.max_seq_length]
|
||||
if ent_name == '':
|
||||
print('error')
|
||||
assert ent_name == ent_name_
|
||||
|
||||
po_list.append((start - i * self.max_seq_length, end - i * self.max_seq_length,
|
||||
self.spo_conf[ent_type]))
|
||||
l_gold_ent.append(
|
||||
(start - i * self.max_seq_length, end - i * self.max_seq_length, ent_name, ent_type))
|
||||
after_ent_num += len(l_gold_ent)
|
||||
|
||||
examples.append(
|
||||
Example(
|
||||
p_id=p_id,
|
||||
text_id=text_id,
|
||||
g_raw_text=raw_text,
|
||||
context=text,
|
||||
g_gold_ent=ent_lst,
|
||||
l_gold_ent=l_gold_ent,
|
||||
is_split=True if len(text_lst) > 1 else False,
|
||||
span_index=i if len(text_lst) > 1 else -1,
|
||||
bert_tokens=tokens,
|
||||
po_list=po_list,
|
||||
))
|
||||
|
||||
p_id += 1
|
||||
text_id += 1
|
||||
|
||||
logging.info('total size before split in {} is {}'.format(data_type, before_text_num))
|
||||
logging.info('total size after split in {} is {}'.format(data_type, len(examples)))
|
||||
logging.info('after_ent_num in {} is {}'.format(data_type, after_ent_num))
|
||||
logging.info('before_ent_num in {} is {}'.format(data_type, before_ent_num))
|
||||
logging.info("{} total size is {} ".format(data_type, len(examples)))
|
||||
logging.info("=" * 15)
|
||||
return examples
|
||||
|
||||
|
||||
class Feature(object):
|
||||
def __init__(self, max_len, spo_config, tokenizer):
|
||||
self.max_len = max_len
|
||||
self.spo_config = spo_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, examples, data_type):
|
||||
return self.convert_examples_to_bert_features(examples, data_type)
|
||||
|
||||
def convert_examples_to_bert_features(self, examples, data_type):
|
||||
logging.info("convert {} examples to features .".format(data_type))
|
||||
|
||||
examples2features = list()
|
||||
for index, example in enumerate(examples):
|
||||
examples2features.append((index, example))
|
||||
|
||||
logging.info("Built instances is Completed")
|
||||
return SPODataset(examples2features, spo_config=self.spo_config, data_type=data_type,
|
||||
tokenizer=self.tokenizer, max_len=self.max_len)
|
||||
|
||||
|
||||
class SPODataset(Dataset):
|
||||
def __init__(self, data, spo_config, data_type, tokenizer=None, max_len=128):
|
||||
super(SPODataset, self).__init__()
|
||||
self.spo_config = spo_config
|
||||
self.tokenizer = tokenizer
|
||||
self.max_len = max_len
|
||||
self.q_ids = [f[0] for f in data]
|
||||
self.features = [f[1] for f in data]
|
||||
self.is_train = True if data_type == 'train' else False
|
||||
|
||||
def __len__(self):
|
||||
return len(self.q_ids)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.q_ids[index], self.features[index]
|
||||
|
||||
def _create_collate_fn(self):
|
||||
def collate(examples):
|
||||
p_ids, examples = zip(*examples)
|
||||
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
|
||||
batch_token_ids, batch_segment_ids = [], []
|
||||
batch_token_type_ids, batch_subject_labels, batch_point_labels, batch_span_labels = [], [], [], []
|
||||
for example in examples:
|
||||
token_ids = self.tokenizer.encode(example.bert_tokens)[1:-1]
|
||||
token_type_ids = np.zeros(len(token_ids), dtype=np.long)
|
||||
segment_ids = len(token_ids) * [0]
|
||||
|
||||
batch_token_ids.append(token_ids)
|
||||
batch_token_type_ids.append(token_type_ids)
|
||||
batch_segment_ids.append(segment_ids)
|
||||
|
||||
batch_span_labels.append(example.po_list)
|
||||
|
||||
point_labels = np.zeros((len(token_ids), 2), dtype=np.int)
|
||||
for (s,e,p) in example.po_list:
|
||||
point_labels[s+1, 0] = 1
|
||||
point_labels[e+1, 1] = 1
|
||||
batch_point_labels.append(point_labels)
|
||||
|
||||
batch_token_ids = sequence_padding(batch_token_ids, is_float=False)
|
||||
batch_token_type_ids = sequence_padding(batch_token_type_ids, is_float=False)
|
||||
batch_segment_ids = sequence_padding(batch_segment_ids, is_float=False)
|
||||
|
||||
batch_point_labels = sequence_padding(batch_point_labels, padding=np.zeros(2), is_float=False)
|
||||
batch_span_labels = span_padding(batch_token_ids, batch_span_labels, is_float=True,
|
||||
class_num=len(self.spo_config))
|
||||
if not self.is_train:
|
||||
return p_ids, batch_token_ids, batch_token_type_ids, batch_segment_ids, batch_point_labels,batch_span_labels
|
||||
else:
|
||||
|
||||
return batch_token_ids, batch_token_type_ids, batch_segment_ids, batch_point_labels,batch_span_labels
|
||||
|
||||
return partial(collate)
|
||||
|
||||
def get_dataloader(self, batch_size, num_workers=0, shuffle=False, pin_memory=False,
|
||||
drop_last=False):
|
||||
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._create_collate_fn(),
|
||||
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)
|
||||
|
||||
|
||||
def span_padding(seqs, span_lable, is_float=False, class_num=None):
|
||||
lengths = [len(s) for s in seqs]
|
||||
batch_length = max(lengths)
|
||||
span_tensor = torch.FloatTensor(len(seqs), batch_length,batch_length, class_num, ).fill_(float(0)) if is_float \
|
||||
else torch.LongTensor(len(seqs), batch_length, batch_length,class_num).fill_(0)
|
||||
|
||||
for i, po_list in enumerate(span_lable):
|
||||
for po in po_list:
|
||||
start_pos = po[0]+1
|
||||
end_pos = po[1]+1
|
||||
predicate = po[2]
|
||||
|
||||
span_tensor[i, start_pos, end_pos, predicate] = 1
|
||||
|
||||
return span_tensor
|
@ -1,159 +0,0 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from warnings import simplefilter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from deepIE.chip_ent.ent_stacked_span.data_loader_char import Reader, Feature
|
||||
from deepIE.chip_ent.ent_stacked_span.train import Trainer
|
||||
from deepIE.config.config import CMeEnt_CONFIG
|
||||
from utils.file_util import save, load
|
||||
|
||||
simplefilter(action='ignore', category=FutureWarning)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# file parameters
|
||||
parser.add_argument("--input", default=None, type=str, required=True)
|
||||
parser.add_argument("--res_path", default=None, type=str, required=False)
|
||||
parser.add_argument("--output"
|
||||
, default=None, type=str, required=False,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
|
||||
# choice parameters
|
||||
parser.add_argument('--spo_version', type=str, default="v1")
|
||||
|
||||
# train parameters
|
||||
parser.add_argument('--train_mode', type=str, default="train")
|
||||
parser.add_argument("--train_batch_size", default=4, type=int, help="Total batch size for training.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--epoch_num", default=3, type=int,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument('--patience_stop', type=int, default=10, help='Patience for learning early stop')
|
||||
parser.add_argument('--device_id', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument("--debug",action='store_true', )
|
||||
parser.add_argument("--diff_lr", action='store_true', )
|
||||
# bert parameters
|
||||
parser.add_argument("--do_lower_case",
|
||||
action='store_true',
|
||||
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||
parser.add_argument("--warmup_proportion", default=0.04, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
|
||||
"of training.")
|
||||
parser.add_argument("--bert_model", default=None, type=str,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||
"bert-base-multilingual-cased, bert-base-chinese.")
|
||||
# parser.add_argument("--tokenizer_path", default='bert-base-chinese', type=str)
|
||||
|
||||
# model parameters
|
||||
parser.add_argument("--max_len", default=1000, type=int)
|
||||
parser.add_argument('--entity_emb_size', type=int, default=300)
|
||||
parser.add_argument('--pos_limit', type=int, default=30)
|
||||
parser.add_argument('--pos_dim', type=int, default=300)
|
||||
parser.add_argument('--pos_size', type=int, default=62)
|
||||
|
||||
parser.add_argument('--hidden_size', type=int, default=150)
|
||||
parser.add_argument('--bert_hidden_size', type=int, default=768)
|
||||
parser.add_argument('--dropout', type=int, default=0.5)
|
||||
parser.add_argument('--bidirectional', type=bool, default=True)
|
||||
parser.add_argument('--pin_memory', type=bool, default=False)
|
||||
parser.add_argument('--encoder_type', type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
args.cache_data = args.input + '/{}_stacked_span_cache_data_{}/'.format(str(args.bert_model).split('/')[1], str(args.max_len))
|
||||
return args
|
||||
|
||||
|
||||
def bulid_dataset(args, spo_config, reader, tokenizer, debug=False):
|
||||
train_src = args.input + "/train_data.txt"
|
||||
dev_src = args.input + "/val_data.txt"
|
||||
test_src = args.input + "/test1.txt"
|
||||
|
||||
train_examples_file = args.cache_data + "/train-examples.pkl"
|
||||
dev_examples_file = args.cache_data + "/dev-examples.pkl"
|
||||
test_examples_file = args.cache_data + "/test-examples.pkl"
|
||||
|
||||
# if not os.path.exists(train_examples_file):
|
||||
train_examples = reader.read_examples(train_src, data_type='train')
|
||||
dev_examples = reader.read_examples(dev_src, data_type='dev')
|
||||
test_examples = reader.read_examples(test_src, data_type='test')
|
||||
|
||||
save(train_examples_file, train_examples, message="train examples")
|
||||
save(dev_examples_file, dev_examples, message="dev examples")
|
||||
save(test_examples_file, test_examples, message="test examples")
|
||||
# else:
|
||||
# logging.info('loading train cache_data {}'.format(train_examples_file))
|
||||
# logging.info('loading dev cache_data {}'.format(dev_examples_file))
|
||||
# logging.info('loading test cache_data {}'.format(test_examples_file))
|
||||
# train_examples, dev_examples, test_examples = load(train_examples_file), load(dev_examples_file), load(
|
||||
# test_examples_file)
|
||||
#
|
||||
# logging.info('train examples size is {}'.format(len(train_examples)))
|
||||
# logging.info('dev examples size is {}'.format(len(dev_examples)))
|
||||
# logging.info('test examples size is {}'.format(len(test_examples)))
|
||||
|
||||
convert_examples_features = Feature(max_len=args.max_len, spo_config=spo_config, tokenizer=tokenizer)
|
||||
|
||||
train_examples = train_examples[:3] if debug else train_examples
|
||||
dev_examples = dev_examples[:3] if debug else dev_examples
|
||||
test_examples = test_examples[:3] if debug else test_examples
|
||||
|
||||
train_data_set = convert_examples_features(train_examples, data_type='train')
|
||||
dev_data_set = convert_examples_features(dev_examples, data_type='dev')
|
||||
test_data_set = convert_examples_features(test_examples, data_type='test')
|
||||
|
||||
train_data_loader = train_data_set.get_dataloader(args.train_batch_size, shuffle=True, pin_memory=args.pin_memory)
|
||||
dev_data_loader = dev_data_set.get_dataloader(args.train_batch_size)
|
||||
test_data_loader = test_data_set.get_dataloader(args.train_batch_size)
|
||||
|
||||
data_loaders = train_data_loader, dev_data_loader, test_data_loader
|
||||
eval_examples = train_examples, dev_examples, test_examples
|
||||
|
||||
return eval_examples, data_loaders, tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
if not os.path.exists(args.output):
|
||||
print('mkdir {}'.format(args.output))
|
||||
os.makedirs(args.output)
|
||||
if not os.path.exists(args.cache_data):
|
||||
print('mkdir {}'.format(args.cache_data))
|
||||
os.makedirs(args.cache_data)
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
logger.info("** ** * bulid dataset ** ** * ")
|
||||
|
||||
spo_conf = CMeEnt_CONFIG
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
|
||||
reader = Reader(spo_conf, tokenizer, max_seq_length=args.max_len)
|
||||
eval_examples, data_loaders, tokenizer = bulid_dataset(args, spo_conf, reader, tokenizer, debug=args.debug)
|
||||
trainer = Trainer(args, data_loaders, eval_examples, spo_conf=spo_conf, tokenizer=tokenizer)
|
||||
|
||||
if args.train_mode == "train":
|
||||
trainer.train(args)
|
||||
elif args.train_mode == "eval":
|
||||
# trainer.resume(args)
|
||||
# trainer.eval_data_set("train")
|
||||
trainer.eval_data_set("dev")
|
||||
elif args.train_mode == "predict":
|
||||
trainer.predict_data_set("test")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,125 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import BertModel
|
||||
from transformers import BertPreTrainedModel
|
||||
|
||||
|
||||
class SingleNonLinearClassifier(nn.Module):
|
||||
def __init__(self, hidden_size, num_label, dropout_rate):
|
||||
super(SingleNonLinearClassifier, self).__init__()
|
||||
self.num_label = num_label
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.classifier = nn.Linear(hidden_size, num_label)
|
||||
|
||||
def forward(self, input_features):
|
||||
input_features = self.dropout(input_features)
|
||||
features_output = self.classifier(input_features)
|
||||
features_output = F.gelu(features_output)
|
||||
return features_output
|
||||
|
||||
|
||||
class MultiNonLinearClassifier(nn.Module):
|
||||
def __init__(self, hidden_size, num_label, dropout_rate):
|
||||
super(MultiNonLinearClassifier, self).__init__()
|
||||
self.num_label = num_label
|
||||
self.classifier1 = nn.Linear(hidden_size, num_label)
|
||||
# self.classifier2 = nn.Linear(int(hidden_size / 2), num_label)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, input_features):
|
||||
input_features = self.dropout(input_features)
|
||||
features_output1 = self.classifier1(input_features)
|
||||
# features_output1 = nn.ReLU()(features_output1)
|
||||
# features_output2 = self.classifier2(features_output1)
|
||||
return features_output1
|
||||
|
||||
|
||||
class EntExtractNet(BertPreTrainedModel):
|
||||
"""
|
||||
Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware and
|
||||
encoded by BERT
|
||||
"""
|
||||
|
||||
def __init__(self, config, classes_num):
|
||||
super(EntExtractNet, self).__init__(config, classes_num)
|
||||
print('ent_po_net.py')
|
||||
|
||||
self.bert = BertModel(config)
|
||||
|
||||
self.classes_num = classes_num
|
||||
|
||||
self.start_outputs = SingleNonLinearClassifier(config.hidden_size, 2, dropout_rate=0.1)
|
||||
self.end_outputs = SingleNonLinearClassifier(config.hidden_size, 2, dropout_rate=0.1)
|
||||
|
||||
self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2, self.classes_num, dropout_rate=0.1)
|
||||
|
||||
self.subject_dense = nn.Linear(config.hidden_size, 2)
|
||||
|
||||
self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, q_ids=None, passage_id=None, token_type_id=None, segment_id=None, point_labels=None,
|
||||
span_labels=None, is_eval=False):
|
||||
mask = (passage_id != 0).float()
|
||||
bert_encoder, _ = self.bert(passage_id, token_type_ids=segment_id,
|
||||
attention_mask=mask) # batch x seq_len x hidden
|
||||
batch_size, seq_len, hid_size = bert_encoder.size()
|
||||
|
||||
start_logits = self.start_outputs(bert_encoder) # batch x seq_len x 2
|
||||
end_logits = self.end_outputs(bert_encoder) # batch x seq_len x 2
|
||||
|
||||
# for every position $i$ in sequence, should concate $j$ to
|
||||
# predict if $i$ and $j$ are start_pos and end_pos for an entity.
|
||||
start_extend = bert_encoder.unsqueeze(2).expand(-1, -1, seq_len, -1)
|
||||
end_extend = bert_encoder.unsqueeze(1).expand(-1, seq_len, -1, -1)
|
||||
|
||||
span_matrix = torch.cat([start_extend, end_extend], 3) # batch x seq_len x seq_len x 2*hidden
|
||||
|
||||
span_logits = self.span_embedding(span_matrix) # batch x seq_len x seq_len x 1
|
||||
span_logits = torch.squeeze(span_logits,-1) # batch x seq_len x seq_len
|
||||
|
||||
if not is_eval:
|
||||
start_positions = point_labels[:, :, 0]
|
||||
end_positions = point_labels[:, :, 1]
|
||||
|
||||
valid_num = torch.sum(mask)
|
||||
|
||||
loss_fct = nn.CrossEntropyLoss(reduction="none")
|
||||
|
||||
start_loss = loss_fct(start_logits.view(-1, 2), start_positions.view(-1))
|
||||
start_loss = torch.sum(start_loss * mask.view(-1))
|
||||
start_loss = start_loss / valid_num.float()
|
||||
|
||||
end_loss = loss_fct(end_logits.view(-1, 2), end_positions.view(-1))
|
||||
end_loss = torch.sum(end_loss * mask.view(-1))
|
||||
end_loss = end_loss / valid_num.float()
|
||||
|
||||
span_loss_fct = nn.BCEWithLogitsLoss(reduction="none")
|
||||
|
||||
span_mask = (mask.unsqueeze(2) *
|
||||
mask.unsqueeze(1)).unsqueeze(3).expand(-1, -1, -1, self.classes_num)
|
||||
|
||||
span_loss = span_loss_fct(span_logits.view(batch_size, -1), span_labels.view(batch_size, -1).float())
|
||||
valid_span_num = torch.sum(span_mask)
|
||||
span_loss = torch.sum(span_loss.view(-1) * span_mask.reshape(-1).float())
|
||||
span_loss = span_loss / valid_span_num.float()
|
||||
|
||||
# total_loss = start_loss + end_loss + span_loss
|
||||
total_loss = start_loss + end_loss + span_loss
|
||||
|
||||
return total_loss, start_loss, end_loss, span_loss
|
||||
else:
|
||||
span_scores = torch.sigmoid(span_logits) # batch x seq_len x seq_len
|
||||
start_labels = torch.argmax(start_logits, dim=-1)
|
||||
end_labels = torch.argmax(end_logits, dim=-1)
|
||||
# print(span_scores.size(),start_labels.size())
|
||||
return start_labels, end_labels, span_scores
|
||||
|
||||
# start_positions = point_labels[:, :, 0]
|
||||
# end_positions = point_labels[:, :, 1]
|
||||
# span_scores = torch.sigmoid(span_logits)
|
||||
#
|
||||
# return start_positions, end_positions, span_scores
|
||||
|
@ -1,413 +0,0 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import codecs
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from warnings import simplefilter
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from deepIE.chip_ent.ent_stacked_span import stacked_span as ent_net
|
||||
from layers.encoders.transformers.bert.bert_optimization import BertAdam
|
||||
|
||||
simplefilter(action='ignore', category=FutureWarning)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, args, data_loaders, examples, spo_conf, tokenizer):
|
||||
|
||||
self.args = args
|
||||
self.tokenizer = tokenizer
|
||||
self.max_len = args.max_len - 2
|
||||
self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu")
|
||||
self.n_gpu = torch.cuda.device_count()
|
||||
self.load_ent_dict()
|
||||
|
||||
self.id2rel = {item: key for key, item in spo_conf.items()}
|
||||
self.rel2id = spo_conf
|
||||
|
||||
if self.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
if args.encoder_type == 'lstm':
|
||||
self.model = ent_net_lstm.EntExtractNet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
|
||||
else:
|
||||
self.model = ent_net.EntExtractNet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
|
||||
|
||||
self.model.to(self.device)
|
||||
if args.train_mode != "train":
|
||||
self.resume(args)
|
||||
|
||||
if self.n_gpu > 1:
|
||||
logging.info('total gpu num is {}'.format(self.n_gpu))
|
||||
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
|
||||
|
||||
train_dataloader, dev_dataloader, test_dataloader = data_loaders
|
||||
train_eval, dev_eval, test_eval = examples
|
||||
self.eval_file_choice = {
|
||||
"train": train_eval,
|
||||
"dev": dev_eval,
|
||||
"test": test_eval
|
||||
}
|
||||
self.data_loader_choice = {
|
||||
"train": train_dataloader,
|
||||
"dev": dev_dataloader,
|
||||
"test": test_dataloader
|
||||
}
|
||||
# todo 稍后要改成新的优化器,并加入梯度截断
|
||||
self.optimizer = self.set_optimizer(args, self.model,
|
||||
train_steps=(int(
|
||||
len(train_eval) / args.train_batch_size) + 1) * args.epoch_num)
|
||||
|
||||
def set_optimizer(self, args, model, train_steps=None):
|
||||
param_optimizer = list(model.named_parameters())
|
||||
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
flag = 'module.bert' if self.n_gpu > 1 else 'bert'
|
||||
|
||||
# TODO:设置不同学习率
|
||||
if args.diff_lr:
|
||||
logging.info('设置不同学习率')
|
||||
for n, p in param_optimizer:
|
||||
if not n.startswith(flag) and not any(nd in n for nd in no_decay):
|
||||
print(n)
|
||||
print('+' * 10)
|
||||
for n, p in param_optimizer:
|
||||
if not n.startswith(flag) and any(nd in n for nd in no_decay):
|
||||
print(n)
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if
|
||||
not any(nd in n for nd in no_decay) and n.startswith(flag)],
|
||||
'weight_decay': 0.01, 'lr': args.learning_rate},
|
||||
{'params': [p for n, p in param_optimizer if
|
||||
not any(nd in n for nd in no_decay) and not n.startswith(flag)],
|
||||
'weight_decay': 0.01, 'lr': args.learning_rate * 10},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and n.startswith(flag)],
|
||||
'weight_decay': 0.0, 'lr': args.learning_rate},
|
||||
{'params': [p for n, p in param_optimizer if
|
||||
any(nd in n for nd in no_decay) and not n.startswith(flag)],
|
||||
'weight_decay': 0.0, 'lr': args.learning_rate * 10}
|
||||
]
|
||||
else:
|
||||
logging.info('原始设置学习率设置')
|
||||
|
||||
# TODO:原始设置
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
||||
'weight_decay': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=train_steps)
|
||||
return optimizer
|
||||
|
||||
def train(self, args):
|
||||
|
||||
best_f1 = 0.0
|
||||
patience_stop = 0
|
||||
self.model.train()
|
||||
step_gap = 20
|
||||
for epoch in range(int(args.epoch_num)):
|
||||
|
||||
global_loss, global_start_loss, global_end_loss, global_span_loss = 0.0, 0.0, 0.0, 0.0
|
||||
|
||||
for step, batch in tqdm(enumerate(self.data_loader_choice[u"train"]), mininterval=5,
|
||||
desc=u'training at epoch : %d ' % epoch, leave=False, file=sys.stdout):
|
||||
|
||||
loss, start_loss, end_loss, span_loss = self.forward(batch)
|
||||
|
||||
if step % step_gap == 0:
|
||||
global_loss += loss
|
||||
global_start_loss += start_loss
|
||||
global_end_loss += end_loss
|
||||
global_span_loss += span_loss
|
||||
current_loss = global_loss / step_gap
|
||||
current_start_loss = global_start_loss / step_gap
|
||||
current_end_loss = global_end_loss / step_gap
|
||||
current_span_loss = global_span_loss / step_gap
|
||||
print(
|
||||
u"step {} / {} of epoch {}, train/loss: {}\tstart:{}\tend:{}\tspan:{}".format(step, len(
|
||||
self.data_loader_choice["train"]), epoch, round(current_loss * 100, 5),
|
||||
round(
|
||||
current_start_loss * 100,
|
||||
5),
|
||||
round(
|
||||
current_end_loss * 100,
|
||||
5),
|
||||
round(
|
||||
current_span_loss * 100,
|
||||
5)))
|
||||
global_loss, global_start_loss, global_end_loss, global_span_loss = 0.0, 0.0, 0.0, 0.0
|
||||
|
||||
res_dev = self.eval_data_set("dev")
|
||||
if res_dev['f1'] >= best_f1:
|
||||
best_f1 = res_dev['f1']
|
||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||
model_to_save = self.model.module if hasattr(self.model,
|
||||
'module') else self.model # Only save the model it-self
|
||||
output_model_file = args.output + "/pytorch_model.bin"
|
||||
torch.save(model_to_save.state_dict(), str(output_model_file))
|
||||
patience_stop = 0
|
||||
else:
|
||||
patience_stop += 1
|
||||
if patience_stop >= args.patience_stop:
|
||||
return
|
||||
|
||||
def resume(self, args):
|
||||
resume_model_file = args.output + "/pytorch_model.bin"
|
||||
logging.info("=> loading checkpoint '{}'".format(resume_model_file))
|
||||
checkpoint = torch.load(resume_model_file, map_location='cpu')
|
||||
self.model.load_state_dict(checkpoint)
|
||||
|
||||
def forward(self, batch, chosen=u'train', eval=False, answer_dict=None):
|
||||
|
||||
batch = tuple(t.to(self.device) for t in batch)
|
||||
if not eval:
|
||||
input_ids, token_type_ids, segment_ids, point_labels, span_labels = batch
|
||||
loss, start_loss, end_loss, span_loss = self.model(passage_id=input_ids, token_type_id=token_type_ids,
|
||||
segment_id=segment_ids, point_labels=point_labels,
|
||||
span_labels=span_labels)
|
||||
if self.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
start_loss = start_loss.mean()
|
||||
end_loss = end_loss.mean()
|
||||
span_loss = span_loss.mean()
|
||||
|
||||
loss.backward()
|
||||
loss = loss.item()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
return loss, start_loss.item(), end_loss.item(), span_loss.item()
|
||||
else:
|
||||
p_ids, input_ids, token_type_ids, segment_ids, point_labels, span_labels = batch
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
start_pred, end_pred, span_scores = self.model(passage_id=input_ids,
|
||||
token_type_id=token_type_ids,
|
||||
segment_id=segment_ids, point_labels=point_labels,
|
||||
span_labels=span_labels,
|
||||
is_eval=eval)
|
||||
ans_dict = self.convert_spo_contour(p_ids, start_pred, end_pred, span_scores, eval_file,
|
||||
answer_dict)
|
||||
return ans_dict
|
||||
|
||||
def eval_data_set(self, chosen="dev"):
|
||||
|
||||
self.model.eval()
|
||||
|
||||
data_loader = self.data_loader_choice[chosen]
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
answer_dict = {i: [[], [], []] for i in range(len(eval_file))}
|
||||
|
||||
last_time = time.time()
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
|
||||
self.forward(batch, chosen, eval=True, answer_dict=answer_dict)
|
||||
used_time = time.time() - last_time
|
||||
logging.info('chosen {} took : {} sec'.format(chosen, used_time))
|
||||
|
||||
# self.convert2result(eval_file, answer_dict)
|
||||
|
||||
res = self.evaluate(eval_file, answer_dict, chosen)
|
||||
self.model.train()
|
||||
return res
|
||||
|
||||
def predict_data_set(self, chosen="dev"):
|
||||
|
||||
self.model.eval()
|
||||
|
||||
data_loader = self.data_loader_choice[chosen]
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
answer_dict = {i: [[], [], []] for i in range(len(eval_file))}
|
||||
|
||||
last_time = time.time()
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
|
||||
self.forward(batch, chosen, eval=True, answer_dict=answer_dict)
|
||||
used_time = time.time() - last_time
|
||||
logging.info('chosen {} took : {} sec'.format(chosen, used_time))
|
||||
|
||||
# self.convert2result(eval_file, answer_dict)
|
||||
|
||||
with codecs.open(self.args.res_path, 'w', 'utf-8') as f:
|
||||
for key in answer_dict.keys():
|
||||
|
||||
raw_text = answer_dict[key][2]
|
||||
if raw_text == []:
|
||||
continue
|
||||
pred = answer_dict[key][1]
|
||||
# pred = self.clean_result_with_dct(raw_text, pred)
|
||||
pred_text = []
|
||||
for (s, e, ent_name, ent_type) in pred:
|
||||
pred_text.append(' '.join([str(s), str(e), ent_type]))
|
||||
if len(pred_text) == 0:
|
||||
f.write(raw_text + '\n')
|
||||
else:
|
||||
f.write(raw_text + '|||' + '|||'.join(pred_text) + '|||' + '\n')
|
||||
|
||||
def clean_result(self, text, po_lst):
|
||||
"""
|
||||
清洗结果
|
||||
:return:
|
||||
"""
|
||||
|
||||
po_lst = list(set(po_lst))
|
||||
po_lst.sort(key=lambda x: x[0])
|
||||
po_lst.sort(key=lambda x: x[1] - x[0], reverse=True)
|
||||
|
||||
area_mask = [0] * len(text)
|
||||
area_type = [False] * len(text)
|
||||
new_po_list = []
|
||||
for (s, e, ent_name, ent_type) in po_lst:
|
||||
if (area_mask[s] == 1 or area_mask[e] == 1) and (not area_type[s] or not area_type[e]):
|
||||
continue
|
||||
else:
|
||||
area_mask[s:e + 1] = [1] * (e - s + 1)
|
||||
if ent_type == 'sym':
|
||||
area_type[s:e + 1] = [True] * (e - s + 1)
|
||||
else:
|
||||
area_type[s:e + 1] = [False] * (e - s + 1)
|
||||
|
||||
new_po_list.append((s, e, ent_name, ent_type))
|
||||
new_po_list.sort(key=lambda x: x[0])
|
||||
return new_po_list
|
||||
|
||||
def clean_result_with_dct(self, text, po_lst):
|
||||
"""
|
||||
清洗结果 利用词典来纠正实体类型
|
||||
:return:
|
||||
"""
|
||||
logging.info('清洗结果 利用词典来纠正实体类型')
|
||||
new_po_list = []
|
||||
for (s, e, ent_name, ent_type) in po_lst:
|
||||
ent_type_ = self.ent_dct.get(ent_name, None)
|
||||
if ent_type_ is not None:
|
||||
ent_type = ent_type_
|
||||
new_po_list.append((s, e, ent_name, ent_type))
|
||||
return new_po_list
|
||||
|
||||
def load_ent_dict(self):
|
||||
ent_dct = {}
|
||||
logging.info('loading ent dict in {}'.format('deepIE/chip_ent/data/' + 'ent_dict.txt'))
|
||||
with open('deepIE/chip_ent/data/' + 'ent_dict.txt', 'r') as fr:
|
||||
for line in fr.readlines():
|
||||
ent_name, ent_type = line.strip().split()
|
||||
ent_dct[ent_name] = ent_type
|
||||
self.ent_dct = ent_dct
|
||||
|
||||
def evaluate(self, eval_file, answer_dict, chosen):
|
||||
|
||||
spo_em, spo_pred_num, spo_gold_num = 0.0, 0.0, 0.0
|
||||
|
||||
for key in answer_dict.keys():
|
||||
raw_text = answer_dict[key][2]
|
||||
triple_gold = answer_dict[key][0]
|
||||
triple_pred = answer_dict[key][1]
|
||||
# triple_pred = self.clean_result_with_dct(raw_text, triple_pred)
|
||||
|
||||
# if set(triple_pred) != set(triple_gold):
|
||||
# print()
|
||||
# print(raw_text)
|
||||
# triple_pred.sort(key=lambda x: x[0])
|
||||
# triple_gold.sort(key=lambda x: x[0])
|
||||
# print(triple_pred)
|
||||
# print(triple_gold)
|
||||
|
||||
spo_em += len(set(triple_pred) & set(triple_gold))
|
||||
spo_pred_num += len(set(triple_pred))
|
||||
spo_gold_num += len(set(triple_gold))
|
||||
|
||||
p = spo_em / spo_pred_num if spo_pred_num != 0 else 0
|
||||
r = spo_em / spo_gold_num if spo_gold_num != 0 else 0
|
||||
f = 2 * p * r / (p + r) if p + r != 0 else 0
|
||||
|
||||
print('============================================')
|
||||
print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, spo_em, spo_pred_num, spo_gold_num))
|
||||
print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(chosen, f * 100, p * 100,
|
||||
r * 100))
|
||||
return {'f1': f, "recall": r, "precision": p}
|
||||
|
||||
def convert2result(self, eval_file, answer_dict):
|
||||
for qid in answer_dict.keys():
|
||||
spoes = answer_dict[qid][2]
|
||||
|
||||
context = eval_file[qid].context
|
||||
tok_to_orig_start_index = eval_file[qid].tok_to_orig_start_index
|
||||
tok_to_orig_end_index = eval_file[qid].tok_to_orig_end_index
|
||||
|
||||
po_predict = []
|
||||
for s, po in spoes.items():
|
||||
po.sort(key=lambda x: x[2])
|
||||
sub_ent = context[tok_to_orig_start_index[s[0] - 1]:tok_to_orig_end_index[s[1] - 1] + 1]
|
||||
for (o1, o2, p) in po:
|
||||
obj_ent = context[tok_to_orig_start_index[o1 - 1]:tok_to_orig_end_index[o2 - 1] + 1]
|
||||
predicate = self.id2rel[p]
|
||||
|
||||
# TODO:到时候选择
|
||||
# if sub_ent.replace(' ','') in context:
|
||||
# sub_ent = sub_ent.replace(' ', '')
|
||||
# if obj_ent.replace(' ','') in context:
|
||||
# obj_ent = obj_ent.replace(' ', '')
|
||||
po_predict.append((sub_ent, predicate, obj_ent))
|
||||
answer_dict[qid][1].extend(po_predict)
|
||||
|
||||
def convert_spo_contour(self, qids, start_preds, end_preds, span_scores, eval_file, answer_dict, threshold=0.5):
|
||||
|
||||
for qid, start_pred, end_pred, span_score in zip(qids.data.cpu().numpy(),
|
||||
start_preds.data.cpu().numpy().tolist(),
|
||||
end_preds.data.cpu().numpy().tolist(),
|
||||
span_scores.data.cpu().numpy().tolist()):
|
||||
example = eval_file[qid.item()]
|
||||
|
||||
text_id = example.text_id
|
||||
tokens = example.bert_tokens
|
||||
|
||||
context = example.context
|
||||
|
||||
span_triple_lst = []
|
||||
|
||||
start_labels = [idx for idx, tmp in enumerate(start_pred) if tmp != 0]
|
||||
end_labels = [idx for idx, tmp in enumerate(end_pred) if tmp != 0]
|
||||
|
||||
for tmp_start in start_labels:
|
||||
if tmp_start > len(tokens) - 2 or tmp_start == 0:
|
||||
continue
|
||||
tmp_end = [tmp for tmp in end_labels if tmp >= tmp_start]
|
||||
if len(tmp_end) == 0:
|
||||
continue
|
||||
for candidate_end in tmp_end:
|
||||
if candidate_end > len(tokens) - 2 or candidate_end == 0:
|
||||
continue
|
||||
for p in range(len(self.id2rel)):
|
||||
if span_score[tmp_start][candidate_end][p] >= threshold:
|
||||
span_triple_lst.append((tmp_start, candidate_end, p))
|
||||
|
||||
po_lst = []
|
||||
for po in span_triple_lst:
|
||||
start, end, p = po
|
||||
ent_name = context[start - 1:end]
|
||||
predicate = self.id2rel[p]
|
||||
po_lst.append((start - 1, end - 1, ent_name, predicate))
|
||||
|
||||
if text_id not in answer_dict:
|
||||
raise ValueError('text_id error in answer_dict ')
|
||||
else:
|
||||
if example.is_split:
|
||||
split_index = example.span_index
|
||||
new_ent_lst = []
|
||||
for (start, end, ent_name, ent_type) in po_lst:
|
||||
start += split_index * self.max_len
|
||||
end += split_index * self.max_len
|
||||
new_ent_lst.append((start, end, ent_name, ent_type))
|
||||
po_lst = new_ent_lst
|
||||
|
||||
answer_dict[text_id][1].extend(po_lst)
|
||||
if len(answer_dict[text_id][0]) > 1:
|
||||
continue
|
||||
answer_dict[text_id][0] = example.g_gold_ent
|
||||
answer_dict[text_id][2] = example.g_raw_text
|
@ -16,14 +16,19 @@ def spo_to_text(spo):
|
||||
def data_preprocess(dir_path):
|
||||
total_files = []
|
||||
|
||||
for i in range(1, 5):
|
||||
files_num=0
|
||||
|
||||
for i in list(range(6, 65)):
|
||||
files_num+=1
|
||||
with open(dir_path + 'res_data_set_{}.json'.format(i), 'r') as fr:
|
||||
print(dir_path + 'res_data_set_{}.json'.format(i))
|
||||
total_files.append(fr.readlines())
|
||||
with open(dir_path + 'res_data_set_{}.json'.format(5), 'r') as fr:
|
||||
print(dir_path + 'res_data_set_{}.json'.format(5))
|
||||
with open(dir_path + 'res_data_set_{}.json'.format(60), 'r') as fr:
|
||||
print(dir_path + 'res_data_set_{}.json'.format(60))
|
||||
files_num += 1
|
||||
p_id = 0
|
||||
final_output = []
|
||||
rel_num=0
|
||||
for line in tqdm(fr.readlines()):
|
||||
src_data = json.loads(line.strip())
|
||||
|
||||
@ -51,19 +56,30 @@ def data_preprocess(dir_path):
|
||||
if spo2text not in spo_source:
|
||||
spo_source[spo2text] = spo
|
||||
for k, v in spo_count.items():
|
||||
if v >= 3:
|
||||
if v >= 23:
|
||||
new_spo_list.append(spo_source[k])
|
||||
rel_num+=len(new_spo_list)
|
||||
final_output.append((text, new_spo_list))
|
||||
p_id += 1
|
||||
with codecs.open('result_chip_0819v1.json', 'w', 'utf-8') as f:
|
||||
for (text, new_spo_list) in final_output:
|
||||
out_put = {}
|
||||
out_put['text'] = text
|
||||
out_put['spo_list'] = new_spo_list
|
||||
json_str = json.dumps(out_put, ensure_ascii=False)
|
||||
f.write(json_str)
|
||||
f.write('\n')
|
||||
print('rel_num',rel_num)
|
||||
print('files_num', files_num)
|
||||
# with codecs.open('deepIE/chip_rel/res_commit/result_chip_0925v5.json', 'w', 'utf-8') as f:
|
||||
# for (text, new_spo_list) in final_output:
|
||||
# out_put = {}
|
||||
# out_put['text'] = text
|
||||
# out_put['spo_list'] = new_spo_list
|
||||
# json_str = json.dumps(out_put, ensure_ascii=False)
|
||||
# f.write(json_str)
|
||||
# f.write('\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_preprocess('deepIE/chip_rel/res/')
|
||||
#result_chip_0924v1,6-54,v=22:11665
|
||||
#result_chip_0924v2,6-54,v=23:11403
|
||||
#result_chip_0924v3,6-54,v=21: 11914
|
||||
#result_chip_0924v4,6-54,146-155,v=27: 11542
|
||||
#result_chip_0925v3,6-60,146-155,v=25:11511 0.665
|
||||
#result_chip_0925v4,6-60,146-155,v=24:11728 0.6656
|
||||
#result_chip_0925v5,6-60,146-155,v=23:11976 0.6659
|
||||
|
||||
|
@ -211,21 +211,19 @@ class SPODataset(Dataset):
|
||||
batch_token_ids.append(token_ids)
|
||||
batch_segment_ids.append(segment_ids)
|
||||
|
||||
ent_label_ids = np.zeros((len(token_ids), 2), dtype=np.float32)
|
||||
for s in ent_labels:
|
||||
ent_label_ids[s[0], 0] = 1
|
||||
ent_label_ids[s[1], 1] = 1
|
||||
if self.is_train:
|
||||
ent_label_ids = np.zeros((len(token_ids), 2), dtype=np.float32)
|
||||
for s in ent_labels:
|
||||
ent_label_ids[s[0], 0] = 1
|
||||
ent_label_ids[s[1], 1] = 1
|
||||
|
||||
batch_ent_labels.append(ent_label_ids)
|
||||
batch_rel_labels.append(rel_labels)
|
||||
batch_ent_labels.append(ent_label_ids)
|
||||
batch_rel_labels.append(rel_labels)
|
||||
|
||||
batch_token_ids = sequence_padding(batch_token_ids, is_float=False)
|
||||
batch_segment_ids = sequence_padding(batch_segment_ids, is_float=False)
|
||||
if not self.is_train:
|
||||
batch_ent_labels = sequence_padding(batch_ent_labels, padding=np.zeros(2), is_float=True)
|
||||
batch_rel_labels = select_padding(batch_token_ids, batch_rel_labels, is_float=True,
|
||||
class_num=len(self.spo_config), use_bert=True)
|
||||
return p_ids, batch_token_ids, batch_segment_ids, batch_ent_labels, batch_rel_labels
|
||||
return p_ids, batch_token_ids, batch_segment_ids
|
||||
else:
|
||||
batch_ent_labels = sequence_padding(batch_ent_labels, padding=np.zeros(2), is_float=True)
|
||||
batch_rel_labels = select_padding(batch_token_ids, batch_rel_labels, is_float=True,
|
||||
|
@ -93,7 +93,7 @@ def get_args():
|
||||
def bulid_dataset(args, spo_config, reader, tokenizer, debug=False):
|
||||
train_src = args.input + "/train_data.json"
|
||||
dev_src = args.input + "/val_data.json"
|
||||
test_src = args.input + "/test1.json"
|
||||
test_src = args.input + "/pse_label.json"
|
||||
|
||||
train_examples_file = args.cache_data + "/train-examples.pkl"
|
||||
dev_examples_file = args.cache_data + "/dev-examples.pkl"
|
||||
|
@ -1,4 +1,6 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import codecs
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
@ -45,6 +47,8 @@ class Trainer(object):
|
||||
self.model = bert.MHSNet(args)
|
||||
|
||||
self.model.to(self.device)
|
||||
if args.train_mode != "train":
|
||||
self.resume(args)
|
||||
|
||||
if self.n_gpu > 1:
|
||||
logging.info('total gpu num is {}'.format(self.n_gpu))
|
||||
@ -161,9 +165,11 @@ class Trainer(object):
|
||||
u"step {} / {} of epoch {}, train/loss: {}\tner:{}\trel:{}".format(step, len(
|
||||
self.data_loader_choice["train"]),
|
||||
epoch,
|
||||
round(current_loss*100,5),
|
||||
round(current_crf_loss*100,5),
|
||||
round(current_selection_loss*100,5)))
|
||||
round(current_loss, 5),
|
||||
round(current_crf_loss, 5),
|
||||
round(
|
||||
current_selection_loss,
|
||||
5)))
|
||||
global_loss, global_crf_loss, global_selection_loss = 0.0, 0.0, 0.0
|
||||
|
||||
res_dev = self.eval_data_set("dev")
|
||||
@ -206,10 +212,9 @@ class Trainer(object):
|
||||
self.optimizer.zero_grad()
|
||||
return loss, crf_loss, selection_loss
|
||||
else:
|
||||
p_ids, passage_ids, segment_ids, ent_ids, rel_ids = batch
|
||||
p_ids, passage_ids, segment_ids = batch
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
ent_logits, rel_logits = self.model(passage_ids=passage_ids, segment_ids=segment_ids, ent_ids=ent_ids,
|
||||
rel_ids=rel_ids, is_eval=eval)
|
||||
ent_logits, rel_logits = self.model(passage_ids=passage_ids, segment_ids=segment_ids, is_eval=eval)
|
||||
answer_dict = self.convert_select_contour(eval_file, p_ids, passage_ids, ent_logits, rel_logits)
|
||||
# p_ids, passage_ids, segment_ids , ent_ids, rel_ids = batch
|
||||
# eval_file = self.eval_file_choice[chosen]
|
||||
@ -236,6 +241,36 @@ class Trainer(object):
|
||||
self.model.train()
|
||||
return res
|
||||
|
||||
def predict_data_set(self, chosen="dev"):
|
||||
|
||||
self.model.eval()
|
||||
|
||||
data_loader = self.data_loader_choice[chosen]
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
answer_dict = {}
|
||||
|
||||
last_time = time.time()
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
|
||||
answer_dict_ = self.forward(batch, chosen, eval=True)
|
||||
answer_dict.update(answer_dict_)
|
||||
used_time = time.time() - last_time
|
||||
logging.info('chosen {} took : {} sec'.format(chosen, used_time))
|
||||
|
||||
with codecs.open(self.args.res_path, 'w', 'utf-8') as f:
|
||||
for key, value in answer_dict.items():
|
||||
entity_pred, spo_tuple_lst = value
|
||||
out_put = {}
|
||||
out_put['text'] = eval_file[key].context
|
||||
spo_lst = []
|
||||
for (s, p, o) in spo_tuple_lst:
|
||||
spo_lst.append({"predicate": p, "subject": s, "object": {"@value": o}})
|
||||
out_put['spo_list'] = spo_lst
|
||||
|
||||
json_str = json.dumps(out_put, ensure_ascii=False)
|
||||
f.write(json_str)
|
||||
f.write('\n')
|
||||
|
||||
def show(self, chosen="dev"):
|
||||
|
||||
self.model.eval()
|
||||
|
Loading…
Reference in New Issue
Block a user