bert-utils/graph.py
2019-01-29 18:31:51 +08:00

125 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import tempfile
import json
import logging
from termcolor import colored
import modeling
import args
import contextlib
def import_tf(device_id=-1, verbose=False):
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' if device_id < 0 else str(device_id)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if verbose else '3'
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.DEBUG if verbose else tf.logging.ERROR)
return tf
def set_logger(context, verbose=False):
logger = logging.getLogger(context)
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
formatter = logging.Formatter(
'%(levelname)-.1s:' + context + ':[%(filename).5s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt=
'%m-%d %H:%M:%S')
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
console_handler.setFormatter(formatter)
logger.handlers = []
logger.addHandler(console_handler)
return logger
def optimize_graph(logger=None, verbose=False):
if not logger:
logger = set_logger(colored('BERT_VEC', 'yellow'), verbose)
try:
# we don't need GPU for optimizing the graph
tf = import_tf(device_id=0, verbose=verbose)
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
# allow_soft_placement:自动选择运行设备
config = tf.ConfigProto(allow_soft_placement=True)
config_fp = args.config_name
init_checkpoint = args.ckpt_name
logger.info('model config: %s' % config_fp)
# 加载bert配置文件
with tf.gfile.GFile(config_fp, 'r') as f:
bert_config = modeling.BertConfig.from_dict(json.load(f))
logger.info('build graph...')
# input placeholders, not sure if they are friendly to XLA
input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask')
input_type_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_type_ids')
# xla加速
jit_scope = tf.contrib.compiler.jit.experimental_jit_scope if args.xla else contextlib.suppress
with jit_scope():
input_tensors = [input_ids, input_mask, input_type_ids]
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=input_type_ids,
use_one_hot_embeddings=False)
# 获取所有要训练的变量
tvars = tf.trainable_variables()
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,
init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
minus_mask = lambda x, m: x - tf.expand_dims(1.0 - m, axis=-1) * 1e30
mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
masked_reduce_max = lambda x, m: tf.reduce_max(minus_mask(x, m), axis=1)
masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)
# 共享卷积核
with tf.variable_scope("pooling"):
# 如果只有一层就只取对应那一层的weight
if len(args.layer_indexes) == 1:
encoder_layer = model.all_encoder_layers[args.layer_indexes[0]]
else:
# 否则遍历需要取的层把所有层的weight取出来并拼接起来shape:768*层数
all_layers = [model.all_encoder_layers[l] for l in args.layer_indexes]
encoder_layer = tf.concat(all_layers, -1)
input_mask = tf.cast(input_mask, tf.float32)
# 以下代码是句向量的生成方法,可以理解为做了一个卷积的操作,但是没有把结果相加, 卷积核是input_mask
pooled = masked_reduce_mean(encoder_layer, input_mask)
pooled = tf.identity(pooled, 'final_encodes')
output_tensors = [pooled]
tmp_g = tf.get_default_graph().as_graph_def()
with tf.Session(config=config) as sess:
logger.info('load parameters from checkpoint...')
sess.run(tf.global_variables_initializer())
logger.info('freeze...')
tmp_g = tf.graph_util.convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors])
dtypes = [n.dtype for n in input_tensors]
logger.info('optimize...')
tmp_g = optimize_for_inference(
tmp_g,
[n.name[:-2] for n in input_tensors],
[n.name[:-2] for n in output_tensors],
[dtype.as_datatype_enum for dtype in dtypes],
False)
tmp_file = tempfile.NamedTemporaryFile('w', delete=False).name
logger.info('write graph to a tmp file: %s' % tmp_file)
with tf.gfile.GFile(tmp_file, 'wb') as f:
f.write(tmp_g.SerializeToString())
return tmp_file
except Exception as e:
logger.error('fail to optimize the graph!')
logger.error(e)