get to chinese model.

This commit is contained in:
shibing624 2023-05-02 23:46:43 +08:00
parent 25b1df6e8a
commit 15d385a241

View File

@ -77,8 +77,10 @@ class Similarity(SimilarityABC):
def __init__( def __init__(
self, self,
corpus: Union[List[str], Dict[str, str]] = None, corpus: Union[List[str], Dict[str, str]] = None,
model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", model_name_or_path="shibing624/text2vec-base-chinese",
encoder_type="MEAN",
max_seq_length=128, max_seq_length=128,
device=None,
): ):
""" """
Initialize the similarity object. Initialize the similarity object.
@ -90,7 +92,12 @@ class Similarity(SimilarityABC):
:param max_seq_length: Max sequence length for sentence model. :param max_seq_length: Max sequence length for sentence model.
""" """
if isinstance(model_name_or_path, str): if isinstance(model_name_or_path, str):
self.sentence_model = SentenceModel(model_name_or_path, max_seq_length=max_seq_length) self.sentence_model = SentenceModel(
model_name_or_path,
encoder_type=encoder_type,
max_seq_length=max_seq_length,
device=device
)
elif hasattr(model_name_or_path, "encode"): elif hasattr(model_name_or_path, "encode"):
self.sentence_model = model_name_or_path self.sentence_model = model_name_or_path
else: else: