Update keras_bert_embedding.py

This commit is contained in:
yongzhuo 2019-06-01 18:40:11 +08:00 committed by GitHub
parent bc336fe684
commit 5745c15acc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -46,13 +46,13 @@ class KerasBertEmbedding():
# 一共12层 # 一共12层
layer_dict = [] layer_dict = []
layer_0 = 7 layer_0 = 7
for i in range(24): for i in range(12):
layer_0 = layer_0 + 8 layer_0 = layer_0 + 8
layer_dict.append(layer_0) layer_dict.append(layer_0)
# 输出它本身 # 输出它本身
if len(layer_indexes) == 0: if len(layer_indexes) == 0:
encoder_layer = model.output encoder_layer = model.output
# 分类如果只有一层就只取最后那一层的weight,取得不正确 # 分类如果只有一层就只取最后那一层的weight;取得不正确,就默认取最后一层
elif len(layer_indexes) == 1: elif len(layer_indexes) == 1:
if layer_indexes[0] in [i+1 for i in range(23)]: if layer_indexes[0] in [i+1 for i in range(23)]:
encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output
@ -67,7 +67,6 @@ class KerasBertEmbedding():
for lay in layer_indexes] for lay in layer_indexes]
print(layer_indexes) print(layer_indexes)
print(all_layers) print(all_layers)
# 其中layer==1的output是格式不对第二层输入input是list
all_layers_select = [] all_layers_select = []
for all_layers_one in all_layers: for all_layers_one in all_layers:
all_layers_select.append(all_layers_one) all_layers_select.append(all_layers_one)