From 5bda829469ef046ca241d2e59fee6d34d4419fe5 Mon Sep 17 00:00:00 2001 From: yongzhuo <31341349+yongzhuo@users.noreply.github.com> Date: Tue, 23 Jul 2019 22:25:16 +0800 Subject: [PATCH] Update extract_keras_bert_feature.py --- FeatureProject/bert/extract_keras_bert_feature.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/FeatureProject/bert/extract_keras_bert_feature.py b/FeatureProject/bert/extract_keras_bert_feature.py index 9b05cb9..bd23448 100644 --- a/FeatureProject/bert/extract_keras_bert_feature.py +++ b/FeatureProject/bert/extract_keras_bert_feature.py @@ -41,8 +41,8 @@ class KerasBertVector(): # lay = model.layers #一共104个layer,其中前八层包括token,pos,embed等, # 每4层(MultiHeadAttention,Dropout,Add,LayerNormalization) - # 一共24层 - layer_dict = [] + # 一共12层transformer + layer_dict = [7] layer_0 = 7 for i in range(12): layer_0 = layer_0 + 8 @@ -50,18 +50,18 @@ class KerasBertVector(): # 输出它本身 if len(layer_indexes) == 0: encoder_layer = model.output - # 分类如果只有一层,就只取最后那一层的weight;取得不正确,就默认取最后一层输出 + # 分类如果只有一层,就只取倒数第二层的weight;取得不正确,就默认取倒数第二层出 elif len(layer_indexes) == 1: - if layer_indexes[0] in [i+1 for i in range(23)]: + if layer_indexes[0] in [i+1 for i in range(12)]: encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output else: - encoder_layer = model.get_layer(index=layer_dict[-1]).output + encoder_layer = model.get_layer(index=layer_dict[-2]).output # 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数 else: # layer_indexes must be [1,2,3,......12] # all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes] - all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(23)] - else model.get_layer(index=layer_dict[-1]).output #如果给出不正确,就默认输出最后一层 + all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(12)] + else model.get_layer(index=layer_dict[-2]).output #如果给出不正确,就默认输出最后一层 for lay in layer_indexes] print(layer_indexes) print(all_layers)