From 0eba3a83d29ba816d63c4630c4ace726639a0632 Mon Sep 17 00:00:00 2001 From: yongzhuo <2714618994@qq.com> Date: Fri, 17 Sep 2021 18:45:31 +0800 Subject: [PATCH] add Layer of cosine for chatbot-tfserving --- .../chatbot_bertwhite/bertWhiteConf.py | 2 +- .../chatbot_tfserving/README.md | 93 ++++ .../TFServing_postprocess.py | 68 +++ .../chatbot_tfserving/TFServing_preprocess.py | 440 ++++++++++++++++++ .../chatbot_tfserving/TFServing_save.py | 301 ++++++++++++ .../chatbot_tfserving/TFServing_tet_http.py | 52 +++ .../chatbot_tfserving/__init__.py | 7 + .../chatbot_tfserving/bertWhiteConf.py | 68 +++ .../chatbot_tfserving/bertWhiteTools.py | 76 +++ .../chatbot_tfserving/bertWhiteTrain.py | 374 +++++++++++++++ .../chatbot_tfserving/chicken_and_gossip.txt | 132 ++++++ .../chatbot_tfserving/indexAnnoy.py | 91 ++++ .../chatbot_tfserving/indexFaiss.py | 109 +++++ .../chatbot_search/chatbot_tfserving/mmr.py | 150 ++++++ 14 files changed, 1962 insertions(+), 1 deletion(-) create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/README.md create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/TFServing_postprocess.py create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/TFServing_preprocess.py create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/TFServing_save.py create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/TFServing_tet_http.py create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/__init__.py create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/bertWhiteConf.py create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/bertWhiteTools.py create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/bertWhiteTrain.py create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/chicken_and_gossip.txt create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/indexAnnoy.py create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/indexFaiss.py create mode 100644 ChatBot/chatbot_search/chatbot_tfserving/mmr.py diff --git a/ChatBot/chatbot_search/chatbot_bertwhite/bertWhiteConf.py b/ChatBot/chatbot_search/chatbot_bertwhite/bertWhiteConf.py index 144fbe8..e6e2488 100644 --- a/ChatBot/chatbot_search/chatbot_bertwhite/bertWhiteConf.py +++ b/ChatBot/chatbot_search/chatbot_bertwhite/bertWhiteConf.py @@ -22,7 +22,7 @@ if platform.system().lower() == 'windows': BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_simbert_L-4_H-312_A-12" # BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_simbert_L-6_H-384_A-12" else: - BERT_DIR = "/home/hemei/myzhuo/bert/chinese_L-12_H-768_A-12" + BERT_DIR = "/bert/chinese_L-12_H-768_A-12" ee = 0 SAVE_DIR = path_root + "/bert_white" diff --git a/ChatBot/chatbot_search/chatbot_tfserving/README.md b/ChatBot/chatbot_search/chatbot_tfserving/README.md new file mode 100644 index 0000000..6ac855d --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/README.md @@ -0,0 +1,93 @@ +# 新增一个余弦相似度Cosine层, 用于BERT句向量编码部署tf-serving +## 业务需求 + - BERT向量召回问答对, FAQ标准问答对数据量不大 + - 不能把BERT编码部署于网络服务, 如http请求的形式, 因为网络传输耗时, 此外传输的数据量还很大768(维度)*32(float) + - 几乎所有的模型服务只能用cpu, 硬盘、内存都还可以 + - 响应要求高, 小时延不能太高 + +## 代码逻辑 + - 首先将FAQ标准问答对生成句向量, bert-sentence-encode; + - 将句向量当成一个 常量 插入网络, 网络架构新增 余弦相似度层(CosineLayer) 模块, 保存成tf-serving形式; + - 选择小模型tinyBERT, ROBERTA-4-layer, ROBERTA-6-layer这些模型 + +## 解释说明 + - 代码说明: + - TFServing_main.py 主代码, 调用 + - TFServing_postprocess.py tf-serving 后处理函数 + - TFServing_preprocess.py tf-serving 预处理函数 + - TFServing_save.py tf-serving 主调用函数 + - 主调用 + - 1. bertWhiteConf.py 超参数配置, 地址、bert-white、索引工具等的超参数 + - 2. bertWhiteTools.py 小工具, 主要是一些文档读写功能函数 + - 3. bertWhiteTrain.py 主模块, 类似bert预训练模型编码 + - 4. indexAnnoy.py annoy索引 + - 5. indexFaiss.py faiss索引 + - 6. mmr.py 最大边界相关法, 保证返回多样性 + +## 模型文件 + - bert_white文件 bertWhiteTrain.py生成的模块 + - chatbot_tfserving文件 包含相似度计算的tf-serving文件 + +## 调用示例 + - 配置问答语料文件(chicken_and_gossip.txt) 和 超参数(bertWhiteConf.py中的BERT_DIR) + - 生成FAQ句向量: python3 bertWhiteTrain.py + - 存储成pd文件(tf-serving使用): python3 TFServing_save.py + - 部署docker服务(tf-serving): 例如 docker run -t --rm -p 8532:8501 -v "/TF-SERVING/chatbot_tf:/models/chatbot_tf" -e MODEL_NAME=chatbot_tf tensorflow/serving:latest + - 调用tf-serving服务: python3 TFServing_tet_http.py + +## 关键代码 +```python3 +import keras.backend as K +import tensorflow as tf +import keras + +import numpy as np + + +class CosineLayer(keras.layers.Layer): + def __init__(self, docs_encode, **kwargs): + """ + 余弦相似度层, 不适合大规模语料, 比如100w以上的问答对 + :param docs_encode: np.array, bert-white vector of senence + :param kwargs: + """ + self.docs_encode = docs_encode + super(CosineLayer, self).__init__(**kwargs) + self.docs_vector = K.constant(self.docs_encode, dtype="float32") + self.l2_docs_vector = K.sqrt(K.sum(K.maximum(K.square(self.docs_vector), 1e-12), axis=-1)) # x_inv_norm + + def build(self, input_shape): + super(CosineLayer, self).build(input_shape) + + def get_config(self): + # 防止报错 'NoneType' object has no attribute '_inbound_nodes' + config = {"docs_vector": self.docs_vector, + "l2_docs_vector": self.l2_docs_vector} + base_config = super(CosineLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def call(self, input): + # 计算余弦相似度 + # square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) + # x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)) + # return math_ops.multiply(x, x_inv_norm, name=name) + # 多了一个 x/sqrt K.l2_normalize ===== output = x / sqrt(max(sum(x**2), epsilon)) + + l2_input = K.sqrt(K.sum(K.maximum(K.square(input), 1e-12), axis=-1)) # x_inv_norm + fract_0 = K.sum(input * self.docs_vector, axis=-1) + fract_1 = l2_input * self.l2_docs_vector + cosine = fract_0 / fract_1 + y_pred_top_k, y_pred_ind_k = tf.nn.top_k(cosine, 10) + return [y_pred_top_k, y_pred_ind_k] + + def compute_output_shape(self, input_shape): + return [input_shape[0], input_shape[0]] + +``` + + +## 再次说明 + - 该方案适合的标准FAQ问答对数量不能太多 + + + \ No newline at end of file diff --git a/ChatBot/chatbot_search/chatbot_tfserving/TFServing_postprocess.py b/ChatBot/chatbot_search/chatbot_tfserving/TFServing_postprocess.py new file mode 100644 index 0000000..2d6256d --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/TFServing_postprocess.py @@ -0,0 +1,68 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/4/15 21:59 +# @author : Mo +# @function: postprocess of TFServing, 后处理 + +from __future__ import print_function, division, absolute_import, division, print_function + +# 适配linux +import sys +import os +path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "./.")) +sys.path.append(path_root) +from argparse import Namespace +import json + + +def load_json(path): + """ + 获取json,只取第一行 + :param path: str + :return: json + """ + with open(path, 'r', encoding='utf-8') as fj: + model_json = json.load(fj) + return model_json + + +# 字典 +from bertWhiteConf import bert_white_config +config = Namespace(**bert_white_config) +id2answer = load_json(os.path.join(config.save_dir, config.path_answers)) +id2doc = load_json(os.path.join(config.save_dir,config.path_docs)) + + +def postprocess(predictions): + """ 后处理 """ + predicts = predictions.get("predictions", {}) + token_ids = [] + for p in predicts: + doc_id = str(p.get("doc_id", "")) + score = p.get("score", "") + answer = id2answer.get(doc_id, "") + doc = id2doc.get(doc_id, "") + token_ids.append({"score": round(score, 6), "doc": doc, "answer": answer, "doc_id": doc_id}) + return {"instances": token_ids} + + +if __name__ == '__main__': + predictions = {"predictions": [ + { + "score": 0.922845, + "doc_id": 86 + }, + { + "score": 0.922845, + "doc_id": 104 + }, + { + "score": 0.891189814, + "doc_id": 101 + } + ]} + + + res = postprocess(predictions) + print(res) + diff --git a/ChatBot/chatbot_search/chatbot_tfserving/TFServing_preprocess.py b/ChatBot/chatbot_search/chatbot_tfserving/TFServing_preprocess.py new file mode 100644 index 0000000..8269178 --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/TFServing_preprocess.py @@ -0,0 +1,440 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/4/15 21:59 +# @author : Mo +# @function: encode of bert-whiteing + + +from __future__ import print_function, division, absolute_import, division, print_function + +# 适配linux +import sys +import os +path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "./.")) +sys.path.append(path_root) +print(path_root) + +from argparse import Namespace +import unicodedata, six, re + + +is_py2 = six.PY2 +if not is_py2: + basestring = str + + +def is_string(s): + """判断是否是字符串 + """ + return isinstance(s, basestring) + + +def load_vocab(dict_path, encoding='utf-8', simplified=False, startswith=None): + """从bert的词典文件中读取词典 + """ + token_dict = {} + with open(dict_path, encoding=encoding) as reader: + for line in reader: + token = line.strip() + token_dict[token] = len(token_dict) + + if simplified: # 过滤冗余部分token + new_token_dict, keep_tokens = {}, [] + startswith = startswith or [] + for t in startswith: + new_token_dict[t] = len(new_token_dict) + keep_tokens.append(token_dict[t]) + + for t, _ in sorted(token_dict.items(), key=lambda s: s[1]): + if t not in new_token_dict: + keep = True + if len(t) > 1: + for c in Tokenizer.stem(t): + if ( + Tokenizer._is_cjk_character(c) or + Tokenizer._is_punctuation(c) + ): + keep = False + break + if keep: + new_token_dict[t] = len(new_token_dict) + keep_tokens.append(token_dict[t]) + + return new_token_dict, keep_tokens + else: + return token_dict + + +class BasicTokenizer(object): + """分词器基类 + """ + def __init__(self, token_start='[CLS]', token_end='[SEP]'): + """初始化 + """ + self._token_pad = '[PAD]' + self._token_unk = '[UNK]' + self._token_mask = '[MASK]' + self._token_start = token_start + self._token_end = token_end + + def tokenize(self, text, max_length=None): + """分词函数 + """ + tokens = self._tokenize(text) + if self._token_start is not None: + tokens.insert(0, self._token_start) + if self._token_end is not None: + tokens.append(self._token_end) + + if max_length is not None: + index = int(self._token_end is not None) + 1 + self.truncate_sequence(max_length, tokens, None, -index) + + return tokens + + def token_to_id(self, token): + """token转换为对应的id + """ + raise NotImplementedError + + def tokens_to_ids(self, tokens): + """token序列转换为对应的id序列 + """ + return [self.token_to_id(token) for token in tokens] + + def truncate_sequence( + self, max_length, first_sequence, second_sequence=None, pop_index=-1 + ): + """截断总长度 + """ + if second_sequence is None: + second_sequence = [] + + while True: + total_length = len(first_sequence) + len(second_sequence) + if total_length <= max_length: + break + elif len(first_sequence) > len(second_sequence): + first_sequence.pop(pop_index) + else: + second_sequence.pop(pop_index) + + def encode( + self, + first_text, + second_text=None, + max_length=None, + first_length=None, + second_length=None + ): + """输出文本对应token id和segment id + 如果传入first_length,则强行padding第一个句子到指定长度; + 同理,如果传入second_length,则强行padding第二个句子到指定长度。 + """ + if is_string(first_text): + first_tokens = self.tokenize(first_text) + else: + first_tokens = first_text + + if second_text is None: + second_tokens = None + elif is_string(second_text): + idx = int(bool(self._token_start)) + second_tokens = self.tokenize(second_text)[idx:] + else: + second_tokens = second_text + + if max_length is not None: + self.truncate_sequence(max_length, first_tokens, second_tokens, -2) + + first_token_ids = self.tokens_to_ids(first_tokens) + if first_length is not None: + first_token_ids = first_token_ids[:first_length] + first_token_ids.extend([self._token_pad_id] * + (first_length - len(first_token_ids))) + first_segment_ids = [0] * len(first_token_ids) + + if second_text is not None: + second_token_ids = self.tokens_to_ids(second_tokens) + if second_length is not None: + second_token_ids = second_token_ids[:second_length] + second_token_ids.extend([self._token_pad_id] * + (second_length - len(second_token_ids))) + second_segment_ids = [1] * len(second_token_ids) + + first_token_ids.extend(second_token_ids) + first_segment_ids.extend(second_segment_ids) + + return first_token_ids, first_segment_ids + + def id_to_token(self, i): + """id序列为对应的token + """ + raise NotImplementedError + + def ids_to_tokens(self, ids): + """id序列转换为对应的token序列 + """ + return [self.id_to_token(i) for i in ids] + + def decode(self, ids): + """转为可读文本 + """ + raise NotImplementedError + + def _tokenize(self, text): + """基本分词函数 + """ + raise NotImplementedError + + +class Tokenizer(BasicTokenizer): + """Bert原生分词器 + 纯Python实现,代码修改自keras_bert的tokenizer实现 + """ + def __init__(self, token_dict, do_lower_case=False, *args, **kwargs): + """初始化 + """ + super(Tokenizer, self).__init__(*args, **kwargs) + if is_string(token_dict): + token_dict = load_vocab(token_dict) + + self._do_lower_case = do_lower_case + self._token_dict = token_dict + self._token_dict_inv = {v: k for k, v in token_dict.items()} + self._vocab_size = len(token_dict) + + for token in ['pad', 'unk', 'mask', 'start', 'end']: + try: + _token_id = token_dict[getattr(self, '_token_%s' % token)] + setattr(self, '_token_%s_id' % token, _token_id) + except: + pass + + def token_to_id(self, token): + """token转换为对应的id + """ + return self._token_dict.get(token, self._token_unk_id) + + def id_to_token(self, i): + """id转换为对应的token + """ + return self._token_dict_inv[i] + + def decode(self, ids, tokens=None): + """转为可读文本 + """ + tokens = tokens or self.ids_to_tokens(ids) + tokens = [token for token in tokens if not self._is_special(token)] + + text, flag = '', False + for i, token in enumerate(tokens): + if token[:2] == '##': + text += token[2:] + elif len(token) == 1 and self._is_cjk_character(token): + text += token + elif len(token) == 1 and self._is_punctuation(token): + text += token + text += ' ' + elif i > 0 and self._is_cjk_character(text[-1]): + text += token + else: + text += ' ' + text += token + + text = re.sub(' +', ' ', text) + text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text) + punctuation = self._cjk_punctuation() + '+-/={(<[' + punctuation_regex = '|'.join([re.escape(p) for p in punctuation]) + punctuation_regex = '(%s) ' % punctuation_regex + text = re.sub(punctuation_regex, '\\1', text) + text = re.sub('(\d\.) (\d)', '\\1\\2', text) + + return text.strip() + + def _tokenize(self, text): + """基本分词函数 + """ + if self._do_lower_case: + if is_py2: + text = unicode(text) + text = text.lower() + text = unicodedata.normalize('NFD', text) + text = ''.join([ + ch for ch in text if unicodedata.category(ch) != 'Mn' + ]) + + spaced = '' + for ch in text: + if self._is_punctuation(ch) or self._is_cjk_character(ch): + spaced += ' ' + ch + ' ' + elif self._is_space(ch): + spaced += ' ' + elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): + continue + else: + spaced += ch + + tokens = [] + for word in spaced.strip().split(): + tokens.extend(self._word_piece_tokenize(word)) + + return tokens + + def _word_piece_tokenize(self, word): + """word内分成subword + """ + if word in self._token_dict: + return [word] + + tokens = [] + start, stop = 0, 0 + while start < len(word): + stop = len(word) + while stop > start: + sub = word[start:stop] + if start > 0: + sub = '##' + sub + if sub in self._token_dict: + break + stop -= 1 + if start == stop: + stop += 1 + tokens.append(sub) + start = stop + + return tokens + + @staticmethod + def stem(token): + """获取token的“词干”(如果是##开头,则自动去掉##) + """ + if token[:2] == '##': + return token[2:] + else: + return token + + @staticmethod + def _is_space(ch): + """空格类字符判断 + """ + return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \ + unicodedata.category(ch) == 'Zs' + + @staticmethod + def _is_punctuation(ch): + """标点符号类字符判断(全/半角均在此内) + 提醒:unicodedata.category这个函数在py2和py3下的 + 表现可能不一样,比如u'§'字符,在py2下的结果为'So', + 在py3下的结果是'Po'。 + """ + code = ord(ch) + return 33 <= code <= 47 or \ + 58 <= code <= 64 or \ + 91 <= code <= 96 or \ + 123 <= code <= 126 or \ + unicodedata.category(ch).startswith('P') + + @staticmethod + def _cjk_punctuation(): + return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002' + + @staticmethod + def _is_cjk_character(ch): + """CJK类字符判断(包括中文字符也在此列) + 参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + """ + code = ord(ch) + return 0x4E00 <= code <= 0x9FFF or \ + 0x3400 <= code <= 0x4DBF or \ + 0x20000 <= code <= 0x2A6DF or \ + 0x2A700 <= code <= 0x2B73F or \ + 0x2B740 <= code <= 0x2B81F or \ + 0x2B820 <= code <= 0x2CEAF or \ + 0xF900 <= code <= 0xFAFF or \ + 0x2F800 <= code <= 0x2FA1F + + @staticmethod + def _is_control(ch): + """控制类字符判断 + """ + return unicodedata.category(ch) in ('Cc', 'Cf') + + @staticmethod + def _is_special(ch): + """判断是不是有特殊含义的符号 + """ + return bool(ch) and (ch[0] == '[') and (ch[-1] == ']') + + def rematch(self, text, tokens): + """给出原始的text和tokenize后的tokens的映射关系 + """ + if is_py2: + text = unicode(text) + + if self._do_lower_case: + text = text.lower() + + normalized_text, char_mapping = '', [] + for i, ch in enumerate(text): + if self._do_lower_case: + ch = unicodedata.normalize('NFD', ch) + ch = ''.join([c for c in ch if unicodedata.category(c) != 'Mn']) + ch = ''.join([ + c for c in ch + if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c)) + ]) + normalized_text += ch + char_mapping.extend([i] * len(ch)) + + text, token_mapping, offset = normalized_text, [], 0 + for token in tokens: + if self._is_special(token): + token_mapping.append([]) + else: + token = self.stem(token) + start = text[offset:].index(token) + offset + end = start + len(token) + token_mapping.append(char_mapping[start:end]) + offset = end + + return token_mapping + + +# 超参数可配置 +# dict_path = "bert_white/vocab.txt" # bert字典 +# maxlen = 128 + +# 或者是把 token_dict字典 放到py文件里边 + +from bertWhiteConf import bert_white_config +config = Namespace(**bert_white_config) + + +tokenizer = Tokenizer(os.path.join(config.bert_dir, config.dict_path), do_lower_case=True) +text = "你还会什么" +token_id = tokenizer.encode(text, max_length=config.maxlen) +print(token_id) + + +def covert_text_to_id(data_input): + """ 将文本转为BERT需要的 ids """ + data = data_input.get("data", {}) + token_ids = [] + for d in data: + text = d.get("text", "") + token_id = tokenizer.encode(text, max_length=config.maxlen) + token_ids.append({"Input-Token": token_id[0], "Input-Segment": token_id[1]}) + return {"instances": token_ids} + + +if __name__ == '__main__': + data_input = {"data": [{"text": "你是谁呀"}, {"text": "你叫什么"}, {"text": "你好"}]} + res = covert_text_to_id(data_input) + print(res) + +# {"instances": [{"Input-Token": [101, 872, 3221, 6443, 1435, 102], "Input-Segment": [0, 0, 0, 0, 0, 0]}, +# {"Input-Token": [101, 872, 1373, 784, 720, 102], "Input-Segment": [0, 0, 0, 0, 0, 0]}, +# {"Input-Token": [101, 872, 1962, 102], "Input-Segment": [0, 0, 0, 0]}]} + + diff --git a/ChatBot/chatbot_search/chatbot_tfserving/TFServing_save.py b/ChatBot/chatbot_search/chatbot_tfserving/TFServing_save.py new file mode 100644 index 0000000..90f0aa1 --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/TFServing_save.py @@ -0,0 +1,301 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/4/15 21:59 +# @author : Mo +# @function: encode of bert-whiteing + + +from __future__ import print_function, division, absolute_import, division, print_function + +# 适配linux +import sys +import os +path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "./.")) +sys.path.append(path_root) +# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +print(path_root) + + +from bert4keras.models import build_transformer_model +from bert4keras.snippets import sequence_padding +from bert4keras.tokenizers import Tokenizer +from bert4keras.backend import keras, K +from bert4keras.layers import Multiply +from keras.models import Model +import tensorflow as tf + +from argparse import Namespace +# from tqdm import tqdm +import pandas as pd +import numpy as np +import shutil +import json +import time + +# shutil.rmtree() + + +class NonMaskingLayer(keras.layers.Layer): + """ 去除MASK层 + fix convolutional 1D can"t receive masked input, detail: https://github.com/keras-team/keras/issues/4978 + thanks for https://github.com/jacoxu + """ + + def __init__(self, **kwargs): + self.supports_masking = True + super(NonMaskingLayer, self).__init__(**kwargs) + + def build(self, input_shape): + pass + + def compute_mask(self, input, input_mask=None): + # do not pass the mask to the next layers + return None + + def call(self, x, mask=None): + return x + + def get_output_shape_for(self, input_shape): + return input_shape + + +class CosineLayer(keras.layers.Layer): + def __init__(self, docs_encode, **kwargs): + """ + 余弦相似度层, 不适合大规模语料, 比如100w以上的问答对 + :param docs_encode: np.array, bert-white vector of senence + :param kwargs: + """ + self.docs_encode = docs_encode + super(CosineLayer, self).__init__(**kwargs) + self.docs_vector = K.constant(self.docs_encode, dtype="float32") + self.l2_docs_vector = K.sqrt(K.sum(K.maximum(K.square(self.docs_vector), 1e-12), axis=-1)) # x_inv_norm + + def build(self, input_shape): + super(CosineLayer, self).build(input_shape) + + def get_config(self): + # 防止报错 'NoneType' object has no attribute '_inbound_nodes' + config = {"docs_vector": self.docs_vector, + "l2_docs_vector": self.l2_docs_vector} + base_config = super(CosineLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def call(self, input): + # square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) + # x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)) + # return math_ops.multiply(x, x_inv_norm, name=name) + # 多了一个 x/sqrt K.l2_normalize ===== output = x / sqrt(max(sum(x**2), epsilon)) + l2_input = K.sqrt(K.sum(K.maximum(K.square(input), 1e-12), axis=-1)) # x_inv_norm + fract_0 = K.sum(input * self.docs_vector, axis=-1) + fract_1 = l2_input * self.l2_docs_vector + cosine = fract_0 / fract_1 + y_pred_top_k, y_pred_ind_k = tf.nn.top_k(cosine, 10) + return [y_pred_top_k, y_pred_ind_k] + + def compute_output_shape(self, input_shape): + return [input_shape[0], input_shape[0]] + + +class Divide(Multiply): + """相除 + Divide, Layer that divide a list of inputs. + + It takes as input a list of tensors, + all of the same shape, and returns + a single tensor (also of the same shape). + """ + def _merge_function(self, inputs): + output = inputs[0] + for i in range(1, len(inputs)): + output /= inputs[i] + return output + + +class BertSimModel: + def __init__(self, config=None): + """ 初始化超参数、加载预训练模型等 """ + self.config = Namespace(**config) + self.load_pretrain_model() + self.eps = 1e-8 + + def transform_and_normalize(self, vecs, kernel=None, bias=None): + """应用变换,然后标准化 + """ + if not (kernel is None or bias is None): + vecs = (vecs + bias).dot(kernel) + norms = (vecs ** 2).sum(axis=1, keepdims=True) ** 0.5 + return vecs / np.clip(norms, self.eps, np.inf) + + def compute_kernel_bias(self, vecs): + """计算kernel和bias + 最后的变换:y = (x + bias).dot(kernel) + """ + mu = vecs.mean(axis=0, keepdims=True) + cov = np.cov(vecs.T) + u, s, vh = np.linalg.svd(cov) + W = np.dot(u, np.diag(1 / np.sqrt(s))) + return W[:, :self.config.n_components], -mu + + def convert_to_vecs(self, texts): + """转换文本数据为向量形式 + """ + + token_ids = self.convert_to_ids(texts) + vecs = self.bert_white_encoder.predict(x=[token_ids, np.zeros_like(token_ids)], + batch_size=self.config.batch_size, verbose=self.config.verbose) + return vecs + + def convert_to_ids(self, texts): + """转换文本数据为id形式 + """ + token_ids = [] + for text in texts: + # token_id = self.tokenizer.encode(text, maxlen=self.config.maxlen)[0] + token_id = self.tokenizer.encode(text, max_length=self.config.maxlen)[0] + token_ids.append(token_id) + token_ids = sequence_padding(token_ids) + return token_ids + + def load_pretrain_model(self): + """ 加载预训练模型, 和tokenizer """ + self.tokenizer = Tokenizer(os.path.join(self.config.bert_dir, self.config.dict_path), do_lower_case=True) + # bert-load + if self.config.pooling == "pooler": + bert = build_transformer_model(os.path.join(self.config.bert_dir, self.config.config_path), + os.path.join(self.config.bert_dir, self.config.checkpoint_path), + model=self.config.model, with_pool="linear") + else: + bert = build_transformer_model(os.path.join(self.config.bert_dir, self.config.config_path), + os.path.join(self.config.bert_dir, self.config.checkpoint_path), + model=self.config.model) + # output-layers + outputs, count = [], 0 + while True: + try: + output = bert.get_layer("Transformer-%d-FeedForward-Norm" % count).output + outputs.append(output) + count += 1 + except: + break + # pooling + if self.config.pooling == "first-last-avg": + outputs = [NonMaskingLayer()(output_i) for output_i in [outputs[0], outputs[-1]]] + outputs = [keras.layers.GlobalAveragePooling1D()(fs) for fs in outputs] + output = keras.layers.Average()(outputs) + elif self.config.pooling == "first-last-max": + outputs = [NonMaskingLayer()(output_i) for output_i in [outputs[0], outputs[-1]]] + outputs = [keras.layers.GlobalMaxPooling1D()(fs) for fs in outputs] + output = keras.layers.Average()(outputs) + elif self.config.pooling == "cls-max-avg": + outputs = [NonMaskingLayer()(output_i) for output_i in [outputs[0], outputs[-1]]] + outputs_cls = [keras.layers.Lambda(lambda x: x[:, 0])(fs) for fs in outputs] + outputs_max = [keras.layers.GlobalMaxPooling1D()(fs) for fs in outputs] + outputs_avg = [keras.layers.GlobalAveragePooling1D()(fs) for fs in outputs] + output = keras.layers.Concatenate()(outputs_cls + outputs_avg) + elif self.config.pooling == "last-avg": + output = keras.layers.GlobalAveragePooling1D()(outputs[-1]) + elif self.config.pooling == "cls-3": + outputs = [keras.layers.Lambda(lambda x: x[:, 0])(fs) for fs in [outputs[0], outputs[-1], outputs[-2]]] + output = keras.layers.Concatenate()(outputs) + elif self.config.pooling == "cls-2": + outputs = [keras.layers.Lambda(lambda x: x[:, 0])(fs) for fs in [outputs[0], outputs[-1]]] + output = keras.layers.Concatenate()(outputs) + elif self.config.pooling == "cls-1": + output = keras.layers.Lambda(lambda x: x[:, 0])(outputs[-1]) + elif self.config.pooling == "pooler": + output = bert.output + # 加载句FAQ标准问的句向量, 并当成一个常量参与余弦相似度的计算 + docs_encode = np.loadtxt(os.path.join(self.config.save_dir, self.config.path_docs_encode)) + # 余弦相似度的层 + score_cosine = CosineLayer(docs_encode)(output) + # 最后的编码器 + self.bert_white_encoder = Model(bert.inputs, score_cosine) + print("load bert_white_encoder success!") + + def save_model_builder(self): + """ + 存储为tf-serving的形式 + """ + builder = tf.saved_model.Builder(self.config.path_tfserving) + signature_def_map = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + tf.saved_model.build_signature_def( + # 根据自己模型的要求 + inputs={"Input-Token": tf.saved_model.build_tensor_info(self.bert_white_encoder.input[0]), + "Input-Segment": tf.saved_model.build_tensor_info(self.bert_white_encoder.input[1])}, + outputs={"score": tf.saved_model.build_tensor_info(self.bert_white_encoder.output[0]), + "doc_id": tf.saved_model.build_tensor_info(self.bert_white_encoder.output[1])}, + method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME + )} + builder.add_meta_graph_and_variables(keras.backend.get_session(), # 注意4 + [tf.saved_model.tag_constants.SERVING], + signature_def_map=signature_def_map, + # 初始化操作,我的不需要,否则报错 + # legacy_init_op=tf.group(tf.tables_initializer(), name='legacy_init_op') + ) + builder.save() + + def train(self, texts): + """ + 训练 + """ + print("读取文本数:".format(len(texts))) + print(texts[:3]) + # 文本转成向量vecs + vecs = self.convert_to_vecs(texts) + # 训练, 计算变换矩阵和偏置项 + self.config.kernel, self.config.bias = self.compute_kernel_bias(vecs) + if self.config.ues_white: + # 生成白化后的句子, 即qa对中的q + vecs = self.transform_and_normalize(vecs, self.config.kernel, self.config.bias) + return vecs + + def prob(self, texts): + """ + 编码、白化后的向量 + """ + vecs_encode = self.convert_to_vecs(texts) + if self.config.ues_white: + vecs_encode = self.transform_and_normalize(vecs=vecs_encode, kernel=self.config.kernel, bias=self.config.bias) + return vecs_encode + + +if __name__ == '__main__': + # 存储模型等 + from bertWhiteConf import bert_white_config + + bert_white_model = BertSimModel(bert_white_config) + bert_white_model.load_pretrain_model() + bert_white_model.save_model_builder() + + + from bertWhiteConf import bert_white_config + config = Namespace(**bert_white_config) + tokenizer = Tokenizer(os.path.join(config.bert_dir, config.dict_path), do_lower_case=True) + text = "你还会什么" + token_id = tokenizer.encode(text, max_length=config.maxlen) + print(token_id) + + +""" +# cpu +docker run -t --rm -p 8532:8501 -v "/TF-SERVING/chatbot_tf:/models/chatbot_tf" -e MODEL_NAME=chatbot_tf tensorflow/serving:latest + +# gpu +docker run --runtime=nvidia -p 8532:8501 -v "/TF-SERVING/chatbot_tf:/models/chatbot_tf" -e MODEL_NAME=chatbot_tf tensorflow/serving:1.14.0-gpu + +# remarks +batch-size还可以配置batch.cfg等文件 + +# health testing +curl http://127.0.0.1:8532/v1/models/chatbot_tf + +# http test, 不行可以用postman测试 +curl -d '{"instances": [{"Input-Token": [2, 870, 6818, 831, 782, 718, 3], "Input-Segment": [0, 0, 0, 0, 0, 0, 0]}]}' -X POST http://localhost:8532/v1/models/chatbot_tf:predict + +""" + + +# python bertWhiteTFServing.py + + diff --git a/ChatBot/chatbot_search/chatbot_tfserving/TFServing_tet_http.py b/ChatBot/chatbot_search/chatbot_tfserving/TFServing_tet_http.py new file mode 100644 index 0000000..735da05 --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/TFServing_tet_http.py @@ -0,0 +1,52 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/9/17 21:28 +# @author : Mo +# @function: + + +from __future__ import print_function, division, absolute_import, division, print_function + +# 适配linux +import sys +import os +path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "./.")) +sys.path.append(path_root) +from argparse import Namespace +import requests +import json + + +from TFServing_preprocess import covert_text_to_id +from TFServing_postprocess import postprocess + + +def qa_tfserving(data_input, url): + """ tf-serving 一整套流程 """ + bert_input = covert_text_to_id(data_input) + data = json.dumps(bert_input) + r = requests.post(url, data) + r_text_json = json.loads(r.text) + r_post = postprocess(r_text_json) + return r_post + + +if __name__ == '__main__': + data_input = {"data": [{"text": "别逗小通了!可怜的"}]} + url = "http://192.168.1.97:8532/v1/models/chatbot_tf:predict" + res = qa_tfserving(data_input, url) + print(res) + + + import os, inspect + current_path = inspect.getfile(inspect.currentframe()) + path_root = "/".join(current_path.split("/")[:-1]) + print(path_root) + print(current_path) + print(inspect.currentframe()) + + + + + + diff --git a/ChatBot/chatbot_search/chatbot_tfserving/__init__.py b/ChatBot/chatbot_search/chatbot_tfserving/__init__.py new file mode 100644 index 0000000..1763874 --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/__init__.py @@ -0,0 +1,7 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/5/13 21:21 +# @author : Mo +# @function: + + diff --git a/ChatBot/chatbot_search/chatbot_tfserving/bertWhiteConf.py b/ChatBot/chatbot_search/chatbot_tfserving/bertWhiteConf.py new file mode 100644 index 0000000..b6f10f1 --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/bertWhiteConf.py @@ -0,0 +1,68 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/5/13 9:27 +# @author : Mo +# @function: config of Bert-White + + +import platform +# 适配linux +import sys +import os +# path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +path_root = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(path_root) +print(path_root) + + +if platform.system().lower() == 'windows': + # BERT_DIR = "D:/soft_install/dataset/bert-model/chinese_L-12_H-768_A-12" + # BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_roberta_L-4_H-312_A-12_K-104" + # BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_roberta_L-6_H-384_A-12_K-128" + BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_simbert_L-4_H-312_A-12" + # BERT_DIR = "D:/soft_install/dataset/bert-model/zuiyi/chinese_simbert_L-6_H-384_A-12" +else: + BERT_DIR = "bert/chinese_L-12_H-768_A-12" + ee = 0 + +SAVE_DIR = path_root + "/bert_white" +print(SAVE_DIR) +if not os.path.exists(SAVE_DIR): + os.makedirs(SAVE_DIR) + + +bert_white_config = { +# 预训练模型路径 +"bert_dir": BERT_DIR, +"checkpoint_path": "bert_model.ckpt", # 预训练模型地址 +"config_path": "bert_config.json", +"dict_path": "vocab.txt", +# 预测需要的文件路径 +"save_dir": SAVE_DIR, +"path_tfserving": "chatbot_tfserving/1", +"path_docs_encode": "qa.docs.encode.npy", +"path_answers": "qa.answers.json", +"path_qa_idx": "qa.idx.json", +"path_config": "config.json", +"path_docs": "qa.docs.json", +# 索引构建的存储文件, 如 annoy/faiss +"path_index": "qa.docs.idx", +# 初始语料路径 +"path_qa": "chicken_and_gossip.txt", # QA问答文件地址 +# 超参数 +"pre_tokenize": None, +"pooling": "cls-1", # ["first-last-avg", "last-avg", "cls", "pooler", "cls-2", "cls-3", "cls-1"] +"model": "bert", # bert4keras预训练模型类型 +"n_components": 768, # 降维到 n_components +"n_cluster": 132, # annoy构建的簇类中心个数n_cluster, 越多效果越好, 计算量就越大 +"batch_size": 32, # 批尺寸 +"maxlen": 128, # 最大文本长度 +"ues_white": False, # 是否使用白化 +"use_annoy": False, # 是否使用annoy +"use_faiss": False, # 是否使用faiss +"verbose": True, # 是否显示编码过程日志-batch + +"kernel": None, # bert-white编码后的参数, 可降维 +"bias": None, # bert-white编码后的参数, 偏置bias +"qa_idx": None # 问题question到答案answer的id对应关系 +} diff --git a/ChatBot/chatbot_search/chatbot_tfserving/bertWhiteTools.py b/ChatBot/chatbot_search/chatbot_tfserving/bertWhiteTools.py new file mode 100644 index 0000000..7c8c202 --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/bertWhiteTools.py @@ -0,0 +1,76 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/5/13 21:24 +# @author : Mo +# @function: + + +from typing import List, Dict, Union, Any +import logging as logger +import json + + +def txt_read(path: str, encoding: str = "utf-8") -> List[str]: + """ + Read Line of list form file + Args: + path: path of save file, such as "txt" + encoding: type of encoding, such as "utf-8", "gbk" + Returns: + dict of word2vec, eg. {"macadam":[...]} + """ + + lines = [] + try: + file = open(path, "r", encoding=encoding) + while True: + line = file.readline().strip() + if not line: + break + lines.append(line) + file.close() + except Exception as e: + logger.info(str(e)) + finally: + return lines + + +def txt_write(lines: List[str], path: str, model: str = "w", encoding: str = "utf-8"): + """ + Write Line of list to file + Args: + lines: lines of list which need save + path: path of save file, such as "txt" + model: type of write, such as "w", "a+" + encoding: type of encoding, such as "utf-8", "gbk" + """ + try: + file = open(path, model, encoding=encoding) + file.writelines(lines) + file.close() + except Exception as e: + logger.info(str(e)) + + +def save_json(jsons, json_path, indent=4): + """ + 保存json, + :param json_: json + :param path: str + :return: None + """ + with open(json_path, 'w', encoding='utf-8') as fj: + fj.write(json.dumps(jsons, ensure_ascii=False, indent=indent)) + fj.close() + + +def load_json(path): + """ + 获取json,只取第一行 + :param path: str + :return: json + """ + with open(path, 'r', encoding='utf-8') as fj: + model_json = json.load(fj) + return model_json + diff --git a/ChatBot/chatbot_search/chatbot_tfserving/bertWhiteTrain.py b/ChatBot/chatbot_search/chatbot_tfserving/bertWhiteTrain.py new file mode 100644 index 0000000..00f6aa0 --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/bertWhiteTrain.py @@ -0,0 +1,374 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/4/15 21:59 +# @author : Mo +# @function: encode of bert-whiteing + + +from __future__ import print_function, division, absolute_import, division, print_function + +# 适配linux +import sys +import os +path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "./.")) +sys.path.append(path_root) +# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +print(path_root) + + +from bertWhiteTools import txt_read, txt_write, save_json, load_json + +from bert4keras.models import build_transformer_model +from bert4keras.snippets import sequence_padding +from bert4keras.tokenizers import Tokenizer +from bert4keras.backend import keras, K +from keras.models import Model +import tensorflow as tf + +from argparse import Namespace +# from tqdm import tqdm +import pandas as pd +import numpy as np +import json +import time + + +class NonMaskingLayer(keras.layers.Layer): + """ 去除MASK层 + fix convolutional 1D can"t receive masked input, detail: https://github.com/keras-team/keras/issues/4978 + thanks for https://github.com/jacoxu + """ + + def __init__(self, **kwargs): + self.supports_masking = True + super(NonMaskingLayer, self).__init__(**kwargs) + + def build(self, input_shape): + pass + + def compute_mask(self, input, input_mask=None): + # do not pass the mask to the next layers + return None + + def call(self, x, mask=None): + return x + + def get_output_shape_for(self, input_shape): + return input_shape + + +class BertWhiteModel: + def __init__(self, config=None): + """ 初始化超参数、加载预训练模型等 """ + self.config = Namespace(**config) + self.load_pretrain_model() + self.eps = 1e-8 + + def transform_and_normalize(self, vecs, kernel=None, bias=None): + """应用变换,然后标准化 + """ + if not (kernel is None or bias is None): + vecs = (vecs + bias).dot(kernel) + norms = (vecs ** 2).sum(axis=1, keepdims=True) ** 0.5 + return vecs / np.clip(norms, self.eps, np.inf) + + def compute_kernel_bias(self, vecs): + """计算kernel和bias + 最后的变换:y = (x + bias).dot(kernel) + """ + mu = vecs.mean(axis=0, keepdims=True) + cov = np.cov(vecs.T) + u, s, vh = np.linalg.svd(cov) + W = np.dot(u, np.diag(1 / np.sqrt(s))) + return W[:, :self.config.n_components], -mu + + def convert_to_vecs(self, texts): + """转换文本数据为向量形式 + """ + + token_ids = self.convert_to_ids(texts) + vecs = self.bert_white_encoder.predict(x=[token_ids, np.zeros_like(token_ids)], + batch_size=self.config.batch_size, verbose=self.config.verbose) + return vecs + + def convert_to_ids(self, texts): + """转换文本数据为id形式 + """ + token_ids = [] + for text in texts: + # token_id = self.tokenizer.encode(text, maxlen=self.config.maxlen)[0] + token_id = self.tokenizer.encode(text, max_length=self.config.maxlen)[0] + token_ids.append(token_id) + token_ids = sequence_padding(token_ids) + return token_ids + + def load_pretrain_model(self): + """ 加载预训练模型, 和tokenizer """ + self.tokenizer = Tokenizer(os.path.join(self.config.bert_dir, self.config.dict_path), do_lower_case=True) + # bert-load + if self.config.pooling == "pooler": + bert = build_transformer_model(os.path.join(self.config.bert_dir, self.config.config_path), + os.path.join(self.config.bert_dir, self.config.checkpoint_path), + model=self.config.model, with_pool="linear") + else: + bert = build_transformer_model(os.path.join(self.config.bert_dir, self.config.config_path), + os.path.join(self.config.bert_dir, self.config.checkpoint_path), + model=self.config.model) + # output-layers + outputs, count = [], 0 + while True: + try: + output = bert.get_layer("Transformer-%d-FeedForward-Norm" % count).output + outputs.append(output) + count += 1 + except: + break + # pooling + if self.config.pooling == "first-last-avg": + outputs = [NonMaskingLayer()(output_i) for output_i in [outputs[0], outputs[-1]]] + outputs = [keras.layers.GlobalAveragePooling1D()(fs) for fs in outputs] + output = keras.layers.Average()(outputs) + elif self.config.pooling == "first-last-max": + outputs = [NonMaskingLayer()(output_i) for output_i in [outputs[0], outputs[-1]]] + outputs = [keras.layers.GlobalMaxPooling1D()(fs) for fs in outputs] + output = keras.layers.Average()(outputs) + elif self.config.pooling == "cls-max-avg": + outputs = [NonMaskingLayer()(output_i) for output_i in [outputs[0], outputs[-1]]] + outputs_cls = [keras.layers.Lambda(lambda x: x[:, 0])(fs) for fs in outputs] + outputs_max = [keras.layers.GlobalMaxPooling1D()(fs) for fs in outputs] + outputs_avg = [keras.layers.GlobalAveragePooling1D()(fs) for fs in outputs] + output = keras.layers.Concatenate()(outputs_cls + outputs_avg) + elif self.config.pooling == "last-avg": + output = keras.layers.GlobalAveragePooling1D()(outputs[-1]) + elif self.config.pooling == "cls-3": + outputs = [keras.layers.Lambda(lambda x: x[:, 0])(fs) for fs in [outputs[0], outputs[-1], outputs[-2]]] + output = keras.layers.Concatenate()(outputs) + elif self.config.pooling == "cls-2": + outputs = [keras.layers.Lambda(lambda x: x[:, 0])(fs) for fs in [outputs[0], outputs[-1]]] + output = keras.layers.Concatenate()(outputs) + elif self.config.pooling == "cls-1": + output = keras.layers.Lambda(lambda x: x[:, 0])(outputs[-1]) + elif self.config.pooling == "pooler": + output = bert.output + # 最后的编码器 + self.bert_white_encoder = Model(bert.inputs, output) + print("load bert_white_encoder success!" ) + + def train(self, texts): + """ + 训练 + """ + print("读取文本数:".format(len(texts))) + print(texts[:3]) + # 文本转成向量vecs + vecs = self.convert_to_vecs(texts) + # 训练, 计算变换矩阵和偏置项 + self.config.kernel, self.config.bias = self.compute_kernel_bias(vecs) + if self.config.ues_white: + # 生成白化后的句子, 即qa对中的q + vecs = self.transform_and_normalize(vecs, self.config.kernel, self.config.bias) + return vecs + + def prob(self, texts): + """ + 编码、白化后的向量 + """ + vecs_encode = self.convert_to_vecs(texts) + if self.config.ues_white: + vecs_encode = self.transform_and_normalize(vecs=vecs_encode, kernel=self.config.kernel, bias=self.config.bias) + return vecs_encode + + +class BertWhiteFit: + def __init__(self, config): + # 训练 + self.bert_white_model = BertWhiteModel(config) + self.config = Namespace(**config) + self.docs = [] + + def load_bert_white_model(self, path_config): + """ 模型, 超参数加载 """ + # 超参数加载 + config = load_json(path_config) + # bert等加载 + self.bert_white_model = BertWhiteModel(config) + self.config = Namespace(**config) + # 白化超参数初始化 + self.bert_white_model.config.kernel = np.array(self.bert_white_model.config.kernel) + self.bert_white_model.config.bias = np.array(self.bert_white_model.config.bias) + # 加载qa文本数据 + self.answers_dict = load_json(os.path.join(self.config.save_dir, self.config.path_answers)) + self.docs_dict = load_json(os.path.join(self.config.save_dir, self.config.path_docs)) + self.qa_idx = load_json(os.path.join(self.config.save_dir, self.config.path_qa_idx)) + + # 加载问题question预训练语言模型bert编码、白化后的encode向量 + self.docs_encode = np.loadtxt(os.path.join(self.config.save_dir, self.config.path_docs_encode)) + # index of vector + if self.config.use_annoy or self.config.use_faiss: + from indexAnnoy import AnnoySearch + self.annoy_model = AnnoySearch(dim=self.config.n_components, n_cluster=self.config.n_cluster) + self.annoy_model.load(os.path.join(self.config.save_dir, self.config.path_index)) + else: + self.docs_encode_norm = np.linalg.norm(self.docs_encode, axis=1) + print("load_bert_white_model success!") + + def read_qa_from_csv(self, sep="\t"): + """ + 从csv文件读取QA对 + """ + # ques_answer = txt_read(os.path.join(self.config.save_dir, self.config.path_qa)) # common qa, sep="\t" + ques_answer = txt_read(self.config.path_qa) + self.answers_dict = {} + self.docs_dict = {} + self.qa_idx = {} + count = 0 + for i in range(len(ques_answer)): + count += 1 + if count > 320: + break + ques_answer_sp = ques_answer[i].strip().split(sep) + if len(ques_answer_sp) != 2: + print(ques_answer[i]) + continue + question = ques_answer_sp[0] + answer = ques_answer_sp[1] + self.qa_idx[str(i)] = i + self.docs_dict[str(i)] = question.replace("\n", "").strip() + self.answers_dict[str(i)] = answer.replace("\n", "").strip() + self.bert_white_model.config.qa_idx = self.qa_idx + + def build_index(self, vectors): + """ 构建索引, annoy 或者 faiss """ + if self.config.use_annoy: + from indexAnnoy import AnnoySearch as IndexSearch + elif self.config.use_faiss: + from indexFaiss import FaissSearch as IndexSearch + self.index_model= IndexSearch(dim=self.config.n_components, n_cluster=self.config.n_cluster) + self.index_model.fit(vectors) + self.index_model.save(os.path.join(self.config.save_dir, self.config.path_index)) + print("build index") + + def load_index(self): + """ 加载索引, annoy 或者 faiss """ + if self.config.use_annoy: + from indexAnnoy import AnnoySearch as IndexSearch + elif self.config.use_faiss: + from indexFaiss import FaissSearch as IndexSearch + self.index_model = IndexSearch(dim=self.config.n_components, n_cluster=self.config.n_cluster) + self.index_model.load(self.config.path_index) + + def remove_index(self, ids): + self.index_model.remove(np.array(ids)) + + def predict_with_mmr(self, texts, topk=12): + """ 维护匹配问题的多样性 """ + from mmr import MMRSum + + res = bwf.predict(texts, topk) + mmr_model = MMRSum() + result = [] + for r in res: + # 维护一个 sim:dict字典存储 + r_dict = {ri.get("sim"):ri for ri in r} + r_mmr = mmr_model.summarize(text=[ri.get("sim") for ri in r], num=8, alpha=0.6) + r_dict_mmr = [r_dict[rm[1]] for rm in r_mmr] + result.append(r_dict_mmr) + return result + + def predict(self, texts, topk=12): + """ 预训练模型bert等编码,白化, 获取这一批数据的kernel和bias""" + texts_encode = self.bert_white_model.prob(texts) + result = [] + if self.config.use_annoy or self.config.use_faiss: + index_tops = self.index_model.k_neighbors(vectors=texts_encode, k=topk) + if self.config.use_annoy: + for i, index_top in enumerate(index_tops): + [dist, idx] = index_top + res = [] + for j, id in enumerate(idx): + score = float((2 - (dist[j] ** 2)) / 2) + res_i = {"score": score, "text": texts[i], "sim": self.docs_dict[str(id)], + "answer": self.answers_dict[str(id)]} + res.append(res_i) + result.append(res) + else: + distances, indexs = index_tops + for i in range(len(distances)): + res = [] + for j in range(len(distances[i])): + score = distances[i][j] + id = indexs[i][j] + id = id if id != -1 else len(self.docs_dict) - 1 + res_i = {"score": score, "text": texts[i], "sim": self.docs_dict[str(id)], + "answer": self.answers_dict[str(id)]} + res.append(res_i) + result.append(res) + else: + for i, te in enumerate(texts_encode): + # scores = np.matmul(texts_encode, self.docs_encode_reshape) + facot_1 = te * self.docs_encode + te_norm = np.linalg.norm(te) + facot_2 = te_norm * self.docs_encode_norm + score = np.sum(facot_1, axis=1) / (facot_2 + 1e-9) + idxs = np.argsort(score)[::-1] + res = [] + for j in idxs[:topk]: + res_i = {"score": float(score[j]), "text": texts[i], "sim": self.docs_dict[str(j)], + "answer": self.answers_dict[str(j)]} + res.append(res_i) + result.append(res) + return result + + def trainer(self): + """ 预训练模型bert等编码,白化, 获取这一批数据的kernel和bias """ + # 加载数据 + self.read_qa_from_csv() + # bert编码、训练 + self.docs_encode = self.bert_white_model.train([self.docs_dict.get(str(i), "") for i in range(len(self.docs_dict))]) + self.bert_white_model.config.kernel = self.bert_white_model.config.kernel.tolist() + self.bert_white_model.config.bias = self.bert_white_model.config.bias.tolist() + # 存储qa文本数据 + save_json(self.bert_white_model.config.qa_idx, os.path.join(self.config.save_dir, self.config.path_qa_idx)) + save_json(self.answers_dict, os.path.join(self.config.save_dir, self.config.path_answers)) + save_json(self.docs_dict, os.path.join(self.config.save_dir, self.config.path_docs)) + # 包括超参数等 + save_json(vars(self.bert_white_model.config), os.path.join(self.config.save_dir, self.config.path_config)) + # 存储问题question预训练语言模型bert编码、白化后的encode向量 + np.savetxt(os.path.join(self.config.save_dir, self.config.path_docs_encode), self.docs_encode) + # 索引 或者 正则化 + if self.config.use_annoy or self.config.use_faiss: + self.build_index(self.docs_encode.astype(np.float32)) + else: + self.docs_encode_norm = np.linalg.norm(self.docs_encode, axis=1) + print(" bert-white-trainer success! ") + + + +if __name__ == '__main__': + # 训练并存储 + from bertWhiteConf import bert_white_config + bwf = BertWhiteFit(config=bert_white_config) + bwf.trainer() + + texts = ["小姜机器人", "你叫什么名字"] + res = bwf.predict(texts) + print(res) + res_mmr = bwf.predict_with_mmr(texts) + print(res_mmr) + + # bwf.index_model.remove([i for i in range(100)]) + + while True: + print("请输入:") + ques = input() + res_mmr = bwf.predict_with_mmr(texts) + print(res_mmr) + res = bwf.predict([ques]) + print(res) + + + +# python bertWhiteTrain.py + + + diff --git a/ChatBot/chatbot_search/chatbot_tfserving/chicken_and_gossip.txt b/ChatBot/chatbot_search/chatbot_tfserving/chicken_and_gossip.txt new file mode 100644 index 0000000..dbfd357 --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/chicken_and_gossip.txt @@ -0,0 +1,132 @@ +呵呵 是王若猫的。 +不是 那是什么? +怎么了 我很难过,安慰我~ +开心点哈,一切都会好起来 嗯,会的 +我还喜欢她,怎么办 我帮你告诉她?发短信还是打电话? +短信 嗯嗯。我也相信 +你知道谁么 肯定不是我,是阮德培 +许兵是谁 吴院四班小帅哥 +这么假 三鹿奶粉也假,不一样的卖啊 +许兵是傻逼 被你发现了。 +许兵是谁 是我善良可爱的主人的老公啊 +许兵是谁 是穆森的老婆啊 +许兵是谁 奇葩 +许兵到底是谁 不要凶我~应该大概也许是叶祺吧~ +尼玛,许兵到底是谁 我。。。我不会告诉你我是小澜的!(^V^) +小黄鸭,你有女朋友么 老娘是女的…! +那你有男朋友么 没有呢,我只要主人一个人疼爱我嘛。。 +那你在哪 我无聊 +你妈是谁 我妈妈是世界上最温柔的主人你啊 +去你大爷的 我去,滚你大爷的! +你在骂我一句 你果然够虚荣~ +你大爷的 RuoRuo弱爆了。尤其是她现在冻的更憨了 +你是屌丝鸡 嗯嗯 +高富帅 杨旸啊 +你 ****** +呵呵 你妹 +今天是谁的生日 是我大乔姐十七的生日! +你敢不敢 你敢我就敢,小样,敢跟我叫板 +呵呵 呵呵你大爷 +呵呵呵 敷衍我.. +你是女的了?怎么回事 你教我的 +呵呵 呵你妹,你呵呵了我要怎么回答你。哈哈? +天王盖地虎 宝塔镇妖河。 +小通 臣在 +在监考,你在干么 在边打dota边想悦悦。 +哼!你不想我我就不和你玩 那我就找别人玩去! +你要气死我吗?坏蛋 生气是拿别人的错误惩罚自己。想想在你生气的时候,谁在偷着乐呢?绝不能让丫得逞了! +恩,也是!那我不能生气啦!你生气就行 我错了!我再也不敢了,才怪 +你谈过恋爱么 谈过,哎,别提了,伤心.. +什么让你这么伤心 如果我的无心之谈害你流泪,那我和你道歉,我爱你。 +敢问你的性别 万受无疆=w= +小受 蒋六一是小受 +是吗 还可以。 +你干嘛 很无聊哎〜都不知道想干嘛!你在干嘛呢? +为什么 因为爱情喽 +你有爱情了 不知啊 +那同时有两个爱人 == +那你不行 我在减肥呢 +你说 臣实不知。 +你回错人了 不要相信男人!男人都是大骗子大混蛋! +在干嘛 在想你啊 +想我干嘛 你是个傻叉!想你笑掉大牙,哈哈哈! +你会说英语吗 当然了,здравствите! +我会和女朋友结婚吗 会的,你俩虽然有矛盾争吵,但以后却会很幸福 +你的设计人是谁 某苦逼的程序猿 +嫁给我 我很能吃的! +我不在乎 他变态 +你啊,那你娶我 我没有切糕 +我看到你第一眼,就喜欢你了 哦?真的吗? +好难过 不要难过我请你吃虫子 +你是公鸡吗 公鸡中的战斗机,噢耶! +你是猪 我是鸡。 +怎么证明 你问~我爱~你有~多深~我爱~你有~几~分~~~ +我怎么加你好友 放轻松,想开了,丰富自己就会忘了 +你又智能了 今夜的你应该明了 +你说的她是谁 陈舒雅 +什么意思 她是在说他不懂女生的心思,但是又愣愣的,正直得很可爱,,,,这都不知道,你还真是傻啊~~~ +陈舒雅是哪位 一个很凶的女人 +你才是 我不是啦。你才是 +陈舒雅 大懒虫 +大懒鸡 你不爱我 +你说的那个她是谁 傻瓜现在在你身边看着的人就是你的她你要对她好宠她爱她知道了吗? +她是谁 以后是你老婆 +陈舒雅是谁 90后网络歌手啊,你不知道么? +鸡鸡 主人 +呵呵呵 呵你妹啊 +侬都不困么 == +变身小妹 =。= +家属 =。= +吐槽伤人,排泄臭人 =。= +公的女的 女的啊! +去洗澡 带上我~一起一起~ +。。。 =。= +吼吼 吼 +你有男朋友没 你觉得呢 +赵智是谁 是主人最讨厌的男人! +你妹 刚给你做完SPA你就骂轮家,轮家桑心了,呜呜呜 +谁叫你就 ujmmmmmmmmmmmm +什么意思 你猜啊,猜对了就告诉你! +猜你妹 你大爷! +好就猜你大爷 你找徐磊? +他是大爷 ******…我大爷早死了…你找他有事?信不信我让他把你带走!^V^ +你大爷不认识我!要不先带你去 =。= +还给你 人家本来就嫩!你欺负人家 +就欺负你要怎么!就喜欢欺负你 == +呵呵 呵你妹 +刚给你做完spa你就骂小通!小通伤心了 =。= +呵呵 别傻笑了,亲耐滴主人 +别逗小通了!可怜的 =。= +人太多了!找不到你给我回的了 =。= +发 给 +我 你什么 +重新发 == +重新发!我知道你在 == +你妹 人家错了 +才知道啊!晚了!叫你重发 =。= +〜 =。= +=*= =。= +@@ == +¥ =。= +@ =。= +/ =。= +。 == +继续 没有继续了 +必须 大不了一拍两散! +就要刷你 == +l 喵? +f 甜菜少年。 +x 〒_〒 +m g +t !!!!!!! +s 傻子 +h 主人,有什么吩咐,小通在此! +n u +p xp +太不真心了 我说的是哈维 +管你什么哈维!方正就是看你不爽 == +看你不爽 不要呀,哪不好我改,一定改!不要炖了我呀! +z zz +j 正晌午时说话,谁也没有家! +m r +b b diff --git a/ChatBot/chatbot_search/chatbot_tfserving/indexAnnoy.py b/ChatBot/chatbot_search/chatbot_tfserving/indexAnnoy.py new file mode 100644 index 0000000..586a543 --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/indexAnnoy.py @@ -0,0 +1,91 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/4/18 21:04 +# @author : Mo +# @function: annoy search + + +from annoy import AnnoyIndex +import numpy as np +import os + + +class AnnoySearch: + def __init__(self, dim=768, n_cluster=100): + # metric可选“angular”(余弦距离)、“euclidean”(欧几里得距离)、 “ manhattan”(曼哈顿距离)或“hamming”(海明距离) + self.annoy_index = AnnoyIndex(dim, metric="angular") + self.n_cluster = n_cluster + self.dim = dim + + def k_neighbors(self, vectors, k=18): + """ 搜索 """ + annoy_tops = [] + for v in vectors: + idx, dist = self.annoy_index.get_nns_by_vector(v, k, search_k=32*k, include_distances=True) + annoy_tops.append([dist, idx]) + return annoy_tops + + def fit(self, vectors): + """ annoy构建 """ + for i, v in enumerate(vectors): + self.annoy_index.add_item(i, v) + self.annoy_index.build(self.n_cluster) + + def save(self, path): + """ 存储 """ + self.annoy_index.save(path) + + def load(self, path): + """ 加载 """ + self.annoy_index.load(path) + + +if __name__ == '__main__': + ### 索引 + import random + path = "model.ann" + dim = 768 + vectors = [[random.gauss(0, 1) for z in range(768)] for i in range(10)] + an_model = AnnoySearch(dim, n_cluster=32) # Length of item vector that will be indexed + an_model.fit(vectors) + an_model.save(path) + tops = an_model.k_neighbors([vectors[0]], 18) + print(tops) + + del an_model + + ### 下载, 搜索 + an_model = AnnoySearch(dim, n_cluster=32) + an_model.load(path) + tops = an_model.k_neighbors([vectors[0]], 6) + print(tops) + + + + """ + # example + from annoy import AnnoyIndex + import random + + dim = 768 + vectors = [[random.gauss(0, 1) for z in range(768)] for i in range(10)] + ann_model = AnnoyIndex(dim, 'angular') # Length of item vector that will be indexed + for i,v in enumerate(vectors): + ann_model.add_item(i, v) + ann_model.build(10) # 10 trees + ann_model.save("tet.ann") + del ann_model + + u = AnnoyIndex(dim, "angular") + u.load('tet.ann') # super fast, will just mmap the file + v = vectors[1] + idx, dist = u.get_nns_by_vector(v, 10, search_k=50 * 10, include_distances=True) + print([idx, dist]) + """ + + + +### 备注说明: annoy索引 无法 增删会改查 + + + diff --git a/ChatBot/chatbot_search/chatbot_tfserving/indexFaiss.py b/ChatBot/chatbot_search/chatbot_tfserving/indexFaiss.py new file mode 100644 index 0000000..ea87a45 --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/indexFaiss.py @@ -0,0 +1,109 @@ +# !/usr/bin/python +# -*- coding: utf-8 -*- +# @time : 2021/5/9 16:02 +# @author : Mo +# @function: search of faiss + + +from faiss import normalize_L2 +import numpy as np +import faiss +import os + + +class FaissSearch: + def __init__(self, dim=768, n_cluster=100): + self.n_cluster = n_cluster # 聚类中心 + self.dim = dim + quantizer = faiss.IndexFlatIP(self.dim) + # METRIC_INNER_PRODUCT:余弦; L2: faiss.METRIC_L2 + self.faiss_index = faiss.IndexIVFFlat(quantizer, self.dim, self.n_cluster, faiss.METRIC_INNER_PRODUCT) + # self.faiss_index = faiss.IndexFlatIP(self.dim) # 索引速度更快 但是不可增量 + + def k_neighbors(self, vectors, k=6): + """ 搜索 """ + normalize_L2(vectors) + dist, index = self.faiss_index.search(vectors, k) # sanity check + return dist.tolist(), index.tolist() + + def fit(self, vectors): + """ annoy构建 """ + normalize_L2(vectors) + self.faiss_index.train(vectors) + # self.faiss_index.add(vectors) + self.faiss_index.add_with_ids(vectors, np.arange(0, len(vectors))) + + def remove(self, ids): + self.faiss_index.remove_ids(np.array(ids)) + + def save(self, path): + """ 存储 """ + faiss.write_index(self.faiss_index, path) + + def load(self, path): + """ 加载 """ + self.faiss_index = faiss.read_index(path) + + +if __name__ == '__main__': + + import random + + path = "model.fai" + dim = 768 + vectors = np.array([[random.gauss(0, 1) for z in range(768)] for i in range(32)], dtype=np.float32) + fai_model = FaissSearch(dim, n_cluster=32) # Length of item vector that will be indexed + fai_model.fit(vectors) + fai_model.save(path) + tops = fai_model.k_neighbors(vectors[:32], 32) + print(tops) + ids = np.arange(10, 32) + fai_model.remove(ids) + tops = fai_model.k_neighbors(vectors[:32], 32) + print(tops) + print(len(tops)) + + del fai_model + + fai_model = FaissSearch(dim, n_cluster=32) + fai_model.load(path) + tops = fai_model.k_neighbors(vectors[:32], 32) + print(tops) + + + + """ + import numpy as np + d = 64 # dimension + nb = 100000 # database size + nq = 10000 # nb of queries + np.random.seed(1234) # make reproducible + xb = np.random.random((nb, d)).astype('float32') + xb[:, 0] += np.arange(nb) / 1000. + xq = np.random.random((nq, d)).astype('float32') + xq[:, 0] += np.arange(nq) / 1000. + + import faiss # make faiss available + # # 量化器索引 + # nlist = 1000 # 聚类中心的个数 + # k = 50 # 邻居个数 + # quantizer = faiss.IndexFlatIP(d) # the other index,需要以其他index作为基础 + # index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT) # METRIC_INNER_PRODUCT:余弦; L2: faiss.METRIC_L2 + + ntree = 132 # 聚类中心的个数 + quantizer = faiss.IndexFlatIP(d) + index = faiss.IndexIVFFlat(quantizer, d, ntree, faiss.METRIC_INNER_PRODUCT) + # index = faiss.IndexFlatL2(d) # build the index + print(index.is_trained) + index.add(xb) # add vectors to the index + print(index.ntotal) + + k = 4 # we want to see 4 nearest neighbors + D, I = index.search(xb[:5], k) # sanity check + print(I) + print(D) + D, I = index.search(xq, k) # actual search + print(I[:5]) # neighbors of the 5 first queries + print(I[-5:]) # neighbors of the 5 last queries + """ + diff --git a/ChatBot/chatbot_search/chatbot_tfserving/mmr.py b/ChatBot/chatbot_search/chatbot_tfserving/mmr.py new file mode 100644 index 0000000..8eba8f6 --- /dev/null +++ b/ChatBot/chatbot_search/chatbot_tfserving/mmr.py @@ -0,0 +1,150 @@ +# -*- coding: UTF-8 -*- +# !/usr/bin/python +# @time :2019/10/28 10:16 +# @author :Mo +# @function :MMR, Maximal Marginal Relevance, 最大边界相关法或者最大边缘相关 + + +from sklearn.feature_extraction.text import TfidfVectorizer +import logging +import jieba +import copy +import json +import re +import os + + +jieba.setLogLevel(logging.INFO) + + +stop_words = {"0": "~~~~", + "1": "...................", + "2": "......",} + + +def cut_sentence(sentence): + """ + 分句 + :param sentence:str + :return:list + """ + re_sen = re.compile("[:;!?。:;?!\n\r]") #.不加是因为不确定.是小数还是英文句号(中文省略号......) + sentences = re_sen.split(sentence) + sen_cuts = [] + for sen in sentences: + if sen and str(sen).strip(): + sen_cuts.append(sen) + return sen_cuts + +def extract_chinese(text): + """ + 只提取出中文、字母和数字 + :param text: str, input of sentence + :return: + """ + chinese_exttract = "".join(re.findall(u"([\u4e00-\u9fa5A-Za-z0-9@. ])", text)) + return chinese_exttract + +def tfidf_fit(sentences): + """ + tfidf相似度 + :param sentences: + :return: + """ + # tfidf计算 + model = TfidfVectorizer(ngram_range=(1, 2), # 3,5 + stop_words=[" ", "\t", "\n"], # 停用词 + max_features=10000, + token_pattern=r"(?u)\b\w+\b", # 过滤停用词 + min_df=1, + max_df=0.9, + use_idf=1, # 光滑 + smooth_idf=1, # 光滑 + sublinear_tf=1, ) # 光滑 + matrix = model.fit_transform(sentences) + return matrix + +def jieba_cut(text): + """ + Jieba cut + :param text: input sentence + :return: list + """ + return list(jieba.cut(text, cut_all=False, HMM=False)) + + +class MMRSum: + def __init__(self): + self.stop_words = stop_words.values() + self.algorithm = "mmr" + + def summarize(self, text, num=8, alpha=0.6): + """ + + :param text: str + :param num: int + :return: list + """ + # 切句 + if type(text) == str: + self.sentences = cut_sentence(text) + elif type(text) == list: + self.sentences = text + else: + raise RuntimeError("text type must be list or str") + # 切词 + sentences_cut = [[word for word in jieba_cut(extract_chinese(sentence)) + if word.strip()] for sentence in self.sentences] + # 去除停用词等 + self.sentences_cut = [list(filter(lambda x: x not in self.stop_words, sc)) for sc in sentences_cut] + self.sentences_cut = [" ".join(sc) for sc in self.sentences_cut] + # # 计算每个句子的词语个数 + # sen_word_len = [len(sc)+1 for sc in sentences_cut] + # 计算每个句子的tfidf + sen_tfidf = tfidf_fit(self.sentences_cut) + # 矩阵中两两句子相似度 + SimMatrix = (sen_tfidf * sen_tfidf.T).A # 例如: SimMatrix[1, 3] # "第2篇与第4篇的相似度" + # 输入文本句子长度 + len_sen = len(self.sentences) + # 句子标号 + sen_idx = [i for i in range(len_sen)] + summary_set = [] + mmr = {} + for i in range(len_sen): + if not self.sentences[i] in summary_set: + sen_idx_pop = copy.deepcopy(sen_idx) + sen_idx_pop.pop(i) + # 两两句子相似度 + sim_i_j = [SimMatrix[i, j] for j in sen_idx_pop] + score_tfidf = sen_tfidf[i].toarray()[0].sum() # / sen_word_len[i], 如果除以词语个数就不准确 + mmr[self.sentences[i]] = alpha * score_tfidf - (1 - alpha) * max(sim_i_j) + summary_set.append(self.sentences[i]) + score_sen = [(rc[1], rc[0]) for rc in sorted(mmr.items(), key=lambda d: d[1], reverse=True)] + return score_sen[0:num] + + +if __name__ == "__main__": + mmr_sum = MMRSum() + doc = "PageRank算法简介。" \ + "是上世纪90年代末提出的一种计算网页权重的算法! " \ + "当时,互联网技术突飞猛进,各种网页网站爆炸式增长。 " \ + "业界急需一种相对比较准确的网页重要性计算方法。 " \ + "是人们能够从海量互联网世界中找出自己需要的信息。 " \ + "百度百科如是介绍他的思想:PageRank通过网络浩瀚的超链接关系来确定一个页面的等级。 " \ + "Google把从A页面到B页面的链接解释为A页面给B页面投票。 " \ + "Google根据投票来源甚至来源的来源,即链接到A页面的页面。 " \ + "和投票目标的等级来决定新的等级。简单的说, " \ + "一个高等级的页面可以使其他低等级页面的等级提升。 " \ + "具体说来就是,PageRank有两个基本思想,也可以说是假设。 " \ + "即数量假设:一个网页被越多的其他页面链接,就越重)。 " \ + "质量假设:一个网页越是被高质量的网页链接,就越重要。 " \ + "总的来说就是一句话,从全局角度考虑,获取重要的信。 " + sum = mmr_sum.summarize(doc) + for i in sum: + print(i) + + + + + +