Keras-TextClassification/keras_textclassification/m06_TextDCNN/predict.py
2019-10-13 09:08:31 +08:00

144 lines
5.3 KiB
Python

# -*- coding: UTF-8 -*-
# !/usr/bin/python
# @time :2019/6/12 14:11
# @author :Mo
# @function :
# 适配linux
import pathlib
import sys
import os
project_path = str(pathlib.Path(os.path.abspath(__file__)).parent.parent.parent)
sys.path.append(project_path)
# 地址
from keras_textclassification.conf.path_config import path_model, path_fineture, path_model_dir, path_hyper_parameters
# 训练验证数据地址
from keras_textclassification.conf.path_config import path_baidu_qa_2019_train, path_baidu_qa_2019_valid
# 数据预处理, 删除文件目录下文件
from keras_textclassification.data_preprocess.text_preprocess import PreprocessText, read_and_process, load_json
# 模型图
from keras_textclassification.m06_TextDCNN.graph import DCNNGraph as Graph
# 模型评估
from sklearn.metrics import classification_report
# 计算时间
import time
import numpy as np
def pred_tet(path_hyper_parameter=path_hyper_parameters, path_test=None, rate=1.0):
# 测试集的准确率
hyper_parameters = load_json(path_hyper_parameter)
if path_test: # 从外部引入测试数据地址
hyper_parameters['data']['val_data'] = path_test
time_start = time.time()
# graph初始化
graph = Graph(hyper_parameters)
print("graph init ok!")
graph.load_model()
print("graph load ok!")
ra_ed = graph.word_embedding
# 数据预处理
pt = PreprocessText()
y, x = read_and_process(hyper_parameters['data']['val_data'])
# 取该数据集的百分之几的语料测试
len_rate = int(len(y) * rate)
x = x[1:len_rate]
y = y[1:len_rate]
y_pred = []
count = 0
for x_one in x:
count += 1
ques_embed = ra_ed.sentence2idx(x_one)
if hyper_parameters['embedding_type'] == 'bert': # bert数据处理, token
x_val_1 = np.array([ques_embed[0]])
x_val_2 = np.array([ques_embed[1]])
x_val = [x_val_1, x_val_2]
else:
x_val = ques_embed
# 预测
pred = graph.predict(x_val)
pre = pt.prereocess_idx(pred[0])
label_pred = pre[0][0][0]
if count % 1000==0:
print(label_pred)
y_pred.append(label_pred)
print("data pred ok!")
# 预测结果转为int类型
index_y = [pt.l2i_i2l['l2i'][i] for i in y]
index_pred = [pt.l2i_i2l['l2i'][i] for i in y_pred]
target_names = [pt.l2i_i2l['i2l'][str(i)] for i in list(set((index_pred + index_y)))]
# 评估
report_predict = classification_report(index_y, index_pred,
target_names=target_names, digits=9)
print(report_predict)
print("耗时:" + str(time.time() - time_start))
def pred_input(path_hyper_parameter=path_hyper_parameters):
# 输入预测
# 加载超参数
hyper_parameters = load_json(path_hyper_parameter)
pt = PreprocessText()
# 模式初始化和加载
graph = Graph(hyper_parameters)
graph.load_model()
ra_ed = graph.word_embedding
ques = '我要打王者荣耀'
# str to token
ques_embed = ra_ed.sentence2idx(ques)
if hyper_parameters['embedding_type'] == 'bert':
x_val_1 = np.array([ques_embed[0]])
x_val_2 = np.array([ques_embed[1]])
x_val = [x_val_1, x_val_2]
else:
x_val = ques_embed
# 预测
pred = graph.predict(x_val)
# 取id to label and pred
pre = pt.prereocess_idx(pred[0])
print(pre)
while True:
print("请输入: ")
ques = input()
ques_embed = ra_ed.sentence2idx(ques)
print(ques_embed)
if hyper_parameters['embedding_type'] == 'bert':
x_val_1 = np.array([ques_embed[0]])
x_val_2 = np.array([ques_embed[1]])
x_val = [x_val_1, x_val_2]
else:
x_val = ques_embed
pred = graph.predict(x_val)
pre = pt.prereocess_idx(pred[0])
print(pre)
if __name__=="__main__":
# 测试集预测
pred_tet(path_test=path_baidu_qa_2019_valid, rate=1) # sample条件下设为1,否则训练语料可能会很少
# 可输入 input 预测
pred_input()
# precision recall f1-score support
#
# 文化 0.000000000 0.000000000 0.000000000 7
# 电脑 0.000000000 0.000000000 0.000000000 51
# 体育 0.000000000 0.000000000 0.000000000 5
# 娱乐 0.074074074 0.050000000 0.059701493 40
# 电子 0.000000000 0.000000000 0.000000000 8
# 育儿 0.000000000 0.000000000 0.000000000 5
# 汽车 0.000000000 0.000000000 0.000000000 5
# 烦恼 0.000000000 0.000000000 0.000000000 20
# 教育 0.144578313 0.206896552 0.170212766 58
# 游戏 0.000000000 0.000000000 0.000000000 92
# 社会 0.000000000 0.000000000 0.000000000 12
# 商业 0.000000000 0.000000000 0.000000000 35
# 健康 0.103448276 0.098360656 0.100840336 61
# 生活 0.128571429 0.734693878 0.218844985 49
#
# accuracy 0.125000000 448
# macro avg 0.032190864 0.077853649 0.039257113 448
# weighted avg 0.053479576 0.125000000 0.065033627 448