Update keras_bert_embedding.py
This commit is contained in:
parent
bc336fe684
commit
5745c15acc
@ -46,13 +46,13 @@ class KerasBertEmbedding():
|
|||||||
# 一共12层
|
# 一共12层
|
||||||
layer_dict = []
|
layer_dict = []
|
||||||
layer_0 = 7
|
layer_0 = 7
|
||||||
for i in range(24):
|
for i in range(12):
|
||||||
layer_0 = layer_0 + 8
|
layer_0 = layer_0 + 8
|
||||||
layer_dict.append(layer_0)
|
layer_dict.append(layer_0)
|
||||||
# 输出它本身
|
# 输出它本身
|
||||||
if len(layer_indexes) == 0:
|
if len(layer_indexes) == 0:
|
||||||
encoder_layer = model.output
|
encoder_layer = model.output
|
||||||
# 分类如果只有一层,就只取最后那一层的weight,取得不正确
|
# 分类如果只有一层,就只取最后那一层的weight;取得不正确,就默认取最后一层
|
||||||
elif len(layer_indexes) == 1:
|
elif len(layer_indexes) == 1:
|
||||||
if layer_indexes[0] in [i+1 for i in range(23)]:
|
if layer_indexes[0] in [i+1 for i in range(23)]:
|
||||||
encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output
|
encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output
|
||||||
@ -67,7 +67,6 @@ class KerasBertEmbedding():
|
|||||||
for lay in layer_indexes]
|
for lay in layer_indexes]
|
||||||
print(layer_indexes)
|
print(layer_indexes)
|
||||||
print(all_layers)
|
print(all_layers)
|
||||||
# 其中layer==1的output是格式不对,第二层输入input是list
|
|
||||||
all_layers_select = []
|
all_layers_select = []
|
||||||
for all_layers_one in all_layers:
|
for all_layers_one in all_layers:
|
||||||
all_layers_select.append(all_layers_one)
|
all_layers_select.append(all_layers_one)
|
||||||
|
Loading…
Reference in New Issue
Block a user