fix xlnet config
This commit is contained in:
parent
f215caa693
commit
38c814ae24
@ -79,7 +79,7 @@ step3: goto # Train&Usage(调用) and Predict&Usage(调用)
|
||||
keras-bert还可以加载百度版ernie(需转换,[https://github.com/ArthurRizar/tensorflow_ernie](https://github.com/ArthurRizar/tensorflow_ernie)),
|
||||
哈工大版bert-wwm(tf框架,[https://github.com/ymcui/Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm))
|
||||
- albert_base_zh/(brightmart训练的albert, 地址为https://github.com/brightmart/albert_zh)
|
||||
- chinese_xlnet_mid_L-24_H-768_A-12/(哈工大预训练的中文xlnet模型[https://github.com/ymcui/Chinese-PreTrained-XLNet],24层)
|
||||
- chinese_xlnet_base_L-12_H-768_A-12/(哈工大预训练的中文xlnet模型[https://github.com/ymcui/Chinese-PreTrained-XLNet],12层)
|
||||
- term_char.txt(已经上传, 项目中已全, wiki字典, 还可以用新华字典什么的)
|
||||
- term_word.txt(未上传, 项目中只有部分, 可参考词向量的)
|
||||
- w2v_model_merge_short.vec(未上传, 项目中只有部分, 词向量, 可以用自己的)
|
||||
|
@ -14,11 +14,15 @@ path_root = path_root.replace('\\', '/')
|
||||
path_embedding_random_char = path_root + '/data/embeddings/term_char.txt'
|
||||
path_embedding_random_word = path_root + '/data/embeddings/term_word.txt'
|
||||
path_embedding_bert = path_root + '/data/embeddings/chinese_L-12_H-768_A-12/'
|
||||
path_embedding_xlnet = path_root + '/data/embeddings/chinese_xlnet_mid_L-24_H-768_A-12/'
|
||||
path_embedding_xlnet = path_root + '/data/embeddings/chinese_xlnet_base_L-12_H-768_A-12/'
|
||||
path_embedding_albert = path_root + '/data/embeddings/albert_base_zh'
|
||||
path_embedding_vector_word2vec_char = path_root + '/data/embeddings/w2v_model_wiki_char.vec'
|
||||
path_embedding_vector_word2vec_word = path_root + '/data/embeddings/w2v_model_merge_short.vec'
|
||||
|
||||
# classify data of tnews
|
||||
path_tnews_train = path_root + '/data/tnews/train.csv'
|
||||
path_tnews_valid = path_root + '/data/tnews/dev.csv'
|
||||
|
||||
# classify data of baidu qa 2019
|
||||
path_baidu_qa_2019_train = path_root + '/data/baidu_qa_2019/baike_qa_train.csv'
|
||||
path_baidu_qa_2019_valid = path_root + '/data/baidu_qa_2019/baike_qa_valid.csv'
|
||||
|
@ -37,7 +37,9 @@ class XlnetGraph(graph):
|
||||
super().create_model(hyper_parameters)
|
||||
embedding_output = self.word_embedding.output
|
||||
# x = embedding_output
|
||||
x = Lambda(lambda x : x[:, -2:-1, :])(embedding_output) # 获取CLS
|
||||
# x = Lambda(lambda x: x[:, -1:, :])(embedding_output) # 获取CLS
|
||||
# x = Lambda(lambda x: x[:, -1:])(embedding_output) # 获取CLS
|
||||
x = embedding_output # 直接output
|
||||
# # text cnn
|
||||
# bert_output_emmbed = SpatialDropout1D(rate=self.dropout)(embedding_output)
|
||||
# concat_out = []
|
||||
@ -52,7 +54,7 @@ class XlnetGraph(graph):
|
||||
# x = GlobalMaxPooling1D(name='TextCNN_MaxPool1D_{}'.format(index))(x)
|
||||
# concat_out.append(x)
|
||||
# x = Concatenate(axis=1)(concat_out)
|
||||
# x = Dropout(self.dropout)(x)
|
||||
x = Dropout(self.dropout)(x)
|
||||
x = Flatten()(x)
|
||||
# 最后就是softmax
|
||||
dense_layer = Dense(self.label, activation=self.activate_classify)(x)
|
||||
|
@ -20,6 +20,7 @@ sys.path.append(project_path)
|
||||
from keras_textclassification.conf.path_config import path_model, path_fineture, path_model_dir, path_hyper_parameters
|
||||
# 训练验证数据地址
|
||||
from keras_textclassification.conf.path_config import path_baidu_qa_2019_train, path_baidu_qa_2019_valid
|
||||
# from keras_textclassification.conf.path_config import path_tnews_train, path_tnews_valid
|
||||
# 数据预处理, 删除文件目录下文件
|
||||
from keras_textclassification.data_preprocess.text_preprocess import PreprocessText, delete_file
|
||||
# 模型图
|
||||
@ -43,18 +44,18 @@ def train(hyper_parameters=None, rate=1.0):
|
||||
'trainable': True, # 暂不支持微调True, embedding是静态的还是动态的, 即控制可不可以微调
|
||||
'level_type': 'char', # 级别, 最小单元, 字/词, 填 'char' or 'word'
|
||||
'embedding_type': 'xlnet', # 级别, 嵌入类型, 还可以填'xlnet'、'random'、 'bert'、 'albert' or 'word2vec"
|
||||
'gpu_memory_fraction': 0.76, #gpu使用率
|
||||
'model': {'label': 17, # 类别数
|
||||
'batch_size': 2, # 批处理尺寸, 感觉原则上越大越好,尤其是样本不均衡的时候, batch_size设置影响比较大
|
||||
'gpu_memory_fraction': 0.8, #gpu使用率
|
||||
'model': {'label': 10, # 类别数
|
||||
'batch_size': 16, # 批处理尺寸, 感觉原则上越大越好,尤其是样本不均衡的时候, batch_size设置影响比较大
|
||||
'filters': [2, 3, 4, 5], # 卷积核尺寸
|
||||
'filters_num': 300, # 卷积个数 text-cnn:300-600
|
||||
'channel_size': 1, # CNN通道数
|
||||
'dropout': 0.5, # 随机失活, 概率
|
||||
'dropout': 0.1, # 随机失活, 概率
|
||||
'decay_step': 1000, # 学习率衰减step, 每N个step衰减一次
|
||||
'decay_rate': 0.9, # 学习率衰减系数, 乘法
|
||||
'decay_rate': 0.99, # 学习率衰减系数, 乘法
|
||||
'epochs': 20, # 训练最大轮次
|
||||
'patience': 3, # 早停,2-3就好
|
||||
'lr': 5e-5, # 学习率, 对训练会有比较大的影响, 如果准确率一直上不去,可以考虑调这个参数
|
||||
'lr': 1e-4, # 学习率, 对训练会有比较大的影响, 如果准确率一直上不去,可以考虑调这个参数
|
||||
'l2': 1e-9, # l2正则化
|
||||
'activate_classify': 'softmax', # 最后一个layer, 即分类激活函数
|
||||
'loss': 'categorical_crossentropy', # 损失函数
|
||||
@ -66,11 +67,11 @@ def train(hyper_parameters=None, rate=1.0):
|
||||
'path_hyper_parameters': path_hyper_parameters, # 模型(包括embedding),超参数地址,
|
||||
'path_fineture': path_fineture, # 保存embedding trainable地址, 例如字向量、词向量、bert向量等
|
||||
},
|
||||
'embedding': {'layer_indexes': [i for i in range(25)] + [-i for i in range(25)], # bert/xlnet取的层数,包括embedding层0,其他是正常的层
|
||||
'embedding': {'layer_indexes': [-1], # bert/xlnet取的层数,包括embedding层0,其他是正常的层
|
||||
# 'corpus_path': '', # embedding预训练数据地址,不配则会默认取conf里边默认的地址, keras-bert可以加载谷歌版bert,百度版ernie(需转换,https://github.com/ArthurRizar/tensorflow_ernie),哈工大版bert-wwm(tf框架,https://github.com/ymcui/Chinese-BERT-wwm)
|
||||
'xlnet_embed':{'attention_type': 'bi', # or 'uni'
|
||||
'memory_len': 0,
|
||||
'target_len': 32,},
|
||||
'memory_len': 32,
|
||||
'target_len': 32},
|
||||
},
|
||||
'data':{'train_data': path_baidu_qa_2019_train, # 训练数据
|
||||
'val_data': path_baidu_qa_2019_valid # 验证数据
|
||||
|
Loading…
Reference in New Issue
Block a user