chip rel first commit
This commit is contained in:
parent
b6182e2dc3
commit
84e81fd375
@ -80,9 +80,6 @@ class Reader(object):
|
||||
p_id += 1
|
||||
data_line = json.loads(line.strip())
|
||||
text_raw = data_line['text']
|
||||
if data_type == 'test':
|
||||
break
|
||||
spo_list = data_line['spo_list']
|
||||
|
||||
tokens, tok_to_orig_start_index, tok_to_orig_end_index = covert_to_tokens(text_raw,
|
||||
tokenizer=self.tokenizer,
|
||||
@ -90,8 +87,23 @@ class Reader(object):
|
||||
return_orig_index=True)
|
||||
tokens = ["[CLS]"] + tokens + ["[SEP]"]
|
||||
|
||||
gold_ent_lst, gold_spo_lst = [], []
|
||||
if 'spo_list' not in data_line:
|
||||
examples.append(
|
||||
Example(
|
||||
p_id=p_id,
|
||||
raw_text=data_line['text'],
|
||||
context=text_raw,
|
||||
tok_to_orig_start_index=tok_to_orig_start_index,
|
||||
tok_to_orig_end_index=tok_to_orig_end_index,
|
||||
bert_tokens=tokens,
|
||||
sub_entity_list=None,
|
||||
gold_answer=None,
|
||||
spoes=None
|
||||
))
|
||||
continue
|
||||
|
||||
gold_ent_lst, gold_spo_lst = [], []
|
||||
spo_list = data_line['spo_list']
|
||||
spoes = {}
|
||||
for spo in spo_list:
|
||||
|
||||
@ -101,8 +113,8 @@ class Reader(object):
|
||||
object = spo['object']['@value']
|
||||
gold_spo_lst.append((subject, predicate, object))
|
||||
|
||||
subject_sub_tokens = covert_to_tokens(subject)
|
||||
object_sub_tokens = covert_to_tokens(object)
|
||||
subject_sub_tokens = covert_to_tokens(subject,tokenizer=self.tokenizer)
|
||||
object_sub_tokens = covert_to_tokens(object,tokenizer=self.tokenizer)
|
||||
subject_start, object_start = search_spo_index(tokens, subject_sub_tokens, object_sub_tokens)
|
||||
|
||||
predicate_label = self.spo_conf[predicate]
|
||||
|
@ -74,7 +74,7 @@ def get_args():
|
||||
|
||||
def bulid_dataset(args, spo_config, reader, tokenizer, debug=False):
|
||||
train_src = args.input + "/train_data.json"
|
||||
dev_src = args.input + "/test2_data.json"
|
||||
dev_src = args.input + "/val_data.json"
|
||||
|
||||
train_examples_file = args.cache_data + "/train-examples.pkl"
|
||||
dev_examples_file = args.cache_data + "/dev-examples.pkl"
|
||||
@ -126,7 +126,7 @@ def main():
|
||||
spo_conf = CMeIE_CONFIG if args.spo_version == 'v1' else None
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
|
||||
reader = Reader(spo_conf, tokenizer, max_seq_length=args.max_len)
|
||||
eval_examples, data_loaders, tokenizer = bulid_dataset(args, spo_conf, reader, tokenizer, debug=False)
|
||||
eval_examples, data_loaders, tokenizer = bulid_dataset(args, spo_conf, reader, tokenizer, debug=True)
|
||||
trainer = Trainer(args, data_loaders, eval_examples, spo_conf=spo_conf, tokenizer=tokenizer)
|
||||
|
||||
if args.train_mode == "train":
|
||||
|
@ -5,11 +5,11 @@ from transformers import BertTokenizer
|
||||
from utils import extract_chinese_and_punct
|
||||
|
||||
chineseandpunctuationextractor = extract_chinese_and_punct.ChineseAndPunctuationExtractor()
|
||||
|
||||
moren_tokenizer = BertTokenizer.from_pretrained('transformer_cpt/bert/', do_lower_case=True)
|
||||
|
||||
def covert_to_tokens(text, tokenizer=None, return_orig_index=False, max_seq_length=500):
|
||||
if not tokenizer:
|
||||
tokenizer = BertTokenizer.from_pretrained('transformer_cpt/bert', do_lower_case=True)
|
||||
tokenizer =moren_tokenizer
|
||||
sub_text = []
|
||||
buff = ""
|
||||
flag_en = False
|
||||
|
Loading…
Reference in New Issue
Block a user