diff --git a/similarity.py b/similarity.py index 17283d5..4d6e792 100644 --- a/similarity.py +++ b/similarity.py @@ -128,7 +128,7 @@ class BertSim: def set_mode(self, mode): self.mode = mode self.estimator = self.get_estimator() - if mode == tf.estimator.ModeKeys.TRAIN: + if mode == tf.estimator.ModeKeys.PREDICT: self.input_queue = Queue(maxsize=1) self.output_queue = Queue(maxsize=1) self.predict_thread = Thread(target=self.predict_from_queue, daemon=True) @@ -664,13 +664,13 @@ class BertSim: if __name__ == '__main__': sim = BertSim() - sim.set_mode(tf.estimator.ModeKeys.TRAIN) - sim.train() - sim.set_mode(tf.estimator.ModeKeys.EVAL) - sim.eval() - # sim.set_mode(tf.estimator.ModeKeys.PREDICT) - # while True: - # sentence1 = input('sentence1: ') - # sentence2 = input('sentence2: ') - # predict = sim.predict(sentence1, sentence2) - # print(f'similarity:{predict[0][1]}') + # sim.set_mode(tf.estimator.ModeKeys.TRAIN) + # sim.train() + # sim.set_mode(tf.estimator.ModeKeys.EVAL) + # sim.eval() + sim.set_mode(tf.estimator.ModeKeys.PREDICT) + while True: + sentence1 = input('sentence1: ') + sentence2 = input('sentence2: ') + predict = sim.predict(sentence1, sentence2) + print(f'similarity:{predict[0][1]}')