diff --git a/README.md b/README.md index 0403975..2265025 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ 生成句向量不需要做fine tune,使用预先训练好的模型即可,可参考`extract_feature.py`的`main`方法,注意参数必须是一个list。 -第一次生成句向量时需要加载graph,速度比较慢,后续速度会很快 +首次生成句向量时需要加载graph,并在output_dir路径下生成一个新的graph文件,因此速度比较慢,再次调用速度会很快 ``` from bert.extrac_feature import BertVector bv = BertVector() diff --git a/args.py b/args.py index 3cc7441..f7db308 100644 --- a/args.py +++ b/args.py @@ -1,31 +1,26 @@ import os -from enum import Enum +import tensorflow as tf + +tf.logging.set_verbosity(tf.logging.INFO) file_path = os.path.dirname(__file__) model_dir = os.path.join(file_path, 'chinese_L-12_H-768_A-12/') config_name = os.path.join(model_dir, 'bert_config.json') ckpt_name = os.path.join(model_dir, 'bert_model.ckpt') - output_dir = os.path.join(model_dir, '../tmp/result/') - vocab_file = os.path.join(model_dir, 'vocab.txt') data_dir = os.path.join(model_dir, '../data/') -max_seq_len = 32 - -layer_indexes = [-2, -3, -4] - +num_train_epochs = 10 batch_size = 128 - -gpu_memory_fraction = 0.8 - learning_rate = 0.00005 -num_train_epochs = 10 +# gpu使用率 +gpu_memory_fraction = 0.8 -use_gpu = False -if use_gpu: - device_id = '0' -else: - device_id = '-1' +# 默认取倒数第二层的输出值作为句向量 +layer_indexes = [-2] + +# 序列的最大程度,单文本建议把该值调小 +max_seq_len = 32 diff --git a/extract_feature.py b/extract_feature.py index 5323167..9540f1e 100644 --- a/extract_feature.py +++ b/extract_feature.py @@ -1,12 +1,10 @@ -from graph import import_tf import modeling import tokenization from graph import optimize_graph import args from queue import Queue from threading import Thread - -tf = import_tf(0, True) +import tensorflow as tf class InputExample(object): diff --git a/graph.py b/graph.py index cc86a4b..ff5ec4c 100644 --- a/graph.py +++ b/graph.py @@ -1,19 +1,10 @@ -import os import tempfile import json import logging from termcolor import colored import modeling import args -import contextlib - - -def import_tf(device_id=-1, verbose=False): - os.environ['CUDA_VISIBLE_DEVICES'] = '-1' if device_id < 0 else str(device_id) - os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if verbose else '3' - import tensorflow as tf - tf.logging.set_verbosity(tf.logging.DEBUG if verbose else tf.logging.ERROR) - return tf +import tensorflow as tf def set_logger(context, verbose=False): @@ -35,7 +26,6 @@ def optimize_graph(logger=None, verbose=False): logger = set_logger(colored('BERT_VEC', 'yellow'), verbose) try: # we don't need GPU for optimizing the graph - tf = import_tf(device_id=0, verbose=verbose) from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference # allow_soft_placement:自动选择运行设备 @@ -75,9 +65,7 @@ def optimize_graph(logger=None, verbose=False): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) - minus_mask = lambda x, m: x - tf.expand_dims(1.0 - m, axis=-1) * 1e30 mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1) - masked_reduce_max = lambda x, m: tf.reduce_max(minus_mask(x, m), axis=1) masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / ( tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10) @@ -113,7 +101,7 @@ def optimize_graph(logger=None, verbose=False): [n.name[:-2] for n in output_tensors], [dtype.as_datatype_enum for dtype in dtypes], False) - tmp_file = tempfile.NamedTemporaryFile('w', delete=False).name + tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name logger.info('write graph to a tmp file: %s' % tmp_file) with tf.gfile.GFile(tmp_file, 'wb') as f: f.write(tmp_g.SerializeToString()) diff --git a/similarity.py b/similarity.py index 0472070..6a77293 100644 --- a/similarity.py +++ b/similarity.py @@ -181,7 +181,7 @@ class BertSim: def model_fn_builder(self, bert_config, num_labels, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_one_hot_embeddings): - """Returns `model_fn` closure for TPUEstimator.""" + """Returns `model_fn` closurimport_tfe for TPUEstimator.""" def model_fn(features, labels, mode, params): # pylint: disable=unused-argument from tensorflow.python.estimator.model_fn import EstimatorSpec