From b7cb27b59e0d9945048ab256a2947c83a7722837 Mon Sep 17 00:00:00 2001 From: joe Date: Wed, 30 Jan 2019 19:05:44 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=A4=9Asentence=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 +++- extract_feature.py | 17 +++++++++++------ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 0e62227..e0ae6f4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ /chinese_L-12_H-768_A-12 tmp/ __pycache__/ -.idea/ \ No newline at end of file +.idea/ +data/data_merger.py +data/data.py diff --git a/extract_feature.py b/extract_feature.py index 9540f1e..a9cf2b8 100644 --- a/extract_feature.py +++ b/extract_feature.py @@ -5,6 +5,9 @@ import args from queue import Queue from threading import Thread import tensorflow as tf +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = '1' class InputExample(object): @@ -43,7 +46,7 @@ class BertVector: self.input_queue = Queue(maxsize=1) self.output_queue = Queue(maxsize=1) self.predict_thread = Thread(target=self.predict_from_queue, daemon=True) - self.predict_thread.start() + self.sentence_len = 0 def get_estimator(self): from tensorflow.python.estimator.estimator import Estimator @@ -80,6 +83,8 @@ class BertVector: self.output_queue.put(i) def encode(self, sentence): + self.sentence_len = len(sentence) + self.predict_thread.start() self.input_queue.put(sentence) prediction = self.output_queue.get() return prediction @@ -93,7 +98,7 @@ class BertVector: 'input_mask': tf.int32, 'input_type_ids': tf.int32}, output_shapes={ - 'unique_ids': (1,), + 'unique_ids': (self.sentence_len,), 'input_ids': (None, self.max_seq_length), 'input_mask': (None, self.max_seq_length), 'input_type_ids': (None, self.max_seq_length)})) @@ -327,7 +332,7 @@ class BertVector: if __name__ == "__main__": bert = BertVector() - while True: - question = input('question: ') - vectors = bert.encode([question]) - print(str(vectors)) + # while True: + # question = input('question: ') + vectors = bert.encode(['你好', '哈哈']) + print(str(vectors))