This commit is contained in:
yongzhuo 2019-11-12 19:59:53 +08:00
parent ad471da0c6
commit 465649e716
3 changed files with 12 additions and 2 deletions

View File

@ -111,6 +111,7 @@ class KerasBertVector():
with graph.as_default():
predicts = model.predict([input_ids, input_type_ids], batch_size=1)
print(predicts.shape)
tokens_text = tokens_text if len(tokens_text) <= self.max_seq_len - 2 else tokens_text[:self.max_seq_len - 2]
for i, token in enumerate(tokens_text):
print(token, [len(predicts[0][i].tolist())], predicts[0][i].tolist())

View File

@ -16,11 +16,11 @@ def calculate_count():
bert_vector = KerasBertVector()
print("bert start ok!")
time_start = time.time()
for i in range(1000):
for i in range(10):
vector = bert_vector.bert_encode(["jy你知道吗我一直都很喜欢你呀在一起在一起在一起哈哈哈哈"])
time_end = time.time()
time_avg = (time_end-time_start)/1000
time_avg = (time_end-time_start)/10
print(vector)
print(time_avg)
# 0.12605296468734742 win10 gpu avg
@ -37,6 +37,11 @@ def sim_two_question():
import math
def cosine_distance(v1, v2): # 余弦距离
if type(v1)==list:
v1 = np.array(v1)
if type(v2)==list:
v2 = np.array(v2)
if v1.all() and v2.all():
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
else:

View File

@ -37,6 +37,10 @@ def sim_two_question():
import math
def cosine_distance(v1, v2): # 余弦距离
if type(v1)==list:
v1 = np.array(v1)
if type(v2)==list:
v2 = np.array(v2)
if v1.all() and v2.all():
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
else: