update test case adjust 1.0.0 version.
This commit is contained in:
parent
bc87a2cc58
commit
fb630e3a79
@ -42,4 +42,4 @@ for q_id, c in res.items():
|
||||
print('query:', sentences[q_id])
|
||||
print("search top 3:")
|
||||
for corpus_id, s in c.items():
|
||||
print(f'\t{corpus[corpus_id]}: {s:.4f}')
|
||||
print(f'\t{model.corpus[corpus_id]}: {s:.4f}')
|
||||
|
@ -47,4 +47,4 @@ for q_id, c in res.items():
|
||||
print('query:', sentences1[q_id])
|
||||
print("search top 3:")
|
||||
for corpus_id, s in c.items():
|
||||
print(f'\t{corpus[corpus_id]}: {s:.4f}')
|
||||
print(f'\t{model.corpus[corpus_id]}: {s:.4f}')
|
||||
|
@ -20,24 +20,37 @@ from similarities.evaluation import evaluate
|
||||
random.seed(42)
|
||||
|
||||
pwd_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
#### Download scifact.zip dataset and unzip the dataset
|
||||
dataset = "scifact"
|
||||
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
|
||||
zip_file = os.path.join(pwd_path, "scifact.zip")
|
||||
if not os.path.exists(zip_file):
|
||||
logger.info("Dataset not exists, downloading...")
|
||||
http_get(url, zip_file)
|
||||
else:
|
||||
logger.info("Dataset already exists, skipping download.")
|
||||
def get_scifact():
|
||||
# Download scifact.zip dataset and unzip the dataset
|
||||
dataset = "scifact"
|
||||
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
|
||||
zip_file = os.path.join(pwd_path, "scifact.zip")
|
||||
if not os.path.exists(zip_file):
|
||||
logger.info("Dataset not exists, downloading...")
|
||||
http_get(url, zip_file, extract=True)
|
||||
else:
|
||||
logger.info("Dataset already exists, skipping download.")
|
||||
data_path = os.path.join(pwd_path, dataset)
|
||||
return data_path
|
||||
|
||||
|
||||
def get_dbpedia():
|
||||
dataset = "dbpedia-entity"
|
||||
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
|
||||
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "dbpedia-entity.zip")
|
||||
data_path = http_get(url, out_dir)
|
||||
zip_file = os.path.join(pwd_path, "dbpedia-entity.zip")
|
||||
if not os.path.exists(zip_file):
|
||||
logger.info("Dataset not exists, downloading...")
|
||||
http_get(url, zip_file, extract=True)
|
||||
else:
|
||||
logger.info("Dataset already exists, skipping download.")
|
||||
data_path = os.path.join(pwd_path, dataset)
|
||||
return data_path
|
||||
|
||||
data_path = os.path.join(pwd_path, dataset)
|
||||
|
||||
data_path = get_scifact()
|
||||
#### Loading test queries and corpus in DBPedia
|
||||
corpus, queries, qrels = SearchDataLoader(data_path).load(split="test")
|
||||
corpus_ids, query_ids = list(corpus), list(queries)
|
||||
|
@ -8,12 +8,13 @@ Compute similarity:
|
||||
2. Retrieves most similar sentence of a query against a corpus of documents.
|
||||
"""
|
||||
|
||||
from typing import List, Union, Tuple, Dict, Any
|
||||
|
||||
import os
|
||||
from typing import List, Union, Dict
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from text2vec import SentenceModel
|
||||
|
||||
from similarities.utils.util import cos_sim, semantic_search, dot_score
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
@ -8,66 +8,32 @@ import sys
|
||||
import unittest
|
||||
|
||||
sys.path.append('..')
|
||||
from text2vec import SentenceModel
|
||||
from similarities.fastsim import AnnoySimilarity
|
||||
from similarities.fastsim import HnswlibSimilarity
|
||||
|
||||
sm = SentenceModel()
|
||||
|
||||
|
||||
class FastTestCase(unittest.TestCase):
|
||||
|
||||
def test_sim_diff(self):
|
||||
a = '研究团队面向国家重大战略需求追踪国际前沿发展借鉴国际人工智能研究领域的科研模式有效整合创新资源解决复'
|
||||
b = '英汉互译比较语言学'
|
||||
m = HnswlibSimilarity(sm)
|
||||
r = m.similarity(a, b)[0]
|
||||
m = HnswlibSimilarity()
|
||||
r = float(m.similarity(a, b)[0])
|
||||
print(a, b, r)
|
||||
self.assertTrue(abs(r - 0.1733) < 0.001)
|
||||
m = AnnoySimilarity(sm)
|
||||
r = m.similarity(a, b)[0]
|
||||
self.assertTrue(abs(r - 0.4098) < 0.001)
|
||||
m = AnnoySimilarity()
|
||||
r = float(m.similarity(a, b)[0])
|
||||
print(a, b, r)
|
||||
self.assertTrue(abs(r - 0.1733) < 0.001)
|
||||
self.assertTrue(abs(r - 0.4098) < 0.001)
|
||||
|
||||
def test_empty(self):
|
||||
m = HnswlibSimilarity(sm, embedding_size=384, corpus=[])
|
||||
v = m._get_vector("This is test1")
|
||||
print(v[:10], v.shape)
|
||||
r = m.similarity("This is a test1", "that is a test5")
|
||||
print(r)
|
||||
print(m.distance("This is a test1", "that is a test5"))
|
||||
|
||||
m = AnnoySimilarity(sm)
|
||||
r = m.similarity("This is a test1", "that is a test5")
|
||||
self.assertTrue(r[0] > 0.0)
|
||||
|
||||
def test_hnsw_score(self):
|
||||
list_of_docs = ["This is a test1", "This is a test2", "This is a test3", '刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌']
|
||||
list_of_docs2 = ["that is test4", "that is a test5", "that is a test6", '刘若英个演员', '唱歌很好听', 'men喜欢这首歌']
|
||||
|
||||
m = HnswlibSimilarity(sm, embedding_size=384, corpus=list_of_docs * 10)
|
||||
v = m._get_vector("This is test1")
|
||||
print(v[:10], v.shape)
|
||||
r = m.similarity("This is a test1", "that is a test5")
|
||||
print(r)
|
||||
self.assertTrue(r[0] > 0.5)
|
||||
print(m.distance("This is a test1", "that is a test5"))
|
||||
print(m.most_similar("This is a test4"))
|
||||
print(m.most_similar("men喜欢这首歌"))
|
||||
m.add_corpus(list_of_docs2)
|
||||
print(m.most_similar("This is a test4"))
|
||||
r = m.most_similar("men喜欢这首歌", topn=5)
|
||||
print(r)
|
||||
self.assertTrue(len(r[0]) == 5)
|
||||
|
||||
def test_hnswlib_model_save_load(self):
|
||||
list_of_docs = ["This is a test1", "This is a test2", "This is a test3", '刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌']
|
||||
list_of_docs2 = ["that is test4", "that is a test5", "that is a test6", '刘若英个演员', '唱歌很好听', 'men喜欢这首歌']
|
||||
|
||||
m = HnswlibSimilarity(sm, embedding_size=384, corpus=list_of_docs * 10)
|
||||
print(m.most_similar("This is a test4"))
|
||||
print(m.most_similar("men喜欢这首歌"))
|
||||
corpus_new = [i + str(id) for id, i in enumerate(list_of_docs * 10)]
|
||||
m = HnswlibSimilarity(corpus=list_of_docs * 10)
|
||||
m.add_corpus(list_of_docs2)
|
||||
m.add_corpus(corpus_new)
|
||||
m.build_index()
|
||||
print(m.most_similar("This is a test4"))
|
||||
print(m.most_similar("men喜欢这首歌"))
|
||||
@ -83,8 +49,8 @@ class FastTestCase(unittest.TestCase):
|
||||
def test_annoy_model(self):
|
||||
list_of_docs = ["This is a test1", "This is a test2", "This is a test3", '刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌']
|
||||
list_of_docs2 = ["that is test4", "that is a test5", "that is a test6", '刘若英个演员', '唱歌很好听', 'men喜欢这首歌']
|
||||
|
||||
m = AnnoySimilarity(sm, embedding_size=384, corpus=list_of_docs * 10)
|
||||
corpus_new = [i + str(id) for id, i in enumerate(list_of_docs * 10)]
|
||||
m = AnnoySimilarity(corpus=list_of_docs * 10)
|
||||
print(m)
|
||||
v = m._get_vector("This is test1")
|
||||
print(v[:10], v.shape)
|
||||
@ -93,6 +59,7 @@ class FastTestCase(unittest.TestCase):
|
||||
print(m.most_similar("This is a test4"))
|
||||
print(m.most_similar("men喜欢这首歌"))
|
||||
m.add_corpus(list_of_docs2)
|
||||
m.add_corpus(corpus_new)
|
||||
m.build_index()
|
||||
print(m.most_similar("This is a test4"))
|
||||
r = m.most_similar("men喜欢这首歌", topn=1)
|
||||
|
@ -118,11 +118,10 @@ class LiteralCase(unittest.TestCase):
|
||||
"""test_word2vec"""
|
||||
text1 = '刘若英是个演员'
|
||||
text2 = '他唱歌很好听'
|
||||
wm = Word2Vec()
|
||||
list_of_corpus = ["This is a test1", "This is a test2", "This is a test3"]
|
||||
list_of_corpus2 = ["that is test4", "that is a test5", "that is a test6"]
|
||||
zh_list = ['刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌', '刘若英是个演员', '演戏很好看的人']
|
||||
m = WordEmbeddingSimilarity(wm, list_of_corpus)
|
||||
m = WordEmbeddingSimilarity(list_of_corpus)
|
||||
print(m.similarity(text1, text2))
|
||||
print(m.distance(text1, text2))
|
||||
m.add_corpus(list_of_corpus2 + zh_list)
|
||||
|
@ -7,11 +7,9 @@ import sys
|
||||
import unittest
|
||||
|
||||
sys.path.append('..')
|
||||
from text2vec import SentenceModel
|
||||
from similarities.similarity import Similarity
|
||||
|
||||
sm = SentenceModel()
|
||||
m = Similarity(sm)
|
||||
m = Similarity()
|
||||
|
||||
|
||||
class SimScoreTestCase(unittest.TestCase):
|
||||
@ -19,9 +17,10 @@ class SimScoreTestCase(unittest.TestCase):
|
||||
def test_sim_diff(self):
|
||||
a = '研究团队面向国家重大战略需求追踪国际前沿发展借鉴国际人工智能研究领域的科研模式有效整合创新资源解决复'
|
||||
b = '英汉互译比较语言学'
|
||||
r = m.similarity(a, b)[0]
|
||||
r = m.similarity(a, b)[0][0]
|
||||
r = float(r)
|
||||
print(a, b, r)
|
||||
self.assertTrue(abs(r - 0.1733) < 0.001)
|
||||
self.assertTrue(abs(r - 0.4098) < 0.001)
|
||||
|
||||
def test_empty(self):
|
||||
v = m._get_vector("This is test1")
|
||||
|
Loading…
Reference in New Issue
Block a user