Update extract_keras_bert_feature.py

This commit is contained in:
yongzhuo 2019-07-23 22:25:16 +08:00 committed by GitHub
parent e5709962af
commit 5bda829469
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -41,8 +41,8 @@ class KerasBertVector():
# lay = model.layers # lay = model.layers
#一共104个layer其中前八层包括token,pos,embed等 #一共104个layer其中前八层包括token,pos,embed等
# 每4层MultiHeadAttention,Dropout,Add,LayerNormalization # 每4层MultiHeadAttention,Dropout,Add,LayerNormalization
# 一共24 # 一共12层transformer
layer_dict = [] layer_dict = [7]
layer_0 = 7 layer_0 = 7
for i in range(12): for i in range(12):
layer_0 = layer_0 + 8 layer_0 = layer_0 + 8
@ -50,18 +50,18 @@ class KerasBertVector():
# 输出它本身 # 输出它本身
if len(layer_indexes) == 0: if len(layer_indexes) == 0:
encoder_layer = model.output encoder_layer = model.output
# 分类如果只有一层,就只取最后那一层的weight取得不正确就默认取最后一层输 # 分类如果只有一层,就只取倒数第二层的weight取得不正确就默认取倒数第二层
elif len(layer_indexes) == 1: elif len(layer_indexes) == 1:
if layer_indexes[0] in [i+1 for i in range(23)]: if layer_indexes[0] in [i+1 for i in range(12)]:
encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output
else: else:
encoder_layer = model.get_layer(index=layer_dict[-1]).output encoder_layer = model.get_layer(index=layer_dict[-2]).output
# 否则遍历需要取的层把所有层的weight取出来并拼接起来shape:768*层数 # 否则遍历需要取的层把所有层的weight取出来并拼接起来shape:768*层数
else: else:
# layer_indexes must be [1,2,3,......12] # layer_indexes must be [1,2,3,......12]
# all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes] # all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes]
all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(23)] all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(12)]
else model.get_layer(index=layer_dict[-1]).output #如果给出不正确,就默认输出最后一层 else model.get_layer(index=layer_dict[-2]).output #如果给出不正确,就默认输出最后一层
for lay in layer_indexes] for lay in layer_indexes]
print(layer_indexes) print(layer_indexes)
print(all_layers) print(all_layers)