提交spo的transformer框架
This commit is contained in:
parent
348bd5353e
commit
6b337c5a28
51
config/spo_config_v2.py
Normal file
51
config/spo_config_v2.py
Normal file
@ -0,0 +1,51 @@
|
||||
BAIDU_RELATION = {
|
||||
"朝代": 0,
|
||||
"人口数量": 1,
|
||||
"出生地": 2,
|
||||
"连载网站": 3,
|
||||
"身高": 4,
|
||||
"占地面积": 5,
|
||||
"作者": 6,
|
||||
"目": 7,
|
||||
"母亲": 8,
|
||||
"海拔": 9,
|
||||
"作词": 10,
|
||||
"嘉宾": 11,
|
||||
"总部地点": 12,
|
||||
"出版社": 13,
|
||||
"主持人": 14,
|
||||
"出生日期": 15,
|
||||
"所在城市": 16,
|
||||
"修业年限": 17,
|
||||
"祖籍": 18,
|
||||
"邮政编码": 19,
|
||||
"毕业院校": 20,
|
||||
"气候": 21,
|
||||
"号": 22,
|
||||
"注册资本": 23,
|
||||
"丈夫": 24,
|
||||
"国籍": 25,
|
||||
"主角": 26,
|
||||
"主演": 27,
|
||||
"民族": 28,
|
||||
"董事长": 29,
|
||||
"所属专辑": 30,
|
||||
"专业代码": 31,
|
||||
"改编自": 32,
|
||||
"歌手": 33,
|
||||
"编剧": 34,
|
||||
"妻子": 35,
|
||||
"面积": 36,
|
||||
"作曲": 37,
|
||||
"官方语言": 38,
|
||||
"出品公司": 39,
|
||||
"成立日期": 40,
|
||||
"简称": 41,
|
||||
"首都": 42,
|
||||
"父亲": 43,
|
||||
"字": 44,
|
||||
"制片人": 45,
|
||||
"上映时间": 46,
|
||||
"创始人": 47,
|
||||
"导演": 48
|
||||
}
|
162
models/spo_net/multi_pointer_net.py
Normal file
162
models/spo_net/multi_pointer_net.py
Normal file
@ -0,0 +1,162 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BertModel
|
||||
from transformers import BertPreTrainedModel
|
||||
|
||||
from layers.encoders.transformers.bert.layernorm import ConditionalLayerNorm
|
||||
from utils.data_util import batch_gather
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
class ERENet(BertPreTrainedModel):
|
||||
"""
|
||||
ERENet : entity relation jointed extraction
|
||||
"""
|
||||
|
||||
def __init__(self, config, classes_num):
|
||||
super(ERENet, self).__init__(config, classes_num)
|
||||
self.classes_num = classes_num
|
||||
|
||||
# BERT model
|
||||
self.bert = BertModel(config)
|
||||
|
||||
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
|
||||
padding_idx=0)
|
||||
# self.encoder_layer = TransformerEncoderLayer(config.hidden_size, nhead=4)
|
||||
# self.transformer_encoder = TransformerEncoder(self.encoder_layer, num_layers=1)
|
||||
self.LayerNorm = ConditionalLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
# pointer net work
|
||||
self.po_dense = nn.Linear(config.hidden_size, self.classes_num * 2)
|
||||
self.subject_dense = nn.Linear(config.hidden_size, 2)
|
||||
self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, q_ids=None, passage_ids=None, segment_ids=None, token_type_ids=None, subject_ids=None,
|
||||
subject_labels=None,
|
||||
object_labels=None, eval_file=None,
|
||||
is_eval=False):
|
||||
mask = (passage_ids != 0).float()
|
||||
bert_encoder = self.bert(passage_ids, token_type_ids=segment_ids, attention_mask=mask)[0]
|
||||
if not is_eval:
|
||||
# subject_encoder = self.token_entity_emb(token_type_ids)
|
||||
# context_encoder = bert_encoder + subject_encoder
|
||||
|
||||
sub_start_encoder = batch_gather(bert_encoder, subject_ids[:, 0])
|
||||
sub_end_encoder = batch_gather(bert_encoder, subject_ids[:, 1])
|
||||
subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
|
||||
context_encoder = self.LayerNorm(bert_encoder, subject)
|
||||
|
||||
sub_preds = self.subject_dense(bert_encoder)
|
||||
po_preds = self.po_dense(context_encoder).reshape(passage_ids.size(0), -1, self.classes_num, 2)
|
||||
|
||||
subject_loss = self.loss_fct(sub_preds, subject_labels)
|
||||
# subject_loss = F.binary_cross_entropy(F.sigmoid(sub_preds) ** 2, subject_labels, reduction='none')
|
||||
subject_loss = subject_loss.mean(2)
|
||||
subject_loss = torch.sum(subject_loss * mask.float()) / torch.sum(mask.float())
|
||||
|
||||
po_loss = self.loss_fct(po_preds, object_labels)
|
||||
# po_loss = F.binary_cross_entropy(F.sigmoid(po_preds) ** 4, object_labels, reduction='none')
|
||||
po_loss = torch.sum(po_loss.mean(3), 2)
|
||||
po_loss = torch.sum(po_loss * mask.float()) / torch.sum(mask.float())
|
||||
|
||||
loss = subject_loss + po_loss
|
||||
|
||||
return loss
|
||||
|
||||
else:
|
||||
|
||||
subject_preds = nn.Sigmoid()(self.subject_dense(bert_encoder))
|
||||
answer_list = list()
|
||||
for qid, sub_pred in zip(q_ids.cpu().numpy(),
|
||||
subject_preds.cpu().numpy()):
|
||||
context = eval_file[qid].bert_tokens
|
||||
start = np.where(sub_pred[:, 0] > 0.6)[0]
|
||||
end = np.where(sub_pred[:, 1] > 0.5)[0]
|
||||
subjects = []
|
||||
for i in start:
|
||||
j = end[end >= i]
|
||||
if i == 0 or i > len(context) - 2:
|
||||
continue
|
||||
|
||||
if len(j) > 0:
|
||||
j = j[0]
|
||||
if j > len(context) - 2:
|
||||
continue
|
||||
subjects.append((i, j))
|
||||
|
||||
answer_list.append(subjects)
|
||||
|
||||
qid_ids, bert_encoders, pass_ids, subject_ids, token_type_ids = [], [], [], [], []
|
||||
for i, subjects in enumerate(answer_list):
|
||||
if subjects:
|
||||
qid = q_ids[i].unsqueeze(0).expand(len(subjects))
|
||||
pass_tensor = passage_ids[i, :].unsqueeze(0).expand(len(subjects), passage_ids.size(1))
|
||||
new_bert_encoder = bert_encoder[i, :, :].unsqueeze(0).expand(len(subjects), bert_encoder.size(1),
|
||||
bert_encoder.size(2))
|
||||
|
||||
token_type_id = torch.zeros((len(subjects), passage_ids.size(1)), dtype=torch.long)
|
||||
for index, (start, end) in enumerate(subjects):
|
||||
token_type_id[index, start:end + 1] = 1
|
||||
|
||||
qid_ids.append(qid)
|
||||
pass_ids.append(pass_tensor)
|
||||
subject_ids.append(torch.tensor(subjects, dtype=torch.long))
|
||||
bert_encoders.append(new_bert_encoder)
|
||||
token_type_ids.append(token_type_id)
|
||||
|
||||
if len(qid_ids) == 0:
|
||||
subject_ids = torch.zeros(1, 2).long().to(bert_encoder.device)
|
||||
qid_tensor = torch.tensor([-1], dtype=torch.long).to(bert_encoder.device)
|
||||
po_tensor = torch.zeros(1, bert_encoder.size(1)).long().to(bert_encoder.device)
|
||||
return qid_tensor, subject_ids, po_tensor
|
||||
|
||||
qids = torch.cat(qid_ids).to(bert_encoder.device)
|
||||
pass_ids = torch.cat(pass_ids).to(bert_encoder.device)
|
||||
bert_encoders = torch.cat(bert_encoders).to(bert_encoder.device)
|
||||
# token_type_ids = torch.cat(token_type_ids).to(bert_encoder.device)
|
||||
subject_ids = torch.cat(subject_ids).to(bert_encoder.device)
|
||||
|
||||
flag = False
|
||||
split_heads = 1024
|
||||
|
||||
bert_encoders_ = torch.split(bert_encoders, split_heads, dim=0)
|
||||
pass_ids_ = torch.split(pass_ids, split_heads, dim=0)
|
||||
# token_type_ids_ = torch.split(token_type_ids, split_heads, dim=0)
|
||||
subject_encoder_ = torch.split(subject_ids, split_heads, dim=0)
|
||||
|
||||
po_preds = list()
|
||||
for i in range(len(bert_encoders_)):
|
||||
bert_encoders = bert_encoders_[i]
|
||||
# token_type_ids = token_type_ids_[i]
|
||||
pass_ids = pass_ids_[i]
|
||||
subject_encoder = subject_encoder_[i]
|
||||
|
||||
if bert_encoders.size(0) == 1:
|
||||
flag = True
|
||||
# print('flag = True**********')
|
||||
bert_encoders = bert_encoders.expand(2, bert_encoders.size(1), bert_encoders.size(2))
|
||||
subject_encoder = subject_encoder.expand(2, subject_encoder.size(1))
|
||||
# pass_ids = pass_ids.expand(2, pass_ids.size(1))
|
||||
|
||||
sub_start_encoder = batch_gather(bert_encoders, subject_encoder[:, 0])
|
||||
sub_end_encoder = batch_gather(bert_encoders, subject_encoder[:, 1])
|
||||
subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
|
||||
context_encoder = self.LayerNorm(bert_encoders, subject)
|
||||
|
||||
# context_encoder = self.LayerNorm(context_encoder)
|
||||
po_pred = self.po_dense(context_encoder).reshape(subject_encoder.size(0), -1, self.classes_num, 2)
|
||||
|
||||
if flag:
|
||||
po_pred = po_pred[1, :, :, :].unsqueeze(0)
|
||||
|
||||
po_preds.append(po_pred)
|
||||
|
||||
po_tensor = torch.cat(po_preds).to(qids.device)
|
||||
po_tensor = nn.Sigmoid()(po_tensor)
|
||||
return qids, subject_ids, po_tensor
|
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
sklearn==0.0
|
||||
torch==1.2.0
|
||||
tqdm==4.35.0
|
||||
transformers==2.2.0
|
||||
|
239
run/spo_extraction/transformers_multi_label_span/data_loader.py
Normal file
239
run/spo_extraction/transformers_multi_label_span/data_loader.py
Normal file
@ -0,0 +1,239 @@
|
||||
import codecs
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.data_util import search, sequence_padding
|
||||
|
||||
|
||||
class PredictObject(object):
|
||||
def __init__(self,
|
||||
object_name,
|
||||
object_start,
|
||||
object_end,
|
||||
predict_type,
|
||||
predict_type_id
|
||||
):
|
||||
self.object_name = object_name
|
||||
self.object_start = object_start
|
||||
self.object_end = object_end
|
||||
self.predict_type = predict_type
|
||||
self.predict_type_id = predict_type_id
|
||||
|
||||
|
||||
class Example(object):
|
||||
def __init__(self,
|
||||
p_id=None,
|
||||
context=None,
|
||||
bert_tokens=None,
|
||||
sub_entity_list=None,
|
||||
gold_answer=None, ):
|
||||
self.p_id = p_id
|
||||
self.context = context
|
||||
self.bert_tokens = bert_tokens
|
||||
self.sub_entity_list = sub_entity_list
|
||||
self.gold_answer = gold_answer
|
||||
|
||||
|
||||
class InputFeature(object):
|
||||
|
||||
def __init__(self,
|
||||
p_id=None,
|
||||
passage_id=None,
|
||||
token_type_id=None,
|
||||
pos_start_id=None,
|
||||
pos_end_id=None,
|
||||
segment_id=None,
|
||||
po_label=None,
|
||||
s1=None,
|
||||
s2=None):
|
||||
self.p_id = p_id
|
||||
self.passage_id = passage_id
|
||||
self.token_type_id = token_type_id
|
||||
self.pos_start_id = pos_start_id
|
||||
self.pos_end_id = pos_end_id
|
||||
self.segment_id = segment_id
|
||||
self.po_label = po_label
|
||||
self.s1 = s1
|
||||
self.s2 = s2
|
||||
|
||||
|
||||
class Reader(object):
|
||||
def __init__(self, seg_char=True):
|
||||
self.seg_char = seg_char
|
||||
|
||||
def read_examples(self, filename, data_type):
|
||||
logging.info("Generating {} examples...".format(data_type))
|
||||
return self._read(filename, data_type)
|
||||
|
||||
def _read(self, filename, data_type):
|
||||
|
||||
examples = []
|
||||
with codecs.open(filename, 'r') as f:
|
||||
gold_num = 0
|
||||
p_id = 0
|
||||
for line in tqdm(f):
|
||||
p_id += 1
|
||||
data_json = json.loads(line.strip())
|
||||
|
||||
text = data_json['text'].lower().replace(' ', '')
|
||||
sub_po_dict, sub_ent_list, spo_list = dict(), list(), list()
|
||||
|
||||
for spo in data_json['spo_list']:
|
||||
# todo 注意 当前将gold ans全部转换为小写,应在预测的时候转为原来的形式
|
||||
subject_name = spo['subject'].lower()
|
||||
object_name = spo['object'].lower()
|
||||
sub_ent_list.append(subject_name)
|
||||
spo_list.append((subject_name, spo['predicate'], object_name))
|
||||
|
||||
examples.append(
|
||||
Example(
|
||||
p_id=p_id,
|
||||
context=text,
|
||||
sub_entity_list=list(set(sub_ent_list)),
|
||||
gold_answer=spo_list
|
||||
)
|
||||
)
|
||||
gold_num += len(set(spo_list))
|
||||
print('total gold num is {}'.format(gold_num))
|
||||
|
||||
logging.info("{} total size is {} ".format(data_type, len(examples)))
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
class Feature(object):
|
||||
def __init__(self, max_len, spo_config, tokenizer):
|
||||
self.max_len = max_len
|
||||
self.spo_config = spo_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def __call__(self, examples, data_type):
|
||||
return self.convert_examples_to_bert_features(examples, data_type)
|
||||
|
||||
def convert_examples_to_bert_features(self, examples, data_type):
|
||||
logging.info("convert {} examples to features .".format(data_type))
|
||||
|
||||
examples2features = list()
|
||||
for index, example in enumerate(examples):
|
||||
examples2features.append((index, example))
|
||||
|
||||
logging.info("Built instances is Completed")
|
||||
return SPODataset(examples2features, spo_config=self.spo_config, data_type=data_type,
|
||||
tokenizer=self.tokenizer, max_len=self.max_len)
|
||||
|
||||
|
||||
class SPODataset(Dataset):
|
||||
def __init__(self, data, spo_config, data_type, tokenizer=None, max_len=128):
|
||||
super(SPODataset, self).__init__()
|
||||
self.spo_config = spo_config
|
||||
self.tokenizer = tokenizer
|
||||
self.max_len = max_len
|
||||
self.q_ids = [f[0] for f in data]
|
||||
self.features = [f[1] for f in data]
|
||||
self.is_train = True if data_type == 'train' else False
|
||||
|
||||
def __len__(self):
|
||||
return len(self.q_ids)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.q_ids[index], self.features[index]
|
||||
|
||||
def _create_collate_fn(self):
|
||||
def collate(examples):
|
||||
p_ids, examples = zip(*examples)
|
||||
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
|
||||
batch_token_ids, batch_segment_ids = [], []
|
||||
batch_token_type_ids, batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], [], []
|
||||
for example in examples:
|
||||
token_ids = truncate_sequence(self.tokenizer.encode(example.context), max_length=self.max_len)
|
||||
segment_ids = len(token_ids) * [0]
|
||||
|
||||
example.bert_tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
|
||||
example.token_ids = token_ids
|
||||
|
||||
if self.is_train:
|
||||
spoes = {}
|
||||
for s, p, o in example.gold_answer:
|
||||
s = self.tokenizer.encode(s)[1:-1]
|
||||
p = self.spo_config[p]
|
||||
o = self.tokenizer.encode(o)[1:-1]
|
||||
# todo 注意这里search的方法应该有多种形式的
|
||||
s_idx = search(s, token_ids)
|
||||
o_idx = search(o, token_ids)
|
||||
if s_idx != -1 and o_idx != -1:
|
||||
s = (s_idx, s_idx + len(s) - 1)
|
||||
o = (o_idx, o_idx + len(o) - 1, p)
|
||||
if s not in spoes:
|
||||
spoes[s] = []
|
||||
spoes[s].append(o)
|
||||
|
||||
if spoes:
|
||||
# subject标签
|
||||
token_type_ids = np.zeros(len(token_ids), dtype=np.long)
|
||||
subject_labels = np.zeros((len(token_ids), 2), dtype=np.float32)
|
||||
for s in spoes:
|
||||
subject_labels[s[0], 0] = 1
|
||||
subject_labels[s[1], 1] = 1
|
||||
# 随机选一个subject
|
||||
subject_ids = random.choice(list(spoes.keys()))
|
||||
# start, end = np.array(list(spoes.keys())).T
|
||||
# start = np.random.choice(start)
|
||||
# end = np.random.choice(end[end >= start])
|
||||
# token_type_ids[start:end + 1] = 1
|
||||
# subject_ids = (start, end)
|
||||
# 对应的object标签
|
||||
object_labels = np.zeros((len(token_ids), len(self.spo_config), 2), dtype=np.float32)
|
||||
for o in spoes.get(subject_ids, []):
|
||||
object_labels[o[0], o[2], 0] = 1
|
||||
object_labels[o[1], o[2], 1] = 1
|
||||
batch_token_ids.append(token_ids)
|
||||
batch_token_type_ids.append(token_type_ids)
|
||||
|
||||
batch_segment_ids.append(segment_ids)
|
||||
batch_subject_labels.append(subject_labels)
|
||||
batch_subject_ids.append(subject_ids)
|
||||
batch_object_labels.append(object_labels)
|
||||
else:
|
||||
batch_token_ids.append(token_ids)
|
||||
batch_segment_ids.append(segment_ids)
|
||||
|
||||
batch_token_ids = sequence_padding(batch_token_ids, is_float=False)
|
||||
batch_segment_ids = sequence_padding(batch_segment_ids, is_float=False)
|
||||
if not self.is_train:
|
||||
return p_ids, batch_token_ids, batch_segment_ids
|
||||
else:
|
||||
batch_token_type_ids = sequence_padding(batch_token_type_ids, is_float=False)
|
||||
batch_subject_ids = torch.tensor(batch_subject_ids)
|
||||
batch_subject_labels = sequence_padding(batch_subject_labels, padding=np.zeros(2), is_float=True)
|
||||
batch_object_labels = sequence_padding(batch_object_labels, padding=np.zeros((len(self.spo_config), 2)),
|
||||
is_float=True)
|
||||
return batch_token_ids, batch_segment_ids, batch_token_type_ids, batch_subject_ids, batch_subject_labels, batch_object_labels
|
||||
|
||||
return partial(collate)
|
||||
|
||||
def get_dataloader(self, batch_size, num_workers=0, shuffle=False, pin_memory=False,
|
||||
drop_last=False):
|
||||
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._create_collate_fn(),
|
||||
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)
|
||||
|
||||
|
||||
def truncate_sequence(first_sequence,
|
||||
max_length,
|
||||
pop_index=-2):
|
||||
"""截断总长度
|
||||
"""
|
||||
while True:
|
||||
total_length = len(first_sequence)
|
||||
if total_length <= max_length:
|
||||
break
|
||||
else:
|
||||
first_sequence.pop(pop_index)
|
||||
|
||||
return first_sequence
|
141
run/spo_extraction/transformers_multi_label_span/main.py
Normal file
141
run/spo_extraction/transformers_multi_label_span/main.py
Normal file
@ -0,0 +1,141 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from config import spo_config_v1, spo_config_v2
|
||||
from run.spo_extraction.transformers_multi_label_span.data_loader import Reader, Feature
|
||||
from run.spo_extraction.transformers_multi_label_span.train import Trainer
|
||||
from utils.file_util import save, load
|
||||
|
||||
logger = logging.getLogger()
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# file parameters
|
||||
parser.add_argument("--input", default=None, type=str, required=True)
|
||||
parser.add_argument("--output"
|
||||
, default=None, type=str, required=False,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
|
||||
# choice parameters
|
||||
parser.add_argument('--baidu_spo_version', type=str, default="v1")
|
||||
|
||||
# train parameters
|
||||
parser.add_argument('--train_mode', type=str, default="train")
|
||||
parser.add_argument("--train_batch_size", default=4, type=int, help="Total batch size for training.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--epoch_num", default=3, type=int,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument('--patience_stop', type=int, default=10, help='Patience for learning early stop')
|
||||
parser.add_argument('--device_id', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
||||
|
||||
# bert parameters
|
||||
parser.add_argument("--do_lower_case",
|
||||
action='store_true',
|
||||
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
|
||||
"of training.")
|
||||
parser.add_argument("--bert_model", default=None, type=str,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||
"bert-base-multilingual-cased, bert-base-chinese.")
|
||||
# parser.add_argument("--tokenizer_path", default='bert-base-chinese', type=str)
|
||||
|
||||
# model parameters
|
||||
parser.add_argument("--max_len", default=1000, type=int)
|
||||
parser.add_argument('--entity_emb_size', type=int, default=300)
|
||||
parser.add_argument('--pos_limit', type=int, default=30)
|
||||
parser.add_argument('--pos_dim', type=int, default=300)
|
||||
parser.add_argument('--pos_size', type=int, default=62)
|
||||
|
||||
parser.add_argument('--hidden_size', type=int, default=150)
|
||||
parser.add_argument('--bert_hidden_size', type=int, default=768)
|
||||
parser.add_argument('--dropout', type=int, default=0.5)
|
||||
parser.add_argument('--bidirectional', type=bool, default=True)
|
||||
parser.add_argument('--pin_memory', type=bool, default=False)
|
||||
args = parser.parse_args()
|
||||
args.cache_data = args.input + '/char_bert_cache_data/'
|
||||
return args
|
||||
|
||||
|
||||
def bulid_dataset(args, spo_config, reader, debug=False):
|
||||
train_src = args.input + "/train_data.json"
|
||||
dev_src = args.input + "/dev_data.json"
|
||||
|
||||
train_examples_file = args.cache_data + "/train-examples.pkl"
|
||||
dev_examples_file = args.cache_data + "/dev-examples.pkl"
|
||||
|
||||
if not os.path.exists(train_examples_file):
|
||||
train_examples = reader.read_examples(train_src, data_type='train')
|
||||
dev_examples = reader.read_examples(dev_src, data_type='dev')
|
||||
save(train_examples_file, train_examples, message="train examples")
|
||||
save(dev_examples_file, dev_examples, message="dev examples")
|
||||
else:
|
||||
train_examples, dev_examples = load(train_examples_file), load(dev_examples_file)
|
||||
|
||||
logging.info('train examples size is {}'.format(len(train_examples)))
|
||||
logging.info('dev examples size is {}'.format(len(dev_examples)))
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model,
|
||||
do_lower_case=True)
|
||||
|
||||
convert_examples_features = Feature(max_len=args.max_len, spo_config=spo_config, tokenizer=tokenizer)
|
||||
|
||||
train_examples = train_examples[:2] if debug else train_examples
|
||||
dev_examples = dev_examples[:2] if debug else dev_examples
|
||||
|
||||
train_data_set = convert_examples_features(train_examples, data_type='train')
|
||||
dev_data_set = convert_examples_features(dev_examples, data_type='dev')
|
||||
train_data_loader = train_data_set.get_dataloader(args.train_batch_size, shuffle=True, pin_memory=args.pin_memory)
|
||||
dev_data_loader = dev_data_set.get_dataloader(args.train_batch_size)
|
||||
|
||||
data_loaders = train_data_loader, dev_data_loader
|
||||
eval_examples = train_examples, dev_examples
|
||||
|
||||
return eval_examples, data_loaders, tokenizer
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
if not os.path.exists(args.output):
|
||||
print('mkdir {}'.format(args.output))
|
||||
os.makedirs(args.output)
|
||||
if not os.path.exists(args.cache_data):
|
||||
print('mkdir {}'.format(args.cache_data))
|
||||
os.makedirs(args.cache_data)
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
logger.info("** ** * bulid dataset ** ** * ")
|
||||
|
||||
spo_conf = spo_config_v1.BAIDU_RELATION if args.baidu_spo_version == 'v1' else spo_config_v2.BAIDU_RELATION
|
||||
|
||||
eval_examples, data_loaders, tokenizer = bulid_dataset(args, spo_conf, Reader(), debug=False)
|
||||
trainer = Trainer(args, data_loaders, eval_examples, spo_conf=spo_conf,tokenizer=tokenizer)
|
||||
|
||||
if args.train_mode == "train":
|
||||
trainer.train(args)
|
||||
elif args.train_mode == "eval":
|
||||
# trainer.resume(args)
|
||||
# trainer.eval_data_set("train")
|
||||
trainer.eval_data_set("dev")
|
||||
elif args.train_mode == "resume":
|
||||
# trainer.resume(args)
|
||||
trainer.show("dev") # bad case analysis
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
278
run/spo_extraction/transformers_multi_label_span/train.py
Normal file
278
run/spo_extraction/transformers_multi_label_span/train.py
Normal file
@ -0,0 +1,278 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
import models.spo_net.multi_pointer_net as mpn
|
||||
from layers.encoders.transformers.bert.bert_optimization import BertAdam
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, args, data_loaders, examples, spo_conf, tokenizer):
|
||||
|
||||
class SPO(tuple):
|
||||
"""用来存三元组的类
|
||||
表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法,
|
||||
使得在判断两个三元组是否等价时容错性更好。
|
||||
"""
|
||||
|
||||
def __init__(self, spo):
|
||||
self.spox = (
|
||||
tuple(tokenizer.tokenize(spo[0].replace(' ', ''))),
|
||||
spo[1],
|
||||
tuple(tokenizer.tokenize(spo[2].replace(' ', ''))),
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return self.spox.__hash__()
|
||||
|
||||
def __eq__(self, spo):
|
||||
return self.spox == spo.spox
|
||||
|
||||
self.args = args
|
||||
self.spo_tuple = SPO
|
||||
self.tokenizer = tokenizer
|
||||
self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu")
|
||||
self.n_gpu = torch.cuda.device_count()
|
||||
|
||||
self.id2rel = {item: key for key, item in spo_conf.items()}
|
||||
self.rel2id = spo_conf
|
||||
|
||||
if self.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
self.model = mpn.ERENet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
|
||||
|
||||
self.model.to(self.device)
|
||||
# self.resume(args)
|
||||
logging.info('total gpu num is {}'.format(self.n_gpu))
|
||||
if self.n_gpu > 1:
|
||||
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
|
||||
|
||||
train_dataloader, dev_dataloader = data_loaders
|
||||
train_eval, dev_eval = examples
|
||||
self.eval_file_choice = {
|
||||
"train": train_eval,
|
||||
"dev": dev_eval,
|
||||
}
|
||||
self.data_loader_choice = {
|
||||
"train": train_dataloader,
|
||||
"dev": dev_dataloader,
|
||||
}
|
||||
# todo 稍后要改成新的优化器,并加入梯度截断
|
||||
self.optimizer = self.set_optimizer(args, self.model,
|
||||
train_steps=(int(
|
||||
len(train_eval) / args.train_batch_size) + 1) * args.epoch_num)
|
||||
|
||||
def set_optimizer(self, args, model, train_steps=None):
|
||||
param_optimizer = list(model.named_parameters())
|
||||
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
||||
'weight_decay': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=train_steps)
|
||||
return optimizer
|
||||
|
||||
def train(self, args):
|
||||
|
||||
best_f1 = 0.0
|
||||
patience_stop = 0
|
||||
self.model.train()
|
||||
step_gap = 20
|
||||
for epoch in range(int(args.epoch_num)):
|
||||
|
||||
global_loss = 0.0
|
||||
|
||||
for step, batch in tqdm(enumerate(self.data_loader_choice[u"train"]), mininterval=5,
|
||||
desc=u'training at epoch : %d ' % epoch, leave=False, file=sys.stdout):
|
||||
|
||||
loss = self.forward(batch)
|
||||
|
||||
if step % step_gap == 0:
|
||||
global_loss += loss
|
||||
current_loss = global_loss / step_gap
|
||||
print(
|
||||
u"step {} / {} of epoch {}, train/loss: {}".format(step, len(self.data_loader_choice["train"]),
|
||||
epoch, current_loss))
|
||||
global_loss = 0.0
|
||||
|
||||
if step % 500 == 0 and epoch >= 6:
|
||||
res_dev = self.eval_data_set("dev")
|
||||
if res_dev['f1'] >= best_f1:
|
||||
best_f1 = res_dev['f1']
|
||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||
model_to_save = self.model.module if hasattr(self.model,
|
||||
'module') else self.model # Only save the model it-self
|
||||
output_model_file = args.output + "/pytorch_model.bin"
|
||||
torch.save(model_to_save.state_dict(), str(output_model_file))
|
||||
patience_stop = 0
|
||||
else:
|
||||
patience_stop += 1
|
||||
if patience_stop >= args.patience_stop:
|
||||
return
|
||||
|
||||
res_dev = self.eval_data_set("dev")
|
||||
if res_dev['f1'] >= best_f1:
|
||||
best_f1 = res_dev['f1']
|
||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||
model_to_save = self.model.module if hasattr(self.model,
|
||||
'module') else self.model # Only save the model it-self
|
||||
output_model_file = args.output + "/pytorch_model.bin"
|
||||
torch.save(model_to_save.state_dict(), str(output_model_file))
|
||||
patience_stop = 0
|
||||
else:
|
||||
patience_stop += 1
|
||||
if patience_stop >= args.patience_stop:
|
||||
return
|
||||
|
||||
def resume(self, args):
|
||||
resume_model_file = args.output + "/pytorch_model.bin"
|
||||
logging.info("=> loading checkpoint '{}'".format(resume_model_file))
|
||||
checkpoint = torch.load(resume_model_file, map_location='cpu')
|
||||
self.model.load_state_dict(checkpoint)
|
||||
|
||||
def forward(self, batch, chosen=u'train', eval=False, answer_dict=None):
|
||||
|
||||
batch = tuple(t.to(self.device) for t in batch)
|
||||
if not eval:
|
||||
input_ids, segment_ids, token_type_ids, subject_ids, subject_labels, object_labels = batch
|
||||
loss = self.model(passage_ids=input_ids, segment_ids=segment_ids, token_type_ids=token_type_ids,
|
||||
subject_ids=subject_ids, subject_labels=subject_labels, object_labels=object_labels)
|
||||
if self.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
|
||||
loss.backward()
|
||||
loss = loss.item()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
return loss
|
||||
else:
|
||||
p_ids, input_ids, segment_ids = batch
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
qids, subject_pred, po_pred = self.model(q_ids=p_ids,
|
||||
passage_ids=input_ids,
|
||||
segment_ids=segment_ids,
|
||||
eval_file=eval_file, is_eval=eval)
|
||||
ans_dict = self.convert_spo_contour(qids, subject_pred, po_pred, eval_file,
|
||||
answer_dict)
|
||||
return ans_dict
|
||||
|
||||
def eval_data_set(self, chosen="dev"):
|
||||
|
||||
self.model.eval()
|
||||
|
||||
data_loader = self.data_loader_choice[chosen]
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
answer_dict = {i: [[], []] for i in range(len(eval_file))}
|
||||
|
||||
last_time = time.time()
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
|
||||
self.forward(batch, chosen, eval=True, answer_dict=answer_dict)
|
||||
used_time = time.time() - last_time
|
||||
logging.info('chosen {} took : {} sec'.format(chosen, used_time))
|
||||
res = self.evaluate(eval_file, answer_dict, chosen)
|
||||
self.model.train()
|
||||
return res
|
||||
|
||||
def show(self, chosen="dev"):
|
||||
|
||||
self.model.eval()
|
||||
answer_dict = {}
|
||||
|
||||
data_loader = self.data_loader_choice[chosen]
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
|
||||
loss, answer_dict_ = self.forward(batch, chosen, eval=True)
|
||||
answer_dict.update(answer_dict_)
|
||||
|
||||
def evaluate(self, eval_file, answer_dict, chosen):
|
||||
|
||||
entity_em = 0
|
||||
entity_pred_num = 0
|
||||
entity_gold_num = 0
|
||||
X, Y, Z = 1e-10, 1e-10, 1e-10
|
||||
for key, value in answer_dict.items():
|
||||
triple_gold = eval_file[key].gold_answer
|
||||
entity_gold = eval_file[key].sub_entity_list
|
||||
|
||||
entity_pred, triple_pred = value
|
||||
|
||||
entity_em += len(set(entity_pred) & set(entity_gold))
|
||||
entity_pred_num += len(set(entity_pred))
|
||||
entity_gold_num += len(set(entity_gold))
|
||||
|
||||
R = set([self.spo_tuple(spo) for spo in triple_pred])
|
||||
T = set([self.spo_tuple(spo) for spo in triple_gold])
|
||||
|
||||
X += len(R & T)
|
||||
Y += len(R)
|
||||
Z += len(T)
|
||||
|
||||
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
|
||||
|
||||
entity_precision = 100.0 * entity_em / entity_pred_num if entity_pred_num > 0 else 0.
|
||||
entity_recall = 100.0 * entity_em / entity_gold_num if entity_gold_num > 0 else 0.
|
||||
entity_f1 = 2 * entity_recall * entity_precision / (entity_recall + entity_precision) if (
|
||||
entity_recall + entity_precision) != 0 else 0.0
|
||||
|
||||
print('============================================')
|
||||
print("{}/entity_em: {},\tentity_pred_num&entity_gold_num: {}\t{} ".format(chosen, entity_em, entity_pred_num,
|
||||
entity_gold_num))
|
||||
print(
|
||||
"{}/entity_f1: {}, \tentity_precision: {},\tentity_recall: {} ".format(chosen, entity_f1, entity_precision,
|
||||
entity_recall))
|
||||
print('============================================')
|
||||
print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, X, Y, Z))
|
||||
print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(chosen, f1 * 100, precision * 100,
|
||||
recall * 100))
|
||||
return {'f1': f1, "recall": recall, "precision": precision}
|
||||
|
||||
def convert_spo_contour(self, qids, subject_preds, po_preds, eval_file, answer_dict):
|
||||
|
||||
for qid, subject, po_pred in zip(qids.data.cpu().numpy(), subject_preds.data.cpu().numpy(),
|
||||
po_preds.data.cpu().numpy()):
|
||||
if qid == -1:
|
||||
continue
|
||||
tokens = eval_file[qid.item()].bert_tokens
|
||||
token_ids = eval_file[qid.item()].token_ids
|
||||
start = np.where(po_pred[:, :, 0] > 0.6)
|
||||
end = np.where(po_pred[:, :, 1] > 0.5)
|
||||
|
||||
spoes = []
|
||||
for _start, predicate1 in zip(*start):
|
||||
if _start > len(tokens) - 2 or _start == 0:
|
||||
continue
|
||||
for _end, predicate2 in zip(*end):
|
||||
if _start <= _end <= len(tokens) - 2 and predicate1 == predicate2:
|
||||
spoes.append((subject, predicate1, (_start, _end)))
|
||||
break
|
||||
po_predict = []
|
||||
for s, p, o in spoes:
|
||||
po_predict.append(
|
||||
(self.tokenizer.decode(token_ids[s[0]:s[1] + 1]).replace(' ', ''),
|
||||
self.id2rel[p],
|
||||
self.tokenizer.decode(token_ids[o[0]:o[1] + 1]).replace(' ', ''))
|
||||
)
|
||||
|
||||
if qid not in answer_dict:
|
||||
raise ValueError('error in answer_dict ')
|
||||
else:
|
||||
answer_dict[qid][0].append(
|
||||
self.tokenizer.decode(token_ids[subject[0]:subject[1] + 1], tokens[subject[0]:subject[1] + 1]))
|
||||
answer_dict[qid][1].extend(po_predict)
|
Loading…
Reference in New Issue
Block a user