serial-recall task
This commit is contained in:
parent
879b556732
commit
3a4e8a124e
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -4,4 +4,4 @@ Scribe: Realistic Handriting in Tensorflow
|
||||
|
||||
loaded dataset:
|
||||
11895 individual data points
|
||||
5947 batches
|
||||
743 batches
|
||||
|
6
handwriting/models/checkpoint
Normal file
6
handwriting/models/checkpoint
Normal file
@ -0,0 +1,6 @@
|
||||
model_checkpoint_path: "model.ckpt-34000"
|
||||
all_model_checkpoint_paths: "model.ckpt-30000"
|
||||
all_model_checkpoint_paths: "model.ckpt-31000"
|
||||
all_model_checkpoint_paths: "model.ckpt-32000"
|
||||
all_model_checkpoint_paths: "model.ckpt-33000"
|
||||
all_model_checkpoint_paths: "model.ckpt-34000"
|
BIN
handwriting/models/model.ckpt-34000.data-00000-of-00001
Normal file
BIN
handwriting/models/model.ckpt-34000.data-00000-of-00001
Normal file
Binary file not shown.
BIN
handwriting/models/model.ckpt-34000.index
Normal file
BIN
handwriting/models/model.ckpt-34000.index
Normal file
Binary file not shown.
BIN
handwriting/models/model.ckpt-34000.meta
Normal file
BIN
handwriting/models/model.ckpt-34000.meta
Normal file
Binary file not shown.
@ -4,13 +4,13 @@ from controller import Controller
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
|
||||
|
||||
"""
|
||||
A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes
|
||||
RNN (cell type LSTM) with 128 hidden layers
|
||||
"""
|
||||
|
||||
class RNNController(Controller):
|
||||
|
||||
def init_controller_params(self):
|
||||
self.rnn_dim = 150
|
||||
self.rnn_dim = 300
|
||||
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
|
||||
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||||
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||||
|
Binary file not shown.
@ -15,11 +15,11 @@ tf.app.flags.DEFINE_integer("ylen", 3, "output dimension")
|
||||
tf.app.flags.DEFINE_integer("stroke_steps", stroke_steps, "Number of time steps for stroke data")
|
||||
tf.app.flags.DEFINE_integer("ascii_steps", stroke_steps/25, "Sequence length")
|
||||
tf.app.flags.DEFINE_integer("data_scale", 50, "How to scale stroke data")
|
||||
tf.app.flags.DEFINE_integer("batch_size", 2, "Size of batch in minibatch gradient descent")
|
||||
tf.app.flags.DEFINE_integer("batch_size", 16, "Size of batch in minibatch gradient descent")
|
||||
|
||||
tf.app.flags.DEFINE_integer("R", 1, "Number of DNC read heads")
|
||||
tf.app.flags.DEFINE_integer("W", 16, "Word length for DNC memory")
|
||||
tf.app.flags.DEFINE_integer("N", 10, "Number of words the DNC memory can store")
|
||||
tf.app.flags.DEFINE_integer("W", 100, "Word length for DNC memory")
|
||||
tf.app.flags.DEFINE_integer("N", 8, "Number of words the DNC memory can store")
|
||||
|
||||
tf.app.flags.DEFINE_integer("train", True, "Train or sample???")
|
||||
tf.app.flags.DEFINE_integer("print_every", 100, "Print training info after this number of train steps")
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -4,13 +4,13 @@ from controller import Controller
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
|
||||
|
||||
"""
|
||||
A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes
|
||||
RNN (cell type LSTM) with 128 hidden layers
|
||||
"""
|
||||
|
||||
class RNNController(Controller):
|
||||
|
||||
def init_controller_params(self):
|
||||
self.rnn_dim = 64
|
||||
self.rnn_dim = 128
|
||||
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
|
||||
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||||
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||||
|
Binary file not shown.
@ -1,4 +1,4 @@
|
||||
model_checkpoint_path: "model.ckpt-10000"
|
||||
all_model_checkpoint_paths: "model.ckpt-8000"
|
||||
all_model_checkpoint_paths: "model.ckpt-9000"
|
||||
all_model_checkpoint_paths: "model.ckpt-10000"
|
||||
model_checkpoint_path: "model.ckpt-6000"
|
||||
all_model_checkpoint_paths: "model.ckpt-4000"
|
||||
all_model_checkpoint_paths: "model.ckpt-5000"
|
||||
all_model_checkpoint_paths: "model.ckpt-6000"
|
||||
|
BIN
repeat-copy/rnn_models/model.ckpt-6000.data-00000-of-00001
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-6000.data-00000-of-00001
Normal file
Binary file not shown.
BIN
repeat-copy/rnn_models/model.ckpt-6000.index
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-6000.index
Normal file
Binary file not shown.
BIN
repeat-copy/rnn_models/model.ckpt-6000.meta
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-6000.meta
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
6
serial-recall/.ipynb_checkpoints/debug-checkpoint.ipynb
Normal file
6
serial-recall/.ipynb_checkpoints/debug-checkpoint.ipynb
Normal file
@ -0,0 +1,6 @@
|
||||
{
|
||||
"cells": [],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
380
serial-recall/.ipynb_checkpoints/repeat-copy-nn-checkpoint.ipynb
Normal file
380
serial-recall/.ipynb_checkpoints/repeat-copy-nn-checkpoint.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
610
serial-recall/.ipynb_checkpoints/visualization-checkpoint.ipynb
Executable file
610
serial-recall/.ipynb_checkpoints/visualization-checkpoint.ipynb
Executable file
File diff suppressed because one or more lines are too long
BIN
serial-recall/nn_controller.pyc
Normal file
BIN
serial-recall/nn_controller.pyc
Normal file
Binary file not shown.
26
serial-recall/rnn_controller.py
Executable file
26
serial-recall/rnn_controller.py
Executable file
@ -0,0 +1,26 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from controller import Controller
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
|
||||
|
||||
"""
|
||||
RNN (cell type LSTM) with 128 hidden units
|
||||
"""
|
||||
|
||||
class RNNController(Controller):
|
||||
|
||||
def init_controller_params(self):
|
||||
self.rnn_dim = 128
|
||||
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
|
||||
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||||
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||||
|
||||
def nn_step(self, X, state):
|
||||
X = tf.convert_to_tensor(X)
|
||||
return self.lstm_cell(X, state)
|
||||
|
||||
def update_state(self, update):
|
||||
return tf.group(self.output.assign(update[0]), self.state.assign(update[1]))
|
||||
|
||||
def get_state(self):
|
||||
return LSTMStateTuple(self.output, self.state)
|
BIN
serial-recall/rnn_controller.pyc
Normal file
BIN
serial-recall/rnn_controller.pyc
Normal file
Binary file not shown.
4
serial-recall/rnn_models/checkpoint
Normal file
4
serial-recall/rnn_models/checkpoint
Normal file
@ -0,0 +1,4 @@
|
||||
model_checkpoint_path: "model.ckpt-10000"
|
||||
all_model_checkpoint_paths: "model.ckpt-8000"
|
||||
all_model_checkpoint_paths: "model.ckpt-9000"
|
||||
all_model_checkpoint_paths: "model.ckpt-10000"
|
557
serial-recall/serial-recall-rnn.ipynb
Normal file
557
serial-recall/serial-recall-rnn.ipynb
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
Before Width: | Height: | Size: 85 KiB After Width: | Height: | Size: 83 KiB |
Loading…
Reference in New Issue
Block a user