From 15d385a241418393f9310450962e8b27659aeb9d Mon Sep 17 00:00:00 2001 From: shibing624 Date: Tue, 2 May 2023 23:46:43 +0800 Subject: [PATCH] get to chinese model. --- similarities/similarity.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/similarities/similarity.py b/similarities/similarity.py index e412541..8f26632 100644 --- a/similarities/similarity.py +++ b/similarities/similarity.py @@ -77,8 +77,10 @@ class Similarity(SimilarityABC): def __init__( self, corpus: Union[List[str], Dict[str, str]] = None, - model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", + model_name_or_path="shibing624/text2vec-base-chinese", + encoder_type="MEAN", max_seq_length=128, + device=None, ): """ Initialize the similarity object. @@ -90,7 +92,12 @@ class Similarity(SimilarityABC): :param max_seq_length: Max sequence length for sentence model. """ if isinstance(model_name_or_path, str): - self.sentence_model = SentenceModel(model_name_or_path, max_seq_length=max_seq_length) + self.sentence_model = SentenceModel( + model_name_or_path, + encoder_type=encoder_type, + max_seq_length=max_seq_length, + device=device + ) elif hasattr(model_name_or_path, "encode"): self.sentence_model = model_name_or_path else: