2017-02-20 05:11:57 +08:00
|
|
|
# Differentiable Neural Computer
|
|
|
|
# inspired by (http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html)
|
|
|
|
# some ideas taken from https://github.com/Mostafa-Samir/DNC-tensorflow
|
|
|
|
# Sam Greydanus. February 2017. MIT License.
|
|
|
|
|
|
|
|
import tensorflow as tf
|
|
|
|
import numpy as np
|
|
|
|
from memory import Memory
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
class DNC:
|
2017-02-27 09:45:35 +08:00
|
|
|
def __init__(self, make_controller, FLAGS, input_steps=None):
|
2017-02-20 05:11:57 +08:00
|
|
|
'''
|
|
|
|
Builds a TensorFlow graph for the Differentiable Neural Computer. Uses TensorArrays and a while loop for efficiency
|
|
|
|
Parameters:
|
|
|
|
----------
|
|
|
|
make_controller: Controller
|
|
|
|
An object class which inherits from the Controller class. We build the object in this function
|
|
|
|
FLAGS: a set of TensorFlow FlagValues which must include
|
|
|
|
FLAGS.xlen: the length of the input vector of the controller
|
|
|
|
FLAGS.ylen: the length of the output vector of the controller
|
|
|
|
FLAGS.batch_size: the number of batches
|
|
|
|
FLAGS.R: the number of DNC read heads
|
|
|
|
FLAGS.W: the DNC "word length" (length of each DNC memory vector)
|
|
|
|
FLAGS.N: the number of DNC word vectors (corresponds to memory size)
|
|
|
|
'''
|
|
|
|
self.xlen = xlen = FLAGS.xlen
|
|
|
|
self.ylen = ylen = FLAGS.ylen
|
|
|
|
self.batch_size = batch_size = FLAGS.batch_size
|
|
|
|
self.R = R = FLAGS.R
|
|
|
|
self.W = W = FLAGS.W
|
|
|
|
self.N = N = FLAGS.N
|
|
|
|
|
|
|
|
# create 1) the DNC's memory and 2) the DNC's controller
|
|
|
|
self.memory = Memory(R, W, N, batch_size)
|
|
|
|
self.controller = make_controller(FLAGS)
|
|
|
|
|
|
|
|
# input data placeholders
|
|
|
|
self.X = tf.placeholder(tf.float32, [batch_size, None, xlen], name='X')
|
|
|
|
self.y = tf.placeholder(tf.float32, [batch_size, None, ylen], name='y')
|
|
|
|
self.tsteps = tf.placeholder(tf.int32, name='tsteps')
|
2017-02-27 09:45:35 +08:00
|
|
|
self.input_steps = input_steps if input_steps is not None else self.tsteps
|
2017-02-20 05:11:57 +08:00
|
|
|
|
|
|
|
self.X_tensor_array = self.unstack_time_dim(self.X)
|
|
|
|
|
|
|
|
# initialize states
|
2017-02-27 09:45:35 +08:00
|
|
|
self.nn_state = self.controller.get_state()
|
|
|
|
self.dnc_state = self.memory.zero_state()
|
2017-02-20 05:11:57 +08:00
|
|
|
|
|
|
|
# values for which we want a history
|
|
|
|
self.hist_keys = ['y_hat', 'f', 'g_a', 'g_w', 'w_r', 'w_w', 'u']
|
2017-02-27 09:45:35 +08:00
|
|
|
dnc_hist = [tf.TensorArray(tf.float32, self.tsteps, clear_after_read=False) for _ in range(len(self.hist_keys))]
|
2017-02-20 05:11:57 +08:00
|
|
|
|
|
|
|
# loop through time
|
2017-02-27 09:45:35 +08:00
|
|
|
with tf.variable_scope("dnc_scope", reuse=True) as scope:
|
2017-02-20 05:11:57 +08:00
|
|
|
time = tf.constant(0, dtype=tf.int32)
|
|
|
|
|
|
|
|
output = tf.while_loop(
|
|
|
|
cond=lambda time, *_: time < self.tsteps,
|
|
|
|
body=self.step,
|
2017-02-27 09:45:35 +08:00
|
|
|
loop_vars=(time, self.nn_state, self.dnc_state, dnc_hist),
|
2017-02-20 05:11:57 +08:00
|
|
|
)
|
2017-02-27 09:45:35 +08:00
|
|
|
(_, self.next_nn_state, self.next_dnc_state, dnc_hist) = output
|
2017-02-20 05:11:57 +08:00
|
|
|
|
|
|
|
# write down the history
|
2017-02-27 09:45:35 +08:00
|
|
|
controller_dependencies = [self.controller.update_state(self.next_nn_state)]
|
2017-02-22 05:38:44 +08:00
|
|
|
with tf.control_dependencies(controller_dependencies):
|
2017-02-20 05:11:57 +08:00
|
|
|
self.dnc_hist = {self.hist_keys[i]: self.stack_time_dim(v) for i, v in enumerate(dnc_hist)} # convert to dict
|
|
|
|
|
2017-02-27 09:45:35 +08:00
|
|
|
def step2(self, time, nn_state, dnc_state, dnc_hist):
|
|
|
|
|
|
|
|
# map from tuple to dict for readability
|
|
|
|
dnc_state = {self.memory.state_keys[i]: v for i, v in enumerate(dnc_state)}
|
|
|
|
dnc_hist = {self.hist_keys[i]: v for i, v in enumerate(dnc_hist)}
|
|
|
|
|
|
|
|
# one full pass!
|
|
|
|
X_t = self.X_tensor_array.read(time)
|
|
|
|
v, zeta, next_nn_state = self.controller.step(X_t, dnc_state['r'], nn_state)
|
|
|
|
next_dnc_state = self.memory.step(zeta, dnc_state)
|
|
|
|
y_hat = self.controller.next_y_hat(v, next_dnc_state['r'])
|
|
|
|
|
|
|
|
dnc_hist['y_hat'] = dnc_hist['y_hat'].write(time, y_hat)
|
|
|
|
dnc_hist['f'] = dnc_hist['f'].write(time, zeta['f'])
|
|
|
|
dnc_hist['g_a'] = dnc_hist['g_a'].write(time, zeta['g_a'])
|
|
|
|
dnc_hist['g_w'] = dnc_hist['g_w'].write(time, zeta['g_w'])
|
|
|
|
dnc_hist['w_r'] = dnc_hist['w_r'].write(time, next_dnc_state['w_r'])
|
|
|
|
dnc_hist['w_w'] = dnc_hist['w_w'].write(time, next_dnc_state['w_w'])
|
|
|
|
dnc_hist['u'] = dnc_hist['u'].write(time, next_dnc_state['u'])
|
|
|
|
|
|
|
|
# map from dict to tuple for tf.while_loop :/
|
|
|
|
next_dnc_state = [next_dnc_state[k] for k in self.memory.state_keys]
|
|
|
|
dnc_hist = [dnc_hist[k] for k in self.hist_keys]
|
|
|
|
|
|
|
|
time += 1
|
|
|
|
return time, next_nn_state, next_dnc_state, dnc_hist
|
|
|
|
|
2017-02-20 05:11:57 +08:00
|
|
|
def step(self, time, nn_state, dnc_state, dnc_hist):
|
|
|
|
'''
|
|
|
|
Performs the feedforward step of the DNC in order to get the DNC output
|
|
|
|
Parameters:
|
|
|
|
----------
|
|
|
|
time: Constant 1-D Tensor
|
|
|
|
the current time step of the model
|
|
|
|
nn_state: LSTMStateTensor or another type of state tensor
|
|
|
|
for the controller network
|
|
|
|
dnc_state: Tuple
|
|
|
|
set of 7 Tensors which define the current state of the DNC (M, u, p, L, w_w, w_r, r) ...see paper
|
|
|
|
dnc_hist: Tuple
|
|
|
|
set of 7 TensorArrays which track the historical states of the DNC (y_hat, f, g_a, g_w, w_r, w_w, u). Good for visualization
|
|
|
|
Returns: Tuple
|
|
|
|
same as input parameters, but updated for the current time step
|
|
|
|
'''
|
|
|
|
|
|
|
|
# map from tuple to dict for readability
|
|
|
|
dnc_state = {self.memory.state_keys[i]: v for i, v in enumerate(dnc_state)}
|
|
|
|
dnc_hist = {self.hist_keys[i]: v for i, v in enumerate(dnc_hist)}
|
|
|
|
|
2017-02-27 09:45:35 +08:00
|
|
|
def use_prev_output():
|
|
|
|
y_prev = tf.concat((dnc_hist['y_hat'].read(time-1), tf.zeros([self.batch_size, self.xlen - self.ylen])), axis=1)
|
|
|
|
return tf.reshape(y_prev, (self.batch_size, self.xlen))
|
|
|
|
|
|
|
|
def use_input_array():
|
|
|
|
return self.X_tensor_array.read(time)
|
|
|
|
|
2017-02-20 05:11:57 +08:00
|
|
|
# one full pass!
|
2017-02-27 09:45:35 +08:00
|
|
|
X_t = tf.cond(time < self.input_steps, use_input_array, use_prev_output)
|
2017-02-20 05:11:57 +08:00
|
|
|
v, zeta, next_nn_state = self.controller.step(X_t, dnc_state['r'], nn_state)
|
|
|
|
next_dnc_state = self.memory.step(zeta, dnc_state)
|
|
|
|
y_hat = self.controller.next_y_hat(v, next_dnc_state['r'])
|
|
|
|
|
|
|
|
dnc_hist['y_hat'] = dnc_hist['y_hat'].write(time, y_hat)
|
|
|
|
dnc_hist['f'] = dnc_hist['f'].write(time, zeta['f'])
|
|
|
|
dnc_hist['g_a'] = dnc_hist['g_a'].write(time, zeta['g_a'])
|
|
|
|
dnc_hist['g_w'] = dnc_hist['g_w'].write(time, zeta['g_w'])
|
|
|
|
dnc_hist['w_r'] = dnc_hist['w_r'].write(time, next_dnc_state['w_r'])
|
|
|
|
dnc_hist['w_w'] = dnc_hist['w_w'].write(time, next_dnc_state['w_w'])
|
|
|
|
dnc_hist['u'] = dnc_hist['u'].write(time, next_dnc_state['u'])
|
|
|
|
|
|
|
|
# map from dict to tuple for tf.while_loop :/
|
|
|
|
next_dnc_state = [next_dnc_state[k] for k in self.memory.state_keys]
|
|
|
|
dnc_hist = [dnc_hist[k] for k in self.hist_keys]
|
|
|
|
|
|
|
|
time += 1
|
|
|
|
return time, next_nn_state, next_dnc_state, dnc_hist
|
|
|
|
|
|
|
|
def get_outputs(self):
|
|
|
|
'''
|
|
|
|
Allows user to access the output of the DNC after all time steps have been executed
|
|
|
|
Returns: tuple
|
|
|
|
y_hat: Tensor (batch_size, controller_dim)
|
|
|
|
The DNC's ouput
|
|
|
|
dnc_hist: Tuple
|
|
|
|
Set of Tensors which contain values of (y_hat, f, g_a, g_w, w_r, w_w, u) respectively for all time steps
|
|
|
|
'''
|
|
|
|
return self.dnc_hist['y_hat'], self.dnc_hist
|
|
|
|
|
|
|
|
def stack_time_dim(self, v):
|
|
|
|
'''
|
|
|
|
Stacks a TensorArray along its time dimension, then transposes so that the time dimension is at index [1]
|
|
|
|
Parameters:
|
|
|
|
----------
|
|
|
|
v: TensorArray [(batch_size, ...), ...]
|
|
|
|
An array of n-dimensional tensor where for each, the first dimension is the batch dimension
|
|
|
|
Returns: Tensor (batch_size, ylen)
|
|
|
|
u: Tensor (batch_size, tsteps, ...)
|
|
|
|
The stacked tensor with index [1] as the time dimension
|
|
|
|
'''
|
|
|
|
stacked = v.stack()
|
|
|
|
return tf.transpose(stacked, [1,0] + range(2, len(stacked.get_shape())) )
|
|
|
|
|
|
|
|
def unstack_time_dim(self, v):
|
|
|
|
'''
|
|
|
|
Unstacks a TensorArray along its time dimension
|
|
|
|
Parameters:
|
|
|
|
----------
|
|
|
|
v: Tensor (batch_size, tsteps, ...)
|
|
|
|
An n-dimensional tensor where dim[0] is the batch dimension and dim[1] is the time dimension
|
|
|
|
Returns: TensorArray [(batch_size, ...) ...]
|
|
|
|
u: Tensor (batch_size, tsteps, ...)
|
|
|
|
An array of n-dimensional tensor where, for each, the first dimension is the batch dimension
|
|
|
|
'''
|
|
|
|
array = tf.TensorArray(dtype=v.dtype, size=self.tsteps)
|
|
|
|
make_time_dim_first = [1, 0] + range(2, len(v.get_shape()))
|
|
|
|
v_T = tf.transpose(v, make_time_dim_first)
|
|
|
|
return array.unstack(v_T)
|