修改多sentence的bug
This commit is contained in:
parent
7c3c1ad29c
commit
b7cb27b59e
2
.gitignore
vendored
2
.gitignore
vendored
@ -2,3 +2,5 @@
|
||||
tmp/
|
||||
__pycache__/
|
||||
.idea/
|
||||
data/data_merger.py
|
||||
data/data.py
|
||||
|
@ -5,6 +5,9 @@ import args
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
import tensorflow as tf
|
||||
import os
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
@ -43,7 +46,7 @@ class BertVector:
|
||||
self.input_queue = Queue(maxsize=1)
|
||||
self.output_queue = Queue(maxsize=1)
|
||||
self.predict_thread = Thread(target=self.predict_from_queue, daemon=True)
|
||||
self.predict_thread.start()
|
||||
self.sentence_len = 0
|
||||
|
||||
def get_estimator(self):
|
||||
from tensorflow.python.estimator.estimator import Estimator
|
||||
@ -80,6 +83,8 @@ class BertVector:
|
||||
self.output_queue.put(i)
|
||||
|
||||
def encode(self, sentence):
|
||||
self.sentence_len = len(sentence)
|
||||
self.predict_thread.start()
|
||||
self.input_queue.put(sentence)
|
||||
prediction = self.output_queue.get()
|
||||
return prediction
|
||||
@ -93,7 +98,7 @@ class BertVector:
|
||||
'input_mask': tf.int32,
|
||||
'input_type_ids': tf.int32},
|
||||
output_shapes={
|
||||
'unique_ids': (1,),
|
||||
'unique_ids': (self.sentence_len,),
|
||||
'input_ids': (None, self.max_seq_length),
|
||||
'input_mask': (None, self.max_seq_length),
|
||||
'input_type_ids': (None, self.max_seq_length)}))
|
||||
@ -327,7 +332,7 @@ class BertVector:
|
||||
|
||||
if __name__ == "__main__":
|
||||
bert = BertVector()
|
||||
while True:
|
||||
question = input('question: ')
|
||||
vectors = bert.encode([question])
|
||||
# while True:
|
||||
# question = input('question: ')
|
||||
vectors = bert.encode(['你好', '哈哈'])
|
||||
print(str(vectors))
|
||||
|
Loading…
Reference in New Issue
Block a user