From bcec40195c92da851274115863d2b8f092060b91 Mon Sep 17 00:00:00 2001 From: joe Date: Mon, 1 Jul 2019 10:07:54 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=8F=A5=E5=90=91=E9=87=8F?= =?UTF-8?q?=E7=94=9F=E6=88=90=E9=80=9F=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- args.py | 5 ++++- extract_feature.py | 17 +++++++++++------ graph.py | 6 ++---- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/args.py b/args.py index f7db308..2ad011c 100644 --- a/args.py +++ b/args.py @@ -23,4 +23,7 @@ gpu_memory_fraction = 0.8 layer_indexes = [-2] # 序列的最大程度,单文本建议把该值调小 -max_seq_len = 32 +max_seq_len = 5 + +# graph名字 +graph_file = 'tmp/result/graph' \ No newline at end of file diff --git a/extract_feature.py b/extract_feature.py index 6ff2409..9a7b350 100644 --- a/extract_feature.py +++ b/extract_feature.py @@ -39,7 +39,11 @@ class BertVector: self.max_seq_length = args.max_seq_len self.layer_indexes = args.layer_indexes self.gpu_memory_fraction = 1 - self.graph_path = optimize_graph() + if os.path.exists(args.graph_file): + self.graph_path = args.graph_file + else: + self.graph_path = optimize_graph() + self.tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, do_lower_case=True) self.batch_size = batch_size self.estimator = self.get_estimator() @@ -75,7 +79,7 @@ class BertVector: config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 return Estimator(model_fn=model_fn, config=RunConfig(session_config=config), - params={'batch_size': self.batch_size}) + params={'batch_size': self.batch_size}, model_dir='../tmp') def predict_from_queue(self): prediction = self.estimator.predict(input_fn=self.queue_predict_input_fn, yield_single_examples=False) @@ -330,7 +334,8 @@ class BertVector: if __name__ == "__main__": bert = BertVector() - # while True: - # question = input('question: ') - vectors = bert.encode(['你好', '哈哈']) - print(str(vectors)) + + while True: + question = input('question: ') + v = bert.encode([question]) + print(str(v)) diff --git a/graph.py b/graph.py index 7313cb6..47443bf 100644 --- a/graph.py +++ b/graph.py @@ -1,4 +1,3 @@ -import tempfile import json import logging from termcolor import colored @@ -7,8 +6,6 @@ import args import tensorflow as tf import os -os.environ['CUDA_VISIBLE_DEVICES'] = '0' - def set_logger(context, verbose=False): logger = logging.getLogger(context) @@ -104,7 +101,8 @@ def optimize_graph(logger=None, verbose=False): [n.name[:-2] for n in output_tensors], [dtype.as_datatype_enum for dtype in dtypes], False) - tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name + # tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name + tmp_file = args.graph_file logger.info('write graph to a tmp file: %s' % tmp_file) with tf.gfile.GFile(tmp_file, 'wb') as f: f.write(tmp_g.SerializeToString())