fisrt commit

This commit is contained in:
loujie0822 2021-03-01 17:39:41 +08:00
parent 9d3a201f1d
commit 314471f02a
9 changed files with 34 additions and 30 deletions

View File

@ -65,7 +65,7 @@ class GazLSTM(nn.Module):
if self.use_bert:
self.bert_encoder = BertModel.from_pretrained('transformer_cpt/bert/')
# self.xlnet_encoder = XLNetModel.from_pretrained('transformer_cpt/chinese_xlnet_base_pytorch')
self.xlnet_encoder = XLNetModel.from_pretrained('transformer_cpt/chinese_xlnet_base_pytorch')
self.bert_encoder_wwm = BertModel.from_pretrained('transformer_cpt/chinese_roberta_wwm_ext_pytorch/')
for p in self.bert_encoder.parameters():
p.requires_grad = False

View File

@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.nn as nn
from transformers import AlbertModel
from transformers import AlbertPreTrainedModel
from transformers.modeling_albert import AlbertPreTrainedModel
from layers.encoders.transformers.bert.layernorm import ConditionalLayerNorm
from utils.data_util import batch_gather

View File

@ -14,7 +14,7 @@ NULLKEY = "-null-"
class Data:
def __init__(self):
self.MAX_SENTENCE_LENGTH = 1000
self.MAX_SENTENCE_LENGTH = 400
self.MAX_WORD_LENGTH = -1
self.number_normalized = True
self.norm_word_emb = True

View File

@ -14,7 +14,7 @@ NULLKEY = "-null-"
class Data:
def __init__(self):
self.MAX_SENTENCE_LENGTH = 250
self.MAX_SENTENCE_LENGTH = 400
self.MAX_WORD_LENGTH = -1
self.number_normalized = True
self.norm_word_emb = True

View File

@ -42,8 +42,4 @@ class Gazetteer:
exit(0)
def size(self):
return len(self.ent2type)
return len(self.ent2type)

View File

@ -164,9 +164,9 @@ def evaluate(data, model, name):
if not instance:
continue
batch_word, batch_biword, batch_wordlen, batch_label, mask, batch_bert, bert_mask = batchify_with_label(
batch_word, batch_biword, batch_wordlen, batch_label, mask, batch_bert, batch_xlnet, bert_mask = batchify_with_label(
instance, data.HP_gpu, data.device)
tag_seq = model(batch_word, batch_biword, batch_wordlen, mask, batch_bert, bert_mask)
tag_seq = model(batch_word, batch_biword, batch_wordlen, mask, batch_bert, batch_xlnet, bert_mask)
pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet)
pred_results += pred_label
@ -198,6 +198,7 @@ def batchify_with_label(input_batch_list, gpu, device):
labels = [sent[2] for sent in input_batch_list]
### bert tokens
bert_ids = [sent[3] for sent in input_batch_list]
xlnet_ids = [sent[4] for sent in input_batch_list]
word_seq_lengths = torch.LongTensor(list(map(len, words)))
max_seq_len = word_seq_lengths.max()
@ -207,10 +208,11 @@ def batchify_with_label(input_batch_list, gpu, device):
mask = autograd.Variable(torch.zeros((batch_size, max_seq_len))).byte()
### bert seq tensor
bert_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len + 2))).long()
xlnet_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long()
bert_mask = autograd.Variable(torch.zeros((batch_size, max_seq_len + 2))).long()
for b, (seq, biseq, label, seqlen, bert_id) in enumerate(
zip(words, biwords, labels, word_seq_lengths, bert_ids)):
for b, (seq, biseq, label, seqlen, bert_id, xlnet_id) in enumerate(
zip(words, biwords, labels, word_seq_lengths, bert_ids, xlnet_ids)):
word_seq_tensor[b, :seqlen] = torch.LongTensor(seq)
biword_seq_tensor[b, :seqlen] = torch.LongTensor(biseq)
label_seq_tensor[b, :seqlen] = torch.LongTensor(label)
@ -218,6 +220,7 @@ def batchify_with_label(input_batch_list, gpu, device):
bert_mask[b, :seqlen + 2] = torch.LongTensor([1] * int(seqlen + 2))
##bert
bert_seq_tensor[b, :seqlen + 2] = torch.LongTensor(bert_id)
xlnet_seq_tensor[b, :seqlen] = torch.LongTensor(xlnet_id)
if gpu:
word_seq_tensor = word_seq_tensor.cuda(device)
@ -226,9 +229,10 @@ def batchify_with_label(input_batch_list, gpu, device):
label_seq_tensor = label_seq_tensor.cuda(device)
mask = mask.cuda(device)
bert_seq_tensor = bert_seq_tensor.cuda(device)
xlnet_seq_tensor = xlnet_seq_tensor.cuda(device)
bert_mask = bert_mask.cuda(device)
return word_seq_tensor, biword_seq_tensor, word_seq_lengths, label_seq_tensor, mask, bert_seq_tensor, bert_mask
return word_seq_tensor, biword_seq_tensor, word_seq_lengths, label_seq_tensor, mask, bert_seq_tensor, xlnet_seq_tensor, bert_mask
def train(data, save_model_dir, seg=True, debug=False, transfer=False):
@ -248,8 +252,7 @@ def train(data, save_model_dir, seg=True, debug=False, transfer=False):
del model_dict['hidden2tag.weight']
del model_dict['hidden2tag.bias']
del model_dict['crf.transitions']
model.load_state_dict(model_dict,strict=False)
model.load_state_dict(model_dict, strict=False)
print("finish building model.")
@ -310,12 +313,12 @@ def train(data, save_model_dir, seg=True, debug=False, transfer=False):
if not instance:
continue
batch_word, batch_biword, batch_wordlen, batch_label, mask, batch_bert, bert_mask = batchify_with_label(
batch_word, batch_biword, batch_wordlen, batch_label, mask, batch_bert, batch_xlnet, bert_mask = batchify_with_label(
instance, data.HP_gpu, data.device)
instance_count += 1
loss, tag_seq = model.neg_log_likelihood_loss(batch_word, batch_biword, batch_wordlen, mask,
batch_label, batch_bert, bert_mask)
batch_label, batch_bert, batch_xlnet, bert_mask)
right, whole = predict_check(tag_seq, batch_label, mask)
right_token += right
@ -324,7 +327,7 @@ def train(data, save_model_dir, seg=True, debug=False, transfer=False):
total_loss += loss.data
batch_loss += loss
if end % 500 == 0:
if end % 200 == 0:
temp_time = time.time()
temp_cost = temp_time - temp_start
temp_start = temp_time
@ -470,7 +473,7 @@ if __name__ == '__main__':
parser.add_argument('--use_char', dest='use_char', action='store_true', default=False)
# parser.set_defaults(use_biword=False)
parser.add_argument('--use_count', action='store_true', default=True)
parser.add_argument('--use_bert', action='store_true', default=True)
parser.add_argument('--use_bert', action='store_true', default=False)
args = parser.parse_args()
@ -520,7 +523,7 @@ if __name__ == '__main__':
data.warm_up = args.warm_up
data.device = args.device
data.bert_finetune = args.bert_finetune
data.transfer=args.transfer
data.transfer = args.transfer
data.HP_lstm_layer = args.lstm_layer
@ -559,7 +562,7 @@ if __name__ == '__main__':
data.show_data_summary()
print('data.use_biword=', data.use_bigram)
print('data.HP_batch_size=', data.HP_batch_size)
train(data, save_model_dir, seg, debug=False,transfer=data.transfer)
train(data, save_model_dir, seg, debug=False, transfer=data.transfer)
elif status == 'test':
print('Loading processed data')
with open(save_data_name, 'rb') as fp:

