fix sim
This commit is contained in:
parent
ad471da0c6
commit
465649e716
@ -111,6 +111,7 @@ class KerasBertVector():
|
||||
with graph.as_default():
|
||||
predicts = model.predict([input_ids, input_type_ids], batch_size=1)
|
||||
print(predicts.shape)
|
||||
tokens_text = tokens_text if len(tokens_text) <= self.max_seq_len - 2 else tokens_text[:self.max_seq_len - 2]
|
||||
for i, token in enumerate(tokens_text):
|
||||
print(token, [len(predicts[0][i].tolist())], predicts[0][i].tolist())
|
||||
|
||||
|
@ -16,11 +16,11 @@ def calculate_count():
|
||||
bert_vector = KerasBertVector()
|
||||
print("bert start ok!")
|
||||
time_start = time.time()
|
||||
for i in range(1000):
|
||||
for i in range(10):
|
||||
vector = bert_vector.bert_encode(["jy,你知道吗,我一直都很喜欢你呀,在一起在一起在一起,哈哈哈哈"])
|
||||
|
||||
time_end = time.time()
|
||||
time_avg = (time_end-time_start)/1000
|
||||
time_avg = (time_end-time_start)/10
|
||||
print(vector)
|
||||
print(time_avg)
|
||||
# 0.12605296468734742 win10 gpu avg
|
||||
@ -37,6 +37,11 @@ def sim_two_question():
|
||||
import math
|
||||
|
||||
def cosine_distance(v1, v2): # 余弦距离
|
||||
if type(v1)==list:
|
||||
v1 = np.array(v1)
|
||||
if type(v2)==list:
|
||||
v2 = np.array(v2)
|
||||
|
||||
if v1.all() and v2.all():
|
||||
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
|
||||
else:
|
||||
|
@ -37,6 +37,10 @@ def sim_two_question():
|
||||
import math
|
||||
|
||||
def cosine_distance(v1, v2): # 余弦距离
|
||||
if type(v1)==list:
|
||||
v1 = np.array(v1)
|
||||
if type(v2)==list:
|
||||
v2 = np.array(v2)
|
||||
if v1.all() and v2.all():
|
||||
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user