update test case adjust 1.0.0 version.

This commit is contained in:
shibing624 2022-03-11 17:01:19 +08:00
parent bc87a2cc58
commit fb630e3a79
7 changed files with 46 additions and 67 deletions

View File

@ -42,4 +42,4 @@ for q_id, c in res.items():
print('query:', sentences[q_id])
print("search top 3:")
for corpus_id, s in c.items():
print(f'\t{corpus[corpus_id]}: {s:.4f}')
print(f'\t{model.corpus[corpus_id]}: {s:.4f}')

View File

@ -47,4 +47,4 @@ for q_id, c in res.items():
print('query:', sentences1[q_id])
print("search top 3:")
for corpus_id, s in c.items():
print(f'\t{corpus[corpus_id]}: {s:.4f}')
print(f'\t{model.corpus[corpus_id]}: {s:.4f}')

View File

@ -20,24 +20,37 @@ from similarities.evaluation import evaluate
random.seed(42)
pwd_path = os.path.dirname(os.path.realpath(__file__))
#### Download scifact.zip dataset and unzip the dataset
dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
zip_file = os.path.join(pwd_path, "scifact.zip")
if not os.path.exists(zip_file):
logger.info("Dataset not exists, downloading...")
http_get(url, zip_file)
else:
logger.info("Dataset already exists, skipping download.")
def get_scifact():
# Download scifact.zip dataset and unzip the dataset
dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
zip_file = os.path.join(pwd_path, "scifact.zip")
if not os.path.exists(zip_file):
logger.info("Dataset not exists, downloading...")
http_get(url, zip_file, extract=True)
else:
logger.info("Dataset already exists, skipping download.")
data_path = os.path.join(pwd_path, dataset)
return data_path
def get_dbpedia():
dataset = "dbpedia-entity"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "dbpedia-entity.zip")
data_path = http_get(url, out_dir)
zip_file = os.path.join(pwd_path, "dbpedia-entity.zip")
if not os.path.exists(zip_file):
logger.info("Dataset not exists, downloading...")
http_get(url, zip_file, extract=True)
else:
logger.info("Dataset already exists, skipping download.")
data_path = os.path.join(pwd_path, dataset)
return data_path
data_path = os.path.join(pwd_path, dataset)
data_path = get_scifact()
#### Loading test queries and corpus in DBPedia
corpus, queries, qrels = SearchDataLoader(data_path).load(split="test")
corpus_ids, query_ids = list(corpus), list(queries)

View File

