序列长度改为None
This commit is contained in:
parent
a066214e9c
commit
62e7727575
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,3 +4,4 @@ __pycache__/
|
|||||||
.idea/
|
.idea/
|
||||||
data/data_merger.py
|
data/data_merger.py
|
||||||
data/data.py
|
data/data.py
|
||||||
|
.DS_Store
|
@ -47,7 +47,6 @@ class BertVector:
|
|||||||
self.output_queue = Queue(maxsize=1)
|
self.output_queue = Queue(maxsize=1)
|
||||||
self.predict_thread = Thread(target=self.predict_from_queue, daemon=True)
|
self.predict_thread = Thread(target=self.predict_from_queue, daemon=True)
|
||||||
self.predict_thread.start()
|
self.predict_thread.start()
|
||||||
self.sentence_len = 0
|
|
||||||
|
|
||||||
def get_estimator(self):
|
def get_estimator(self):
|
||||||
from tensorflow.python.estimator.estimator import Estimator
|
from tensorflow.python.estimator.estimator import Estimator
|
||||||
@ -84,7 +83,6 @@ class BertVector:
|
|||||||
self.output_queue.put(i)
|
self.output_queue.put(i)
|
||||||
|
|
||||||
def encode(self, sentence):
|
def encode(self, sentence):
|
||||||
self.sentence_len = len(sentence)
|
|
||||||
self.input_queue.put(sentence)
|
self.input_queue.put(sentence)
|
||||||
prediction = self.output_queue.get()['encodes']
|
prediction = self.output_queue.get()['encodes']
|
||||||
return prediction
|
return prediction
|
||||||
@ -98,7 +96,7 @@ class BertVector:
|
|||||||
'input_mask': tf.int32,
|
'input_mask': tf.int32,
|
||||||
'input_type_ids': tf.int32},
|
'input_type_ids': tf.int32},
|
||||||
output_shapes={
|
output_shapes={
|
||||||
'unique_ids': (self.sentence_len,),
|
'unique_ids': (None,),
|
||||||
'input_ids': (None, self.max_seq_length),
|
'input_ids': (None, self.max_seq_length),
|
||||||
'input_mask': (None, self.max_seq_length),
|
'input_mask': (None, self.max_seq_length),
|
||||||
'input_type_ids': (None, self.max_seq_length)}).prefetch(10))
|
'input_type_ids': (None, self.max_seq_length)}).prefetch(10))
|
||||||
|
Loading…
Reference in New Issue
Block a user