add CapsuleNet-bojone of NLP

This commit is contained in:
yongzhuo 2019-07-19 19:00:03 +08:00
parent 0b1ed10d2b
commit 164be1eb0a
8 changed files with 174 additions and 31 deletions

View File

@ -148,7 +148,7 @@ class PreprocessText:
save_json(l2i_i2l, path_fast_text_model_l2i_i2l)
len_ql = int(rate * len(ques))
if len_ql <= 5000: # sample时候不生效,使得语料足够训练
if len_ql <= 500: # sample时候不生效,使得语料足够训练
len_ql = len(ques)
x = []

View File

@ -7,9 +7,11 @@ uncommenting them and commenting their counterparts.
Author: Xifeng Guo, E-mail: `guoxifeng1990@163.com`, Github: `https://github.com/XifengGuo/CapsNet-Keras`
"""
from keras.layers import Activation, Layer
from keras import initializers, layers
import keras.backend as K
import tensorflow as tf
from keras import initializers, layers
class Length(layers.Layer):
@ -192,14 +194,94 @@ def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
return layers.Lambda(squash, )(outputs)
"""
# The following is another way to implement primary capsule layer. This is much slower.
# Apply Conv2D `n_channels` times and concatenate all capsules
def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
def PrimaryCap_nchannels(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
# The following is another way to implement primary capsule layer. This is much slower.
# Apply Conv2D `n_channels` times and concatenate all capsules
outputs = []
for _ in range(n_channels):
output = layers.Conv2D(filters=dim_capsule, kernel_size=kernel_size, strides=strides, padding=padding)(inputs)
outputs.append(layers.Reshape([output.get_shape().as_list()[1] ** 2, dim_capsule])(output))
outputs = layers.Concatenate(axis=1)(outputs)
return layers.Lambda(squash)(outputs)
"""
def squash_bojone(x, axis=-1):
"""
activation of squash
:param x: vector
:param axis: int
:return: vector
"""
s_squared_norm = K.sum(K.square(x), axis, keepdims=True)
scale = K.sqrt(s_squared_norm + K.epsilon())
return x / scale
class Capsule_bojone(Layer):
"""
# auther: bojone
# explain: A Capsule Implement with Pure Keras
# github: https://github.com/bojone/Capsule/blob/master/Capsule_Keras.py
"""
def __init__(self, num_capsule, dim_capsule, routings=3, kernel_size=(9, 1),
share_weights=True, activation='default', **kwargs):
super(Capsule_bojone, self).__init__(**kwargs)
self.num_capsule = num_capsule
self.dim_capsule = dim_capsule
self.routings = routings
self.kernel_size = kernel_size
self.share_weights = share_weights
if activation == 'default':
self.activation = squash_bojone
else:
self.activation = Activation(activation)
def build(self, input_shape):
super(Capsule_bojone, self).build(input_shape)
input_dim_capsule = input_shape[-1]
if self.share_weights:
self.W = self.add_weight(name='capsule_kernel',
shape=(1, input_dim_capsule,
self.num_capsule * self.dim_capsule),
# shape=self.kernel_size,
initializer='glorot_uniform',
trainable=True)
else:
input_num_capsule = input_shape[-2]
self.W = self.add_weight(name='capsule_kernel',
shape=(input_num_capsule,
input_dim_capsule,
self.num_capsule * self.dim_capsule),
initializer='glorot_uniform',
trainable=True)
def call(self, u_vecs):
if self.share_weights:
u_hat_vecs = K.conv1d(u_vecs, self.W)
else:
u_hat_vecs = K.local_conv1d(u_vecs, self.W, [1], [1])
batch_size = K.shape(u_vecs)[0]
input_num_capsule = K.shape(u_vecs)[1]
u_hat_vecs = K.reshape(u_hat_vecs, (batch_size, input_num_capsule,
self.num_capsule, self.dim_capsule))
u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3))
# final u_hat_vecs.shape = [None, num_capsule, input_num_capsule, dim_capsule]
b = K.zeros_like(u_hat_vecs[:, :, :, 0]) # shape = [None, num_capsule, input_num_capsule]
outputs = None
for i in range(self.routings):
b = K.permute_dimensions(b, (0, 2, 1)) # shape = [None, input_num_capsule, num_capsule]
c = K.softmax(b)
c = K.permute_dimensions(c, (0, 2, 1))
b = K.permute_dimensions(b, (0, 2, 1))
outputs = self.activation(K.batch_dot(c, u_hat_vecs, [2, 2]))
if i < self.routings - 1:
b = K.batch_dot(outputs, u_hat_vecs, [2, 3])
return outputs
def compute_output_shape(self, input_shape):
return (None, self.num_capsule, self.dim_capsule)

View File

@ -90,7 +90,7 @@ def train(hyper_parameters=None, rate=1.0):
if __name__=="__main__":
train(rate=0.001) # sample条件下设为1,否则训练语料可能会很少
train(rate=0.01) # sample条件下设为1,否则训练语料可能会很少
# 注意: 4G的080Ti的GPU、win10下batch_size=32,len_max=20, gpu<=0.87, 应该就可以bert-fineture了。
# 全量数据训练一轮(batch_size=32),就能达到80%准确率(验证集), 效果还是不错的
# win10下出现过错误,gpu、len_max、batch_size配小一点就好:ailed to allocate 3.56G (3822520832 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory

View File

@ -44,7 +44,7 @@ class TextCNNGraph(graph):
)(conv)
conv_pools.append(pooled)
# 拼接
x = Concatenate(axis=1)(conv_pools)
x = Concatenate(axis=-1)(conv_pools)
x = Flatten()(x)
x = Dropout(self.dropout)(x)
output = Dense(units=self.label, activation=self.activate_classify)(x)

View File

@ -98,7 +98,7 @@ def train(hyper_parameters=None, rate=1.0):
if __name__=="__main__":
train(rate=0.001) # sample条件下设为1,否则训练语料可能会很少
train(rate=0.01) # sample条件下设为1,否则训练语料可能会很少
# 注意: 4G的080Ti的GPU、win10下batch_size=32,len_max=20, gpu<=0.87, 应该就可以bert-fineture了。
# 全量数据训练一轮(batch_size=32),就能达到80%准确率(验证集), 效果还是不错的
# win10下出现过错误,gpu、len_max、batch_size配小一点就好:ailed to allocate 3.56G (3822520832 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory

View File

@ -93,7 +93,7 @@ def train(hyper_parameters=None, rate=1.0):
if __name__=="__main__":
train(rate=0.001) # sample条件下设为1,否则训练语料可能会很少
train(rate=0.01) # sample条件下设为1,否则训练语料可能会很少
# 注意: 4G的080Ti的GPU、win10下batch_size=32,len_max=20, gpu<=0.87, 应该就可以bert-fineture了。
# 全量数据训练一轮(batch_size=32),就能达到80%准确率(验证集), 效果还是不错的
# win10下出现过错误,gpu、len_max、batch_size配小一点就好:ailed to allocate 3.56G (3822520832 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory

View File

@ -5,8 +5,9 @@
# @function :graph of base
from keras_textclassification.keras_layers.capsule import CapsuleLayer, PrimaryCap, Length, Mask
from keras.layers import Conv2D, MaxPool2D, Concatenate
from keras_textclassification.keras_layers.capsule import Capsule_bojone, CapsuleLayer, PrimaryCap, Length, Mask
from keras.layers import Conv1D, Conv2D, MaxPool2D, Concatenate, SpatialDropout1D
from keras.layers import Bidirectional, LSTM, GRU
from keras.layers import Dropout, Dense, Flatten
from keras.optimizers import Adam
from keras.layers import Reshape
@ -23,8 +24,9 @@ class CapsuleNetGraph(graph):
初始化
:param hyper_parameters: json超参
"""
self.routings = hyper_parameters['model'].get('routings', 1)
self.routings = hyper_parameters['model'].get('routings', 5)
self.dim_capsule = hyper_parameters['model'].get('dim_capsule', 16)
self.num_capsule = hyper_parameters['model'].get('num_capsule', 16)
super().__init__(hyper_parameters)
def create_model(self, hyper_parameters):
@ -35,43 +37,101 @@ class CapsuleNetGraph(graph):
"""
super().create_model(hyper_parameters)
embedding = self.word_embedding.output
embed_layer = SpatialDropout1D(self.dropout)(embedding)
conv_pools = []
for filter in self.filters:
x = Conv1D(filters=self.filters_num,
kernel_size=filter,
padding='valid',
kernel_initializer='normal',
activation='relu',
)(embed_layer)
capsule = Capsule_bojone(num_capsule=self.num_capsule,
dim_capsule=self.dim_capsule,
routings=self.routings,
kernel_size=(filter, 1),
share_weights=True)(x)
conv_pools.append(capsule)
capsule = Concatenate(axis=-1)(conv_pools)
capsule = Flatten()(capsule)
capsule = Dropout(self.dropout)(capsule)
output = Dense(self.label, activation='sigmoid')(capsule)
self.model = Model(inputs=self.word_embedding.input, outputs=output)
self.model.summary(120)
def create_model_gru(self, hyper_parameters):
"""
构建神经网络, bi-gru + capsule
:param hyper_parameters:json, hyper parameters of network
:return: tensor, moedl
"""
super().create_model(hyper_parameters)
embedding = self.word_embedding.output
embed_layer = SpatialDropout1D(self.dropout)(embedding)
x = Bidirectional(GRU(self.filters_num,
activation='relu',
dropout=self.dropout,
recurrent_dropout=self.dropout,
return_sequences=True))(embed_layer)
capsule = Capsule_bojone(num_capsule=self.num_capsule,
dim_capsule=self.dim_capsule,
routings=self.routings,
kernel_size=(self.filters[0], 1),
share_weights=True)(x)
capsule = Flatten()(capsule)
capsule = Dropout(self.dropout)(capsule)
output = Dense(self.label, activation='sigmoid')(capsule)
self.model = Model(inputs=self.word_embedding.input, outputs=output)
self.model.summary(120)
def create_model_basic(self, hyper_parameters):
"""
构建神经网络, 原版, 速度超级慢
:param hyper_parameters:json, hyper parameters of network
:return: tensor, moedl
"""
super().create_model(hyper_parameters)
embedding = self.word_embedding.output
embedding = SpatialDropout1D(self.dropout)(embedding)
embedding_reshape = Reshape((self.len_max, self.embed_size, 1))(embedding)
conv_pools = []
for filter in self.filters:
# Layer 1: Just a conventional Conv2D layer
conv1 = Conv2D(filters=self.filters_num,
kernel_size=(filter, self.embed_size),
strides=1,
strides=(1, 1),
padding='valid',
activation='relu',)(embedding_reshape)
conv_shape = K.int_shape(conv1)
conv1 = Reshape((conv_shape[1], conv_shape[3], conv_shape[2]))(conv1)
# Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
primarycaps = PrimaryCap(inputs=conv1,
dim_capsule=self.dim_capsule,
n_channels=self.channel_size,
kernel_size=(self.len_max - filter + 1, 1),
kernel_size=(filter, conv_shape[3]),
strides=(1, 1),
padding='valid')
# Layer 3: Capsule layer. Routing algorithm works here.
digitcaps = CapsuleLayer(num_capsule=self.label,
dim_capsule=int(self.dim_capsule * 2),
routings=self.routings, )(primarycaps)
digitcaps = CapsuleLayer(num_capsule=self.num_capsule, #self.label,
dim_capsule=int(self.dim_capsule * 2),
routings=self.routings, )(primarycaps)
conv_pools.append(digitcaps)
# 拼接
x = Concatenate(axis=-1)(conv_pools)
# x = Flatten()(x)
# dense_layer = Dense(self.label, activation=self.activate_classify)(x)
x = Flatten()(x)
x = Dropout(self.dropout)(x)
dense_layer = Dense(self.label, activation=self.activate_classify)(x)
# out_caps = [dense_layer]
# Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
# If using tensorflow, this will not be necessary. :)
out_caps = Length(name='capsnet')(x)
# out_caps = Length(name='capsnet')(x)
# 最后就是softmax
self.model = Model(inputs=self.word_embedding.input, outputs=out_caps)
self.model = Model(inputs=self.word_embedding.input, outputs=dense_layer)
self.model.summary(120)
def create_model_old(self, hyper_parameters):
def create_model_margin(self, hyper_parameters):
"""
构建神经网络
:param hyper_parameters:json, hyper parameters of network

View File

@ -32,16 +32,16 @@ def train(hyper_parameters=None, rate=1.0):
if not hyper_parameters:
hyper_parameters = {
'len_max': 50, # 句子最大长度, 固定 推荐20-50
'embed_size': 96, # 字/词向量维度
'embed_size': 300, # 字/词向量维度
'vocab_size': 20000, # 这里随便填的,会根据代码里修改
'trainable': False, # embedding是静态的还是动态的
'level_type': 'char', # 级别, 最小单元, 字/词, 填 'char' or 'word'
'embedding_type': 'random', # 级别, 嵌入类型, 还可以填'random'、 'bert' or 'word2vec"
'model': {'label': 17, # 类别数
'batch_size': 64, # 批处理尺寸
'filters': [3, 4, 5], # 卷积核尺寸
'batch_size': 16, # 批处理尺寸
'filters': [2, 3, 4, 5], # 卷积核尺寸
'filters_num': 300, # 卷积个数 text-cnn:300-600
'channel_size': 3, # CNN通道数
'channel_size': 16, # CNN通道数
'dropout': 0.5, # 舍弃概率
'decay_step': 100, # 学习率衰减step, 每N个step衰减一次
'decay_rate': 0.9, # 学习率衰减系数, 乘法
@ -56,8 +56,9 @@ def train(hyper_parameters=None, rate=1.0):
# 模型地址, loss降低则保存的依据, save_best_only=True, save_weights_only=True
'path_hyper_parameters': path_hyper_parameters, # 模型(包括embedding),超参数地址,
'path_fineture': path_fineture, # 保存embedding trainable地址, 例如字向量、词向量、bert向量等
'routings': 1,
'routings': 5,
'dim_capsule': 16,
'num_capsule': 16
},
'embedding': {'layer_indexes': [12], # bert取的层数,
'corpus_path': '', # embedding预训练数据地址,不配则会默认取conf里边默认的地址
@ -89,7 +90,7 @@ def train(hyper_parameters=None, rate=1.0):
if __name__ == "__main__":
train(rate=0.001) # sample条件下设为1,否则训练语料可能会很少
train(rate=0.01) # sample条件下设为1,否则训练语料可能会很少
# 注意: 4G的080Ti的GPU、win10下batch_size=32,len_max=20, gpu<=0.87, 应该就可以bert-fineture了。
# 全量数据训练一轮(batch_size=32),就能达到80%准确率(验证集), 效果还是不错的
# win10下出现过错误,gpu、len_max、batch_size配小一点就好:ailed to allocate 3.56G (3822520832 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory