修改bert的Embedding层layer取错问题

This commit is contained in:
yongzhuo 2019-05-30 21:51:00 +08:00 committed by GitHub
parent 4858620cc1
commit b51aec0d4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,12 +4,14 @@
# @author :Mo # @author :Mo
# @function :embedding of bert keras # @function :embedding of bert keras
from conf.feature_config import gpu_memory_fraction, config_name, ckpt_name, vocab_file, max_seq_len, layer_indexes from ClassificationText.bert.args import gpu_memory_fraction, max_seq_len, layer_indexes
from conf.feature_config import config_name, ckpt_name, vocab_file
from FeatureProject.bert.layers_keras import NonMaskingLayer from FeatureProject.bert.layers_keras import NonMaskingLayer
from keras_bert import load_trained_model_from_checkpoint from keras_bert import load_trained_model_from_checkpoint
import keras.backend.tensorflow_backend as ktf_keras import keras.backend.tensorflow_backend as ktf_keras
import keras.backend as k_keras import keras.backend as k_keras
from keras.models import Model from keras.models import Model
from keras.layers import Add
import tensorflow as tf import tensorflow as tf
import os import os
@ -29,7 +31,7 @@ class KerasBertEmbedding():
def __init__(self): def __init__(self):
self.config_path, self.checkpoint_path, self.dict_path, self.max_seq_len = config_name, ckpt_name, vocab_file, max_seq_len self.config_path, self.checkpoint_path, self.dict_path, self.max_seq_len = config_name, ckpt_name, vocab_file, max_seq_len
def bert_encode(self): def bert_encode(self, layer_indexes=[12]):
# 全局使用使其可以django、flask、tornado等调用 # 全局使用使其可以django、flask、tornado等调用
global graph global graph
graph = tf.get_default_graph() graph = tf.get_default_graph()
@ -37,14 +39,40 @@ class KerasBertEmbedding():
model = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path, model = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path,
seq_len=self.max_seq_len) seq_len=self.max_seq_len)
print(model.output) print(model.output)
# 分类如果只选一层就只取最后那一层的weight print(len(model.layers))
if len(layer_indexes) == 1: # lay = model.layers
encoder_layer = model.get_layer(index=len(model.layers)-1).output #一共104个layer其中前八层包括token,pos,embed等
# 每4层MultiHeadAttention,Dropout,Add,LayerNormalization
# 一共24层
layer_dict = []
layer_0 = 7
for i in range(24):
layer_0 = layer_0 + 4
layer_dict.append(layer_0)
# 输出它本身
if len(layer_indexes) == 0:
encoder_layer = model.output
# 分类如果只有一层就只取最后那一层的weight取得不正确
elif len(layer_indexes) == 1:
if layer_indexes[0] in [i+1 for i in range(23)]:
encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output
else:
encoder_layer = model.get_layer(index=layer_dict[-1]).output
# 否则遍历需要取的层把所有层的weight取出来并拼接起来shape:768*层数 # 否则遍历需要取的层把所有层的weight取出来并拼接起来shape:768*层数
else: else:
# layer_indexes must be [1,2,3,......12] # layer_indexes must be [1,2,3,......12...24]
all_layers = [model.get_layer(index=lay).output for lay in layer_indexes] # all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes]
encoder_layer = k_keras.concatenate(all_layers, -1) all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(23)]
else model.get_layer(index=layer_dict[-1]).output #如果给出不正确,就默认输出最后一层
for lay in layer_indexes]
print(layer_indexes)
print(all_layers)
# 其中layer==1的output是格式不对第二层输入input是list
all_layers_select = []
for all_layers_one in all_layers:
all_layers_select.append(all_layers_one)
encoder_layer = Add()(all_layers_select)
print(encoder_layer.shape)
print("KerasBertEmbedding:") print("KerasBertEmbedding:")
print(encoder_layer.shape) print(encoder_layer.shape)
output_layer = NonMaskingLayer()(encoder_layer) output_layer = NonMaskingLayer()(encoder_layer)
@ -56,3 +84,4 @@ class KerasBertEmbedding():
if __name__ == "__main__": if __name__ == "__main__":
bert_vector = KerasBertEmbedding() bert_vector = KerasBertEmbedding()
pooled = bert_vector.bert_encode() pooled = bert_vector.bert_encode()