优化句向量生成速度
This commit is contained in:
parent
98219427aa
commit
bcec40195c
5
args.py
5
args.py
@ -23,4 +23,7 @@ gpu_memory_fraction = 0.8
|
|||||||
layer_indexes = [-2]
|
layer_indexes = [-2]
|
||||||
|
|
||||||
# 序列的最大程度,单文本建议把该值调小
|
# 序列的最大程度,单文本建议把该值调小
|
||||||
max_seq_len = 32
|
max_seq_len = 5
|
||||||
|
|
||||||
|
# graph名字
|
||||||
|
graph_file = 'tmp/result/graph'
|
@ -39,7 +39,11 @@ class BertVector:
|
|||||||
self.max_seq_length = args.max_seq_len
|
self.max_seq_length = args.max_seq_len
|
||||||
self.layer_indexes = args.layer_indexes
|
self.layer_indexes = args.layer_indexes
|
||||||
self.gpu_memory_fraction = 1
|
self.gpu_memory_fraction = 1
|
||||||
|
if os.path.exists(args.graph_file):
|
||||||
|
self.graph_path = args.graph_file
|
||||||
|
else:
|
||||||
self.graph_path = optimize_graph()
|
self.graph_path = optimize_graph()
|
||||||
|
|
||||||
self.tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
|
self.tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.estimator = self.get_estimator()
|
self.estimator = self.get_estimator()
|
||||||
@ -75,7 +79,7 @@ class BertVector:
|
|||||||
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
|
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
|
||||||
|
|
||||||
return Estimator(model_fn=model_fn, config=RunConfig(session_config=config),
|
return Estimator(model_fn=model_fn, config=RunConfig(session_config=config),
|
||||||
params={'batch_size': self.batch_size})
|
params={'batch_size': self.batch_size}, model_dir='../tmp')
|
||||||
|
|
||||||
def predict_from_queue(self):
|
def predict_from_queue(self):
|
||||||
prediction = self.estimator.predict(input_fn=self.queue_predict_input_fn, yield_single_examples=False)
|
prediction = self.estimator.predict(input_fn=self.queue_predict_input_fn, yield_single_examples=False)
|
||||||
@ -330,7 +334,8 @@ class BertVector:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
bert = BertVector()
|
bert = BertVector()
|
||||||
# while True:
|
|
||||||
# question = input('question: ')
|
while True:
|
||||||
vectors = bert.encode(['你好', '哈哈'])
|
question = input('question: ')
|
||||||
print(str(vectors))
|
v = bert.encode([question])
|
||||||
|
print(str(v))
|
||||||
|
6
graph.py
6
graph.py
@ -1,4 +1,3 @@
|
|||||||
import tempfile
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
@ -7,8 +6,6 @@ import args
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
||||||
|
|
||||||
|
|
||||||
def set_logger(context, verbose=False):
|
def set_logger(context, verbose=False):
|
||||||
logger = logging.getLogger(context)
|
logger = logging.getLogger(context)
|
||||||
@ -104,7 +101,8 @@ def optimize_graph(logger=None, verbose=False):
|
|||||||
[n.name[:-2] for n in output_tensors],
|
[n.name[:-2] for n in output_tensors],
|
||||||
[dtype.as_datatype_enum for dtype in dtypes],
|
[dtype.as_datatype_enum for dtype in dtypes],
|
||||||
False)
|
False)
|
||||||
tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name
|
# tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name
|
||||||
|
tmp_file = args.graph_file
|
||||||
logger.info('write graph to a tmp file: %s' % tmp_file)
|
logger.info('write graph to a tmp file: %s' % tmp_file)
|
||||||
with tf.gfile.GFile(tmp_file, 'wb') as f:
|
with tf.gfile.GFile(tmp_file, 'wb') as f:
|
||||||
f.write(tmp_g.SerializeToString())
|
f.write(tmp_g.SerializeToString())
|
||||||
|
Loading…
Reference in New Issue
Block a user