序列长度改为None
This commit is contained in:
parent
a066214e9c
commit
62e7727575
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ __pycache__/
|
||||
.idea/
|
||||
data/data_merger.py
|
||||
data/data.py
|
||||
.DS_Store
|
@ -47,7 +47,6 @@ class BertVector:
|
||||
self.output_queue = Queue(maxsize=1)
|
||||
self.predict_thread = Thread(target=self.predict_from_queue, daemon=True)
|
||||
self.predict_thread.start()
|
||||
self.sentence_len = 0
|
||||
|
||||
def get_estimator(self):
|
||||
from tensorflow.python.estimator.estimator import Estimator
|
||||
@ -84,7 +83,6 @@ class BertVector:
|
||||
self.output_queue.put(i)
|
||||
|
||||
def encode(self, sentence):
|
||||
self.sentence_len = len(sentence)
|
||||
self.input_queue.put(sentence)
|
||||
prediction = self.output_queue.get()['encodes']
|
||||
return prediction
|
||||
@ -98,7 +96,7 @@ class BertVector:
|
||||
'input_mask': tf.int32,
|
||||
'input_type_ids': tf.int32},
|
||||
output_shapes={
|
||||
'unique_ids': (self.sentence_len,),
|
||||
'unique_ids': (None,),
|
||||
'input_ids': (None, self.max_seq_length),
|
||||
'input_mask': (None, self.max_seq_length),
|
||||
'input_type_ids': (None, self.max_seq_length)}).prefetch(10))
|
||||
|
Loading…
Reference in New Issue
Block a user