fix sim
This commit is contained in:
parent
ad471da0c6
commit
465649e716
@ -111,6 +111,7 @@ class KerasBertVector():
|
|||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
predicts = model.predict([input_ids, input_type_ids], batch_size=1)
|
predicts = model.predict([input_ids, input_type_ids], batch_size=1)
|
||||||
print(predicts.shape)
|
print(predicts.shape)
|
||||||
|
tokens_text = tokens_text if len(tokens_text) <= self.max_seq_len - 2 else tokens_text[:self.max_seq_len - 2]
|
||||||
for i, token in enumerate(tokens_text):
|
for i, token in enumerate(tokens_text):
|
||||||
print(token, [len(predicts[0][i].tolist())], predicts[0][i].tolist())
|
print(token, [len(predicts[0][i].tolist())], predicts[0][i].tolist())
|
||||||
|
|
||||||
|
@ -16,11 +16,11 @@ def calculate_count():
|
|||||||
bert_vector = KerasBertVector()
|
bert_vector = KerasBertVector()
|
||||||
print("bert start ok!")
|
print("bert start ok!")
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
for i in range(1000):
|
for i in range(10):
|
||||||
vector = bert_vector.bert_encode(["jy,你知道吗,我一直都很喜欢你呀,在一起在一起在一起,哈哈哈哈"])
|
vector = bert_vector.bert_encode(["jy,你知道吗,我一直都很喜欢你呀,在一起在一起在一起,哈哈哈哈"])
|
||||||
|
|
||||||
time_end = time.time()
|
time_end = time.time()
|
||||||
time_avg = (time_end-time_start)/1000
|
time_avg = (time_end-time_start)/10
|
||||||
print(vector)
|
print(vector)
|
||||||
print(time_avg)
|
print(time_avg)
|
||||||
# 0.12605296468734742 win10 gpu avg
|
# 0.12605296468734742 win10 gpu avg
|
||||||
@ -37,6 +37,11 @@ def sim_two_question():
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
def cosine_distance(v1, v2): # 余弦距离
|
def cosine_distance(v1, v2): # 余弦距离
|
||||||
|
if type(v1)==list:
|
||||||
|
v1 = np.array(v1)
|
||||||
|
if type(v2)==list:
|
||||||
|
v2 = np.array(v2)
|
||||||
|
|
||||||
if v1.all() and v2.all():
|
if v1.all() and v2.all():
|
||||||
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
|
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
|
||||||
else:
|
else:
|
||||||
|
@ -37,6 +37,10 @@ def sim_two_question():
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
def cosine_distance(v1, v2): # 余弦距离
|
def cosine_distance(v1, v2): # 余弦距离
|
||||||
|
if type(v1)==list:
|
||||||
|
v1 = np.array(v1)
|
||||||
|
if type(v2)==list:
|
||||||
|
v2 = np.array(v2)
|
||||||
if v1.all() and v2.all():
|
if v1.all() and v2.all():
|
||||||
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
|
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user