get to chinese model.
This commit is contained in:
parent
25b1df6e8a
commit
15d385a241
@ -77,8 +77,10 @@ class Similarity(SimilarityABC):
|
||||
def __init__(
|
||||
self,
|
||||
corpus: Union[List[str], Dict[str, str]] = None,
|
||||
model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
||||
model_name_or_path="shibing624/text2vec-base-chinese",
|
||||
encoder_type="MEAN",
|
||||
max_seq_length=128,
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Initialize the similarity object.
|
||||
@ -90,7 +92,12 @@ class Similarity(SimilarityABC):
|
||||
:param max_seq_length: Max sequence length for sentence model.
|
||||
"""
|
||||
if isinstance(model_name_or_path, str):
|
||||
self.sentence_model = SentenceModel(model_name_or_path, max_seq_length=max_seq_length)
|
||||
self.sentence_model = SentenceModel(
|
||||
model_name_or_path,
|
||||
encoder_type=encoder_type,
|
||||
max_seq_length=max_seq_length,
|
||||
device=device
|
||||
)
|
||||
elif hasattr(model_name_or_path, "encode"):
|
||||
self.sentence_model = model_name_or_path
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user