修改predict代码

This commit is contained in:
joe 2019-01-30 09:57:02 +08:00
parent 3c4ed2ffe7
commit 0ca26576a4

View File

@ -128,7 +128,7 @@ class BertSim:
def set_mode(self, mode): def set_mode(self, mode):
self.mode = mode self.mode = mode
self.estimator = self.get_estimator() self.estimator = self.get_estimator()
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.PREDICT:
self.input_queue = Queue(maxsize=1) self.input_queue = Queue(maxsize=1)
self.output_queue = Queue(maxsize=1) self.output_queue = Queue(maxsize=1)
self.predict_thread = Thread(target=self.predict_from_queue, daemon=True) self.predict_thread = Thread(target=self.predict_from_queue, daemon=True)
@ -664,13 +664,13 @@ class BertSim:
if __name__ == '__main__': if __name__ == '__main__':
sim = BertSim() sim = BertSim()
sim.set_mode(tf.estimator.ModeKeys.TRAIN) # sim.set_mode(tf.estimator.ModeKeys.TRAIN)
sim.train() # sim.train()
sim.set_mode(tf.estimator.ModeKeys.EVAL) # sim.set_mode(tf.estimator.ModeKeys.EVAL)
sim.eval() # sim.eval()
# sim.set_mode(tf.estimator.ModeKeys.PREDICT) sim.set_mode(tf.estimator.ModeKeys.PREDICT)
# while True: while True:
# sentence1 = input('sentence1: ') sentence1 = input('sentence1: ')
# sentence2 = input('sentence2: ') sentence2 = input('sentence2: ')
# predict = sim.predict(sentence1, sentence2) predict = sim.predict(sentence1, sentence2)
# print(f'similarity{predict[0][1]}') print(f'similarity{predict[0][1]}')