add CapsuleNet-bojone of NLP
This commit is contained in:
parent
0b1ed10d2b
commit
164be1eb0a
@ -148,7 +148,7 @@ class PreprocessText:
|
||||
save_json(l2i_i2l, path_fast_text_model_l2i_i2l)
|
||||
|
||||
len_ql = int(rate * len(ques))
|
||||
if len_ql <= 5000: # sample时候不生效,使得语料足够训练
|
||||
if len_ql <= 500: # sample时候不生效,使得语料足够训练
|
||||
len_ql = len(ques)
|
||||
|
||||
x = []
|
||||
|
@ -7,9 +7,11 @@ uncommenting them and commenting their counterparts.
|
||||
Author: Xifeng Guo, E-mail: `guoxifeng1990@163.com`, Github: `https://github.com/XifengGuo/CapsNet-Keras`
|
||||
"""
|
||||
|
||||
|
||||
from keras.layers import Activation, Layer
|
||||
from keras import initializers, layers
|
||||
import keras.backend as K
|
||||
import tensorflow as tf
|
||||
from keras import initializers, layers
|
||||
|
||||
|
||||
class Length(layers.Layer):
|
||||
@ -192,14 +194,94 @@ def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
|
||||
return layers.Lambda(squash, )(outputs)
|
||||
|
||||
|
||||
"""
|
||||
# The following is another way to implement primary capsule layer. This is much slower.
|
||||
# Apply Conv2D `n_channels` times and concatenate all capsules
|
||||
def PrimaryCap(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
|
||||
def PrimaryCap_nchannels(inputs, dim_capsule, n_channels, kernel_size, strides, padding):
|
||||
# The following is another way to implement primary capsule layer. This is much slower.
|
||||
# Apply Conv2D `n_channels` times and concatenate all capsules
|
||||
outputs = []
|
||||
for _ in range(n_channels):
|
||||
output = layers.Conv2D(filters=dim_capsule, kernel_size=kernel_size, strides=strides, padding=padding)(inputs)
|
||||
outputs.append(layers.Reshape([output.get_shape().as_list()[1] ** 2, dim_capsule])(output))
|
||||
outputs = layers.Concatenate(axis=1)(outputs)
|
||||
return layers.Lambda(squash)(outputs)
|
||||
"""
|
||||
|
||||
|
||||
|
||||
|
||||
def squash_bojone(x, axis=-1):
|
||||
"""
|
||||
activation of squash
|
||||
:param x: vector
|
||||
:param axis: int
|
||||
:return: vector
|
||||
"""
|
||||
s_squared_norm = K.sum(K.square(x), axis, keepdims=True)
|
||||
scale = K.sqrt(s_squared_norm + K.epsilon())
|
||||
return x / scale
|
||||
|
||||
|
||||
class Capsule_bojone(Layer):
|
||||
"""
|
||||
# auther: bojone
|
||||
# explain: A Capsule Implement with Pure Keras
|
||||
# github: https://github.com/bojone/Capsule/blob/master/Capsule_Keras.py
|
||||
"""
|
||||
def __init__(self, num_capsule, dim_capsule, routings=3, kernel_size=(9, 1),
|
||||
share_weights=True, activation='default', **kwargs):
|
||||
super(Capsule_bojone, self).__init__(**kwargs)
|
||||
self.num_capsule = num_capsule
|
||||
self.dim_capsule = dim_capsule
|
||||
self.routings = routings
|
||||
self.kernel_size = kernel_size
|
||||
self.share_weights = share_weights
|
||||
if activation == 'default':
|
||||
self.activation = squash_bojone
|
||||
else:
|
||||
self.activation = Activation(activation)
|
||||
|
||||
def build(self, input_shape):
|
||||
super(Capsule_bojone, self).build(input_shape)
|
||||
input_dim_capsule = input_shape[-1]
|
||||
if self.share_weights:
|
||||
self.W = self.add_weight(name='capsule_kernel',
|
||||
shape=(1, input_dim_capsule,
|
||||
self.num_capsule * self.dim_capsule),
|
||||
# shape=self.kernel_size,
|
||||
initializer='glorot_uniform',
|
||||
trainable=True)
|
||||
else:
|
||||
input_num_capsule = input_shape[-2]
|
||||
self.W = self.add_weight(name='capsule_kernel',
|
||||
shape=(input_num_capsule,
|
||||
input_dim_capsule,
|
||||
self.num_capsule * self.dim_capsule),
|
||||
initializer='glorot_uniform',
|
||||
trainable=True)
|
||||
|
||||
def call(self, u_vecs):
|
||||
if self.share_weights:
|
||||
u_hat_vecs = K.conv1d(u_vecs, self.W)
|
||||
else:
|
||||
u_hat_vecs = K.local_conv1d(u_vecs, self.W, [1], [1])
|
||||
|
||||
batch_size = K.shape(u_vecs)[0]
|
||||
input_num_capsule = K.shape(u_vecs)[1]
|
||||
u_hat_vecs = K.reshape(u_hat_vecs, (batch_size, input_num_capsule,
|
||||
self.num_capsule, self.dim_capsule))
|
||||
u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3))
|
||||
# final u_hat_vecs.shape = [None, num_capsule, input_num_capsule, dim_capsule]
|
||||
|
||||
b = K.zeros_like(u_hat_vecs[:, :, :, 0]) # shape = [None, num_capsule, input_num_capsule]
|
||||
outputs = None
|
||||
for i in range(self.routings):
|
||||
b = K.permute_dimensions(b, (0, 2, 1)) # shape = [None, input_num_capsule, num_capsule]
|
||||
c = K.softmax(b)
|
||||
c = K.permute_dimensions(c, (0, 2, 1))
|
||||
b = K.permute_dimensions(b, (0, 2, 1))
|
||||
outputs = self.activation(K.batch_dot(c, u_hat_vecs, [2, 2]))
|
||||
if i < self.routings - 1:
|
||||
b = K.batch_dot(outputs, u_hat_vecs, [2, 3])
|
||||
|
||||
return outputs
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return (None, self.num_capsule, self.dim_capsule)
|
@ -90,7 +90,7 @@ def train(hyper_parameters=None, rate=1.0):
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
train(rate=0.001) # sample条件下设为1,否则训练语料可能会很少
|
||||
train(rate=0.01) # sample条件下设为1,否则训练语料可能会很少
|
||||
# 注意: 4G的080Ti的GPU、win10下batch_size=32,len_max=20, gpu<=0.87, 应该就可以bert-fineture了。
|
||||
# 全量数据训练一轮(batch_size=32),就能达到80%准确率(验证集), 效果还是不错的
|
||||
# win10下出现过错误,gpu、len_max、batch_size配小一点就好:ailed to allocate 3.56G (3822520832 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
|
||||
|
@ -44,7 +44,7 @@ class TextCNNGraph(graph):
|
||||
)(conv)
|
||||
conv_pools.append(pooled)
|
||||
# 拼接
|
||||
x = Concatenate(axis=1)(conv_pools)
|
||||
x = Concatenate(axis=-1)(conv_pools)
|
||||
x = Flatten()(x)
|
||||
x = Dropout(self.dropout)(x)
|
||||
output = Dense(units=self.label, activation=self.activate_classify)(x)
|
||||
|
@ -98,7 +98,7 @@ def train(hyper_parameters=None, rate=1.0):
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
train(rate=0.001) # sample条件下设为1,否则训练语料可能会很少
|
||||
train(rate=0.01) # sample条件下设为1,否则训练语料可能会很少
|
||||
# 注意: 4G的080Ti的GPU、win10下batch_size=32,len_max=20, gpu<=0.87, 应该就可以bert-fineture了。
|
||||
# 全量数据训练一轮(batch_size=32),就能达到80%准确率(验证集), 效果还是不错的
|
||||
# win10下出现过错误,gpu、len_max、batch_size配小一点就好:ailed to allocate 3.56G (3822520832 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
|
||||
|
@ -93,7 +93,7 @@ def train(hyper_parameters=None, rate=1.0):
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
train(rate=0.001) # sample条件下设为1,否则训练语料可能会很少
|
||||
train(rate=0.01) # sample条件下设为1,否则训练语料可能会很少
|
||||
# 注意: 4G的080Ti的GPU、win10下batch_size=32,len_max=20, gpu<=0.87, 应该就可以bert-fineture了。
|
||||
# 全量数据训练一轮(batch_size=32),就能达到80%准确率(验证集), 效果还是不错的
|
||||
# win10下出现过错误,gpu、len_max、batch_size配小一点就好:ailed to allocate 3.56G (3822520832 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
|
||||
|
@ -5,8 +5,9 @@
|
||||
# @function :graph of base
|
||||
|
||||
|
||||
from keras_textclassification.keras_layers.capsule import CapsuleLayer, PrimaryCap, Length, Mask
|
||||
from keras.layers import Conv2D, MaxPool2D, Concatenate
|
||||
from keras_textclassification.keras_layers.capsule import Capsule_bojone, CapsuleLayer, PrimaryCap, Length, Mask
|
||||
from keras.layers import Conv1D, Conv2D, MaxPool2D, Concatenate, SpatialDropout1D
|
||||
from keras.layers import Bidirectional, LSTM, GRU
|
||||
from keras.layers import Dropout, Dense, Flatten
|
||||
from keras.optimizers import Adam
|
||||
from keras.layers import Reshape
|
||||
@ -23,8 +24,9 @@ class CapsuleNetGraph(graph):
|
||||
初始化
|
||||
:param hyper_parameters: json,超参
|
||||
"""
|
||||
self.routings = hyper_parameters['model'].get('routings', 1)
|
||||
self.routings = hyper_parameters['model'].get('routings', 5)
|
||||
self.dim_capsule = hyper_parameters['model'].get('dim_capsule', 16)
|
||||
self.num_capsule = hyper_parameters['model'].get('num_capsule', 16)
|
||||
super().__init__(hyper_parameters)
|
||||
|
||||
def create_model(self, hyper_parameters):
|
||||
@ -35,43 +37,101 @@ class CapsuleNetGraph(graph):
|
||||
"""
|
||||
super().create_model(hyper_parameters)
|
||||
embedding = self.word_embedding.output
|
||||
embed_layer = SpatialDropout1D(self.dropout)(embedding)
|
||||
conv_pools = []
|
||||
for filter in self.filters:
|
||||
x = Conv1D(filters=self.filters_num,
|
||||
kernel_size=filter,
|
||||
padding='valid',
|
||||
kernel_initializer='normal',
|
||||
activation='relu',
|
||||
)(embed_layer)
|
||||
capsule = Capsule_bojone(num_capsule=self.num_capsule,
|
||||
dim_capsule=self.dim_capsule,
|
||||
routings=self.routings,
|
||||
kernel_size=(filter, 1),
|
||||
share_weights=True)(x)
|
||||
conv_pools.append(capsule)
|
||||
capsule = Concatenate(axis=-1)(conv_pools)
|
||||
capsule = Flatten()(capsule)
|
||||
capsule = Dropout(self.dropout)(capsule)
|
||||
output = Dense(self.label, activation='sigmoid')(capsule)
|
||||
self.model = Model(inputs=self.word_embedding.input, outputs=output)
|
||||
self.model.summary(120)
|
||||
|
||||
def create_model_gru(self, hyper_parameters):
|
||||
"""
|
||||
构建神经网络, bi-gru + capsule
|
||||
:param hyper_parameters:json, hyper parameters of network
|
||||
:return: tensor, moedl
|
||||
"""
|
||||
super().create_model(hyper_parameters)
|
||||
embedding = self.word_embedding.output
|
||||
embed_layer = SpatialDropout1D(self.dropout)(embedding)
|
||||
x = Bidirectional(GRU(self.filters_num,
|
||||
activation='relu',
|
||||
dropout=self.dropout,
|
||||
recurrent_dropout=self.dropout,
|
||||
return_sequences=True))(embed_layer)
|
||||
capsule = Capsule_bojone(num_capsule=self.num_capsule,
|
||||
dim_capsule=self.dim_capsule,
|
||||
routings=self.routings,
|
||||
kernel_size=(self.filters[0], 1),
|
||||
share_weights=True)(x)
|
||||
capsule = Flatten()(capsule)
|
||||
capsule = Dropout(self.dropout)(capsule)
|
||||
output = Dense(self.label, activation='sigmoid')(capsule)
|
||||
self.model = Model(inputs=self.word_embedding.input, outputs=output)
|
||||
self.model.summary(120)
|
||||
|
||||
def create_model_basic(self, hyper_parameters):
|
||||
"""
|
||||
构建神经网络, 原版, 速度超级慢
|
||||
:param hyper_parameters:json, hyper parameters of network
|
||||
:return: tensor, moedl
|
||||
"""
|
||||
super().create_model(hyper_parameters)
|
||||
embedding = self.word_embedding.output
|
||||
embedding = SpatialDropout1D(self.dropout)(embedding)
|
||||
embedding_reshape = Reshape((self.len_max, self.embed_size, 1))(embedding)
|
||||
conv_pools = []
|
||||
for filter in self.filters:
|
||||
# Layer 1: Just a conventional Conv2D layer
|
||||
conv1 = Conv2D(filters=self.filters_num,
|
||||
kernel_size=(filter, self.embed_size),
|
||||
strides=1,
|
||||
strides=(1, 1),
|
||||
padding='valid',
|
||||
activation='relu',)(embedding_reshape)
|
||||
|
||||
conv_shape = K.int_shape(conv1)
|
||||
conv1 = Reshape((conv_shape[1], conv_shape[3], conv_shape[2]))(conv1)
|
||||
# Layer 2: Conv2D layer with `squash` activation, then reshape to [None, num_capsule, dim_capsule]
|
||||
primarycaps = PrimaryCap(inputs=conv1,
|
||||
dim_capsule=self.dim_capsule,
|
||||
n_channels=self.channel_size,
|
||||
kernel_size=(self.len_max - filter + 1, 1),
|
||||
kernel_size=(filter, conv_shape[3]),
|
||||
strides=(1, 1),
|
||||
padding='valid')
|
||||
# Layer 3: Capsule layer. Routing algorithm works here.
|
||||
digitcaps = CapsuleLayer(num_capsule=self.label,
|
||||
dim_capsule=int(self.dim_capsule * 2),
|
||||
routings=self.routings, )(primarycaps)
|
||||
digitcaps = CapsuleLayer(num_capsule=self.num_capsule, #self.label,
|
||||
dim_capsule=int(self.dim_capsule * 2),
|
||||
routings=self.routings, )(primarycaps)
|
||||
conv_pools.append(digitcaps)
|
||||
# 拼接
|
||||
x = Concatenate(axis=-1)(conv_pools)
|
||||
# x = Flatten()(x)
|
||||
# dense_layer = Dense(self.label, activation=self.activate_classify)(x)
|
||||
x = Flatten()(x)
|
||||
x = Dropout(self.dropout)(x)
|
||||
dense_layer = Dense(self.label, activation=self.activate_classify)(x)
|
||||
# out_caps = [dense_layer]
|
||||
# Layer 4: This is an auxiliary layer to replace each capsule with its length. Just to match the true label's shape.
|
||||
# If using tensorflow, this will not be necessary. :)
|
||||
out_caps = Length(name='capsnet')(x)
|
||||
# out_caps = Length(name='capsnet')(x)
|
||||
|
||||
|
||||
# 最后就是softmax
|
||||
self.model = Model(inputs=self.word_embedding.input, outputs=out_caps)
|
||||
self.model = Model(inputs=self.word_embedding.input, outputs=dense_layer)
|
||||
self.model.summary(120)
|
||||
|
||||
def create_model_old(self, hyper_parameters):
|
||||
def create_model_margin(self, hyper_parameters):
|
||||
"""
|
||||
构建神经网络
|
||||
:param hyper_parameters:json, hyper parameters of network
|
||||
|
@ -32,16 +32,16 @@ def train(hyper_parameters=None, rate=1.0):
|
||||
if not hyper_parameters:
|
||||
hyper_parameters = {
|
||||
'len_max': 50, # 句子最大长度, 固定 推荐20-50
|
||||
'embed_size': 96, # 字/词向量维度
|
||||
'embed_size': 300, # 字/词向量维度
|
||||
'vocab_size': 20000, # 这里随便填的,会根据代码里修改
|
||||
'trainable': False, # embedding是静态的还是动态的
|
||||
'level_type': 'char', # 级别, 最小单元, 字/词, 填 'char' or 'word'
|
||||
'embedding_type': 'random', # 级别, 嵌入类型, 还可以填'random'、 'bert' or 'word2vec"
|
||||
'model': {'label': 17, # 类别数
|
||||
'batch_size': 64, # 批处理尺寸
|
||||
'filters': [3, 4, 5], # 卷积核尺寸
|
||||
'batch_size': 16, # 批处理尺寸
|
||||
'filters': [2, 3, 4, 5], # 卷积核尺寸
|
||||
'filters_num': 300, # 卷积个数 text-cnn:300-600
|
||||
'channel_size': 3, # CNN通道数
|
||||
'channel_size': 16, # CNN通道数
|
||||
'dropout': 0.5, # 舍弃概率
|
||||
'decay_step': 100, # 学习率衰减step, 每N个step衰减一次
|
||||
'decay_rate': 0.9, # 学习率衰减系数, 乘法
|
||||
@ -56,8 +56,9 @@ def train(hyper_parameters=None, rate=1.0):
|
||||
# 模型地址, loss降低则保存的依据, save_best_only=True, save_weights_only=True
|
||||
'path_hyper_parameters': path_hyper_parameters, # 模型(包括embedding),超参数地址,
|
||||
'path_fineture': path_fineture, # 保存embedding trainable地址, 例如字向量、词向量、bert向量等
|
||||
'routings': 1,
|
||||
'routings': 5,
|
||||
'dim_capsule': 16,
|
||||
'num_capsule': 16
|
||||
},
|
||||
'embedding': {'layer_indexes': [12], # bert取的层数,
|
||||
'corpus_path': '', # embedding预训练数据地址,不配则会默认取conf里边默认的地址
|
||||
@ -89,7 +90,7 @@ def train(hyper_parameters=None, rate=1.0):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train(rate=0.001) # sample条件下设为1,否则训练语料可能会很少
|
||||
train(rate=0.01) # sample条件下设为1,否则训练语料可能会很少
|
||||
# 注意: 4G的080Ti的GPU、win10下batch_size=32,len_max=20, gpu<=0.87, 应该就可以bert-fineture了。
|
||||
# 全量数据训练一轮(batch_size=32),就能达到80%准确率(验证集), 效果还是不错的
|
||||
# win10下出现过错误,gpu、len_max、batch_size配小一点就好:ailed to allocate 3.56G (3822520832 bytes) from device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
|
||||
|
Loading…
Reference in New Issue
Block a user