Update extract_keras_bert_feature.py
This commit is contained in:
parent
97bd11a7e7
commit
d5b1d3f256
@ -45,7 +45,7 @@ class KerasBertVector():
|
||||
layer_dict = []
|
||||
layer_0 = 7
|
||||
for i in range(24):
|
||||
layer_0 = layer_0 + 4
|
||||
layer_0 = layer_0 + 8
|
||||
layer_dict.append(layer_0)
|
||||
# 输出它本身
|
||||
if len(layer_indexes) == 0:
|
||||
@ -58,7 +58,7 @@ class KerasBertVector():
|
||||
encoder_layer = model.get_layer(index=layer_dict[-1]).output
|
||||
# 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数
|
||||
else:
|
||||
# layer_indexes must be [1,2,3,......12...24]
|
||||
# layer_indexes must be [1,2,3,......12]
|
||||
# all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes]
|
||||
all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(23)]
|
||||
else model.get_layer(index=layer_dict[-1]).output #如果给出不正确,就默认输出最后一层
|
||||
|
Loading…
Reference in New Issue
Block a user