DeepIE/run/attribute_extraction/data_loader.py
2020-05-20 19:33:53 +08:00

385 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
from collections import Counter
from functools import partial
import jieba
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from layers.encoders.transformers.bert.bert_tokenization import BertTokenizer
from utils.data_util import padding, mpn_padding, _handle_pos_limit
config = {
'drug': {
'药品-用药频率': 0,
'药品-持续时间': 1,
'药品-用药剂量': 2,
'药品-用药方法': 3,
'药品-不良反应': 4,
},
'disease': {
'疾病-检查方法': 0,
'疾病-临床表现': 1,
'疾病-非药治疗': 2,
'疾病-药品名称': 3,
'疾病-部位': 4,
}
}
class Attribute(object):
def __init__(self,
value,
value_pos_start,
value_pos_end,
attr_type,
attr_type_id
):
self.value = value
self.value_pos_start = value_pos_start
self.value_pos_end = value_pos_end
self.attr_type = attr_type
self.attr_type_id = attr_type_id
class Example(object):
def __init__(self,
p_id=None,
context=None,
bert_tokens=None,
entity_name=None,
entity_position=None,
pos_start=None,
pos_end=None,
gold_attr_list=None,
gold_answer=None):
self.p_id = p_id
self.context = context
self.bert_tokens = bert_tokens
self.entity_name = entity_name
self.entity_position = entity_position
self.pos_start = pos_start
self.pos_end = pos_end
self.gold_attr_list = gold_attr_list
self.gold_answer = gold_answer
def __str__(self):
return self.__repr__()
def __repr__(self):
gold_repr = []
for attr in self.gold_attr_list:
gold_repr.append(attr.attr_type + '-' + attr.value)
return 'entity {} with attribute:\n{}'.format(self.entity_name, '\t'.join(gold_repr))
class InputFeature(object):
def __init__(self,
p_id=None,
passage_id=None,
token_type_id=None,
pos_start_id=None,
pos_end_id=None,
segment_id=None,
label=None):
self.p_id = p_id
self.passage_id = passage_id
self.token_type_id = token_type_id
self.pos_start_id = pos_start_id
self.pos_end_id = pos_end_id
self.segment_id = segment_id
self.label = label
class Reader(object):
def __init__(self, do_lowercase=False, seg_char=False, max_len=600, entity_type="药品名称"):
self.do_lowercase = do_lowercase
self.seg_char = seg_char
self.max_len = max_len
self.entity_config = config[entity_type]
if self.seg_char:
logging.info("seg_char...")
else:
logging.info("seg_word using jieba ...")
def read_examples(self, filename, data_type):
logging.info("Generating {} examples...".format(data_type))
return self._read(filename, data_type)
def _read(self, filename, data_type):
with open(filename, 'r') as fh:
source_data = json.load(fh)
examples = []
for p_id in tqdm(range(len(source_data))):
data = source_data[p_id]
para = data['text']
context = para if self.seg_char else ''.join(jieba.lcut(para))
if len(context) > self.max_len:
context = context[:self.max_len]
context = context.lower() if self.do_lowercase else context
_data_dict = dict()
_data_dict['id'] = p_id
entity_name, entity_position = data['entity'][0], data['entity'][1]
entity_name = entity_name.lower() if self.do_lowercase else entity_name
start, end = entity_position
assert entity_name == context[start:end]
# pos_start&pos_end: 指句子中词语相对entity的position限制
# 如:[-30, 30]embed 时整体+31变成[1, 61]
# 则一共62个pos token0 留给 pad
pos_start = list(map(lambda i: i - start, list(range(len(context)))))
pos_end = list(map(lambda i: i - end, list(range(len(context)))))
pos_start = _handle_pos_limit(pos_start)
pos_end = _handle_pos_limit(pos_end)
attribute_list = data['attribute_list']
gold_attr_list = []
for attribute in attribute_list:
attr_type = attribute['type']
value = attribute['value']
value, value_pos_start, value_pos_end = value[0], value[1], value[2]
value = value.lower() if self.do_lowercase else value
assert value == context[value_pos_start:value_pos_end]
gold_attr_list.append(Attribute(
value=value,
value_pos_start=value_pos_start,
value_pos_end=value_pos_end,
attr_type=attr_type,
attr_type_id=self.entity_config[attr_type]
))
gold_answer = [
attr.attr_type + '@' + attr.value + str(attr.value_pos_start) + '-' + str(attr.value_pos_end + 1)
for attr in gold_attr_list]
examples.append(
Example(
p_id=p_id,
context=context,
entity_name=entity_name,
entity_position=entity_position,
pos_start=pos_start,
pos_end=pos_end,
gold_attr_list=gold_attr_list,
gold_answer=gold_answer
)
)
logging.info("{} total size is {} ".format(data_type, len(examples)))
return examples
class Vocabulary(object):
def __init__(self, special_tokens=["<OOV>", "<MASK>"]):
self.char_vocab = None
self.emb_mat = None
self.char2idx = None
self.char_counter = Counter()
self.special_tokens = special_tokens
def build_vocab_only_with_char(self, examples, min_char_count=-1):
logging.info("Building vocabulary only with character...")
self.char_vocab = ["<PAD>"]
if self.special_tokens is not None and isinstance(self.special_tokens, list):
self.char_vocab.extend(self.special_tokens)
for example in tqdm(examples):
for char in example.context:
self.char_counter[char] += 1
for w, v in self.char_counter.most_common():
if v >= min_char_count:
self.char_vocab.append(w)
self.char2idx = {token: idx for idx, token in enumerate(self.char_vocab)}
logging.info("total char counter size is {} ".format(len(self.char_counter)))
logging.info("total char vocabulary size is {} ".format(len(self.char_vocab)))
def _load_embedding(self, embedding_file, embedding_dict):
with open(embedding_file) as f:
for line in f:
if len(line.rstrip().split(" ")) <= 2: continue
token, vector = line.rstrip().split(" ", 1)
embedding_dict[token] = np.fromstring(vector, dtype=np.float, sep=" ")
return embedding_dict
def make_embedding(self, vocab, embedding_file, emb_size):
embedding_dict = dict()
embedding_dict["<PAD>"] = np.array([0. for _ in range(emb_size)])
self._load_embedding(embedding_file, embedding_dict)
count = 0
for token in tqdm(vocab):
if token not in embedding_dict:
count += 1
embedding_dict[token] = np.array([np.random.normal(scale=0.1) for _ in range(emb_size)])
logging.info(
"{} / {} tokens have corresponding in embedding vector".format(len(vocab) - count, len(vocab)))
emb_mat = [embedding_dict[token] for idx, token in enumerate(vocab)]
return emb_mat
class Feature(object):
def __init__(self, args, token2idx_dict):
self.bert = args.use_bert
self.token2idx_dict = token2idx_dict
if self.bert:
self.tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
def token2wid(self, token):
if token in self.token2idx_dict:
return self.token2idx_dict[token]
return self.token2idx_dict["<OOV>"]
def __call__(self, examples, entity_type, data_type):
if self.bert:
return self.convert_examples_to_bert_features(examples, entity_type, data_type)
else:
return self.convert_examples_to_features(examples, entity_type, data_type)
def convert_examples_to_features(self, examples, entity_type, data_type):
logging.info("Processing {} examples...".format(data_type))
examples2features = list()
for index, example in enumerate(examples):
gold_attr_list = example.gold_attr_list
ent_start, ent_end = example.entity_position[0], example.entity_position[1]
passage_id = np.zeros(len(example.context), dtype=np.int)
token_type_id = np.zeros(len(example.context), dtype=np.int)
pos_start_id = np.zeros(len(example.context), dtype=np.int)
pos_end_id = np.zeros(len(example.context), dtype=np.int)
for i, token in enumerate(example.context):
if ent_start <= i < ent_end:
# token = "<MASK>"
token_type_id[i] = 1
passage_id[i] = self.token2wid(token)
pos_start_id[i] = example.pos_start[i]
pos_end_id[i] = example.pos_end[i]
examples2features.append(
InputFeature(
p_id=index,
passage_id=passage_id,
token_type_id=token_type_id,
pos_start_id=pos_start_id,
pos_end_id=pos_end_id,
segment_id=token_type_id,
label=gold_attr_list
))
logging.info("Built instances is Completed")
return AttributeMPNDataset(examples2features, attribute_num=len(config[entity_type]))
def convert_examples_to_bert_features(self, examples, entity_type, data_type):
logging.info("Processing {} examples...".format(data_type))
examples2features = list()
for index, example in enumerate(examples):
gold_attr_list = example.gold_attr_list
ent_start, ent_end = example.entity_position[0], example.entity_position[1]
segment_id = np.zeros(len(example.context) + 2, dtype=np.int)
token_type_id = np.zeros(len(example.context) + 2, dtype=np.int)
pos_start_id = np.zeros(len(example.context) + 2, dtype=np.int)
pos_end_id = np.zeros(len(example.context) + 2, dtype=np.int)
tokens = ["[CLS]"]
raw_tokens = ["[CLS]"]
for i, token in enumerate(example.context):
raw_tokens.append(token)
if ent_start <= i < ent_end:
# token_type_id[i + 1] = 1
# segment_id[i + 1] = 1
token = '[unused1]'
tokens.append(token)
pos_start_id[i + 1] = example.pos_start[i]
pos_end_id[i + 1] = example.pos_end[i]
tokens.append("[SEP]")
raw_tokens.append("[SEP]")
passage_id = self.tokenizer.convert_tokens_to_ids(tokens)
example.bert_tokens = raw_tokens
examples2features.append(
InputFeature(
p_id=index,
passage_id=passage_id,
token_type_id=token_type_id,
pos_start_id=pos_start_id,
pos_end_id=pos_end_id,
segment_id=segment_id,
label=gold_attr_list
))
logging.info("Built instances is Completed")
return AttributeMPNDataset(examples2features, attribute_num=len(config[entity_type]), use_bert=True)
class AttributeMPNDataset(Dataset):
def __init__(self, features, attribute_num, use_bert=False):
super(AttributeMPNDataset, self).__init__()
self.use_bert = use_bert
self.q_ids = [f.p_id for f in features]
self.passages = [f.passage_id for f in features]
self.token_type = [f.token_type_id for f in features]
self.label = [f.label for f in features]
self.attribute_num = attribute_num
self.segment_ids = [f.segment_id for f in features]
self.pos_start_ids = [f.pos_start_id for f in features]
self.pos_end_ids = [f.pos_end_id for f in features]
def __len__(self):
return len(self.passages)
def __getitem__(self, index):
return self.q_ids[index], self.passages[index], self.token_type[index], self.segment_ids[index], self.label[
index], self.pos_start_ids[index], self.pos_end_ids[index]
def _create_collate_fn(self, batch_first=False):
def collate(examples):
p_ids, passages, token_type, segment_ids, label, pos_start_ids, pos_end_ids = zip(*examples)
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
passages_tensor, _ = padding(passages, is_float=False, batch_first=batch_first)
pos_start_tensor, _ = padding(pos_start_ids, is_float=False, batch_first=batch_first)
pos_end_tensor, _ = padding(pos_end_ids, is_float=False, batch_first=batch_first)
token_type_tensor, _ = padding(token_type, is_float=False, batch_first=batch_first)
segment_tensor, _ = padding(segment_ids, is_float=False, batch_first=batch_first)
o1_tensor, o2_tensor = mpn_padding(passages, label, class_num=self.attribute_num, is_float=True,
use_bert=self.use_bert)
return p_ids, passages_tensor, token_type_tensor, segment_tensor, pos_start_tensor, pos_end_tensor, o1_tensor, o2_tensor
return partial(collate)
def get_dataloader(self, batch_size, num_workers=0, shuffle=False, batch_first=True, pin_memory=False,
drop_last=False):
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._create_collate_fn(batch_first),
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)