修改多sentence的bug

This commit is contained in:
joe 2019-01-30 19:05:44 +08:00
parent 7c3c1ad29c
commit b7cb27b59e
2 changed files with 14 additions and 7 deletions

4
.gitignore vendored
View File

@ -1,4 +1,6 @@
/chinese_L-12_H-768_A-12
tmp/
__pycache__/
.idea/
.idea/
data/data_merger.py
data/data.py

View File

@ -5,6 +5,9 @@ import args
from queue import Queue
from threading import Thread
import tensorflow as tf
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
class InputExample(object):
@ -43,7 +46,7 @@ class BertVector:
self.input_queue = Queue(maxsize=1)
self.output_queue = Queue(maxsize=1)
self.predict_thread = Thread(target=self.predict_from_queue, daemon=True)
self.predict_thread.start()
self.sentence_len = 0
def get_estimator(self):
from tensorflow.python.estimator.estimator import Estimator
@ -80,6 +83,8 @@ class BertVector:
self.output_queue.put(i)
def encode(self, sentence):
self.sentence_len = len(sentence)
self.predict_thread.start()
self.input_queue.put(sentence)
prediction = self.output_queue.get()
return prediction
@ -93,7 +98,7 @@ class BertVector:
'input_mask': tf.int32,
'input_type_ids': tf.int32},
output_shapes={
'unique_ids': (1,),
'unique_ids': (self.sentence_len,),
'input_ids': (None, self.max_seq_length),
'input_mask': (None, self.max_seq_length),
'input_type_ids': (None, self.max_seq_length)}))
@ -327,7 +332,7 @@ class BertVector:
if __name__ == "__main__":
bert = BertVector()
while True:
question = input('question: ')
vectors = bert.encode([question])
print(str(vectors))
# while True:
# question = input('question: ')
vectors = bert.encode(['你好', '哈哈'])
print(str(vectors))