Update extract_keras_bert_feature.py

This commit is contained in:
yongzhuo 2019-05-31 14:27:28 +08:00 committed by GitHub
parent 97bd11a7e7
commit d5b1d3f256
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -45,7 +45,7 @@ class KerasBertVector():
layer_dict = []
layer_0 = 7
for i in range(24):
layer_0 = layer_0 + 4
layer_0 = layer_0 + 8
layer_dict.append(layer_0)
# 输出它本身
if len(layer_indexes) == 0:
@ -58,7 +58,7 @@ class KerasBertVector():
encoder_layer = model.get_layer(index=layer_dict[-1]).output
# 否则遍历需要取的层把所有层的weight取出来并拼接起来shape:768*层数
else:
# layer_indexes must be [1,2,3,......12...24]
# layer_indexes must be [1,2,3,......12]
# all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes]
all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(23)]
else model.get_layer(index=layer_dict[-1]).output #如果给出不正确,就默认输出最后一层