sigmoid
This commit is contained in:
parent
10bf265208
commit
3bb53e028d
@ -41,7 +41,7 @@
|
||||
- baidu_qa_2019(百度qa问答语料,只取title作为分类样本,17个类,有一个是空'',已经压缩上传)
|
||||
- baike_qa_train.csv
|
||||
- baike_qa_valid.csv
|
||||
-byte_multi_news(今日头条2018新闻标题多标签语料,1070个标签,fate233爬取, 地址为: [byte_multi_news](https://github.com/fate233/toutiao-multilevel-text-classfication-dataset))
|
||||
- byte_multi_news(今日头条2018新闻标题多标签语料,1070个标签,fate233爬取, 地址为: [byte_multi_news](https://github.com/fate233/toutiao-multilevel-text-classfication-dataset))
|
||||
-labels.csv
|
||||
-train.csv
|
||||
-valid.csv
|
||||
|
@ -52,7 +52,7 @@ def pred_input(path_hyper_parameter=path_hyper_parameters):
|
||||
pre = pt.prereocess_idx(pred[0])
|
||||
ls_nulti = []
|
||||
for ls in pre[0]:
|
||||
if ls[1] >= 0.1:
|
||||
if ls[1] >= 0.5:
|
||||
ls_nulti.append(ls)
|
||||
print(pre[0])
|
||||
print(ls_nulti)
|
||||
@ -71,7 +71,7 @@ def pred_input(path_hyper_parameter=path_hyper_parameters):
|
||||
pre = pt.prereocess_idx(pred[0])
|
||||
ls_nulti = []
|
||||
for ls in pre[0]:
|
||||
if ls[1] >= 0.1:
|
||||
if ls[1] >= 0.5:
|
||||
ls_nulti.append(ls)
|
||||
print(pre[0])
|
||||
print(ls_nulti)
|
||||
|
@ -42,7 +42,7 @@ def train(hyper_parameters=None, rate=1.0):
|
||||
'patience': 3, # 早停,2-3就好
|
||||
'lr': 1e-3, # 学习率, bert取5e-5, 其他取1e-3, 对训练会有比较大的影响, 如果准确率一直上不去,可以考虑调这个参数
|
||||
'l2': 1e-9, # l2正则化
|
||||
'activate_classify': 'softmax', # 最后一个layer, 即分类激活函数
|
||||
'activate_classify': 'sigmoid', # 'softmax', # 最后一个layer, 即分类激活函数
|
||||
'loss': 'categorical_crossentropy', # 损失函数, 可能有问题, 可以自己定义
|
||||
'metrics': 'top_k_categorical_accuracy', # 1070个类, 太多了先用topk, 这里数据k设置为最大:33
|
||||
# 'metrics': 'categorical_accuracy', # 保存更好模型的评价标准
|
||||
|
Loading…
Reference in New Issue
Block a user