序列长度改为None

This commit is contained in:
terrifyzhao 2019-04-19 10:56:12 +08:00
parent a066214e9c
commit 62e7727575
2 changed files with 2 additions and 3 deletions

1
.gitignore vendored
View File

@ -4,3 +4,4 @@ __pycache__/
.idea/ .idea/
data/data_merger.py data/data_merger.py
data/data.py data/data.py
.DS_Store

View File

@ -47,7 +47,6 @@ class BertVector:
self.output_queue = Queue(maxsize=1) self.output_queue = Queue(maxsize=1)
self.predict_thread = Thread(target=self.predict_from_queue, daemon=True) self.predict_thread = Thread(target=self.predict_from_queue, daemon=True)
self.predict_thread.start() self.predict_thread.start()
self.sentence_len = 0
def get_estimator(self): def get_estimator(self):
from tensorflow.python.estimator.estimator import Estimator from tensorflow.python.estimator.estimator import Estimator
@ -84,7 +83,6 @@ class BertVector:
self.output_queue.put(i) self.output_queue.put(i)
def encode(self, sentence): def encode(self, sentence):
self.sentence_len = len(sentence)
self.input_queue.put(sentence) self.input_queue.put(sentence)
prediction = self.output_queue.get()['encodes'] prediction = self.output_queue.get()['encodes']
return prediction return prediction
@ -98,7 +96,7 @@ class BertVector:
'input_mask': tf.int32, 'input_mask': tf.int32,
'input_type_ids': tf.int32}, 'input_type_ids': tf.int32},
output_shapes={ output_shapes={
'unique_ids': (self.sentence_len,), 'unique_ids': (None,),
'input_ids': (None, self.max_seq_length), 'input_ids': (None, self.max_seq_length),
'input_mask': (None, self.max_seq_length), 'input_mask': (None, self.max_seq_length),
'input_type_ids': (None, self.max_seq_length)}).prefetch(10)) 'input_type_ids': (None, self.max_seq_length)}).prefetch(10))