更换名字

This commit is contained in:
loujie0822 2020-08-19 21:39:31 +08:00
parent a951d3075c
commit c3faa8ab87
7 changed files with 178 additions and 9 deletions

View File

@ -1,4 +1,7 @@
# _*_ coding:utf-8 _*_
"""
spo_bert: 适用于中文 BERT RoBERTa
"""
import argparse
import logging
import os
@ -10,8 +13,8 @@ import torch
from transformers import BertTokenizer
from deepIE.chip_rel.config.config import CMeIE_CONFIG
from deepIE.chip_rel.etl_span_transformers.data_loader_ptms_total_sub import Reader, Feature
from deepIE.chip_rel.etl_span_transformers.train import Trainer
from deepIE.chip_rel.spo_transformers.data_loader_ptms_total_sub import Reader, Feature
from deepIE.chip_rel.spo_transformers.train import Trainer
from utils.file_util import save, load
simplefilter(action='ignore', category=FutureWarning)
@ -25,6 +28,7 @@ def get_args():
# file parameters
parser.add_argument("--input", default=None, type=str, required=True)
parser.add_argument("--res_path", default=None, type=str, required=False)
parser.add_argument("--output"
, default=None, type=str, required=False,
help="The output directory where the model checkpoints and predictions will be written.")
@ -48,7 +52,7 @@ def get_args():
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
parser.add_argument("--warmup_proportion", default=0.04, type=float,
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
"of training.")
parser.add_argument("--bert_model", default=None, type=str,
@ -70,7 +74,7 @@ def get_args():
parser.add_argument('--bidirectional', type=bool, default=True)
parser.add_argument('--pin_memory', type=bool, default=False)
args = parser.parse_args()
args.cache_data = args.input + '/bert_cache_data_{}/'.format(str(args.max_len))
args.cache_data = args.input + '/{}_cache_data_{}/'.format(str(args.bert_model).split('/')[1],str(args.max_len))
return args
@ -136,7 +140,7 @@ def main():
logger.info("** ** * bulid dataset ** ** * ")
spo_conf = CMeIE_CONFIG if args.spo_version == 'v1' else None
spo_conf = CMeIE_CONFIG
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
reader = Reader(spo_conf, tokenizer, max_seq_length=args.max_len)
eval_examples, data_loaders, tokenizer = bulid_dataset(args, spo_conf, reader, tokenizer, debug=args.debug)

View File

@ -11,7 +11,9 @@ import torch
import torch.nn as nn
from tqdm import tqdm
import models.spo_net.etl_span_transformers as etl
import models.spo_net.etl_span_transformers as etl_bert
import models.spo_net.etl_span_albert as etl_albert
from layers.encoders.transformers.bert.bert_optimization import BertAdam
simplefilter(action='ignore', category=FutureWarning)
@ -32,7 +34,14 @@ class Trainer(object):
if self.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
self.model = etl.ERENet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
if 'albert' in args.bert_model:
self.model = etl_albert.ERENet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
else:
"""
通用预训练适用于中文BERTRoBRETa以及各种wwm
"""
self.model = etl_bert.ERENet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
self.model.to(self.device)
if args.train_mode != "train":
@ -196,7 +205,7 @@ class Trainer(object):
self.convert2result(eval_file, answer_dict)
with codecs.open('result_chip_0813v1.json', 'w', 'utf-8') as f:
with codecs.open(self.args.res_path, 'w', 'utf-8') as f:
for key, ans_list in answer_dict.items():
out_put = {}
out_put['text'] = eval_file[int(key)].raw_text

View File

@ -0,0 +1,155 @@
# _*_ coding:utf-8 _*_
"""
仅适用于中文ALBERT
"""
import warnings
import numpy as np
import torch
import torch.nn as nn
from transformers import AlbertModel
from transformers import AlbertPreTrainedModel
from layers.encoders.transformers.bert.layernorm import ConditionalLayerNorm
from utils.data_util import batch_gather
warnings.filterwarnings("ignore")
class ERENet(AlbertPreTrainedModel):
"""
ERENet : entity relation jointed extraction
"""
def __init__(self, config, classes_num):
super(ERENet, self).__init__(config, classes_num)
self.classes_num = classes_num
# BERT model
self.albert = AlbertModel(config)
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
padding_idx=0)
self.LayerNorm = ConditionalLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# pointer net work
self.po_dense = nn.Linear(config.hidden_size, self.classes_num * 2)
self.subject_dense = nn.Linear(config.hidden_size, 2)
self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')
self.init_weights()
def forward(self, q_ids=None, passage_ids=None, segment_ids=None, token_type_ids=None, subject_ids=None,
subject_labels=None,
object_labels=None, eval_file=None,
is_eval=False):
mask = (passage_ids != 0).float()
bert_encoder = self.albert(passage_ids, token_type_ids=segment_ids, attention_mask=mask)[0]
if not is_eval:
sub_start_encoder = batch_gather(bert_encoder, subject_ids[:, 0])
sub_end_encoder = batch_gather(bert_encoder, subject_ids[:, 1])
subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
context_encoder = self.LayerNorm(bert_encoder, subject)
sub_preds = self.subject_dense(bert_encoder)
po_preds = self.po_dense(context_encoder).reshape(passage_ids.size(0), -1, self.classes_num, 2)
subject_loss = self.loss_fct(sub_preds, subject_labels)
subject_loss = subject_loss.mean(2)
subject_loss = torch.sum(subject_loss * mask.float()) / torch.sum(mask.float())
po_loss = self.loss_fct(po_preds, object_labels)
po_loss = torch.sum(po_loss.mean(3), 2)
po_loss = torch.sum(po_loss * mask.float()) / torch.sum(mask.float())
loss = subject_loss + po_loss
return loss
else:
subject_preds = nn.Sigmoid()(self.subject_dense(bert_encoder))
answer_list = list()
for qid, sub_pred in zip(q_ids.cpu().numpy(),
subject_preds.cpu().numpy()):
context = eval_file[qid].bert_tokens
start = np.where(sub_pred[:, 0] > 0.5)[0]
end = np.where(sub_pred[:, 1] > 0.5)[0]
subjects = []
for i in start:
j = end[end >= i]
if i == 0 or i > len(context) - 2:
continue
if len(j) > 0:
j = j[0]
if j > len(context) - 2:
continue
subjects.append((i, j))
answer_list.append(subjects)
qid_ids, bert_encoders, pass_ids, subject_ids, token_type_ids = [], [], [], [], []
for i, subjects in enumerate(answer_list):
if subjects:
qid = q_ids[i].unsqueeze(0).expand(len(subjects))
pass_tensor = passage_ids[i, :].unsqueeze(0).expand(len(subjects), passage_ids.size(1))
new_bert_encoder = bert_encoder[i, :, :].unsqueeze(0).expand(len(subjects), bert_encoder.size(1),
bert_encoder.size(2))
token_type_id = torch.zeros((len(subjects), passage_ids.size(1)), dtype=torch.long)
for index, (start, end) in enumerate(subjects):
token_type_id[index, start:end + 1] = 1
qid_ids.append(qid)
pass_ids.append(pass_tensor)
subject_ids.append(torch.tensor(subjects, dtype=torch.long))
bert_encoders.append(new_bert_encoder)
token_type_ids.append(token_type_id)
if len(qid_ids) == 0:
subject_ids = torch.zeros(1, 2).long().to(bert_encoder.device)
qid_tensor = torch.tensor([-1], dtype=torch.long).to(bert_encoder.device)
po_tensor = torch.zeros(1, bert_encoder.size(1)).long().to(bert_encoder.device)
return qid_tensor, subject_ids, po_tensor
qids = torch.cat(qid_ids).to(bert_encoder.device)
pass_ids = torch.cat(pass_ids).to(bert_encoder.device)
bert_encoders = torch.cat(bert_encoders).to(bert_encoder.device)
subject_ids = torch.cat(subject_ids).to(bert_encoder.device)
flag = False
split_heads = 1024
bert_encoders_ = torch.split(bert_encoders, split_heads, dim=0)
pass_ids_ = torch.split(pass_ids, split_heads, dim=0)
subject_encoder_ = torch.split(subject_ids, split_heads, dim=0)
po_preds = list()
for i in range(len(bert_encoders_)):
bert_encoders = bert_encoders_[i]
pass_ids = pass_ids_[i]
subject_encoder = subject_encoder_[i]
if bert_encoders.size(0) == 1:
flag = True
bert_encoders = bert_encoders.expand(2, bert_encoders.size(1), bert_encoders.size(2))
subject_encoder = subject_encoder.expand(2, subject_encoder.size(1))
sub_start_encoder = batch_gather(bert_encoders, subject_encoder[:, 0])
sub_end_encoder = batch_gather(bert_encoders, subject_encoder[:, 1])
subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
context_encoder = self.LayerNorm(bert_encoders, subject)
po_pred = self.po_dense(context_encoder).reshape(subject_encoder.size(0), -1, self.classes_num, 2)
if flag:
po_pred = po_pred[1, :, :, :].unsqueeze(0)
po_preds.append(po_pred)
po_tensor = torch.cat(po_preds).to(qids.device)
po_tensor = nn.Sigmoid()(po_tensor)
return qids, subject_ids, po_tensor

View File

@ -23,6 +23,7 @@ class ERENet(BertPreTrainedModel):
self.classes_num = classes_num
# BERT model
self.bert = BertModel(config)
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
padding_idx=0)
@ -69,7 +70,7 @@ class ERENet(BertPreTrainedModel):
for qid, sub_pred in zip(q_ids.cpu().numpy(),
subject_preds.cpu().numpy()):
context = eval_file[qid].bert_tokens
start = np.where(sub_pred[:, 0] > 0.6)[0]
start = np.where(sub_pred[:, 0] > 0.5)[0]
end = np.where(sub_pred[:, 1] > 0.5)[0]
subjects = []
for i in start: