Update similarity.py
update predict method to support batch prediction.
This commit is contained in:
parent
1d5f3eb649
commit
0a2da688d1
@ -103,7 +103,10 @@ class SimProcessor(DataProcessor):
|
||||
return test_data
|
||||
|
||||
def get_sentence_examples(self, questions):
|
||||
for index, data in enumerate(questions):
|
||||
questions = questions[0] #self.input_queue.put([(sentences1, sentences2)])
|
||||
print(len(questions))
|
||||
for index,data in enumerate(zip(questions[0],questions[1])):
|
||||
# print(data)
|
||||
guid = 'test-%d' % index
|
||||
text_a = tokenization.convert_to_unicode(str(data[0]))
|
||||
text_b = tokenization.convert_to_unicode(str(data[1]))
|
||||
@ -301,7 +304,7 @@ class BertSim:
|
||||
'input_ids': (None, self.max_seq_length),
|
||||
'input_mask': (None, self.max_seq_length),
|
||||
'segment_ids': (None, self.max_seq_length),
|
||||
'label_ids': (1,)}).prefetch(10))
|
||||
'label_ids': (None,)}).prefetch(10))
|
||||
|
||||
def convert_examples_to_features(self, examples, label_list, max_seq_length, tokenizer):
|
||||
"""Convert a set of `InputExample`s to a list of `InputFeatures`."""
|
||||
@ -668,7 +671,37 @@ if __name__ == '__main__':
|
||||
sim.train()
|
||||
sim.set_mode(tf.estimator.ModeKeys.EVAL)
|
||||
sim.eval()
|
||||
# sim.set_mode(tf.estimator.ModeKeys.PREDICT)
|
||||
|
||||
#####预测测试
|
||||
sim.set_mode(tf.estimator.ModeKeys.PREDICT)
|
||||
import time
|
||||
results_1 = []
|
||||
t1 = time.time()
|
||||
for i in range(1000):
|
||||
if i % 2 ==0:
|
||||
x = bs.predict(["你{}好".format(i)],["您{}好".format(i)])[0][1]
|
||||
else:
|
||||
x = bs.predict(["你{}好".format(i)],["不{}好".format(i)])[0][1]
|
||||
results_1.append(x)
|
||||
t2 = time.time()
|
||||
print('predict one by one cost: {} seconds.'.format(str(t2 - t1)))
|
||||
t3 = time.time()
|
||||
########=====predict batch=============
|
||||
sentences_1 = []
|
||||
sentences_2 = []
|
||||
for i in range(1000):
|
||||
if i % 2 ==0:
|
||||
sentences_1.append("你{}好".format(i))
|
||||
sentences_2.append("您{}好".format(i))
|
||||
else:
|
||||
sentences_1.append("你{}好".format(i))
|
||||
sentences_2.append("不{}好".format(i))
|
||||
batch_results_1 = bs.predict(sentences_1,sentences_2)
|
||||
batch_results_1 = batch_results_1[:,1]
|
||||
t4 = time.time()
|
||||
print('predict batch cost: {} seconds.'.format(str(t4 - t3)))
|
||||
from scipy.stats import pearsonr
|
||||
print(pearsonr(results_1,batch_results_1)) ###(1.0,0.0)
|
||||
# while True:
|
||||
# sentence1 = input('sentence1: ')
|
||||
# sentence2 = input('sentence2: ')
|
||||
|
Loading…
Reference in New Issue
Block a user