From 62e772757508010adc85d7d0db841d83326fe0a4 Mon Sep 17 00:00:00 2001 From: terrifyzhao Date: Fri, 19 Apr 2019 10:56:12 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BA=8F=E5=88=97=E9=95=BF=E5=BA=A6=E6=94=B9?= =?UTF-8?q?=E4=B8=BANone?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + extract_feature.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index e0ae6f4..c3abd94 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__/ .idea/ data/data_merger.py data/data.py +.DS_Store \ No newline at end of file diff --git a/extract_feature.py b/extract_feature.py index ec7f295..6ff2409 100644 --- a/extract_feature.py +++ b/extract_feature.py @@ -47,7 +47,6 @@ class BertVector: self.output_queue = Queue(maxsize=1) self.predict_thread = Thread(target=self.predict_from_queue, daemon=True) self.predict_thread.start() - self.sentence_len = 0 def get_estimator(self): from tensorflow.python.estimator.estimator import Estimator @@ -84,7 +83,6 @@ class BertVector: self.output_queue.put(i) def encode(self, sentence): - self.sentence_len = len(sentence) self.input_queue.put(sentence) prediction = self.output_queue.get()['encodes'] return prediction @@ -98,7 +96,7 @@ class BertVector: 'input_mask': tf.int32, 'input_type_ids': tf.int32}, output_shapes={ - 'unique_ids': (self.sentence_len,), + 'unique_ids': (None,), 'input_ids': (None, self.max_seq_length), 'input_mask': (None, self.max_seq_length), 'input_type_ids': (None, self.max_seq_length)}).prefetch(10))