From 1b83777e7662511dd1a65c6351d854720c03532c Mon Sep 17 00:00:00 2001 From: yongzhuo <31341349+yongzhuo@users.noreply.github.com> Date: Sat, 1 Jun 2019 18:41:16 +0800 Subject: [PATCH] Update extract_keras_bert_feature.py --- FeatureProject/bert/extract_keras_bert_feature.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/FeatureProject/bert/extract_keras_bert_feature.py b/FeatureProject/bert/extract_keras_bert_feature.py index 5df3814..9b05cb9 100644 --- a/FeatureProject/bert/extract_keras_bert_feature.py +++ b/FeatureProject/bert/extract_keras_bert_feature.py @@ -44,13 +44,13 @@ class KerasBertVector(): # 一共24层 layer_dict = [] layer_0 = 7 - for i in range(24): + for i in range(12): layer_0 = layer_0 + 8 layer_dict.append(layer_0) # 输出它本身 if len(layer_indexes) == 0: encoder_layer = model.output - # 分类如果只有一层,就只取最后那一层的weight,取得不正确 + # 分类如果只有一层,就只取最后那一层的weight;取得不正确,就默认取最后一层输出 elif len(layer_indexes) == 1: if layer_indexes[0] in [i+1 for i in range(23)]: encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output @@ -65,7 +65,6 @@ class KerasBertVector(): for lay in layer_indexes] print(layer_indexes) print(all_layers) - # 其中layer==1的output是格式不对,第二层输入input是list all_layers_select = [] for all_layers_one in all_layers: all_layers_select.append(all_layers_one)