From 38c814ae24437da33b6a7b6a1e5d0792a408e2ab Mon Sep 17 00:00:00 2001 From: yongzhuo <2714618994@qq.com> Date: Thu, 2 Sep 2021 16:15:19 +0800 Subject: [PATCH] fix xlnet config --- README.md | 2 +- keras_textclassification/conf/path_config.py | 6 ++++- keras_textclassification/m00_Xlnet/graph.py | 6 +++-- keras_textclassification/m00_Xlnet/train.py | 23 ++++++++++---------- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 751fad3..7ffde18 100644 --- a/README.md +++ b/README.md @@ -79,7 +79,7 @@ step3: goto # Train&Usage(调用) and Predict&Usage(调用) keras-bert还可以加载百度版ernie(需转换,[https://github.com/ArthurRizar/tensorflow_ernie](https://github.com/ArthurRizar/tensorflow_ernie)), 哈工大版bert-wwm(tf框架,[https://github.com/ymcui/Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm)) - albert_base_zh/(brightmart训练的albert, 地址为https://github.com/brightmart/albert_zh) - - chinese_xlnet_mid_L-24_H-768_A-12/(哈工大预训练的中文xlnet模型[https://github.com/ymcui/Chinese-PreTrained-XLNet],24层) + - chinese_xlnet_base_L-12_H-768_A-12/(哈工大预训练的中文xlnet模型[https://github.com/ymcui/Chinese-PreTrained-XLNet],12层) - term_char.txt(已经上传, 项目中已全, wiki字典, 还可以用新华字典什么的) - term_word.txt(未上传, 项目中只有部分, 可参考词向量的) - w2v_model_merge_short.vec(未上传, 项目中只有部分, 词向量, 可以用自己的) diff --git a/keras_textclassification/conf/path_config.py b/keras_textclassification/conf/path_config.py index c4e58a5..3a01502 100644 --- a/keras_textclassification/conf/path_config.py +++ b/keras_textclassification/conf/path_config.py @@ -14,11 +14,15 @@ path_root = path_root.replace('\\', '/') path_embedding_random_char = path_root + '/data/embeddings/term_char.txt' path_embedding_random_word = path_root + '/data/embeddings/term_word.txt' path_embedding_bert = path_root + '/data/embeddings/chinese_L-12_H-768_A-12/' -path_embedding_xlnet = path_root + '/data/embeddings/chinese_xlnet_mid_L-24_H-768_A-12/' +path_embedding_xlnet = path_root + '/data/embeddings/chinese_xlnet_base_L-12_H-768_A-12/' path_embedding_albert = path_root + '/data/embeddings/albert_base_zh' path_embedding_vector_word2vec_char = path_root + '/data/embeddings/w2v_model_wiki_char.vec' path_embedding_vector_word2vec_word = path_root + '/data/embeddings/w2v_model_merge_short.vec' +# classify data of tnews +path_tnews_train = path_root + '/data/tnews/train.csv' +path_tnews_valid = path_root + '/data/tnews/dev.csv' + # classify data of baidu qa 2019 path_baidu_qa_2019_train = path_root + '/data/baidu_qa_2019/baike_qa_train.csv' path_baidu_qa_2019_valid = path_root + '/data/baidu_qa_2019/baike_qa_valid.csv' diff --git a/keras_textclassification/m00_Xlnet/graph.py b/keras_textclassification/m00_Xlnet/graph.py index dbaa739..31a41ad 100644 --- a/keras_textclassification/m00_Xlnet/graph.py +++ b/keras_textclassification/m00_Xlnet/graph.py @@ -37,7 +37,9 @@ class XlnetGraph(graph): super().create_model(hyper_parameters) embedding_output = self.word_embedding.output # x = embedding_output - x = Lambda(lambda x : x[:, -2:-1, :])(embedding_output) # 获取CLS + # x = Lambda(lambda x: x[:, -1:, :])(embedding_output) # 获取CLS + # x = Lambda(lambda x: x[:, -1:])(embedding_output) # 获取CLS + x = embedding_output # 直接output # # text cnn # bert_output_emmbed = SpatialDropout1D(rate=self.dropout)(embedding_output) # concat_out = [] @@ -52,7 +54,7 @@ class XlnetGraph(graph): # x = GlobalMaxPooling1D(name='TextCNN_MaxPool1D_{}'.format(index))(x) # concat_out.append(x) # x = Concatenate(axis=1)(concat_out) - # x = Dropout(self.dropout)(x) + x = Dropout(self.dropout)(x) x = Flatten()(x) # 最后就是softmax dense_layer = Dense(self.label, activation=self.activate_classify)(x) diff --git a/keras_textclassification/m00_Xlnet/train.py b/keras_textclassification/m00_Xlnet/train.py index 834919c..118d1f6 100644 --- a/keras_textclassification/m00_Xlnet/train.py +++ b/keras_textclassification/m00_Xlnet/train.py @@ -20,6 +20,7 @@ sys.path.append(project_path) from keras_textclassification.conf.path_config import path_model, path_fineture, path_model_dir, path_hyper_parameters # 训练验证数据地址 from keras_textclassification.conf.path_config import path_baidu_qa_2019_train, path_baidu_qa_2019_valid +# from keras_textclassification.conf.path_config import path_tnews_train, path_tnews_valid # 数据预处理, 删除文件目录下文件 from keras_textclassification.data_preprocess.text_preprocess import PreprocessText, delete_file # 模型图 @@ -43,18 +44,18 @@ def train(hyper_parameters=None, rate=1.0): 'trainable': True, # 暂不支持微调True, embedding是静态的还是动态的, 即控制可不可以微调 'level_type': 'char', # 级别, 最小单元, 字/词, 填 'char' or 'word' 'embedding_type': 'xlnet', # 级别, 嵌入类型, 还可以填'xlnet'、'random'、 'bert'、 'albert' or 'word2vec" - 'gpu_memory_fraction': 0.76, #gpu使用率 - 'model': {'label': 17, # 类别数 - 'batch_size': 2, # 批处理尺寸, 感觉原则上越大越好,尤其是样本不均衡的时候, batch_size设置影响比较大 + 'gpu_memory_fraction': 0.8, #gpu使用率 + 'model': {'label': 10, # 类别数 + 'batch_size': 16, # 批处理尺寸, 感觉原则上越大越好,尤其是样本不均衡的时候, batch_size设置影响比较大 'filters': [2, 3, 4, 5], # 卷积核尺寸 'filters_num': 300, # 卷积个数 text-cnn:300-600 'channel_size': 1, # CNN通道数 - 'dropout': 0.5, # 随机失活, 概率 + 'dropout': 0.1, # 随机失活, 概率 'decay_step': 1000, # 学习率衰减step, 每N个step衰减一次 - 'decay_rate': 0.9, # 学习率衰减系数, 乘法 + 'decay_rate': 0.99, # 学习率衰减系数, 乘法 'epochs': 20, # 训练最大轮次 'patience': 3, # 早停,2-3就好 - 'lr': 5e-5, # 学习率, 对训练会有比较大的影响, 如果准确率一直上不去,可以考虑调这个参数 + 'lr': 1e-4, # 学习率, 对训练会有比较大的影响, 如果准确率一直上不去,可以考虑调这个参数 'l2': 1e-9, # l2正则化 'activate_classify': 'softmax', # 最后一个layer, 即分类激活函数 'loss': 'categorical_crossentropy', # 损失函数 @@ -66,14 +67,14 @@ def train(hyper_parameters=None, rate=1.0): 'path_hyper_parameters': path_hyper_parameters, # 模型(包括embedding),超参数地址, 'path_fineture': path_fineture, # 保存embedding trainable地址, 例如字向量、词向量、bert向量等 }, - 'embedding': {'layer_indexes': [i for i in range(25)] + [-i for i in range(25)], # bert/xlnet取的层数,包括embedding层0,其他是正常的层 + 'embedding': {'layer_indexes': [-1], # bert/xlnet取的层数,包括embedding层0,其他是正常的层 # 'corpus_path': '', # embedding预训练数据地址,不配则会默认取conf里边默认的地址, keras-bert可以加载谷歌版bert,百度版ernie(需转换,https://github.com/ArthurRizar/tensorflow_ernie),哈工大版bert-wwm(tf框架,https://github.com/ymcui/Chinese-BERT-wwm) 'xlnet_embed':{'attention_type': 'bi', # or 'uni' - 'memory_len': 0, - 'target_len': 32,}, + 'memory_len': 32, + 'target_len': 32}, }, - 'data':{'train_data': path_baidu_qa_2019_train, # 训练数据 - 'val_data': path_baidu_qa_2019_valid # 验证数据 + 'data':{'train_data': path_baidu_qa_2019_train, # 训练数据 + 'val_data': path_baidu_qa_2019_valid # 验证数据 }, }