Update extract_keras_bert_feature.py
This commit is contained in:
parent
e5709962af
commit
5bda829469
@ -41,8 +41,8 @@ class KerasBertVector():
|
|||||||
# lay = model.layers
|
# lay = model.layers
|
||||||
#一共104个layer,其中前八层包括token,pos,embed等,
|
#一共104个layer,其中前八层包括token,pos,embed等,
|
||||||
# 每4层(MultiHeadAttention,Dropout,Add,LayerNormalization)
|
# 每4层(MultiHeadAttention,Dropout,Add,LayerNormalization)
|
||||||
# 一共24层
|
# 一共12层transformer
|
||||||
layer_dict = []
|
layer_dict = [7]
|
||||||
layer_0 = 7
|
layer_0 = 7
|
||||||
for i in range(12):
|
for i in range(12):
|
||||||
layer_0 = layer_0 + 8
|
layer_0 = layer_0 + 8
|
||||||
@ -50,18 +50,18 @@ class KerasBertVector():
|
|||||||
# 输出它本身
|
# 输出它本身
|
||||||
if len(layer_indexes) == 0:
|
if len(layer_indexes) == 0:
|
||||||
encoder_layer = model.output
|
encoder_layer = model.output
|
||||||
# 分类如果只有一层,就只取最后那一层的weight;取得不正确,就默认取最后一层输出
|
# 分类如果只有一层,就只取倒数第二层的weight;取得不正确,就默认取倒数第二层出
|
||||||
elif len(layer_indexes) == 1:
|
elif len(layer_indexes) == 1:
|
||||||
if layer_indexes[0] in [i+1 for i in range(23)]:
|
if layer_indexes[0] in [i+1 for i in range(12)]:
|
||||||
encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output
|
encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output
|
||||||
else:
|
else:
|
||||||
encoder_layer = model.get_layer(index=layer_dict[-1]).output
|
encoder_layer = model.get_layer(index=layer_dict[-2]).output
|
||||||
# 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数
|
# 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数
|
||||||
else:
|
else:
|
||||||
# layer_indexes must be [1,2,3,......12]
|
# layer_indexes must be [1,2,3,......12]
|
||||||
# all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes]
|
# all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes]
|
||||||
all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(23)]
|
all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(12)]
|
||||||
else model.get_layer(index=layer_dict[-1]).output #如果给出不正确,就默认输出最后一层
|
else model.get_layer(index=layer_dict[-2]).output #如果给出不正确,就默认输出最后一层
|
||||||
for lay in layer_indexes]
|
for lay in layer_indexes]
|
||||||
print(layer_indexes)
|
print(layer_indexes)
|
||||||
print(all_layers)
|
print(all_layers)
|
||||||
|
Loading…
Reference in New Issue
Block a user