dnc-jupyter/copy/dnc/controller.py

162 lines
6.4 KiB
Python
Raw Normal View History

2017-02-20 05:11:57 +08:00
# Differentiable Neural Computer
# inspired by (http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html)
# some ideas taken from https://github.com/Mostafa-Samir/DNC-tensorflow
# Sam Greydanus. February 2017. MIT License.
import tensorflow as tf
import numpy as np
class Controller():
def __init__(self, FLAGS):
'''
An interface that defines how the neural network "controller" interacts with the DNC
Parameters:
----------
FLAGS: a set of TensorFlow FlagValues which must include
FLAGS.xlen: the length of the input vector of the controller
FLAGS.ylen: the length of the output vector of the controller
FLAGS.batch_size: the number of batches
FLAGS.R: the number of DNC read heads
FLAGS.W: the DNC "word length" (length of each DNC memory vector)
Returns: Tensor (batch_size, nn_output_size)
'''
self.xlen = FLAGS.xlen
self.ylen = FLAGS.ylen
self.batch_size = FLAGS.batch_size
self.R = R = FLAGS.R
self.W = W = FLAGS.W
self.chi_dim = self.xlen + self.W * self.R
self.zeta_dim = W*R + 3*W + 5*R + 3
self.r_dim = W*R
# define network vars
self.params = {}
with tf.name_scope("controller"):
self.init_controller_params()
self.controller_dim = self.get_controller_dim()
init = tf.truncated_normal_initializer(stddev=0.075, dtype=tf.float32)
self.params['W_z'] = tf.get_variable("W_z", [self.controller_dim, self.zeta_dim], initializer=init)
self.params['W_v'] = tf.get_variable("W_v", [self.controller_dim, self.ylen], initializer=init)
self.params['W_r'] = tf.get_variable("W_r", [self.r_dim, self.ylen], initializer=init)
def init_controller_params(self):
'''
Initializes all the parameters of the neural network controller
'''
raise NotImplementedError("init_controller_params does not exist")
def nn_step(self, chi, state):
'''
Performs the feedforward step of the controller in order to get the DNC interface vector, zeta
Parameters:
----------
chi: Tensor (batch_size, chi_dim)
the input concatenated with the previous output of the DNC
state: LSTMStateTensor or another type of state tensor
Returns: Tuple
zeta_hat: Tensor (batch_size, controller_dim)
next_state: LSTMStateTensor or another type of state tensor
'''
raise NotImplementedError("nn_step does not exist")
def zero_state(self):
'''
Returns the initial state of the controller. If the controller is not recurrent, it still needs to return a dummy value
Returns: LSTMStateTensor or another type of state tensor
nn_state: LSTMStateTensor or another type of state tensor
'''
raise NotImplementedError("get_state does not exist")
def get_controller_dim(self):
'''
Feeds zeros through the controller and obtains an output in order to find the controller's output dimension
Returns: int
controller_dim: the output dimension of the controller
'''
test_chi = tf.zeros([self.batch_size, self.chi_dim])
nn_output, nn_state = self.nn_step(test_chi, state=None)
return nn_output.get_shape().as_list()[-1]
def prepare_interface(self, zeta_hat):
'''
Packages the interface vector, zeta, as a dictionary of variables as described in the DNC Nature paper
Parameters:
----------
zeta_hat: Tensor (batch_size, zeta_dim)
the interface vector before processing, zeta_hat
Returns: dict
zeta: variable names (string) mapping to tensors (Tensor)
'''
zeta = {}
R, W = self.R, self.W
splits = np.cumsum([0,W*R,R,W,1,W,W,R,1,1,3*R])
vs = [zeta_hat[:, splits[i]:splits[i+1]] for i in range(len(splits)-1)]
kappa_r = tf.reshape(vs[0], (-1, W, R))
beta_r = tf.reshape(vs[1], (-1, R))
kappa_w = tf.reshape(vs[2], (-1, W, 1))
beta_w = tf.reshape(vs[3], (-1, 1))
e = tf.reshape(vs[4], (-1, W))
v = tf.reshape(vs[5], (-1, W))
f = tf.reshape(vs[6], (-1, R))
g_a = vs[7]
g_w = vs[8]
pi = tf.reshape(vs[9], (-1, 3, R))
zeta['kappa_r'] = kappa_r
zeta['beta_r'] = 1 + tf.nn.softplus(beta_r)
zeta['kappa_w'] = kappa_w
zeta['beta_w'] = 1 + tf.nn.softplus(beta_w)
zeta['e'] = tf.nn.sigmoid(e)
zeta['v'] = v
zeta['f'] = tf.nn.sigmoid(f)
zeta['g_a'] = tf.nn.sigmoid(g_a)
zeta['g_w'] = tf.nn.sigmoid(g_w)
zeta['pi'] = tf.nn.softmax(pi, 1)
return zeta
def step(self, X, r_prev, state):
'''
The sum of operations executed by the Controller at a given time step before interfacing with the DNC
Parameters:
----------
X: Tensor (batch_size, chi_dim)
the input for this time step
r_prev: previous output of the DNC
state: LSTMStateTensor or another type of state tensor
Returns: Tuple
v: Tensor (batch_size, ylen)
The controller's output (eventually added elementwise to the DNC output)
zeta: The processed interface vector which the network will use to interact with the DNC
nn_state: LSTMStateTensor or another type of state tensor
'''
r_prev = tf.reshape(r_prev, (-1, self.r_dim)) # flatten
chi = tf.concat([X, r_prev], 1)
nn_output, nn_state = self.nn_step(chi, state)
v = tf.matmul(nn_output, self.params['W_v'])
zeta_hat = tf.matmul(nn_output, self.params['W_z'])
zeta = self.prepare_interface(zeta_hat)
return v, zeta, nn_state
def next_y_hat(self, v, r):
'''
The sum of operations executed by the Controller at a given time step after interacting with the DNC
Parameters:
----------
v: Tensor (batch_size, ylen)
The controller's output (added elementwise to the DNC output)
r_prev: the current output of the DNC
Returns: Tensor (batch_size, ylen)
y_hat: Tensor (batch_size, controller_dim)
The DNC's ouput
'''
r = tf.reshape(r, (-1, self.W * self.R)) # flatten
y_hat = v + tf.matmul(r, self.params['W_r'])
return y_hat