From b51aec0d4fa4f83960d957450259a421890ca47b Mon Sep 17 00:00:00 2001 From: yongzhuo <31341349+yongzhuo@users.noreply.github.com> Date: Thu, 30 May 2019 21:51:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9bert=E7=9A=84Embedding?= =?UTF-8?q?=E5=B1=82layer=E5=8F=96=E9=94=99=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../bert/keras_bert_embedding.py | 45 +++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/ClassificationText/bert/keras_bert_embedding.py b/ClassificationText/bert/keras_bert_embedding.py index 64a22b7..00ceb10 100644 --- a/ClassificationText/bert/keras_bert_embedding.py +++ b/ClassificationText/bert/keras_bert_embedding.py @@ -4,12 +4,14 @@ # @author :Mo # @function :embedding of bert keras -from conf.feature_config import gpu_memory_fraction, config_name, ckpt_name, vocab_file, max_seq_len, layer_indexes +from ClassificationText.bert.args import gpu_memory_fraction, max_seq_len, layer_indexes +from conf.feature_config import config_name, ckpt_name, vocab_file from FeatureProject.bert.layers_keras import NonMaskingLayer from keras_bert import load_trained_model_from_checkpoint import keras.backend.tensorflow_backend as ktf_keras import keras.backend as k_keras from keras.models import Model +from keras.layers import Add import tensorflow as tf import os @@ -29,7 +31,7 @@ class KerasBertEmbedding(): def __init__(self): self.config_path, self.checkpoint_path, self.dict_path, self.max_seq_len = config_name, ckpt_name, vocab_file, max_seq_len - def bert_encode(self): + def bert_encode(self, layer_indexes=[12]): # 全局使用,使其可以django、flask、tornado等调用 global graph graph = tf.get_default_graph() @@ -37,14 +39,40 @@ class KerasBertEmbedding(): model = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path, seq_len=self.max_seq_len) print(model.output) - # 分类如果只选一层,就只取最后那一层的weight - if len(layer_indexes) == 1: - encoder_layer = model.get_layer(index=len(model.layers)-1).output + print(len(model.layers)) + # lay = model.layers + #一共104个layer,其中前八层包括token,pos,embed等, + # 每4层(MultiHeadAttention,Dropout,Add,LayerNormalization) + # 一共24层 + layer_dict = [] + layer_0 = 7 + for i in range(24): + layer_0 = layer_0 + 4 + layer_dict.append(layer_0) + # 输出它本身 + if len(layer_indexes) == 0: + encoder_layer = model.output + # 分类如果只有一层,就只取最后那一层的weight,取得不正确 + elif len(layer_indexes) == 1: + if layer_indexes[0] in [i+1 for i in range(23)]: + encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output + else: + encoder_layer = model.get_layer(index=layer_dict[-1]).output # 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数 else: - # layer_indexes must be [1,2,3,......12] - all_layers = [model.get_layer(index=lay).output for lay in layer_indexes] - encoder_layer = k_keras.concatenate(all_layers, -1) + # layer_indexes must be [1,2,3,......12...24] + # all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes] + all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(23)] + else model.get_layer(index=layer_dict[-1]).output #如果给出不正确,就默认输出最后一层 + for lay in layer_indexes] + print(layer_indexes) + print(all_layers) + # 其中layer==1的output是格式不对,第二层输入input是list + all_layers_select = [] + for all_layers_one in all_layers: + all_layers_select.append(all_layers_one) + encoder_layer = Add()(all_layers_select) + print(encoder_layer.shape) print("KerasBertEmbedding:") print(encoder_layer.shape) output_layer = NonMaskingLayer()(encoder_layer) @@ -56,3 +84,4 @@ class KerasBertEmbedding(): if __name__ == "__main__": bert_vector = KerasBertEmbedding() pooled = bert_vector.bert_encode() +