View File

@ -14,7 +14,7 @@ NULLKEY = "-null-"
class Data:
def __init__(self):
self.MAX_SENTENCE_LENGTH = 1000
self.MAX_SENTENCE_LENGTH = 400
self.MAX_WORD_LENGTH = -1
self.number_normalized = True
self.norm_word_emb = True
@ -25,7 +25,7 @@ class Data:
self.biword_alphabet = Alphabet('biword', min_freq=self.min_freq)
self.label_alphabet = Alphabet('label', True)
self.device = 0
self.transfer=False
self.transfer = False
self.biword_count = {}
@ -45,6 +45,7 @@ class Data:
self.dev_split_index = []
self.use_bigram = True
self.use_bert = False
self.word_emb_dim = 50
self.biword_emb_dim = 50
@ -80,6 +81,7 @@ class Data:
print(" MAX WORD LENGTH: %s" % (self.MAX_WORD_LENGTH))
print(" Number normalized: %s" % (self.number_normalized))
print(" Use bigram: %s" % (self.use_bigram))
print(" Use bert : %s" % (self.use_bert))
print(" Word alphabet size: %s" % (self.word_alphabet_size))
print(" Biword alphabet size: %s" % (self.biword_alphabet_size))
print(" Label alphabet size: %s" % (self.label_alphabet_size))
@ -132,7 +134,7 @@ class Data:
self.fix_alphabet()
print("Refresh label alphabet finished: old:%s -> new:%s" % (old_size, self.label_alphabet_size))
def build_alphabet(self, input_file,only_label=False,use_label=True):
def build_alphabet(self, input_file, only_label=False, use_label=True):
in_lines = open(input_file, 'r', encoding="utf-8").readlines()
seqlen = 0
for idx in tqdm(range(len(in_lines))):

View File

@ -4,6 +4,7 @@ import numpy as np
# from transformers.tokenization_bert import BertTokenizer
from transformers import BertTokenizer
from transformers.tokenization_xlnet import XLNetTokenizer
NULLKEY = "-null-"
@ -21,7 +22,8 @@ def normalize_word(word):
def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, number_normalized,
max_sent_length, bertpath):
tokenizer = BertTokenizer.from_pretrained(bertpath, do_lower_case=True)
xlnet_tokenizer = XLNetTokenizer.from_pretrained('transformer_cpt/chinese_xlnet_base_pytorch/',
add_special_tokens=False)
in_lines = open(input_file, 'r', encoding="utf-8").readlines()
instence_texts = []
instence_Ids = []
@ -67,6 +69,7 @@ def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, nu
texts = ['[CLS]'] + words[:max_sent_length] + ['[SEP]']
bert_text_ids = tokenizer.convert_tokens_to_ids(texts)
xlnet_text_ids = xlnet_tokenizer.convert_tokens_to_ids(words[:max_sent_length])
instence_texts.append([words, biwords, labels])
word_Ids = word_Ids[:max_sent_length]
@ -74,7 +77,7 @@ def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, nu
label_Ids = label_Ids[:max_sent_length]
assert len(texts) - 2 == len(word_Ids)
instence_Ids.append([word_Ids, biword_Ids, label_Ids, bert_text_ids])
instence_Ids.append([word_Ids, biword_Ids, label_Ids, bert_text_ids, xlnet_text_ids])
words = []
biwords = []

View File

@ -38,9 +38,9 @@ class Trainer(object):
self.model.to(self.device)
if args.train_mode == "predict":
self.resume(args)
# logging.info('total gpu num is {}'.format(self.n_gpu))
# if self.n_gpu > 1:
# self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
logging.info('total gpu num is {}'.format(self.n_gpu))
if self.n_gpu > 1:
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
train_dataloader, dev_dataloader = data_loaders
train_eval, dev_eval = examples