serial-recall task

This commit is contained in:
Sam Greydanus 2017-02-28 22:11:13 -05:00
parent 879b556732
commit 3a4e8a124e
40 changed files with 2161 additions and 194 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -4,4 +4,4 @@ Scribe: Realistic Handriting in Tensorflow
loaded dataset: loaded dataset:
11895 individual data points 11895 individual data points
5947 batches 743 batches

View File

@ -0,0 +1,6 @@
model_checkpoint_path: "model.ckpt-34000"
all_model_checkpoint_paths: "model.ckpt-30000"
all_model_checkpoint_paths: "model.ckpt-31000"
all_model_checkpoint_paths: "model.ckpt-32000"
all_model_checkpoint_paths: "model.ckpt-33000"
all_model_checkpoint_paths: "model.ckpt-34000"

Binary file not shown.

Binary file not shown.

View File

@ -4,13 +4,13 @@ from controller import Controller
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
""" """
A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes RNN (cell type LSTM) with 128 hidden layers
""" """
class RNNController(Controller): class RNNController(Controller):
def init_controller_params(self): def init_controller_params(self):
self.rnn_dim = 150 self.rnn_dim = 300
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim) self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False) self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False) self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)

Binary file not shown.

View File

@ -15,11 +15,11 @@ tf.app.flags.DEFINE_integer("ylen", 3, "output dimension")
tf.app.flags.DEFINE_integer("stroke_steps", stroke_steps, "Number of time steps for stroke data") tf.app.flags.DEFINE_integer("stroke_steps", stroke_steps, "Number of time steps for stroke data")
tf.app.flags.DEFINE_integer("ascii_steps", stroke_steps/25, "Sequence length") tf.app.flags.DEFINE_integer("ascii_steps", stroke_steps/25, "Sequence length")
tf.app.flags.DEFINE_integer("data_scale", 50, "How to scale stroke data") tf.app.flags.DEFINE_integer("data_scale", 50, "How to scale stroke data")
tf.app.flags.DEFINE_integer("batch_size", 2, "Size of batch in minibatch gradient descent") tf.app.flags.DEFINE_integer("batch_size", 16, "Size of batch in minibatch gradient descent")
tf.app.flags.DEFINE_integer("R", 1, "Number of DNC read heads") tf.app.flags.DEFINE_integer("R", 1, "Number of DNC read heads")
tf.app.flags.DEFINE_integer("W", 16, "Word length for DNC memory") tf.app.flags.DEFINE_integer("W", 100, "Word length for DNC memory")
tf.app.flags.DEFINE_integer("N", 10, "Number of words the DNC memory can store") tf.app.flags.DEFINE_integer("N", 8, "Number of words the DNC memory can store")
tf.app.flags.DEFINE_integer("train", True, "Train or sample???") tf.app.flags.DEFINE_integer("train", True, "Train or sample???")
tf.app.flags.DEFINE_integer("print_every", 100, "Print training info after this number of train steps") tf.app.flags.DEFINE_integer("print_every", 100, "Print training info after this number of train steps")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -4,13 +4,13 @@ from controller import Controller
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
""" """
A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes RNN (cell type LSTM) with 128 hidden layers
""" """
class RNNController(Controller): class RNNController(Controller):
def init_controller_params(self): def init_controller_params(self):
self.rnn_dim = 64 self.rnn_dim = 128
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim) self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False) self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False) self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)

Binary file not shown.

View File

@ -1,4 +1,4 @@
model_checkpoint_path: "model.ckpt-10000" model_checkpoint_path: "model.ckpt-6000"
all_model_checkpoint_paths: "model.ckpt-8000" all_model_checkpoint_paths: "model.ckpt-4000"
all_model_checkpoint_paths: "model.ckpt-9000" all_model_checkpoint_paths: "model.ckpt-5000"
all_model_checkpoint_paths: "model.ckpt-10000" all_model_checkpoint_paths: "model.ckpt-6000"

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 1
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

26
serial-recall/rnn_controller.py Executable file
View File

@ -0,0 +1,26 @@
import numpy as np
import tensorflow as tf
from controller import Controller
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
"""
RNN (cell type LSTM) with 128 hidden units
"""
class RNNController(Controller):
def init_controller_params(self):
self.rnn_dim = 128
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
def nn_step(self, X, state):
X = tf.convert_to_tensor(X)
return self.lstm_cell(X, state)
def update_state(self, update):
return tf.group(self.output.assign(update[0]), self.state.assign(update[1]))
def get_state(self):
return LSTMStateTuple(self.output, self.state)

Binary file not shown.

View File

@ -0,0 +1,4 @@
model_checkpoint_path: "model.ckpt-10000"
all_model_checkpoint_paths: "model.ckpt-8000"
all_model_checkpoint_paths: "model.ckpt-9000"
all_model_checkpoint_paths: "model.ckpt-10000"

File diff suppressed because one or more lines are too long

Binary file not shown.

Before

Width:  |  Height:  |  Size: 85 KiB

After

Width:  |  Height:  |  Size: 83 KiB