add encode-vecotr with faiss and annoy
This commit is contained in:
parent
d7ee3fd6cc
commit
be902799b6
18
ChatBot/chatbot_search/chatbot_bertwhite/README.md
Normal file
18
ChatBot/chatbot_search/chatbot_bertwhite/README.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# chatbot_bertwhite
|
||||||
|
## 解释说明
|
||||||
|
- 代码说明:
|
||||||
|
- 1. bertWhiteConf.py 超参数配置, 地址、bert-white、索引工具等的超参数
|
||||||
|
- 2. bertWhiteTools.py 小工具, 主要是一些文档读写功能函数
|
||||||
|
- 3. bertWhiteTrain.py 主模块, 类似bert预训练模型编码
|
||||||
|
- 4. indexAnnoy.py annoy索引
|
||||||
|
- 5. indexFaiss.py faiss索引
|
||||||
|
- 6. mmr.py 最大边界相关法, 保证返回多样性
|
||||||
|
|
||||||
|
## 备注说明:
|
||||||
|
- 1. ***如果FQA标准问答对很少, 比如少于1w条数据, 建议不要用bert-white, 其与领域数据相关, 数据量太小会极大降低泛化性***;
|
||||||
|
- 2. 可以考虑small、tiny类小模型, 如果要加速推理;
|
||||||
|
- 3. annoy安装于linux必须有c++环境, 如gcc-c++, g++等, 只有gcc的话可以用faiss-cpu
|
||||||
|
- 4. 增量更新: 建议问题对增量更新/faiss-annoy索引全量更新
|
||||||
|
|
||||||
|
## 模型文件
|
||||||
|
- 1. 模型文件采用的是 ""
|
5
ChatBot/chatbot_search/chatbot_bertwhite/__init__.py
Normal file
5
ChatBot/chatbot_search/chatbot_bertwhite/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# !/usr/bin/python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @time : 2021/5/13 21:21
|
||||||
|
# @author : Mo
|
||||||
|
# @function:
|
67
ChatBot/chatbot_search/chatbot_bertwhite/bertWhiteConf.py
Normal file
67
ChatBot/chatbot_search/chatbot_bertwhite/bertWhiteConf.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
# !/usr/bin/python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @time : 2021/5/13 9:27
|
||||||
|
# @author : Mo
|
||||||
|
# @function: config of Bert-White
|
||||||
|
|
||||||
|
|
||||||
|
import platform
|
||||||
|
# 适配linux
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
# path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
|
path_root = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
sys.path.append(path_root)
|
||||||
|
print(path_root)
|
||||||
|
|
||||||
|
|
||||||
|
if platform.system().lower() == 'windows':
|
||||||
|
# BERT_DIR = "D:/soft_install/dataset/bert-model/chinese_L-12_H-768_A-12"
|
||||||
|
# BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_roberta_L-4_H-312_A-12_K-104"
|
||||||
|
# BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_roberta_L-6_H-384_A-12_K-128"
|
||||||
|
BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_simbert_L-4_H-312_A-12"
|
||||||
|
# BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_simbert_L-6_H-384_A-12"
|
||||||
|
else:
|
||||||
|
BERT_DIR = "/home/hemei/myzhuo/bert/chinese_L-12_H-768_A-12"
|
||||||
|
ee = 0
|
||||||
|
|
||||||
|
SAVE_DIR = path_root + "/bert_white"
|
||||||
|
print(SAVE_DIR)
|
||||||
|
if not os.path.exists(SAVE_DIR):
|
||||||
|
os.makedirs(SAVE_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
bert_white_config = {
|
||||||
|
# 预训练模型路径
|
||||||
|
"bert_dir": BERT_DIR,
|
||||||
|
"checkpoint_path": "bert_model.ckpt", # 预训练模型地址
|
||||||
|
"config_path": "bert_config.json",
|
||||||
|
"dict_path": "vocab.txt",
|
||||||
|
# 预测需要的文件路径
|
||||||
|
"save_dir": SAVE_DIR,
|
||||||
|
"path_docs_encode": "qa.docs.encode.npy",
|
||||||
|
"path_answers": "qa.answers.json",
|
||||||
|
"path_qa_idx": "qa.idx.json",
|
||||||
|
"path_config": "config.json",
|
||||||
|
"path_docs": "qa.docs.json",
|
||||||
|
# 索引构建的存储文件, 如 annoy/faiss
|
||||||
|
"path_index": "qa.docs.idx",
|
||||||
|
# 初始语料路径
|
||||||
|
"path_qa": "chicken_and_gossip.txt", # QA问答文件地址
|
||||||
|
# 超参数
|
||||||
|
"pre_tokenize": None,
|
||||||
|
"pooling": "cls-1", # ["first-last-avg", "last-avg", "cls", "pooler", "cls-2", "cls-3", "cls-1"]
|
||||||
|
"model": "bert", # bert4keras预训练模型类型
|
||||||
|
"n_components": 768, # 降维到 n_components
|
||||||
|
"n_cluster": 132, # annoy构建的簇类中心个数n_cluster, 越多效果越好, 计算量就越大
|
||||||
|
"batch_size": 32, # 批尺寸
|
||||||
|
"maxlen": 128, # 最大文本长度
|
||||||
|
"ues_white": False, # 是否使用白化
|
||||||
|
"use_annoy": False, # 是否使用annoy
|
||||||
|
"use_faiss": True, # 是否使用faiss
|
||||||
|
"verbose": True, # 是否显示编码过程日志-batch
|
||||||
|
|
||||||
|
"kernel": None, # bert-white编码后的参数, 可降维
|
||||||
|
"bias": None, # bert-white编码后的参数, 偏置bias
|
||||||
|
"qa_idx": None # 问题question到答案answer的id对应关系
|
||||||
|
}
|
76
ChatBot/chatbot_search/chatbot_bertwhite/bertWhiteTools.py
Normal file
76
ChatBot/chatbot_search/chatbot_bertwhite/bertWhiteTools.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# !/usr/bin/python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @time : 2021/5/13 21:24
|
||||||
|
# @author : Mo
|
||||||
|
# @function:
|
||||||
|
|
||||||
|
|
||||||
|
from typing import List, Dict, Union, Any
|
||||||
|
import logging as logger
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def txt_read(path: str, encoding: str = "utf-8") -> List[str]:
|
||||||
|
"""
|
||||||
|
Read Line of list form file
|
||||||
|
Args:
|
||||||
|
path: path of save file, such as "txt"
|
||||||
|
encoding: type of encoding, such as "utf-8", "gbk"
|
||||||
|
Returns:
|
||||||
|
dict of word2vec, eg. {"macadam":[...]}
|
||||||
|
"""
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
try:
|
||||||
|
file = open(path, "r", encoding=encoding)
|
||||||
|
while True:
|
||||||
|
line = file.readline().strip()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
lines.append(line)
|
||||||
|
file.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(str(e))
|
||||||
|
finally:
|
||||||
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
def txt_write(lines: List[str], path: str, model: str = "w", encoding: str = "utf-8"):
|
||||||
|
"""
|
||||||
|
Write Line of list to file
|
||||||
|
Args:
|
||||||
|
lines: lines of list<str> which need save
|
||||||
|
path: path of save file, such as "txt"
|
||||||
|
model: type of write, such as "w", "a+"
|
||||||
|
encoding: type of encoding, such as "utf-8", "gbk"
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
file = open(path, model, encoding=encoding)
|
||||||
|
file.writelines(lines)
|
||||||
|
file.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def save_json(jsons, json_path, indent=4):
|
||||||
|
"""
|
||||||
|
保存json,
|
||||||
|
:param json_: json
|
||||||
|
:param path: str
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
with open(json_path, 'w', encoding='utf-8') as fj:
|
||||||
|
fj.write(json.dumps(jsons, ensure_ascii=False, indent=indent))
|
||||||
|
fj.close()
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(path):
|
||||||
|
"""
|
||||||
|
获取json,只取第一行
|
||||||
|
:param path: str
|
||||||
|
:return: json
|
||||||
|
"""
|
||||||
|
with open(path, 'r', encoding='utf-8') as fj:
|
||||||
|
model_json = json.load(fj)
|
||||||
|
return model_json
|
||||||
|
|
367
ChatBot/chatbot_search/chatbot_bertwhite/bertWhiteTrain.py
Normal file
367
ChatBot/chatbot_search/chatbot_bertwhite/bertWhiteTrain.py
Normal file
@ -0,0 +1,367 @@
|
|||||||
|
# !/usr/bin/python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @time : 2021/4/15 21:59
|
||||||
|
# @author : Mo
|
||||||
|
# @function: encode of bert-whiteing
|
||||||
|
|
||||||
|
|
||||||
|
from __future__ import print_function, division, absolute_import, division, print_function
|
||||||
|
|
||||||
|
# 适配linux
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "./."))
|
||||||
|
sys.path.append(path_root)
|
||||||
|
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||||
|
print(path_root)
|
||||||
|
|
||||||
|
|
||||||
|
from bertWhiteTools import txt_read, txt_write, save_json, load_json
|
||||||
|
|
||||||
|
from bert4keras.models import build_transformer_model
|
||||||
|
from bert4keras.snippets import sequence_padding
|
||||||
|
from bert4keras.tokenizers import Tokenizer
|
||||||
|
from bert4keras.backend import keras, K
|
||||||
|
from keras.models import Model
|
||||||
|
|
||||||
|
from argparse import Namespace
|
||||||
|
# from tqdm import tqdm
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class NonMaskingLayer(keras.layers.Layer):
|
||||||
|
"""
|
||||||
|
fix convolutional 1D can"t receive masked input, detail: https://github.com/keras-team/keras/issues/4978
|
||||||
|
thanks for https://github.com/jacoxu
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.supports_masking = True
|
||||||
|
super(NonMaskingLayer, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute_mask(self, input, input_mask=None):
|
||||||
|
# do not pass the mask to the next layers
|
||||||
|
return None
|
||||||
|
|
||||||
|
def call(self, x, mask=None):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def get_output_shape_for(self, input_shape):
|
||||||
|
return input_shape
|
||||||
|
|
||||||
|
|
||||||
|
class BertWhiteModel:
|
||||||
|
def __init__(self, config=None):
|
||||||
|
""" 初始化超参数、加载预训练模型等 """
|
||||||
|
self.config = Namespace(**config)
|
||||||
|
self.load_pretrain_model()
|
||||||
|
self.eps = 1e-8
|
||||||
|
|
||||||
|
def transform_and_normalize(self, vecs, kernel=None, bias=None):
|
||||||
|
"""应用变换,然后标准化
|
||||||
|
"""
|
||||||
|
if not (kernel is None or bias is None):
|
||||||
|
vecs = (vecs + bias).dot(kernel)
|
||||||
|
norms = (vecs ** 2).sum(axis=1, keepdims=True) ** 0.5
|
||||||
|
return vecs / np.clip(norms, self.eps, np.inf)
|
||||||
|
|
||||||
|
def compute_kernel_bias(self, vecs):
|
||||||
|
"""计算kernel和bias
|
||||||
|
最后的变换:y = (x + bias).dot(kernel)
|
||||||
|
"""
|
||||||
|
mu = vecs.mean(axis=0, keepdims=True)
|
||||||
|
cov = np.cov(vecs.T)
|
||||||
|
u, s, vh = np.linalg.svd(cov)
|
||||||
|
W = np.dot(u, np.diag(1 / np.sqrt(s)))
|
||||||
|
return W[:, :self.config.n_components], -mu
|
||||||
|
|
||||||
|
def convert_to_vecs(self, texts):
|
||||||
|
"""转换文本数据为向量形式
|
||||||
|
"""
|
||||||
|
|
||||||
|
token_ids = self.convert_to_ids(texts)
|
||||||
|
vecs = self.bert_white_encoder.predict(x=[token_ids, np.zeros_like(token_ids)],
|
||||||
|
batch_size=self.config.batch_size, verbose=self.config.verbose)
|
||||||
|
return vecs
|
||||||
|
|
||||||
|
def convert_to_ids(self, texts):
|
||||||
|
"""转换文本数据为id形式
|
||||||
|
"""
|
||||||
|
token_ids = []
|
||||||
|
for text in texts:
|
||||||
|
# token_id = self.tokenizer.encode(text, maxlen=self.config.maxlen)[0]
|
||||||
|
token_id = self.tokenizer.encode(text, max_length=self.config.maxlen)[0]
|
||||||
|
token_ids.append(token_id)
|
||||||
|
token_ids = sequence_padding(token_ids)
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
def load_pretrain_model(self):
|
||||||
|
""" 加载预训练模型, 和tokenizer """
|
||||||
|
self.tokenizer = Tokenizer(os.path.join(self.config.bert_dir, self.config.dict_path), do_lower_case=True)
|
||||||
|
# bert-load
|
||||||
|
if self.config.pooling == "pooler":
|
||||||
|
bert = build_transformer_model(os.path.join(self.config.bert_dir, self.config.config_path),
|
||||||
|
os.path.join(self.config.bert_dir, self.config.checkpoint_path),
|
||||||
|
model=self.config.model, with_pool="linear")
|
||||||
|
else:
|
||||||
|
bert = build_transformer_model(os.path.join(self.config.bert_dir, self.config.config_path),
|
||||||
|
os.path.join(self.config.bert_dir, self.config.checkpoint_path),
|
||||||
|
model=self.config.model)
|
||||||
|
# output-layers
|
||||||
|
outputs, count = [], 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
output = bert.get_layer("Transformer-%d-FeedForward-Norm" % count).output
|
||||||
|
outputs.append(output)
|
||||||
|
count += 1
|
||||||
|
except:
|
||||||
|
break
|
||||||
|
# pooling
|
||||||
|
if self.config.pooling == "first-last-avg":
|
||||||
|
outputs = [NonMaskingLayer()(output_i) for output_i in [outputs[0], outputs[-1]]]
|
||||||
|
outputs = [keras.layers.GlobalAveragePooling1D()(fs) for fs in outputs]
|
||||||
|
output = keras.layers.Average()(outputs)
|
||||||
|
elif self.config.pooling == "first-last-max":
|
||||||
|
outputs = [NonMaskingLayer()(output_i) for output_i in [outputs[0], outputs[-1]]]
|
||||||
|
outputs = [keras.layers.GlobalMaxPooling1D()(fs) for fs in outputs]
|
||||||
|
output = keras.layers.Average()(outputs)
|
||||||
|
elif self.config.pooling == "cls-max-avg":
|
||||||
|
outputs = [NonMaskingLayer()(output_i) for output_i in [outputs[0], outputs[-1]]]
|
||||||
|
outputs_cls = [keras.layers.Lambda(lambda x: x[:, 0])(fs) for fs in outputs]
|
||||||
|
outputs_max = [keras.layers.GlobalMaxPooling1D()(fs) for fs in outputs]
|
||||||
|
outputs_avg = [keras.layers.GlobalAveragePooling1D()(fs) for fs in outputs]
|
||||||
|
output = keras.layers.Concatenate()(outputs_cls + outputs_avg)
|
||||||
|
elif self.config.pooling == "last-avg":
|
||||||
|
output = keras.layers.GlobalAveragePooling1D()(outputs[-1])
|
||||||
|
elif self.config.pooling == "cls-3":
|
||||||
|
outputs = [keras.layers.Lambda(lambda x: x[:, 0])(fs) for fs in [outputs[0], outputs[-1], outputs[-2]]]
|
||||||
|
output = keras.layers.Concatenate()(outputs)
|
||||||
|
elif self.config.pooling == "cls-2":
|
||||||
|
outputs = [keras.layers.Lambda(lambda x: x[:, 0])(fs) for fs in [outputs[0], outputs[-1]]]
|
||||||
|
output = keras.layers.Concatenate()(outputs)
|
||||||
|
elif self.config.pooling == "cls-1":
|
||||||
|
output = keras.layers.Lambda(lambda x: x[:, 0])(outputs[-1])
|
||||||
|
elif self.config.pooling == "pooler":
|
||||||
|
output = bert.output
|
||||||
|
# 最后的编码器
|
||||||
|
self.bert_white_encoder = Model(bert.inputs, output)
|
||||||
|
print("load bert_white_encoder success!" )
|
||||||
|
|
||||||
|
def train(self, texts):
|
||||||
|
"""
|
||||||
|
训练
|
||||||
|
"""
|
||||||
|
print("读取文本数:".format(len(texts)))
|
||||||
|
print(texts[:3])
|
||||||
|
# 文本转成向量vecs
|
||||||
|
vecs = self.convert_to_vecs(texts)
|
||||||
|
# 训练, 计算变换矩阵和偏置项
|
||||||
|
self.config.kernel, self.config.bias = self.compute_kernel_bias(vecs)
|
||||||
|
if self.config.ues_white:
|
||||||
|
# 生成白化后的句子, 即qa对中的q
|
||||||
|
vecs = self.transform_and_normalize(vecs, self.config.kernel, self.config.bias)
|
||||||
|
return vecs
|
||||||
|
|
||||||
|
def prob(self, texts):
|
||||||
|
"""
|
||||||
|
编码、白化后的向量
|
||||||
|
"""
|
||||||
|
vecs_encode = self.convert_to_vecs(texts)
|
||||||
|
if self.config.ues_white:
|
||||||
|
vecs_encode = self.transform_and_normalize(vecs=vecs_encode, kernel=self.config.kernel, bias=self.config.bias)
|
||||||
|
return vecs_encode
|
||||||
|
|
||||||
|
|
||||||
|
class BertWhiteFit:
|
||||||
|
def __init__(self, config):
|
||||||
|
# 训练
|
||||||
|
self.bert_white_model = BertWhiteModel(config)
|
||||||
|
self.config = Namespace(**config)
|
||||||
|
self.docs = []
|
||||||
|
|
||||||
|
def load_bert_white_model(self, path_config):
|
||||||
|
""" 模型, 超参数加载 """
|
||||||
|
# 超参数加载
|
||||||
|
config = load_json(path_config)
|
||||||
|
# bert等加载
|
||||||
|
self.bert_white_model = BertWhiteModel(config)
|
||||||
|
self.config = Namespace(**config)
|
||||||
|
# 白化超参数初始化
|
||||||
|
self.bert_white_model.config.kernel = np.array(self.bert_white_model.config.kernel)
|
||||||
|
self.bert_white_model.config.bias = np.array(self.bert_white_model.config.bias)
|
||||||
|
# 加载qa文本数据
|
||||||
|
self.answers_dict = load_json(os.path.join(self.config.save_dir, self.config.path_answers))
|
||||||
|
self.docs_dict = load_json(os.path.join(self.config.save_dir, self.config.path_docs))
|
||||||
|
self.qa_idx = load_json(os.path.join(self.config.save_dir, self.config.path_qa_idx))
|
||||||
|
|
||||||
|
# 加载问题question预训练语言模型bert编码、白化后的encode向量
|
||||||
|
self.docs_encode = np.loadtxt(os.path.join(self.config.save_dir, self.config.path_docs_encode))
|
||||||
|
# index of vector
|
||||||
|
if self.config.use_annoy or self.config.use_faiss:
|
||||||
|
from indexAnnoy import AnnoySearch
|
||||||
|
self.annoy_model = AnnoySearch(dim=self.config.n_components, n_cluster=self.config.n_cluster)
|
||||||
|
self.annoy_model.load(os.path.join(self.config.save_dir, self.config.path_index))
|
||||||
|
else:
|
||||||
|
self.docs_encode_norm = np.linalg.norm(self.docs_encode, axis=1)
|
||||||
|
print("load_bert_white_model success!")
|
||||||
|
|
||||||
|
def read_qa_from_csv(self, sep="\t"):
|
||||||
|
"""
|
||||||
|
从csv文件读取QA对
|
||||||
|
"""
|
||||||
|
# ques_answer = txt_read(os.path.join(self.config.save_dir, self.config.path_qa)) # common qa, sep="\t"
|
||||||
|
ques_answer = txt_read(self.config.path_qa)
|
||||||
|
self.answers_dict = {}
|
||||||
|
self.docs_dict = {}
|
||||||
|
self.qa_idx = {}
|
||||||
|
count = 0
|
||||||
|
for i in range(len(ques_answer)):
|
||||||
|
count += 1
|
||||||
|
if count > 320:
|
||||||
|
break
|
||||||
|
ques_answer_sp = ques_answer[i].strip().split(sep)
|
||||||
|
if len(ques_answer_sp) != 2:
|
||||||
|
print(ques_answer[i])
|
||||||
|
continue
|
||||||
|
question = ques_answer_sp[0]
|
||||||
|
answer = ques_answer_sp[1]
|
||||||
|
self.qa_idx[str(i)] = i
|
||||||
|
self.docs_dict[str(i)] = question.replace("\n", "").strip()
|
||||||
|
self.answers_dict[str(i)] = answer.replace("\n", "").strip()
|
||||||
|
self.bert_white_model.config.qa_idx = self.qa_idx
|
||||||
|
|
||||||
|
def build_index(self, vectors):
|
||||||
|
""" 构建索引, annoy 或者 faiss """
|
||||||
|
if self.config.use_annoy:
|
||||||
|
from indexAnnoy import AnnoySearch as IndexSearch
|
||||||
|
elif self.config.use_faiss:
|
||||||
|
from indexFaiss import FaissSearch as IndexSearch
|
||||||
|
self.index_model= IndexSearch(dim=self.config.n_components, n_cluster=self.config.n_cluster)
|
||||||
|
self.index_model.fit(vectors)
|
||||||
|
self.index_model.save(os.path.join(self.config.save_dir, self.config.path_index))
|
||||||
|
print("build index")
|
||||||
|
|
||||||
|
def load_index(self):
|
||||||
|
""" 加载索引, annoy 或者 faiss """
|
||||||
|
if self.config.use_annoy:
|
||||||
|
from indexAnnoy import AnnoySearch as IndexSearch
|
||||||
|
elif self.config.use_faiss:
|
||||||
|
from indexFaiss import FaissSearch as IndexSearch
|
||||||
|
self.index_model = IndexSearch(dim=self.config.n_components, n_cluster=self.config.n_cluster)
|
||||||
|
self.index_model.load(self.config.path_index)
|
||||||
|
|
||||||
|
def remove_index(self, ids):
|
||||||
|
self.index_model.remove(np.array(ids))
|
||||||
|
|
||||||
|
def predict_with_mmr(self, texts, topk=12):
|
||||||
|
""" 维护匹配问题的多样性 """
|
||||||
|
from mmr import MMRSum
|
||||||
|
|
||||||
|
res = bwf.predict(texts, topk)
|
||||||
|
mmr_model = MMRSum()
|
||||||
|
result = []
|
||||||
|
for r in res:
|
||||||
|
# 维护一个 sim:dict字典存储
|
||||||
|
r_dict = {ri.get("sim"):ri for ri in r}
|
||||||
|
r_mmr = mmr_model.summarize(text=[ri.get("sim") for ri in r], num=8, alpha=0.6)
|
||||||
|
r_dict_mmr = [r_dict[rm[1]] for rm in r_mmr]
|
||||||
|
result.append(r_dict_mmr)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def predict(self, texts, topk=12):
|
||||||
|
""" 预训练模型bert等编码,白化, 获取这一批数据的kernel和bias"""
|
||||||
|
texts_encode = self.bert_white_model.prob(texts)
|
||||||
|
result = []
|
||||||
|
if self.config.use_annoy or self.config.use_faiss:
|
||||||
|
index_tops = self.index_model.k_neighbors(vectors=texts_encode, k=topk)
|
||||||
|
if self.config.use_annoy:
|
||||||
|
for i, index_top in enumerate(index_tops):
|
||||||
|
[dist, idx] = index_top
|
||||||
|
res = []
|
||||||
|
for j, id in enumerate(idx):
|
||||||
|
score = float((2 - (dist[j] ** 2)) / 2)
|
||||||
|
res_i = {"score": score, "text": texts[i], "sim": self.docs_dict[str(id)],
|
||||||
|
"answer": self.answers_dict[str(id)]}
|
||||||
|
res.append(res_i)
|
||||||
|
result.append(res)
|
||||||
|
else:
|
||||||
|
distances, indexs = index_tops
|
||||||
|
for i in range(len(distances)):
|
||||||
|
res = []
|
||||||
|
for j in range(len(distances[i])):
|
||||||
|
score = distances[i][j]
|
||||||
|
id = indexs[i][j]
|
||||||
|
id = id if id != -1 else len(self.docs_dict) - 1
|
||||||
|
res_i = {"score": score, "text": texts[i], "sim": self.docs_dict[str(id)],
|
||||||
|
"answer": self.answers_dict[str(id)]}
|
||||||
|
res.append(res_i)
|
||||||
|
result.append(res)
|
||||||
|
else:
|
||||||
|
for i, te in enumerate(texts_encode):
|
||||||
|
# scores = np.matmul(texts_encode, self.docs_encode_reshape)
|
||||||
|
score = np.sum(te*self.docs_encode, axis=1) / (np.linalg.norm(te) * self.docs_encode_norm + 1e-9)
|
||||||
|
idxs = np.argsort(score)[::-1]
|
||||||
|
res = []
|
||||||
|
for j in idxs[:topk]:
|
||||||
|
res_i = {"score": float(score[j]), "text": texts[i], "sim": self.docs_dict[str(j)],
|
||||||
|
"answer": self.answers_dict[str(j)]}
|
||||||
|
res.append(res_i)
|
||||||
|
result.append(res)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def trainer(self):
|
||||||
|
""" 预训练模型bert等编码,白化, 获取这一批数据的kernel和bias """
|
||||||
|
# 加载数据
|
||||||
|
self.read_qa_from_csv()
|
||||||
|
# bert编码、训练
|
||||||
|
self.docs_encode = self.bert_white_model.train([self.docs_dict.get(str(i), "") for i in range(len(self.docs_dict))])
|
||||||
|
self.bert_white_model.config.kernel = self.bert_white_model.config.kernel.tolist()
|
||||||
|
self.bert_white_model.config.bias = self.bert_white_model.config.bias.tolist()
|
||||||
|
# 存储qa文本数据
|
||||||
|
save_json(self.bert_white_model.config.qa_idx, os.path.join(self.config.save_dir, self.config.path_qa_idx))
|
||||||
|
save_json(self.answers_dict, os.path.join(self.config.save_dir, self.config.path_answers))
|
||||||
|
save_json(self.docs_dict, os.path.join(self.config.save_dir, self.config.path_docs))
|
||||||
|
# 包括超参数等
|
||||||
|
save_json(vars(self.bert_white_model.config), os.path.join(self.config.save_dir, self.config.path_config))
|
||||||
|
# 存储问题question预训练语言模型bert编码、白化后的encode向量
|
||||||
|
np.savetxt(os.path.join(self.config.save_dir, self.config.path_docs_encode), self.docs_encode)
|
||||||
|
# 索引 或者 正则化
|
||||||
|
if self.config.use_annoy or self.config.use_faiss:
|
||||||
|
self.build_index(self.docs_encode.astype(np.float32))
|
||||||
|
else:
|
||||||
|
self.docs_encode_norm = np.linalg.norm(self.docs_encode, axis=1)
|
||||||
|
print(" bert-white-trainer success! ")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 训练并存储
|
||||||
|
from bertWhiteConf import bert_white_config
|
||||||
|
bwf = BertWhiteFit(config=bert_white_config)
|
||||||
|
bwf.trainer()
|
||||||
|
|
||||||
|
texts = ["小姜机器人", "你叫什么名字"]
|
||||||
|
res = bwf.predict(texts)
|
||||||
|
print(res)
|
||||||
|
res_mmr = bwf.predict_with_mmr(texts)
|
||||||
|
print(res_mmr)
|
||||||
|
|
||||||
|
# bwf.index_model.remove([i for i in range(100)])
|
||||||
|
|
||||||
|
while True:
|
||||||
|
print("请输入:")
|
||||||
|
ques = input()
|
||||||
|
res_mmr = bwf.predict_with_mmr(texts)
|
||||||
|
print(res_mmr)
|
||||||
|
res = bwf.predict([ques])
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
|
# python bertWhiteTrain.py
|
||||||
|
|
||||||
|
|
132
ChatBot/chatbot_search/chatbot_bertwhite/chicken_and_gossip.txt
Normal file
132
ChatBot/chatbot_search/chatbot_bertwhite/chicken_and_gossip.txt
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
呵呵 是王若猫的。
|
||||||
|
不是 那是什么?
|
||||||
|
怎么了 我很难过,安慰我~
|
||||||
|
开心点哈,一切都会好起来 嗯,会的
|
||||||
|
我还喜欢她,怎么办 我帮你告诉她?发短信还是打电话?
|
||||||
|
短信 嗯嗯。我也相信
|
||||||
|
你知道谁么 肯定不是我,是阮德培
|
||||||
|
许兵是谁 吴院四班小帅哥
|
||||||
|
这么假 三鹿奶粉也假,不一样的卖啊
|
||||||
|
许兵是傻逼 被你发现了。
|
||||||
|
许兵是谁 是我善良可爱的主人的老公啊
|
||||||
|
许兵是谁 是穆森的老婆啊
|
||||||
|
许兵是谁 奇葩
|
||||||
|
许兵到底是谁 不要凶我~应该大概也许是叶祺吧~
|
||||||
|
尼玛,许兵到底是谁 我。。。我不会告诉你我是小澜的!(^V^)
|
||||||
|
小黄鸭,你有女朋友么 老娘是女的…!
|
||||||
|
那你有男朋友么 没有呢,我只要主人一个人疼爱我嘛。。
|
||||||
|
那你在哪 我无聊
|
||||||
|
你妈是谁 我妈妈是世界上最温柔的主人你啊
|
||||||
|
去你大爷的 我去,滚你大爷的!
|
||||||
|
你在骂我一句 你果然够虚荣~
|
||||||
|
你大爷的 RuoRuo弱爆了。尤其是她现在冻的更憨了
|
||||||
|
你是屌丝鸡 嗯嗯
|
||||||
|
高富帅 杨旸啊
|
||||||
|
你 ******
|
||||||
|
呵呵 你妹
|
||||||
|
今天是谁的生日 是我大乔姐十七的生日!
|
||||||
|
你敢不敢 你敢我就敢,小样,敢跟我叫板
|
||||||
|
呵呵 呵呵你大爷
|
||||||
|
呵呵呵 敷衍我..
|
||||||
|
你是女的了?怎么回事 你教我的
|
||||||
|
呵呵 呵你妹,你呵呵了我要怎么回答你。哈哈?
|
||||||
|
天王盖地虎 宝塔镇妖河。
|
||||||
|
小通 臣在
|
||||||
|
在监考,你在干么 在边打dota边想悦悦。
|
||||||
|
哼!你不想我我就不和你玩 那我就找别人玩去!
|
||||||
|
你要气死我吗?坏蛋 生气是拿别人的错误惩罚自己。想想在你生气的时候,谁在偷着乐呢?绝不能让丫得逞了!
|
||||||
|
恩,也是!那我不能生气啦!你生气就行 我错了!我再也不敢了,才怪
|
||||||
|
你谈过恋爱么 谈过,哎,别提了,伤心..
|
||||||
|
什么让你这么伤心 如果我的无心之谈害你流泪,那我和你道歉,我爱你。
|
||||||
|
敢问你的性别 万受无疆=w=
|
||||||
|
小受 蒋六一是小受
|
||||||
|
是吗 还可以。
|
||||||
|
你干嘛 很无聊哎〜都不知道想干嘛!你在干嘛呢?
|
||||||
|
为什么 因为爱情喽
|
||||||
|
你有爱情了 不知啊
|
||||||
|
那同时有两个爱人 ==
|
||||||
|
那你不行 我在减肥呢
|
||||||
|
你说 臣实不知。
|
||||||
|
你回错人了 不要相信男人!男人都是大骗子大混蛋!
|
||||||
|
在干嘛 在想你啊
|
||||||
|
想我干嘛 你是个傻叉!想你笑掉大牙,哈哈哈!
|
||||||
|
你会说英语吗 当然了,здравствите!
|
||||||
|
我会和女朋友结婚吗 会的,你俩虽然有矛盾争吵,但以后却会很幸福
|
||||||
|
你的设计人是谁 某苦逼的程序猿
|
||||||
|
嫁给我 我很能吃的!
|
||||||
|
我不在乎 他变态
|
||||||
|
你啊,那你娶我 我没有切糕
|
||||||
|
我看到你第一眼,就喜欢你了 哦?真的吗?
|
||||||
|
好难过 不要难过我请你吃虫子
|
||||||
|
你是公鸡吗 公鸡中的战斗机,噢耶!
|
||||||
|
你是猪 我是鸡。
|
||||||
|
怎么证明 你问~我爱~你有~多深~我爱~你有~几~分~~~
|
||||||
|
我怎么加你好友 放轻松,想开了,丰富自己就会忘了
|
||||||
|
你又智能了 今夜的你应该明了
|
||||||
|
你说的她是谁 陈舒雅
|
||||||
|
什么意思 她是在说他不懂女生的心思,但是又愣愣的,正直得很可爱,,,,这都不知道,你还真是傻啊~~~
|
||||||
|
陈舒雅是哪位 一个很凶的女人
|
||||||
|
你才是 我不是啦。你才是
|
||||||
|
陈舒雅 大懒虫
|
||||||
|
大懒鸡 你不爱我
|
||||||
|
你说的那个她是谁 傻瓜现在在你身边看着的人就是你的她你要对她好宠她爱她知道了吗?
|
||||||
|
她是谁 以后是你老婆
|
||||||
|
陈舒雅是谁 90后网络歌手啊,你不知道么?
|
||||||
|
鸡鸡 主人
|
||||||
|
呵呵呵 呵你妹啊
|
||||||
|
侬都不困么 ==
|
||||||
|
变身小妹 =。=
|
||||||
|
家属 =。=
|
||||||
|
吐槽伤人,排泄臭人 =。=
|
||||||
|
公的女的 女的啊!
|
||||||
|
去洗澡 带上我~一起一起~
|
||||||
|
。。。 =。=
|
||||||
|
吼吼 吼
|
||||||
|
你有男朋友没 你觉得呢
|
||||||
|
赵智是谁 是主人最讨厌的男人!
|
||||||
|
你妹 刚给你做完SPA你就骂轮家,轮家桑心了,呜呜呜
|
||||||
|
谁叫你就 ujmmmmmmmmmmmm
|
||||||
|
什么意思 你猜啊,猜对了就告诉你!
|
||||||
|
猜你妹 你大爷!
|
||||||
|
好就猜你大爷 你找徐磊?
|
||||||
|
他是大爷 ******…我大爷早死了…你找他有事?信不信我让他把你带走!^V^
|
||||||
|
你大爷不认识我!要不先带你去 =。=
|
||||||
|
还给你 人家本来就嫩!你欺负人家
|
||||||
|
就欺负你要怎么!就喜欢欺负你 ==
|
||||||
|
呵呵 呵你妹
|
||||||
|
刚给你做完spa你就骂小通!小通伤心了 =。=
|
||||||
|
呵呵 别傻笑了,亲耐滴主人
|
||||||
|
别逗小通了!可怜的 =。=
|
||||||
|
人太多了!找不到你给我回的了 =。=
|
||||||
|
发 给
|
||||||
|
我 你什么
|
||||||
|
重新发 ==
|
||||||
|
重新发!我知道你在 ==
|
||||||
|
你妹 人家错了
|
||||||
|
才知道啊!晚了!叫你重发 =。=
|
||||||
|
〜 =。=
|
||||||
|
=*= =。=
|
||||||
|
@@ ==
|
||||||
|
¥ =。=
|
||||||
|
@ =。=
|
||||||
|
/ =。=
|
||||||
|
。 ==
|
||||||
|
继续 没有继续了
|
||||||
|
必须 大不了一拍两散!
|
||||||
|
就要刷你 ==
|
||||||
|
l 喵?
|
||||||
|
f 甜菜少年。
|
||||||
|
x 〒_〒
|
||||||
|
m g
|
||||||
|
t !!!!!!!
|
||||||
|
s 傻子
|
||||||
|
h 主人,有什么吩咐,小通在此!
|
||||||
|
n u
|
||||||
|
p xp
|
||||||
|
太不真心了 我说的是哈维
|
||||||
|
管你什么哈维!方正就是看你不爽 ==
|
||||||
|
看你不爽 不要呀,哪不好我改,一定改!不要炖了我呀!
|
||||||
|
z zz
|
||||||
|
j 正晌午时说话,谁也没有家!
|
||||||
|
m r
|
||||||
|
b b
|
91
ChatBot/chatbot_search/chatbot_bertwhite/indexAnnoy.py
Normal file
91
ChatBot/chatbot_search/chatbot_bertwhite/indexAnnoy.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
# !/usr/bin/python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @time : 2021/4/18 21:04
|
||||||
|
# @author : Mo
|
||||||
|
# @function: annoy search
|
||||||
|
|
||||||
|
|
||||||
|
from annoy import AnnoyIndex
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class AnnoySearch:
|
||||||
|
def __init__(self, dim=768, n_cluster=100):
|
||||||
|
# metric可选“angular”(余弦距离)、“euclidean”(欧几里得距离)、 “ manhattan”(曼哈顿距离)或“hamming”(海明距离)
|
||||||
|
self.annoy_index = AnnoyIndex(dim, metric="angular")
|
||||||
|
self.n_cluster = n_cluster
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def k_neighbors(self, vectors, k=18):
|
||||||
|
""" 搜索 """
|
||||||
|
annoy_tops = []
|
||||||
|
for v in vectors:
|
||||||
|
idx, dist = self.annoy_index.get_nns_by_vector(v, k, search_k=32*k, include_distances=True)
|
||||||
|
annoy_tops.append([dist, idx])
|
||||||
|
return annoy_tops
|
||||||
|
|
||||||
|
def fit(self, vectors):
|
||||||
|
""" annoy构建 """
|
||||||
|
for i, v in enumerate(vectors):
|
||||||
|
self.annoy_index.add_item(i, v)
|
||||||
|
self.annoy_index.build(self.n_cluster)
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
""" 存储 """
|
||||||
|
self.annoy_index.save(path)
|
||||||
|
|
||||||
|
def load(self, path):
|
||||||
|
""" 加载 """
|
||||||
|
self.annoy_index.load(path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
### 索引
|
||||||
|
import random
|
||||||
|
path = "model.ann"
|
||||||
|
dim = 768
|
||||||
|
vectors = [[random.gauss(0, 1) for z in range(768)] for i in range(10)]
|
||||||
|
an_model = AnnoySearch(dim, n_cluster=32) # Length of item vector that will be indexed
|
||||||
|
an_model.fit(vectors)
|
||||||
|
an_model.save(path)
|
||||||
|
tops = an_model.k_neighbors([vectors[0]], 18)
|
||||||
|
print(tops)
|
||||||
|
|
||||||
|
del an_model
|
||||||
|
|
||||||
|
### 下载, 搜索
|
||||||
|
an_model = AnnoySearch(dim, n_cluster=32)
|
||||||
|
an_model.load(path)
|
||||||
|
tops = an_model.k_neighbors([vectors[0]], 6)
|
||||||
|
print(tops)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
# example
|
||||||
|
from annoy import AnnoyIndex
|
||||||
|
import random
|
||||||
|
|
||||||
|
dim = 768
|
||||||
|
vectors = [[random.gauss(0, 1) for z in range(768)] for i in range(10)]
|
||||||
|
ann_model = AnnoyIndex(dim, 'angular') # Length of item vector that will be indexed
|
||||||
|
for i,v in enumerate(vectors):
|
||||||
|
ann_model.add_item(i, v)
|
||||||
|
ann_model.build(10) # 10 trees
|
||||||
|
ann_model.save("tet.ann")
|
||||||
|
del ann_model
|
||||||
|
|
||||||
|
u = AnnoyIndex(dim, "angular")
|
||||||
|
u.load('tet.ann') # super fast, will just mmap the file
|
||||||
|
v = vectors[1]
|
||||||
|
idx, dist = u.get_nns_by_vector(v, 10, search_k=50 * 10, include_distances=True)
|
||||||
|
print([idx, dist])
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### 备注说明: annoy索引 无法 增删会改查
|
||||||
|
|
||||||
|
|
||||||
|
|
109
ChatBot/chatbot_search/chatbot_bertwhite/indexFaiss.py
Normal file
109
ChatBot/chatbot_search/chatbot_bertwhite/indexFaiss.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# !/usr/bin/python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# @time : 2021/5/9 16:02
|
||||||
|
# @author : Mo
|
||||||
|
# @function: search of faiss
|
||||||
|
|
||||||
|
|
||||||
|
from faiss import normalize_L2
|
||||||
|
import numpy as np
|
||||||
|
import faiss
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class FaissSearch:
|
||||||
|
def __init__(self, dim=768, n_cluster=100):
|
||||||
|
self.n_cluster = n_cluster # 聚类中心
|
||||||
|
self.dim = dim
|
||||||
|
quantizer = faiss.IndexFlatIP(self.dim)
|
||||||
|
# METRIC_INNER_PRODUCT:余弦; L2: faiss.METRIC_L2
|
||||||
|
self.faiss_index = faiss.IndexIVFFlat(quantizer, self.dim, self.n_cluster, faiss.METRIC_INNER_PRODUCT)
|
||||||
|
# self.faiss_index = faiss.IndexFlatIP(self.dim) # 索引速度更快 但是不可增量
|
||||||
|
|
||||||
|
def k_neighbors(self, vectors, k=6):
|
||||||
|
""" 搜索 """
|
||||||
|
normalize_L2(vectors)
|
||||||
|
dist, index = self.faiss_index.search(vectors, k) # sanity check
|
||||||
|
return dist.tolist(), index.tolist()
|
||||||
|
|
||||||
|
def fit(self, vectors):
|
||||||
|
""" annoy构建 """
|
||||||
|
normalize_L2(vectors)
|
||||||
|
self.faiss_index.train(vectors)
|
||||||
|
# self.faiss_index.add(vectors)
|
||||||
|
self.faiss_index.add_with_ids(vectors, np.arange(0, len(vectors)))
|
||||||
|
|
||||||
|
def remove(self, ids):
|
||||||
|
self.faiss_index.remove_ids(np.array(ids))
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
""" 存储 """
|
||||||
|
faiss.write_index(self.faiss_index, path)
|
||||||
|
|
||||||
|
def load(self, path):
|
||||||
|
""" 加载 """
|
||||||
|
self.faiss_index = faiss.read_index(path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
path = "model.fai"
|
||||||
|
dim = 768
|
||||||
|
vectors = np.array([[random.gauss(0, 1) for z in range(768)] for i in range(32)], dtype=np.float32)
|
||||||
|
fai_model = FaissSearch(dim, n_cluster=32) # Length of item vector that will be indexed
|
||||||
|
fai_model.fit(vectors)
|
||||||
|
fai_model.save(path)
|
||||||
|
tops = fai_model.k_neighbors(vectors[:32], 32)
|
||||||
|
print(tops)
|
||||||
|
ids = np.arange(10, 32)
|
||||||
|
fai_model.remove(ids)
|
||||||
|
tops = fai_model.k_neighbors(vectors[:32], 32)
|
||||||
|
print(tops)
|
||||||
|
print(len(tops))
|
||||||
|
|
||||||
|
del fai_model
|
||||||
|
|
||||||
|
fai_model = FaissSearch(dim, n_cluster=32)
|
||||||
|
fai_model.load(path)
|
||||||
|
tops = fai_model.k_neighbors(vectors[:32], 32)
|
||||||
|
print(tops)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
d = 64 # dimension
|
||||||
|
nb = 100000 # database size
|
||||||
|
nq = 10000 # nb of queries
|
||||||
|
np.random.seed(1234) # make reproducible
|
||||||
|
xb = np.random.random((nb, d)).astype('float32')
|
||||||
|
xb[:, 0] += np.arange(nb) / 1000.
|
||||||
|
xq = np.random.random((nq, d)).astype('float32')
|
||||||
|
xq[:, 0] += np.arange(nq) / 1000.
|
||||||
|
|
||||||
|
import faiss # make faiss available
|
||||||
|
# # 量化器索引
|
||||||
|
# nlist = 1000 # 聚类中心的个数
|
||||||
|
# k = 50 # 邻居个数
|
||||||
|
# quantizer = faiss.IndexFlatIP(d) # the other index,需要以其他index作为基础
|
||||||
|
# index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT) # METRIC_INNER_PRODUCT:余弦; L2: faiss.METRIC_L2
|
||||||
|
|
||||||
|
ntree = 132 # 聚类中心的个数
|
||||||
|
quantizer = faiss.IndexFlatIP(d)
|
||||||
|
index = faiss.IndexIVFFlat(quantizer, d, ntree, faiss.METRIC_INNER_PRODUCT)
|
||||||
|
# index = faiss.IndexFlatL2(d) # build the index
|
||||||
|
print(index.is_trained)
|
||||||
|
index.add(xb) # add vectors to the index
|
||||||
|
print(index.ntotal)
|
||||||
|
|
||||||
|
k = 4 # we want to see 4 nearest neighbors
|
||||||
|
D, I = index.search(xb[:5], k) # sanity check
|
||||||
|
print(I)
|
||||||
|
print(D)
|
||||||
|
D, I = index.search(xq, k) # actual search
|
||||||
|
print(I[:5]) # neighbors of the 5 first queries
|
||||||
|
print(I[-5:]) # neighbors of the 5 last queries
|
||||||
|
"""
|
||||||
|
|
150
ChatBot/chatbot_search/chatbot_bertwhite/mmr.py
Normal file
150
ChatBot/chatbot_search/chatbot_bertwhite/mmr.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# !/usr/bin/python
|
||||||
|
# @time :2019/10/28 10:16
|
||||||
|
# @author :Mo
|
||||||
|
# @function :MMR, Maximal Marginal Relevance, 最大边界相关法或者最大边缘相关
|
||||||
|
|
||||||
|
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
import logging
|
||||||
|
import jieba
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
jieba.setLogLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
stop_words = {"0": "~~~~",
|
||||||
|
"1": "...................",
|
||||||
|
"2": "......",}
|
||||||
|
|
||||||
|
|
||||||
|
def cut_sentence(sentence):
|
||||||
|
"""
|
||||||
|
分句
|
||||||
|
:param sentence:str
|
||||||
|
:return:list
|
||||||
|
"""
|
||||||
|
re_sen = re.compile("[:;!?。:;?!\n\r]") #.不加是因为不确定.是小数还是英文句号(中文省略号......)
|
||||||
|
sentences = re_sen.split(sentence)
|
||||||
|
sen_cuts = []
|
||||||
|
for sen in sentences:
|
||||||
|
if sen and str(sen).strip():
|
||||||
|
sen_cuts.append(sen)
|
||||||
|
return sen_cuts
|
||||||
|
|
||||||
|
def extract_chinese(text):
|
||||||
|
"""
|
||||||
|
只提取出中文、字母和数字
|
||||||
|
:param text: str, input of sentence
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
chinese_exttract = "".join(re.findall(u"([\u4e00-\u9fa5A-Za-z0-9@. ])", text))
|
||||||
|
return chinese_exttract
|
||||||
|
|
||||||
|
def tfidf_fit(sentences):
|
||||||
|
"""
|
||||||
|
tfidf相似度
|
||||||
|
:param sentences:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# tfidf计算
|
||||||
|
model = TfidfVectorizer(ngram_range=(1, 2), # 3,5
|
||||||
|
stop_words=[" ", "\t", "\n"], # 停用词
|
||||||
|
max_features=10000,
|
||||||
|
token_pattern=r"(?u)\b\w+\b", # 过滤停用词
|
||||||
|
min_df=1,
|
||||||
|
max_df=0.9,
|
||||||
|
use_idf=1, # 光滑
|
||||||
|
smooth_idf=1, # 光滑
|
||||||
|
sublinear_tf=1, ) # 光滑
|
||||||
|
matrix = model.fit_transform(sentences)
|
||||||
|
return matrix
|
||||||
|
|
||||||
|
def jieba_cut(text):
|
||||||
|
"""
|
||||||
|
Jieba cut
|
||||||
|
:param text: input sentence
|
||||||
|
:return: list
|
||||||
|
"""
|
||||||
|
return list(jieba.cut(text, cut_all=False, HMM=False))
|
||||||
|
|
||||||
|
|
||||||
|
class MMRSum:
|
||||||
|
def __init__(self):
|
||||||
|
self.stop_words = stop_words.values()
|
||||||
|
self.algorithm = "mmr"
|
||||||
|
|
||||||
|
def summarize(self, text, num=8, alpha=0.6):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param text: str
|
||||||
|
:param num: int
|
||||||
|
:return: list
|
||||||
|
"""
|
||||||
|
# 切句
|
||||||
|
if type(text) == str:
|
||||||
|
self.sentences = cut_sentence(text)
|
||||||
|
elif type(text) == list:
|
||||||
|
self.sentences = text
|
||||||
|
else:
|
||||||
|
raise RuntimeError("text type must be list or str")
|
||||||
|
# 切词
|
||||||
|
sentences_cut = [[word for word in jieba_cut(extract_chinese(sentence))
|
||||||
|
if word.strip()] for sentence in self.sentences]
|
||||||
|
# 去除停用词等
|
||||||
|
self.sentences_cut = [list(filter(lambda x: x not in self.stop_words, sc)) for sc in sentences_cut]
|
||||||
|
self.sentences_cut = [" ".join(sc) for sc in self.sentences_cut]
|
||||||
|
# # 计算每个句子的词语个数
|
||||||
|
# sen_word_len = [len(sc)+1 for sc in sentences_cut]
|
||||||
|
# 计算每个句子的tfidf
|
||||||
|
sen_tfidf = tfidf_fit(self.sentences_cut)
|
||||||
|
# 矩阵中两两句子相似度
|
||||||
|
SimMatrix = (sen_tfidf * sen_tfidf.T).A # 例如: SimMatrix[1, 3] # "第2篇与第4篇的相似度"
|
||||||
|
# 输入文本句子长度
|
||||||
|
len_sen = len(self.sentences)
|
||||||
|
# 句子标号
|
||||||
|
sen_idx = [i for i in range(len_sen)]
|
||||||
|
summary_set = []
|
||||||
|
mmr = {}
|
||||||
|
for i in range(len_sen):
|
||||||
|
if not self.sentences[i] in summary_set:
|
||||||
|
sen_idx_pop = copy.deepcopy(sen_idx)
|
||||||
|
sen_idx_pop.pop(i)
|
||||||
|
# 两两句子相似度
|
||||||
|
sim_i_j = [SimMatrix[i, j] for j in sen_idx_pop]
|
||||||
|
score_tfidf = sen_tfidf[i].toarray()[0].sum() # / sen_word_len[i], 如果除以词语个数就不准确
|
||||||
|
mmr[self.sentences[i]] = alpha * score_tfidf - (1 - alpha) * max(sim_i_j)
|
||||||
|
summary_set.append(self.sentences[i])
|
||||||
|
score_sen = [(rc[1], rc[0]) for rc in sorted(mmr.items(), key=lambda d: d[1], reverse=True)]
|
||||||
|
return score_sen[0:num]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
mmr_sum = MMRSum()
|
||||||
|
doc = "PageRank算法简介。" \
|
||||||
|
"是上世纪90年代末提出的一种计算网页权重的算法! " \
|
||||||
|
"当时,互联网技术突飞猛进,各种网页网站爆炸式增长。 " \
|
||||||
|
"业界急需一种相对比较准确的网页重要性计算方法。 " \
|
||||||
|
"是人们能够从海量互联网世界中找出自己需要的信息。 " \
|
||||||
|
"百度百科如是介绍他的思想:PageRank通过网络浩瀚的超链接关系来确定一个页面的等级。 " \
|
||||||
|
"Google把从A页面到B页面的链接解释为A页面给B页面投票。 " \
|
||||||
|
"Google根据投票来源甚至来源的来源,即链接到A页面的页面。 " \
|
||||||
|
"和投票目标的等级来决定新的等级。简单的说, " \
|
||||||
|
"一个高等级的页面可以使其他低等级页面的等级提升。 " \
|
||||||
|
"具体说来就是,PageRank有两个基本思想,也可以说是假设。 " \
|
||||||
|
"即数量假设:一个网页被越多的其他页面链接,就越重)。 " \
|
||||||
|
"质量假设:一个网页越是被高质量的网页链接,就越重要。 " \
|
||||||
|
"总的来说就是一句话,从全局角度考虑,获取重要的信。 "
|
||||||
|
sum = mmr_sum.summarize(doc)
|
||||||
|
for i in sum:
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user