优化句向量生成速度

This commit is contained in:
joe 2019-07-01 10:07:54 +08:00
parent 98219427aa
commit bcec40195c
3 changed files with 17 additions and 11 deletions

View File

@ -23,4 +23,7 @@ gpu_memory_fraction = 0.8
layer_indexes = [-2]
# 序列的最大程度,单文本建议把该值调小
max_seq_len = 32
max_seq_len = 5
# graph名字
graph_file = 'tmp/result/graph'

View File

@ -39,7 +39,11 @@ class BertVector:
self.max_seq_length = args.max_seq_len
self.layer_indexes = args.layer_indexes
self.gpu_memory_fraction = 1
if os.path.exists(args.graph_file):
self.graph_path = args.graph_file
else:
self.graph_path = optimize_graph()
self.tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
self.batch_size = batch_size
self.estimator = self.get_estimator()
@ -75,7 +79,7 @@ class BertVector:
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
return Estimator(model_fn=model_fn, config=RunConfig(session_config=config),
params={'batch_size': self.batch_size})
params={'batch_size': self.batch_size}, model_dir='../tmp')
def predict_from_queue(self):
prediction = self.estimator.predict(input_fn=self.queue_predict_input_fn, yield_single_examples=False)
@ -330,7 +334,8 @@ class BertVector:
if __name__ == "__main__":
bert = BertVector()
# while True:
# question = input('question: ')
vectors = bert.encode(['你好', '哈哈'])
print(str(vectors))
while True:
question = input('question: ')
v = bert.encode([question])
print(str(v))

View File

@ -1,4 +1,3 @@
import tempfile
import json
import logging
from termcolor import colored
@ -7,8 +6,6 @@ import args
import tensorflow as tf
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def set_logger(context, verbose=False):
logger = logging.getLogger(context)
@ -104,7 +101,8 @@ def optimize_graph(logger=None, verbose=False):
[n.name[:-2] for n in output_tensors],
[dtype.as_datatype_enum for dtype in dtypes],
False)
tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name
# tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name
tmp_file = args.graph_file
logger.info('write graph to a tmp file: %s' % tmp_file)
with tf.gfile.GFile(tmp_file, 'wb') as f:
f.write(tmp_g.SerializeToString())