From db2b8154319151c31fbddcb1c20fc7f319917937 Mon Sep 17 00:00:00 2001 From: loujie0822 Date: Fri, 14 Feb 2020 23:27:02 +0800 Subject: [PATCH 1/5] z --- .../__init__.py | 0 .../mpn/__init__.py | 0 .../mpn/data_loader.py | 462 ++++++++++++++++++ .../mpn/main.py | 171 +++++++ .../mpn/train.py | 295 +++++++++++ 5 files changed, 928 insertions(+) create mode 100644 run/entity_relation_jointed_extraction/__init__.py create mode 100644 run/entity_relation_jointed_extraction/mpn/__init__.py create mode 100644 run/entity_relation_jointed_extraction/mpn/data_loader.py create mode 100644 run/entity_relation_jointed_extraction/mpn/main.py create mode 100644 run/entity_relation_jointed_extraction/mpn/train.py diff --git a/run/entity_relation_jointed_extraction/__init__.py b/run/entity_relation_jointed_extraction/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/run/entity_relation_jointed_extraction/mpn/__init__.py b/run/entity_relation_jointed_extraction/mpn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/run/entity_relation_jointed_extraction/mpn/data_loader.py b/run/entity_relation_jointed_extraction/mpn/data_loader.py new file mode 100644 index 0000000..02dee4b --- /dev/null +++ b/run/entity_relation_jointed_extraction/mpn/data_loader.py @@ -0,0 +1,462 @@ +import codecs +import json +import logging +from collections import Counter +from functools import partial + +import jieba +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +from tqdm import tqdm + +from config.baidu_spo_config import BAIDU_RELATION +from layers.encoders.transformers.bert.bert_tokenization import BertTokenizer +from utils.data_util import padding, _handle_pos_limit, find_position, spo_padding + + +class PredictObject(object): + def __init__(self, + object_name, + object_start, + object_end, + predict_type, + predict_type_id + ): + self.object_name = object_name + self.object_start = object_start + self.object_end = object_end + self.predict_type = predict_type + self.predict_type_id = predict_type_id + + +class Example(object): + def __init__(self, + p_id=None, + context=None, + bert_tokens=None, + sub_pos=None, + sub_entity_list=None, + relative_pos_start=None, + relative_pos_end=None, + po_list=None, + gold_answer=None): + self.p_id = p_id + self.context = context + self.bert_tokens = bert_tokens + self.sub_pos = sub_pos + self.sub_entity_list = sub_entity_list + self.relative_pos_start = relative_pos_start + self.relative_pos_end = relative_pos_end + self.po_list = po_list + self.gold_answer = gold_answer + + +class InputFeature(object): + + def __init__(self, + p_id=None, + passage_id=None, + token_type_id=None, + pos_start_id=None, + pos_end_id=None, + segment_id=None, + po_label=None, + s1=None, + s2=None): + self.p_id = p_id + self.passage_id = passage_id + self.token_type_id = token_type_id + self.pos_start_id = pos_start_id + self.pos_end_id = pos_end_id + self.segment_id = segment_id + self.po_label = po_label + self.s1 = s1 + self.s2 = s2 + + +class Reader(object): + def __init__(self, do_lowercase=False, seg_char=False, max_len=600): + + self.do_lowercase = do_lowercase + self.seg_char = seg_char + self.max_len = max_len + self.relation_config = BAIDU_RELATION + + if self.seg_char: + logging.info("seg_char...") + else: + logging.info("seg_word using jieba ...") + + def read_examples(self, filename, data_type): + logging.info("Generating {} examples...".format(data_type)) + return self._read(filename, data_type) + + def _data_process(self, filename, data_type='train'): + with codecs.open(filename, 'r') as f: + output_data = list() + for line in tqdm(f): + data_json = json.loads(line.strip()) + text = data_json['text'].lower() + sub_po_dict, sub_ent_list, spo_list = dict(), list(), list() + + for spo in data_json['spo_list']: + # TODO .strip('《》').strip() + subject_name = spo['subject'].lower().strip('《》').strip() + object_name = spo['object'].lower().strip('《》').strip() + s_start, s_end = find_position(subject_name, text) + o_start, o_end = find_position(object_name, text) + + if text[s_start:s_end] != subject_name: + print(subject_name) + subject_name = spo['subject'].lower().replace('》', '').replace('《', '') + s_start, s_end = find_position(subject_name, text) + if s_start != -1 and o_start != -1: + sub_ent_list.append((subject_name, s_start, s_end)) + spo_list.append((subject_name, spo['predicate'], object_name)) + if subject_name not in sub_po_dict: + sub_po_dict[subject_name] = {} + sub_po_dict[subject_name]['sub_pos'] = [s_start, s_end] + sub_po_dict[subject_name]['po_list'] = [ + {'predict': spo['predicate'], 'object': (object_name, o_start, o_end)}] + else: + sub_po_dict[subject_name]['po_list'].append( + {'predict': spo['predicate'], 'object': (object_name, o_start, o_end)}) + text_spo = dict() + text_spo['context'] = text + text_spo['sub_po_dict'] = sub_po_dict + text_spo['spo_list'] = list(set(spo_list)) + text_spo['sub_ent_list'] = list(set(sub_ent_list)) + output_data.append(text_spo) + + if data_type == 'train': + return self._convert_train_data(output_data) + return output_data + + @staticmethod + def _convert_train_data(src_data): + """ + 将train_data转化为满足训练要求的形式,即: + 1条数据为:一个subject对应响应的(predict,object)-->sub_po_dict + :param data: + :return: + """ + spo_data = [] + for data in src_data: + for sub_ent, po_dict in data['sub_po_dict'].items(): + data['sub_name'] = sub_ent + data['sub_pos'] = po_dict['sub_pos'] + data['po_list'] = po_dict['po_list'] + + spo_data.append(data) + return spo_data + + def _read(self, filename, data_type): + + data_set = self._data_process(filename, data_type) + logging.info("{} data_set total size is {} ".format(data_type, len(data_set))) + examples = [] + for p_id in tqdm(range(len(data_set))): + data = data_set[p_id] + para = data['context'] + context = para if self.seg_char else ''.join(jieba.lcut(para)) + if len(context) > self.max_len: + context = context[:self.max_len] + + if data_type == 'train': + + start, end = data['sub_pos'] + if start >= self.max_len or end >= self.max_len: + continue + assert data['sub_name'] == context[start:end] + + # pos_start&pos_end: 指句子中词语相对subject_entity的position(相对距离) + # 如:[-30, 30],embed 时整体+31,变成[1, 61] + # 则一共62个pos token,0 留给 pad + pos_start = list(map(lambda i: i - start, list(range(len(context))))) + pos_end = list(map(lambda i: i - end, list(range(len(context))))) + relative_pos_start = _handle_pos_limit(pos_start) + relative_pos_end = _handle_pos_limit(pos_end) + + po_list = [] + for predict_object in data['po_list']: + predict_type = predict_object['predict'] + object_ = predict_object['object'] + object_name, object_start, object_end = object_[0], object_[1], object_[2] + + if object_start >= self.max_len or object_end >= self.max_len: + continue + assert object_name == context[object_start:object_end] + + po_list.append(PredictObject( + object_name=object_name, + object_start=object_start, + object_end=object_end, + predict_type=predict_type, + predict_type_id=self.relation_config[predict_type] + )) + + examples.append( + Example( + p_id=p_id, + context=context, + sub_pos=data['sub_pos'], + sub_entity_list=data['sub_ent_list'], + relative_pos_start=relative_pos_start, + relative_pos_end=relative_pos_end, + po_list=po_list, + gold_answer=data['spo_list'] + ) + ) + else: + examples.append( + Example( + p_id=p_id, + context=context, + sub_pos=None, + sub_entity_list=data['sub_ent_list'], + relative_pos_start=None, + relative_pos_end=None, + po_list=None, + gold_answer=data['spo_list'] + ) + ) + + logging.info("{} total size is {} ".format(data_type, len(examples))) + + return examples + + +class Vocabulary(object): + + def __init__(self, special_tokens=["", ""]): + + self.char_vocab = None + self.emb_mat = None + self.char2idx = None + self.char_counter = Counter() + self.special_tokens = special_tokens + + def build_vocab_only_with_char(self, examples, min_char_count=-1): + + logging.info("Building vocabulary only with character...") + + self.char_vocab = [""] + + if self.special_tokens is not None and isinstance(self.special_tokens, list): + self.char_vocab.extend(self.special_tokens) + + for example in tqdm(examples): + for char in example.context: + self.char_counter[char] += 1 + + for w, v in self.char_counter.most_common(): + if v >= min_char_count: + self.char_vocab.append(w) + + self.char2idx = {token: idx for idx, token in enumerate(self.char_vocab)} + + logging.info("total char counter size is {} ".format(len(self.char_counter))) + logging.info("total char vocabulary size is {} ".format(len(self.char_vocab))) + + def _load_embedding(self, embedding_file, embedding_dict): + + with open(embedding_file) as f: + for line in f: + if len(line.rstrip().split(" ")) <= 2: continue + token, vector = line.rstrip().split(" ", 1) + embedding_dict[token] = np.fromstring(vector, dtype=np.float, sep=" ") + return embedding_dict + + def make_embedding(self, vocab, embedding_file, emb_size): + + embedding_dict = dict() + embedding_dict[""] = np.array([0. for _ in range(emb_size)]) + + self._load_embedding(embedding_file, embedding_dict) + + count = 0 + for token in tqdm(vocab): + if token not in embedding_dict: + count += 1 + embedding_dict[token] = np.array([np.random.normal(scale=0.1) for _ in range(emb_size)]) + logging.info( + "{} / {} tokens have corresponding in embedding vector".format(len(vocab) - count, len(vocab))) + + emb_mat = [embedding_dict[token] for idx, token in enumerate(vocab)] + + return emb_mat + + +class Feature(object): + def __init__(self, args, token2idx_dict): + self.bert = args.use_bert + self.token2idx_dict = token2idx_dict + if self.bert: + self.tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) + + def token2wid(self, token): + if token in self.token2idx_dict: + return self.token2idx_dict[token] + return self.token2idx_dict[""] + + def __call__(self, examples, entity_type, data_type): + + if self.bert: + return self.convert_examples_to_bert_features(examples, entity_type, data_type) + else: + return self.convert_examples_to_features(examples, data_type) + + def convert_examples_to_features(self, examples, data_type): + + logging.info("convert {} examples to features .".format(data_type)) + + examples2features = list() + for index, example in enumerate(examples): + + passage_id = np.zeros(len(example.context), dtype=np.int) + segment_id = np.zeros(len(example.context), dtype=np.int) + token_type_id = np.zeros(len(example.context), dtype=np.int) + pos_start_id = np.zeros(len(example.context), dtype=np.int) + pos_end_id = np.zeros(len(example.context), dtype=np.int) + s1 = np.zeros(len(example.context), dtype=np.float) + s2 = np.zeros(len(example.context), dtype=np.float) + + for (_, start, end) in example.sub_entity_list: + if start >= len(example.context) or end >= len(example.context): + continue + s1[start] = 1.0 + s2[end - 1] = 1.0 + + if data_type == 'train': + sub_start, sub_end = example.sub_pos[0], example.sub_pos[1] + for i, token in enumerate(example.context): + if sub_start <= i < sub_end: + # token = "" + token_type_id[i] = 1 + passage_id[i] = self.token2wid(token) + pos_start_id[i] = example.relative_pos_start[i] + pos_end_id[i] = example.relative_pos_end[i] + + examples2features.append( + InputFeature( + p_id=index, + passage_id=passage_id, + token_type_id=token_type_id, + pos_start_id=pos_start_id, + pos_end_id=pos_end_id, + segment_id=segment_id, + po_label=example.po_list, + s1=s1, + s2=s2 + + )) + + logging.info("Built instances is Completed") + return SPODataset(examples2features, predict_num=len(BAIDU_RELATION), data_type=data_type) + + def convert_examples_to_bert_features(self, examples, entity_type, data_type): + + logging.info("Processing {} examples...".format(data_type)) + + examples2features = list() + for index, example in enumerate(examples): + + gold_attr_list = example.gold_attr_list + ent_start, ent_end = example.entity_position[0], example.entity_position[1] + segment_id = np.zeros(len(example.context) + 2, dtype=np.int) + token_type_id = np.zeros(len(example.context) + 2, dtype=np.int) + pos_start_id = np.zeros(len(example.context) + 2, dtype=np.int) + pos_end_id = np.zeros(len(example.context) + 2, dtype=np.int) + + tokens = ["[CLS]"] + raw_tokens = ["[CLS]"] + for i, token in enumerate(example.context): + raw_tokens.append(token) + if ent_start <= i < ent_end: + # token_type_id[i + 1] = 1 + # segment_id[i + 1] = 1 + token = '[unused1]' + tokens.append(token) + pos_start_id[i + 1] = example.pos_start[i] + pos_end_id[i + 1] = example.pos_end[i] + + tokens.append("[SEP]") + raw_tokens.append("[SEP]") + passage_id = self.tokenizer.convert_tokens_to_ids(tokens) + example.bert_tokens = raw_tokens + examples2features.append( + InputFeature( + p_id=index, + passage_id=passage_id, + token_type_id=token_type_id, + pos_start_id=pos_start_id, + pos_end_id=pos_end_id, + segment_id=segment_id, + po_label=gold_attr_list + )) + + logging.info("Built instances is Completed") + return SPODataset(examples2features, predict_num=len(BAIDU_RELATION), use_bert=True) + + +class SPODataset(Dataset): + def __init__(self, features, predict_num, data_type, use_bert=False): + super(SPODataset, self).__init__() + self.use_bert = use_bert + self.is_train = True if data_type == 'train' else False + self.q_ids = [f.p_id for f in features] + self.passages = [f.passage_id for f in features] + self.segment_ids = [f.segment_id for f in features] + self.predict_num = predict_num + + if self.is_train: + self.token_type = [f.token_type_id for f in features] + self.pos_start_ids = [f.pos_start_id for f in features] + self.pos_end_ids = [f.pos_end_id for f in features] + self.s1 = [f.s1 for f in features] + self.s2 = [f.s2 for f in features] + self.po_label = [f.po_label for f in features] + + def __len__(self): + return len(self.passages) + + def __getitem__(self, index): + if self.is_train: + return self.q_ids[index], self.passages[index], self.segment_ids[index], self.token_type[index], \ + self.pos_start_ids[index], self.pos_end_ids[index], self.s1[index], self.s2[index], self.po_label[ + index] + else: + return self.q_ids[index], self.passages[index], self.segment_ids[index] + + def _create_collate_fn(self, batch_first=False): + def collate(examples): + if self.is_train: + p_ids, passages, segment_ids, token_type, pos_start_ids, pos_end_ids, s1, s2, label = zip(*examples) + + p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long) + passages_tensor, _ = padding(passages, is_float=False, batch_first=batch_first) + segment_tensor, _ = padding(segment_ids, is_float=False, batch_first=batch_first) + + token_type_tensor, _ = padding(token_type, is_float=False, batch_first=batch_first) + pos_start_tensor, _ = padding(pos_start_ids, is_float=False, batch_first=batch_first) + pos_end_tensor, _ = padding(pos_end_ids, is_float=False, batch_first=batch_first) + s1_tensor, _ = padding(s1, is_float=True, batch_first=batch_first) + s2_tensor, _ = padding(s2, is_float=True, batch_first=batch_first) + po1_tensor, po2_tensor = spo_padding(passages, label, class_num=self.predict_num, is_float=True, + use_bert=self.use_bert) + return p_ids, passages_tensor, segment_tensor, token_type_tensor, pos_start_tensor, pos_end_tensor, \ + s1_tensor, s2_tensor, po1_tensor, po2_tensor + else: + p_ids, passages, segment_ids = zip(*examples) + p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long) + passages_tensor, _ = padding(passages, is_float=False, batch_first=batch_first) + segment_tensor, _ = padding(segment_ids, is_float=False, batch_first=batch_first) + return p_ids, passages_tensor, segment_tensor + + return partial(collate) + + def get_dataloader(self, batch_size, num_workers=0, shuffle=False, batch_first=True, pin_memory=False, + drop_last=False): + return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._create_collate_fn(batch_first), + num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last) diff --git a/run/entity_relation_jointed_extraction/mpn/main.py b/run/entity_relation_jointed_extraction/mpn/main.py new file mode 100644 index 0000000..a7a957c --- /dev/null +++ b/run/entity_relation_jointed_extraction/mpn/main.py @@ -0,0 +1,171 @@ +# _*_ coding:utf-8 _*_ +import argparse +import logging +import os + +from config.baidu_spo_config import BAIDU_RELATION +from run.entity_relation_jointed_extraction.mpn.data_loader import Reader, Vocabulary, Feature +from run.entity_relation_jointed_extraction.mpn.train import Trainer +from utils.file_util import save, load + +logger = logging.getLogger() +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + +def get_args(): + parser = argparse.ArgumentParser() + + # file parameters + parser.add_argument("--input", default=None, type=str, required=True) + parser.add_argument("--output" + , default=None, type=str, required=False, + help="The output directory where the model checkpoints and predictions will be written.") + # "cpt/baidu_w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5" + # 'cpt/baidu_w2v/w2v.txt' + parser.add_argument('--embedding_file', type=str, + default='cpt/baidu_w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5') + + # choice parameters + parser.add_argument('--entity_type', type=str, default='disease') + parser.add_argument('--use_word2vec', type=bool, default=False) + parser.add_argument('--use_bert', type=bool, default=False) + parser.add_argument('--seg_char', type=bool, default=True) + + # train parameters + parser.add_argument('--train_mode', type=str, default="train") + parser.add_argument("--train_batch_size", default=4, type=int, help="Total batch size for training.") + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument("--epoch_num", default=3, type=int, + help="Total number of training epochs to perform.") + parser.add_argument('--patience_stop', type=int, default=10, help='Patience for learning early stop') + parser.add_argument('--device_id', type=int, default=0) + parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") + + # bert parameters + parser.add_argument("--do_lower_case", + action='store_true', + help="Whether to lower case the input text. True for uncased models, False for cased models.") + parser.add_argument("--warmup_proportion", default=0.1, type=float, + help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% " + "of training.") + parser.add_argument("--bert_model", default=None, type=str, + help="Bert pre-trained model selected in the list: bert-base-uncased, " + "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, " + "bert-base-multilingual-cased, bert-base-chinese.") + + # model parameters + parser.add_argument("--max_len", default=1000, type=int) + parser.add_argument('--word_emb_size', type=int, default=300) + parser.add_argument('--char_emb_size', type=int, default=300) + parser.add_argument('--entity_emb_size', type=int, default=300) + parser.add_argument('--pos_limit', type=int, default=30) + parser.add_argument('--pos_dim', type=int, default=300) + parser.add_argument('--pos_size', type=int, default=62) + + parser.add_argument('--hidden_size', type=int, default=150) + parser.add_argument('--num_layers', type=int, default=2) + parser.add_argument('--dropout', type=int, default=0.5) + parser.add_argument('--rnn_encoder', type=str, default='lstm', help="must choose in blow: lstm or gru") + parser.add_argument('--bidirectional', type=bool, default=True) + parser.add_argument('--pin_memory', type=bool, default=False) + parser.add_argument('--transformer_layers', type=int, default=1) + parser.add_argument('--nhead', type=int, default=4) + parser.add_argument('--dim_feedforward', type=int, default=2048) + args = parser.parse_args() + if args.use_word2vec: + args.cache_data = args.input + '/char2v_cache_data/' + elif args.use_bert: + args.cache_data = args.input + '/char_bert_cache_data/' + else: + args.cache_data = args.input + '/char_cache_data/' + return args + + +def bulid_dataset(args, reader, vocab, debug=False): + char2idx, char_emb = None, None + train_src = args.input + "/train_data.json" + dev_src = args.input + "/dev_data.json" + + train_examples_file = args.cache_data + "/train-examples.pkl" + dev_examples_file = args.cache_data + "/dev-examples.pkl" + + char_emb_file = args.cache_data + "/char_emb.pkl" + char_dictionary = args.cache_data + "/char_dict.pkl" + + if not os.path.exists(train_examples_file): + + train_examples = reader.read_examples(train_src, data_type='train') + dev_examples = reader.read_examples(dev_src, data_type='dev') + + if not args.use_bert: + # todo : min_word_count=3 ? + vocab.build_vocab_only_with_char(train_examples, min_char_count=1) + if args.use_word2vec and args.embedding_file: + char_emb = vocab.make_embedding(vocab=vocab.char_vocab, + embedding_file=args.embedding_file, + emb_size=args.word_emb_size) + save(char_emb_file, char_emb, message="char embedding") + save(char_dictionary, vocab.char2idx, message="char dictionary") + char2idx = vocab.char2idx + save(train_examples_file, train_examples, message="train examples") + save(dev_examples_file, dev_examples, message="dev examples") + else: + if not args.use_bert: + if args.use_word2vec and args.embedding_file: + char_emb = load(char_emb_file) + char2idx = load(char_dictionary) + logging.info("total char vocabulary size is {} ".format(len(char2idx))) + train_examples, dev_examples = load(train_examples_file), load(dev_examples_file) + + logging.info('train examples size is {}'.format(len(train_examples))) + logging.info('dev examples size is {}'.format(len(dev_examples))) + + if not args.use_bert: + args.vocab_size = len(char2idx) + convert_examples_features = Feature(args, token2idx_dict=char2idx) + + train_examples = train_examples[:10] if debug else train_examples + dev_examples = dev_examples[:10] if debug else dev_examples + + train_data_set = convert_examples_features(train_examples, entity_type=args.entity_type, + data_type='train') + dev_data_set = convert_examples_features(dev_examples, entity_type=args.entity_type, + data_type='dev') + train_data_loader = train_data_set.get_dataloader(args.train_batch_size, shuffle=True, pin_memory=args.pin_memory) + dev_data_loader = dev_data_set.get_dataloader(args.train_batch_size) + + data_loaders = train_data_loader, dev_data_loader + eval_examples = train_examples, dev_examples + + return eval_examples, data_loaders, char_emb + +def main(): + args = get_args() + if not os.path.exists(args.output): + print('mkdir {}'.format(args.output)) + os.makedirs(args.output) + if not os.path.exists(args.cache_data): + print('mkdir {}'.format(args.cache_data)) + os.makedirs(args.cache_data) + + logger.info("** ** * bulid dataset ** ** * ") + reader = Reader(seg_char=args.seg_char, max_len=args.max_len) + vocab = Vocabulary() + + eval_examples, data_loaders, char_emb = bulid_dataset(args, reader, vocab, debug=False) + + trainer = Trainer(args, data_loaders, eval_examples, char_emb, spo_conf=BAIDU_RELATION) + + if args.train_mode == "train": + trainer.train(args) + elif args.train_mode == "eval": + trainer.resume(args) + trainer.eval_data_set("train") + trainer.eval_data_set("dev") + elif args.train_mode == "resume": + # trainer.resume(args) + trainer.show("dev") # bad case analysis + + +if __name__ == '__main__': + main() diff --git a/run/entity_relation_jointed_extraction/mpn/train.py b/run/entity_relation_jointed_extraction/mpn/train.py new file mode 100644 index 0000000..cc87d8a --- /dev/null +++ b/run/entity_relation_jointed_extraction/mpn/train.py @@ -0,0 +1,295 @@ +# _*_ coding:utf-8 _*_ +import logging +import random +import sys +import time + +import numpy as np +import torch +from torch import nn +from tqdm import tqdm + +import models.attribute_extract_net.bert_mpn as bert_mpn +import models.attribute_extract_net.mpn as mpn +from utils.optimizer_util import set_optimizer + +logger = logging.getLogger(__name__) + + +class Trainer(object): + + def __init__(self, args, data_loaders, examples, char_emb, spo_conf): + if args.use_bert: + self.model = bert_mpn.AttributeExtractNet.from_pretrained(args.bert_model, args, spo_conf) + else: + self.model = mpn.AttributeExtractNet(args, char_emb, spo_conf) + + self.args = args + + self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu") + self.n_gpu = torch.cuda.device_count() + + self.id2rel = {item: key for key, item in attribute_conf.items()} + self.rel2id =attribute_conf + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if self.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + self.model.to(self.device) + # self.resume(args) + logging.info('total gpu num is {}'.format(self.n_gpu)) + if self.n_gpu > 1: + self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1]) + + train_dataloader, dev_dataloader = data_loaders + train_eval, dev_eval = examples + self.eval_file_choice = { + "train": train_eval, + "dev": dev_eval, + } + self.data_loader_choice = { + "train": train_dataloader, + "dev": dev_dataloader, + } + self.optimizer = set_optimizer(args, self.model, + train_steps=(int(len(train_eval) / args.train_batch_size) + 1) * args.epoch_num) + + def train(self, args): + + best_f1 = 0.0 + patience_stop = 0 + self.model.train() + step_gap = 20 + for epoch in range(int(args.epoch_num)): + + global_loss = 0.0 + + for step, batch in tqdm(enumerate(self.data_loader_choice[u"train"]), mininterval=5, + desc=u'training at epoch : %d ' % epoch, leave=False, file=sys.stdout): + + loss, answer_dict_ = self.forward(batch) + + if step % step_gap == 0: + global_loss += loss + current_loss = global_loss / step_gap + print( + u"step {} / {} of epoch {}, train/loss: {}".format(step, len(self.data_loader_choice["train"]), + epoch, current_loss)) + global_loss = 0.0 + + res_dev = self.eval_data_set("dev") + if res_dev['f1'] >= best_f1: + best_f1 = res_dev['f1'] + logging.info("** ** * Saving fine-tuned model ** ** * ") + model_to_save = self.model.module if hasattr(self.model, + 'module') else self.model # Only save the model it-self + output_model_file = args.output + "/pytorch_model.bin" + torch.save(model_to_save.state_dict(), str(output_model_file)) + patience_stop = 0 + else: + patience_stop += 1 + if patience_stop >= args.patience_stop: + return + + def resume(self, args): + resume_model_file = args.output + "/pytorch_model.bin" + logging.info("=> loading checkpoint '{}'".format(resume_model_file)) + checkpoint = torch.load(resume_model_file, map_location='cpu') + self.model.load_state_dict(checkpoint) + + def forward(self, batch, chosen=u'train', grad=True, eval=False, detail=False): + + batch = tuple(t.to(self.device) for t in batch) + + p_ids, passage_id, token_type_id, segment_id, pos_start,pos_end,start_id, end_id = batch + loss, po1, po2 = self.model(passage_id=passage_id, token_type_id=token_type_id, segment_id=segment_id, + pos_start=pos_start,pos_end=pos_end,start_id=start_id, end_id=end_id, is_eval=eval) + + if self.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu. + + if grad: + loss.backward() + loss = loss.item() + self.optimizer.step() + self.optimizer.zero_grad() + + if eval: + eval_file = self.eval_file_choice[chosen] + answer_dict_ = convert_pointer_net_contour(eval_file, p_ids, po1, po2, self.id2rel, + use_bert=self.args.use_bert) + else: + answer_dict_ = None + return loss, answer_dict_ + + def eval_data_set(self, chosen="dev"): + + self.model.eval() + answer_dict = {} + + data_loader = self.data_loader_choice[chosen] + eval_file = self.eval_file_choice[chosen] + last_time = time.time() + with torch.no_grad(): + for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout): + loss, answer_dict_ = self.forward(batch, chosen, grad=False, eval=True) + answer_dict.update(answer_dict_) + used_time = time.time() - last_time + logging.info('chosen {} took : {} sec'.format(chosen, used_time)) + res = self.evaluate(eval_file, answer_dict, chosen) + self.detail_evaluate(eval_file, answer_dict, chosen) + self.model.train() + return res + + def show(self, chosen="dev"): + + self.model.eval() + answer_dict = {} + + data_loader = self.data_loader_choice[chosen] + eval_file = self.eval_file_choice[chosen] + with torch.no_grad(): + for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout): + loss, answer_dict_ = self.forward(batch, chosen, grad=False, eval=True, detail=True) + answer_dict.update(answer_dict_) + self.badcase_analysis(eval_file, answer_dict, chosen) + + @staticmethod + def evaluate(eval_file, answer_dict, chosen): + + em = 0 + pre = 0 + gold = 0 + for key, value in answer_dict.items(): + ground_truths = eval_file[int(key)].gold_answer + value, l1, l2 = value + prediction = list(value) if len(value) else [] + assert type(prediction) == type(ground_truths) + intersection = set(prediction) & set(ground_truths) + + em += len(intersection) + pre += len(set(prediction)) + gold += len(set(ground_truths)) + + precision = 100.0 * em / pre if pre > 0 else 0. + recall = 100.0 * em / gold if gold > 0 else 0. + f1 = 2 * recall * precision / (recall + precision) if (recall + precision) != 0 else 0.0 + print('============================================') + print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, em, pre, gold)) + print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(chosen, f1, precision, + recall)) + return {'f1': f1, "recall": recall, "precision": precision, 'em': em, 'pre': pre, 'gold': gold} + + def detail_evaluate(self, eval_file, answer_dict, chosen): + def generate_detail_dict(spo_list): + dict_detail = dict() + for i, tag in enumerate(spo_list): + detail_name = tag.split('@')[0] + if detail_name not in dict_detail: + dict_detail[detail_name] = [tag] + else: + dict_detail[detail_name].append(tag) + return dict_detail + + total_detail = {} + for key, value in answer_dict.items(): + ground_truths = eval_file[int(key)].gold_answer + value, l1, l2 = value + prediction = list(value) if len(value) else [] + + gold_detail = generate_detail_dict(ground_truths) + pred_detail = generate_detail_dict(prediction) + for key in self.rel2id.keys(): + + pred = pred_detail.get(key, []) + gold = gold_detail.get(key, []) + em = len(set(pred) & set(gold)) + pred_num = len(set(pred)) + gold_num = len(set(gold)) + + if key not in total_detail: + total_detail[key] = dict() + total_detail[key]['em'] = em + total_detail[key]['pred_num'] = pred_num + total_detail[key]['gold_num'] = gold_num + else: + total_detail[key]['em'] += em + total_detail[key]['pred_num'] += pred_num + total_detail[key]['gold_num'] += gold_num + for key, res_dict_ in total_detail.items(): + res_dict_['p'] = 100.0 * res_dict_['em'] / res_dict_['pred_num'] if res_dict_['pred_num'] != 0 else 0.0 + res_dict_['r'] = 100.0 * res_dict_['em'] / res_dict_['gold_num'] if res_dict_['gold_num'] != 0 else 0.0 + res_dict_['f'] = 2 * res_dict_['p'] * res_dict_['r'] / (res_dict_['p'] + res_dict_['r']) if res_dict_['p'] + \ + res_dict_[ + 'r'] != 0 else 0.0 + + for gold_key, res_dict_ in total_detail.items(): + print('===============================================================') + print("{}/em: {},\tpred_num&gold_num: {}\t{} ".format(gold_key, res_dict_['em'], res_dict_['pred_num'], + res_dict_['gold_num'])) + print( + "{}/f1: {},\tprecison&recall: {}\t{}".format(gold_key, res_dict_['f'], res_dict_['p'], res_dict_['r'])) + + @staticmethod + def badcase_analysis(eval_file, answer_dict, chosen): + em = 0 + pre = 0 + gold = 0 + content = '' + for key, value in answer_dict.items(): + entity_name = eval_file[int(key)].entity_name + context = eval_file[int(key)].context + ground_truths = eval_file[int(key)].gold_answer + value, l1, l2 = value + prediction = list(value) if len(value) else [''] + assert type(prediction) == type(ground_truths) + + intersection = set(prediction) & set(ground_truths) + + if prediction == ground_truths == ['']: + continue + if set(prediction) != set(ground_truths): + ground_truths = list(sorted(set(ground_truths))) + prediction = list(sorted(set(prediction))) + print('raw context is:\t' + context) + print('subject_name is:\t' + entity_name) + print('pred_text is:\t' + '\t'.join(prediction)) + print('gold_text is:\t' + '\t'.join(ground_truths)) + content += 'raw context is:\t' + context + '\n' + content += 'subject_name is:\t' + entity_name + '\n' + content += 'pred_text is:\t' + '\t'.join(prediction) + '\n' + content += 'gold_text is:\t' + '\t'.join(ground_truths) + '\n' + content += '===============================' + em += len(intersection) + pre += len(set(prediction)) + gold += len(set(ground_truths)) + with open('badcase_{}.txt'.format(chosen), 'w') as f: + f.write(content) + + +def convert_pointer_net_contour(eval_file, q_ids, po1, po2, id2rel, use_bert=False): + answer_dict = dict() + for qid, o1, o2 in zip(q_ids, po1.data.cpu().numpy(), po2.data.cpu().numpy()): + + context = eval_file[qid.item()].context if not use_bert else eval_file[qid.item()].bert_tokens + gold_attr_list = eval_file[qid.item()].gold_attr_list + gold_answer = [attr.attr_type + '@' + attr.value for attr in gold_attr_list] + + answers = list() + start, end = np.where(o1 > 0.5), np.where(o2 > 0.5) + for _start, _attr_type_id_start in zip(*start): + if _start > len(context) or (_start == 0 and use_bert): + continue + for _end, _attr_type_id_end in zip(*end): + if _start <= _end < len(context) and _attr_type_id_start == _attr_type_id_end: + _attr_value = ''.join(context[_start: _end + 1]) if use_bert else context[_start: _end + 1] + _attr_type = id2rel[_attr_type_id_start] + _attr = _attr_type + '@' + _attr_value + answers.append(_attr) + break + + answer_dict[str(qid.item())] = [answers, o1, o2] + + return answer_dict From 3e0e4a65d6afd3dd8ec50a04b67197a3e2772eba Mon Sep 17 00:00:00 2001 From: loujie0822 Date: Sat, 15 Feb 2020 18:24:17 +0800 Subject: [PATCH 2/5] z --- .gitignore | 2 +- pyscripts/__init__.py | 0 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 pyscripts/__init__.py diff --git a/.gitignore b/.gitignore index 1a77509..ff77dde 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ settings.py instance/ data/ - +pyscripts/ .pytest_cache/ .coverage diff --git a/pyscripts/__init__.py b/pyscripts/__init__.py new file mode 100644 index 0000000..e69de29 From 1a3811d8d0ce115443789302a8756dfdfd62588a Mon Sep 17 00:00:00 2001 From: loujie0822 Date: Mon, 17 Feb 2020 13:12:36 +0800 Subject: [PATCH 3/5] z --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index ff77dde..e50ed84 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ settings.py instance/ data/ +cpt/ pyscripts/ .pytest_cache/ .coverage From dd6dbb0f35950c3fa5393900c2733f443c110d9f Mon Sep 17 00:00:00 2001 From: loujie0822 Date: Mon, 17 Feb 2020 13:14:58 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AE=9E=E4=BD=93?= =?UTF-8?q?=E5=85=B3=E7=B3=BB=E8=81=94=E5=90=88=E6=8A=BD=E5=8F=96=E6=9C=BA?= =?UTF-8?q?=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/ere_net/bert_mpn.py | 232 ++++++++++++++++ models/ere_net/mpn.py | 248 ++++++++++++++++++ .../mpn/{data.py => data_loader.py} | 0 run/attribute_extract/mpn/main.py | 2 +- .../mpn/data_loader.py | 59 +++-- .../mpn/main.py | 12 +- .../mpn/train.py | 198 ++++++-------- 7 files changed, 607 insertions(+), 144 deletions(-) create mode 100644 models/ere_net/bert_mpn.py create mode 100644 models/ere_net/mpn.py rename run/attribute_extract/mpn/{data.py => data_loader.py} (100%) diff --git a/models/ere_net/bert_mpn.py b/models/ere_net/bert_mpn.py new file mode 100644 index 0000000..3f1a2b0 --- /dev/null +++ b/models/ere_net/bert_mpn.py @@ -0,0 +1,232 @@ +# _*_ coding:utf-8 _*_ +import warnings + +import numpy as np +import torch +import torch.nn as nn + +from layers.encoders.transformers.bert.bert_model import BertModel + +warnings.filterwarnings("ignore") +from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer + + +class EntityNET(nn.Module): + """ + ERENet : entity relation extraction + """ + + def __init__(self, args): + super(EntityNET, self).__init__() + + self.sb1 = nn.Linear(args.bert_hidden_size, 1) + self.sb2 = nn.Linear(args.bert_hidden_size, 1) + + def forward(self, sent_encoder, q_ids=None, eval_file=None, passages=None, s1=None, s2=None, is_eval=False): + + sequence_mask = passages != 0 + sb1 = self.sb1(sent_encoder).squeeze() + sb2 = self.sb2(sent_encoder).squeeze() + + if not is_eval: + loss_fct = nn.BCEWithLogitsLoss(reduction='none') + + sb1_loss = loss_fct(sb1, s1) + s1_loss = torch.sum(sb1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) + + s2_loss = loss_fct(sb2, s2) + s2_loss = torch.sum(s2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) + + ent_loss = s1_loss + s2_loss + return ent_loss + else: + answer_list = self.predict(eval_file, q_ids, sb1, sb2) + return answer_list + + def predict(self, eval_file, q_ids=None, sb1=None, sb2=None): + answer_list = list() + for qid, p1, p2 in zip(q_ids.cpu().numpy(), + sb1.cpu().numpy(), + sb2.cpu().numpy()): + + context = eval_file[qid].context + start = None + end = None + threshold = 0.0 + positions = list() + for idx in range(0, len(context)): + if idx == 0: + continue + if p1[idx] > threshold and start is None: + start = idx + if p2[idx] > threshold and end is None: + end = idx + if start is not None and end is not None and start <= end: + positions.append((start, end + 1)) + start = None + end = None + answer_list.append(positions) + + return answer_list + + +class RelNET(nn.Module): + """ + ERENet : entity relation extraction + """ + + def __init__(self, args, spo_conf): + super(RelNET, self).__init__() + self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.bert_hidden_size, + padding_idx=0) + self.encoder_layer = TransformerEncoderLayer(args.bert_hidden_size, args.nhead) + self.transformer_encoder = TransformerEncoder(self.encoder_layer, args.transformer_layers) + + self.classes_num = len(spo_conf) + self.ob1 = nn.Linear(args.bert_hidden_size, self.classes_num) + self.ob2 = nn.Linear(args.bert_hidden_size, self.classes_num) + + def forward(self, passages=None, sent_encoder=None, posit_ids=None, o1=None, o2=None, is_eval=False): + mask = passages.eq(0) + + subject_encoder = sent_encoder + self.token_entity_emb(posit_ids) + + subject_encoder = torch.transpose(subject_encoder, 1, 0) + transformer_encoder = self.transformer_encoder(subject_encoder, src_key_padding_mask=mask) + transformer_encoder = torch.transpose(transformer_encoder, 0, 1) + + po1 = self.ob1(transformer_encoder) + po2 = self.ob2(transformer_encoder) + + if not is_eval: + loss_fct = nn.BCEWithLogitsLoss(reduction='none') + + sequence_mask = passages != 0 + + s1_loss = loss_fct(po1, o1) + s1_loss = torch.sum(s1_loss, 2) + s1_loss = torch.sum(s1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num + + s2_loss = loss_fct(po2, o2) + s2_loss = torch.sum(s2_loss, 2) + s2_loss = torch.sum(s2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num + + rel_loss = s1_loss + s2_loss + + return rel_loss + + else: + po1 = nn.Sigmoid()(po1) + po2 = nn.Sigmoid()(po2) + return po1, po2 + + +class ERENet(nn.Module): + """ + ERENet : entity relation extraction + """ + + def __init__(self, args, spo_conf): + super(ERENet, self).__init__() + print('joint entity relation extraction') + self.bert_encoder = BertModel.from_pretrained(args.bert_model) + self.entity_extraction = EntityNET(args) + self.rel_extraction = RelNET(args, spo_conf) + + def forward(self, q_ids=None, eval_file=None, passages=None, token_type_ids=None, segment_ids=None, s1=None, + s2=None, po1=None, po2=None, is_eval=False): + + sequence_mask = passages != 0 + sent_encoder, _ = self.bert_encoder(passages, token_type_ids=segment_ids, attention_mask=sequence_mask, + output_all_encoded_layers=False) + + if not is_eval: + # entity_extraction + ent_loss = self.entity_extraction(sent_encoder, passages=passages, s1=s1, s2=s2, + is_eval=is_eval) + + # rel_extraction + rel_loss = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, posit_ids=token_type_ids, + o1=po1, + o2=po2, is_eval=False) + + # add total loss + total_loss = ent_loss + rel_loss + + return total_loss + + + else: + + answer_list = self.entity_extraction(sent_encoder, q_ids=q_ids, eval_file=eval_file, + passages=passages, is_eval=is_eval) + start_list, end_list = list(), list() + qid_list, pass_list, posit_list, sent_list = list(), list(), list(), list() + for i, ans_list in enumerate(answer_list): + seq_len = passages.size(1) + posit_ids = [] + for ans_tuple in ans_list: + posit_array = np.zeros(seq_len, dtype=np.int) + start, end = ans_tuple[0], ans_tuple[1] + start_list.append(start) + end_list.append(end) + posit_array[start:end] = 1 + posit_ids.append(posit_array) + + if len(posit_ids) == 0: + continue + qid_ = q_ids[i].unsqueeze(0).expand(len(posit_ids)) + sent_tensor = sent_encoder[i, :, :].unsqueeze(0).expand(len(posit_ids), sent_encoder.size(1), + sent_encoder.size(2)) + pass_tensor = passages[i, :].unsqueeze(0).expand(len(posit_ids), passages.size(1)) + posit_tensor = torch.tensor(posit_ids, dtype=torch.long).to(sent_encoder.device) + + qid_list.append(qid_) + pass_list.append(pass_tensor) + posit_list.append(posit_tensor) + sent_list.append(sent_tensor) + + if len(qid_list) == 0: + qid_tensor = torch.tensor([-1, -1], dtype=torch.long).to(sent_encoder.device) + return qid_tensor, qid_tensor, qid_tensor, qid_tensor, qid_tensor + + qid_tensor = torch.cat(qid_list).to(sent_encoder.device) + sent_tensor = torch.cat(sent_list).to(sent_encoder.device) + pass_tensor = torch.cat(pass_list).to(sent_encoder.device) + posi_tensor = torch.cat(posit_list).to(sent_encoder.device) + + flag = False + split_heads = 1024 + + inputs = torch.split(pass_tensor, split_heads, dim=0) + posits = torch.split(posi_tensor, split_heads, dim=0) + sents = torch.split(sent_tensor, split_heads, dim=0) + + po1_list, po2_list = list(), list() + for i in range(len(inputs)): + passages = inputs[i] + sent_encoder = sents[i] + posit_ids = posits[i] + + if passages.size(0) == 1: + flag = True + passages = passages.expand(2, passages.size(1)) + sent_encoder = sent_encoder.expand(2, sent_encoder.size(1), sent_encoder.size(2)) + posit_ids = posit_ids.expand(2, posit_ids.size(1)) + + po1, po2 = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, posit_ids=posit_ids, + is_eval=is_eval) + if flag: + po1 = po1[1, :, :].unsqueeze(0) + po2 = po2[1, :, :].unsqueeze(0) + + po1_list.append(po1) + po2_list.append(po2) + + po1_tensor = torch.cat(po1_list).to(sent_encoder.device) + po2_tensor = torch.cat(po2_list).to(sent_encoder.device) + + s_tensor = torch.tensor(start_list, dtype=torch.long).to(sent_encoder.device) + e_tensor = torch.tensor(end_list, dtype=torch.long).to(sent_encoder.device) + + return qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor diff --git a/models/ere_net/mpn.py b/models/ere_net/mpn.py new file mode 100644 index 0000000..0d30656 --- /dev/null +++ b/models/ere_net/mpn.py @@ -0,0 +1,248 @@ +import warnings + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer + +from layers.encoders.rnns.stacked_rnn import StackedBRNN + +warnings.filterwarnings("ignore") + + +class SentenceEncoder(nn.Module): + def __init__(self, args, input_size): + super(SentenceEncoder, self).__init__() + rnn_type = nn.LSTM if args.rnn_encoder == 'lstm' else nn.GRU + self.encoder = StackedBRNN( + input_size=input_size, + hidden_size=args.hidden_size, + num_layers=args.num_layers, + dropout_rate=args.dropout, + dropout_output=True, + concat_layers=False, + rnn_type=rnn_type, + padding=True + ) + + def forward(self, input, mask): + return self.encoder(input, mask) + + +class EntityNET(nn.Module): + """ + EntityNET : entity extraction using pointer network + """ + + def __init__(self, args, char_emb): + super(EntityNET, self).__init__() + + if char_emb is not None: + self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False, + padding_idx=0) + else: + self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size, + padding_idx=0) + + self.sentence_encoder = SentenceEncoder(args, args.word_emb_size) + self.s1 = nn.Linear(args.hidden_size * 2, 1) + self.s2 = nn.Linear(args.hidden_size * 2, 1) + + def forward(self, q_ids=None, eval_file=None, passages=None, s1=None, s2=None, is_eval=False): + mask = passages.eq(0) + sequence_mask = passages != 0 + + char_emb = self.char_emb(passages) + + sent_encoder = self.sentence_encoder(char_emb, mask) + + s1_ = self.s1(sent_encoder).squeeze() + s2_ = self.s2(sent_encoder).squeeze() + + if not is_eval: + loss_fct = nn.BCEWithLogitsLoss(reduction='none') + + sb1_loss = loss_fct(s1_, s1) + s1_loss = torch.sum(sb1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) + + s2_loss = loss_fct(s2_, s2) + s2_loss = torch.sum(s2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) + + ent_loss = s1_loss + s2_loss + return sent_encoder, ent_loss + else: + answer_list = self.predict(eval_file, q_ids, s1_, s2_) + return sent_encoder, answer_list + + def predict(self, eval_file, q_ids=None, s1=None, s2=None): + sub_ans_list = list() + for qid, p1, p2 in zip(q_ids.cpu().numpy(), + s1.cpu().numpy(), + s2.cpu().numpy()): + start = None + end = None + threshold = 0.0 + positions = list() + for idx in range(0, len(eval_file[qid].context)): + if p1[idx] > threshold and start is None: + start = idx + if p2[idx] > threshold and end is None: + end = idx + if start is not None and end is not None and start <= end: + positions.append((start, end + 1)) + start = None + end = None + sub_ans_list.append(positions) + + return sub_ans_list + + +class RelNET(nn.Module): + """ + ERENet : entity relation extraction + """ + + def __init__(self, args, spo_conf): + super(RelNET, self).__init__() + self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size, + padding_idx=0) + self.sentence_encoder = SentenceEncoder(args, args.word_emb_size) + self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead) + + self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers) + + self.classes_num = len(spo_conf) + self.po1 = nn.Linear(args.hidden_size * 2, self.classes_num) + self.po2 = nn.Linear(args.hidden_size * 2, self.classes_num) + + def forward(self, passages=None, sent_encoder=None, token_type_id=None, po1=None, po2=None, is_eval=False): + mask = passages.eq(0) + sequence_mask = passages != 0 + + subject_encoder = sent_encoder + self.token_entity_emb(token_type_id) + sent_sub_aware_encoder = self.sentence_encoder(subject_encoder, mask).transpose(1, 0) + + transformer_encoder = self.transformer_encoder(sent_sub_aware_encoder, src_key_padding_mask=mask).transpose(0, + 1) + + po1_ = self.po1(transformer_encoder) + po2_ = self.po2(transformer_encoder) + + if not is_eval: + loss_fct = nn.BCEWithLogitsLoss(reduction='none') + + po1_loss = loss_fct(po1_, po1) + po1_loss = torch.sum(po1_loss, 2) + po1_loss = torch.sum(po1_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num + + po2_loss = loss_fct(po2_, po2) + po2_loss = torch.sum(po2_loss, 2) + po2_loss = torch.sum(po2_loss * sequence_mask.float()) / torch.sum(sequence_mask.float()) / self.classes_num + + rel_loss = po1_loss + po2_loss + + return rel_loss + + else: + po1 = nn.Sigmoid()(po1_) + po2 = nn.Sigmoid()(po2_) + return po1, po2 + + +class ERENet(nn.Module): + """ + ERENet : entity relation jointed extraction with Multi-label Pointer Network(MPN) based Entity-aware + """ + + def __init__(self, args, char_emb, spo_conf): + super(ERENet, self).__init__() + print('joint entity relation extraction') + self.entity_extraction = EntityNET(args, char_emb) + self.rel_extraction = RelNET(args, spo_conf) + + def forward(self, q_ids=None, eval_file=None, passages=None, token_type_ids=None, segment_ids=None, s1=None, + s2=None, po1=None, po2=None, is_eval=False): + + if not is_eval: + sent_encoder, ent_loss = self.entity_extraction(passages=passages, s1=s1, s2=s2, is_eval=is_eval) + rel_loss = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, token_type_id=token_type_ids, + po1=po1, po2=po2, is_eval=False) + total_loss = ent_loss + rel_loss + + return total_loss + else: + + sent_encoder, answer_list = self.entity_extraction(q_ids=q_ids, eval_file=eval_file, + passages=passages, is_eval=is_eval) + start_list, end_list = list(), list() + qid_list, pass_list, posit_list, sent_list = list(), list(), list(), list() + for i, ans_list in enumerate(answer_list): + seq_len = passages.size(1) + posit_ids = [] + for ans_tuple in ans_list: + posit_array = np.zeros(seq_len, dtype=np.int) + start, end = ans_tuple[0], ans_tuple[1] + start_list.append(start) + end_list.append(end) + posit_array[start:end] = 1 + posit_ids.append(posit_array) + + if len(posit_ids) == 0: + continue + qid_ = q_ids[i].unsqueeze(0).expand(len(posit_ids)) + sent_tensor = sent_encoder[i, :, :].unsqueeze(0).expand(len(posit_ids), sent_encoder.size(1), + sent_encoder.size(2)) + pass_tensor = passages[i, :].unsqueeze(0).expand(len(posit_ids), passages.size(1)) + posit_tensor = torch.tensor(posit_ids, dtype=torch.long).to(sent_encoder.device) + + qid_list.append(qid_) + pass_list.append(pass_tensor) + posit_list.append(posit_tensor) + sent_list.append(sent_tensor) + + if len(qid_list) == 0: + # print('len(qid_list)==0:') + qid_tensor = torch.tensor([-1, -1], dtype=torch.long).to(sent_encoder.device) + return qid_tensor, qid_tensor, qid_tensor, qid_tensor, qid_tensor + + qid_tensor = torch.cat(qid_list).to(sent_encoder.device) + sent_tensor = torch.cat(sent_list).to(sent_encoder.device) + pass_tensor = torch.cat(pass_list).to(sent_encoder.device) + posi_tensor = torch.cat(posit_list).to(sent_encoder.device) + + flag = False + split_heads = 1024 + + inputs = torch.split(pass_tensor, split_heads, dim=0) + posits = torch.split(posi_tensor, split_heads, dim=0) + sents = torch.split(sent_tensor, split_heads, dim=0) + + po1_list, po2_list = list(), list() + for i in range(len(inputs)): + passages = inputs[i] + sent_encoder = sents[i] + posit_ids = posits[i] + + if passages.size(0) == 1: + flag = True + # print('flag = True**********') + passages = passages.expand(2, passages.size(1)) + sent_encoder = sent_encoder.expand(2, sent_encoder.size(1), sent_encoder.size(2)) + posit_ids = posit_ids.expand(2, posit_ids.size(1)) + + po1, po2 = self.rel_extraction(passages=passages, sent_encoder=sent_encoder, token_type_id=posit_ids, + is_eval=is_eval) + if flag: + po1 = po1[1, :, :].unsqueeze(0) + po2 = po2[1, :, :].unsqueeze(0) + + po1_list.append(po1) + po2_list.append(po2) + + po1_tensor = torch.cat(po1_list).to(sent_encoder.device) + po2_tensor = torch.cat(po2_list).to(sent_encoder.device) + + s_tensor = torch.tensor(start_list, dtype=torch.long).to(sent_encoder.device) + e_tensor = torch.tensor(end_list, dtype=torch.long).to(sent_encoder.device) + + return qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor diff --git a/run/attribute_extract/mpn/data.py b/run/attribute_extract/mpn/data_loader.py similarity index 100% rename from run/attribute_extract/mpn/data.py rename to run/attribute_extract/mpn/data_loader.py diff --git a/run/attribute_extract/mpn/main.py b/run/attribute_extract/mpn/main.py index 7c141fa..aba4db6 100644 --- a/run/attribute_extract/mpn/main.py +++ b/run/attribute_extract/mpn/main.py @@ -3,7 +3,7 @@ import argparse import logging import os -from run.attribute_extract.mpn.data import Reader, Vocabulary, config, Feature +from run.attribute_extract.mpn.data_loader import Reader, Vocabulary, config, Feature from run.attribute_extract.mpn.train import Trainer from utils.file_util import save, load diff --git a/run/entity_relation_jointed_extraction/mpn/data_loader.py b/run/entity_relation_jointed_extraction/mpn/data_loader.py index 02dee4b..eafeef7 100644 --- a/run/entity_relation_jointed_extraction/mpn/data_loader.py +++ b/run/entity_relation_jointed_extraction/mpn/data_loader.py @@ -93,8 +93,9 @@ class Reader(object): return self._read(filename, data_type) def _data_process(self, filename, data_type='train'): + output_data = list() with codecs.open(filename, 'r') as f: - output_data = list() + gold_num = 0 for line in tqdm(f): data_json = json.loads(line.strip()) text = data_json['text'].lower() @@ -108,12 +109,13 @@ class Reader(object): o_start, o_end = find_position(object_name, text) if text[s_start:s_end] != subject_name: - print(subject_name) + # print(subject_name) subject_name = spo['subject'].lower().replace('》', '').replace('《', '') s_start, s_end = find_position(subject_name, text) if s_start != -1 and o_start != -1: sub_ent_list.append((subject_name, s_start, s_end)) spo_list.append((subject_name, spo['predicate'], object_name)) + if subject_name not in sub_po_dict: sub_po_dict[subject_name] = {} sub_po_dict[subject_name]['sub_pos'] = [s_start, s_end] @@ -127,10 +129,12 @@ class Reader(object): text_spo['sub_po_dict'] = sub_po_dict text_spo['spo_list'] = list(set(spo_list)) text_spo['sub_ent_list'] = list(set(sub_ent_list)) + gold_num += len(set(spo_list)) output_data.append(text_spo) if data_type == 'train': return self._convert_train_data(output_data) + # print(f'total gold num is {gold_num}') return output_data @staticmethod @@ -300,10 +304,10 @@ class Feature(object): return self.token2idx_dict[token] return self.token2idx_dict[""] - def __call__(self, examples, entity_type, data_type): + def __call__(self, examples, data_type): if self.bert: - return self.convert_examples_to_bert_features(examples, entity_type, data_type) + return self.convert_examples_to_bert_features(examples, data_type) else: return self.convert_examples_to_features(examples, data_type) @@ -328,13 +332,15 @@ class Feature(object): s1[start] = 1.0 s2[end - 1] = 1.0 + for i, token in enumerate(example.context): + passage_id[i] = self.token2wid(token) + if data_type == 'train': sub_start, sub_end = example.sub_pos[0], example.sub_pos[1] for i, token in enumerate(example.context): if sub_start <= i < sub_end: # token = "" token_type_id[i] = 1 - passage_id[i] = self.token2wid(token) pos_start_id[i] = example.relative_pos_start[i] pos_end_id[i] = example.relative_pos_end[i] @@ -355,36 +361,41 @@ class Feature(object): logging.info("Built instances is Completed") return SPODataset(examples2features, predict_num=len(BAIDU_RELATION), data_type=data_type) - def convert_examples_to_bert_features(self, examples, entity_type, data_type): + def convert_examples_to_bert_features(self, examples, data_type): logging.info("Processing {} examples...".format(data_type)) examples2features = list() for index, example in enumerate(examples): - gold_attr_list = example.gold_attr_list - ent_start, ent_end = example.entity_position[0], example.entity_position[1] segment_id = np.zeros(len(example.context) + 2, dtype=np.int) token_type_id = np.zeros(len(example.context) + 2, dtype=np.int) pos_start_id = np.zeros(len(example.context) + 2, dtype=np.int) pos_end_id = np.zeros(len(example.context) + 2, dtype=np.int) + s1 = np.zeros(len(example.context) + 2, dtype=np.float) + s2 = np.zeros(len(example.context) + 2, dtype=np.float) + + for (_, start, end) in example.sub_entity_list: + if start >= len(example.context) or end >= len(example.context): + continue + s1[start + 1] = 1.0 + s2[end] = 1.0 tokens = ["[CLS]"] - raw_tokens = ["[CLS]"] for i, token in enumerate(example.context): - raw_tokens.append(token) - if ent_start <= i < ent_end: - # token_type_id[i + 1] = 1 - # segment_id[i + 1] = 1 - token = '[unused1]' tokens.append(token) - pos_start_id[i + 1] = example.pos_start[i] - pos_end_id[i + 1] = example.pos_end[i] - tokens.append("[SEP]") - raw_tokens.append("[SEP]") passage_id = self.tokenizer.convert_tokens_to_ids(tokens) - example.bert_tokens = raw_tokens + example.bert_tokens = tokens + + if data_type == 'train': + sub_start, sub_end = example.sub_pos[0], example.sub_pos[1] + for i, token in enumerate(example.context): + if sub_start <= i < sub_end: + token_type_id[i + 1] = 1 + pos_start_id[i + 1] = example.relative_pos_start[i] + pos_end_id[i + 1] = example.relative_pos_end[i] + examples2features.append( InputFeature( p_id=index, @@ -393,11 +404,13 @@ class Feature(object): pos_start_id=pos_start_id, pos_end_id=pos_end_id, segment_id=segment_id, - po_label=gold_attr_list + po_label=example.po_list, + s1=s1, + s2=s2 )) logging.info("Built instances is Completed") - return SPODataset(examples2features, predict_num=len(BAIDU_RELATION), use_bert=True) + return SPODataset(examples2features, predict_num=len(BAIDU_RELATION), use_bert=True,data_type=data_type) class SPODataset(Dataset): @@ -445,8 +458,8 @@ class SPODataset(Dataset): s2_tensor, _ = padding(s2, is_float=True, batch_first=batch_first) po1_tensor, po2_tensor = spo_padding(passages, label, class_num=self.predict_num, is_float=True, use_bert=self.use_bert) - return p_ids, passages_tensor, segment_tensor, token_type_tensor, pos_start_tensor, pos_end_tensor, \ - s1_tensor, s2_tensor, po1_tensor, po2_tensor + return p_ids, passages_tensor, segment_tensor, token_type_tensor, s1_tensor, s2_tensor, po1_tensor, \ + po2_tensor else: p_ids, passages, segment_ids = zip(*examples) p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long) diff --git a/run/entity_relation_jointed_extraction/mpn/main.py b/run/entity_relation_jointed_extraction/mpn/main.py index a7a957c..6714f17 100644 --- a/run/entity_relation_jointed_extraction/mpn/main.py +++ b/run/entity_relation_jointed_extraction/mpn/main.py @@ -63,6 +63,7 @@ def get_args(): parser.add_argument('--pos_size', type=int, default=62) parser.add_argument('--hidden_size', type=int, default=150) + parser.add_argument('--bert_hidden_size', type=int, default=768) parser.add_argument('--num_layers', type=int, default=2) parser.add_argument('--dropout', type=int, default=0.5) parser.add_argument('--rnn_encoder', type=str, default='lstm', help="must choose in blow: lstm or gru") @@ -127,10 +128,8 @@ def bulid_dataset(args, reader, vocab, debug=False): train_examples = train_examples[:10] if debug else train_examples dev_examples = dev_examples[:10] if debug else dev_examples - train_data_set = convert_examples_features(train_examples, entity_type=args.entity_type, - data_type='train') - dev_data_set = convert_examples_features(dev_examples, entity_type=args.entity_type, - data_type='dev') + train_data_set = convert_examples_features(train_examples, data_type='train') + dev_data_set = convert_examples_features(dev_examples, data_type='dev') train_data_loader = train_data_set.get_dataloader(args.train_batch_size, shuffle=True, pin_memory=args.pin_memory) dev_data_loader = dev_data_set.get_dataloader(args.train_batch_size) @@ -139,6 +138,7 @@ def bulid_dataset(args, reader, vocab, debug=False): return eval_examples, data_loaders, char_emb + def main(): args = get_args() if not os.path.exists(args.output): @@ -159,8 +159,8 @@ def main(): if args.train_mode == "train": trainer.train(args) elif args.train_mode == "eval": - trainer.resume(args) - trainer.eval_data_set("train") + # trainer.resume(args) + # trainer.eval_data_set("train") trainer.eval_data_set("dev") elif args.train_mode == "resume": # trainer.resume(args) diff --git a/run/entity_relation_jointed_extraction/mpn/train.py b/run/entity_relation_jointed_extraction/mpn/train.py index cc87d8a..5d81a6c 100644 --- a/run/entity_relation_jointed_extraction/mpn/train.py +++ b/run/entity_relation_jointed_extraction/mpn/train.py @@ -9,8 +9,8 @@ import torch from torch import nn from tqdm import tqdm -import models.attribute_extract_net.bert_mpn as bert_mpn -import models.attribute_extract_net.mpn as mpn +import models.ere_net.bert_mpn as bert_mpn +import models.ere_net.mpn as mpn from utils.optimizer_util import set_optimizer logger = logging.getLogger(__name__) @@ -20,17 +20,17 @@ class Trainer(object): def __init__(self, args, data_loaders, examples, char_emb, spo_conf): if args.use_bert: - self.model = bert_mpn.AttributeExtractNet.from_pretrained(args.bert_model, args, spo_conf) + self.model = bert_mpn.ERENet(args, spo_conf) else: - self.model = mpn.AttributeExtractNet(args, char_emb, spo_conf) + self.model = mpn.ERENet(args, char_emb, spo_conf) self.args = args self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu") self.n_gpu = torch.cuda.device_count() - self.id2rel = {item: key for key, item in attribute_conf.items()} - self.rel2id =attribute_conf + self.id2rel = {item: key for key, item in spo_conf.items()} + self.rel2id = spo_conf random.seed(args.seed) np.random.seed(args.seed) @@ -69,7 +69,7 @@ class Trainer(object): for step, batch in tqdm(enumerate(self.data_loader_choice[u"train"]), mininterval=5, desc=u'training at epoch : %d ' % epoch, leave=False, file=sys.stdout): - loss, answer_dict_ = self.forward(batch) + loss = self.forward(batch) if step % step_gap == 0: global_loss += loss @@ -99,47 +99,48 @@ class Trainer(object): checkpoint = torch.load(resume_model_file, map_location='cpu') self.model.load_state_dict(checkpoint) - def forward(self, batch, chosen=u'train', grad=True, eval=False, detail=False): + def forward(self, batch, chosen=u'train', eval=False, answer_dict=None): batch = tuple(t.to(self.device) for t in batch) + if not eval: - p_ids, passage_id, token_type_id, segment_id, pos_start,pos_end,start_id, end_id = batch - loss, po1, po2 = self.model(passage_id=passage_id, token_type_id=token_type_id, segment_id=segment_id, - pos_start=pos_start,pos_end=pos_end,start_id=start_id, end_id=end_id, is_eval=eval) + p_ids, input_ids, segment_ids, token_type_ids, s1, s2, po1, po2 = batch + loss = self.model(passages=input_ids, token_type_ids=token_type_ids, segment_ids=segment_ids, s1=s1, s2=s2, + po1=po1, po2=po2, + is_eval=eval) + if self.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu. - if self.n_gpu > 1: - loss = loss.mean() # mean() to average on multi-gpu. - - if grad: loss.backward() loss = loss.item() self.optimizer.step() self.optimizer.zero_grad() - - if eval: - eval_file = self.eval_file_choice[chosen] - answer_dict_ = convert_pointer_net_contour(eval_file, p_ids, po1, po2, self.id2rel, - use_bert=self.args.use_bert) + return loss else: - answer_dict_ = None - return loss, answer_dict_ + p_ids, input_ids, segment_ids = batch + eval_file = self.eval_file_choice[chosen] + qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor = self.model(q_ids=p_ids, eval_file=eval_file, + passages=input_ids, is_eval=eval) + ans_dict = self.convert_spo_contour(qid_tensor, po1_tensor, po2_tensor, s_tensor, e_tensor, eval_file, + answer_dict, use_bert=self.args.use_bert) + return ans_dict def eval_data_set(self, chosen="dev"): self.model.eval() - answer_dict = {} data_loader = self.data_loader_choice[chosen] eval_file = self.eval_file_choice[chosen] + answer_dict = {i: [[], []] for i in range(len(eval_file))} + last_time = time.time() with torch.no_grad(): for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout): - loss, answer_dict_ = self.forward(batch, chosen, grad=False, eval=True) - answer_dict.update(answer_dict_) + self.forward(batch, chosen, eval=True, answer_dict=answer_dict) used_time = time.time() - last_time logging.info('chosen {} took : {} sec'.format(chosen, used_time)) res = self.evaluate(eval_file, answer_dict, chosen) - self.detail_evaluate(eval_file, answer_dict, chosen) + # self.detail_evaluate(eval_file, answer_dict, chosen) self.model.train() return res @@ -152,85 +153,53 @@ class Trainer(object): eval_file = self.eval_file_choice[chosen] with torch.no_grad(): for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout): - loss, answer_dict_ = self.forward(batch, chosen, grad=False, eval=True, detail=True) + loss, answer_dict_ = self.forward(batch, chosen, eval=True) answer_dict.update(answer_dict_) self.badcase_analysis(eval_file, answer_dict, chosen) @staticmethod def evaluate(eval_file, answer_dict, chosen): - em = 0 - pre = 0 - gold = 0 + entity_em = 0 + entity_pred_num = 0 + entity_gold_num = 0 + + triple_em = 0 + triple_pred_num = 0 + triple_gold_num = 0 for key, value in answer_dict.items(): - ground_truths = eval_file[int(key)].gold_answer - value, l1, l2 = value - prediction = list(value) if len(value) else [] - assert type(prediction) == type(ground_truths) - intersection = set(prediction) & set(ground_truths) + triple_gold = eval_file[key].gold_answer + entity_gold = eval_file[key].sub_entity_list - em += len(intersection) - pre += len(set(prediction)) - gold += len(set(ground_truths)) + entity_pred, triple_pred = value - precision = 100.0 * em / pre if pre > 0 else 0. - recall = 100.0 * em / gold if gold > 0 else 0. + entity_em += len(set(entity_pred) & set(entity_gold)) + entity_pred_num += len(set(entity_pred)) + entity_gold_num += len(set(entity_gold)) + + triple_em += len(set(triple_pred) & set(triple_gold)) + triple_pred_num += len(set(triple_pred)) + triple_gold_num += len(set(triple_gold)) + + entity_precision = 100.0 * entity_em / entity_pred_num if entity_pred_num > 0 else 0. + entity_recall = 100.0 * entity_em / entity_gold_num if entity_gold_num > 0 else 0. + entity_f1 = 2 * entity_recall * entity_precision / (entity_recall + entity_precision) if ( + entity_recall + entity_precision) != 0 else 0.0 + + precision = 100.0 * triple_em / triple_pred_num if triple_pred_num > 0 else 0. + recall = 100.0 * triple_em / triple_gold_num if triple_gold_num > 0 else 0. f1 = 2 * recall * precision / (recall + precision) if (recall + precision) != 0 else 0.0 print('============================================') - print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, em, pre, gold)) + print("{}/entity_em: {},\tentity_pred_num&entity_gold_num: {}\t{} ".format(chosen, entity_em, entity_pred_num, + entity_gold_num)) + print( + "{}/entity_f1: {}, \tentity_precision: {},\tentity_recall: {} ".format(chosen, entity_f1, entity_precision, + entity_recall)) + print('============================================') + print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, triple_em, triple_pred_num, triple_gold_num)) print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(chosen, f1, precision, recall)) - return {'f1': f1, "recall": recall, "precision": precision, 'em': em, 'pre': pre, 'gold': gold} - - def detail_evaluate(self, eval_file, answer_dict, chosen): - def generate_detail_dict(spo_list): - dict_detail = dict() - for i, tag in enumerate(spo_list): - detail_name = tag.split('@')[0] - if detail_name not in dict_detail: - dict_detail[detail_name] = [tag] - else: - dict_detail[detail_name].append(tag) - return dict_detail - - total_detail = {} - for key, value in answer_dict.items(): - ground_truths = eval_file[int(key)].gold_answer - value, l1, l2 = value - prediction = list(value) if len(value) else [] - - gold_detail = generate_detail_dict(ground_truths) - pred_detail = generate_detail_dict(prediction) - for key in self.rel2id.keys(): - - pred = pred_detail.get(key, []) - gold = gold_detail.get(key, []) - em = len(set(pred) & set(gold)) - pred_num = len(set(pred)) - gold_num = len(set(gold)) - - if key not in total_detail: - total_detail[key] = dict() - total_detail[key]['em'] = em - total_detail[key]['pred_num'] = pred_num - total_detail[key]['gold_num'] = gold_num - else: - total_detail[key]['em'] += em - total_detail[key]['pred_num'] += pred_num - total_detail[key]['gold_num'] += gold_num - for key, res_dict_ in total_detail.items(): - res_dict_['p'] = 100.0 * res_dict_['em'] / res_dict_['pred_num'] if res_dict_['pred_num'] != 0 else 0.0 - res_dict_['r'] = 100.0 * res_dict_['em'] / res_dict_['gold_num'] if res_dict_['gold_num'] != 0 else 0.0 - res_dict_['f'] = 2 * res_dict_['p'] * res_dict_['r'] / (res_dict_['p'] + res_dict_['r']) if res_dict_['p'] + \ - res_dict_[ - 'r'] != 0 else 0.0 - - for gold_key, res_dict_ in total_detail.items(): - print('===============================================================') - print("{}/em: {},\tpred_num&gold_num: {}\t{} ".format(gold_key, res_dict_['em'], res_dict_['pred_num'], - res_dict_['gold_num'])) - print( - "{}/f1: {},\tprecison&recall: {}\t{}".format(gold_key, res_dict_['f'], res_dict_['p'], res_dict_['r'])) + return {'f1': f1, "recall": recall, "precision": precision} @staticmethod def badcase_analysis(eval_file, answer_dict, chosen): @@ -268,28 +237,29 @@ class Trainer(object): with open('badcase_{}.txt'.format(chosen), 'w') as f: f.write(content) - -def convert_pointer_net_contour(eval_file, q_ids, po1, po2, id2rel, use_bert=False): - answer_dict = dict() - for qid, o1, o2 in zip(q_ids, po1.data.cpu().numpy(), po2.data.cpu().numpy()): - - context = eval_file[qid.item()].context if not use_bert else eval_file[qid.item()].bert_tokens - gold_attr_list = eval_file[qid.item()].gold_attr_list - gold_answer = [attr.attr_type + '@' + attr.value for attr in gold_attr_list] - - answers = list() - start, end = np.where(o1 > 0.5), np.where(o2 > 0.5) - for _start, _attr_type_id_start in zip(*start): - if _start > len(context) or (_start == 0 and use_bert): + def convert_spo_contour(self, qid_tensor, po1, po2, s_tensor, e_tensor, eval_file, answer_dict, use_bert=False): + for qid, s, e, o1, o2 in zip(qid_tensor.data.cpu().numpy(), s_tensor.data.cpu().numpy(), + e_tensor.data.cpu().numpy(), po1.data.cpu().numpy(), po2.data.cpu().numpy()): + if qid == -1: continue - for _end, _attr_type_id_end in zip(*end): - if _start <= _end < len(context) and _attr_type_id_start == _attr_type_id_end: - _attr_value = ''.join(context[_start: _end + 1]) if use_bert else context[_start: _end + 1] - _attr_type = id2rel[_attr_type_id_start] - _attr = _attr_type + '@' + _attr_value - answers.append(_attr) - break + context = eval_file[qid.item()].context if not use_bert else eval_file[qid.item()].bert_tokens + gold_answer = eval_file[qid].gold_answer - answer_dict[str(qid.item())] = [answers, o1, o2] + _subject = ''.join(context[s:e]) if use_bert else context[s:e] + answers = list() + start, end = np.where(o1 > 0.5), np.where(o2 > 0.5) + for _start, _predict_id_start in zip(*start): + if _start > len(context) or (_start == 0 and use_bert): + continue + for _end, _predict_id_end in zip(*end): + if _start <= _end < len(context) and _predict_id_start == _predict_id_end: + _obeject = ''.join(context[_start: _end + 1]) if use_bert else context[_start: _end + 1] + _predicate = self.id2rel[_predict_id_start] + answers.append((_subject, _predicate, _obeject)) + break - return answer_dict + if qid not in answer_dict: + print('erro in answer_dict ') + else: + answer_dict[qid][0].append((_subject, s-1, e-1)) + answer_dict[qid][1].extend(answers) From 824e913df1848199548bac9996b9159c221c82c8 Mon Sep 17 00:00:00 2001 From: loujie0822 Date: Mon, 17 Feb 2020 13:16:03 +0800 Subject: [PATCH 5/5] add baidu spo config --- config/baidu_spo_config.py | 51 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 config/baidu_spo_config.py diff --git a/config/baidu_spo_config.py b/config/baidu_spo_config.py new file mode 100644 index 0000000..64a5243 --- /dev/null +++ b/config/baidu_spo_config.py @@ -0,0 +1,51 @@ +BAIDU_RELATION = { + "朝代": 0, + "人口数量": 1, + "出生地": 2, + "连载网站": 3, + "身高": 4, + "占地面积": 5, + "作者": 6, + "目": 7, + "母亲": 8, + "海拔": 9, + "作词": 10, + "嘉宾": 11, + "总部地点": 12, + "出版社": 13, + "主持人": 14, + "出生日期": 15, + "所在城市": 16, + "修业年限": 17, + "祖籍": 18, + "邮政编码": 19, + "毕业院校": 20, + "气候": 21, + "号": 22, + "注册资本": 23, + "丈夫": 24, + "国籍": 25, + "主角": 26, + "主演": 27, + "民族": 28, + "董事长": 29, + "所属专辑": 30, + "专业代码": 31, + "改编自": 32, + "歌手": 33, + "编剧": 34, + "妻子": 35, + "面积": 36, + "作曲": 37, + "官方语言": 38, + "出品公司": 39, + "成立日期": 40, + "简称": 41, + "首都": 42, + "父亲": 43, + "字": 44, + "制片人": 45, + "上映时间": 46, + "创始人": 47, + "导演": 48 +}