fix xlnet config

This commit is contained in:
yongzhuo 2021-09-02 16:15:19 +08:00
parent f215caa693
commit 38c814ae24
4 changed files with 22 additions and 15 deletions

View File

@ -79,7 +79,7 @@ step3: goto # Train&Usage(调用) and Predict&Usage(调用)
keras-bert还可以加载百度版ernie(需转换,[https://github.com/ArthurRizar/tensorflow_ernie](https://github.com/ArthurRizar/tensorflow_ernie)), keras-bert还可以加载百度版ernie(需转换,[https://github.com/ArthurRizar/tensorflow_ernie](https://github.com/ArthurRizar/tensorflow_ernie)),
哈工大版bert-wwm(tf框架[https://github.com/ymcui/Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm)) 哈工大版bert-wwm(tf框架[https://github.com/ymcui/Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm))
- albert_base_zh/(brightmart训练的albert, 地址为https://github.com/brightmart/albert_zh) - albert_base_zh/(brightmart训练的albert, 地址为https://github.com/brightmart/albert_zh)
- chinese_xlnet_mid_L-24_H-768_A-12/(哈工大预训练的中文xlnet模型[https://github.com/ymcui/Chinese-PreTrained-XLNet],24层) - chinese_xlnet_base_L-12_H-768_A-12/(哈工大预训练的中文xlnet模型[https://github.com/ymcui/Chinese-PreTrained-XLNet],12层)
- term_char.txt(已经上传, 项目中已全, wiki字典, 还可以用新华字典什么的) - term_char.txt(已经上传, 项目中已全, wiki字典, 还可以用新华字典什么的)
- term_word.txt(未上传, 项目中只有部分, 可参考词向量的) - term_word.txt(未上传, 项目中只有部分, 可参考词向量的)
- w2v_model_merge_short.vec(未上传, 项目中只有部分, 词向量, 可以用自己的) - w2v_model_merge_short.vec(未上传, 项目中只有部分, 词向量, 可以用自己的)

View File

@ -14,11 +14,15 @@ path_root = path_root.replace('\\', '/')
path_embedding_random_char = path_root + '/data/embeddings/term_char.txt' path_embedding_random_char = path_root + '/data/embeddings/term_char.txt'
path_embedding_random_word = path_root + '/data/embeddings/term_word.txt' path_embedding_random_word = path_root + '/data/embeddings/term_word.txt'
path_embedding_bert = path_root + '/data/embeddings/chinese_L-12_H-768_A-12/' path_embedding_bert = path_root + '/data/embeddings/chinese_L-12_H-768_A-12/'
path_embedding_xlnet = path_root + '/data/embeddings/chinese_xlnet_mid_L-24_H-768_A-12/' path_embedding_xlnet = path_root + '/data/embeddings/chinese_xlnet_base_L-12_H-768_A-12/'
path_embedding_albert = path_root + '/data/embeddings/albert_base_zh' path_embedding_albert = path_root + '/data/embeddings/albert_base_zh'
path_embedding_vector_word2vec_char = path_root + '/data/embeddings/w2v_model_wiki_char.vec' path_embedding_vector_word2vec_char = path_root + '/data/embeddings/w2v_model_wiki_char.vec'
path_embedding_vector_word2vec_word = path_root + '/data/embeddings/w2v_model_merge_short.vec' path_embedding_vector_word2vec_word = path_root + '/data/embeddings/w2v_model_merge_short.vec'
# classify data of tnews
path_tnews_train = path_root + '/data/tnews/train.csv'
path_tnews_valid = path_root + '/data/tnews/dev.csv'
# classify data of baidu qa 2019 # classify data of baidu qa 2019
path_baidu_qa_2019_train = path_root + '/data/baidu_qa_2019/baike_qa_train.csv' path_baidu_qa_2019_train = path_root + '/data/baidu_qa_2019/baike_qa_train.csv'
path_baidu_qa_2019_valid = path_root + '/data/baidu_qa_2019/baike_qa_valid.csv' path_baidu_qa_2019_valid = path_root + '/data/baidu_qa_2019/baike_qa_valid.csv'

View File

@ -37,7 +37,9 @@ class XlnetGraph(graph):
super().create_model(hyper_parameters) super().create_model(hyper_parameters)
embedding_output = self.word_embedding.output embedding_output = self.word_embedding.output
# x = embedding_output # x = embedding_output
x = Lambda(lambda x : x[:, -2:-1, :])(embedding_output) # 获取CLS # x = Lambda(lambda x: x[:, -1:, :])(embedding_output) # 获取CLS
# x = Lambda(lambda x: x[:, -1:])(embedding_output) # 获取CLS
x = embedding_output # 直接output
# # text cnn # # text cnn
# bert_output_emmbed = SpatialDropout1D(rate=self.dropout)(embedding_output) # bert_output_emmbed = SpatialDropout1D(rate=self.dropout)(embedding_output)
# concat_out = [] # concat_out = []
@ -52,7 +54,7 @@ class XlnetGraph(graph):
# x = GlobalMaxPooling1D(name='TextCNN_MaxPool1D_{}'.format(index))(x) # x = GlobalMaxPooling1D(name='TextCNN_MaxPool1D_{}'.format(index))(x)
# concat_out.append(x) # concat_out.append(x)
# x = Concatenate(axis=1)(concat_out) # x = Concatenate(axis=1)(concat_out)
# x = Dropout(self.dropout)(x) x = Dropout(self.dropout)(x)
x = Flatten()(x) x = Flatten()(x)
# 最后就是softmax # 最后就是softmax
dense_layer = Dense(self.label, activation=self.activate_classify)(x) dense_layer = Dense(self.label, activation=self.activate_classify)(x)

View File

@ -20,6 +20,7 @@ sys.path.append(project_path)
from keras_textclassification.conf.path_config import path_model, path_fineture, path_model_dir, path_hyper_parameters from keras_textclassification.conf.path_config import path_model, path_fineture, path_model_dir, path_hyper_parameters
# 训练验证数据地址 # 训练验证数据地址
from keras_textclassification.conf.path_config import path_baidu_qa_2019_train, path_baidu_qa_2019_valid from keras_textclassification.conf.path_config import path_baidu_qa_2019_train, path_baidu_qa_2019_valid
# from keras_textclassification.conf.path_config import path_tnews_train, path_tnews_valid
# 数据预处理, 删除文件目录下文件 # 数据预处理, 删除文件目录下文件
from keras_textclassification.data_preprocess.text_preprocess import PreprocessText, delete_file from keras_textclassification.data_preprocess.text_preprocess import PreprocessText, delete_file
# 模型图 # 模型图
@ -43,18 +44,18 @@ def train(hyper_parameters=None, rate=1.0):
'trainable': True, # 暂不支持微调True, embedding是静态的还是动态的, 即控制可不可以微调 'trainable': True, # 暂不支持微调True, embedding是静态的还是动态的, 即控制可不可以微调
'level_type': 'char', # 级别, 最小单元, 字/词, 填 'char' or 'word' 'level_type': 'char', # 级别, 最小单元, 字/词, 填 'char' or 'word'
'embedding_type': 'xlnet', # 级别, 嵌入类型, 还可以填'xlnet'、'random'、 'bert'、 'albert' or 'word2vec" 'embedding_type': 'xlnet', # 级别, 嵌入类型, 还可以填'xlnet'、'random'、 'bert'、 'albert' or 'word2vec"
'gpu_memory_fraction': 0.76, #gpu使用率 'gpu_memory_fraction': 0.8, #gpu使用率
'model': {'label': 17, # 类别数 'model': {'label': 10, # 类别数
'batch_size': 2, # 批处理尺寸, 感觉原则上越大越好,尤其是样本不均衡的时候, batch_size设置影响比较大 'batch_size': 16, # 批处理尺寸, 感觉原则上越大越好,尤其是样本不均衡的时候, batch_size设置影响比较大
'filters': [2, 3, 4, 5], # 卷积核尺寸 'filters': [2, 3, 4, 5], # 卷积核尺寸
'filters_num': 300, # 卷积个数 text-cnn:300-600 'filters_num': 300, # 卷积个数 text-cnn:300-600
'channel_size': 1, # CNN通道数 'channel_size': 1, # CNN通道数
'dropout': 0.5, # 随机失活, 概率 'dropout': 0.1, # 随机失活, 概率
'decay_step': 1000, # 学习率衰减step, 每N个step衰减一次 'decay_step': 1000, # 学习率衰减step, 每N个step衰减一次
'decay_rate': 0.9, # 学习率衰减系数, 乘法 'decay_rate': 0.99, # 学习率衰减系数, 乘法
'epochs': 20, # 训练最大轮次 'epochs': 20, # 训练最大轮次
'patience': 3, # 早停,2-3就好 'patience': 3, # 早停,2-3就好
'lr': 5e-5, # 学习率, 对训练会有比较大的影响, 如果准确率一直上不去,可以考虑调这个参数 'lr': 1e-4, # 学习率, 对训练会有比较大的影响, 如果准确率一直上不去,可以考虑调这个参数
'l2': 1e-9, # l2正则化 'l2': 1e-9, # l2正则化
'activate_classify': 'softmax', # 最后一个layer, 即分类激活函数 'activate_classify': 'softmax', # 最后一个layer, 即分类激活函数
'loss': 'categorical_crossentropy', # 损失函数 'loss': 'categorical_crossentropy', # 损失函数
@ -66,14 +67,14 @@ def train(hyper_parameters=None, rate=1.0):
'path_hyper_parameters': path_hyper_parameters, # 模型(包括embedding),超参数地址, 'path_hyper_parameters': path_hyper_parameters, # 模型(包括embedding),超参数地址,
'path_fineture': path_fineture, # 保存embedding trainable地址, 例如字向量、词向量、bert向量等 'path_fineture': path_fineture, # 保存embedding trainable地址, 例如字向量、词向量、bert向量等
}, },
'embedding': {'layer_indexes': [i for i in range(25)] + [-i for i in range(25)], # bert/xlnet取的层数,包括embedding层0其他是正常的层 'embedding': {'layer_indexes': [-1], # bert/xlnet取的层数,包括embedding层0其他是正常的层
# 'corpus_path': '', # embedding预训练数据地址,不配则会默认取conf里边默认的地址, keras-bert可以加载谷歌版bert,百度版ernie(需转换https://github.com/ArthurRizar/tensorflow_ernie),哈工大版bert-wwm(tf框架https://github.com/ymcui/Chinese-BERT-wwm) # 'corpus_path': '', # embedding预训练数据地址,不配则会默认取conf里边默认的地址, keras-bert可以加载谷歌版bert,百度版ernie(需转换https://github.com/ArthurRizar/tensorflow_ernie),哈工大版bert-wwm(tf框架https://github.com/ymcui/Chinese-BERT-wwm)
'xlnet_embed':{'attention_type': 'bi', # or 'uni' 'xlnet_embed':{'attention_type': 'bi', # or 'uni'
'memory_len': 0, 'memory_len': 32,
'target_len': 32,}, 'target_len': 32},
}, },
'data':{'train_data': path_baidu_qa_2019_train, # 训练数据 'data':{'train_data': path_baidu_qa_2019_train, # 训练数据
'val_data': path_baidu_qa_2019_valid # 验证数据 'val_data': path_baidu_qa_2019_valid # 验证数据
}, },
} }