diff --git a/extract_feature.py b/extract_feature.py index 820d061..ec7f295 100644 --- a/extract_feature.py +++ b/extract_feature.py @@ -86,7 +86,7 @@ class BertVector: def encode(self, sentence): self.sentence_len = len(sentence) self.input_queue.put(sentence) - prediction = self.output_queue.get()['encodes'][0] + prediction = self.output_queue.get()['encodes'] return prediction def queue_predict_input_fn(self):