repeat copy now works pretty well

This commit is contained in:
Sam Greydanus 2017-02-21 16:38:44 -05:00
parent a2521c76fa
commit 491bdd5485
38 changed files with 1690 additions and 818 deletions

View File

@ -3,11 +3,11 @@ DNC: Differentiable Neural Computer
Implements DeepMind's third nature paper, [Hybrid computing using a neural network with dynamic external memory](http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html) by Graves et. al. Implements DeepMind's third nature paper, [Hybrid computing using a neural network with dynamic external memory](http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html) by Graves et. al.
![DNC schema](copy/static/dnc_schema.png?raw=true) ![Repeat copy results](static/repeat_copy_results.png?raw=true)
Based on the paper's appendix, I sketched the [computational graph](https://docs.google.com/drawings/d/1Fc9eOH1wPw0PbBHWkEH39jik7h7HT9BWAE8ZhSr4hJc/edit?usp=sharing) Based on the paper's appendix, I sketched the [computational graph](https://docs.google.com/drawings/d/1Fc9eOH1wPw0PbBHWkEH39jik7h7HT9BWAE8ZhSr4hJc/edit?usp=sharing)
Based on Mostafa-Samir's code, I got the copy task to work ([Jupyter notebook](https://nbviewer.jupyter.org/github/greydanus/dnc/blob/master/copy/copy.ipynb)) I got the repeat-copy copy task to work ([Jupyter notebook](https://nbviewer.jupyter.org/github/greydanus/dnc/blob/master/repeat-copy/repeat-copy-nn.ipynb))
_This is a work in progress_ _This is a work in progress_
-------- --------

View File

@ -63,7 +63,7 @@ class Controller():
''' '''
raise NotImplementedError("nn_step does not exist") raise NotImplementedError("nn_step does not exist")
def zero_state(self): def get_state(self):
''' '''
Returns the initial state of the controller. If the controller is not recurrent, it still needs to return a dummy value Returns the initial state of the controller. If the controller is not recurrent, it still needs to return a dummy value
Returns: LSTMStateTensor or another type of state tensor Returns: LSTMStateTensor or another type of state tensor
@ -78,7 +78,7 @@ class Controller():
controller_dim: the output dimension of the controller controller_dim: the output dimension of the controller
''' '''
test_chi = tf.zeros([self.batch_size, self.chi_dim]) test_chi = tf.zeros([self.batch_size, self.chi_dim])
nn_output, nn_state = self.nn_step(test_chi, state=None) nn_output, nn_state = self.nn_step(test_chi, state=self.get_state())
return nn_output.get_shape().as_list()[-1] return nn_output.get_shape().as_list()[-1]
def prepare_interface(self, zeta_hat): def prepare_interface(self, zeta_hat):

Binary file not shown.

View File

@ -44,7 +44,7 @@ class DNC:
self.X_tensor_array = self.unstack_time_dim(self.X) self.X_tensor_array = self.unstack_time_dim(self.X)
# initialize states # initialize states
nn_state = self.controller.zero_state() nn_state = self.controller.get_state()
dnc_state = self.memory.zero_state() dnc_state = self.memory.zero_state()
# values for which we want a history # values for which we want a history
@ -63,7 +63,8 @@ class DNC:
(_, next_nn_state, next_dnc_state, dnc_hist) = output (_, next_nn_state, next_dnc_state, dnc_hist) = output
# write down the history # write down the history
with tf.control_dependencies(next_dnc_state): controller_dependencies = [self.controller.update_state(next_nn_state)]
with tf.control_dependencies(controller_dependencies):
self.dnc_hist = {self.hist_keys[i]: self.stack_time_dim(v) for i, v in enumerate(dnc_hist)} # convert to dict self.dnc_hist = {self.hist_keys[i]: self.stack_time_dim(v) for i, v in enumerate(dnc_hist)} # convert to dict
def step(self, time, nn_state, dnc_state, dnc_hist): def step(self, time, nn_state, dnc_state, dnc_hist):

Binary file not shown.

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,2 +0,0 @@
model_checkpoint_path: "model.ckpt-5000"
all_model_checkpoint_paths: "model.ckpt-5000"

View File

@ -28,5 +28,5 @@ class NNController(Controller):
h2 = tf.nn.elu(z2) h2 = tf.nn.elu(z2)
return h2, state return h2, state
def zero_state(self): def get_state(self):
return LSTMStateTuple(tf.zeros(1), tf.zeros(1)) return LSTMStateTuple(tf.zeros(1), tf.zeros(1))

View File

@ -0,0 +1,2 @@
model_checkpoint_path: "model.ckpt-10000"
all_model_checkpoint_paths: "model.ckpt-10000"

Binary file not shown.

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from controller import Controller from controller import Controller
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
""" """
A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes
@ -10,17 +10,17 @@ A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes
class RNNController(Controller): class RNNController(Controller):
def init_controller_params(self): def init_controller_params(self):
rnn_dim = 64 self.rnn_dim = 64
init = tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32) self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
self.params['cell'] = tf.nn.rnn_cell.BasicLSTMCell(rnn_dim, initializer = init) self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
self.params['state'] = tf.Variable(tf.zeros([self.batch_size, rnn_dim]), trainable=False)
self.params['output'] = tf.Variable(tf.zeros([self.batch_size, rnn_dim]), trainable=False)
def nn_step(self, X, state): def nn_step(self, X, state):
X = tf.convert_to_tensor(X) X = tf.convert_to_tensor(X)
return self.params['cell'](X, state) return self.lstm_cell(X, state)
def zero_state(self): def update_state(self, update):
return (self.params['output'], self.params['state']) return tf.group(self.output.assign(update[0]), self.state.assign(update[1]))
def get_state(self):
return LSTMStateTuple(self.output, self.state)

Binary file not shown.

View File

@ -0,0 +1,2 @@
model_checkpoint_path: "model.ckpt-1000"
all_model_checkpoint_paths: "model.ckpt-1000"

Binary file not shown.

View File

Before

Width:  |  Height:  |  Size: 117 KiB

After

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 85 KiB