111 lines
4.3 KiB
Python
111 lines
4.3 KiB
Python
|
# encoding:utf-8
|
|||
|
"""
|
|||
|
@author = 'XXY'
|
|||
|
@contact = '529379497@qq.com'
|
|||
|
@researchFie1d = 'NLP DL ML'
|
|||
|
@date= '2017/12/21 10:18'
|
|||
|
"""
|
|||
|
import json, os
|
|||
|
import jieba.posseg as pseg
|
|||
|
import logging
|
|||
|
from data_helper import *
|
|||
|
import numpy as np
|
|||
|
from sklearn.model_selection import train_test_split
|
|||
|
import pandas as pd
|
|||
|
import tensorflow as tf
|
|||
|
from tensorflow.contrib import learn
|
|||
|
from sklearn.metrics import classification_report, accuracy_score
|
|||
|
|
|||
|
logging.getLogger().setLevel(logging.INFO)
|
|||
|
|
|||
|
|
|||
|
def make_submission(file, prediction, encoding):
|
|||
|
valid_id = []
|
|||
|
label2int = {u"人类作者": 0, u"机器作者": 1, u"机器翻译": 2, u"自动摘要": 3}
|
|||
|
int2label = {0.0: u"人类作者", 1.0: u"机器作者", 2.0: u"机器翻译", 3.0: u"自动摘要"}
|
|||
|
|
|||
|
for line in open('./data/validation.txt'):
|
|||
|
text = json.loads(line.strip())
|
|||
|
valid_id.append(text['id'])
|
|||
|
result = pd.DataFrame({'id': valid_id, 'label': prediction})
|
|||
|
result['label'] = result['label'].apply(lambda x: int2label[x])
|
|||
|
print(result.head())
|
|||
|
result.to_csv(file, header=None, index=None, encoding=encoding)
|
|||
|
|
|||
|
|
|||
|
def predict_unseen_data():
|
|||
|
X_char = []
|
|||
|
X_punc = []
|
|||
|
y = []
|
|||
|
|
|||
|
# 读取字符特征
|
|||
|
print "读取char特征"
|
|||
|
with open('./data/validation_char.txt') as f:
|
|||
|
for line in f:
|
|||
|
temp = line.strip().split('\t')
|
|||
|
text = temp[0][1:-1].split(',')
|
|||
|
label = temp[1]
|
|||
|
X_char.append(text)
|
|||
|
y.append(label)
|
|||
|
|
|||
|
# 读取标点符号结构特征
|
|||
|
print "读取punc特征"
|
|||
|
with open('./data/validation_punc.txt') as f:
|
|||
|
for line in f:
|
|||
|
temp = line.strip().split('\t')
|
|||
|
text = temp[0][1:-1].split(',')
|
|||
|
X_punc.append(text)
|
|||
|
|
|||
|
print "读取全连接层特征"
|
|||
|
X_validation_fc_feat = pd.read_csv('./data/validation_fc_feat_norm_200.txt', sep=',')
|
|||
|
print"数据加载完毕!"
|
|||
|
|
|||
|
|
|||
|
X_validation_char = np.array(X_char)
|
|||
|
X_validation_punc = np.array(X_punc)
|
|||
|
X_validation_fc_feat = X_validation_fc_feat.values
|
|||
|
params = json.loads(open('./config/cnn_parameters.json').read())
|
|||
|
print "validation数据大小:", X_validation_char.shape, X_validation_punc.shape
|
|||
|
print "validation集加载完毕!"
|
|||
|
|
|||
|
checkpoint_dir = 'models/cnn_models/trained_model_1528419951/'
|
|||
|
if not checkpoint_dir.endswith('/'):
|
|||
|
checkpoint_dir += '/'
|
|||
|
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir + 'checkpoints') # 加载最近保存的模型
|
|||
|
# checkpoint_file = '/home/h325/data/Xxy/SMP_NEW/models/cnn_models/trained_model_1528257857/checkpoints/model-12400'
|
|||
|
logging.critical('Loaded the trained model: {}'.format(checkpoint_file))
|
|||
|
|
|||
|
graph = tf.Graph()
|
|||
|
with graph.as_default():
|
|||
|
session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
|
|||
|
sess = tf.Session(config=session_conf)
|
|||
|
|
|||
|
with sess.as_default():
|
|||
|
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
|
|||
|
saver.restore(sess, checkpoint_file)
|
|||
|
input_x_char = graph.get_operation_by_name("input/input_x_char").outputs[0]
|
|||
|
input_x_punc = graph.get_operation_by_name("input/input_x_punc").outputs[0]
|
|||
|
input_x_fc_feat = graph.get_operation_by_name('input/input_x_fc_feat').outputs[0]
|
|||
|
dropout_keep_prob = graph.get_operation_by_name("dropout/dropout_keep_prob").outputs[0]
|
|||
|
predictions = graph.get_operation_by_name("output/predictions").outputs[0]
|
|||
|
|
|||
|
dev_predictions = []
|
|||
|
for i in range(int(len(X_validation_char) / params['batch_size']) + 1):
|
|||
|
start_index = i * params['batch_size']
|
|||
|
end_index = min((i + 1) * params['batch_size'], len(X_validation_char))
|
|||
|
X_validation_char_batch = X_validation_char[start_index: end_index]
|
|||
|
X_validation_punc_batch = X_validation_punc[start_index: end_index]
|
|||
|
X_validation_fc_feat_batch = X_validation_fc_feat[start_index: end_index]
|
|||
|
prediction= sess.run(predictions, {input_x_char: X_validation_char_batch, input_x_punc:X_validation_punc_batch,
|
|||
|
input_x_fc_feat:X_validation_fc_feat_batch, dropout_keep_prob: 1.0})
|
|||
|
dev_predictions = np.concatenate([dev_predictions, prediction])
|
|||
|
|
|||
|
# for test_batch in test_batches:
|
|||
|
# X_test_batch = test_batch
|
|||
|
# prediction= sess.run(predictions, {input_x1: X_test_batch, dropout_keep_prob: 1.0})
|
|||
|
# dev_predictions = np.concatenate([dev_predictions, prediction])
|
|||
|
make_submission('results/cnn_char_punc_fc_feat_result4.csv', dev_predictions, encoding='utf-8')
|
|||
|
logging.critical('The prediction is complete')
|
|||
|
if __name__ == '__main__':
|
|||
|
predict_unseen_data()
|