Update extract_keras_bert_feature.py

This commit is contained in:
yongzhuo 2019-07-23 22:25:16 +08:00 committed by GitHub
parent e5709962af
commit 5bda829469
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -41,8 +41,8 @@ class KerasBertVector():
# lay = model.layers
#一共104个layer其中前八层包括token,pos,embed等
# 每4层MultiHeadAttention,Dropout,Add,LayerNormalization
# 一共24
layer_dict = []
# 一共12层transformer
layer_dict = [7]
layer_0 = 7
for i in range(12):
layer_0 = layer_0 + 8
@ -50,18 +50,18 @@ class KerasBertVector():
# 输出它本身
if len(layer_indexes) == 0:
encoder_layer = model.output
# 分类如果只有一层,就只取最后那一层的weight取得不正确就默认取最后一层输
# 分类如果只有一层,就只取倒数第二层的weight取得不正确就默认取倒数第二层
elif len(layer_indexes) == 1:
if layer_indexes[0] in [i+1 for i in range(23)]:
if layer_indexes[0] in [i+1 for i in range(12)]:
encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output
else:
encoder_layer = model.get_layer(index=layer_dict[-1]).output
encoder_layer = model.get_layer(index=layer_dict[-2]).output
# 否则遍历需要取的层把所有层的weight取出来并拼接起来shape:768*层数
else:
# layer_indexes must be [1,2,3,......12]
# all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes]
all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(23)]
else model.get_layer(index=layer_dict[-1]).output #如果给出不正确,就默认输出最后一层
all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(12)]
else model.get_layer(index=layer_dict[-2]).output #如果给出不正确,就默认输出最后一层
for lay in layer_indexes]
print(layer_indexes)
print(all_layers)