修改bert的Embedding层layer取错问题
This commit is contained in:
parent
4858620cc1
commit
b51aec0d4f
@ -4,12 +4,14 @@
|
|||||||
# @author :Mo
|
# @author :Mo
|
||||||
# @function :embedding of bert keras
|
# @function :embedding of bert keras
|
||||||
|
|
||||||
from conf.feature_config import gpu_memory_fraction, config_name, ckpt_name, vocab_file, max_seq_len, layer_indexes
|
from ClassificationText.bert.args import gpu_memory_fraction, max_seq_len, layer_indexes
|
||||||
|
from conf.feature_config import config_name, ckpt_name, vocab_file
|
||||||
from FeatureProject.bert.layers_keras import NonMaskingLayer
|
from FeatureProject.bert.layers_keras import NonMaskingLayer
|
||||||
from keras_bert import load_trained_model_from_checkpoint
|
from keras_bert import load_trained_model_from_checkpoint
|
||||||
import keras.backend.tensorflow_backend as ktf_keras
|
import keras.backend.tensorflow_backend as ktf_keras
|
||||||
import keras.backend as k_keras
|
import keras.backend as k_keras
|
||||||
from keras.models import Model
|
from keras.models import Model
|
||||||
|
from keras.layers import Add
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -29,7 +31,7 @@ class KerasBertEmbedding():
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.config_path, self.checkpoint_path, self.dict_path, self.max_seq_len = config_name, ckpt_name, vocab_file, max_seq_len
|
self.config_path, self.checkpoint_path, self.dict_path, self.max_seq_len = config_name, ckpt_name, vocab_file, max_seq_len
|
||||||
|
|
||||||
def bert_encode(self):
|
def bert_encode(self, layer_indexes=[12]):
|
||||||
# 全局使用,使其可以django、flask、tornado等调用
|
# 全局使用,使其可以django、flask、tornado等调用
|
||||||
global graph
|
global graph
|
||||||
graph = tf.get_default_graph()
|
graph = tf.get_default_graph()
|
||||||
@ -37,14 +39,40 @@ class KerasBertEmbedding():
|
|||||||
model = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path,
|
model = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path,
|
||||||
seq_len=self.max_seq_len)
|
seq_len=self.max_seq_len)
|
||||||
print(model.output)
|
print(model.output)
|
||||||
# 分类如果只选一层,就只取最后那一层的weight
|
print(len(model.layers))
|
||||||
if len(layer_indexes) == 1:
|
# lay = model.layers
|
||||||
encoder_layer = model.get_layer(index=len(model.layers)-1).output
|
#一共104个layer,其中前八层包括token,pos,embed等,
|
||||||
|
# 每4层(MultiHeadAttention,Dropout,Add,LayerNormalization)
|
||||||
|
# 一共24层
|
||||||
|
layer_dict = []
|
||||||
|
layer_0 = 7
|
||||||
|
for i in range(24):
|
||||||
|
layer_0 = layer_0 + 4
|
||||||
|
layer_dict.append(layer_0)
|
||||||
|
# 输出它本身
|
||||||
|
if len(layer_indexes) == 0:
|
||||||
|
encoder_layer = model.output
|
||||||
|
# 分类如果只有一层,就只取最后那一层的weight,取得不正确
|
||||||
|
elif len(layer_indexes) == 1:
|
||||||
|
if layer_indexes[0] in [i+1 for i in range(23)]:
|
||||||
|
encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output
|
||||||
|
else:
|
||||||
|
encoder_layer = model.get_layer(index=layer_dict[-1]).output
|
||||||
# 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数
|
# 否则遍历需要取的层,把所有层的weight取出来并拼接起来shape:768*层数
|
||||||
else:
|
else:
|
||||||
# layer_indexes must be [1,2,3,......12]
|
# layer_indexes must be [1,2,3,......12...24]
|
||||||
all_layers = [model.get_layer(index=lay).output for lay in layer_indexes]
|
# all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes]
|
||||||
encoder_layer = k_keras.concatenate(all_layers, -1)
|
all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(23)]
|
||||||
|
else model.get_layer(index=layer_dict[-1]).output #如果给出不正确,就默认输出最后一层
|
||||||
|
for lay in layer_indexes]
|
||||||
|
print(layer_indexes)
|
||||||
|
print(all_layers)
|
||||||
|
# 其中layer==1的output是格式不对,第二层输入input是list
|
||||||
|
all_layers_select = []
|
||||||
|
for all_layers_one in all_layers:
|
||||||
|
all_layers_select.append(all_layers_one)
|
||||||
|
encoder_layer = Add()(all_layers_select)
|
||||||
|
print(encoder_layer.shape)
|
||||||
print("KerasBertEmbedding:")
|
print("KerasBertEmbedding:")
|
||||||
print(encoder_layer.shape)
|
print(encoder_layer.shape)
|
||||||
output_layer = NonMaskingLayer()(encoder_layer)
|
output_layer = NonMaskingLayer()(encoder_layer)
|
||||||
@ -56,3 +84,4 @@ class KerasBertEmbedding():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
bert_vector = KerasBertEmbedding()
|
bert_vector = KerasBertEmbedding()
|
||||||
pooled = bert_vector.bert_encode()
|
pooled = bert_vector.bert_encode()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user