This commit is contained in:
loujie0822 2020-05-09 18:33:48 +08:00
parent 56def3425e
commit 0815cdd656
2 changed files with 17 additions and 10 deletions

View File

@ -52,12 +52,14 @@ class Data:
self.biword_alphabet_size = 0
self.label_alphabet_size = 0
self.bertpath = 'transformer_cpt/bert/'
### hyperparameters
self.HP_iteration = 100
self.HP_batch_size = 10
self.HP_hidden_dim = 128
self.HP_dropout = 0.5
self.HP_lstm_layer = 2
self.HP_lstm_layer = 1
self.HP_bilstm = True
self.HP_gpu = True
self.HP_lr = 0.015
@ -82,6 +84,7 @@ class Data:
print(" Norm word emb: %s" % (self.norm_word_emb))
print(" Norm biword emb: %s" % (self.norm_biword_emb))
print(" Norm gaz emb: %s" % (self.norm_gaz_emb))
print(" bert file is : %s" % (self.bert_type))
print(" Train instance number: %s" % (len(self.train_texts)))
print(" Dev instance number: %s" % (len(self.dev_texts)))
print(" Test instance number: %s" % (len(self.test_texts)))
@ -134,7 +137,7 @@ class Data:
pairs = line.strip().split('\t')
if len(pairs) == 1:
word = ' '
print('word == ')
# print('word == ')
else:
word = pairs[0]
if self.number_normalized:
@ -216,18 +219,18 @@ class Data:
if name == "train":
self.train_texts, self.train_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
self.label_alphabet, self.number_normalized,
self.MAX_SENTENCE_LENGTH)
self.MAX_SENTENCE_LENGTH, self.bertpath)
elif name == "dev":
self.dev_texts, self.dev_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
self.label_alphabet, self.number_normalized,
self.MAX_SENTENCE_LENGTH)
self.MAX_SENTENCE_LENGTH, self.bertpath)
elif name == "test":
self.test_texts, self.test_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
self.label_alphabet, self.number_normalized,
self.MAX_SENTENCE_LENGTH)
self.MAX_SENTENCE_LENGTH, self.bertpath)
elif name == "raw":
self.raw_texts, self.raw_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
self.label_alphabet, self.number_normalized,
self.MAX_SENTENCE_LENGTH)
self.MAX_SENTENCE_LENGTH, self.bertpath)
else:
print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name))

View File

@ -19,8 +19,8 @@ def normalize_word(word):
def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, number_normalized,
max_sent_length):
tokenizer = BertTokenizer.from_pretrained('transformer_cpt/bert/', do_lower_case=True)
max_sent_length, bertpath):
tokenizer = BertTokenizer.from_pretrained(bertpath, do_lower_case=True)
in_lines = open(input_file, 'r', encoding="utf-8").readlines()
instence_texts = []
@ -36,8 +36,12 @@ def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, nu
for idx in range(len(in_lines)):
line = in_lines[idx]
if len(line) > 2:
pairs = line.strip().split()
word = pairs[0]
pairs = line.strip().split('\t')
if len(pairs) == 1:
word = ' '
# print('word == ')
else:
word = pairs[0]
if number_normalized:
word = normalize_word(word)
label = pairs[-1]