update save to index.
This commit is contained in:
parent
d2405d940d
commit
e666de1e63
@ -1,7 +1,6 @@
|
||||
[![PyPI version](https://badge.fury.io/py/similarities.svg)](https://badge.fury.io/py/similarities)
|
||||
[![Downloads](https://pepy.tech/badge/similarities)](https://pepy.tech/project/similarities)
|
||||
[![Contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
|
||||
[![GitHub contributors](https://img.shields.io/github/contributors/shibing624/similarities.svg)](https://github.com/shibing624/similarities/graphs/contributors)
|
||||
[![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
|
||||
[![python_version](https://img.shields.io/badge/Python-3.5%2B-green.svg)](requirements.txt)
|
||||
[![GitHub issues](https://img.shields.io/github/issues/shibing624/similarities.svg)](https://github.com/shibing624/similarities/issues)
|
||||
|
@ -37,8 +37,8 @@ print('-' * 50 + '\n')
|
||||
model.add_corpus(corpus)
|
||||
res = model.most_similar(queries=sentences, topn=3)
|
||||
print(res)
|
||||
for q_id, c in res.items():
|
||||
for q_id, id_score_dict in res.items():
|
||||
print('query:', sentences[q_id])
|
||||
print("search top 3:")
|
||||
for corpus_id, s in c.items():
|
||||
for corpus_id, s in id_score_dict.items():
|
||||
print(f'\t{model.corpus[corpus_id]}: {s:.4f}')
|
||||
|
@ -41,6 +41,8 @@ corpus = [
|
||||
]
|
||||
|
||||
model.add_corpus(corpus)
|
||||
model.save_index('en_corpus_emb.json')
|
||||
model.load_index('en_corpus_emb.json')
|
||||
res = model.most_similar(queries=sentences1, topn=3)
|
||||
print(res)
|
||||
for q_id, c in res.items():
|
||||
|
@ -5,7 +5,6 @@
|
||||
"""
|
||||
import datetime
|
||||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
|
||||
@ -13,7 +12,7 @@ from loguru import logger
|
||||
|
||||
sys.path.append('../..')
|
||||
from similarities import BM25Similarity
|
||||
from similarities.utils import http_get
|
||||
from similarities.utils.get_file import http_get
|
||||
from similarities.data_loader import SearchDataLoader
|
||||
from similarities.evaluation import evaluate
|
||||
|
@ -13,7 +13,7 @@ from loguru import logger
|
||||
|
||||
sys.path.append('../..')
|
||||
from similarities import Similarity
|
||||
from similarities.utils import http_get
|
||||
from similarities.utils.get_file import http_get
|
||||
from similarities.data_loader import SearchDataLoader
|
||||
from similarities.evaluation import evaluate
|
||||
|
@ -23,6 +23,34 @@ corpus = [
|
||||
]
|
||||
|
||||
|
||||
def annoy_demo():
|
||||
corpus_new = [i + str(id) for id, i in enumerate(corpus * 10)]
|
||||
model = AnnoySimilarity(corpus=corpus_new)
|
||||
print(model)
|
||||
similarity_score = model.similarity(sentences[0], sentences[1])
|
||||
print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}")
|
||||
model.add_corpus(corpus)
|
||||
model.build_index()
|
||||
model.save_index('annoy_model.bin')
|
||||
print(model.most_similar("men喜欢这首歌"))
|
||||
# Semantic Search batch
|
||||
del model
|
||||
model = AnnoySimilarity()
|
||||
model.load_index('annoy_model.bin')
|
||||
print(model.most_similar("men喜欢这首歌"))
|
||||
queries = ["如何更换花呗绑定银行卡", "men喜欢这首歌"]
|
||||
res = model.most_similar(queries, topn=3)
|
||||
print(res)
|
||||
for q_id, c in res.items():
|
||||
print('query:', queries[q_id])
|
||||
print("search top 3:")
|
||||
for corpus_id, s in c.items():
|
||||
print(f'\t{model.corpus[corpus_id]}: {s:.4f}')
|
||||
|
||||
# os.remove('annoy_model.bin')
|
||||
print('-' * 50 + '\n')
|
||||
|
||||
|
||||
def hnswlib_demo():
|
||||
corpus_new = [i + str(id) for id, i in enumerate(corpus * 10)]
|
||||
print(corpus_new)
|
||||
@ -32,8 +60,12 @@ def hnswlib_demo():
|
||||
print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}")
|
||||
model.add_corpus(corpus)
|
||||
model.build_index()
|
||||
model.save_index('test.model')
|
||||
model.save_index('hnsw_model.bin')
|
||||
print(model.most_similar("men喜欢这首歌"))
|
||||
# Semantic Search batch
|
||||
del model
|
||||
model = HnswlibSimilarity()
|
||||
model.load_index('hnsw_model.bin')
|
||||
print(model.most_similar("men喜欢这首歌"))
|
||||
queries = ["如何更换花呗绑定银行卡", "men喜欢这首歌"]
|
||||
res = model.most_similar(queries, topn=3)
|
||||
@ -44,34 +76,10 @@ def hnswlib_demo():
|
||||
for corpus_id, s in c.items():
|
||||
print(f'\t{model.corpus[corpus_id]}: {s:.4f}')
|
||||
|
||||
os.remove('test.model')
|
||||
print('-' * 50 + '\n')
|
||||
|
||||
|
||||
def annoy_demo():
|
||||
corpus_new = [i + str(id) for id, i in enumerate(corpus * 10)]
|
||||
model = AnnoySimilarity(corpus=corpus_new)
|
||||
print(model)
|
||||
similarity_score = model.similarity(sentences[0], sentences[1])
|
||||
print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}")
|
||||
model.add_corpus(corpus)
|
||||
model.build_index()
|
||||
model.save_index('test.model')
|
||||
# Semantic Search batch
|
||||
print(model.most_similar("men喜欢这首歌"))
|
||||
queries = ["如何更换花呗绑定银行卡", "men喜欢这首歌"]
|
||||
res = model.most_similar(queries, topn=3)
|
||||
print(res)
|
||||
for q_id, c in res.items():
|
||||
print('query:', queries[q_id])
|
||||
print("search top 3:")
|
||||
for corpus_id, s in c.items():
|
||||
print(f'\t{model.corpus[corpus_id]}: {s:.4f}')
|
||||
|
||||
os.remove('test.model')
|
||||
# os.remove('hnsw_model.bin')
|
||||
print('-' * 50 + '\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
hnswlib_demo()
|
||||
annoy_demo()
|
||||
hnswlib_demo()
|
||||
|
6
setup.py
6
setup.py
@ -4,7 +4,7 @@ import sys
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
# Avoids IDE errors, but actual version is read from version.py
|
||||
__version__ = None
|
||||
__version__ = ""
|
||||
exec(open('similarities/version.py').read())
|
||||
|
||||
if sys.version_info < (3,):
|
||||
@ -33,10 +33,6 @@ setup(
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
keywords='similarities,Chinese Text Similarity Calculation Tool,similarity,word2vec',
|
||||
|
@ -19,12 +19,11 @@ class AnnoySimilarity(Similarity):
|
||||
self,
|
||||
corpus: Union[List[str], Dict[str, str]] = None,
|
||||
model_name_or_path="shibing624/text2vec-base-chinese",
|
||||
embedding_size: int = 768,
|
||||
n_trees: int = 256
|
||||
):
|
||||
super().__init__(corpus, model_name_or_path)
|
||||
self.index = None
|
||||
self.embedding_size = embedding_size
|
||||
self.embedding_size = self.get_sentence_embedding_dimension()
|
||||
self.n_trees = n_trees
|
||||
if corpus is not None and self.corpus_embeddings:
|
||||
self.build_index()
|
||||
@ -35,9 +34,8 @@ class AnnoySimilarity(Similarity):
|
||||
base += f", corpus size: {len(self.corpus)}"
|
||||
return base
|
||||
|
||||
def build_index(self):
|
||||
"""Build Annoy index after add new documents."""
|
||||
# Create Annoy Index
|
||||
def create_index(self):
|
||||
"""Create Annoy Index."""
|
||||
try:
|
||||
from annoy import AnnoyIndex
|
||||
except ImportError:
|
||||
@ -45,37 +43,49 @@ class AnnoySimilarity(Similarity):
|
||||
|
||||
# Creating the annoy index
|
||||
self.index = AnnoyIndex(self.embedding_size, 'angular')
|
||||
logger.debug(f"Init Annoy index, embedding_size: {self.embedding_size}")
|
||||
|
||||
logger.info(f"Init Annoy index, embedding_size: {self.embedding_size}")
|
||||
def build_index(self):
|
||||
"""Build Annoy index after add new documents."""
|
||||
self.create_index()
|
||||
logger.debug(f"Building index with {self.n_trees} trees.")
|
||||
|
||||
for i in range(len(self.corpus_embeddings)):
|
||||
self.index.add_item(i, self.corpus_embeddings[i])
|
||||
self.index.build(self.n_trees)
|
||||
|
||||
def save_index(self, index_path: str):
|
||||
def save_index(self, index_path: str = "annoy_index.bin"):
|
||||
"""Save the annoy index to disk."""
|
||||
if self.index and index_path:
|
||||
logger.info(f"Saving index to: {index_path}")
|
||||
if index_path:
|
||||
if self.index is None:
|
||||
self.build_index()
|
||||
self.index.save(index_path)
|
||||
corpus_emb_json_path = index_path + ".json"
|
||||
super().save_index(corpus_emb_json_path)
|
||||
logger.info(f"Saving Annoy index to: {index_path}, corpus embedding to: {corpus_emb_json_path}")
|
||||
else:
|
||||
logger.warning("No index path given. Index not saved.")
|
||||
|
||||
def load_index(self, index_path: str):
|
||||
def load_index(self, index_path: str = "annoy_index.bin"):
|
||||
"""Load Annoy Index from disc."""
|
||||
if index_path and os.path.exists(index_path):
|
||||
logger.info(f"Loading index from: {index_path}")
|
||||
corpus_emb_json_path = index_path + ".json"
|
||||
logger.info(f"Loading index from: {index_path}, corpus embedding from: {corpus_emb_json_path}")
|
||||
super().load_index(corpus_emb_json_path)
|
||||
if self.index is None:
|
||||
self.create_index()
|
||||
self.index.load(index_path)
|
||||
else:
|
||||
logger.warning("No index path given. Index not loaded.")
|
||||
|
||||
def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10):
|
||||
def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10,
|
||||
score_function: str = "cos_sim"):
|
||||
"""Find the topn most similar texts to the query against the corpus."""
|
||||
result = {}
|
||||
if self.corpus_embeddings and self.index is None:
|
||||
logger.warning(f"No index found. Please add corpus and build index first, e.g. with `build_index()`."
|
||||
f"Now returning slow search result.")
|
||||
return super().most_similar(queries, topn)
|
||||
return super().most_similar(queries, topn, score_function=score_function)
|
||||
if not self.corpus_embeddings:
|
||||
logger.error("No corpus_embeddings found. Please add corpus first, e.g. with `add_corpus()`.")
|
||||
return result
|
||||
@ -91,7 +101,7 @@ class AnnoySimilarity(Similarity):
|
||||
corpus_ids, distances = self.index.get_nns_by_vector(queries_embeddings[idx], topn, include_distances=True)
|
||||
for corpus_id, distance in zip(corpus_ids, distances):
|
||||
score = 1 - (distance ** 2) / 2
|
||||
result[qid][self.corpus_ids_map[corpus_id]] = score
|
||||
result[qid][corpus_id] = score
|
||||
|
||||
return result
|
||||
|
||||
@ -106,10 +116,10 @@ class HnswlibSimilarity(Similarity):
|
||||
self,
|
||||
corpus: Union[List[str], Dict[str, str]] = None,
|
||||
model_name_or_path="shibing624/text2vec-base-chinese",
|
||||
embedding_size: int = 768, ef_construction: int = 400, M: int = 64, ef: int = 50
|
||||
ef_construction: int = 400, M: int = 64, ef: int = 50
|
||||
):
|
||||
super().__init__(corpus, model_name_or_path)
|
||||
self.embedding_size = embedding_size
|
||||
self.embedding_size = self.get_sentence_embedding_dimension()
|
||||
self.ef_construction = ef_construction
|
||||
self.M = M
|
||||
self.ef = ef
|
||||
@ -123,52 +133,64 @@ class HnswlibSimilarity(Similarity):
|
||||
base += f", corpus size: {len(self.corpus)}"
|
||||
return base
|
||||
|
||||
def build_index(self):
|
||||
"""Build Hnswlib index after add new documents."""
|
||||
# Create hnswlib Index
|
||||
def create_index(self):
|
||||
"""Create Hnswlib Index."""
|
||||
try:
|
||||
import hnswlib
|
||||
except ImportError:
|
||||
raise ImportError("Hnswlib is not installed. Please install it first, e.g. with `pip install hnswlib`.")
|
||||
|
||||
# We use Inner Product (dot-product) as Index. We will normalize our vectors to unit length,
|
||||
# then is Inner Product equal to cosine similarity
|
||||
# Creating the hnswlib index
|
||||
self.index = hnswlib.Index(space='cosine', dim=self.embedding_size)
|
||||
self.index.init_index(max_elements=len(self.corpus_embeddings), ef_construction=self.ef_construction, M=self.M)
|
||||
# Controlling the recall by setting ef:
|
||||
self.index.set_ef(self.ef) # ef should always be > top_k_hits
|
||||
logger.debug(f"Init Hnswlib index, embedding_size: {self.embedding_size}")
|
||||
|
||||
def build_index(self):
|
||||
"""Build Hnswlib index after add new documents."""
|
||||
# Init the HNSWLIB index
|
||||
logger.info(f"Creating HNSWLIB index, max_elements: {len(self.corpus)}")
|
||||
self.create_index()
|
||||
logger.info(f"Building HNSWLIB index, max_elements: {len(self.corpus)}")
|
||||
logger.debug(f"Parameters Required: M: {self.M}")
|
||||
logger.debug(f"Parameters Required: ef_construction: {self.ef_construction}")
|
||||
logger.debug(f"Parameters Required: ef(>topn): {self.ef}")
|
||||
|
||||
self.index.init_index(max_elements=len(self.corpus_embeddings), ef_construction=self.ef_construction, M=self.M)
|
||||
# Then we train the index to find a suitable clustering
|
||||
self.index.add_items(self.corpus_embeddings, list(range(len(self.corpus_embeddings))))
|
||||
# Controlling the recall by setting ef:
|
||||
self.index.set_ef(self.ef) # ef should always be > top_k_hits
|
||||
|
||||
def save_index(self, index_path: str):
|
||||
"""Save the annoy index to disk."""
|
||||
if self.index and index_path:
|
||||
logger.info(f"Saving index to: {index_path}")
|
||||
def save_index(self, index_path: str = "hnswlib_index.bin"):
|
||||
"""Save the index to disk."""
|
||||
if index_path:
|
||||
if self.index is None:
|
||||
self.build_index()
|
||||
self.index.save_index(index_path)
|
||||
corpus_emb_json_path = index_path + ".json"
|
||||
super().save_index(corpus_emb_json_path)
|
||||
logger.info(f"Saving hnswlib index to: {index_path}, corpus embedding to: {corpus_emb_json_path}")
|
||||
else:
|
||||
logger.warning("No index path given. Index not saved.")
|
||||
|
||||
def load_index(self, index_path: str):
|
||||
"""Load Annoy Index from disc."""
|
||||
def load_index(self, index_path: str = "hnswlib_index.bin"):
|
||||
"""Load Index from disc."""
|
||||
if index_path and os.path.exists(index_path):
|
||||
logger.info(f"Loading index from: {index_path}")
|
||||
corpus_emb_json_path = index_path + ".json"
|
||||
logger.info(f"Loading index from: {index_path}, corpus embedding from: {corpus_emb_json_path}")
|
||||
super().load_index(corpus_emb_json_path)
|
||||
if self.index is None:
|
||||
self.create_index()
|
||||
self.index.load_index(index_path)
|
||||
else:
|
||||
logger.warning("No index path given. Index not loaded.")
|
||||
|
||||
def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10):
|
||||
def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10,
|
||||
score_function: str = "cos_sim"):
|
||||
"""Find the topn most similar texts to the query against the corpus."""
|
||||
result = {}
|
||||
if self.corpus_embeddings and self.index is None:
|
||||
logger.warning(f"No index found. Please add corpus and build index first, e.g. with `build_index()`."
|
||||
f"Now returning slow search result.")
|
||||
return super().most_similar(queries, topn)
|
||||
return super().most_similar(queries, topn, score_function=score_function)
|
||||
if not self.corpus_embeddings:
|
||||
logger.error("No corpus_embeddings found. Please add corpus first, e.g. with `add_corpus()`.")
|
||||
return result
|
||||
@ -186,6 +208,6 @@ class HnswlibSimilarity(Similarity):
|
||||
hits = [{'corpus_id': id, 'score': 1 - distance} for id, distance in zip(corpus_ids[i], distances[i])]
|
||||
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
||||
for hit in hits:
|
||||
result[qid][self.corpus_ids_map[hit['corpus_id']]] = hit['score']
|
||||
result[qid][hit['corpus_id']] = hit['score']
|
||||
|
||||
return result
|
||||
|
@ -38,7 +38,6 @@ class ClipSimilarity(SimilarityABC):
|
||||
self.clip_model = CLIPModel(model_name_or_path) # load the CLIP model
|
||||
self.score_functions = {'cos_sim': cos_sim, 'dot': dot_score}
|
||||
self.corpus = {}
|
||||
self.corpus_ids_map = {}
|
||||
self.corpus_embeddings = []
|
||||
if corpus is not None:
|
||||
self.add_corpus(corpus)
|
||||
@ -59,11 +58,17 @@ class ClipSimilarity(SimilarityABC):
|
||||
img = img.convert('RGB')
|
||||
return img
|
||||
|
||||
def _get_vector(self, text_or_img: Union[List[Image.Image], Image.Image, str, List[str]],
|
||||
show_progress_bar: bool = False):
|
||||
def _get_vector(
|
||||
self,
|
||||
text_or_img: Union[List[Image.Image], Image.Image, str, List[str]],
|
||||
batch_size: int = 128,
|
||||
show_progress_bar: bool = False,
|
||||
):
|
||||
"""
|
||||
Returns the embeddings for a batch of images.
|
||||
:param text_or_img: list of str or str or Image.Image or image list
|
||||
:param text_or_img: list of str or Image.Image or image list
|
||||
:param batch_size: batch size
|
||||
:param show_progress_bar: show progress bar
|
||||
:return: np.ndarray, embeddings for the given images
|
||||
"""
|
||||
if isinstance(text_or_img, str):
|
||||
@ -72,7 +77,7 @@ class ClipSimilarity(SimilarityABC):
|
||||
text_or_img = [text_or_img]
|
||||
if isinstance(text_or_img, list) and isinstance(text_or_img[0], Image.Image):
|
||||
text_or_img = [self._convert_to_rgb(i) for i in text_or_img]
|
||||
return self.clip_model.encode(text_or_img, batch_size=128, show_progress_bar=show_progress_bar)
|
||||
return self.clip_model.encode(text_or_img, batch_size=batch_size, show_progress_bar=show_progress_bar)
|
||||
|
||||
def add_corpus(self, corpus: Union[List[Image.Image], Dict[str, Image.Image]]):
|
||||
"""
|
||||
@ -93,7 +98,6 @@ class ClipSimilarity(SimilarityABC):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
logger.info(f"Start computing corpus embeddings, new docs: {len(corpus_new)}")
|
||||
corpus_embeddings = self._get_vector(list(corpus_new.values()), show_progress_bar=True).tolist()
|
||||
if self.corpus_embeddings:
|
||||
@ -147,7 +151,7 @@ class ClipSimilarity(SimilarityABC):
|
||||
all_hits = semantic_search(queries_embeddings, corpus_embeddings, top_k=topn)
|
||||
for idx, hits in enumerate(all_hits):
|
||||
for hit in hits[0:topn]:
|
||||
result[queries_ids_map[idx]][self.corpus_ids_map[hit['corpus_id']]] = hit['score']
|
||||
result[queries_ids_map[idx]][hit['corpus_id']] = hit['score']
|
||||
|
||||
return result
|
||||
|
||||
@ -201,7 +205,6 @@ class ImageHashSimilarity(SimilarityABC):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
logger.info(f"Start computing corpus embeddings, new docs: {len(corpus_new)}")
|
||||
corpus_embeddings = []
|
||||
for doc_fp in tqdm(list(corpus_new.values()), desc="Calculating corpus image hash"):
|
||||
@ -317,7 +320,6 @@ class SiftSimilarity(SimilarityABC):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
logger.info(f"Start computing corpus embeddings, new docs: {len(corpus_new)}")
|
||||
corpus_embeddings = []
|
||||
for img in tqdm(list(corpus_new.values()), desc="Calculating corpus image SIFT"):
|
||||
|
@ -36,7 +36,7 @@ class SimHashSimilarity(SimilarityABC):
|
||||
|
||||
def __init__(self, corpus: Union[List[str], Dict[str, str]] = None):
|
||||
self.corpus = {}
|
||||
self.corpus_ids_map = {}
|
||||
|
||||
self.corpus_embeddings = []
|
||||
if corpus is not None:
|
||||
self.add_corpus(corpus)
|
||||
@ -71,7 +71,7 @@ class SimHashSimilarity(SimilarityABC):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
|
||||
logger.info(f"Start computing corpus embeddings, new docs: {len(corpus_new)}")
|
||||
corpus_texts = list(corpus_new.values())
|
||||
corpus_embeddings = []
|
||||
@ -198,7 +198,7 @@ class TfidfSimilarity(SimilarityABC):
|
||||
def __init__(self, corpus: Union[List[str], Dict[str, str]] = None):
|
||||
super().__init__()
|
||||
self.corpus = {}
|
||||
self.corpus_ids_map = {}
|
||||
|
||||
self.corpus_embeddings = []
|
||||
self.tfidf = TFIDF()
|
||||
if corpus is not None:
|
||||
@ -234,7 +234,7 @@ class TfidfSimilarity(SimilarityABC):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
|
||||
logger.info(f"Start computing corpus embeddings, new docs: {len(corpus_new)}")
|
||||
corpus_texts = list(corpus_new.values())
|
||||
corpus_embeddings = []
|
||||
@ -280,7 +280,7 @@ class TfidfSimilarity(SimilarityABC):
|
||||
all_hits = semantic_search(queries_embeddings, corpus_embeddings, top_k=topn)
|
||||
for idx, hits in enumerate(all_hits):
|
||||
for hit in hits[0:topn]:
|
||||
result[queries_ids_map[idx]][self.corpus_ids_map[hit['corpus_id']]] = hit['score']
|
||||
result[queries_ids_map[idx]][hit['corpus_id']] = hit['score']
|
||||
|
||||
return result
|
||||
|
||||
@ -294,7 +294,7 @@ class BM25Similarity(SimilarityABC):
|
||||
def __init__(self, corpus: Union[List[str], Dict[str, str]] = None):
|
||||
super().__init__()
|
||||
self.corpus = {}
|
||||
self.corpus_ids_map = {}
|
||||
|
||||
self.bm25 = None
|
||||
self.default_stopwords = load_stopwords(default_stopwords_file)
|
||||
if corpus is not None:
|
||||
@ -330,7 +330,7 @@ class BM25Similarity(SimilarityABC):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
|
||||
logger.info(f"Start computing corpus embeddings, new docs: {len(corpus_new)}")
|
||||
corpus_texts = list(corpus_new.values())
|
||||
corpus_seg = [jieba.lcut(d) for d in corpus_texts]
|
||||
@ -360,7 +360,7 @@ class BM25Similarity(SimilarityABC):
|
||||
q_res = [{'corpus_id': corpus_id, 'score': score} for corpus_id, score in enumerate(scores)]
|
||||
q_res = sorted(q_res, key=lambda x: x['score'], reverse=True)[:topn]
|
||||
for res in q_res:
|
||||
corpus_id = self.corpus_ids_map[res['corpus_id']]
|
||||
corpus_id = res['corpus_id']
|
||||
result[qid][corpus_id] = res['score']
|
||||
|
||||
return result
|
||||
@ -385,7 +385,7 @@ class WordEmbeddingSimilarity(SimilarityABC):
|
||||
else:
|
||||
raise ValueError("model_name_or_path must be ~text2vec.Word2Vec or Word2Vec model name")
|
||||
self.corpus = {}
|
||||
self.corpus_ids_map = {}
|
||||
|
||||
self.corpus_embeddings = []
|
||||
if corpus is not None:
|
||||
self.add_corpus(corpus)
|
||||
@ -420,7 +420,7 @@ class WordEmbeddingSimilarity(SimilarityABC):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
|
||||
logger.info(f"Start computing corpus embeddings, new docs: {len(corpus_new)}")
|
||||
corpus_texts = list(corpus_new.values())
|
||||
corpus_embeddings = self._get_vector(corpus_texts, show_progress_bar=True).tolist()
|
||||
@ -463,7 +463,7 @@ class WordEmbeddingSimilarity(SimilarityABC):
|
||||
all_hits = semantic_search(queries_embeddings, corpus_embeddings, top_k=topn)
|
||||
for idx, hits in enumerate(all_hits):
|
||||
for hit in hits[0:topn]:
|
||||
result[queries_ids_map[idx]][self.corpus_ids_map[hit['corpus_id']]] = hit['score']
|
||||
result[queries_ids_map[idx]][hit['corpus_id']] = hit['score']
|
||||
|
||||
return result
|
||||
|
||||
@ -479,7 +479,7 @@ class CilinSimilarity(SimilarityABC):
|
||||
super().__init__()
|
||||
self.cilin_dict = self.load_cilin_dict(cilin_path) # Cilin(词林) semantic dictionary
|
||||
self.corpus = {}
|
||||
self.corpus_ids_map = {}
|
||||
|
||||
if corpus is not None:
|
||||
self.add_corpus(corpus)
|
||||
|
||||
@ -513,7 +513,7 @@ class CilinSimilarity(SimilarityABC):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
|
||||
logger.info(f"Start add new docs: {len(corpus_new)}")
|
||||
logger.info(f"Add {len(corpus)} docs, total: {len(self.corpus)}")
|
||||
|
||||
@ -636,7 +636,7 @@ class HownetSimilarity(SimilarityABC):
|
||||
def __init__(self, corpus: Union[List[str], Dict[str, str]] = None, hownet_path: str = default_hownet_path):
|
||||
self.hownet_dict = self.load_hownet_dict(hownet_path) # semantic dictionary
|
||||
self.corpus = {}
|
||||
self.corpus_ids_map = {}
|
||||
|
||||
if corpus is not None:
|
||||
self.add_corpus(corpus)
|
||||
|
||||
@ -670,7 +670,7 @@ class HownetSimilarity(SimilarityABC):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
|
||||
logger.info(f"Start add new docs: {len(corpus_new)}")
|
||||
logger.info(f"Add {len(corpus)} docs, total: {len(self.corpus)}")
|
||||
|
||||
@ -766,7 +766,7 @@ class SameCharsSimilarity(SimilarityABC):
|
||||
def __init__(self, corpus: Union[List[str], Dict[str, str]] = None):
|
||||
super().__init__()
|
||||
self.corpus = {}
|
||||
self.corpus_ids_map = {}
|
||||
|
||||
if corpus is not None:
|
||||
self.add_corpus(corpus)
|
||||
|
||||
@ -800,7 +800,7 @@ class SameCharsSimilarity(SimilarityABC):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
|
||||
logger.info(f"Start add new docs: {len(corpus_new)}")
|
||||
logger.info(f"Add {len(corpus)} docs, total: {len(self.corpus)}")
|
||||
|
||||
@ -862,7 +862,7 @@ class SequenceMatcherSimilarity(SimilarityABC):
|
||||
def __init__(self, corpus: Union[List[str], Dict[str, str]] = None):
|
||||
super().__init__()
|
||||
self.corpus = {}
|
||||
self.corpus_ids_map = {}
|
||||
|
||||
if corpus is not None:
|
||||
self.add_corpus(corpus)
|
||||
|
||||
@ -896,7 +896,7 @@ class SequenceMatcherSimilarity(SimilarityABC):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
|
||||
logger.info(f"Start add new docs: {len(corpus_new)}")
|
||||
logger.info(f"Add {len(corpus)} docs, total: {len(self.corpus)}")
|
||||
|
||||
|
@ -9,6 +9,7 @@ Compute similarity:
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import List, Union, Dict
|
||||
|
||||
import numpy as np
|
||||
@ -66,7 +67,7 @@ class SimilarityABC:
|
||||
|
||||
class Similarity(SimilarityABC):
|
||||
"""
|
||||
Bert similarity:
|
||||
Sentence Similarity:
|
||||
1. Compute the similarity between two sentences
|
||||
2. Retrieves most similar sentence of a query against a corpus of documents.
|
||||
|
||||
@ -76,14 +77,17 @@ class Similarity(SimilarityABC):
|
||||
def __init__(
|
||||
self,
|
||||
corpus: Union[List[str], Dict[str, str]] = None,
|
||||
model_name_or_path="shibing624/text2vec-base-chinese",
|
||||
model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
||||
max_seq_length=128,
|
||||
):
|
||||
"""
|
||||
Initialize the similarity object.
|
||||
:param model_name_or_path: bert model name or path, can be: ['bert-base-uncased', 'bert-base-chinese', ...]
|
||||
default "shibing624/text2vec-base-chinese", refer: https://github.com/shibing624/text2vec
|
||||
:param model_name_or_path: Transformer model name or path, like:
|
||||
'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', 'bert-base-uncased', 'bert-base-chinese',
|
||||
'shibing624/text2vec-base-chinese', ...
|
||||
model in HuggingFace Model Hub and release from https://github.com/shibing624/text2vec
|
||||
:param corpus: Corpus of documents to use for similarity queries.
|
||||
:param max_seq_length: Max sequence length for sentence model.
|
||||
"""
|
||||
if isinstance(model_name_or_path, str):
|
||||
self.sentence_model = SentenceModel(model_name_or_path, max_seq_length=max_seq_length)
|
||||
@ -93,7 +97,6 @@ class Similarity(SimilarityABC):
|
||||
raise ValueError("model_name_or_path is transformers model name or path")
|
||||
self.score_functions = {'cos_sim': cos_sim, 'dot': dot_score}
|
||||
self.corpus = {}
|
||||
self.corpus_ids_map = {}
|
||||
self.corpus_embeddings = []
|
||||
if corpus is not None:
|
||||
self.add_corpus(corpus)
|
||||
@ -108,42 +111,54 @@ class Similarity(SimilarityABC):
|
||||
base += f", corpus size: {len(self.corpus)}"
|
||||
return base
|
||||
|
||||
def get_sentence_embedding_dimension(self):
|
||||
"""
|
||||
Get the dimension of the sentence embeddings.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int or None
|
||||
The dimension of the sentence embeddings, or None if it cannot be determined.
|
||||
"""
|
||||
if hasattr(self.sentence_model, "get_sentence_embedding_dimension"):
|
||||
return self.sentence_model.get_sentence_embedding_dimension()
|
||||
else:
|
||||
return getattr(self.sentence_model.bert.pooler.dense, "out_features", None)
|
||||
|
||||
def add_corpus(self, corpus: Union[List[str], Dict[str, str]]):
|
||||
"""
|
||||
Extend the corpus with new documents.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
corpus : list of str or dict
|
||||
:param corpus: corpus of documents to use for similarity queries.
|
||||
:return: self.corpus, self.corpus embeddings
|
||||
"""
|
||||
corpus_new = {}
|
||||
new_corpus = {}
|
||||
start_id = len(self.corpus) if self.corpus else 0
|
||||
if isinstance(corpus, list):
|
||||
corpus = list(set(corpus))
|
||||
for id, doc in enumerate(corpus):
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[start_id + id] = doc
|
||||
else:
|
||||
for id, doc in corpus.items():
|
||||
if doc not in list(self.corpus.values()):
|
||||
corpus_new[id] = doc
|
||||
self.corpus.update(corpus_new)
|
||||
self.corpus_ids_map = {i: id for i, id in enumerate(list(self.corpus.keys()))}
|
||||
logger.info(f"Start computing corpus embeddings, new docs: {len(corpus_new)}")
|
||||
corpus_embeddings = self._get_vector(list(corpus_new.values()), show_progress_bar=True).tolist()
|
||||
if self.corpus_embeddings:
|
||||
self.corpus_embeddings += corpus_embeddings
|
||||
else:
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
logger.info(f"Add {len(corpus)} docs, total: {len(self.corpus)}, emb size: {len(self.corpus_embeddings)}")
|
||||
for id, doc in enumerate(corpus):
|
||||
if isinstance(corpus, list):
|
||||
if doc not in self.corpus.values():
|
||||
new_corpus[start_id + id] = doc
|
||||
else:
|
||||
if doc not in self.corpus.values():
|
||||
new_corpus[id] = doc
|
||||
self.corpus.update(new_corpus)
|
||||
logger.info(f"Start computing corpus embeddings, new docs: {len(new_corpus)}")
|
||||
corpus_embeddings = self._get_vector(list(new_corpus.values()), show_progress_bar=True).tolist()
|
||||
self.corpus_embeddings = self.corpus_embeddings + corpus_embeddings \
|
||||
if self.corpus_embeddings else corpus_embeddings
|
||||
logger.info(f"Add {len(new_corpus)} docs, total: {len(self.corpus)}, emb len: {len(self.corpus_embeddings)}")
|
||||
|
||||
def _get_vector(self, sentences: Union[str, List[str]], show_progress_bar: bool = False) -> np.ndarray:
|
||||
def _get_vector(
|
||||
self,
|
||||
sentences: Union[str, List[str]],
|
||||
batch_size: int = 64,
|
||||
show_progress_bar: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Returns the embeddings for a batch of sentences.
|
||||
:param sentences:
|
||||
:return:
|
||||
"""
|
||||
return self.sentence_model.encode(sentences, show_progress_bar=show_progress_bar)
|
||||
return self.sentence_model.encode(sentences, batch_size=batch_size, show_progress_bar=show_progress_bar)
|
||||
|
||||
def similarity(self, a: Union[str, List[str]], b: Union[str, List[str]], score_function: str = "cos_sim"):
|
||||
"""
|
||||
@ -166,25 +181,60 @@ class Similarity(SimilarityABC):
|
||||
"""Compute cosine distance between two texts."""
|
||||
return 1 - self.similarity(a, b)
|
||||
|
||||
def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10):
|
||||
def most_similar(self, queries: Union[str, List[str], Dict[str, str]], topn: int = 10,
|
||||
score_function: str = "cos_sim"):
|
||||
"""
|
||||
Find the topn most similar texts to the queries against the corpus.
|
||||
:param queries: str or list of str
|
||||
:param topn: int
|
||||
:param score_function: function to compute similarity, default cos_sim
|
||||
:return: Dict[str, Dict[str, float]], {query_id: {corpus_id: similarity_score}, ...}
|
||||
"""
|
||||
if isinstance(queries, str) or not hasattr(queries, '__len__'):
|
||||
queries = [queries]
|
||||
if isinstance(queries, list):
|
||||
queries = {id: query for id, query in enumerate(queries)}
|
||||
if score_function not in self.score_functions:
|
||||
raise ValueError(f"score function: {score_function} must be either (cos_sim) for cosine similarity"
|
||||
" or (dot) for dot product")
|
||||
score_function = self.score_functions[score_function]
|
||||
result = {qid: {} for qid, query in queries.items()}
|
||||
queries_ids_map = {i: id for i, id in enumerate(list(queries.keys()))}
|
||||
queries_texts = list(queries.values())
|
||||
queries_embeddings = self._get_vector(queries_texts)
|
||||
corpus_embeddings = np.array(self.corpus_embeddings, dtype=np.float32)
|
||||
all_hits = semantic_search(queries_embeddings, corpus_embeddings, top_k=topn)
|
||||
all_hits = semantic_search(queries_embeddings, corpus_embeddings, top_k=topn, score_function=score_function)
|
||||
for idx, hits in enumerate(all_hits):
|
||||
for hit in hits[0:topn]:
|
||||
result[queries_ids_map[idx]][self.corpus_ids_map[hit['corpus_id']]] = hit['score']
|
||||
result[queries_ids_map[idx]][hit['corpus_id']] = hit['score']
|
||||
|
||||
return result
|
||||
|
||||
def save_index(self, index_path: str = "corpus_emb.json"):
|
||||
"""
|
||||
Save corpus embeddings to json file.
|
||||
:param index_path: json file path
|
||||
:return:
|
||||
"""
|
||||
corpus_emb = {id: {"doc": self.corpus[id], "doc_emb": emb} for id, emb in
|
||||
zip(self.corpus.keys(), self.corpus_embeddings)}
|
||||
with open(index_path, "w", encoding="utf-8") as f:
|
||||
json.dump(corpus_emb, f)
|
||||
logger.debug(f"Save corpus embeddings to file: {index_path}.")
|
||||
|
||||
def load_index(self, index_path: str = "corpus_emb.json"):
|
||||
"""
|
||||
Load corpus embeddings from json file.
|
||||
:param index_path: json file path
|
||||
:return: list of corpus embeddings, dict of corpus ids map, dict of corpus
|
||||
"""
|
||||
try:
|
||||
with open(index_path, "r", encoding="utf-8") as f:
|
||||
corpus_emb = json.load(f)
|
||||
corpus_embeddings = []
|
||||
for id, corpus_dict in corpus_emb.items():
|
||||
self.corpus[int(id)] = corpus_dict["doc"]
|
||||
corpus_embeddings.append(corpus_dict["doc_emb"])
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
except (IOError, json.JSONDecodeError):
|
||||
logger.error("Error: Could not load corpus embeddings from file.")
|
||||
|
@ -3,12 +3,3 @@
|
||||
@author:XuMing(xuming624@qq.com)
|
||||
@description:
|
||||
"""
|
||||
|
||||
from .distance import *
|
||||
from .get_file import *
|
||||
from .imagehash import *
|
||||
from .ngram_util import *
|
||||
from .rank_bm25 import *
|
||||
from .tfidf import *
|
||||
from .tokenizer import *
|
||||
from .util import *
|
||||
|
@ -1,174 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@author:XuMing(xuming624@qq.com)
|
||||
@description:
|
||||
"""
|
||||
|
||||
|
||||
class NgramUtil(object):
|
||||
@staticmethod
|
||||
def unigrams(words):
|
||||
"""
|
||||
Input: a list of words, e.g., ["I", "am", "Denny"]
|
||||
Output: a list of unigram
|
||||
"""
|
||||
assert type(words) == list
|
||||
return words
|
||||
|
||||
@staticmethod
|
||||
def bigrams(words, join_string, skip=0):
|
||||
"""
|
||||
Input: a list of words, e.g., ["I", "am", "Denny"]
|
||||
Output: a list of bigram, e.g., ["I_am", "am_Denny"]
|
||||
"""
|
||||
assert type(words) == list
|
||||
L = len(words)
|
||||
if L > 1:
|
||||
lst = []
|
||||
for i in range(L - 1):
|
||||
for k in range(1, skip + 2):
|
||||
if i + k < L:
|
||||
lst.append(join_string.join([words[i], words[i + k]]))
|
||||
else:
|
||||
# set it as unigram
|
||||
lst = NgramUtil.unigrams(words)
|
||||
return lst
|
||||
|
||||
@staticmethod
|
||||
def trigrams(words, join_string, skip=0):
|
||||
"""
|
||||
Input: a list of words, e.g., ["I", "am", "Denny"]
|
||||
Output: a list of trigram, e.g., ["I_am_Denny"]
|
||||
"""
|
||||
assert type(words) == list
|
||||
L = len(words)
|
||||
if L > 2:
|
||||
lst = []
|
||||
for i in range(L - 2):
|
||||
for k1 in range(1, skip + 2):
|
||||
for k2 in range(1, skip + 2):
|
||||
if i + k1 < L and i + k1 + k2 < L:
|
||||
lst.append(join_string.join([words[i], words[i + k1], words[i + k1 + k2]]))
|
||||
else:
|
||||
# set it as bigram
|
||||
lst = NgramUtil.bigrams(words, join_string, skip)
|
||||
return lst
|
||||
|
||||
@staticmethod
|
||||
def fourgrams(words, join_string):
|
||||
"""
|
||||
Input: a list of words, e.g., ["I", "am", "Denny", "boy"]
|
||||
Output: a list of trigram, e.g., ["I_am_Denny_boy"]
|
||||
"""
|
||||
assert type(words) == list
|
||||
L = len(words)
|
||||
if L > 3:
|
||||
lst = []
|
||||
for i in range(L - 3):
|
||||
lst.append(join_string.join([words[i], words[i + 1], words[i + 2], words[i + 3]]))
|
||||
else:
|
||||
# set it as trigram
|
||||
lst = NgramUtil.trigrams(words, join_string)
|
||||
return lst
|
||||
|
||||
@staticmethod
|
||||
def uniterms(words):
|
||||
return NgramUtil.unigrams(words)
|
||||
|
||||
@staticmethod
|
||||
def biterms(words, join_string):
|
||||
"""
|
||||
Input: a list of words, e.g., ["I", "am", "Denny", "boy"]
|
||||
Output: a list of biterm, e.g., ["I_am", "I_Denny", "I_boy", "am_Denny", "am_boy", "Denny_boy"]
|
||||
"""
|
||||
assert type(words) == list
|
||||
L = len(words)
|
||||
if L > 1:
|
||||
lst = []
|
||||
for i in range(L - 1):
|
||||
for j in range(i + 1, L):
|
||||
lst.append(join_string.join([words[i], words[j]]))
|
||||
else:
|
||||
# set it as uniterm
|
||||
lst = NgramUtil.uniterms(words)
|
||||
return lst
|
||||
|
||||
@staticmethod
|
||||
def triterms(words, join_string):
|
||||
"""
|
||||
Input: a list of words, e.g., ["I", "am", "Denny", "boy"]
|
||||
Output: a list of triterm, e.g., ["I_am_Denny", "I_am_boy", "I_Denny_boy", "am_Denny_boy"]
|
||||
"""
|
||||
assert type(words) == list
|
||||
L = len(words)
|
||||
if L > 2:
|
||||
lst = []
|
||||
for i in range(L - 2):
|
||||
for j in range(i + 1, L - 1):
|
||||
for k in range(j + 1, L):
|
||||
lst.append(join_string.join([words[i], words[j], words[k]]))
|
||||
else:
|
||||
# set it as biterm
|
||||
lst = NgramUtil.biterms(words, join_string)
|
||||
return lst
|
||||
|
||||
@staticmethod
|
||||
def fourterms(words, join_string):
|
||||
"""
|
||||
Input: a list of words, e.g., ["I", "am", "Denny", "boy", "ha"]
|
||||
Output: a list of fourterm, e.g., ["I_am_Denny_boy", "I_am_Denny_ha", "I_am_boy_ha", "I_Denny_boy_ha", "am_Denny_boy_ha"]
|
||||
"""
|
||||
assert type(words) == list
|
||||
L = len(words)
|
||||
if L > 3:
|
||||
lst = []
|
||||
for i in range(L - 3):
|
||||
for j in range(i + 1, L - 2):
|
||||
for k in range(j + 1, L - 1):
|
||||
for l in range(k + 1, L):
|
||||
lst.append(join_string.join([words[i], words[j], words[k], words[l]]))
|
||||
else:
|
||||
# set it as triterm
|
||||
lst = NgramUtil.triterms(words, join_string)
|
||||
return lst
|
||||
|
||||
@staticmethod
|
||||
def ngrams(words, ngram, join_string=" "):
|
||||
"""
|
||||
wrapper for ngram
|
||||
"""
|
||||
if ngram == 1:
|
||||
return NgramUtil.unigrams(words)
|
||||
elif ngram == 2:
|
||||
return NgramUtil.bigrams(words, join_string)
|
||||
elif ngram == 3:
|
||||
return NgramUtil.trigrams(words, join_string)
|
||||
elif ngram == 4:
|
||||
return NgramUtil.fourgrams(words, join_string)
|
||||
elif ngram == 12:
|
||||
unigram = NgramUtil.unigrams(words)
|
||||
bigram = [x for x in NgramUtil.bigrams(words, join_string) if len(x.split(join_string)) == 2]
|
||||
return unigram + bigram
|
||||
elif ngram == 123:
|
||||
unigram = NgramUtil.unigrams(words)
|
||||
bigram = [x for x in NgramUtil.bigrams(words, join_string) if len(x.split(join_string)) == 2]
|
||||
trigram = [x for x in NgramUtil.trigrams(words, join_string) if len(x.split(join_string)) == 3]
|
||||
return unigram + bigram + trigram
|
||||
elif ngram == 1234:
|
||||
unigram = NgramUtil.unigrams(words)
|
||||
bigram = [x for x in NgramUtil.bigrams(words, join_string) if len(x.split(join_string)) == 2]
|
||||
trigram = [x for x in NgramUtil.trigrams(words, join_string) if len(x.split(join_string)) == 3]
|
||||
fourgram = [x for x in NgramUtil.fourgrams(words, join_string) if len(x.split(join_string)) == 4]
|
||||
return unigram + bigram + trigram + fourgram
|
||||
|
||||
@staticmethod
|
||||
def nterms(words, nterm, join_string=" "):
|
||||
"""wrapper for nterm"""
|
||||
if nterm == 1:
|
||||
return NgramUtil.uniterms(words)
|
||||
elif nterm == 2:
|
||||
return NgramUtil.biterms(words, join_string)
|
||||
elif nterm == 3:
|
||||
return NgramUtil.triterms(words, join_string)
|
||||
elif nterm == 4:
|
||||
return NgramUtil.fourterms(words, join_string)
|
@ -1,31 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@author:XuMing(xuming624@qq.com)
|
||||
@description:
|
||||
"""
|
||||
import os
|
||||
import jieba
|
||||
import logging
|
||||
|
||||
|
||||
class JiebaTokenizer(object):
|
||||
def __init__(self, dict_path='', custom_word_freq_dict=None):
|
||||
self.model = jieba
|
||||
self.model.default_logger.setLevel(logging.ERROR)
|
||||
# 初始化大词典
|
||||
if os.path.exists(dict_path):
|
||||
self.model.set_dictionary(dict_path)
|
||||
# 加载用户自定义词典
|
||||
if custom_word_freq_dict:
|
||||
for w, f in custom_word_freq_dict.items():
|
||||
self.model.add_word(w, freq=f)
|
||||
|
||||
def tokenize(self, sentence, cut_all=False, HMM=True):
|
||||
"""
|
||||
切词并返回切词位置
|
||||
:param sentence: 句子
|
||||
:param cut_all: 全模式,默认关闭
|
||||
:param HMM: 是否打开NER识别,默认打开
|
||||
:return: A list of strings.
|
||||
"""
|
||||
return self.model.lcut(sentence, cut_all=cut_all, HMM=HMM)
|
@ -103,8 +103,8 @@ def semantic_search(
|
||||
This function performs a cosine similarity search between a list of query embeddings and a list of corpus embeddings.
|
||||
It can be used for Information Retrieval / Semantic Search for corpora up to about 1 Million entries.
|
||||
|
||||
:param query_embeddings: A 2 dimensional tensor with the query embeddings.
|
||||
:param corpus_embeddings: A 2 dimensional tensor with the corpus embeddings.
|
||||
:param query_embeddings: A 2-dimensional tensor with the query embeddings.
|
||||
:param corpus_embeddings: A 2-dimensional tensor with the corpus embeddings.
|
||||
:param query_chunk_size: Process 100 queries simultaneously. Increasing that value increases the speed, but
|
||||
requires more memory.
|
||||
:param corpus_chunk_size: Scans the corpus 100k entries at a time. Increasing that value increases the speed,
|
||||
|
@ -4,4 +4,4 @@
|
||||
@description:
|
||||
"""
|
||||
|
||||
__version__ = '1.0.4'
|
||||
__version__ = '1.0.5'
|
||||
|
Loading…
Reference in New Issue
Block a user