nlp_xiaojiang/ChatBot/chatbot_search/chatbot_sentence_vec_by_bert.py
2019-08-24 23:40:02 +08:00

88 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: UTF-8 -*-
# !/usr/bin/python
# @time :2019/5/12 13:16
# @author :Mo
# @function :chatbot based search, encode sentence_vec by bert
def chatbot_sentence_vec_by_bert_own():
"""bert encode is writted by my own"""
from FeatureProject.bert.extract_keras_bert_feature import KerasBertVector
from conf.path_config import chicken_and_gossip_path
from utils.text_tools import txtRead
import numpy as np
# 读取数据和一些参数这里只取了100个标准问题
topk = 5
matrix_ques_save_path = "doc_vecs_chicken_and_gossip"
questions = txtRead(chicken_and_gossip_path, encodeType='utf-8')
ques = [ques.split('\t')[0] for ques in questions][0:100]
# 生成标准问题的bert句向量
bert_vector = KerasBertVector()
ques_basic_vecs = bert_vector.bert_encode(ques)
# 线上你可以生成直接调用然后直接load就好
np.savetxt(matrix_ques_save_path, ques_basic_vecs)
# matrix_ques = np.loadtxt(matrix_ques_save_path)
query_bert_vec = bert_vector.bert_encode(["小姜机器人是什么"])[0]
query_bert_vec = np.array(query_bert_vec)
print(query_bert_vec)
# 矩阵点乘很快的你也可以用annoy等工具计算就更加快了
qq_score = np.sum(query_bert_vec * ques_basic_vecs, axis=1) / np.linalg.norm(ques_basic_vecs, axis=1)
topk_idx = np.argsort(qq_score)[::-1][:topk]
for idx in topk_idx:
print('小姜机器人回答检索: %s\t%s' % (qq_score[idx], questions[idx]))
while True:
print("你的问题:")
query = input()
query_bert_vec = bert_vector.bert_encode([query])[0]
query_bert_vec = np.array(query_bert_vec)
# 矩阵点乘很快的你也可以用annoy等工具计算就更加快了
qq_score = np.sum(query_bert_vec * ques_basic_vecs, axis=1) / np.linalg.norm(ques_basic_vecs, axis=1)
topk_idx = np.argsort(qq_score)[::-1][:topk]
for idx in topk_idx:
print('小姜机器人回答检索: %s\t%s' % (qq_score[idx], questions[idx]))
def chatbot_sentence_vec_by_bert_bertasserver():
"""bert encode is used bert as server"""
from conf.path_config import chicken_and_gossip_path
from bert_serving.client import BertClient
from utils.text_tools import txtRead
import numpy as np
topk = 5
matrix_ques_save_path = "doc_vecs_chicken_and_gossip"
questions = txtRead(chicken_and_gossip_path, encodeType='utf-8')
ques = [ques.split('\t')[0] for ques in questions][0:100]
bc = BertClient(ip = 'localhost')
doc_vecs = bc.encode(ques)
np.savetxt(matrix_ques_save_path, doc_vecs)
# matrix_ques = np.loadtxt(matrix_ques_save_path)
while True:
query = input('你问: ')
query_vec = bc.encode([query])[0]
query_bert_vec = np.array(query_bert_vec)
# compute normalized dot product as score
score = np.sum(query_vec * doc_vecs, axis=1) / np.linalg.norm(doc_vecs, axis=1)
topk_idx = np.argsort(score)[::-1][:topk]
for idx in topk_idx:
print('小姜机器人回答: %s\t%s' % (score[idx], questions[idx]))
if __name__=="__main__":
chatbot_sentence_vec_by_bert_own()
# chatbot_sentence_vec_by_bert_bertasserver()
# result
# 小姜机器人是什么
# Tokens: ['[CLS]', '小', '姜', '机', '器', '人', '是', '什', '么', '[SEP]']
# (1, 32, 768)
# [CLS] [768] [1.0393640995025635, -0.31394684314727783, -0.08567211031913757, -0.12281288206577301,