repeat copy now works pretty well
This commit is contained in:
parent
a2521c76fa
commit
491bdd5485
@ -3,11 +3,11 @@ DNC: Differentiable Neural Computer
|
|||||||
|
|
||||||
Implements DeepMind's third nature paper, [Hybrid computing using a neural network with dynamic external memory](http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html) by Graves et. al.
|
Implements DeepMind's third nature paper, [Hybrid computing using a neural network with dynamic external memory](http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html) by Graves et. al.
|
||||||
|
|
||||||
![DNC schema](copy/static/dnc_schema.png?raw=true)
|
![Repeat copy results](static/repeat_copy_results.png?raw=true)
|
||||||
|
|
||||||
Based on the paper's appendix, I sketched the [computational graph](https://docs.google.com/drawings/d/1Fc9eOH1wPw0PbBHWkEH39jik7h7HT9BWAE8ZhSr4hJc/edit?usp=sharing)
|
Based on the paper's appendix, I sketched the [computational graph](https://docs.google.com/drawings/d/1Fc9eOH1wPw0PbBHWkEH39jik7h7HT9BWAE8ZhSr4hJc/edit?usp=sharing)
|
||||||
|
|
||||||
Based on Mostafa-Samir's code, I got the copy task to work ([Jupyter notebook](https://nbviewer.jupyter.org/github/greydanus/dnc/blob/master/copy/copy.ipynb))
|
I got the repeat-copy copy task to work ([Jupyter notebook](https://nbviewer.jupyter.org/github/greydanus/dnc/blob/master/repeat-copy/repeat-copy-nn.ipynb))
|
||||||
|
|
||||||
_This is a work in progress_
|
_This is a work in progress_
|
||||||
--------
|
--------
|
||||||
|
@ -63,7 +63,7 @@ class Controller():
|
|||||||
'''
|
'''
|
||||||
raise NotImplementedError("nn_step does not exist")
|
raise NotImplementedError("nn_step does not exist")
|
||||||
|
|
||||||
def zero_state(self):
|
def get_state(self):
|
||||||
'''
|
'''
|
||||||
Returns the initial state of the controller. If the controller is not recurrent, it still needs to return a dummy value
|
Returns the initial state of the controller. If the controller is not recurrent, it still needs to return a dummy value
|
||||||
Returns: LSTMStateTensor or another type of state tensor
|
Returns: LSTMStateTensor or another type of state tensor
|
||||||
@ -78,7 +78,7 @@ class Controller():
|
|||||||
controller_dim: the output dimension of the controller
|
controller_dim: the output dimension of the controller
|
||||||
'''
|
'''
|
||||||
test_chi = tf.zeros([self.batch_size, self.chi_dim])
|
test_chi = tf.zeros([self.batch_size, self.chi_dim])
|
||||||
nn_output, nn_state = self.nn_step(test_chi, state=None)
|
nn_output, nn_state = self.nn_step(test_chi, state=self.get_state())
|
||||||
return nn_output.get_shape().as_list()[-1]
|
return nn_output.get_shape().as_list()[-1]
|
||||||
|
|
||||||
def prepare_interface(self, zeta_hat):
|
def prepare_interface(self, zeta_hat):
|
||||||
|
Binary file not shown.
@ -44,7 +44,7 @@ class DNC:
|
|||||||
self.X_tensor_array = self.unstack_time_dim(self.X)
|
self.X_tensor_array = self.unstack_time_dim(self.X)
|
||||||
|
|
||||||
# initialize states
|
# initialize states
|
||||||
nn_state = self.controller.zero_state()
|
nn_state = self.controller.get_state()
|
||||||
dnc_state = self.memory.zero_state()
|
dnc_state = self.memory.zero_state()
|
||||||
|
|
||||||
# values for which we want a history
|
# values for which we want a history
|
||||||
@ -63,7 +63,8 @@ class DNC:
|
|||||||
(_, next_nn_state, next_dnc_state, dnc_hist) = output
|
(_, next_nn_state, next_dnc_state, dnc_hist) = output
|
||||||
|
|
||||||
# write down the history
|
# write down the history
|
||||||
with tf.control_dependencies(next_dnc_state):
|
controller_dependencies = [self.controller.update_state(next_nn_state)]
|
||||||
|
with tf.control_dependencies(controller_dependencies):
|
||||||
self.dnc_hist = {self.hist_keys[i]: self.stack_time_dim(v) for i, v in enumerate(dnc_hist)} # convert to dict
|
self.dnc_hist = {self.hist_keys[i]: self.stack_time_dim(v) for i, v in enumerate(dnc_hist)} # convert to dict
|
||||||
|
|
||||||
def step(self, time, nn_state, dnc_state, dnc_hist):
|
def step(self, time, nn_state, dnc_state, dnc_hist):
|
||||||
|
BIN
dnc/dnc.pyc
BIN
dnc/dnc.pyc
Binary file not shown.
File diff suppressed because one or more lines are too long
404
repeat-copy/.ipynb_checkpoints/repeat-copy-nn-checkpoint.ipynb
Normal file
404
repeat-copy/.ipynb_checkpoints/repeat-copy-nn-checkpoint.ipynb
Normal file
File diff suppressed because one or more lines are too long
405
repeat-copy/.ipynb_checkpoints/repeat-copy-rnn-checkpoint.ipynb
Normal file
405
repeat-copy/.ipynb_checkpoints/repeat-copy-rnn-checkpoint.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -1,2 +0,0 @@
|
|||||||
model_checkpoint_path: "model.ckpt-5000"
|
|
||||||
all_model_checkpoint_paths: "model.ckpt-5000"
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -28,5 +28,5 @@ class NNController(Controller):
|
|||||||
h2 = tf.nn.elu(z2)
|
h2 = tf.nn.elu(z2)
|
||||||
return h2, state
|
return h2, state
|
||||||
|
|
||||||
def zero_state(self):
|
def get_state(self):
|
||||||
return LSTMStateTuple(tf.zeros(1), tf.zeros(1))
|
return LSTMStateTuple(tf.zeros(1), tf.zeros(1))
|
||||||
|
2
repeat-copy/nn_models/checkpoint
Normal file
2
repeat-copy/nn_models/checkpoint
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
model_checkpoint_path: "model.ckpt-10000"
|
||||||
|
all_model_checkpoint_paths: "model.ckpt-10000"
|
BIN
repeat-copy/nn_models/model.ckpt-10000.data-00000-of-00001
Normal file
BIN
repeat-copy/nn_models/model.ckpt-10000.data-00000-of-00001
Normal file
Binary file not shown.
BIN
repeat-copy/nn_models/model.ckpt-10000.index
Normal file
BIN
repeat-copy/nn_models/model.ckpt-10000.index
Normal file
Binary file not shown.
Binary file not shown.
453
repeat-copy/repeat-copy-nn.ipynb
Normal file
453
repeat-copy/repeat-copy-nn.ipynb
Normal file
File diff suppressed because one or more lines are too long
405
repeat-copy/repeat-copy-rnn.ipynb
Normal file
405
repeat-copy/repeat-copy-rnn.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,7 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from controller import Controller
|
from controller import Controller
|
||||||
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
|
||||||
|
|
||||||
"""
|
"""
|
||||||
A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes
|
A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes
|
||||||
@ -10,17 +10,17 @@ A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes
|
|||||||
class RNNController(Controller):
|
class RNNController(Controller):
|
||||||
|
|
||||||
def init_controller_params(self):
|
def init_controller_params(self):
|
||||||
rnn_dim = 64
|
self.rnn_dim = 64
|
||||||
init = tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32)
|
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
|
||||||
|
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||||||
self.params['cell'] = tf.nn.rnn_cell.BasicLSTMCell(rnn_dim, initializer = init)
|
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||||||
self.params['state'] = tf.Variable(tf.zeros([self.batch_size, rnn_dim]), trainable=False)
|
|
||||||
self.params['output'] = tf.Variable(tf.zeros([self.batch_size, rnn_dim]), trainable=False)
|
|
||||||
|
|
||||||
|
|
||||||
def nn_step(self, X, state):
|
def nn_step(self, X, state):
|
||||||
X = tf.convert_to_tensor(X)
|
X = tf.convert_to_tensor(X)
|
||||||
return self.params['cell'](X, state)
|
return self.lstm_cell(X, state)
|
||||||
|
|
||||||
def zero_state(self):
|
def update_state(self, update):
|
||||||
return (self.params['output'], self.params['state'])
|
return tf.group(self.output.assign(update[0]), self.state.assign(update[1]))
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return LSTMStateTuple(self.output, self.state)
|
||||||
|
BIN
repeat-copy/rnn_controller.pyc
Normal file
BIN
repeat-copy/rnn_controller.pyc
Normal file
Binary file not shown.
2
repeat-copy/rnn_models/checkpoint
Normal file
2
repeat-copy/rnn_models/checkpoint
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
model_checkpoint_path: "model.ckpt-1000"
|
||||||
|
all_model_checkpoint_paths: "model.ckpt-1000"
|
BIN
repeat-copy/rnn_models/model.ckpt-1000.data-00000-of-00001
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-1000.data-00000-of-00001
Normal file
Binary file not shown.
BIN
repeat-copy/rnn_models/model.ckpt-1000.index
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-1000.index
Normal file
Binary file not shown.
Binary file not shown.
Before Width: | Height: | Size: 117 KiB After Width: | Height: | Size: 117 KiB |
BIN
static/repeat_copy_results.png
Normal file
BIN
static/repeat_copy_results.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 85 KiB |
Loading…
Reference in New Issue
Block a user