bert ner
This commit is contained in:
parent
56def3425e
commit
0815cdd656
@ -52,12 +52,14 @@ class Data:
|
||||
self.biword_alphabet_size = 0
|
||||
self.label_alphabet_size = 0
|
||||
|
||||
self.bertpath = 'transformer_cpt/bert/'
|
||||
|
||||
### hyperparameters
|
||||
self.HP_iteration = 100
|
||||
self.HP_batch_size = 10
|
||||
self.HP_hidden_dim = 128
|
||||
self.HP_dropout = 0.5
|
||||
self.HP_lstm_layer = 2
|
||||
self.HP_lstm_layer = 1
|
||||
self.HP_bilstm = True
|
||||
self.HP_gpu = True
|
||||
self.HP_lr = 0.015
|
||||
@ -82,6 +84,7 @@ class Data:
|
||||
print(" Norm word emb: %s" % (self.norm_word_emb))
|
||||
print(" Norm biword emb: %s" % (self.norm_biword_emb))
|
||||
print(" Norm gaz emb: %s" % (self.norm_gaz_emb))
|
||||
print(" bert file is : %s" % (self.bert_type))
|
||||
print(" Train instance number: %s" % (len(self.train_texts)))
|
||||
print(" Dev instance number: %s" % (len(self.dev_texts)))
|
||||
print(" Test instance number: %s" % (len(self.test_texts)))
|
||||
@ -134,7 +137,7 @@ class Data:
|
||||
pairs = line.strip().split('\t')
|
||||
if len(pairs) == 1:
|
||||
word = ' '
|
||||
print('word == ')
|
||||
# print('word == ')
|
||||
else:
|
||||
word = pairs[0]
|
||||
if self.number_normalized:
|
||||
@ -216,18 +219,18 @@ class Data:
|
||||
if name == "train":
|
||||
self.train_texts, self.train_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
||||
self.label_alphabet, self.number_normalized,
|
||||
self.MAX_SENTENCE_LENGTH)
|
||||
self.MAX_SENTENCE_LENGTH, self.bertpath)
|
||||
elif name == "dev":
|
||||
self.dev_texts, self.dev_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
||||
self.label_alphabet, self.number_normalized,
|
||||
self.MAX_SENTENCE_LENGTH)
|
||||
self.MAX_SENTENCE_LENGTH, self.bertpath)
|
||||
elif name == "test":
|
||||
self.test_texts, self.test_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
||||
self.label_alphabet, self.number_normalized,
|
||||
self.MAX_SENTENCE_LENGTH)
|
||||
self.MAX_SENTENCE_LENGTH, self.bertpath)
|
||||
elif name == "raw":
|
||||
self.raw_texts, self.raw_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
||||
self.label_alphabet, self.number_normalized,
|
||||
self.MAX_SENTENCE_LENGTH)
|
||||
self.MAX_SENTENCE_LENGTH, self.bertpath)
|
||||
else:
|
||||
print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name))
|
||||
|
@ -19,8 +19,8 @@ def normalize_word(word):
|
||||
|
||||
|
||||
def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, number_normalized,
|
||||
max_sent_length):
|
||||
tokenizer = BertTokenizer.from_pretrained('transformer_cpt/bert/', do_lower_case=True)
|
||||
max_sent_length, bertpath):
|
||||
tokenizer = BertTokenizer.from_pretrained(bertpath, do_lower_case=True)
|
||||
|
||||
in_lines = open(input_file, 'r', encoding="utf-8").readlines()
|
||||
instence_texts = []
|
||||
@ -36,8 +36,12 @@ def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, nu
|
||||
for idx in range(len(in_lines)):
|
||||
line = in_lines[idx]
|
||||
if len(line) > 2:
|
||||
pairs = line.strip().split()
|
||||
word = pairs[0]
|
||||
pairs = line.strip().split('\t')
|
||||
if len(pairs) == 1:
|
||||
word = ' '
|
||||
# print('word == ')
|
||||
else:
|
||||
word = pairs[0]
|
||||
if number_normalized:
|
||||
word = normalize_word(word)
|
||||
label = pairs[-1]
|
||||
|
Loading…
Reference in New Issue
Block a user