162 lines
6.4 KiB
Python
Executable File
162 lines
6.4 KiB
Python
Executable File
# Differentiable Neural Computer
|
|
# inspired by (http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html)
|
|
# some ideas taken from https://github.com/Mostafa-Samir/DNC-tensorflow
|
|
# Sam Greydanus. February 2017. MIT License.
|
|
|
|
import tensorflow as tf
|
|
import numpy as np
|
|
|
|
class Controller():
|
|
def __init__(self, FLAGS):
|
|
'''
|
|
An interface that defines how the neural network "controller" interacts with the DNC
|
|
Parameters:
|
|
----------
|
|
FLAGS: a set of TensorFlow FlagValues which must include
|
|
FLAGS.xlen: the length of the input vector of the controller
|
|
FLAGS.ylen: the length of the output vector of the controller
|
|
FLAGS.batch_size: the number of batches
|
|
FLAGS.R: the number of DNC read heads
|
|
FLAGS.W: the DNC "word length" (length of each DNC memory vector)
|
|
|
|
Returns: Tensor (batch_size, nn_output_size)
|
|
'''
|
|
|
|
self.xlen = FLAGS.xlen
|
|
self.ylen = FLAGS.ylen
|
|
self.batch_size = FLAGS.batch_size
|
|
self.R = R = FLAGS.R
|
|
self.W = W = FLAGS.W
|
|
|
|
self.chi_dim = self.xlen + self.W * self.R
|
|
self.zeta_dim = W*R + 3*W + 5*R + 3
|
|
self.r_dim = W*R
|
|
|
|
# define network vars
|
|
self.params = {}
|
|
with tf.name_scope("controller"):
|
|
self.init_controller_params()
|
|
self.controller_dim = self.get_controller_dim()
|
|
|
|
init = tf.truncated_normal_initializer(stddev=0.075, dtype=tf.float32)
|
|
self.params['W_z'] = tf.get_variable("W_z", [self.controller_dim, self.zeta_dim], initializer=init)
|
|
self.params['W_v'] = tf.get_variable("W_v", [self.controller_dim, self.ylen], initializer=init)
|
|
self.params['W_r'] = tf.get_variable("W_r", [self.r_dim, self.ylen], initializer=init)
|
|
|
|
def init_controller_params(self):
|
|
'''
|
|
Initializes all the parameters of the neural network controller
|
|
'''
|
|
raise NotImplementedError("init_controller_params does not exist")
|
|
|
|
def nn_step(self, chi, state):
|
|
'''
|
|
Performs the feedforward step of the controller in order to get the DNC interface vector, zeta
|
|
Parameters:
|
|
----------
|
|
chi: Tensor (batch_size, chi_dim)
|
|
the input concatenated with the previous output of the DNC
|
|
state: LSTMStateTensor or another type of state tensor
|
|
Returns: Tuple
|
|
zeta_hat: Tensor (batch_size, controller_dim)
|
|
next_state: LSTMStateTensor or another type of state tensor
|
|
'''
|
|
raise NotImplementedError("nn_step does not exist")
|
|
|
|
def zero_state(self):
|
|
'''
|
|
Returns the initial state of the controller. If the controller is not recurrent, it still needs to return a dummy value
|
|
Returns: LSTMStateTensor or another type of state tensor
|
|
nn_state: LSTMStateTensor or another type of state tensor
|
|
'''
|
|
raise NotImplementedError("get_state does not exist")
|
|
|
|
def get_controller_dim(self):
|
|
'''
|
|
Feeds zeros through the controller and obtains an output in order to find the controller's output dimension
|
|
Returns: int
|
|
controller_dim: the output dimension of the controller
|
|
'''
|
|
test_chi = tf.zeros([self.batch_size, self.chi_dim])
|
|
nn_output, nn_state = self.nn_step(test_chi, state=None)
|
|
return nn_output.get_shape().as_list()[-1]
|
|
|
|
def prepare_interface(self, zeta_hat):
|
|
'''
|
|
Packages the interface vector, zeta, as a dictionary of variables as described in the DNC Nature paper
|
|
Parameters:
|
|
----------
|
|
zeta_hat: Tensor (batch_size, zeta_dim)
|
|
the interface vector before processing, zeta_hat
|
|
Returns: dict
|
|
zeta: variable names (string) mapping to tensors (Tensor)
|
|
'''
|
|
zeta = {}
|
|
R, W = self.R, self.W
|
|
splits = np.cumsum([0,W*R,R,W,1,W,W,R,1,1,3*R])
|
|
vs = [zeta_hat[:, splits[i]:splits[i+1]] for i in range(len(splits)-1)]
|
|
|
|
kappa_r = tf.reshape(vs[0], (-1, W, R))
|
|
beta_r = tf.reshape(vs[1], (-1, R))
|
|
kappa_w = tf.reshape(vs[2], (-1, W, 1))
|
|
beta_w = tf.reshape(vs[3], (-1, 1))
|
|
e = tf.reshape(vs[4], (-1, W))
|
|
v = tf.reshape(vs[5], (-1, W))
|
|
f = tf.reshape(vs[6], (-1, R))
|
|
g_a = vs[7]
|
|
g_w = vs[8]
|
|
pi = tf.reshape(vs[9], (-1, 3, R))
|
|
|
|
zeta['kappa_r'] = kappa_r
|
|
zeta['beta_r'] = 1 + tf.nn.softplus(beta_r)
|
|
zeta['kappa_w'] = kappa_w
|
|
zeta['beta_w'] = 1 + tf.nn.softplus(beta_w)
|
|
zeta['e'] = tf.nn.sigmoid(e)
|
|
zeta['v'] = v
|
|
zeta['f'] = tf.nn.sigmoid(f)
|
|
zeta['g_a'] = tf.nn.sigmoid(g_a)
|
|
zeta['g_w'] = tf.nn.sigmoid(g_w)
|
|
zeta['pi'] = tf.nn.softmax(pi, 1)
|
|
|
|
return zeta
|
|
|
|
def step(self, X, r_prev, state):
|
|
'''
|
|
The sum of operations executed by the Controller at a given time step before interfacing with the DNC
|
|
Parameters:
|
|
----------
|
|
X: Tensor (batch_size, chi_dim)
|
|
the input for this time step
|
|
r_prev: previous output of the DNC
|
|
state: LSTMStateTensor or another type of state tensor
|
|
Returns: Tuple
|
|
v: Tensor (batch_size, ylen)
|
|
The controller's output (eventually added elementwise to the DNC output)
|
|
zeta: The processed interface vector which the network will use to interact with the DNC
|
|
nn_state: LSTMStateTensor or another type of state tensor
|
|
'''
|
|
r_prev = tf.reshape(r_prev, (-1, self.r_dim)) # flatten
|
|
chi = tf.concat([X, r_prev], 1)
|
|
nn_output, nn_state = self.nn_step(chi, state)
|
|
|
|
v = tf.matmul(nn_output, self.params['W_v'])
|
|
zeta_hat = tf.matmul(nn_output, self.params['W_z'])
|
|
zeta = self.prepare_interface(zeta_hat)
|
|
return v, zeta, nn_state
|
|
|
|
def next_y_hat(self, v, r):
|
|
'''
|
|
The sum of operations executed by the Controller at a given time step after interacting with the DNC
|
|
Parameters:
|
|
----------
|
|
v: Tensor (batch_size, ylen)
|
|
The controller's output (added elementwise to the DNC output)
|
|
r_prev: the current output of the DNC
|
|
Returns: Tensor (batch_size, ylen)
|
|
y_hat: Tensor (batch_size, controller_dim)
|
|
The DNC's ouput
|
|
'''
|
|
r = tf.reshape(r, (-1, self.W * self.R)) # flatten
|
|
y_hat = v + tf.matmul(r, self.params['W_r'])
|
|
return y_hat
|