diff --git a/.gitignore b/.gitignore index e0ae6f4..c3abd94 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__/ .idea/ data/data_merger.py data/data.py +.DS_Store \ No newline at end of file diff --git a/extract_feature.py b/extract_feature.py index ec7f295..6ff2409 100644 --- a/extract_feature.py +++ b/extract_feature.py @@ -47,7 +47,6 @@ class BertVector: self.output_queue = Queue(maxsize=1) self.predict_thread = Thread(target=self.predict_from_queue, daemon=True) self.predict_thread.start() - self.sentence_len = 0 def get_estimator(self): from tensorflow.python.estimator.estimator import Estimator @@ -84,7 +83,6 @@ class BertVector: self.output_queue.put(i) def encode(self, sentence): - self.sentence_len = len(sentence) self.input_queue.put(sentence) prediction = self.output_queue.get()['encodes'] return prediction @@ -98,7 +96,7 @@ class BertVector: 'input_mask': tf.int32, 'input_type_ids': tf.int32}, output_shapes={ - 'unique_ids': (self.sentence_len,), + 'unique_ids': (None,), 'input_ids': (None, self.max_seq_length), 'input_mask': (None, self.max_seq_length), 'input_type_ids': (None, self.max_seq_length)}).prefetch(10))