修改predict代码

This commit is contained in:
joe 2019-01-30 09:57:02 +08:00
parent 3c4ed2ffe7
commit 0ca26576a4

View File

@ -128,7 +128,7 @@ class BertSim:
def set_mode(self, mode):
self.mode = mode
self.estimator = self.get_estimator()
if mode == tf.estimator.ModeKeys.TRAIN:
if mode == tf.estimator.ModeKeys.PREDICT:
self.input_queue = Queue(maxsize=1)
self.output_queue = Queue(maxsize=1)
self.predict_thread = Thread(target=self.predict_from_queue, daemon=True)
@ -664,13 +664,13 @@ class BertSim:
if __name__ == '__main__':
sim = BertSim()
sim.set_mode(tf.estimator.ModeKeys.TRAIN)
sim.train()
sim.set_mode(tf.estimator.ModeKeys.EVAL)
sim.eval()
# sim.set_mode(tf.estimator.ModeKeys.PREDICT)
# while True:
# sentence1 = input('sentence1: ')
# sentence2 = input('sentence2: ')
# predict = sim.predict(sentence1, sentence2)
# print(f'similarity{predict[0][1]}')
# sim.set_mode(tf.estimator.ModeKeys.TRAIN)
# sim.train()
# sim.set_mode(tf.estimator.ModeKeys.EVAL)
# sim.eval()
sim.set_mode(tf.estimator.ModeKeys.PREDICT)
while True:
sentence1 = input('sentence1: ')
sentence2 = input('sentence2: ')
predict = sim.predict(sentence1, sentence2)
print(f'similarity{predict[0][1]}')