DeepIE/run/relation_extraction/etl_stl/data_loader.py
2020-05-20 19:32:45 +08:00

358 lines
14 KiB
Python

import codecs
import json
import logging
import random
from collections import Counter
from functools import partial
import jieba
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from config.spo_config_v1 import BAIDU_BIES,BAIDU_ENTITY
from utils.data_util import Tokenizer, search, sequence_padding
class PredictObject(object):
def __init__(self,
object_name,
object_start,
object_end,
predict_type,
predict_type_id
):
self.object_name = object_name
self.object_start = object_start
self.object_end = object_end
self.predict_type = predict_type
self.predict_type_id = predict_type_id
class Example(object):
def __init__(self,
p_id=None,
context=None,
bert_tokens=None,
text_word=None,
sub_pos=None,
sub_entity_list=None,
relative_pos_start=None,
relative_pos_end=None,
po_list=None,
gold_answer=None,
token_ids=None):
self.p_id = p_id
self.context = context
self.text_word = text_word
self.bert_tokens = bert_tokens
self.sub_pos = sub_pos
self.sub_entity_list = sub_entity_list
self.relative_pos_start = relative_pos_start
self.relative_pos_end = relative_pos_end
self.po_list = po_list
self.gold_answer = gold_answer
self.token_ids = token_ids
class InputFeature(object):
def __init__(self,
p_id=None,
passage_id=None,
token_type_id=None,
pos_start_id=None,
pos_end_id=None,
segment_id=None,
po_label=None,
s1=None,
s2=None):
self.p_id = p_id
self.passage_id = passage_id
self.token_type_id = token_type_id
self.pos_start_id = pos_start_id
self.pos_end_id = pos_end_id
self.segment_id = segment_id
self.po_label = po_label
self.s1 = s1
self.s2 = s2
class Reader(object):
def __init__(self, do_lowercase=False, seg_char=False):
self.do_lowercase = do_lowercase
self.seg_char = seg_char
self.relation_config = BAIDU_BIES
if self.seg_char:
logging.info("seg_char...")
else:
logging.info("seg_word using jieba ...")
def read_examples(self, filename, data_type):
logging.info("Generating {} examples...".format(data_type))
return self._read(filename, data_type)
def _read(self, filename, data_type):
examples = []
with codecs.open(filename, 'r') as f:
gold_num = 0
p_id = 0
for line in tqdm(f):
p_id += 1
data_json = json.loads(line.strip())
text = data_json['text'].lower().replace(' ', '')
text_word = jieba.lcut(text)
sub_po_dict, sub_ent_list, spo_list = dict(), list(), list()
for spo in data_json['spo_list']:
subject_name = spo['subject'].lower().replace(' ', '')
object_name = spo['object'].lower().replace(' ', '')
sub_ent_list.append(subject_name)
spo_list.append((subject_name, spo['predicate'], object_name))
examples.append(
Example(
p_id=p_id,
context=text,
text_word=text_word,
sub_entity_list=list(set(sub_ent_list)),
gold_answer=spo_list
)
)
gold_num += len(set(spo_list))
logging.info('total gold num is {}'.format(gold_num))
logging.info("{} total size is {} ".format(data_type, len(examples)))
return examples
class Vocabulary(object):
def __init__(self, special_tokens=["<OOV>", "<MASK>"]):
self.char_vocab = None
self.emb_mat = None
self.char2idx = dict()
self.word2idx = dict()
self.char_counter = Counter()
self.word_counter = Counter()
self.special_tokens = special_tokens
def build_vocab_only_with_char(self, examples, min_char_count=-1, min_word_count=-1):
logging.info("Building vocabulary only with character...")
self.char_vocab = ["<PAD>"]
self.word_vocab = ["<PAD>"]
if self.special_tokens is not None and isinstance(self.special_tokens, list):
self.char_vocab.extend(self.special_tokens)
self.word_vocab.extend(self.special_tokens)
for example in tqdm(examples):
for char in example.context:
self.char_counter[char] += 1
for word in example.text_word:
self.word_counter[word] += 1
for c, v in self.char_counter.most_common():
if v >= min_char_count:
self.char_vocab.append(c)
for w, v in self.word_counter.most_common():
if v >= min_word_count:
self.word_vocab.append(w)
self.char2idx = {token: idx for idx, token in enumerate(self.char_vocab)}
logging.info("total char counter size is {} ".format(len(self.char_counter)))
logging.info("total char vocabulary size is {} ".format(len(self.char_vocab)))
logging.info("total word vocabulary size without embedding is {} ".format(len(self.word_vocab)))
def _load_embedding(self, embedding_file, embedding_dict):
with open(embedding_file) as f:
for line in f:
if len(line.rstrip().split(" ")) <= 2: continue
token, vector = line.rstrip().split(" ", 1)
embedding_dict[token] = np.fromstring(vector, dtype=np.float, sep=" ")
return embedding_dict
def make_embedding(self, vocab, embedding_file, emb_size):
embedding_dict = dict()
embedding_dict["<PAD>"] = np.array([0. for _ in range(emb_size)])
self._load_embedding(embedding_file, embedding_dict)
logging.info("total embedding size is {} ".format(len(embedding_dict)))
# emb_mat = [embedding_dict[token] for token in vocab if token in embedding_dict]
#
# index = 0
# for token in embedding_dict.keys():
# if token in vocab:
# self.word2idx.update({token: index})
# index += 1
count = 0
emb_mat = []
index = 0
for token in tqdm(vocab):
if token in embedding_dict.keys():
self.word2idx.update({token: index})
emb_mat.append(embedding_dict[token])
index += 1
else:
count += 1
logging.info(
"{} / {} tokens have corresponding in embedding vector".format(len(vocab) - count, len(vocab)))
logging.info("total word vocabulary size is {} ".format(len(self.word2idx)))
return emb_mat
class Feature(object):
def __init__(self, args, char2idx, word2idx):
self.bert = args.use_bert
self.char2idx = char2idx
self.word2idx = word2idx
self.max_len = args.max_len
if self.bert:
self.tokenizer = Tokenizer(args.bert_model + '/vocab.txt', do_lower_case=True)
def __call__(self, examples, data_type):
if self.bert:
return self.convert_examples_to_bert_features(examples, data_type)
else:
return self.convert_examples_to_features(examples, data_type)
def convert_examples_to_bert_features(self, examples, data_type):
logging.info("convert {} examples to features .".format(data_type))
examples2features = list()
for index, example in enumerate(examples):
examples2features.append((index, example))
logging.info("Built instances is Completed")
return SPOBERTDataset(examples2features, use_bert=True, data_type=data_type,
tokenizer=self.tokenizer, max_len=self.max_len)
def convert_examples_to_features(self, examples, data_type):
logging.info("convert {} examples to features .".format(data_type))
examples2features = list()
for index, example in enumerate(examples):
examples2features.append((index, example))
logging.info("Built instances is Completed")
return SPODataset(examples2features, use_bert=True, data_type=data_type,
word2idx=self.word2idx, char2idx=self.char2idx, max_len=self.max_len)
class SPODataset(Dataset):
def __init__(self, data, data_type, use_bert=False, word2idx=None, char2idx=None, max_len=128):
super(SPODataset, self).__init__()
self.use_bert = use_bert
self.word2idx = word2idx
self.char2idx = char2idx
self.max_len = max_len
self.q_ids = [f[0] for f in data]
self.features = [f[1] for f in data]
self.is_train = True if data_type == 'train' else False
def __len__(self):
return len(self.q_ids)
def __getitem__(self, index):
return self.q_ids[index], self.features[index]
def _create_collate_fn(self, batch_first=False):
def collate(examples):
p_ids, examples = zip(*examples)
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
batch_char_ids, batch_word_ids = [], []
batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], []
for example in examples:
# todo maxlen
char_ids = [self.char2idx.get(char, 1) for char in example.context]
word_ids = [self.word2idx.get(word, 0) for word in example.text_word for _ in word]
if len(char_ids) != len(word_ids):
print(example.context)
print(char_ids)
print(len(char_ids))
print(example.text_word)
print(word_ids)
print(len(word_ids))
assert len(char_ids) == len(word_ids)
char_ids = char_ids[:self.max_len]
word_ids = word_ids[:self.max_len]
# example.context = example.context[:self.max_len]
if self.is_train:
spoes = {}
for s, p, o in example.gold_answer:
s = [self.char2idx.get(s_, 1) for s_ in s]
# p = BAIDU_RELATION[p]
o = [self.char2idx.get(o_, 1) for o_ in o]
s_idx = search(s, char_ids)
o_idx = search(o, char_ids)
if s_idx != -1 and o_idx != -1:
s = (s_idx, s_idx + len(s) - 1)
o = (o_idx, o_idx + len(o) - 1, p)
if s not in spoes:
spoes[s] = []
spoes[s].append(o)
if spoes:
# subject标签
token_type_ids = np.zeros(len(char_ids), dtype=np.long)
subject_labels = np.zeros(len(char_ids), dtype=np.int)
for s in spoes:
subject_labels[s[0]]=BAIDU_ENTITY['B']
for index in range(s[0] + 1, s[1] + 1):
subject_labels[index] = BAIDU_ENTITY['I']
# 随机选一个subject
subject_ids = random.choice(list(spoes.keys()))
token_type_ids[subject_ids[0]:subject_ids[1] + 1] = 1
# 对应的object标签
object_labels = np.zeros(len(char_ids), dtype=np.int)
for o in spoes.get(subject_ids, []):
object_labels[o[0]] = BAIDU_BIES['B' + '-' + o[2]]
for index in range(o[0] + 1, o[1] + 1):
object_labels[index] = BAIDU_BIES['I' + '-' + o[2]]
batch_char_ids.append(char_ids)
batch_word_ids.append(word_ids)
batch_token_type_ids.append(token_type_ids)
batch_subject_labels.append(subject_labels)
batch_subject_ids.append(subject_ids)
batch_object_labels.append(object_labels)
else:
batch_char_ids.append(char_ids)
batch_word_ids.append(word_ids)
batch_char_ids = sequence_padding(batch_char_ids, is_float=False)
batch_word_ids = sequence_padding(batch_word_ids, is_float=False)
if not self.is_train:
return p_ids, batch_char_ids, batch_word_ids
else:
batch_token_type_ids = sequence_padding(batch_token_type_ids, is_float=False)
batch_subject_ids = torch.tensor(batch_subject_ids)
batch_subject_labels = sequence_padding(batch_subject_labels, is_float=False)
batch_object_labels = sequence_padding(batch_object_labels, is_float=False)
return batch_char_ids, batch_word_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
return partial(collate)
def get_dataloader(self, batch_size, num_workers=0, shuffle=False, batch_first=True, pin_memory=False,
drop_last=False):
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._create_collate_fn(batch_first),
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)