first add model codes
This commit is contained in:
parent
cad8b9b8fd
commit
144d94550a
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
0
models/attribute_extract_net/__init__.py
Normal file
0
models/attribute_extract_net/__init__.py
Normal file
120
models/attribute_extract_net/bert_mpn.py
Normal file
120
models/attribute_extract_net/bert_mpn.py
Normal file
@ -0,0 +1,120 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from layers.encoders.transformers.bert.bert_model import BertModel
|
||||
from layers.encoders.transformers.bert.bert_pretrain import BertPreTrainedModel
|
||||
|
||||
|
||||
# class AttributeExtractNet(BertPreTrainedModel):
|
||||
# """
|
||||
# Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware and
|
||||
# encoded by BERT
|
||||
# """
|
||||
#
|
||||
# def __init__(self, config, args, attribute_conf):
|
||||
# print('bert mpn baseline')
|
||||
# super(AttributeExtractNet, self).__init__(config, args, attribute_conf)
|
||||
#
|
||||
# self.bert = BertModel(config)
|
||||
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
|
||||
# padding_idx=0)
|
||||
#
|
||||
# # sentence_encoder using transformer
|
||||
# self.transformer_encoder_layer = TransformerEncoderLayer(config.hidden_size, args.nhead,
|
||||
# dim_feedforward=args.dim_feedforward)
|
||||
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
|
||||
#
|
||||
# self.classes_num = len(attribute_conf)
|
||||
#
|
||||
# # pointer net work
|
||||
# self.attr_start = nn.Linear(config.hidden_size, self.classes_num)
|
||||
# self.attr_end = nn.Linear(config.hidden_size, self.classes_num)
|
||||
#
|
||||
# self.apply(self.init_bert_weights)
|
||||
#
|
||||
# def forward(self, passage_id=None, token_type_id=None, segment_id=None, pos_start=None, pos_end=None, start_id=None,
|
||||
# end_id=None, is_eval=False):
|
||||
# mask = passage_id.eq(0)
|
||||
# sent_mask = passage_id != 0
|
||||
#
|
||||
# context_encoder, _ = self.bert(passage_id, segment_id, attention_mask=sent_mask,
|
||||
# output_all_encoded_layers=False)
|
||||
#
|
||||
# token_entity_emb = self.token_entity_emb(token_type_id)
|
||||
#
|
||||
# # sent encoder based entity-aware
|
||||
# sent_entity_encoder = context_encoder + token_entity_emb
|
||||
# transformer_encoder = self.transformer_encoder(sent_entity_encoder.transpose(1, 0),
|
||||
# src_key_padding_mask=mask).transpose(0, 1)
|
||||
#
|
||||
# attr_start = self.attr_start(transformer_encoder)
|
||||
# attr_end = self.attr_end(transformer_encoder)
|
||||
#
|
||||
# loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
#
|
||||
# s1_loss = loss_fct(attr_start, start_id)
|
||||
# s1_loss = torch.sum(s1_loss, 2)
|
||||
# s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
#
|
||||
# s2_loss = loss_fct(attr_end, end_id)
|
||||
# s2_loss = torch.sum(s2_loss, 2)
|
||||
# s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
#
|
||||
# total_loss = s1_loss + s2_loss
|
||||
# po1 = nn.Sigmoid()(attr_start)
|
||||
# po2 = nn.Sigmoid()(attr_end)
|
||||
#
|
||||
# return total_loss, po1, po2
|
||||
|
||||
class AttributeExtractNet(BertPreTrainedModel):
|
||||
"""
|
||||
Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware and
|
||||
encoded by BERT
|
||||
"""
|
||||
|
||||
def __init__(self, config, args, attribute_conf):
|
||||
super(AttributeExtractNet, self).__init__(config, args, attribute_conf)
|
||||
print('bert debug - 2 ')
|
||||
|
||||
self.bert = BertModel(config)
|
||||
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=config.hidden_size,
|
||||
padding_idx=0)
|
||||
|
||||
# # sentence_encoder using transformer
|
||||
# self.transformer_encoder_layer = TransformerEncoderLayer(config.hidden_size, args.nhead,
|
||||
# dim_feedforward=args.dim_feedforward)
|
||||
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
|
||||
|
||||
self.classes_num = len(attribute_conf)
|
||||
|
||||
# pointer net work
|
||||
self.attr_start = nn.Linear(config.hidden_size, self.classes_num)
|
||||
self.attr_end = nn.Linear(config.hidden_size, self.classes_num)
|
||||
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, passage_id=None, token_type_id=None, segment_id=None, pos_start=None, pos_end=None, start_id=None,
|
||||
end_id=None, is_eval=False):
|
||||
sent_mask = passage_id != 0
|
||||
|
||||
context_encoder, _ = self.bert(passage_id, segment_id, attention_mask=sent_mask,
|
||||
output_all_encoded_layers=False)
|
||||
|
||||
attr_start = self.attr_start(context_encoder)
|
||||
attr_end = self.attr_end(context_encoder)
|
||||
|
||||
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
s1_loss = loss_fct(attr_start, start_id)
|
||||
s1_loss = torch.sum(s1_loss, 2)
|
||||
s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
|
||||
s2_loss = loss_fct(attr_end, end_id)
|
||||
s2_loss = torch.sum(s2_loss, 2)
|
||||
s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
|
||||
total_loss = s1_loss + s2_loss
|
||||
po1 = nn.Sigmoid()(attr_start)
|
||||
po2 = nn.Sigmoid()(attr_end)
|
||||
|
||||
return total_loss, po1, po2
|
384
models/attribute_extract_net/mpn.py
Normal file
384
models/attribute_extract_net/mpn.py
Normal file
@ -0,0 +1,384 @@
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.transformer import TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
from layers.encoders.rnns.stacked_rnn import StackedBRNN
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
class SentenceEncoder(nn.Module):
|
||||
def __init__(self, args, input_size):
|
||||
super(SentenceEncoder, self).__init__()
|
||||
rnn_type = nn.LSTM if args.rnn_encoder == 'lstm' else nn.GRU
|
||||
self.encoder = StackedBRNN(
|
||||
input_size=input_size,
|
||||
hidden_size=args.hidden_size,
|
||||
num_layers=args.num_layers,
|
||||
dropout_rate=args.dropout,
|
||||
dropout_output=True,
|
||||
concat_layers=False,
|
||||
rnn_type=rnn_type,
|
||||
padding=True
|
||||
)
|
||||
|
||||
def forward(self, input, mask):
|
||||
return self.encoder(input, mask)
|
||||
|
||||
|
||||
class AttributeExtractNet(nn.Module):
|
||||
"""
|
||||
Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware
|
||||
实体感知方式:char_emb + token_entity_emb+pos_start_emb+pos_end_emb
|
||||
"""
|
||||
|
||||
def __init__(self, args, char_emb, attribute_conf):
|
||||
super(AttributeExtractNet, self).__init__()
|
||||
print('实体感知方式:char_emb + token_entity_emb+pos_start_emb+pos_end_emb')
|
||||
|
||||
if char_emb is not None:
|
||||
self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
|
||||
padding_idx=0)
|
||||
else:
|
||||
self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
|
||||
padding_idx=0)
|
||||
|
||||
self.pos_start = nn.Embedding(num_embeddings=args.pos_size, embedding_dim=args.pos_dim,
|
||||
padding_idx=0)
|
||||
self.pos_end = nn.Embedding(num_embeddings=args.pos_size, embedding_dim=args.pos_dim,
|
||||
padding_idx=0)
|
||||
# token whether belong to a entity, 1 represent a entity token, else 0;
|
||||
self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
|
||||
padding_idx=0)
|
||||
# sentence_encoder using lstm
|
||||
self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
|
||||
self.second_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
|
||||
|
||||
# sentence_encoder using transformer
|
||||
self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead,
|
||||
dim_feedforward=args.dim_feedforward)
|
||||
self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
|
||||
|
||||
self.classes_num = len(attribute_conf)
|
||||
|
||||
# pointer net work
|
||||
self.attr_start = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
self.attr_end = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
|
||||
def forward(self, passage_id=None, token_type_id=None, segment_id=None, pos_start=None, pos_end=None, start_id=None,
|
||||
end_id=None, is_eval=False):
|
||||
mask = passage_id.eq(0)
|
||||
sent_mask = passage_id != 0
|
||||
|
||||
char_emb = self.char_emb(passage_id)
|
||||
pos_start_emb = self.pos_start(pos_start)
|
||||
pos_end_emb = self.pos_end(pos_end)
|
||||
token_entity_emb = self.token_entity_emb(token_type_id)
|
||||
# sent encoder based entity-aware
|
||||
sent_entity_encoder = char_emb + token_entity_emb+pos_start_emb+pos_end_emb
|
||||
sent_first_encoder = self.first_sentence_encoder(sent_entity_encoder, mask).transpose(1, 0)
|
||||
transformer_encoder = self.transformer_encoder(sent_first_encoder, src_key_padding_mask=mask).transpose(0, 1)
|
||||
|
||||
attr_start = self.attr_start(transformer_encoder)
|
||||
attr_end = self.attr_end(transformer_encoder)
|
||||
|
||||
loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
s1_loss = loss_fct(attr_start, start_id)
|
||||
s1_loss = torch.sum(s1_loss, 2)
|
||||
s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
|
||||
s2_loss = loss_fct(attr_end, end_id)
|
||||
s2_loss = torch.sum(s2_loss, 2)
|
||||
s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
|
||||
total_loss = s1_loss + s2_loss
|
||||
po1 = nn.Sigmoid()(attr_start)
|
||||
po2 = nn.Sigmoid()(attr_end)
|
||||
|
||||
return total_loss, po1, po2
|
||||
|
||||
# class AttributeExtractNet(nn.Module):
|
||||
# """
|
||||
# Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware
|
||||
# 实体感知方式:sent-encoder+pos_start+pos_end
|
||||
# """
|
||||
#
|
||||
# def __init__(self, args, char_emb, attribute_conf):
|
||||
# super(AttributeExtractNet, self).__init__()
|
||||
# print('实体感知方式: sent-encoder+pos_start')
|
||||
# if char_emb is not None:
|
||||
# self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
|
||||
# padding_idx=0)
|
||||
# else:
|
||||
# self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
|
||||
# padding_idx=0)
|
||||
#
|
||||
# self.pos_start = nn.Embedding(num_embeddings=args.pos_size, embedding_dim=args.pos_dim,
|
||||
# padding_idx=0)
|
||||
# self.pos_end = nn.Embedding(num_embeddings=args.pos_size, embedding_dim=args.pos_dim,
|
||||
# padding_idx=0)
|
||||
#
|
||||
# # token whether belong to a entity, 1 represent a entity token, else 0;
|
||||
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
|
||||
# padding_idx=0)
|
||||
# # sentence_encoder using lstm
|
||||
# self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
|
||||
# self.second_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
|
||||
#
|
||||
# # sentence_encoder using transformer
|
||||
# self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead,
|
||||
# dim_feedforward=args.dim_feedforward)
|
||||
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
|
||||
#
|
||||
# self.classes_num = len(attribute_conf)
|
||||
#
|
||||
# # pointer net work
|
||||
# self.attr_start = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
# self.attr_end = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
#
|
||||
# def forward(self, passage_id=None, token_type_id=None, segment_id=None, pos_start=None, pos_end=None, start_id=None,
|
||||
# end_id=None, is_eval=False):
|
||||
# mask = passage_id.eq(0)
|
||||
# sent_mask = passage_id != 0
|
||||
#
|
||||
# char_emb = self.char_emb(passage_id)
|
||||
# pos_start_emb = self.pos_start(pos_start)
|
||||
# pos_end_emb = self.pos_end(pos_end)
|
||||
# token_entity_emb = self.token_entity_emb(token_type_id)
|
||||
#
|
||||
# sent_first_encoder = self.first_sentence_encoder(char_emb, mask)
|
||||
#
|
||||
# # sent encoder based entity-aware
|
||||
# sent_entity_aware = sent_first_encoder+pos_start_emb
|
||||
#
|
||||
# sent_second_encoder = self.second_sentence_encoder(sent_entity_aware, mask).transpose(1, 0)
|
||||
# transformer_encoder = self.transformer_encoder(sent_second_encoder, src_key_padding_mask=mask).transpose(0, 1)
|
||||
#
|
||||
# attr_start = self.attr_start(transformer_encoder)
|
||||
# attr_end = self.attr_end(transformer_encoder)
|
||||
#
|
||||
# loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
#
|
||||
# s1_loss = loss_fct(attr_start, start_id)
|
||||
# s1_loss = torch.sum(s1_loss, 2)
|
||||
# s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
#
|
||||
# s2_loss = loss_fct(attr_end, end_id)
|
||||
# s2_loss = torch.sum(s2_loss, 2)
|
||||
# s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
#
|
||||
# total_loss = s1_loss + s2_loss
|
||||
# po1 = nn.Sigmoid()(attr_start)
|
||||
# po2 = nn.Sigmoid()(attr_end)
|
||||
#
|
||||
# return total_loss, po1, po2
|
||||
|
||||
|
||||
# class AttributeExtractNet(nn.Module):
|
||||
# """
|
||||
# Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware
|
||||
# 实体感知方式:[sent-encoder, token_ent_type, global_repre]
|
||||
# """
|
||||
#
|
||||
# def __init__(self, args, char_emb, attribute_conf):
|
||||
# super(AttributeExtractNet, self).__init__()
|
||||
# print('实体感知方式:[sent-encoder, token_ent_type, global_repre]')
|
||||
# if char_emb is not None:
|
||||
# self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
|
||||
# padding_idx=0)
|
||||
# else:
|
||||
# self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
|
||||
# padding_idx=0)
|
||||
#
|
||||
# # token whether belong to a entity, 1 represent a entity token, else 0;
|
||||
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
|
||||
# padding_idx=0)
|
||||
# # sentence_encoder using lstm
|
||||
# self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
|
||||
# self.second_sentence_encoder = SentenceEncoder(args, args.char_emb_size * 3)
|
||||
#
|
||||
# # sentence_encoder using transformer
|
||||
# self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead,
|
||||
# dim_feedforward=args.dim_feedforward)
|
||||
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
|
||||
#
|
||||
# self.classes_num = len(attribute_conf)
|
||||
#
|
||||
# # pointer net work
|
||||
# self.attr_start = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
# self.attr_end = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
#
|
||||
# def forward(self, passage_id=None, token_type_id=None, segment_id=None, pos_start=None, pos_end=None, start_id=None,
|
||||
# end_id=None, is_eval=False):
|
||||
# mask = passage_id.eq(0)
|
||||
# sent_mask = passage_id != 0
|
||||
#
|
||||
# char_emb = self.char_emb(passage_id)
|
||||
# token_entity_emb = self.token_entity_emb(token_type_id)
|
||||
#
|
||||
# sent_first_encoder = self.first_sentence_encoder(char_emb, mask)
|
||||
# global_encoder_, _ = torch.max(sent_first_encoder, 1)
|
||||
# global_encoder = global_encoder_.unsqueeze(1).expand_as(sent_first_encoder)
|
||||
# # sent encoder based entity-aware
|
||||
# sent_entity_aware = torch.cat([sent_first_encoder, token_entity_emb, global_encoder], -1)
|
||||
#
|
||||
# sent_second_encoder = self.second_sentence_encoder(sent_entity_aware, mask).transpose(1, 0)
|
||||
# transformer_encoder = self.transformer_encoder(sent_second_encoder, src_key_padding_mask=mask).transpose(0, 1)
|
||||
#
|
||||
# attr_start = self.attr_start(transformer_encoder)
|
||||
# attr_end = self.attr_end(transformer_encoder)
|
||||
#
|
||||
# loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
#
|
||||
# s1_loss = loss_fct(attr_start, start_id)
|
||||
# s1_loss = torch.sum(s1_loss, 2)
|
||||
# s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
#
|
||||
# s2_loss = loss_fct(attr_end, end_id)
|
||||
# s2_loss = torch.sum(s2_loss, 2)
|
||||
# s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
#
|
||||
# total_loss = s1_loss + s2_loss
|
||||
# po1 = nn.Sigmoid()(attr_start)
|
||||
# po2 = nn.Sigmoid()(attr_end)
|
||||
#
|
||||
# return total_loss, po1, po2
|
||||
|
||||
# class AttributeExtractNet(nn.Module):
|
||||
# """
|
||||
# Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware
|
||||
# 实体感知方式:token_entity_embedding
|
||||
# """
|
||||
#
|
||||
# def __init__(self, args, char_emb, attribute_conf):
|
||||
# print('basline ')
|
||||
# super(AttributeExtractNet, self).__init__()
|
||||
# if char_emb is not None:
|
||||
# self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
|
||||
# padding_idx=0)
|
||||
# else:
|
||||
# self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
|
||||
# padding_idx=0)
|
||||
#
|
||||
# # token whether belong to a entity, 1 represent a entity token, else 0;
|
||||
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
|
||||
# padding_idx=0)
|
||||
# # sentence_encoder using lstm
|
||||
# self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
|
||||
# self.second_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
|
||||
#
|
||||
# # sentence_encoder using transformer
|
||||
# self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead,
|
||||
# dim_feedforward=args.dim_feedforward)
|
||||
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
|
||||
#
|
||||
# self.classes_num = len(attribute_conf)
|
||||
#
|
||||
# # pointer net work
|
||||
# self.attr_start = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
# self.attr_end = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
#
|
||||
# def forward(self, passage_id=None, token_type_id=None, segment_id=None,start_id=None, end_id=None, is_eval=False):
|
||||
# mask = passage_id.eq(0)
|
||||
# sent_mask = passage_id != 0
|
||||
#
|
||||
# char_emb = self.char_emb(passage_id)
|
||||
# token_entity_emb = self.token_entity_emb(token_type_id)
|
||||
#
|
||||
# sent_first_encoder = self.first_sentence_encoder(char_emb, mask)
|
||||
# # sent encoder based entity-aware
|
||||
# sent_entity_encoder = sent_first_encoder + token_entity_emb
|
||||
# sent_second_encoder = self.second_sentence_encoder(sent_entity_encoder, mask).transpose(1, 0)
|
||||
# transformer_encoder = self.transformer_encoder(sent_second_encoder, src_key_padding_mask=mask).transpose(0, 1)
|
||||
#
|
||||
# attr_start = self.attr_start(transformer_encoder)
|
||||
# attr_end = self.attr_end(transformer_encoder)
|
||||
#
|
||||
# loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
#
|
||||
# s1_loss = loss_fct(attr_start, start_id)
|
||||
# s1_loss = torch.sum(s1_loss, 2)
|
||||
# s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
#
|
||||
# s2_loss = loss_fct(attr_end, end_id)
|
||||
# s2_loss = torch.sum(s2_loss, 2)
|
||||
# s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
#
|
||||
# total_loss = s1_loss + s2_loss
|
||||
# po1 = nn.Sigmoid()(attr_start)
|
||||
# po2 = nn.Sigmoid()(attr_end)
|
||||
#
|
||||
# return total_loss, po1, po2
|
||||
|
||||
# class AttributeExtractNet(nn.Module):
|
||||
# """
|
||||
# Attribute Extract Net with Multi-label Pointer Network(MPN) based Entity-aware
|
||||
# 实体感知方式:显示编码,将实体替换为<MASK>
|
||||
# """
|
||||
#
|
||||
# def __init__(self, args, char_emb, attribute_conf):
|
||||
# super(AttributeExtractNet, self).__init__()
|
||||
# print('实体感知方式:显示编码,将实体替换为<MASK>')
|
||||
#
|
||||
#
|
||||
# if char_emb is not None:
|
||||
# self.char_emb = nn.Embedding.from_pretrained(torch.tensor(char_emb, dtype=torch.float32), freeze=False,
|
||||
# padding_idx=0)
|
||||
# else:
|
||||
# self.char_emb = nn.Embedding(num_embeddings=args.vocab_size, embedding_dim=args.char_emb_size,
|
||||
# padding_idx=0)
|
||||
#
|
||||
# # token whether belong to a entity, 1 represent a entity token, else 0;
|
||||
# self.token_entity_emb = nn.Embedding(num_embeddings=2, embedding_dim=args.entity_emb_size,
|
||||
# padding_idx=0)
|
||||
# # sentence_encoder using lstm
|
||||
# self.first_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
|
||||
# self.second_sentence_encoder = SentenceEncoder(args, args.char_emb_size)
|
||||
#
|
||||
# # sentence_encoder using transformer
|
||||
# self.transformer_encoder_layer = TransformerEncoderLayer(args.hidden_size * 2, args.nhead,
|
||||
# dim_feedforward=args.dim_feedforward)
|
||||
# self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, args.transformer_layers)
|
||||
#
|
||||
# self.classes_num = len(attribute_conf)
|
||||
#
|
||||
# # pointer net work
|
||||
# self.attr_start = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
# self.attr_end = nn.Linear(args.hidden_size * 2, self.classes_num)
|
||||
#
|
||||
# def forward(self, passage_id=None, token_type_id=None, segment_id=None, start_id=None, end_id=None, is_eval=False):
|
||||
# mask = passage_id.eq(0)
|
||||
# sent_mask = passage_id != 0
|
||||
#
|
||||
# char_emb = self.char_emb(passage_id)
|
||||
# # token_entity_emb = self.token_entity_emb(token_type_id)
|
||||
#
|
||||
# # char_emb = char_emb + token_entity_emb
|
||||
#
|
||||
# sent_first_encoder = self.first_sentence_encoder(char_emb, mask).transpose(1, 0)
|
||||
# # sent encoder based entity-aware
|
||||
# # sent_entity_encoder = sent_first_encoder + token_entity_emb
|
||||
# # sent_second_encoder = self.second_sentence_encoder(sent_first_encoder, mask).transpose(1, 0)
|
||||
# transformer_encoder = self.transformer_encoder(sent_first_encoder, src_key_padding_mask=mask).transpose(0, 1)
|
||||
#
|
||||
# attr_start = self.attr_start(transformer_encoder)
|
||||
# attr_end = self.attr_end(transformer_encoder)
|
||||
#
|
||||
# loss_fct = nn.BCEWithLogitsLoss(reduction='none')
|
||||
#
|
||||
# s1_loss = loss_fct(attr_start, start_id)
|
||||
# s1_loss = torch.sum(s1_loss, 2)
|
||||
# s1_loss = torch.sum(s1_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
#
|
||||
# s2_loss = loss_fct(attr_end, end_id)
|
||||
# s2_loss = torch.sum(s2_loss, 2)
|
||||
# s2_loss = torch.sum(s2_loss * sent_mask.float()) / torch.sum(sent_mask.float()) / self.classes_num
|
||||
#
|
||||
# total_loss = s1_loss + s2_loss
|
||||
# po1 = nn.Sigmoid()(attr_start)
|
||||
# po2 = nn.Sigmoid()(attr_end)
|
||||
#
|
||||
# return total_loss, po1, po2
|
0
run/__init__.py
Normal file
0
run/__init__.py
Normal file
0
run/attribute_extract/__init__.py
Normal file
0
run/attribute_extract/__init__.py
Normal file
0
run/attribute_extract/mpn/__init__.py
Normal file
0
run/attribute_extract/mpn/__init__.py
Normal file
392
run/attribute_extract/mpn/data.py
Normal file
392
run/attribute_extract/mpn/data.py
Normal file
@ -0,0 +1,392 @@
|
||||
import json
|
||||
import logging
|
||||
from collections import Counter
|
||||
from functools import partial
|
||||
|
||||
import jieba
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from layers.encoders.transformers.bert.bert_tokenization import BertTokenizer
|
||||
from utils.data_util import padding, mpn_padding, _handle_pos_limit
|
||||
|
||||
config = {
|
||||
|
||||
'drug': {
|
||||
'药品-用药频率': 0,
|
||||
'药品-持续时间': 1,
|
||||
'药品-用药剂量': 2,
|
||||
'药品-用药方法': 3,
|
||||
'药品-不良反应': 4,
|
||||
},
|
||||
'disease': {
|
||||
'疾病-检查方法': 0,
|
||||
'疾病-临床表现': 1,
|
||||
'疾病-非药治疗': 2,
|
||||
'疾病-药品名称': 3,
|
||||
'疾病-部位': 4,
|
||||
},
|
||||
'oncology_drug': {
|
||||
'药物_剂量': 0,
|
||||
'药物_给药日': 1,
|
||||
'药物_给药方式': 2,
|
||||
},
|
||||
'yingxiang_bingzao': {
|
||||
"病灶部位_异常描述因子": 0,
|
||||
"病灶部位_阴性描述因子": 1,
|
||||
"病灶部位-诊断因子": 2
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class Attribute(object):
|
||||
def __init__(self,
|
||||
value,
|
||||
value_pos_start,
|
||||
value_pos_end,
|
||||
attr_type,
|
||||
attr_type_id
|
||||
):
|
||||
self.value = value
|
||||
self.value_pos_start = value_pos_start
|
||||
self.value_pos_end = value_pos_end
|
||||
self.attr_type = attr_type
|
||||
self.attr_type_id = attr_type_id
|
||||
|
||||
|
||||
class Example(object):
|
||||
def __init__(self,
|
||||
p_id=None,
|
||||
context=None,
|
||||
bert_tokens=None,
|
||||
entity_name=None,
|
||||
entity_position=None,
|
||||
pos_start=None,
|
||||
pos_end=None,
|
||||
gold_attr_list=None,
|
||||
gold_answer=None):
|
||||
self.p_id = p_id
|
||||
self.context = context
|
||||
self.bert_tokens = bert_tokens
|
||||
self.entity_name = entity_name
|
||||
self.entity_position = entity_position
|
||||
self.pos_start = pos_start
|
||||
self.pos_end = pos_end
|
||||
self.gold_attr_list = gold_attr_list
|
||||
self.gold_answer = gold_answer
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
gold_repr = []
|
||||
for attr in self.gold_attr_list:
|
||||
gold_repr.append(attr.attr_type + '-' + attr.value)
|
||||
return 'entity {} with attribute:\n{}'.format(self.entity_name, '\t'.join(gold_repr))
|
||||
|
||||
|
||||
class InputFeature(object):
|
||||
|
||||
def __init__(self,
|
||||
p_id=None,
|
||||
passage_id=None,
|
||||
token_type_id=None,
|
||||
pos_start_id=None,
|
||||
pos_end_id=None,
|
||||
segment_id=None,
|
||||
label=None):
|
||||
self.p_id = p_id
|
||||
self.passage_id = passage_id
|
||||
self.token_type_id = token_type_id
|
||||
self.pos_start_id = pos_start_id
|
||||
self.pos_end_id = pos_end_id
|
||||
self.segment_id = segment_id
|
||||
self.label = label
|
||||
|
||||
|
||||
class Reader(object):
|
||||
def __init__(self, do_lowercase=False, seg_char=False, max_len=600, entity_type="药品名称"):
|
||||
|
||||
self.do_lowercase = do_lowercase
|
||||
self.seg_char = seg_char
|
||||
self.max_len = max_len
|
||||
self.entity_config = config[entity_type]
|
||||
|
||||
if self.seg_char:
|
||||
logging.info("seg_char...")
|
||||
else:
|
||||
logging.info("seg_word using jieba ...")
|
||||
|
||||
def read_examples(self, filename, data_type):
|
||||
logging.info("Generating {} examples...".format(data_type))
|
||||
return self._read(filename, data_type)
|
||||
|
||||
def _read(self, filename, data_type):
|
||||
|
||||
with open(filename, 'r') as fh:
|
||||
source_data = json.load(fh)
|
||||
|
||||
examples = []
|
||||
for p_id in tqdm(range(len(source_data))):
|
||||
data = source_data[p_id]
|
||||
para = data['text']
|
||||
context = para if self.seg_char else ''.join(jieba.lcut(para))
|
||||
if len(context) > self.max_len:
|
||||
context = context[:self.max_len]
|
||||
context = context.lower() if self.do_lowercase else context
|
||||
|
||||
_data_dict = dict()
|
||||
_data_dict['id'] = p_id
|
||||
entity_name, entity_position = data['entity'][0], data['entity'][1]
|
||||
entity_name = entity_name.lower() if self.do_lowercase else entity_name
|
||||
start, end = entity_position
|
||||
assert entity_name == context[start:end]
|
||||
|
||||
# pos_start&pos_end: 指句子中词语相对entity的position限制
|
||||
# 如:[-30, 30],embed 时整体+31,变成[1, 61]
|
||||
# 则一共62个pos token,0 留给 pad
|
||||
|
||||
pos_start = list(map(lambda i: i - start, list(range(len(context)))))
|
||||
pos_end = list(map(lambda i: i - end, list(range(len(context)))))
|
||||
pos_start = _handle_pos_limit(pos_start)
|
||||
pos_end = _handle_pos_limit(pos_end)
|
||||
|
||||
attribute_list = data['attribute_list']
|
||||
|
||||
gold_attr_list = []
|
||||
for attribute in attribute_list:
|
||||
attr_type = attribute['type']
|
||||
value = attribute['value']
|
||||
value, value_pos_start, value_pos_end = value[0], value[1], value[2]
|
||||
value = value.lower() if self.do_lowercase else value
|
||||
|
||||
assert value == context[value_pos_start:value_pos_end]
|
||||
|
||||
gold_attr_list.append(Attribute(
|
||||
value=value,
|
||||
value_pos_start=value_pos_start,
|
||||
value_pos_end=value_pos_end,
|
||||
attr_type=attr_type,
|
||||
attr_type_id=self.entity_config[attr_type]
|
||||
))
|
||||
gold_answer = [attr.attr_type + '@' + attr.value for attr in gold_attr_list]
|
||||
|
||||
examples.append(
|
||||
Example(
|
||||
p_id=p_id,
|
||||
context=context,
|
||||
entity_name=entity_name,
|
||||
entity_position=entity_position,
|
||||
pos_start=pos_start,
|
||||
pos_end=pos_end,
|
||||
gold_attr_list=gold_attr_list,
|
||||
gold_answer=gold_answer
|
||||
)
|
||||
)
|
||||
logging.info("{} total size is {} ".format(data_type, len(examples)))
|
||||
return examples
|
||||
|
||||
|
||||
class Vocabulary(object):
|
||||
|
||||
def __init__(self, special_tokens=["<OOV>", "<MASK>"]):
|
||||
|
||||
self.char_vocab = None
|
||||
self.emb_mat = None
|
||||
self.char2idx = None
|
||||
self.char_counter = Counter()
|
||||
self.special_tokens = special_tokens
|
||||
|
||||
def build_vocab_only_with_char(self, examples, min_char_count=-1):
|
||||
|
||||
logging.info("Building vocabulary only with character...")
|
||||
|
||||
self.char_vocab = ["<PAD>"]
|
||||
|
||||
if self.special_tokens is not None and isinstance(self.special_tokens, list):
|
||||
self.char_vocab.extend(self.special_tokens)
|
||||
|
||||
for example in tqdm(examples):
|
||||
for char in example.context:
|
||||
self.char_counter[char] += 1
|
||||
|
||||
for w, v in self.char_counter.most_common():
|
||||
if v >= min_char_count:
|
||||
self.char_vocab.append(w)
|
||||
|
||||
self.char2idx = {token: idx for idx, token in enumerate(self.char_vocab)}
|
||||
|
||||
logging.info("total char counter size is {} ".format(len(self.char_counter)))
|
||||
logging.info("total char vocabulary size is {} ".format(len(self.char_vocab)))
|
||||
|
||||
def _load_embedding(self, embedding_file, embedding_dict):
|
||||
|
||||
with open(embedding_file) as f:
|
||||
for line in f:
|
||||
if len(line.rstrip().split(" ")) <= 2: continue
|
||||
token, vector = line.rstrip().split(" ", 1)
|
||||
embedding_dict[token] = np.fromstring(vector, dtype=np.float, sep=" ")
|
||||
return embedding_dict
|
||||
|
||||
def make_embedding(self, vocab, embedding_file, emb_size):
|
||||
|
||||
embedding_dict = dict()
|
||||
embedding_dict["<PAD>"] = np.array([0. for _ in range(emb_size)])
|
||||
|
||||
self._load_embedding(embedding_file, embedding_dict)
|
||||
|
||||
count = 0
|
||||
for token in tqdm(vocab):
|
||||
if token not in embedding_dict:
|
||||
count += 1
|
||||
embedding_dict[token] = np.array([np.random.normal(scale=0.1) for _ in range(emb_size)])
|
||||
logging.info(
|
||||
"{} / {} tokens have corresponding in embedding vector".format(len(vocab) - count, len(vocab)))
|
||||
|
||||
emb_mat = [embedding_dict[token] for idx, token in enumerate(vocab)]
|
||||
|
||||
return emb_mat
|
||||
|
||||
|
||||
class Feature(object):
|
||||
def __init__(self, args, token2idx_dict):
|
||||
self.bert = args.use_bert
|
||||
self.token2idx_dict = token2idx_dict
|
||||
if self.bert:
|
||||
self.tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
|
||||
def token2wid(self, token):
|
||||
if token in self.token2idx_dict:
|
||||
return self.token2idx_dict[token]
|
||||
return self.token2idx_dict["<OOV>"]
|
||||
|
||||
def __call__(self, examples, entity_type, data_type):
|
||||
|
||||
if self.bert:
|
||||
return self.convert_examples_to_bert_features(examples, entity_type, data_type)
|
||||
else:
|
||||
return self.convert_examples_to_features(examples, entity_type, data_type)
|
||||
|
||||
def convert_examples_to_features(self, examples, entity_type, data_type):
|
||||
|
||||
logging.info("Processing {} examples...".format(data_type))
|
||||
|
||||
examples2features = list()
|
||||
for index, example in enumerate(examples):
|
||||
|
||||
gold_attr_list = example.gold_attr_list
|
||||
ent_start, ent_end = example.entity_position[0], example.entity_position[1]
|
||||
|
||||
passage_id = np.zeros(len(example.context), dtype=np.int)
|
||||
token_type_id = np.zeros(len(example.context), dtype=np.int)
|
||||
pos_start_id = np.zeros(len(example.context), dtype=np.int)
|
||||
pos_end_id = np.zeros(len(example.context), dtype=np.int)
|
||||
|
||||
for i, token in enumerate(example.context):
|
||||
if ent_start <= i < ent_end:
|
||||
# token = "<MASK>"
|
||||
token_type_id[i] = 1
|
||||
passage_id[i] = self.token2wid(token)
|
||||
pos_start_id[i] = example.pos_start[i]
|
||||
pos_end_id[i] = example.pos_end[i]
|
||||
|
||||
examples2features.append(
|
||||
InputFeature(
|
||||
p_id=index,
|
||||
passage_id=passage_id,
|
||||
token_type_id=token_type_id,
|
||||
pos_start_id=pos_start_id,
|
||||
pos_end_id=pos_end_id,
|
||||
segment_id=token_type_id,
|
||||
label=gold_attr_list
|
||||
))
|
||||
|
||||
logging.info("Built instances is Completed")
|
||||
return AttributeMPNDataset(examples2features, attribute_num=len(config[entity_type]))
|
||||
|
||||
def convert_examples_to_bert_features(self, examples, entity_type, data_type):
|
||||
|
||||
logging.info("Processing {} examples...".format(data_type))
|
||||
|
||||
examples2features = list()
|
||||
for index, example in enumerate(examples):
|
||||
|
||||
gold_attr_list = example.gold_attr_list
|
||||
ent_start, ent_end = example.entity_position[0], example.entity_position[1]
|
||||
segment_id = np.zeros(len(example.context) + 2, dtype=np.int)
|
||||
token_type_id = np.zeros(len(example.context) + 2, dtype=np.int)
|
||||
pos_start_id = np.zeros(len(example.context) + 2, dtype=np.int)
|
||||
pos_end_id = np.zeros(len(example.context) + 2, dtype=np.int)
|
||||
|
||||
tokens = ["[CLS]"]
|
||||
raw_tokens = ["[CLS]"]
|
||||
for i, token in enumerate(example.context):
|
||||
raw_tokens.append(token)
|
||||
if ent_start <= i < ent_end:
|
||||
# token_type_id[i + 1] = 1
|
||||
# segment_id[i + 1] = 1
|
||||
token = '[unused1]'
|
||||
tokens.append(token)
|
||||
pos_start_id[i + 1] = example.pos_start[i]
|
||||
pos_end_id[i + 1] = example.pos_end[i]
|
||||
|
||||
tokens.append("[SEP]")
|
||||
raw_tokens.append("[SEP]")
|
||||
passage_id = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
example.bert_tokens = raw_tokens
|
||||
examples2features.append(
|
||||
InputFeature(
|
||||
p_id=index,
|
||||
passage_id=passage_id,
|
||||
token_type_id=token_type_id,
|
||||
pos_start_id=pos_start_id,
|
||||
pos_end_id=pos_end_id,
|
||||
segment_id=segment_id,
|
||||
label=gold_attr_list
|
||||
))
|
||||
|
||||
logging.info("Built instances is Completed")
|
||||
return AttributeMPNDataset(examples2features, attribute_num=len(config[entity_type]), use_bert=True)
|
||||
|
||||
|
||||
class AttributeMPNDataset(Dataset):
|
||||
def __init__(self, features, attribute_num, use_bert=False):
|
||||
super(AttributeMPNDataset, self).__init__()
|
||||
self.use_bert = use_bert
|
||||
self.q_ids = [f.p_id for f in features]
|
||||
self.passages = [f.passage_id for f in features]
|
||||
self.token_type = [f.token_type_id for f in features]
|
||||
self.label = [f.label for f in features]
|
||||
self.attribute_num = attribute_num
|
||||
self.segment_ids = [f.segment_id for f in features]
|
||||
self.pos_start_ids = [f.pos_start_id for f in features]
|
||||
self.pos_end_ids = [f.pos_end_id for f in features]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.passages)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.q_ids[index], self.passages[index], self.token_type[index], self.segment_ids[index], self.label[
|
||||
index], self.pos_start_ids[index], self.pos_end_ids[index]
|
||||
|
||||
def _create_collate_fn(self, batch_first=False):
|
||||
def collate(examples):
|
||||
p_ids, passages, token_type, segment_ids, label, pos_start_ids, pos_end_ids = zip(*examples)
|
||||
p_ids = torch.tensor([p_id for p_id in p_ids], dtype=torch.long)
|
||||
passages_tensor, _ = padding(passages, is_float=False, batch_first=batch_first)
|
||||
pos_start_tensor, _ = padding(pos_start_ids, is_float=False, batch_first=batch_first)
|
||||
pos_end_tensor, _ = padding(pos_end_ids, is_float=False, batch_first=batch_first)
|
||||
token_type_tensor, _ = padding(token_type, is_float=False, batch_first=batch_first)
|
||||
segment_tensor, _ = padding(segment_ids, is_float=False, batch_first=batch_first)
|
||||
o1_tensor, o2_tensor = mpn_padding(passages, label, class_num=self.attribute_num, is_float=True,
|
||||
use_bert=self.use_bert)
|
||||
return p_ids, passages_tensor, token_type_tensor, segment_tensor, pos_start_tensor,pos_end_tensor,o1_tensor, o2_tensor
|
||||
|
||||
return partial(collate)
|
||||
|
||||
def get_dataloader(self, batch_size, num_workers=0, shuffle=False, batch_first=True, pin_memory=False,
|
||||
drop_last=False):
|
||||
return DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self._create_collate_fn(batch_first),
|
||||
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)
|
181
run/attribute_extract/mpn/main.py
Normal file
181
run/attribute_extract/mpn/main.py
Normal file
@ -0,0 +1,181 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from run.attribute_extract.mpn.data import Reader, Vocabulary, config, Feature
|
||||
from run.attribute_extract.mpn.train import Trainer
|
||||
from utils.file_util import save, load
|
||||
|
||||
logger = logging.getLogger()
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# file parameters
|
||||
parser.add_argument("--input", default=None, type=str, required=True)
|
||||
parser.add_argument("--output"
|
||||
, default=None, type=str, required=False,
|
||||
help="The output directory where the model checkpoints and predictions will be written.")
|
||||
# "cpt/baidu_w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5"
|
||||
# 'cpt/baidu_w2v/w2v.txt'
|
||||
parser.add_argument('--embedding_file', type=str,
|
||||
default='cpt/baidu_w2v/sgns.target.word-character.char1-2.dynwin5.thr10.neg5.dim300.iter5')
|
||||
|
||||
# choice parameters
|
||||
parser.add_argument('--entity_type', type=str, default='disease')
|
||||
parser.add_argument('--use_word2vec', type=bool, default=False)
|
||||
parser.add_argument('--use_bert', type=bool, default=False)
|
||||
parser.add_argument('--seg_char', type=bool, default=True)
|
||||
|
||||
# train parameters
|
||||
parser.add_argument('--train_mode', type=str, default="train")
|
||||
parser.add_argument("--train_batch_size", default=4, type=int, help="Total batch size for training.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--epoch_num", default=3, type=int,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument('--patience_stop', type=int, default=10, help='Patience for learning early stop')
|
||||
parser.add_argument('--device_id', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
|
||||
|
||||
# bert parameters
|
||||
parser.add_argument("--do_lower_case",
|
||||
action='store_true',
|
||||
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
|
||||
"of training.")
|
||||
parser.add_argument("--bert_model", default=None, type=str,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||
"bert-base-multilingual-cased, bert-base-chinese.")
|
||||
|
||||
# model parameters
|
||||
parser.add_argument("--max_len", default=1000, type=int)
|
||||
parser.add_argument('--word_emb_size', type=int, default=300)
|
||||
parser.add_argument('--char_emb_size', type=int, default=300)
|
||||
parser.add_argument('--entity_emb_size', type=int, default=300)
|
||||
parser.add_argument('--pos_limit', type=int, default=30)
|
||||
parser.add_argument('--pos_dim', type=int, default=300)
|
||||
parser.add_argument('--pos_size', type=int, default=62)
|
||||
|
||||
parser.add_argument('--hidden_size', type=int, default=150)
|
||||
parser.add_argument('--num_layers', type=int, default=2)
|
||||
parser.add_argument('--dropout', type=int, default=0.5)
|
||||
parser.add_argument('--rnn_encoder', type=str, default='lstm', help="must choose in blow: lstm or gru")
|
||||
parser.add_argument('--bidirectional', type=bool, default=True)
|
||||
parser.add_argument('--pin_memory', type=bool, default=False)
|
||||
parser.add_argument('--transformer_layers', type=int, default=1)
|
||||
parser.add_argument('--nhead', type=int, default=4)
|
||||
parser.add_argument('--dim_feedforward', type=int, default=2048)
|
||||
args = parser.parse_args()
|
||||
if args.use_word2vec:
|
||||
args.cache_data = args.input + '/char2v_cache_data/'
|
||||
elif args.use_bert:
|
||||
args.cache_data = args.input + '/char_bert_cache_data/'
|
||||
else:
|
||||
args.cache_data = args.input + '/char_cache_data/'
|
||||
return args
|
||||
|
||||
|
||||
def bulid_dataset(args, reader, vocab, debug=False):
|
||||
char2idx, char_emb = None, None
|
||||
train_src = args.input + "/train.json"
|
||||
dev_src = args.input + "/dev.json"
|
||||
|
||||
train_examples_file = args.cache_data + "/train-examples.pkl"
|
||||
dev_examples_file = args.cache_data + "/dev-examples.pkl"
|
||||
|
||||
char_emb_file = args.cache_data + "/char_emb.pkl"
|
||||
char_dictionary = args.cache_data + "/char_dict.pkl"
|
||||
|
||||
if not os.path.exists(train_examples_file):
|
||||
|
||||
train_examples = reader.read_examples(train_src, data_type='train')
|
||||
dev_examples = reader.read_examples(dev_src, data_type='dev')
|
||||
|
||||
if not args.use_bert:
|
||||
# todo : min_word_count=3 ?
|
||||
vocab.build_vocab_only_with_char(train_examples, min_char_count=1)
|
||||
if args.use_word2vec and args.embedding_file:
|
||||
char_emb = vocab.make_embedding(vocab=vocab.char_vocab,
|
||||
embedding_file=args.embedding_file,
|
||||
emb_size=args.word_emb_size)
|
||||
save(char_emb_file, char_emb, message="char embedding")
|
||||
save(char_dictionary, vocab.char2idx, message="char dictionary")
|
||||
char2idx = vocab.char2idx
|
||||
save(train_examples_file, train_examples, message="train examples")
|
||||
save(dev_examples_file, dev_examples, message="dev examples")
|
||||
else:
|
||||
if not args.use_bert:
|
||||
if args.use_word2vec and args.embedding_file:
|
||||
char_emb = load(char_emb_file)
|
||||
char2idx = load(char_dictionary)
|
||||
logging.info("total char vocabulary size is {} ".format(len(char2idx)))
|
||||
train_examples, dev_examples = load(train_examples_file), load(dev_examples_file)
|
||||
|
||||
logging.info('train examples size is {}'.format(len(train_examples)))
|
||||
logging.info('dev examples size is {}'.format(len(dev_examples)))
|
||||
|
||||
if not args.use_bert:
|
||||
args.vocab_size = len(char2idx)
|
||||
convert_examples_features = Feature(args, token2idx_dict=char2idx)
|
||||
|
||||
train_examples = train_examples[:10] if debug else train_examples
|
||||
dev_examples = dev_examples[:10] if debug else dev_examples
|
||||
|
||||
train_data_set = convert_examples_features(train_examples, entity_type=args.entity_type,
|
||||
data_type='train')
|
||||
dev_data_set = convert_examples_features(dev_examples, entity_type=args.entity_type,
|
||||
data_type='dev')
|
||||
train_data_loader = train_data_set.get_dataloader(args.train_batch_size, shuffle=True, pin_memory=args.pin_memory)
|
||||
dev_data_loader = dev_data_set.get_dataloader(args.train_batch_size)
|
||||
|
||||
data_loaders = train_data_loader, dev_data_loader
|
||||
eval_examples = train_examples, dev_examples
|
||||
|
||||
return eval_examples, data_loaders, char_emb
|
||||
|
||||
|
||||
# TODO
|
||||
'''
|
||||
1、增加自动构建 词向量的逻辑
|
||||
2、对比不同实体感知构建方式
|
||||
3、增加elmo方式
|
||||
4、增加对抗训练
|
||||
|
||||
'''
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
if not os.path.exists(args.output):
|
||||
print('mkdir {}'.format(args.output))
|
||||
os.makedirs(args.output)
|
||||
if not os.path.exists(args.cache_data):
|
||||
print('mkdir {}'.format(args.cache_data))
|
||||
os.makedirs(args.cache_data)
|
||||
|
||||
logger.info("** ** * bulid dataset ** ** * ")
|
||||
reader = Reader(seg_char=args.seg_char, max_len=args.max_len, entity_type=args.entity_type)
|
||||
vocab = Vocabulary()
|
||||
|
||||
eval_examples, data_loaders, char_emb = bulid_dataset(args, reader, vocab, debug=False)
|
||||
|
||||
trainer = Trainer(args, data_loaders, eval_examples, char_emb, attribute_conf=config[args.entity_type])
|
||||
|
||||
if args.train_mode == "train":
|
||||
trainer.train(args)
|
||||
elif args.train_mode == "eval":
|
||||
trainer.resume(args)
|
||||
trainer.eval_data_set("train")
|
||||
trainer.eval_data_set("dev")
|
||||
elif args.train_mode == "resume":
|
||||
# trainer.resume(args)
|
||||
trainer.show("dev") # bad case analysis
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
295
run/attribute_extract/mpn/train.py
Normal file
295
run/attribute_extract/mpn/train.py
Normal file
@ -0,0 +1,295 @@
|
||||
# _*_ coding:utf-8 _*_
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
import models.attribute_extract_net.bert_mpn as bert_mpn
|
||||
import models.attribute_extract_net.mpn as mpn
|
||||
from utils.optimizer_util import set_optimizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(self, args, data_loaders, examples, char_emb, attribute_conf):
|
||||
if args.use_bert:
|
||||
self.model = bert_mpn.AttributeExtractNet.from_pretrained(args.bert_model, args, attribute_conf)
|
||||
else:
|
||||
self.model = mpn.AttributeExtractNet(args, char_emb, attribute_conf)
|
||||
|
||||
self.args = args
|
||||
|
||||
self.device = torch.device("cuda:{}".format(args.device_id) if torch.cuda.is_available() else "cpu")
|
||||
self.n_gpu = torch.cuda.device_count()
|
||||
|
||||
self.id2rel = {item: key for key, item in attribute_conf.items()}
|
||||
self.rel2id =attribute_conf
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if self.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
self.model.to(self.device)
|
||||
# self.resume(args)
|
||||
logging.info('total gpu num is {}'.format(self.n_gpu))
|
||||
if self.n_gpu > 1:
|
||||
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
|
||||
|
||||
train_dataloader, dev_dataloader = data_loaders
|
||||
train_eval, dev_eval = examples
|
||||
self.eval_file_choice = {
|
||||
"train": train_eval,
|
||||
"dev": dev_eval,
|
||||
}
|
||||
self.data_loader_choice = {
|
||||
"train": train_dataloader,
|
||||
"dev": dev_dataloader,
|
||||
}
|
||||
self.optimizer = set_optimizer(args, self.model,
|
||||
train_steps=(int(len(train_eval) / args.train_batch_size) + 1) * args.epoch_num)
|
||||
|
||||
def train(self, args):
|
||||
|
||||
best_f1 = 0.0
|
||||
patience_stop = 0
|
||||
self.model.train()
|
||||
step_gap = 20
|
||||
for epoch in range(int(args.epoch_num)):
|
||||
|
||||
global_loss = 0.0
|
||||
|
||||
for step, batch in tqdm(enumerate(self.data_loader_choice[u"train"]), mininterval=5,
|
||||
desc=u'training at epoch : %d ' % epoch, leave=False, file=sys.stdout):
|
||||
|
||||
loss, answer_dict_ = self.forward(batch)
|
||||
|
||||
if step % step_gap == 0:
|
||||
global_loss += loss
|
||||
current_loss = global_loss / step_gap
|
||||
print(
|
||||
u"step {} / {} of epoch {}, train/loss: {}".format(step, len(self.data_loader_choice["train"]),
|
||||
epoch, current_loss))
|
||||
global_loss = 0.0
|
||||
|
||||
res_dev = self.eval_data_set("dev")
|
||||
if res_dev['f1'] >= best_f1:
|
||||
best_f1 = res_dev['f1']
|
||||
logging.info("** ** * Saving fine-tuned model ** ** * ")
|
||||
model_to_save = self.model.module if hasattr(self.model,
|
||||
'module') else self.model # Only save the model it-self
|
||||
output_model_file = args.output + "/pytorch_model.bin"
|
||||
torch.save(model_to_save.state_dict(), str(output_model_file))
|
||||
patience_stop = 0
|
||||
else:
|
||||
patience_stop += 1
|
||||
if patience_stop >= args.patience_stop:
|
||||
return
|
||||
|
||||
def resume(self, args):
|
||||
resume_model_file = args.output + "/pytorch_model.bin"
|
||||
logging.info("=> loading checkpoint '{}'".format(resume_model_file))
|
||||
checkpoint = torch.load(resume_model_file, map_location='cpu')
|
||||
self.model.load_state_dict(checkpoint)
|
||||
|
||||
def forward(self, batch, chosen=u'train', grad=True, eval=False, detail=False):
|
||||
|
||||
batch = tuple(t.to(self.device) for t in batch)
|
||||
|
||||
p_ids, passage_id, token_type_id, segment_id, pos_start,pos_end,start_id, end_id = batch
|
||||
loss, po1, po2 = self.model(passage_id=passage_id, token_type_id=token_type_id, segment_id=segment_id,
|
||||
pos_start=pos_start,pos_end=pos_end,start_id=start_id, end_id=end_id, is_eval=eval)
|
||||
|
||||
if self.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
|
||||
if grad:
|
||||
loss.backward()
|
||||
loss = loss.item()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if eval:
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
answer_dict_ = convert_pointer_net_contour(eval_file, p_ids, po1, po2, self.id2rel,
|
||||
use_bert=self.args.use_bert)
|
||||
else:
|
||||
answer_dict_ = None
|
||||
return loss, answer_dict_
|
||||
|
||||
def eval_data_set(self, chosen="dev"):
|
||||
|
||||
self.model.eval()
|
||||
answer_dict = {}
|
||||
|
||||
data_loader = self.data_loader_choice[chosen]
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
last_time = time.time()
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
|
||||
loss, answer_dict_ = self.forward(batch, chosen, grad=False, eval=True)
|
||||
answer_dict.update(answer_dict_)
|
||||
used_time = time.time() - last_time
|
||||
logging.info('chosen {} took : {} sec'.format(chosen, used_time))
|
||||
res = self.evaluate(eval_file, answer_dict, chosen)
|
||||
self.detail_evaluate(eval_file, answer_dict, chosen)
|
||||
self.model.train()
|
||||
return res
|
||||
|
||||
def show(self, chosen="dev"):
|
||||
|
||||
self.model.eval()
|
||||
answer_dict = {}
|
||||
|
||||
data_loader = self.data_loader_choice[chosen]
|
||||
eval_file = self.eval_file_choice[chosen]
|
||||
with torch.no_grad():
|
||||
for _, batch in tqdm(enumerate(data_loader), mininterval=5, leave=False, file=sys.stdout):
|
||||
loss, answer_dict_ = self.forward(batch, chosen, grad=False, eval=True, detail=True)
|
||||
answer_dict.update(answer_dict_)
|
||||
self.badcase_analysis(eval_file, answer_dict, chosen)
|
||||
|
||||
@staticmethod
|
||||
def evaluate(eval_file, answer_dict, chosen):
|
||||
|
||||
em = 0
|
||||
pre = 0
|
||||
gold = 0
|
||||
for key, value in answer_dict.items():
|
||||
ground_truths = eval_file[int(key)].gold_answer
|
||||
value, l1, l2 = value
|
||||
prediction = list(value) if len(value) else []
|
||||
assert type(prediction) == type(ground_truths)
|
||||
intersection = set(prediction) & set(ground_truths)
|
||||
|
||||
em += len(intersection)
|
||||
pre += len(set(prediction))
|
||||
gold += len(set(ground_truths))
|
||||
|
||||
precision = 100.0 * em / pre if pre > 0 else 0.
|
||||
recall = 100.0 * em / gold if gold > 0 else 0.
|
||||
f1 = 2 * recall * precision / (recall + precision) if (recall + precision) != 0 else 0.0
|
||||
print('============================================')
|
||||
print("{}/em: {},\tpre&gold: {}\t{} ".format(chosen, em, pre, gold))
|
||||
print("{}/f1: {}, \tPrecision: {},\tRecall: {} ".format(chosen, f1, precision,
|
||||
recall))
|
||||
return {'f1': f1, "recall": recall, "precision": precision, 'em': em, 'pre': pre, 'gold': gold}
|
||||
|
||||
def detail_evaluate(self, eval_file, answer_dict, chosen):
|
||||
def generate_detail_dict(spo_list):
|
||||
dict_detail = dict()
|
||||
for i, tag in enumerate(spo_list):
|
||||
detail_name = tag.split('@')[0]
|
||||
if detail_name not in dict_detail:
|
||||
dict_detail[detail_name] = [tag]
|
||||
else:
|
||||
dict_detail[detail_name].append(tag)
|
||||
return dict_detail
|
||||
|
||||
total_detail = {}
|
||||
for key, value in answer_dict.items():
|
||||
ground_truths = eval_file[int(key)].gold_answer
|
||||
value, l1, l2 = value
|
||||
prediction = list(value) if len(value) else []
|
||||
|
||||
gold_detail = generate_detail_dict(ground_truths)
|
||||
pred_detail = generate_detail_dict(prediction)
|
||||
for key in self.rel2id.keys():
|
||||
|
||||
pred = pred_detail.get(key, [])
|
||||
gold = gold_detail.get(key, [])
|
||||
em = len(set(pred) & set(gold))
|
||||
pred_num = len(set(pred))
|
||||
gold_num = len(set(gold))
|
||||
|
||||
if key not in total_detail:
|
||||
total_detail[key] = dict()
|
||||
total_detail[key]['em'] = em
|
||||
total_detail[key]['pred_num'] = pred_num
|
||||
total_detail[key]['gold_num'] = gold_num
|
||||
else:
|
||||
total_detail[key]['em'] += em
|
||||
total_detail[key]['pred_num'] += pred_num
|
||||
total_detail[key]['gold_num'] += gold_num
|
||||
for key, res_dict_ in total_detail.items():
|
||||
res_dict_['p'] = 100.0 * res_dict_['em'] / res_dict_['pred_num'] if res_dict_['pred_num'] != 0 else 0.0
|
||||
res_dict_['r'] = 100.0 * res_dict_['em'] / res_dict_['gold_num'] if res_dict_['gold_num'] != 0 else 0.0
|
||||
res_dict_['f'] = 2 * res_dict_['p'] * res_dict_['r'] / (res_dict_['p'] + res_dict_['r']) if res_dict_['p'] + \
|
||||
res_dict_[
|
||||
'r'] != 0 else 0.0
|
||||
|
||||
for gold_key, res_dict_ in total_detail.items():
|
||||
print('===============================================================')
|
||||
print("{}/em: {},\tpred_num&gold_num: {}\t{} ".format(gold_key, res_dict_['em'], res_dict_['pred_num'],
|
||||
res_dict_['gold_num']))
|
||||
print(
|
||||
"{}/f1: {},\tprecison&recall: {}\t{}".format(gold_key, res_dict_['f'], res_dict_['p'], res_dict_['r']))
|
||||
|
||||
@staticmethod
|
||||
def badcase_analysis(eval_file, answer_dict, chosen):
|
||||
em = 0
|
||||
pre = 0
|
||||
gold = 0
|
||||
content = ''
|
||||
for key, value in answer_dict.items():
|
||||
entity_name = eval_file[int(key)].entity_name
|
||||
context = eval_file[int(key)].context
|
||||
ground_truths = eval_file[int(key)].gold_answer
|
||||
value, l1, l2 = value
|
||||
prediction = list(value) if len(value) else ['']
|
||||
assert type(prediction) == type(ground_truths)
|
||||
|
||||
intersection = set(prediction) & set(ground_truths)
|
||||
|
||||
if prediction == ground_truths == ['']:
|
||||
continue
|
||||
if set(prediction) != set(ground_truths):
|
||||
ground_truths = list(sorted(set(ground_truths)))
|
||||
prediction = list(sorted(set(prediction)))
|
||||
print('raw context is:\t' + context)
|
||||
print('subject_name is:\t' + entity_name)
|
||||
print('pred_text is:\t' + '\t'.join(prediction))
|
||||
print('gold_text is:\t' + '\t'.join(ground_truths))
|
||||
content += 'raw context is:\t' + context + '\n'
|
||||
content += 'subject_name is:\t' + entity_name + '\n'
|
||||
content += 'pred_text is:\t' + '\t'.join(prediction) + '\n'
|
||||
content += 'gold_text is:\t' + '\t'.join(ground_truths) + '\n'
|
||||
content += '==============================='
|
||||
em += len(intersection)
|
||||
pre += len(set(prediction))
|
||||
gold += len(set(ground_truths))
|
||||
with open('badcase_{}.txt'.format(chosen), 'w') as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def convert_pointer_net_contour(eval_file, q_ids, po1, po2, id2rel, use_bert=False):
|
||||
answer_dict = dict()
|
||||
for qid, o1, o2 in zip(q_ids, po1.data.cpu().numpy(), po2.data.cpu().numpy()):
|
||||
|
||||
context = eval_file[qid.item()].context if not use_bert else eval_file[qid.item()].bert_tokens
|
||||
gold_attr_list = eval_file[qid.item()].gold_attr_list
|
||||
gold_answer = [attr.attr_type + '@' + attr.value for attr in gold_attr_list]
|
||||
|
||||
answers = list()
|
||||
start, end = np.where(o1 > 0.5), np.where(o2 > 0.5)
|
||||
for _start, _attr_type_id_start in zip(*start):
|
||||
if _start > len(context) or (_start == 0 and use_bert):
|
||||
continue
|
||||
for _end, _attr_type_id_end in zip(*end):
|
||||
if _start <= _end < len(context) and _attr_type_id_start == _attr_type_id_end:
|
||||
_attr_value = ''.join(context[_start: _end + 1]) if use_bert else context[_start: _end + 1]
|
||||
_attr_type = id2rel[_attr_type_id_start]
|
||||
_attr = _attr_type + '@' + _attr_value
|
||||
answers.append(_attr)
|
||||
break
|
||||
|
||||
answer_dict[str(qid.item())] = [answers, o1, o2]
|
||||
|
||||
return answer_dict
|
0
run/entity_extraction/__init__.py
Normal file
0
run/entity_extraction/__init__.py
Normal file
0
run/entity_linking/__init__.py
Normal file
0
run/entity_linking/__init__.py
Normal file
0
run/entity_relation_extraction/__init__.py
Normal file
0
run/entity_relation_extraction/__init__.py
Normal file
0
run/event_extraction/__init__.py
Normal file
0
run/event_extraction/__init__.py
Normal file
0
run/relation_extraction/__init__.py
Normal file
0
run/relation_extraction/__init__.py
Normal file
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
51
utils/data_util.py
Normal file
51
utils/data_util.py
Normal file
@ -0,0 +1,51 @@
|
||||
import torch
|
||||
|
||||
|
||||
def padding(seqs, is_float=False, batch_first=False):
|
||||
lengths = [len(s) for s in seqs]
|
||||
|
||||
seqs = [torch.Tensor(s) for s in seqs]
|
||||
batch_length = max(lengths)
|
||||
|
||||
seq_tensor = torch.FloatTensor(batch_length, len(seqs)).fill_(float(0)) if is_float \
|
||||
else torch.LongTensor(batch_length, len(seqs)).fill_(0)
|
||||
|
||||
for i, s in enumerate(seqs):
|
||||
end_seq = lengths[i]
|
||||
seq_tensor[:end_seq, i].copy_(s[:end_seq])
|
||||
|
||||
if batch_first:
|
||||
seq_tensor = seq_tensor.t()
|
||||
|
||||
return seq_tensor, lengths
|
||||
|
||||
|
||||
def mpn_padding(seqs, label, class_num, is_float=False, use_bert=False):
|
||||
lengths = [len(s) for s in seqs]
|
||||
|
||||
seqs = [torch.Tensor(s) for s in seqs]
|
||||
batch_length = max(lengths)
|
||||
|
||||
o1_tensor = torch.FloatTensor(len(seqs), batch_length, class_num).fill_(float(0)) if is_float \
|
||||
else torch.LongTensor(len(seqs), batch_length, class_num).fill_(0)
|
||||
o2_tensor = torch.FloatTensor(len(seqs), batch_length, class_num).fill_(float(0)) if is_float \
|
||||
else torch.LongTensor(len(seqs), batch_length, class_num).fill_(0)
|
||||
for i, label_ in enumerate(label):
|
||||
for attr in label_:
|
||||
if use_bert:
|
||||
o1_tensor[i, attr.value_pos_start + 1, attr.attr_type_id] = 1
|
||||
o2_tensor[i, attr.value_pos_end, attr.attr_type_id] = 1
|
||||
else:
|
||||
o1_tensor[i, attr.value_pos_start, attr.attr_type_id] = 1
|
||||
o2_tensor[i, attr.value_pos_end - 1, attr.attr_type_id] = 1
|
||||
|
||||
return o1_tensor, o2_tensor
|
||||
|
||||
|
||||
def _handle_pos_limit(pos, limit=30):
|
||||
for i, p in enumerate(pos):
|
||||
if p > limit:
|
||||
pos[i] = limit
|
||||
if p < -limit:
|
||||
pos[i] = -limit
|
||||
return [p + limit + 1 for p in pos]
|
46
utils/file_util.py
Normal file
46
utils/file_util.py
Normal file
@ -0,0 +1,46 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
|
||||
def pickle_dump_large_file(obj, filepath):
|
||||
max_bytes = 2 ** 31 - 1
|
||||
bytes_out = pickle.dumps(obj)
|
||||
n_bytes = sys.getsizeof(bytes_out)
|
||||
with open(filepath, 'wb') as f_out:
|
||||
for idx in range(0, n_bytes, max_bytes):
|
||||
f_out.write(bytes_out[idx:idx + max_bytes])
|
||||
|
||||
|
||||
def pickle_load_large_file(filepath):
|
||||
max_bytes = 2 ** 31 - 1
|
||||
input_size = os.path.getsize(filepath)
|
||||
bytes_in = bytearray(0)
|
||||
with open(filepath, 'rb') as f_in:
|
||||
for _ in range(0, input_size, max_bytes):
|
||||
bytes_in += f_in.read(max_bytes)
|
||||
obj = pickle.loads(bytes_in)
|
||||
return obj
|
||||
|
||||
|
||||
def save(filepath, obj, message=None):
|
||||
if message is not None:
|
||||
logging.info("Saving {}...".format(message))
|
||||
pickle_dump_large_file(obj, filepath)
|
||||
|
||||
|
||||
def load(filepath):
|
||||
return pickle_load_large_file(filepath)
|
||||
|
||||
|
||||
def read_json(path):
|
||||
with open(path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(obj, path):
|
||||
with open(path, 'wb') as f:
|
||||
f.write(json.dumps(obj, indent=2, ensure_ascii=False).
|
||||
encode('utf-8'))
|
25
utils/optimizer_util.py
Normal file
25
utils/optimizer_util.py
Normal file
@ -0,0 +1,25 @@
|
||||
from torch import optim
|
||||
|
||||
from layers.encoders.transformers.bert.bert_optimization import BertAdam
|
||||
|
||||
|
||||
def set_optimizer(args, model, train_steps=None):
|
||||
if args.use_bert:
|
||||
param_optimizer = list(model.named_parameters())
|
||||
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
|
||||
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
|
||||
'weight_decay': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
t_total=train_steps)
|
||||
return optimizer
|
||||
else:
|
||||
parameters_trainable = list(filter(lambda p: p.requires_grad, model.parameters()))
|
||||
optimizer = optim.Adam(parameters_trainable, lr=args.learning_rate)
|
||||
return optimizer
|
Loading…
Reference in New Issue
Block a user