get to chinese model.
This commit is contained in:
parent
25b1df6e8a
commit
15d385a241
@ -77,8 +77,10 @@ class Similarity(SimilarityABC):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
corpus: Union[List[str], Dict[str, str]] = None,
|
corpus: Union[List[str], Dict[str, str]] = None,
|
||||||
model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
model_name_or_path="shibing624/text2vec-base-chinese",
|
||||||
|
encoder_type="MEAN",
|
||||||
max_seq_length=128,
|
max_seq_length=128,
|
||||||
|
device=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the similarity object.
|
Initialize the similarity object.
|
||||||
@ -90,7 +92,12 @@ class Similarity(SimilarityABC):
|
|||||||
:param max_seq_length: Max sequence length for sentence model.
|
:param max_seq_length: Max sequence length for sentence model.
|
||||||
"""
|
"""
|
||||||
if isinstance(model_name_or_path, str):
|
if isinstance(model_name_or_path, str):
|
||||||
self.sentence_model = SentenceModel(model_name_or_path, max_seq_length=max_seq_length)
|
self.sentence_model = SentenceModel(
|
||||||
|
model_name_or_path,
|
||||||
|
encoder_type=encoder_type,
|
||||||
|
max_seq_length=max_seq_length,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
elif hasattr(model_name_or_path, "encode"):
|
elif hasattr(model_name_or_path, "encode"):
|
||||||
self.sentence_model = model_name_or_path
|
self.sentence_model = model_name_or_path
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user