add image demo.
This commit is contained in:
parent
5bd0f05749
commit
a121c3e0d1
109
README.md
109
README.md
@ -57,7 +57,7 @@ python3 setup.py install
|
||||
|
||||
# Usage
|
||||
|
||||
### 1. 计算两个句子的相似度值
|
||||
### 1. 文本语义相似度计算
|
||||
|
||||
```shell
|
||||
>>> from similarities import Similarity
|
||||
@ -70,7 +70,7 @@ similarity score: 0.8551
|
||||
|
||||
> 余弦值`score`范围是[-1, 1],值越大越相似。
|
||||
|
||||
### 2. 文档集中相似文本搜索
|
||||
### 2. 文本语义匹配搜索
|
||||
|
||||
一般在文档候选集中找与query最相似的文本,常用于QA场景的问句相似匹配、文本相似检索等任务。
|
||||
|
||||
@ -78,38 +78,40 @@ similarity score: 0.8551
|
||||
中文示例[examples/base_demo.py](./examples/base_demo.py)
|
||||
|
||||
```python
|
||||
import sys
|
||||
|
||||
sys.path.append('..')
|
||||
from similarities import Similarity
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = Similarity("shibing624/text2vec-base-chinese")
|
||||
# 1.Compute cosine similarity between two sentences.
|
||||
sentences = ['如何更换花呗绑定银行卡',
|
||||
'花呗更改绑定银行卡']
|
||||
corpus = [
|
||||
'花呗更改绑定银行卡',
|
||||
'我什么时候开通了花呗',
|
||||
'俄罗斯警告乌克兰反对欧盟协议',
|
||||
'暴风雨掩埋了东北部;新泽西16英寸的降雪',
|
||||
'中央情报局局长访问以色列叙利亚会谈',
|
||||
'人在巴基斯坦基地的炸弹袭击中丧生',
|
||||
]
|
||||
similarity_score = model.similarity(sentences[0], sentences[1])
|
||||
print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}")
|
||||
|
||||
# 2.Compute similarity between two list
|
||||
similarity_scores = model.similarity(sentences, corpus)
|
||||
print(similarity_scores.numpy())
|
||||
for i in range(len(sentences)):
|
||||
for j in range(len(corpus)):
|
||||
print(f"{sentences[i]} vs {corpus[j]}, score: {similarity_scores.numpy()[i][j]:.4f}")
|
||||
# 1.Compute cosine similarity between two sentences.
|
||||
sentences = ['如何更换花呗绑定银行卡',
|
||||
'花呗更改绑定银行卡']
|
||||
corpus = [
|
||||
'花呗更改绑定银行卡',
|
||||
'我什么时候开通了花呗',
|
||||
'俄罗斯警告乌克兰反对欧盟协议',
|
||||
'暴风雨掩埋了东北部;新泽西16英寸的降雪',
|
||||
'中央情报局局长访问以色列叙利亚会谈',
|
||||
'人在巴基斯坦基地的炸弹袭击中丧生',
|
||||
]
|
||||
model = Similarity("shibing624/text2vec-base-chinese")
|
||||
print(model)
|
||||
similarity_score = model.similarity(sentences[0], sentences[1])
|
||||
print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}")
|
||||
|
||||
# 3.Semantic Search
|
||||
m = Similarity("shibing624/text2vec-base-chinese", corpus=corpus)
|
||||
q = '如何更换花呗绑定银行卡'
|
||||
print(m.most_similar(q, topn=5))
|
||||
print("query:", q)
|
||||
for i in m.most_similar(q, topn=5):
|
||||
print('\t', i)
|
||||
# 2.Compute similarity between two list
|
||||
similarity_scores = model.similarity(sentences, corpus)
|
||||
print(similarity_scores.numpy())
|
||||
for i in range(len(sentences)):
|
||||
for j in range(len(corpus)):
|
||||
print(f"{sentences[i]} vs {corpus[j]}, score: {similarity_scores.numpy()[i][j]:.4f}")
|
||||
|
||||
# 3.Semantic Search
|
||||
model.add_corpus(corpus)
|
||||
q = '如何更换花呗绑定银行卡'
|
||||
print("query:", q)
|
||||
for i in model.most_similar(q, topn=5):
|
||||
print('\t', i)
|
||||
```
|
||||
|
||||
output:
|
||||
@ -143,22 +145,22 @@ query: 如何更换花呗绑定银行卡
|
||||
英文示例[examples/base_english_demo.py](./examples/base_english_demo.py)
|
||||
|
||||
|
||||
### 3. 快速近似匹配搜索
|
||||
### 3. 快速近似语义匹配搜索
|
||||
|
||||
支持Annoy、Hnswlib的近似匹配搜索,常用于百万数据集的匹配搜索任务。
|
||||
支持Annoy、Hnswlib的近似语义匹配搜索,常用于百万数据集的匹配搜索任务。
|
||||
|
||||
|
||||
示例[examples/fast_sim_demo.py](./examples/fast_sim_demo.py)
|
||||
|
||||
|
||||
### 4. 基于字面的文本相似度计算
|
||||
### 4. 基于字面的文本相似度计算和匹配搜索
|
||||
|
||||
支持同义词词林(Cilin)、知网Hownet、词向量(WordEmbedding)、Tfidf、Simhash、BM25等算法的相似度计算和匹配搜索,常用于文本匹配冷启动。
|
||||
支持同义词词林(Cilin)、知网Hownet、词向量(WordEmbedding)、Tfidf、SimHash、BM25等算法的相似度计算和字面匹配搜索,常用于文本匹配冷启动。
|
||||
|
||||
示例[examples/literal_sim_demo.py](./examples/literal_sim_demo.py)
|
||||
|
||||
```python
|
||||
from similarities.literalsim import SimHashSimilarity, TfidfSimilarity, BM25Similarity,
|
||||
from similarities.literalsim import SimHashSimilarity, TfidfSimilarity, BM25Similarity, \
|
||||
WordEmbeddingSimilarity, CilinSimilarity, HownetSimilarity
|
||||
|
||||
text1 = "如何更换花呗绑定银行卡"
|
||||
@ -166,7 +168,7 @@ text2 = "花呗更改绑定银行卡"
|
||||
|
||||
m = TfidfSimilarity()
|
||||
print(text1, text2, ' sim score: ', m.similarity(text1, text2))
|
||||
print('distance:', m.distance(text1, text2))
|
||||
|
||||
zh_list = ['刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌', '我不是演员吗']
|
||||
m.add_corpus(zh_list)
|
||||
print(m.most_similar('刘若英是演员'))
|
||||
@ -175,11 +177,42 @@ print(m.most_similar('刘若英是演员'))
|
||||
output:
|
||||
```shell
|
||||
如何更换花呗绑定银行卡 花呗更改绑定银行卡 sim score: 0.8203384355246909
|
||||
distance: 0.17966156447530912
|
||||
|
||||
[(0, '刘若英是个演员', 0.9847577834309504), (3, '我不是演员吗', 0.7056381915655814), (1, '他唱歌很好听', 0.5), (2, 'women喜欢这首歌', 0.5)]
|
||||
```
|
||||
|
||||
### 5. 图像相似度计算和匹配搜索
|
||||
|
||||
支持[CLIP](similarities/imagesim.py)、pHash、SIFT等算法的图像相似度计算和匹配搜索。
|
||||
|
||||
示例[examples/image_demo.py](./examples/image_demo.py)
|
||||
|
||||
```python
|
||||
import sys
|
||||
import glob
|
||||
|
||||
sys.path.append('..')
|
||||
from similarities.imagesim import ImageHashSimilarity, SiftSimilarity, ClipSimilarity
|
||||
|
||||
image_fp1 = 'data/image1.png'
|
||||
image_fp2 = 'data/image12-like-image1.png'
|
||||
m = ClipSimilarity()
|
||||
print(m)
|
||||
print(m.similarity(image_fp1, image_fp2))
|
||||
# add corpus
|
||||
m.add_corpus(glob.glob('data/*.jpg') + glob.glob('data/*.png'))
|
||||
r = m.most_similar(image_fp1)
|
||||
print(r)
|
||||
```
|
||||
|
||||
output:
|
||||
```shell
|
||||
0.9579
|
||||
|
||||
[(6, 'data/image1.png', 1.0), (0, 'data/image12-like-image1.png', 0.9579654335975647), (4, 'data/image8-like-image1.png', 0.9326782822608948), ... ]
|
||||
```
|
||||
![image_sim](docs/image_sim.png)
|
||||
|
||||
# Contact
|
||||
|
||||
- Issue(建议):[![GitHub issues](https://img.shields.io/github/issues/shibing624/similarities.svg)](https://github.com/shibing624/similarities/issues)
|
||||
|
BIN
docs/image_sim.png
Normal file
BIN
docs/image_sim.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 367 KiB |
@ -10,34 +10,32 @@ import sys
|
||||
sys.path.append('..')
|
||||
from similarities import Similarity
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 1.Compute cosine similarity between two sentences.
|
||||
sentences = ['如何更换花呗绑定银行卡',
|
||||
'花呗更改绑定银行卡']
|
||||
corpus = [
|
||||
'花呗更改绑定银行卡',
|
||||
'我什么时候开通了花呗',
|
||||
'俄罗斯警告乌克兰反对欧盟协议',
|
||||
'暴风雨掩埋了东北部;新泽西16英寸的降雪',
|
||||
'中央情报局局长访问以色列叙利亚会谈',
|
||||
'人在巴基斯坦基地的炸弹袭击中丧生',
|
||||
]
|
||||
model = Similarity("shibing624/text2vec-base-chinese")
|
||||
print(model)
|
||||
similarity_score = model.similarity(sentences[0], sentences[1])
|
||||
print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}")
|
||||
# 1.Compute cosine similarity between two sentences.
|
||||
sentences = ['如何更换花呗绑定银行卡',
|
||||
'花呗更改绑定银行卡']
|
||||
corpus = [
|
||||
'花呗更改绑定银行卡',
|
||||
'我什么时候开通了花呗',
|
||||
'俄罗斯警告乌克兰反对欧盟协议',
|
||||
'暴风雨掩埋了东北部;新泽西16英寸的降雪',
|
||||
'中央情报局局长访问以色列叙利亚会谈',
|
||||
'人在巴基斯坦基地的炸弹袭击中丧生',
|
||||
]
|
||||
model = Similarity("shibing624/text2vec-base-chinese")
|
||||
print(model)
|
||||
similarity_score = model.similarity(sentences[0], sentences[1])
|
||||
print(f"{sentences[0]} vs {sentences[1]}, score: {float(similarity_score):.4f}")
|
||||
|
||||
# 2.Compute similarity between two list
|
||||
similarity_scores = model.similarity(sentences, corpus)
|
||||
print(similarity_scores.numpy())
|
||||
for i in range(len(sentences)):
|
||||
for j in range(len(corpus)):
|
||||
print(f"{sentences[i]} vs {corpus[j]}, score: {similarity_scores.numpy()[i][j]:.4f}")
|
||||
# 2.Compute similarity between two list
|
||||
similarity_scores = model.similarity(sentences, corpus)
|
||||
print(similarity_scores.numpy())
|
||||
for i in range(len(sentences)):
|
||||
for j in range(len(corpus)):
|
||||
print(f"{sentences[i]} vs {corpus[j]}, score: {similarity_scores.numpy()[i][j]:.4f}")
|
||||
|
||||
# 3.Semantic Search
|
||||
model.add_corpus(corpus)
|
||||
q = '如何更换花呗绑定银行卡'
|
||||
print(model.most_similar(q, topn=5))
|
||||
print("query:", q)
|
||||
for i in model.most_similar(q, topn=5):
|
||||
print('\t', i)
|
||||
# 3.Semantic Search
|
||||
model.add_corpus(corpus)
|
||||
q = '如何更换花呗绑定银行卡'
|
||||
print("query:", q)
|
||||
for i in model.most_similar(q, topn=5):
|
||||
print('\t', i)
|
||||
|
@ -20,8 +20,6 @@ def hnswlib():
|
||||
|
||||
m = HnswlibSimilarity(sm, embedding_size=384, corpus=list_of_docs * 10)
|
||||
print(m)
|
||||
v = m._get_vector("This is test1")
|
||||
print(v[:10], v.shape)
|
||||
print(m.similarity("This is a test1", "that is a test5"))
|
||||
print(m.distance("This is a test1", "that is a test5"))
|
||||
print(m.most_similar("This is a test4"))
|
||||
@ -44,8 +42,6 @@ def annoy():
|
||||
|
||||
m = AnnoySimilarity(sm, embedding_size=384, corpus=list_of_docs * 10)
|
||||
print(m)
|
||||
v = m._get_vector("This is test1")
|
||||
print(v[:10], v.shape)
|
||||
print(m.similarity("This is a test1", "that is a test5"))
|
||||
print(m.distance("This is a test1", "that is a test5"))
|
||||
print(m.most_similar("This is a test4"))
|
||||
|
@ -14,8 +14,6 @@ def phash_demo(image_fp1, image_fp2):
|
||||
m = ImageHashSimilarity(hash_function='phash')
|
||||
print(m)
|
||||
print(m.similarity(image_fp1, image_fp2))
|
||||
m.most_similar(image_fp1)
|
||||
# no corpus
|
||||
m.add_corpus(glob.glob('data/*.jpg') + glob.glob('data/*.png'))
|
||||
r = m.most_similar(image_fp1)
|
||||
print(r)
|
||||
@ -23,8 +21,6 @@ def phash_demo(image_fp1, image_fp2):
|
||||
m = ImageHashSimilarity(hash_function='average_hash')
|
||||
print(m)
|
||||
print(m.similarity(image_fp1, image_fp2))
|
||||
m.most_similar(image_fp1)
|
||||
# no corpus
|
||||
m.add_corpus(glob.glob('data/*.jpg') + glob.glob('data/*.png'))
|
||||
r = m.most_similar(image_fp1)
|
||||
print(r)
|
||||
@ -35,7 +31,7 @@ def sift_demo(image_fp1, image_fp2):
|
||||
print(m)
|
||||
print(m.similarity(image_fp1, image_fp2))
|
||||
m.most_similar(image_fp1)
|
||||
# no corpus
|
||||
# add corpus
|
||||
m.add_corpus(glob.glob('data/*.jpg'))
|
||||
m.add_corpus(glob.glob('data/*.png'))
|
||||
r = m.most_similar(image_fp1)
|
||||
@ -46,8 +42,7 @@ def clip_demo(image_fp1, image_fp2):
|
||||
m = ClipSimilarity()
|
||||
print(m)
|
||||
print(m.similarity(image_fp1, image_fp2))
|
||||
m.most_similar(image_fp1)
|
||||
# no corpus
|
||||
# add corpus
|
||||
m.add_corpus(glob.glob('data/*.jpg') + glob.glob('data/*.png'))
|
||||
r = m.most_similar(image_fp1)
|
||||
print(r)
|
||||
@ -59,4 +54,4 @@ if __name__ == "__main__":
|
||||
|
||||
phash_demo(image_fp1, image_fp2)
|
||||
sift_demo(image_fp1, image_fp2)
|
||||
clip_demo(image_fp1, image_fp2)
|
||||
clip_demo(image_fp1, image_fp2) # the best result
|
||||
|
@ -20,6 +20,7 @@ def main():
|
||||
text1 = '刘若英是个演员'
|
||||
text2 = '他唱歌很好听'
|
||||
m = SimHashSimilarity()
|
||||
print(m)
|
||||
print(m.similarity(text1, text2))
|
||||
print(m.distance(text1, text2))
|
||||
print(m.most_similar('刘若英是演员'))
|
||||
@ -30,6 +31,7 @@ def main():
|
||||
text1 = "如何更换花呗绑定银行卡"
|
||||
text2 = "花呗更改绑定银行卡"
|
||||
m = TfidfSimilarity()
|
||||
print(m)
|
||||
print(text1, text2, ' sim score: ', m.similarity(text1, text2))
|
||||
print('distance:', m.distance(text1, text2))
|
||||
zh_list = ['刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌', '我不是演员吗']
|
||||
@ -37,11 +39,13 @@ def main():
|
||||
print(m.most_similar('刘若英是演员'))
|
||||
|
||||
m = BM25Similarity()
|
||||
print(m)
|
||||
zh_list = ['刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌', '我不是演员吗']
|
||||
m.add_corpus(zh_list)
|
||||
print(m.most_similar('刘若英是演员'))
|
||||
|
||||
wm = Word2Vec()
|
||||
print(m)
|
||||
list_of_corpus = ["This is a test1", "This is a test2", "This is a test3"]
|
||||
list_of_corpus2 = ["that is test4", "that is a test5", "that is a test6"]
|
||||
m = WordEmbeddingSimilarity(wm, list_of_corpus)
|
||||
@ -55,6 +59,7 @@ def main():
|
||||
text1 = '周杰伦是一个歌手'
|
||||
text2 = '刘若英是个演员'
|
||||
m = CilinSimilarity()
|
||||
print(m)
|
||||
print(m.similarity(text1, text2))
|
||||
print(m.distance(text1, text2))
|
||||
zh_list = ['刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌']
|
||||
@ -62,6 +67,7 @@ def main():
|
||||
print(m.most_similar('刘若英是演员'))
|
||||
|
||||
m = HownetSimilarity()
|
||||
print(m)
|
||||
print(m.similarity(text1, text2))
|
||||
print(m.distance(text1, text2))
|
||||
zh_list = ['刘若英是个演员', '他唱歌很好听', 'women喜欢这首歌']
|
||||
|
@ -7,7 +7,6 @@ refer: https://colab.research.google.com/drive/1leOzG-AQw5MkzgA4qNW5fb3yc-oJ4Lo4
|
||||
Adjust the code to compare similarity score and search.
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
from typing import List, Union
|
||||
|
||||
import cv2
|
||||
@ -22,7 +21,100 @@ from similarities.utils.distance import hamming_distance
|
||||
from similarities.utils.imagehash import phash, dhash, whash, average_hash
|
||||
from similarities.utils.util import cos_sim
|
||||
|
||||
pwd_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
class ClipSimilarity:
|
||||
"""
|
||||
Compute CLIP similarity between two images and retrieves most
|
||||
similar image for a given image corpus.
|
||||
|
||||
CLIP: https://github.com/openai/CLIP.git
|
||||
"""
|
||||
|
||||
def __init__(self, corpus: List[str] = None, model_name_or_path: str = 'clip-ViT-B-32'):
|
||||
self.corpus = []
|
||||
self.clip_model = SentenceTransformer(model_name_or_path) # load the CLIP model
|
||||
self.corpus_embeddings = []
|
||||
if corpus is not None:
|
||||
self.add_corpus(corpus)
|
||||
|
||||
def __len__(self):
|
||||
"""Get length of corpus."""
|
||||
return len(self.corpus)
|
||||
|
||||
def __str__(self):
|
||||
base = f"Similarity: {self.__class__.__name__}, matching_model: CLIP"
|
||||
if self.corpus:
|
||||
base += f", corpus size: {len(self.corpus)}"
|
||||
return base
|
||||
|
||||
def add_corpus(self, corpus: List[str]):
|
||||
"""
|
||||
Extend the corpus with new documents.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
corpus : list of str
|
||||
"""
|
||||
self.corpus += corpus
|
||||
corpus_embeddings = self._get_vector(corpus).tolist()
|
||||
if self.corpus_embeddings:
|
||||
self.corpus_embeddings += corpus_embeddings
|
||||
else:
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
logger.info(f"Add corpus size: {len(corpus)}, total size: {len(self.corpus)}")
|
||||
|
||||
def _convert_to_rgb(self, img):
|
||||
"""Convert image to RGB mode."""
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
return img
|
||||
|
||||
def _get_vector(self, img_paths: Union[str, List[str]]):
|
||||
"""
|
||||
Returns the embeddings for a batch of images.
|
||||
:param img_paths:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(img_paths, str):
|
||||
img_paths = [img_paths]
|
||||
imgs = [Image.open(filepath) for filepath in img_paths]
|
||||
imgs = [self._convert_to_rgb(img) for img in imgs]
|
||||
return self.clip_model.encode(imgs, batch_size=128, convert_to_tensor=False, show_progress_bar=True)
|
||||
|
||||
def similarity(self, fp1: str, fp2: str):
|
||||
"""
|
||||
Compute similarity between two image files.
|
||||
:param fp1: image file path 1
|
||||
:param fp2: image file path 2
|
||||
:return: similarity score
|
||||
"""
|
||||
emb1 = self._get_vector(fp1)
|
||||
emb2 = self._get_vector(fp2)
|
||||
similarity_score = float(cos_sim(emb1, emb2))
|
||||
|
||||
return similarity_score
|
||||
|
||||
def distance(self, fp1: str, fp2: str):
|
||||
"""Compute distance between two image files."""
|
||||
return 1 - self.similarity(fp1, fp2)
|
||||
|
||||
def most_similar(self, query_fp: str, topn: int = 10):
|
||||
"""
|
||||
Find the topn most similar images to the query against the corpus.
|
||||
:param query_fp: str
|
||||
:param topn: int
|
||||
:return: list of tuples (id, image_path, similarity)
|
||||
"""
|
||||
result = []
|
||||
q_emb = self._get_vector(query_fp)
|
||||
|
||||
# Computes the cosine-similarity between the query embedding and all image embeddings.
|
||||
hits = semantic_search(q_emb, np.array(self.corpus_embeddings, dtype=np.float32), top_k=topn)
|
||||
hits = hits[0] # Get the first query result when query is string
|
||||
|
||||
for hit in hits[:topn]:
|
||||
result.append((hit['corpus_id'], self.corpus[hit['corpus_id']], hit['score']))
|
||||
return result
|
||||
|
||||
|
||||
class ImageHashSimilarity:
|
||||
@ -240,99 +332,3 @@ class SiftSimilarity:
|
||||
result.append((corpus_id, doc, score))
|
||||
result.sort(key=lambda x: x[2], reverse=True)
|
||||
return result[:topn]
|
||||
|
||||
|
||||
class ClipSimilarity:
|
||||
"""
|
||||
Compute CLIP similarity between two images and retrieves most
|
||||
similar image for a given image corpus.
|
||||
|
||||
CLIP: https://github.com/openai/CLIP.git
|
||||
"""
|
||||
|
||||
def __init__(self, corpus: List[str] = None, model_name_or_path: str = 'clip-ViT-B-32'):
|
||||
self.corpus = []
|
||||
self.clip_model = SentenceTransformer(model_name_or_path) # load the CLIP model
|
||||
self.corpus_embeddings = []
|
||||
if corpus is not None:
|
||||
self.add_corpus(corpus)
|
||||
|
||||
def __len__(self):
|
||||
"""Get length of corpus."""
|
||||
return len(self.corpus)
|
||||
|
||||
def __str__(self):
|
||||
base = f"Similarity: {self.__class__.__name__}, matching_model: CLIP"
|
||||
if self.corpus:
|
||||
base += f", corpus size: {len(self.corpus)}"
|
||||
return base
|
||||
|
||||
def add_corpus(self, corpus: List[str]):
|
||||
"""
|
||||
Extend the corpus with new documents.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
corpus : list of str
|
||||
"""
|
||||
self.corpus += corpus
|
||||
corpus_embeddings = self._get_vector(corpus).tolist()
|
||||
if self.corpus_embeddings:
|
||||
self.corpus_embeddings += corpus_embeddings
|
||||
else:
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
logger.info(f"Add corpus size: {len(corpus)}, total size: {len(self.corpus)}")
|
||||
|
||||
def _convert_to_rgb(self, img):
|
||||
"""Convert image to RGB mode."""
|
||||
if img.mode != 'RGB':
|
||||
img = img.convert('RGB')
|
||||
return img
|
||||
|
||||
def _get_vector(self, img_paths: Union[str, List[str]]):
|
||||
"""
|
||||
Returns the embeddings for a batch of images.
|
||||
:param img_paths:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(img_paths, str):
|
||||
img_paths = [img_paths]
|
||||
imgs = [Image.open(filepath) for filepath in img_paths]
|
||||
imgs = [self._convert_to_rgb(img) for img in imgs]
|
||||
return self.clip_model.encode(imgs, batch_size=128, convert_to_tensor=False, show_progress_bar=True)
|
||||
|
||||
def similarity(self, fp1: str, fp2: str):
|
||||
"""
|
||||
Compute similarity between two image files.
|
||||
:param fp1: image file path 1
|
||||
:param fp2: image file path 2
|
||||
:return: similarity score
|
||||
"""
|
||||
emb1 = self._get_vector(fp1)
|
||||
emb2 = self._get_vector(fp2)
|
||||
similarity_score = float(cos_sim(emb1, emb2))
|
||||
|
||||
return similarity_score
|
||||
|
||||
def distance(self, fp1: str, fp2: str):
|
||||
"""Compute distance between two image files."""
|
||||
return 1 - self.similarity(fp1, fp2)
|
||||
|
||||
def most_similar(self, query_fp: str, topn: int = 10):
|
||||
"""
|
||||
Find the topn most similar images to the query against the corpus.
|
||||
:param query_fp: str
|
||||
:param topn: int
|
||||
:return: list of tuples (id, image_path, similarity)
|
||||
"""
|
||||
result = []
|
||||
q_emb = self._get_vector(query_fp)
|
||||
|
||||
# Computes the cosine-similarity between the query embedding and all image embeddings.
|
||||
hits = semantic_search(q_emb, np.array(self.corpus_embeddings, dtype=np.float32), top_k=topn)
|
||||
hits = hits[0] # Get the first query result when query is string
|
||||
|
||||
for hit in hits[:topn]:
|
||||
result.append((hit['corpus_id'], self.corpus[hit['corpus_id']], hit['score']))
|
||||
|
||||
return result[:topn]
|
||||
|
@ -111,5 +111,4 @@ class Similarity:
|
||||
|
||||
for hit in hits[0:topn]:
|
||||
result.append((hit['corpus_id'], self.corpus[hit['corpus_id']], hit['score']))
|
||||
|
||||
return result
|
||||
|
Loading…
Reference in New Issue
Block a user