@ -8,12 +8,13 @@ Compute similarity:
2. Retrieves most similar sentence of a query against a corpus of documents.
"""
from typing import List, Union, Tuple, Dict, Any
import os
from typing import List, Union, Dict
import numpy as np
from loguru import logger
from text2vec import SentenceModel
from similarities.utils.util import cos_sim, semantic_search, dot_score
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

View File

@ -8,66 +8,32 @@ import sys
import unittest
sys.path.append('..')
from text2vec import SentenceModel
from similarities.fastsim import AnnoySimilarity
from similarities.fastsim import HnswlibSimilarity
sm = SentenceModel()
class FastTestCase(unittest.TestCase):
def test_sim_diff(self):
a = '研究团队面向国家重大战略需求追踪国际前沿发展借鉴国际人工智能研究领域的科研模式有效整合创新资源解决复'
b = '英汉互译比较语言学'
m = HnswlibSimilarity(sm)
r = m.similarity(a, b)[0]
m = HnswlibSimilarity()
r = float(m.similarity(a, b)[0])
print(a, b, r)
self.assertTrue(abs(r - 0.1733) < 0.001)
m = AnnoySimilarity(sm)
r = m.similarity(a, b)[0]
self.assertTrue(abs(r - 0.4098) < 0.001)
m = AnnoySimilarity()
r = float(m.similarity(a, b)[0])
print(a, b, r)
self.assertTrue(abs(r - 0.1733) < 0.001)
self.assertTrue(abs(r - 0.4098) < 0.001)
def test_empty(self):
m = HnswlibSimilarity(sm, embedding_size=384, corpus=[])
v = m._get_vector("This is test1")
print(v[:10], v.shape)
r = m.similarity("This is a test1", "that is a test5")
print(r)
print(m.distance("This is a test1", "that is a test5"))
m = AnnoySimilarity(sm)
r = m.similarity("This is a test1", "that is a test5")
self.assertTrue(r[0] > 0.0)
def test_hnsw_score(self):
list_of_docs = ["This is a test1", "This is a test2", "This is a test3", '刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌']
list_of_docs2 = ["that is test4", "that is a test5", "that is a test6", '刘若英个演员', '唱歌很好听', 'men喜欢这首歌']
m = HnswlibSimilarity(sm, embedding_size=384, corpus=list_of_docs * 10)
v = m._get_vector("This is test1")
print(v[:10], v.shape)
r = m.similarity("This is a test1", "that is a test5")
print(r)
self.assertTrue(r[0] > 0.5)
print(m.distance("This is a test1", "that is a test5"))
print(m.most_similar("This is a test4"))
print(m.most_similar("men喜欢这首歌"))
m.add_corpus(list_of_docs2)
print(m.most_similar("This is a test4"))
r = m.most_similar("men喜欢这首歌", topn=5)
print(r)
self.assertTrue(len(r[0]) == 5)
def test_hnswlib_model_save_load(self):
list_of_docs = ["This is a test1", "This is a test2", "This is a test3", '刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌']
list_of_docs2 = ["that is test4", "that is a test5", "that is a test6", '刘若英个演员', '唱歌很好听', 'men喜欢这首歌']
m = HnswlibSimilarity(sm, embedding_size=384, corpus=list_of_docs * 10)
print(m.most_similar("This is a test4"))
print(m.most_similar("men喜欢这首歌"))
corpus_new = [i + str(id) for id, i in enumerate(list_of_docs * 10)]
m = HnswlibSimilarity(corpus=list_of_docs * 10)
m.add_corpus(list_of_docs2)
m.add_corpus(corpus_new)
m.build_index()
print(m.most_similar("This is a test4"))
print(m.most_similar("men喜欢这首歌"))
@ -83,8 +49,8 @@ class FastTestCase(unittest.TestCase):
def test_annoy_model(self):
list_of_docs = ["This is a test1", "This is a test2", "This is a test3", '刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌']
list_of_docs2 = ["that is test4", "that is a test5", "that is a test6", '刘若英个演员', '唱歌很好听', 'men喜欢这首歌']
m = AnnoySimilarity(sm, embedding_size=384, corpus=list_of_docs * 10)
corpus_new = [i + str(id) for id, i in enumerate(list_of_docs * 10)]
m = AnnoySimilarity(corpus=list_of_docs * 10)
print(m)
v = m._get_vector("This is test1")
print(v[:10], v.shape)
@ -93,6 +59,7 @@ class FastTestCase(unittest.TestCase):
print(m.most_similar("This is a test4"))
print(m.most_similar("men喜欢这首歌"))
m.add_corpus(list_of_docs2)
m.add_corpus(corpus_new)
m.build_index()
print(m.most_similar("This is a test4"))
r = m.most_similar("men喜欢这首歌", topn=1)

View File

@ -118,11 +118,10 @@ class LiteralCase(unittest.TestCase):
"""test_word2vec"""
text1 = '刘若英是个演员'
text2 = '他唱歌很好听'
wm = Word2Vec()
list_of_corpus = ["This is a test1", "This is a test2", "This is a test3"]
list_of_corpus2 = ["that is test4", "that is a test5", "that is a test6"]
zh_list = ['刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌', '刘若英是个演员', '演戏很好看的人']
m = WordEmbeddingSimilarity(wm, list_of_corpus)
m = WordEmbeddingSimilarity(list_of_corpus)
print(m.similarity(text1, text2))
print(m.distance(text1, text2))
m.add_corpus(list_of_corpus2 + zh_list)

View File

@ -7,11 +7,9 @@ import sys
import unittest
sys.path.append('..')
from text2vec import SentenceModel
from similarities.similarity import Similarity
sm = SentenceModel()
m = Similarity(sm)
m = Similarity()
class SimScoreTestCase(unittest.TestCase):
@ -19,9 +17,10 @@ class SimScoreTestCase(unittest.TestCase):
def test_sim_diff(self):
a = '研究团队面向国家重大战略需求追踪国际前沿发展借鉴国际人工智能研究领域的科研模式有效整合创新资源解决复'
b = '英汉互译比较语言学'
r = m.similarity(a, b)[0]
r = m.similarity(a, b)[0][0]
r = float(r)
print(a, b, r)
self.assertTrue(abs(r - 0.1733) < 0.001)
self.assertTrue(abs(r - 0.4098) < 0.001)
def test_empty(self):
v = m._get_vector("This is test1")