z
This commit is contained in:
parent
afce3517f9
commit
a951d3075c
@ -1,3 +1,6 @@
|
||||
"""
|
||||
随机选择subject
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
@ -113,8 +116,8 @@ class Reader(object):
|
||||
object = spo['object']['@value']
|
||||
gold_spo_lst.append((subject, predicate, object))
|
||||
|
||||
subject_sub_tokens = covert_to_tokens(subject,tokenizer=self.tokenizer)
|
||||
object_sub_tokens = covert_to_tokens(object,tokenizer=self.tokenizer)
|
||||
subject_sub_tokens = covert_to_tokens(subject, tokenizer=self.tokenizer)
|
||||
object_sub_tokens = covert_to_tokens(object, tokenizer=self.tokenizer)
|
||||
subject_start, object_start = search_spo_index(tokens, subject_sub_tokens, object_sub_tokens)
|
||||
|
||||
predicate_label = self.spo_conf[predicate]
|
||||
|
@ -0,0 +1,267 @@
|
||||
"""
|
||||
不再随机选择subject,而是将其全部flatten
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from deepIE.chip_rel.utils.data_utils import covert_to_tokens, search_spo_index
|
||||
from utils import extract_chinese_and_punct
|
||||
from utils.data_util import sequence_padding
|
||||
|
||||
chineseandpunctuationextractor = extract_chinese_and_punct.ChineseAndPunctuationExtractor()
|
||||
|
||||
|
||||
class Example(object):
|
||||
def __init__(self,
|
||||
p_id=None,
|
||||
raw_text=None,
|
||||
context=None,
|
||||
choice_sub=None,
|
||||
tok_to_orig_start_index=None,
|
||||
tok_to_orig_end_index=None,
|
||||
bert_tokens=None,
|
||||
spoes=None,
|
||||
sub_entity_list=None,
|
||||
gold_answer=None, ):
|
||||
self.p_id = p_id
|
||||
self.context = context
|
||||
self.raw_text = raw_text
|
||||
self.choice_sub = choice_sub
|
||||
self.tok_to_orig_start_index = tok_to_orig_start_index
|
||||
self.tok_to_orig_end_index = tok_to_orig_end_index
|
||||
self.bert_tokens = bert_tokens
|
||||
self.spoes = spoes
|
||||
self.sub_entity_list = sub_entity_list
|
||||
self.gold_answer = gold_answer
|
||||
|
||||
|
||||
class InputFeature(object):
|
||||
|
||||
def __init__(self,
|
||||
p_id=None,
|
||||
passage_id=None,
|
||||
token_type_id=None,
|
||||
pos_start_id=None,
|
||||
pos_end_id=None,
|
||||
segment_id=None,
|
||||
po_label=None,
|
||||
s1=None,
|
||||
s2=None):
|
||||
self.p_id = p_id
|
||||
self.passage_id = passage_id
|
||||
self.token_type_id = token_type_id
|
||||
self.pos_start_id = pos_start_id
|
||||
self.pos_end_id = pos_end_id
|
||||
self.segment_id = segment_id
|
||||
self.po_label = po_label
|
||||
self.s1 = s1
|
||||
self.s2 = s2
|
||||
|
||||
|
||||
class Reader(object):
|
||||
def __init__(self, spo_conf, tokenizer=None, max_seq_length=None):
|
||||
self.spo_conf = spo_conf
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
def read_examples(self, filename, data_type):
|
||||
logging.info("Generating {} examples...".format(data_type))
|
||||
return self._read(filename, data_type)
|
||||
|
||||
def _read(self, filename, data_type):
|
||||
|
||||
examples = []
|
||||
gold_num = 0
|
||||
with open(filename, 'r') as fr:
|
||||
p_id = 0
|
||||
for line in tqdm(fr.readlines()):
|
||||
p_id += 1
|
||||
data_line = json.loads(line.strip())
|
||||
text_raw = data_line['text']
|
||||
|
||||
tokens, tok_to_orig_start_index, tok_to_orig_end_index = covert_to_tokens(text_raw,
|
||||
tokenizer=self.tokenizer,
|
||||
max_seq_length=self.max_seq_length,
|
||||
return_orig_index=True)
|
||||
tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
|
||||
if 'spo_list' not in data_line:
|
||||
examples.append(
|
||||
Example(
|
||||
p_id=p_id,
|
||||
raw_text=data_line['text'],
|
||||
context=text_raw,
|
||||
tok_to_orig_start_index=tok_to_orig_start_index,
|
||||
tok_to_orig_end_index=tok_to_orig_end_index,
|
||||
bert_tokens=tokens,
|
||||
sub_entity_list=None,
|
||||
gold_answer=None,
|
||||
spoes=None
|
||||
))
|
||||
continue
|
||||
|
||||
gold_ent_lst, gold_spo_lst = [], []
|
||||
spo_list = data_line['spo_list']
|
||||
spoes = {}
|
||||
for spo in spo_list:
|
||||
|
||||
subject = spo['subject']
|
||||
gold_ent_lst.append(subject)
|
||||
predicate = spo['predicate']
|
||||
object = spo['object']['@value']
|
||||
gold_spo_lst.append((subject, predicate, object))
|
||||
|
||||
subject_sub_tokens = covert_to_tokens(subject, tokenizer=self.tokenizer)
|
||||
object_sub_tokens = covert_to_tokens(object, tokenizer=self.tokenizer)
|
||||
subject_start, object_start = search_spo_index(tokens, subject_sub_tokens, object_sub_tokens)
|
||||
|
||||
predicate_label = self.spo_conf[predicate]
|
||||
|
||||
if subject_start != -1 and object_start != -1:
|
||||
s = (subject_start, subject_start + len(subject_sub_tokens) - 1)
|
||||
o = (object_start, object_start + len(object_sub_tokens) - 1, predicate_label)
|
||||
if s not in spoes:
|
||||
spoes[s] = []
|
||||
spoes[s].append(o)
|
||||
if subject_start == -1 or object_start == -1:
|
||||
print('error')
|
||||
print(subject_sub_tokens, object_sub_tokens, text_raw)
|
||||
if data_type == 'train':
|
||||
for s in spoes.keys():
|
||||
examples.append(
|
||||
Example(
|
||||
p_id=p_id,
|
||||
context=text_raw,
|
||||
choice_sub=s,
|
||||
tok_to_orig_start_index=tok_to_orig_start_index,
|
||||
tok_to_orig_end_index=tok_to_orig_end_index,
|
||||
bert_tokens=tokens,
|
||||
sub_entity_list=gold_ent_lst,
|
||||
gold_answer=gold_spo_lst,
|
||||
spoes=spoes
|
||||
|
||||
))
|
||||
else:
|
||||
examples.append(
|
||||
Example(
|
||||
p_id=p_id,
|
||||
context=text_raw,
|
||||
tok_to_orig_start_index=tok_to_orig_start_index,
|
||||
tok_to_orig_end_index=tok_to_orig_end_index,
|
||||
bert_tokens=tokens,
|
||||
sub_entity_list=gold_ent_lst,
|
||||
gold_answer=gold_spo_lst,
|
||||
spoes=spoes
|
||||
|
||||
))
|
||||
gold_num += len(gold_spo_lst)
|
||||
|
||||
logging.info('total gold spo num in {} is {}'.format(data_type, gold_num))
|
||||
|
||||
logging.info("{} total size is {} ".format(data_type, len(examples)))
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
class Feature(object):
|
||||
def __init__(self, max_len, spo_config, tokenizer):
|
||||
self.max_len = max_len
|
||||
self.spo_config = spo_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, examples, data_type):
|
||||
return self.convert_examples_to_bert_features(examples, data_type)
|
||||
|
||||
def convert_examples_to_bert_features(self, examples, data_type):
|
||||
logging.info("convert {} examples to features .".format(data_type))
|
||||
|
||||
examples2features = list()
|
||||
for index, example in enumerate(examples):
|
||||
examples2features.append((index, example))
|
||||
|
||||
logging.info("Built instances is Completed")
|
||||
return SPODataset(examples2features, spo_config=self.spo_config, data_type=data_type,
|
||||
tokenizer=self.tokenizer, max_len=self.max_len)
|
||||
|
||||
|
||||
class SPODataset(Dataset):
|
||||
def __init__(self, data, spo_config, data_type, tokenizer=None, max_len=128):
|
||||
super(SPODataset, self).__init__()
|
||||
self.spo_config = spo_config
|
||||
self.tokenizer = tokenizer
|
||||
self.max_len = max_len
|
||||
self.q_ids = [f[0] for f in data]
|
||||
self.features = [f[1] for f in data]
|
||||
self.is_train = True if data_type == 'train' else False
|
||||
|
||||
def __len__(self):
|
||||
return len(self.q_ids)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.q_ids[index], self.features[index]
|
||||
|
||||
def _create_collate_fn(self):
|
||||
def collate(examples):
|
||||
p_ids, examples = zip(*examples)
|
||||
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
|
||||
batch_token_ids, batch_segment_ids = [], []
|
||||
batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], []
|
||||
for example in examples:
|
||||
spoes = example.spoes
|
||||
token_ids = self.tokenizer.encode(example.bert_tokens)[1:-1]
|
||||
segment_ids = len(token_ids) * [0]
|
||||
|
||||
if self.is_train:
|
||||
if spoes:
|
||||
# subject标签
|
||||
token_type_ids = np.zeros(len(token_ids), dtype=np.long)
|
||||
subject_labels = np.zeros((len(token_ids), 2), dtype=np.float32)
|
||||
for s in spoes:
|
||||
subject_labels[s[0], 0] = 1
|
||||
subject_labels[s[1], 1] = 1
|
||||
# 随机选一个subject
|
||||
# subject_ids = random.choice(list(spoes.keys()))
|
||||
|
||||
# 非随机选一个subject
|
||||
subject_ids = example.choice_sub
|
||||
|
||||
# 对应的object标签
|
||||
object_labels = np.zeros((len(token_ids), len(self.spo_config), 2), dtype=np.float32)
|
||||
for o in spoes.get(subject_ids, []):
|
||||
object_labels[o[0], o[2], 0] = 1
|
||||
object_labels[o[1], o[2], 1] = 1
|
||||
batch_token_ids.append(token_ids)
|
||||
batch_token_type_ids.append(token_type_ids)
|
||||
|
||||
batch_segment_ids.append(segment_ids)
|
||||
batch_subject_labels.append(subject_labels)
|
||||
batch_subject_ids.append(subject_ids)
|
||||
batch_object_labels.append(object_labels)
|
||||
else:
|
||||
batch_token_ids.append(token_ids)
|
||||
batch_segment_ids.append(segment_ids)
|
||||
|
||||
batch_token_ids = sequence_padding(batch_token_ids, is_float=False)
|
||||
batch_segment_ids = sequence_padding(batch_segment_ids, is_float=False)
|
||||
if not self.is_train:
|
||||
return p_ids, batch_token_ids, batch_segment_ids
|
||||
else:
|
||||
batch_token_type_ids = sequence_padding(batch_token_type_ids, is_float=False)
|
||||
batch_subject_ids = torch.tensor(batch_subject_ids)
|
||||
batch_subject_labels = sequence_padding(batch_subject_labels, padding=np.zeros(2), is_float=True)
|
||||
batch_object_labels = sequence_padding(batch_object_labels, padding=np.zeros((len(self.spo_config), 2)),
|
||||
is_float=True)
|
||||
return batch_token_ids, batch_segment_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
|
||||
|
||||
return partial(collate)
|
||||
|
||||
def get_dataloader(self, batch_size, num_workers=0, shuffle=False, pin_memory=False,
|
||||
drop_last=False):
|
||||
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._create_collate_fn(),
|
||||
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)
|
@ -10,7 +10,7 @@ import torch
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from deepIE.chip_rel.config.config import CMeIE_CONFIG
|
||||
from deepIE.chip_rel.etl_span_transformers.data_loader_ptms import Reader, Feature
|
||||
from deepIE.chip_rel.etl_span_transformers.data_loader_ptms_total_sub import Reader, Feature
|
||||
from deepIE.chip_rel.etl_span_transformers.train import Trainer
|
||||
from utils.file_util import save, load
|
||||
|
||||
@ -43,7 +43,7 @@ def get_args():
|
||||
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
||||
|
||||
parser.add_argument("--debug",
|
||||
action='store_true',)
|
||||
action='store_true', )
|
||||
# bert parameters
|
||||
parser.add_argument("--do_lower_case",
|
||||
action='store_true',
|
||||
@ -77,35 +77,46 @@ def get_args():
|
||||
def bulid_dataset(args, spo_config, reader, tokenizer, debug=False):
|
||||
train_src = args.input + "/train_data.json"
|
||||
dev_src = args.input + "/val_data.json"
|
||||
test_src = args.input + "/test1.json"
|
||||
|
||||
train_examples_file = args.cache_data + "/train-examples.pkl"
|
||||
dev_examples_file = args.cache_data + "/dev-examples.pkl"
|
||||
test_examples_file = args.cache_data + "/test-examples.pkl"
|
||||
|
||||
if not os.path.exists(train_examples_file):
|
||||
train_examples = reader.read_examples(train_src, data_type='train')
|
||||
dev_examples = reader.read_examples(dev_src, data_type='dev')
|
||||
test_examples = reader.read_examples(test_src, data_type='test')
|
||||
|
||||
save(train_examples_file, train_examples, message="train examples")
|
||||
save(dev_examples_file, dev_examples, message="dev examples")
|
||||
save(test_examples_file, test_examples, message="test examples")
|
||||
else:
|
||||
logging.info('loading train cache_data {}'.format(train_examples_file))
|
||||
logging.info('loading dev cache_data {}'.format(dev_examples_file))
|
||||
train_examples, dev_examples = load(train_examples_file), load(dev_examples_file)
|
||||
logging.info('loading test cache_data {}'.format(test_examples_file))
|
||||
train_examples, dev_examples,test_examples = load(train_examples_file), load(dev_examples_file),load(test_examples_file)
|
||||
|
||||
logging.info('train examples size is {}'.format(len(train_examples)))
|
||||
logging.info('dev examples size is {}'.format(len(dev_examples)))
|
||||
logging.info('test examples size is {}'.format(len(test_examples)))
|
||||
|
||||
convert_examples_features = Feature(max_len=args.max_len, spo_config=spo_config, tokenizer=tokenizer)
|
||||
|
||||
train_examples = train_examples[:2] if debug else train_examples
|
||||
dev_examples = dev_examples[:2] if debug else dev_examples
|
||||
test_examples = test_examples[:2] if debug else test_examples
|
||||
|
||||
train_data_set = convert_examples_features(train_examples, data_type='train')
|
||||
dev_data_set = convert_examples_features(dev_examples, data_type='dev')
|
||||
test_data_set = convert_examples_features(test_examples, data_type='test')
|
||||
|
||||
train_data_loader = train_data_set.get_dataloader(args.train_batch_size, shuffle=True, pin_memory=args.pin_memory)
|
||||
dev_data_loader = dev_data_set.get_dataloader(args.train_batch_size)
|
||||
test_data_loader = test_data_set.get_dataloader(args.train_batch_size)
|
||||
|
||||
data_loaders = train_data_loader, dev_data_loader
|
||||
eval_examples = train_examples, dev_examples
|
||||
data_loaders = train_data_loader, dev_data_loader, test_data_loader
|
||||
eval_examples = train_examples, dev_examples, test_examples
|
||||
|
||||
return eval_examples, data_loaders, tokenizer
|
||||
|
||||
@ -138,7 +149,7 @@ def main():
|
||||
# trainer.eval_data_set("train")
|
||||
trainer.eval_data_set("dev")
|
||||
elif args.train_mode == "predict":
|
||||
trainer.predict_data_set("dev")
|
||||
trainer.predict_data_set("test")
|
||||
elif args.train_mode == "resume":
|
||||
# trainer.resume(args)
|
||||
trainer.show("dev") # bad case analysis
|
||||
|
@ -8,6 +8,7 @@ from warnings import simplefilter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
import models.spo_net.etl_span_transformers as etl
|
||||
@ -34,21 +35,23 @@ class Trainer(object):
|
||||
self.model = etl.ERENet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
|
||||
|
||||
self.model.to(self.device)
|
||||
if args.train_mode == "predict":
|
||||
if args.train_mode != "train":
|
||||
self.resume(args)
|
||||
# logging.info('total gpu num is {}'.format(self.n_gpu))
|
||||
# if self.n_gpu > 1:
|
||||
# self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
|
||||
logging.info('total gpu num is {}'.format(self.n_gpu))
|
||||
if self.n_gpu > 1:
|
||||
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
|
||||
|
||||
train_dataloader, dev_dataloader = data_loaders
|
||||
train_eval, dev_eval = examples
|
||||
train_dataloader, dev_dataloader, test_dataloader = data_loaders
|
||||
train_eval, dev_eval, test_eval = examples
|
||||
self.eval_file_choice = {
|
||||
"train": train_eval,
|
||||
"dev": dev_eval,
|
||||
"test": test_eval
|
||||
}
|
||||
self.data_loader_choice = {
|
||||
"train": train_dataloader,
|
||||
"dev": dev_dataloader,
|
||||
"test": test_dataloader
|
||||
}
|
||||
# todo 稍后要改成新的优化器,并加入梯度截断
|
||||
self.optimizer = self.set_optimizer(args, self.model,
|
||||
@ -193,7 +196,7 @@ class Trainer(object):
|
||||
|
||||
self.convert2result(eval_file, answer_dict)
|
||||
|
||||
with codecs.open('result_6.json', 'w', 'utf-8') as f:
|
||||
with codecs.open('result_chip_0813v1.json', 'w', 'utf-8') as f:
|
||||
for key, ans_list in answer_dict.items():
|
||||
out_put = {}
|
||||
out_put['text'] = eval_file[int(key)].raw_text
|
||||
@ -224,12 +227,21 @@ class Trainer(object):
|
||||
spo_em, spo_pred_num, spo_gold_num = 0.0, 0.0, 0.0
|
||||
|
||||
for key in answer_dict.keys():
|
||||
|
||||
context = eval_file[key].context
|
||||
|
||||
entity_pred = answer_dict[key][0]
|
||||
entity_gold = eval_file[key].sub_entity_list
|
||||
|
||||
triple_pred = answer_dict[key][1]
|
||||
triple_gold = eval_file[key].gold_answer
|
||||
|
||||
# if set(triple_pred) != set(triple_gold):
|
||||
# print()
|
||||
# print(context)
|
||||
# print(triple_pred)
|
||||
# print(triple_gold)
|
||||
|
||||
ent_em += len(set(entity_pred) & set(entity_gold))
|
||||
ent_pred_num += len(set(entity_pred))
|
||||
ent_gold_num += len(set(entity_gold))
|
||||
@ -299,7 +311,7 @@ class Trainer(object):
|
||||
context = eval_file[qid.item()].context
|
||||
tok_to_orig_start_index = eval_file[qid.item()].tok_to_orig_start_index
|
||||
tok_to_orig_end_index = eval_file[qid.item()].tok_to_orig_end_index
|
||||
start = np.where(po_pred[:, :, 0] > 0.6)
|
||||
start = np.where(po_pred[:, :, 0] > 0.5)
|
||||
end = np.where(po_pred[:, :, 1] > 0.5)
|
||||
|
||||
for _start, predicate1 in zip(*start):
|
||||
|
@ -7,7 +7,7 @@ from utils import extract_chinese_and_punct
|
||||
chineseandpunctuationextractor = extract_chinese_and_punct.ChineseAndPunctuationExtractor()
|
||||
moren_tokenizer = BertTokenizer.from_pretrained('transformer_cpt/bert/', do_lower_case=True)
|
||||
|
||||
def covert_to_tokens(text, tokenizer=None, return_orig_index=False, max_seq_length=500):
|
||||
def covert_to_tokens(text, tokenizer=None, return_orig_index=False, max_seq_length=300):
|
||||
if not tokenizer:
|
||||
tokenizer =moren_tokenizer
|
||||
sub_text = []
|
||||
|
Loading…
Reference in New Issue
Block a user