bert ner
This commit is contained in:
parent
56def3425e
commit
0815cdd656
@ -52,12 +52,14 @@ class Data:
|
|||||||
self.biword_alphabet_size = 0
|
self.biword_alphabet_size = 0
|
||||||
self.label_alphabet_size = 0
|
self.label_alphabet_size = 0
|
||||||
|
|
||||||
|
self.bertpath = 'transformer_cpt/bert/'
|
||||||
|
|
||||||
### hyperparameters
|
### hyperparameters
|
||||||
self.HP_iteration = 100
|
self.HP_iteration = 100
|
||||||
self.HP_batch_size = 10
|
self.HP_batch_size = 10
|
||||||
self.HP_hidden_dim = 128
|
self.HP_hidden_dim = 128
|
||||||
self.HP_dropout = 0.5
|
self.HP_dropout = 0.5
|
||||||
self.HP_lstm_layer = 2
|
self.HP_lstm_layer = 1
|
||||||
self.HP_bilstm = True
|
self.HP_bilstm = True
|
||||||
self.HP_gpu = True
|
self.HP_gpu = True
|
||||||
self.HP_lr = 0.015
|
self.HP_lr = 0.015
|
||||||
@ -82,6 +84,7 @@ class Data:
|
|||||||
print(" Norm word emb: %s" % (self.norm_word_emb))
|
print(" Norm word emb: %s" % (self.norm_word_emb))
|
||||||
print(" Norm biword emb: %s" % (self.norm_biword_emb))
|
print(" Norm biword emb: %s" % (self.norm_biword_emb))
|
||||||
print(" Norm gaz emb: %s" % (self.norm_gaz_emb))
|
print(" Norm gaz emb: %s" % (self.norm_gaz_emb))
|
||||||
|
print(" bert file is : %s" % (self.bert_type))
|
||||||
print(" Train instance number: %s" % (len(self.train_texts)))
|
print(" Train instance number: %s" % (len(self.train_texts)))
|
||||||
print(" Dev instance number: %s" % (len(self.dev_texts)))
|
print(" Dev instance number: %s" % (len(self.dev_texts)))
|
||||||
print(" Test instance number: %s" % (len(self.test_texts)))
|
print(" Test instance number: %s" % (len(self.test_texts)))
|
||||||
@ -134,7 +137,7 @@ class Data:
|
|||||||
pairs = line.strip().split('\t')
|
pairs = line.strip().split('\t')
|
||||||
if len(pairs) == 1:
|
if len(pairs) == 1:
|
||||||
word = ' '
|
word = ' '
|
||||||
print('word == ')
|
# print('word == ')
|
||||||
else:
|
else:
|
||||||
word = pairs[0]
|
word = pairs[0]
|
||||||
if self.number_normalized:
|
if self.number_normalized:
|
||||||
@ -216,18 +219,18 @@ class Data:
|
|||||||
if name == "train":
|
if name == "train":
|
||||||
self.train_texts, self.train_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
self.train_texts, self.train_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
||||||
self.label_alphabet, self.number_normalized,
|
self.label_alphabet, self.number_normalized,
|
||||||
self.MAX_SENTENCE_LENGTH)
|
self.MAX_SENTENCE_LENGTH, self.bertpath)
|
||||||
elif name == "dev":
|
elif name == "dev":
|
||||||
self.dev_texts, self.dev_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
self.dev_texts, self.dev_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
||||||
self.label_alphabet, self.number_normalized,
|
self.label_alphabet, self.number_normalized,
|
||||||
self.MAX_SENTENCE_LENGTH)
|
self.MAX_SENTENCE_LENGTH, self.bertpath)
|
||||||
elif name == "test":
|
elif name == "test":
|
||||||
self.test_texts, self.test_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
self.test_texts, self.test_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
||||||
self.label_alphabet, self.number_normalized,
|
self.label_alphabet, self.number_normalized,
|
||||||
self.MAX_SENTENCE_LENGTH)
|
self.MAX_SENTENCE_LENGTH, self.bertpath)
|
||||||
elif name == "raw":
|
elif name == "raw":
|
||||||
self.raw_texts, self.raw_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
self.raw_texts, self.raw_Ids = read_instance(input_file, self.word_alphabet, self.biword_alphabet,
|
||||||
self.label_alphabet, self.number_normalized,
|
self.label_alphabet, self.number_normalized,
|
||||||
self.MAX_SENTENCE_LENGTH)
|
self.MAX_SENTENCE_LENGTH, self.bertpath)
|
||||||
else:
|
else:
|
||||||
print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name))
|
print("Error: you can only generate train/dev/test instance! Illegal input:%s" % (name))
|
||||||
|
@ -19,8 +19,8 @@ def normalize_word(word):
|
|||||||
|
|
||||||
|
|
||||||
def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, number_normalized,
|
def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, number_normalized,
|
||||||
max_sent_length):
|
max_sent_length, bertpath):
|
||||||
tokenizer = BertTokenizer.from_pretrained('transformer_cpt/bert/', do_lower_case=True)
|
tokenizer = BertTokenizer.from_pretrained(bertpath, do_lower_case=True)
|
||||||
|
|
||||||
in_lines = open(input_file, 'r', encoding="utf-8").readlines()
|
in_lines = open(input_file, 'r', encoding="utf-8").readlines()
|
||||||
instence_texts = []
|
instence_texts = []
|
||||||
@ -36,8 +36,12 @@ def read_instance(input_file, word_alphabet, biword_alphabet, label_alphabet, nu
|
|||||||
for idx in range(len(in_lines)):
|
for idx in range(len(in_lines)):
|
||||||
line = in_lines[idx]
|
line = in_lines[idx]
|
||||||
if len(line) > 2:
|
if len(line) > 2:
|
||||||
pairs = line.strip().split()
|
pairs = line.strip().split('\t')
|
||||||
word = pairs[0]
|
if len(pairs) == 1:
|
||||||
|
word = ' '
|
||||||
|
# print('word == ')
|
||||||
|
else:
|
||||||
|
word = pairs[0]
|
||||||
if number_normalized:
|
if number_normalized:
|
||||||
word = normalize_word(word)
|
word = normalize_word(word)
|
||||||
label = pairs[-1]
|
label = pairs[-1]
|
||||||
|
Loading…
Reference in New Issue
Block a user