This commit is contained in:
loujie0822 2020-05-09 18:33:48 +08:00
parent 56def3425e
commit 0815cdd656
2 changed files with 17 additions and 10 deletions

View File

@ -52,12 +52,14 @@ class Data:
self.biword_alphabet_size = 0 self.biword_alphabet_size = 0
self.label_alphabet_size = 0 self.label_alphabet_size = 0
self.bertpath = 'transformer_cpt/bert/'
### hyperparameters ### hyperparameters
self.HP_iteration = 100 self.HP_iteration = 100
self.HP_batch_size = 10 self.HP_batch_size = 10
self.HP_hidden_dim = 128 self.HP_hidden_dim = 128
self.HP_dropout = 0.5 self.HP_dropout = 0.5
self.HP_lstm_layer = 2 self.HP_lstm_layer = 1
self.HP_bilstm = True self.HP_bilstm = True
self.HP_gpu = True self.HP_gpu = True
self.HP_lr = 0.015 self.HP_lr = 0.015
@ -82,6 +84,7 @@ class Data:
print(" Norm word emb: %s" % (self.norm_word_emb)) print(" Norm word emb: %s" % (self.norm_word_emb))
print(" Norm biword emb: %s" % (self.norm_biword_emb)) print(" Norm biword emb: %s" % (self.norm_biword_emb))
print(" Norm gaz emb: %s" % (self.norm_gaz_emb)) print(" Norm gaz emb: %s" % (self.norm_gaz_emb))
print(" bert file is : %s" % (self.bert_type))
print(" Train instance number: %s" % (len(self.train_texts))) print(" Train instance number: %s" % (len(self.train_texts)))
print(" Dev instance number: %s" % (len(self.dev_texts))) print(" Dev instance number: %s" % (len(self.dev_texts)))
print(" Test instance number: %s" % (len(self.test_texts))) print(" Test instance number: %s" % (len(self.test_texts)))
@ -134,7 +137,7 @@ class Data:
pairs = line.strip().split('\t') pairs = line.strip().split('\t')
if len(pairs) == 1: if len(pairs) == 1:
word = ' ' word = ' '
print('word == ') # print('word == ')
else: else:
word = pairs[0] word = pairs[0]
if self.number_normalized: if self.number_normalized:
@ -216,18 +219,18 @@ class Data:
if name == "train": if name == "train":
self.train_texts, self.train_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet, self.train_texts, self.train_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
self.label_alphabet, self.number_normalized, self.label_alphabet, self.number_normalized,
self.MAX_SENTENCE_LENGTH) self.MAX_SENTENCE_LENGTH, self.bertpath)
elif name == "dev": elif name == "dev":
self.dev_texts, self.dev_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet, self.dev_texts, self.dev_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
self.label_alphabet, self.number_normalized, self.label_alphabet, self.number_normalized,
self.MAX_SENTENCE_LENGTH) self.MAX_SENTENCE_LENGTH, self.bertpath)
elif name == "test": elif name == "test":
self.test_texts, self.test_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet, self.test_texts, self.test_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
self.label_alphabet, self.number_normalized, self.label_alphabet, self.number_normalized,
self.MAX_SENTENCE_LENGTH) self.MAX_SENTENCE_LENGTH, self.bertpath)
elif name == "raw": elif name == "raw":
self.raw_texts, self.raw_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet, self.raw_texts, self.raw_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
self.label_alphabet, self.number_normalized, self.label_alphabet, self.number_normalized,
self.MAX_SENTENCE_LENGTH) self.MAX_SENTENCE_LENGTH, self.bertpath)
else: else:
print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name)) print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name))

View File

@ -19,8 +19,8 @@ def normalize_word(word):
def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, number_normalized, def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, number_normalized,
max_sent_length): max_sent_length, bertpath):
tokenizer = BertTokenizer.from_pretrained('transformer_cpt/bert/', do_lower_case=True) tokenizer = BertTokenizer.from_pretrained(bertpath, do_lower_case=True)
in_lines = open(input_file, 'r', encoding="utf-8").readlines() in_lines = open(input_file, 'r', encoding="utf-8").readlines()
instence_texts = [] instence_texts = []
@ -36,8 +36,12 @@ def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, nu
for idx in range(len(in_lines)): for idx in range(len(in_lines)):
line = in_lines[idx] line = in_lines[idx]
if len(line) > 2: if len(line) > 2:
pairs = line.strip().split() pairs = line.strip().split('\t')
word = pairs[0] if len(pairs) == 1:
word = ' '
# print('word == ')
else:
word = pairs[0]
if number_normalized: if number_normalized:
word = normalize_word(word) word = normalize_word(word)
label = pairs[-1] label = pairs[-1]