Delete keras_bert_embedding.py

This commit is contained in:
yongzhuo 2019-06-11 15:40:43 +08:00 committed by GitHub
parent 9e9ac2a1c7
commit 129b74bed4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,86 +0,0 @@
# -*- coding: UTF-8 -*-
# !/usr/bin/python
# @time :2019/5/8 20:04
# @author :Mo
# @function :embedding of bert keras
from ClassificationText.bert.args import gpu_memory_fraction, max_seq_len, layer_indexes
from conf.feature_config import config_name, ckpt_name, vocab_file
from FeatureProject.bert.layers_keras import NonMaskingLayer
from keras_bert import load_trained_model_from_checkpoint
import keras.backend.tensorflow_backend as ktf_keras
import keras.backend as k_keras
from keras.models import Model
from keras.layers import Add
import tensorflow as tf
import os
import logging as logger
# 全局使用使其可以django、flask、tornado等调用
graph = None
model = None
# gpu配置与使用率设置
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = gpu_memory_fraction
sess = tf.Session(config=config)
ktf_keras.set_session(sess)
class KerasBertEmbedding():
def __init__(self):
self.config_path, self.checkpoint_path, self.dict_path, self.max_seq_len = config_name, ckpt_name, vocab_file, max_seq_len
def bert_encode(self, layer_indexes=[12]):
# 全局使用使其可以django、flask、tornado等调用
global graph
graph = tf.get_default_graph()
global model
model = load_trained_model_from_checkpoint(self.config_path, self.checkpoint_path,
seq_len=self.max_seq_len)
print(model.output)
print(len(model.layers))
# lay = model.layers
#一共104个layer其中前八层包括token,pos,embed等
# 每8层MultiHeadAttention,Dropout,Add,LayerNormalization
# 一共12层
layer_dict = []
layer_0 = 7
for i in range(12):
layer_0 = layer_0 + 8
layer_dict.append(layer_0)
# 输出它本身
if len(layer_indexes) == 0:
encoder_layer = model.output
# 分类如果只有一层就只取最后那一层的weight取得不正确就默认取最后一层
elif len(layer_indexes) == 1:
if layer_indexes[0] in [i+1 for i in range(23)]:
encoder_layer = model.get_layer(index=layer_dict[layer_indexes[0]]).output
else:
encoder_layer = model.get_layer(index=layer_dict[-1]).output
# 否则遍历需要取的层把所有层的weight取出来并拼接起来shape:768*层数
else:
# layer_indexes must be [1,2,3,......12]
# all_layers = [model.get_layer(index=lay).output if lay is not 1 else model.get_layer(index=lay).output[0] for lay in layer_indexes]
all_layers = [model.get_layer(index=layer_dict[lay-1]).output if lay in [i+1 for i in range(23)]
else model.get_layer(index=layer_dict[-1]).output #如果给出不正确,就默认输出最后一层
for lay in layer_indexes]
print(layer_indexes)
print(all_layers)
all_layers_select = []
for all_layers_one in all_layers:
all_layers_select.append(all_layers_one)
encoder_layer = Add()(all_layers_select)
print(encoder_layer.shape)
print("KerasBertEmbedding:")
print(encoder_layer.shape)
output_layer = NonMaskingLayer()(encoder_layer)
model = Model(model.inputs, output_layer)
# model.summary(120)
return model.inputs, model.output
if __name__ == "__main__":
bert_vector = KerasBertEmbedding()
pooled = bert_vector.bert_encode()