fix xlnet
This commit is contained in:
parent
2fa1ac3430
commit
b4a3e17f6a
@ -20,6 +20,7 @@ import keras.backend as K
|
||||
import numpy as np
|
||||
import codecs
|
||||
import jieba
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
@ -341,6 +342,12 @@ class XlnetEmbedding(BaseEmbedding):
|
||||
self.batch_size = hyper_parameters['model'].get('batch_size', 2)
|
||||
super().__init__(hyper_parameters)
|
||||
|
||||
def build_config(self, path_config: str=None):
|
||||
# reader config of bert
|
||||
self.configs = {}
|
||||
if path_config is not None:
|
||||
self.configs.update(json.load(open(path_config)))
|
||||
|
||||
def build(self):
|
||||
from keras_xlnet import load_trained_model_from_checkpoint, set_custom_objects
|
||||
from keras_xlnet import Tokenizer, ATTENTION_TYPE_BI, ATTENTION_TYPE_UNI
|
||||
@ -352,7 +359,7 @@ class XlnetEmbedding(BaseEmbedding):
|
||||
|
||||
self.attention_type = self.xlnet_embed.get('attention_type', 'bi') # or 'uni'
|
||||
self.attention_type = ATTENTION_TYPE_BI if self.attention_type == 'bi' else ATTENTION_TYPE_UNI
|
||||
self.memory_len = self.xlnet_embed.get('memory_len', 0)
|
||||
self.memory_len = self.xlnet_embed.get('memory_len', 0)
|
||||
self.target_len = self.xlnet_embed.get('target_len', 5)
|
||||
print('load xlnet model start!')
|
||||
# 模型加载
|
||||
@ -366,38 +373,21 @@ class XlnetEmbedding(BaseEmbedding):
|
||||
mask_index=0)
|
||||
#
|
||||
set_custom_objects()
|
||||
self.build_config(self.config_path)
|
||||
# 字典加载
|
||||
self.tokenizer = Tokenizer(self.spiece_model)
|
||||
# debug时候查看layers
|
||||
self.model_layers = model.layers
|
||||
len_layers = self.model_layers.__len__()
|
||||
print(len_layers)
|
||||
# # debug时候查看layers
|
||||
# self.model_layers = model.layers
|
||||
# len_layers = self.model_layers.__len__()
|
||||
# print(len_layers)
|
||||
num_hidden_layers = self.configs.get("n_layer", 12)
|
||||
|
||||
layer_real = [i for i in range(25)] + [-i for i in range(25)]
|
||||
layer_real = [i for i in range(num_hidden_layers)] + [-i for i in range(num_hidden_layers)]
|
||||
# 简要判别一下
|
||||
self.layer_indexes = [i if i in layer_real else -2 for i in self.layer_indexes]
|
||||
|
||||
len_couche = int((len_layers - 6) / 10)
|
||||
# 一共246个layer
|
||||
# 每层10个layer(MultiHeadAttention,Dropout,Add,LayerNormalization),第一是9个layer的输入和embedding层
|
||||
# 一共24层
|
||||
layer_dict = []
|
||||
layer_0 = 7
|
||||
for i in range(len_couche):
|
||||
layer_0 = layer_0 + 10
|
||||
layer_dict.append(layer_0)
|
||||
layer_dict.append(247)
|
||||
# 测试 get_output_at
|
||||
# def get_number(index):
|
||||
# try:
|
||||
# model_node = model.get_output_at(node_index=index)
|
||||
# gg = 0
|
||||
# except:
|
||||
# print('node index wrong!')
|
||||
# print(index)
|
||||
# list_index = [i for i in range(25)] + [-i for i in range(25)]
|
||||
# for li in list_index:
|
||||
# get_number(li)
|
||||
output_layer = "FeedForward-Normal-{0}"
|
||||
layer_dict = [model.get_layer(output_layer.format(i + 1)).get_output_at(node_index=0)
|
||||
for i in range(num_hidden_layers)]
|
||||
|
||||
# 输出它本身
|
||||
if len(self.layer_indexes) == 0:
|
||||
@ -405,15 +395,14 @@ class XlnetEmbedding(BaseEmbedding):
|
||||
# 分类如果只有一层,取得不正确的话就取倒数第二层
|
||||
elif len(self.layer_indexes) == 1:
|
||||
if self.layer_indexes[0] in layer_real:
|
||||
encoder_layer = model.get_layer(index=layer_dict[self.layer_indexes[0]]).get_output_at(node_index=0)
|
||||
encoder_layer = layer_dict[self.layer_indexes[0]]
|
||||
else:
|
||||
encoder_layer = model.get_layer(index=layer_dict[-1]).get_output_at(node_index=0)
|
||||
encoder_layer = layer_dict[-1]
|
||||
# 否则遍历需要取的层,把所有层的weight取出来并加起来shape:768*层数
|
||||
else:
|
||||
# layer_indexes must be [0, 1, 2,3,......24]
|
||||
all_layers = [model.get_layer(index=layer_dict[lay]).get_output_at(node_index=0)
|
||||
if lay in layer_real
|
||||
else model.get_layer(index=layer_dict[-1]).get_output_at(node_index=0) # 如果给出不正确,就默认输出倒数第一层
|
||||
all_layers = [layer_dict[lay] if lay in layer_real
|
||||
else layer_dict[-1] # 如果给出不正确,就默认输出倒数第一层
|
||||
for lay in self.layer_indexes]
|
||||
print(self.layer_indexes)
|
||||
print(all_layers)
|
||||
|
Loading…
Reference in New Issue
Block a user