fisrt commit
This commit is contained in:
parent
9d3a201f1d
commit
314471f02a
@ -65,7 +65,7 @@ class GazLSTM(nn.Module):
|
||||
|
||||
if self.use_bert:
|
||||
self.bert_encoder = BertModel.from_pretrained('transformer_cpt/bert/')
|
||||
# self.xlnet_encoder = XLNetModel.from_pretrained('transformer_cpt/chinese_xlnet_base_pytorch')
|
||||
self.xlnet_encoder = XLNetModel.from_pretrained('transformer_cpt/chinese_xlnet_base_pytorch')
|
||||
self.bert_encoder_wwm = BertModel.from_pretrained('transformer_cpt/chinese_roberta_wwm_ext_pytorch/')
|
||||
for p in self.bert_encoder.parameters():
|
||||
p.requires_grad = False
|
||||
|
@ -11,7 +11,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AlbertModel
|
||||
from transformers import AlbertPreTrainedModel
|
||||
from transformers.modeling_albert import AlbertPreTrainedModel
|
||||
|
||||
from layers.encoders.transformers.bert.layernorm import ConditionalLayerNorm
|
||||
from utils.data_util import batch_gather
|
||||
|
@ -14,7 +14,7 @@ NULLKEY = "-null-"
|
||||
|
||||
class Data:
|
||||
def __init__(self):
|
||||
self.MAX_SENTENCE_LENGTH = 1000
|
||||
self.MAX_SENTENCE_LENGTH = 400
|
||||
self.MAX_WORD_LENGTH = -1
|
||||
self.number_normalized = True
|
||||
self.norm_word_emb = True
|
||||
|
@ -14,7 +14,7 @@ NULLKEY = "-null-"
|
||||
|
||||
class Data:
|
||||
def __init__(self):
|
||||
self.MAX_SENTENCE_LENGTH = 250
|
||||
self.MAX_SENTENCE_LENGTH = 400
|
||||
self.MAX_WORD_LENGTH = -1
|
||||
self.number_normalized = True
|
||||
self.norm_word_emb = True
|
||||
|
@ -43,7 +43,3 @@ class Gazetteer:
|
||||
|
||||
def size(self):
|
||||
return len(self.ent2type)
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -164,9 +164,9 @@ def evaluate(data, model, name):
|
||||
|
||||
if not instance:
|
||||
continue
|
||||
batch_word, batch_biword, batch_wordlen, batch_label, mask, batch_bert, bert_mask = batchify_with_label(
|
||||
batch_word, batch_biword, batch_wordlen, batch_label, mask, batch_bert, batch_xlnet, bert_mask = batchify_with_label(
|
||||
instance, data.HP_gpu, data.device)
|
||||
tag_seq = model(batch_word, batch_biword, batch_wordlen, mask, batch_bert, bert_mask)
|
||||
tag_seq = model(batch_word, batch_biword, batch_wordlen, mask, batch_bert, batch_xlnet, bert_mask)
|
||||
|
||||
pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet)
|
||||
pred_results += pred_label
|
||||
@ -198,6 +198,7 @@ def batchify_with_label(input_batch_list, gpu, device):
|
||||
labels = [sent[2] for sent in input_batch_list]
|
||||
### bert tokens
|
||||
bert_ids = [sent[3] for sent in input_batch_list]
|
||||
xlnet_ids = [sent[4] for sent in input_batch_list]
|
||||
|
||||
word_seq_lengths = torch.LongTensor(list(map(len, words)))
|
||||
max_seq_len = word_seq_lengths.max()
|
||||
@ -207,10 +208,11 @@ def batchify_with_label(input_batch_list, gpu, device):
|
||||
mask = autograd.Variable(torch.zeros((batch_size, max_seq_len))).byte()
|
||||
### bert seq tensor
|
||||
bert_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len + 2))).long()
|
||||
xlnet_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len))).long()
|
||||
bert_mask = autograd.Variable(torch.zeros((batch_size, max_seq_len + 2))).long()
|
||||
|
||||
for b, (seq, biseq, label, seqlen, bert_id) in enumerate(
|
||||
zip(words, biwords, labels, word_seq_lengths, bert_ids)):
|
||||
for b, (seq, biseq, label, seqlen, bert_id, xlnet_id) in enumerate(
|
||||
zip(words, biwords, labels, word_seq_lengths, bert_ids, xlnet_ids)):
|
||||
word_seq_tensor[b, :seqlen] = torch.LongTensor(seq)
|
||||
biword_seq_tensor[b, :seqlen] = torch.LongTensor(biseq)
|
||||
label_seq_tensor[b, :seqlen] = torch.LongTensor(label)
|
||||
@ -218,6 +220,7 @@ def batchify_with_label(input_batch_list, gpu, device):
|
||||
bert_mask[b, :seqlen + 2] = torch.LongTensor([1] * int(seqlen + 2))
|
||||
##bert
|
||||
bert_seq_tensor[b, :seqlen + 2] = torch.LongTensor(bert_id)
|
||||
xlnet_seq_tensor[b, :seqlen] = torch.LongTensor(xlnet_id)
|
||||
|
||||
if gpu:
|
||||
word_seq_tensor = word_seq_tensor.cuda(device)
|
||||
@ -226,9 +229,10 @@ def batchify_with_label(input_batch_list, gpu, device):
|
||||
label_seq_tensor = label_seq_tensor.cuda(device)
|
||||
mask = mask.cuda(device)
|
||||
bert_seq_tensor = bert_seq_tensor.cuda(device)
|
||||
xlnet_seq_tensor = xlnet_seq_tensor.cuda(device)
|
||||
bert_mask = bert_mask.cuda(device)
|
||||
|
||||
return word_seq_tensor, biword_seq_tensor, word_seq_lengths, label_seq_tensor, mask, bert_seq_tensor, bert_mask
|
||||
return word_seq_tensor, biword_seq_tensor, word_seq_lengths, label_seq_tensor, mask, bert_seq_tensor, xlnet_seq_tensor, bert_mask
|
||||
|
||||
|
||||
def train(data, save_model_dir, seg=True, debug=False, transfer=False):
|
||||
@ -248,8 +252,7 @@ def train(data, save_model_dir, seg=True, debug=False, transfer=False):
|
||||
del model_dict['hidden2tag.weight']
|
||||
del model_dict['hidden2tag.bias']
|
||||
del model_dict['crf.transitions']
|
||||
model.load_state_dict(model_dict,strict=False)
|
||||
|
||||
model.load_state_dict(model_dict, strict=False)
|
||||
|
||||
print("finish building model.")
|
||||
|
||||
@ -310,12 +313,12 @@ def train(data, save_model_dir, seg=True, debug=False, transfer=False):
|
||||
if not instance:
|
||||
continue
|
||||
|
||||
batch_word, batch_biword, batch_wordlen, batch_label, mask, batch_bert, bert_mask = batchify_with_label(
|
||||
batch_word, batch_biword, batch_wordlen, batch_label, mask, batch_bert, batch_xlnet, bert_mask = batchify_with_label(
|
||||
instance, data.HP_gpu, data.device)
|
||||
|
||||
instance_count += 1
|
||||
loss, tag_seq = model.neg_log_likelihood_loss(batch_word, batch_biword, batch_wordlen, mask,
|
||||
batch_label, batch_bert, bert_mask)
|
||||
batch_label, batch_bert, batch_xlnet, bert_mask)
|
||||
|
||||
right, whole = predict_check(tag_seq, batch_label, mask)
|
||||
right_token += right
|
||||
@ -324,7 +327,7 @@ def train(data, save_model_dir, seg=True, debug=False, transfer=False):
|
||||
total_loss += loss.data
|
||||
batch_loss += loss
|
||||
|
||||
if end % 500 == 0:
|
||||
if end % 200 == 0:
|
||||
temp_time = time.time()
|
||||
temp_cost = temp_time - temp_start
|
||||
temp_start = temp_time
|
||||
@ -470,7 +473,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--use_char', dest='use_char', action='store_true', default=False)
|
||||
# parser.set_defaults(use_biword=False)
|
||||
parser.add_argument('--use_count', action='store_true', default=True)
|
||||
parser.add_argument('--use_bert', action='store_true', default=True)
|
||||
parser.add_argument('--use_bert', action='store_true', default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -520,7 +523,7 @@ if __name__ == '__main__':
|
||||
data.warm_up = args.warm_up
|
||||
data.device = args.device
|
||||
data.bert_finetune = args.bert_finetune
|
||||
data.transfer=args.transfer
|
||||
data.transfer = args.transfer
|
||||
|
||||
data.HP_lstm_layer = args.lstm_layer
|
||||
|
||||
@ -559,7 +562,7 @@ if __name__ == '__main__':
|
||||
data.show_data_summary()
|
||||
print('data.use_biword=', data.use_bigram)
|
||||
print('data.HP_batch_size=', data.HP_batch_size)
|
||||
train(data, save_model_dir, seg, debug=False,transfer=data.transfer)
|
||||
train(data, save_model_dir, seg, debug=False, transfer=data.transfer)
|
||||
elif status == 'test':
|
||||
print('Loading processed data')
|
||||
with open(save_data_name, 'rb') as fp:
|
||||
|
@ -14,7 +14,7 @@ NULLKEY = "-null-"
|
||||
|
||||
class Data:
|
||||
def __init__(self):
|
||||
self.MAX_SENTENCE_LENGTH = 1000
|
||||
self.MAX_SENTENCE_LENGTH = 400
|
||||
self.MAX_WORD_LENGTH = -1
|
||||
self.number_normalized = True
|
||||
self.norm_word_emb = True
|
||||
@ -25,7 +25,7 @@ class Data:
|
||||
self.biword_alphabet = Alphabet('biword', min_freq=self.min_freq)
|
||||
self.label_alphabet = Alphabet('label', True)
|
||||
self.device = 0
|
||||
self.transfer=False
|
||||
self.transfer = False
|
||||
|
||||
self.biword_count = {}
|
||||
|
||||
@ -45,6 +45,7 @@ class Data:
|
||||
self.dev_split_index = []
|
||||
|
||||
self.use_bigram = True
|
||||
self.use_bert = False
|
||||
self.word_emb_dim = 50
|
||||
self.biword_emb_dim = 50
|
||||
|
||||
@ -80,6 +81,7 @@ class Data:
|
||||
print(" MAX WORD LENGTH: %s" % (self.MAX_WORD_LENGTH))
|
||||
print(" Number normalized: %s" % (self.number_normalized))
|
||||
print(" Use bigram: %s" % (self.use_bigram))
|
||||
print(" Use bert : %s" % (self.use_bert))
|
||||
print(" Word alphabet size: %s" % (self.word_alphabet_size))
|
||||
print(" Biword alphabet size: %s" % (self.biword_alphabet_size))
|
||||
print(" Label alphabet size: %s" % (self.label_alphabet_size))
|
||||
@ -132,7 +134,7 @@ class Data:
|
||||
self.fix_alphabet()
|
||||
print("Refresh label alphabet finished: old:%s -> new:%s" % (old_size, self.label_alphabet_size))
|
||||
|
||||
def build_alphabet(self, input_file,only_label=False,use_label=True):
|
||||
def build_alphabet(self, input_file, only_label=False, use_label=True):
|
||||
in_lines = open(input_file, 'r', encoding="utf-8").readlines()
|
||||
seqlen = 0
|
||||
for idx in tqdm(range(len(in_lines))):
|
||||
|
@ -4,6 +4,7 @@ import numpy as np
|
||||
|
||||
# from transformers.tokenization_bert import BertTokenizer
|
||||
from transformers import BertTokenizer
|
||||
from transformers.tokenization_xlnet import XLNetTokenizer
|
||||
|
||||
NULLKEY = "-null-"
|
||||
|
||||
@ -21,7 +22,8 @@ def normalize_word(word):
|
||||
def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, number_normalized,
|
||||
max_sent_length, bertpath):
|
||||
tokenizer = BertTokenizer.from_pretrained(bertpath, do_lower_case=True)
|
||||
|
||||
xlnet_tokenizer = XLNetTokenizer.from_pretrained('transformer_cpt/chinese_xlnet_base_pytorch/',
|
||||
add_special_tokens=False)
|
||||
in_lines = open(input_file, 'r', encoding="utf-8").readlines()
|
||||
instence_texts = []
|
||||
instence_Ids = []
|
||||
@ -67,6 +69,7 @@ def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, nu
|
||||
texts = ['[CLS]'] + words[:max_sent_length] + ['[SEP]']
|
||||
|
||||
bert_text_ids = tokenizer.convert_tokens_to_ids(texts)
|
||||
xlnet_text_ids = xlnet_tokenizer.convert_tokens_to_ids(words[:max_sent_length])
|
||||
instence_texts.append([words, biwords, labels])
|
||||
|
||||
word_Ids = word_Ids[:max_sent_length]
|
||||
@ -74,7 +77,7 @@ def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, nu
|
||||
label_Ids = label_Ids[:max_sent_length]
|
||||
|
||||
assert len(texts) - 2 == len(word_Ids)
|
||||
instence_Ids.append([word_Ids, biword_Ids, label_Ids, bert_text_ids])
|
||||
instence_Ids.append([word_Ids, biword_Ids, label_Ids, bert_text_ids, xlnet_text_ids])
|
||||
|
||||
words = []
|
||||
biwords = []
|
||||
|
@ -38,9 +38,9 @@ class Trainer(object):
|
||||
self.model.to(self.device)
|
||||
if args.train_mode == "predict":
|
||||
self.resume(args)
|
||||
# logging.info('total gpu num is {}'.format(self.n_gpu))
|
||||
# if self.n_gpu > 1:
|
||||
# self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
|
||||
logging.info('total gpu num is {}'.format(self.n_gpu))
|
||||
if self.n_gpu > 1:
|
||||
self.model = nn.DataParallel(self.model.cuda(), device_ids=[0, 1])
|
||||
|
||||
train_dataloader, dev_dataloader = data_loaders
|
||||
train_eval, dev_eval = examples
|
||||
|
Loading…
Reference in New Issue
Block a user