修改配置参数
This commit is contained in:
parent
e0de69a52e
commit
7c3c1ad29c
@ -12,7 +12,7 @@
|
||||
|
||||
生成句向量不需要做fine tune,使用预先训练好的模型即可,可参考`extract_feature.py`的`main`方法,注意参数必须是一个list。
|
||||
|
||||
第一次生成句向量时需要加载graph,速度比较慢,后续速度会很快
|
||||
首次生成句向量时需要加载graph,并在output_dir路径下生成一个新的graph文件,因此速度比较慢,再次调用速度会很快
|
||||
```
|
||||
from bert.extrac_feature import BertVector
|
||||
bv = BertVector()
|
||||
|
27
args.py
27
args.py
@ -1,31 +1,26 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
import tensorflow as tf
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
file_path = os.path.dirname(__file__)
|
||||
|
||||
model_dir = os.path.join(file_path, 'chinese_L-12_H-768_A-12/')
|
||||
config_name = os.path.join(model_dir, 'bert_config.json')
|
||||
ckpt_name = os.path.join(model_dir, 'bert_model.ckpt')
|
||||
|
||||
output_dir = os.path.join(model_dir, '../tmp/result/')
|
||||
|
||||
vocab_file = os.path.join(model_dir, 'vocab.txt')
|
||||
data_dir = os.path.join(model_dir, '../data/')
|
||||
|
||||
max_seq_len = 32
|
||||
|
||||
layer_indexes = [-2, -3, -4]
|
||||
|
||||
num_train_epochs = 10
|
||||
batch_size = 128
|
||||
|
||||
gpu_memory_fraction = 0.8
|
||||
|
||||
learning_rate = 0.00005
|
||||
|
||||
num_train_epochs = 10
|
||||
# gpu使用率
|
||||
gpu_memory_fraction = 0.8
|
||||
|
||||
use_gpu = False
|
||||
if use_gpu:
|
||||
device_id = '0'
|
||||
else:
|
||||
device_id = '-1'
|
||||
# 默认取倒数第二层的输出值作为句向量
|
||||
layer_indexes = [-2]
|
||||
|
||||
# 序列的最大程度,单文本建议把该值调小
|
||||
max_seq_len = 32
|
||||
|
@ -1,12 +1,10 @@
|
||||
from graph import import_tf
|
||||
import modeling
|
||||
import tokenization
|
||||
from graph import optimize_graph
|
||||
import args
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
tf = import_tf(0, True)
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
|
16
graph.py
16
graph.py
@ -1,19 +1,10 @@
|
||||
import os
|
||||
import tempfile
|
||||
import json
|
||||
import logging
|
||||
from termcolor import colored
|
||||
import modeling
|
||||
import args
|
||||
import contextlib
|
||||
|
||||
|
||||
def import_tf(device_id=-1, verbose=False):
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' if device_id < 0 else str(device_id)
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if verbose else '3'
|
||||
import tensorflow as tf
|
||||
tf.logging.set_verbosity(tf.logging.DEBUG if verbose else tf.logging.ERROR)
|
||||
return tf
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def set_logger(context, verbose=False):
|
||||
@ -35,7 +26,6 @@ def optimize_graph(logger=None, verbose=False):
|
||||
logger = set_logger(colored('BERT_VEC', 'yellow'), verbose)
|
||||
try:
|
||||
# we don't need GPU for optimizing the graph
|
||||
tf = import_tf(device_id=0, verbose=verbose)
|
||||
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
|
||||
|
||||
# allow_soft_placement:自动选择运行设备
|
||||
@ -75,9 +65,7 @@ def optimize_graph(logger=None, verbose=False):
|
||||
|
||||
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
|
||||
|
||||
minus_mask = lambda x, m: x - tf.expand_dims(1.0 - m, axis=-1) * 1e30
|
||||
mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
|
||||
masked_reduce_max = lambda x, m: tf.reduce_max(minus_mask(x, m), axis=1)
|
||||
masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
|
||||
tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)
|
||||
|
||||
@ -113,7 +101,7 @@ def optimize_graph(logger=None, verbose=False):
|
||||
[n.name[:-2] for n in output_tensors],
|
||||
[dtype.as_datatype_enum for dtype in dtypes],
|
||||
False)
|
||||
tmp_file = tempfile.NamedTemporaryFile('w', delete=False).name
|
||||
tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name
|
||||
logger.info('write graph to a tmp file: %s' % tmp_file)
|
||||
with tf.gfile.GFile(tmp_file, 'wb') as f:
|
||||
f.write(tmp_g.SerializeToString())
|
||||
|
@ -181,7 +181,7 @@ class BertSim:
|
||||
def model_fn_builder(self, bert_config, num_labels, init_checkpoint, learning_rate,
|
||||
num_train_steps, num_warmup_steps,
|
||||
use_one_hot_embeddings):
|
||||
"""Returns `model_fn` closure for TPUEstimator."""
|
||||
"""Returns `model_fn` closurimport_tfe for TPUEstimator."""
|
||||
|
||||
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
|
||||
from tensorflow.python.estimator.model_fn import EstimatorSpec
|
||||
|
Loading…
Reference in New Issue
Block a user