优化句向量生成速度
This commit is contained in:
parent
98219427aa
commit
bcec40195c
5
args.py
5
args.py
@ -23,4 +23,7 @@ gpu_memory_fraction = 0.8
|
||||
layer_indexes = [-2]
|
||||
|
||||
# 序列的最大程度,单文本建议把该值调小
|
||||
max_seq_len = 32
|
||||
max_seq_len = 5
|
||||
|
||||
# graph名字
|
||||
graph_file = 'tmp/result/graph'
|
@ -39,7 +39,11 @@ class BertVector:
|
||||
self.max_seq_length = args.max_seq_len
|
||||
self.layer_indexes = args.layer_indexes
|
||||
self.gpu_memory_fraction = 1
|
||||
self.graph_path = optimize_graph()
|
||||
if os.path.exists(args.graph_file):
|
||||
self.graph_path = args.graph_file
|
||||
else:
|
||||
self.graph_path = optimize_graph()
|
||||
|
||||
self.tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=True)
|
||||
self.batch_size = batch_size
|
||||
self.estimator = self.get_estimator()
|
||||
@ -75,7 +79,7 @@ class BertVector:
|
||||
config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
|
||||
|
||||
return Estimator(model_fn=model_fn, config=RunConfig(session_config=config),
|
||||
params={'batch_size': self.batch_size})
|
||||
params={'batch_size': self.batch_size}, model_dir='../tmp')
|
||||
|
||||
def predict_from_queue(self):
|
||||
prediction = self.estimator.predict(input_fn=self.queue_predict_input_fn, yield_single_examples=False)
|
||||
@ -330,7 +334,8 @@ class BertVector:
|
||||
|
||||
if __name__ == "__main__":
|
||||
bert = BertVector()
|
||||
# while True:
|
||||
# question = input('question: ')
|
||||
vectors = bert.encode(['你好', '哈哈'])
|
||||
print(str(vectors))
|
||||
|
||||
while True:
|
||||
question = input('question: ')
|
||||
v = bert.encode([question])
|
||||
print(str(v))
|
||||
|
6
graph.py
6
graph.py
@ -1,4 +1,3 @@
|
||||
import tempfile
|
||||
import json
|
||||
import logging
|
||||
from termcolor import colored
|
||||
@ -7,8 +6,6 @@ import args
|
||||
import tensorflow as tf
|
||||
import os
|
||||
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||||
|
||||
|
||||
def set_logger(context, verbose=False):
|
||||
logger = logging.getLogger(context)
|
||||
@ -104,7 +101,8 @@ def optimize_graph(logger=None, verbose=False):
|
||||
[n.name[:-2] for n in output_tensors],
|
||||
[dtype.as_datatype_enum for dtype in dtypes],
|
||||
False)
|
||||
tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name
|
||||
# tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name
|
||||
tmp_file = args.graph_file
|
||||
logger.info('write graph to a tmp file: %s' % tmp_file)
|
||||
with tf.gfile.GFile(tmp_file, 'wb') as f:
|
||||
f.write(tmp_g.SerializeToString())
|
||||
|
Loading…
Reference in New Issue
Block a user