bert-utils/graph.py

113 lines
4.8 KiB
Python
Raw Normal View History

2019-01-29 18:31:51 +08:00
import json
import logging
from termcolor import colored
import modeling
import args
2019-01-30 11:39:49 +08:00
import tensorflow as tf
import os
2019-01-29 18:31:51 +08:00
def set_logger(context, verbose=False):
logger = logging.getLogger(context)
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
formatter = logging.Formatter(
'%(levelname)-.1s:' + context + ':[%(filename).5s:%(funcName).3s:%(lineno)3d]:%(message)s', datefmt=
'%m-%d %H:%M:%S')
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG if verbose else logging.INFO)
console_handler.setFormatter(formatter)
logger.handlers = []
logger.addHandler(console_handler)
return logger
def optimize_graph(logger=None, verbose=False):
if not logger:
logger = set_logger(colored('BERT_VEC', 'yellow'), verbose)
try:
# we don't need GPU for optimizing the graph
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
tf.gfile.MakeDirs(args.output_dir)
2019-01-29 18:31:51 +08:00
config_fp = args.config_name
logger.info('model config: %s' % config_fp)
# 加载bert配置文件
with tf.gfile.GFile(config_fp, 'r') as f:
bert_config = modeling.BertConfig.from_dict(json.load(f))
logger.info('build graph...')
# input placeholders, not sure if they are friendly to XLA
input_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_ids')
input_mask = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_mask')
input_type_ids = tf.placeholder(tf.int32, (None, args.max_seq_len), 'input_type_ids')
2019-01-30 11:12:09 +08:00
jit_scope = tf.contrib.compiler.jit.experimental_jit_scope
2019-01-29 18:31:51 +08:00
with jit_scope():
input_tensors = [input_ids, input_mask, input_type_ids]
model = modeling.BertModel(
config=bert_config,
is_training=False,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=input_type_ids,
use_one_hot_embeddings=False)
# 获取所有要训练的变量
tvars = tf.trainable_variables()
init_checkpoint = args.ckpt_name
2019-01-29 18:31:51 +08:00
(assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars,
init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
# 共享卷积核
with tf.variable_scope("pooling"):
# 如果只有一层就只取对应那一层的weight
if len(args.layer_indexes) == 1:
encoder_layer = model.all_encoder_layers[args.layer_indexes[0]]
else:
# 否则遍历需要取的层把所有层的weight取出来并拼接起来shape:768*层数
all_layers = [model.all_encoder_layers[l] for l in args.layer_indexes]
encoder_layer = tf.concat(all_layers, -1)
mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1)
masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / (
tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10)
2019-01-29 18:31:51 +08:00
input_mask = tf.cast(input_mask, tf.float32)
2019-01-29 18:31:51 +08:00
# 以下代码是句向量的生成方法,可以理解为做了一个卷积的操作,但是没有把结果相加, 卷积核是input_mask
pooled = masked_reduce_mean(encoder_layer, input_mask)
pooled = tf.identity(pooled, 'final_encodes')
output_tensors = [pooled]
tmp_g = tf.get_default_graph().as_graph_def()
# allow_soft_placement:自动选择运行设备
config = tf.ConfigProto(allow_soft_placement=True)
2019-01-29 18:31:51 +08:00
with tf.Session(config=config) as sess:
logger.info('load parameters from checkpoint...')
sess.run(tf.global_variables_initializer())
logger.info('freeze...')
tmp_g = tf.graph_util.convert_variables_to_constants(sess, tmp_g, [n.name[:-2] for n in output_tensors])
dtypes = [n.dtype for n in input_tensors]
logger.info('optimize...')
tmp_g = optimize_for_inference(
tmp_g,
[n.name[:-2] for n in input_tensors],
[n.name[:-2] for n in output_tensors],
[dtype.as_datatype_enum for dtype in dtypes],
False)
2019-07-01 10:07:54 +08:00
# tmp_file = tempfile.NamedTemporaryFile('w', delete=False, dir=args.output_dir).name
tmp_file = args.graph_file
2019-01-29 18:31:51 +08:00
logger.info('write graph to a tmp file: %s' % tmp_file)
with tf.gfile.GFile(tmp_file, 'wb') as f:
f.write(tmp_g.SerializeToString())
return tmp_file
except Exception as e:
logger.error('fail to optimize the graph!')
logger.error(e)