Update similarity.py

update predict method to support batch prediction.
This commit is contained in:
hahajinbu 2019-10-14 14:50:44 +08:00 committed by GitHub
parent 1d5f3eb649
commit 0a2da688d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -103,7 +103,10 @@ class SimProcessor(DataProcessor):
return test_data
def get_sentence_examples(self, questions):
for index, data in enumerate(questions):
questions = questions[0] #self.input_queue.put([(sentences1, sentences2)])
print(len(questions))
for index,data in enumerate(zip(questions[0],questions[1])):
# print(data)
guid = 'test-%d' % index
text_a = tokenization.convert_to_unicode(str(data[0]))
text_b = tokenization.convert_to_unicode(str(data[1]))
@ -301,7 +304,7 @@ class BertSim:
'input_ids': (None, self.max_seq_length),
'input_mask': (None, self.max_seq_length),
'segment_ids': (None, self.max_seq_length),
'label_ids': (1,)}).prefetch(10))
'label_ids': (None,)}).prefetch(10))
def convert_examples_to_features(self, examples, label_list, max_seq_length, tokenizer):
"""Convert a set of `InputExample`s to a list of `InputFeatures`."""
@ -668,7 +671,37 @@ if __name__ == '__main__':
sim.train()
sim.set_mode(tf.estimator.ModeKeys.EVAL)
sim.eval()
# sim.set_mode(tf.estimator.ModeKeys.PREDICT)
#####预测测试
sim.set_mode(tf.estimator.ModeKeys.PREDICT)
import time
results_1 = []
t1 = time.time()
for i in range(1000):
if i % 2 ==0:
x = bs.predict(["{}".format(i)],["{}".format(i)])[0][1]
else:
x = bs.predict(["{}".format(i)],["{}".format(i)])[0][1]
results_1.append(x)
t2 = time.time()
print('predict one by one cost: {} seconds.'.format(str(t2 - t1)))
t3 = time.time()
########=====predict batch=============
sentences_1 = []
sentences_2 = []
for i in range(1000):
if i % 2 ==0:
sentences_1.append("{}".format(i))
sentences_2.append("{}".format(i))
else:
sentences_1.append("{}".format(i))
sentences_2.append("{}".format(i))
batch_results_1 = bs.predict(sentences_1,sentences_2)
batch_results_1 = batch_results_1[:,1]
t4 = time.time()
print('predict batch cost: {} seconds.'.format(str(t4 - t3)))
from scipy.stats import pearsonr
print(pearsonr(results_1,batch_results_1)) ###(1.0,0.0)
# while True:
# sentence1 = input('sentence1: ')
# sentence2 = input('sentence2: ')