更换名字
;
This commit is contained in:
parent
a951d3075c
commit
c3faa8ab87
@ -1,4 +1,7 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
"""
|
||||
spo_bert: 适用于中文 BERT 和 RoBERTa
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
@ -10,8 +13,8 @@ import torch
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from deepIE.chip_rel.config.config import CMeIE_CONFIG
|
||||
from deepIE.chip_rel.etl_span_transformers.data_loader_ptms_total_sub import Reader, Feature
|
||||
from deepIE.chip_rel.etl_span_transformers.train import Trainer
|
||||
from deepIE.chip_rel.spo_transformers.data_loader_ptms_total_sub import Reader, Feature
|
||||
from deepIE.chip_rel.spo_transformers.train import Trainer
|
||||
from utils.file_util import save, load
|
||||
|
||||
simplefilter(action='ignore', category=FutureWarning)
|
||||
@ -25,6 +28,7 @@ def get_args():
|
||||
|
||||
# file parameters
|
||||
parser.add_argument("--input", default=None, type=str, required=True)
|
||||
parser.add_argument("--res_path", default=None, type=str, required=False)
|
||||
parser.add_argument("--output"
|
||||
, default=None, type=str, required=False,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
@ -48,7 +52,7 @@ def get_args():
|
||||
parser.add_argument("--do_lower_case",
|
||||
action='store_true',
|
||||
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
parser.add_argument("--warmup_proportion", default=0.04, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
|
||||
"of training.")
|
||||
parser.add_argument("--bert_model", default=None, type=str,
|
||||
@ -70,7 +74,7 @@ def get_args():
|
||||
parser.add_argument('--bidirectional', type=bool, default=True)
|
||||
parser.add_argument('--pin_memory', type=bool, default=False)
|
||||
args = parser.parse_args()
|
||||
args.cache_data = args.input + '/bert_cache_data_{}/'.format(str(args.max_len))
|
||||
args.cache_data = args.input + '/{}_cache_data_{}/'.format(str(args.bert_model).split('/')[1],str(args.max_len))
|
||||
return args
|
||||
|
||||
|
||||
@ -136,7 +140,7 @@ def main():
|
||||
|
||||
logger.info("** ** * bulid dataset ** ** * ")
|
||||
|
||||
spo_conf = CMeIE_CONFIG if args.spo_version == 'v1' else None
|
||||
spo_conf = CMeIE_CONFIG
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)
|
||||
reader = Reader(spo_conf, tokenizer, max_seq_length=args.max_len)
|
||||
eval_examples, data_loaders, tokenizer = bulid_dataset(args, spo_conf, reader, tokenizer, debug=args.debug)
|
@ -11,7 +11,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
import models.spo_net.etl_span_transformers as etl
|
||||
import models.spo_net.etl_span_transformers as etl_bert
|
||||
import models.spo_net.etl_span_albert as etl_albert
|
||||
|
||||
from layers.encoders.transformers.bert.bert_optimization import BertAdam
|
||||
|
||||
simplefilter(action='ignore', category=FutureWarning)
|
||||
@ -32,7 +34,14 @@ class Trainer(object):
|
||||
|
||||
if self.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
self.model = etl.ERENet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
|
||||
|
||||
if 'albert' in args.bert_model:
|
||||
self.model = etl_albert.ERENet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
|
||||
else:
|
||||
"""
|
||||
通用预训练,适用于中文BERT,RoBRETa以及各种wwm
|
||||
"""
|
||||
self.model = etl_bert.ERENet.from_pretrained(args.bert_model, classes_num=len(spo_conf))
|
||||
|
||||
self.model.to(self.device)
|
||||
if args.train_mode != "train":
|
||||
@ -196,7 +205,7 @@ class Trainer(object):
|
||||
|
||||
self.convert2result(eval_file, answer_dict)
|
||||
|
||||
with codecs.open('result_chip_0813v1.json', 'w', 'utf-8') as f:
|
||||
with codecs.open(self.args.res_path, 'w', 'utf-8') as f:
|
||||
for key, ans_list in answer_dict.items():
|
||||
out_put = {}
|
||||
out_put['text'] = eval_file[int(key)].raw_text
|
155
models/spo_net/etl_span_albert.py
Normal file
155
models/spo_net/etl_span_albert.py
Normal file
@ -0,0 +1,155 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
|
||||
|
||||
"""
|
||||
仅适用于中文ALBERT
|
||||
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AlbertModel
|
||||
from transformers import AlbertPreTrainedModel
|
||||
|
||||
from layers.encoders.transformers.bert.layernorm import ConditionalLayerNorm
|
||||
from utils.data_util import batch_gather
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
class ERENet(AlbertPreTrainedModel):
|
||||
"""
|
||||
ERENet : entity relation jointed extraction
|
||||
"""
|
||||
|
||||
def __init__(self, config, classes_num):
|
||||
super(ERENet, self).__init__(config, classes_num)
|
||||
self.classes_num = classes_num
|
||||
|
||||
# BERT model
|
||||
|
||||
self.albert = AlbertModel(config)
|
||||
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
|
||||
padding_idx=0)
|
||||
self.LayerNorm = ConditionalLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
# pointer net work
|
||||
self.po_dense = nn.Linear(config.hidden_size, self.classes_num * 2)
|
||||
self.subject_dense = nn.Linear(config.hidden_size, 2)
|
||||
self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, q_ids=None, passage_ids=None, segment_ids=None, token_type_ids=None, subject_ids=None,
|
||||
subject_labels=None,
|
||||
object_labels=None, eval_file=None,
|
||||
is_eval=False):
|
||||
mask = (passage_ids != 0).float()
|
||||
bert_encoder = self.albert(passage_ids, token_type_ids=segment_ids, attention_mask=mask)[0]
|
||||
if not is_eval:
|
||||
sub_start_encoder = batch_gather(bert_encoder, subject_ids[:, 0])
|
||||
sub_end_encoder = batch_gather(bert_encoder, subject_ids[:, 1])
|
||||
subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
|
||||
context_encoder = self.LayerNorm(bert_encoder, subject)
|
||||
|
||||
sub_preds = self.subject_dense(bert_encoder)
|
||||
po_preds = self.po_dense(context_encoder).reshape(passage_ids.size(0), -1, self.classes_num, 2)
|
||||
|
||||
subject_loss = self.loss_fct(sub_preds, subject_labels)
|
||||
subject_loss = subject_loss.mean(2)
|
||||
subject_loss = torch.sum(subject_loss * mask.float()) / torch.sum(mask.float())
|
||||
|
||||
po_loss = self.loss_fct(po_preds, object_labels)
|
||||
po_loss = torch.sum(po_loss.mean(3), 2)
|
||||
po_loss = torch.sum(po_loss * mask.float()) / torch.sum(mask.float())
|
||||
|
||||
loss = subject_loss + po_loss
|
||||
|
||||
return loss
|
||||
|
||||
else:
|
||||
|
||||
subject_preds = nn.Sigmoid()(self.subject_dense(bert_encoder))
|
||||
answer_list = list()
|
||||
for qid, sub_pred in zip(q_ids.cpu().numpy(),
|
||||
subject_preds.cpu().numpy()):
|
||||
context = eval_file[qid].bert_tokens
|
||||
start = np.where(sub_pred[:, 0] > 0.5)[0]
|
||||
end = np.where(sub_pred[:, 1] > 0.5)[0]
|
||||
subjects = []
|
||||
for i in start:
|
||||
j = end[end >= i]
|
||||
if i == 0 or i > len(context) - 2:
|
||||
continue
|
||||
|
||||
if len(j) > 0:
|
||||
j = j[0]
|
||||
if j > len(context) - 2:
|
||||
continue
|
||||
subjects.append((i, j))
|
||||
|
||||
answer_list.append(subjects)
|
||||
|
||||
qid_ids, bert_encoders, pass_ids, subject_ids, token_type_ids = [], [], [], [], []
|
||||
for i, subjects in enumerate(answer_list):
|
||||
if subjects:
|
||||
qid = q_ids[i].unsqueeze(0).expand(len(subjects))
|
||||
pass_tensor = passage_ids[i, :].unsqueeze(0).expand(len(subjects), passage_ids.size(1))
|
||||
new_bert_encoder = bert_encoder[i, :, :].unsqueeze(0).expand(len(subjects), bert_encoder.size(1),
|
||||
bert_encoder.size(2))
|
||||
|
||||
token_type_id = torch.zeros((len(subjects), passage_ids.size(1)), dtype=torch.long)
|
||||
for index, (start, end) in enumerate(subjects):
|
||||
token_type_id[index, start:end + 1] = 1
|
||||
|
||||
qid_ids.append(qid)
|
||||
pass_ids.append(pass_tensor)
|
||||
subject_ids.append(torch.tensor(subjects, dtype=torch.long))
|
||||
bert_encoders.append(new_bert_encoder)
|
||||
token_type_ids.append(token_type_id)
|
||||
|
||||
if len(qid_ids) == 0:
|
||||
subject_ids = torch.zeros(1, 2).long().to(bert_encoder.device)
|
||||
qid_tensor = torch.tensor([-1], dtype=torch.long).to(bert_encoder.device)
|
||||
po_tensor = torch.zeros(1, bert_encoder.size(1)).long().to(bert_encoder.device)
|
||||
return qid_tensor, subject_ids, po_tensor
|
||||
|
||||
qids = torch.cat(qid_ids).to(bert_encoder.device)
|
||||
pass_ids = torch.cat(pass_ids).to(bert_encoder.device)
|
||||
bert_encoders = torch.cat(bert_encoders).to(bert_encoder.device)
|
||||
subject_ids = torch.cat(subject_ids).to(bert_encoder.device)
|
||||
|
||||
flag = False
|
||||
split_heads = 1024
|
||||
|
||||
bert_encoders_ = torch.split(bert_encoders, split_heads, dim=0)
|
||||
pass_ids_ = torch.split(pass_ids, split_heads, dim=0)
|
||||
subject_encoder_ = torch.split(subject_ids, split_heads, dim=0)
|
||||
|
||||
po_preds = list()
|
||||
for i in range(len(bert_encoders_)):
|
||||
bert_encoders = bert_encoders_[i]
|
||||
pass_ids = pass_ids_[i]
|
||||
subject_encoder = subject_encoder_[i]
|
||||
|
||||
if bert_encoders.size(0) == 1:
|
||||
flag = True
|
||||
bert_encoders = bert_encoders.expand(2, bert_encoders.size(1), bert_encoders.size(2))
|
||||
subject_encoder = subject_encoder.expand(2, subject_encoder.size(1))
|
||||
sub_start_encoder = batch_gather(bert_encoders, subject_encoder[:, 0])
|
||||
sub_end_encoder = batch_gather(bert_encoders, subject_encoder[:, 1])
|
||||
subject = torch.cat([sub_start_encoder, sub_end_encoder], 1)
|
||||
context_encoder = self.LayerNorm(bert_encoders, subject)
|
||||
|
||||
po_pred = self.po_dense(context_encoder).reshape(subject_encoder.size(0), -1, self.classes_num, 2)
|
||||
|
||||
if flag:
|
||||
po_pred = po_pred[1, :, :, :].unsqueeze(0)
|
||||
|
||||
po_preds.append(po_pred)
|
||||
|
||||
po_tensor = torch.cat(po_preds).to(qids.device)
|
||||
po_tensor = nn.Sigmoid()(po_tensor)
|
||||
return qids, subject_ids, po_tensor
|
@ -23,6 +23,7 @@ class ERENet(BertPreTrainedModel):
|
||||
self.classes_num = classes_num
|
||||
|
||||
# BERT model
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
|
||||
padding_idx=0)
|
||||
@ -69,7 +70,7 @@ class ERENet(BertPreTrainedModel):
|
||||
for qid, sub_pred in zip(q_ids.cpu().numpy(),
|
||||
subject_preds.cpu().numpy()):
|
||||
context = eval_file[qid].bert_tokens
|
||||
start = np.where(sub_pred[:, 0] > 0.6)[0]
|
||||
start = np.where(sub_pred[:, 0] > 0.5)[0]
|
||||
end = np.where(sub_pred[:, 1] > 0.5)[0]
|
||||
subjects = []
|
||||
for i in start:
|
||||
|
Loading…
Reference in New Issue
Block a user