yeah it's lit

This commit is contained in:
Sam Greydanus 2017-02-26 20:45:35 -05:00
parent eec207bfcb
commit 879b556732
30 changed files with 2023 additions and 48 deletions

View File

@ -77,8 +77,9 @@ class Controller():
Returns: int
controller_dim: the output dimension of the controller
'''
test_chi = tf.zeros([self.batch_size, self.chi_dim])
nn_output, nn_state = self.nn_step(test_chi, state=self.get_state())
with tf.variable_scope("dnc_scope") as scope:
test_chi = tf.zeros([self.batch_size, self.chi_dim])
nn_output, nn_state = self.nn_step(test_chi, state=self.get_state())
return nn_output.get_shape().as_list()[-1]
def prepare_interface(self, zeta_hat):

Binary file not shown.

View File

@ -10,7 +10,7 @@ from memory import Memory
import os
class DNC:
def __init__(self, make_controller, FLAGS):
def __init__(self, make_controller, FLAGS, input_steps=None):
'''
Builds a TensorFlow graph for the Differentiable Neural Computer. Uses TensorArrays and a while loop for efficiency
Parameters:
@ -40,33 +40,61 @@ class DNC:
self.X = tf.placeholder(tf.float32, [batch_size, None, xlen], name='X')
self.y = tf.placeholder(tf.float32, [batch_size, None, ylen], name='y')
self.tsteps = tf.placeholder(tf.int32, name='tsteps')
self.input_steps = input_steps if input_steps is not None else self.tsteps
self.X_tensor_array = self.unstack_time_dim(self.X)
# initialize states
nn_state = self.controller.get_state()
dnc_state = self.memory.zero_state()
self.nn_state = self.controller.get_state()
self.dnc_state = self.memory.zero_state()
# values for which we want a history
self.hist_keys = ['y_hat', 'f', 'g_a', 'g_w', 'w_r', 'w_w', 'u']
dnc_hist = [tf.TensorArray(tf.float32, self.tsteps) for _ in range(len(self.hist_keys))]
dnc_hist = [tf.TensorArray(tf.float32, self.tsteps, clear_after_read=False) for _ in range(len(self.hist_keys))]
# loop through time
with tf.variable_scope("while_loop") as scope:
with tf.variable_scope("dnc_scope", reuse=True) as scope:
time = tf.constant(0, dtype=tf.int32)
output = tf.while_loop(
cond=lambda time, *_: time < self.tsteps,
body=self.step,
loop_vars=(time, nn_state, dnc_state, dnc_hist),
loop_vars=(time, self.nn_state, self.dnc_state, dnc_hist),
)
(_, next_nn_state, next_dnc_state, dnc_hist) = output
(_, self.next_nn_state, self.next_dnc_state, dnc_hist) = output
# write down the history
controller_dependencies = [self.controller.update_state(next_nn_state)]
controller_dependencies = [self.controller.update_state(self.next_nn_state)]
with tf.control_dependencies(controller_dependencies):
self.dnc_hist = {self.hist_keys[i]: self.stack_time_dim(v) for i, v in enumerate(dnc_hist)} # convert to dict
def step2(self, time, nn_state, dnc_state, dnc_hist):
# map from tuple to dict for readability
dnc_state = {self.memory.state_keys[i]: v for i, v in enumerate(dnc_state)}
dnc_hist = {self.hist_keys[i]: v for i, v in enumerate(dnc_hist)}
# one full pass!
X_t = self.X_tensor_array.read(time)
v, zeta, next_nn_state = self.controller.step(X_t, dnc_state['r'], nn_state)
next_dnc_state = self.memory.step(zeta, dnc_state)
y_hat = self.controller.next_y_hat(v, next_dnc_state['r'])
dnc_hist['y_hat'] = dnc_hist['y_hat'].write(time, y_hat)
dnc_hist['f'] = dnc_hist['f'].write(time, zeta['f'])
dnc_hist['g_a'] = dnc_hist['g_a'].write(time, zeta['g_a'])
dnc_hist['g_w'] = dnc_hist['g_w'].write(time, zeta['g_w'])
dnc_hist['w_r'] = dnc_hist['w_r'].write(time, next_dnc_state['w_r'])
dnc_hist['w_w'] = dnc_hist['w_w'].write(time, next_dnc_state['w_w'])
dnc_hist['u'] = dnc_hist['u'].write(time, next_dnc_state['u'])
# map from dict to tuple for tf.while_loop :/
next_dnc_state = [next_dnc_state[k] for k in self.memory.state_keys]
dnc_hist = [dnc_hist[k] for k in self.hist_keys]
time += 1
return time, next_nn_state, next_dnc_state, dnc_hist
def step(self, time, nn_state, dnc_state, dnc_hist):
'''
Performs the feedforward step of the DNC in order to get the DNC output
@ -88,8 +116,15 @@ class DNC:
dnc_state = {self.memory.state_keys[i]: v for i, v in enumerate(dnc_state)}
dnc_hist = {self.hist_keys[i]: v for i, v in enumerate(dnc_hist)}
def use_prev_output():
y_prev = tf.concat((dnc_hist['y_hat'].read(time-1), tf.zeros([self.batch_size, self.xlen - self.ylen])), axis=1)
return tf.reshape(y_prev, (self.batch_size, self.xlen))
def use_input_array():
return self.X_tensor_array.read(time)
# one full pass!
X_t = self.X_tensor_array.read(time)
X_t = tf.cond(time < self.input_steps, use_input_array, use_prev_output)
v, zeta, next_nn_state = self.controller.step(X_t, dnc_state['r'], nn_state)
next_dnc_state = self.memory.step(zeta, dnc_state)
y_hat = self.controller.next_y_hat(v, next_dnc_state['r'])

Binary file not shown.

View File

@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 1
}

View File

@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 1
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,7 @@
Scribe: Realistic Handriting in Tensorflow
by Sam Greydanus
loaded dataset:
11895 individual data points
5947 batches

Binary file not shown.

26
handwriting/rnn_controller.py Executable file
View File

@ -0,0 +1,26 @@
import numpy as np
import tensorflow as tf
from controller import Controller
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
"""
A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes
"""
class RNNController(Controller):
def init_controller_params(self):
self.rnn_dim = 150
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
def nn_step(self, X, state):
X = tf.convert_to_tensor(X)
return self.lstm_cell(X, state)
def update_state(self, update):
return tf.group(self.output.assign(update[0]), self.state.assign(update[1]))
def get_state(self):
return LSTMStateTuple(self.output, self.state)

Binary file not shown.

142
handwriting/train.py Normal file
View File

@ -0,0 +1,142 @@
import tensorflow as tf
import numpy as np
import sys
sys.path.insert(0, '../dnc')
from dnc import DNC
from utils import *
from rnn_controller import RNNController
# hyperparameters
alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
stroke_steps = 175
tf.app.flags.DEFINE_integer("xlen", len(alphabet) + 4, "Input dimension")
tf.app.flags.DEFINE_integer("ylen", 3, "output dimension")
tf.app.flags.DEFINE_integer("stroke_steps", stroke_steps, "Number of time steps for stroke data")
tf.app.flags.DEFINE_integer("ascii_steps", stroke_steps/25, "Sequence length")
tf.app.flags.DEFINE_integer("data_scale", 50, "How to scale stroke data")
tf.app.flags.DEFINE_integer("batch_size", 2, "Size of batch in minibatch gradient descent")
tf.app.flags.DEFINE_integer("R", 1, "Number of DNC read heads")
tf.app.flags.DEFINE_integer("W", 16, "Word length for DNC memory")
tf.app.flags.DEFINE_integer("N", 10, "Number of words the DNC memory can store")
tf.app.flags.DEFINE_integer("train", True, "Train or sample???")
tf.app.flags.DEFINE_integer("print_every", 100, "Print training info after this number of train steps")
tf.app.flags.DEFINE_integer("iterations", 5000000, "Number of training iterations")
tf.app.flags.DEFINE_float("lr", 1e-4, "Learning rate (alpha) for the model")
tf.app.flags.DEFINE_float("momentum", .9, "RMSProp momentum")
tf.app.flags.DEFINE_integer("save_every", 1000, "Save model after this number of train steps")
tf.app.flags.DEFINE_string("save_path", "models/model.ckpt", "Where to save checkpoints")
tf.app.flags.DEFINE_string("data_dir", "data/", "Where to save checkpoints")
tf.app.flags.DEFINE_string("log_dir", "logs/", "Where to find data")
tf.app.flags.DEFINE_string("alphabet", alphabet, "Viable characters")
FLAGS = tf.app.flags.FLAGS
# modify dataloader
def next_batch(FLAGS, dl):
X_batch = []
y_batch = []
text = []
_X_batch, _y_batch, _text, _one_hots = dl.next_batch()
for i in range(FLAGS.batch_size):
ascii_part = np.concatenate((_one_hots[i], np.zeros((FLAGS.ascii_steps, 3))), axis=1)
X_stroke_part = np.concatenate((np.zeros((FLAGS.stroke_steps, len(FLAGS.alphabet)+1)), _X_batch[i]), axis=1)
X = np.concatenate((ascii_part, X_stroke_part), axis=0)
y = np.concatenate((np.zeros((FLAGS.ascii_steps, 3)), _y_batch[i]), axis=0)
X_batch.append(X) ; y_batch.append(y) ; text.append(_text)
return [X_batch, y_batch, text]
logger = Logger(FLAGS)
dl = Dataloader(FLAGS, logger, limit = 500)
# helper funcs
def binary_cross_entropy(y_hat, y):
return tf.reduce_mean(-y*tf.log(y_hat) - (1-y)*tf.log(1-y_hat))
def llprint(message):
sys.stdout.write(message)
sys.stdout.flush()
# build graph
sess = tf.InteractiveSession()
llprint("building graph...\n")
optimizer = tf.train.RMSPropOptimizer(FLAGS.lr, momentum=FLAGS.momentum)
dnc = DNC(RNNController, FLAGS, input_steps=FLAGS.ascii_steps)
llprint("defining loss...\n")
y_hat, outputs = dnc.get_outputs()
# TODO: fix this loss: l2 on [:,:,:2] and then do binary cross entropy on <EOS> tags
loss = tf.nn.l2_loss(dnc.y - y_hat)*100./(FLAGS.batch_size*(FLAGS.ascii_steps+FLAGS.stroke_steps))
llprint("computing gradients...\n")
gradients = optimizer.compute_gradients(loss)
for i, (grad, var) in enumerate(gradients):
if grad is not None:
gradients[i] = (tf.clip_by_value(grad, -10, 10), var)
grad_op = optimizer.apply_gradients(gradients)
llprint("init variables... \n")
sess.run(tf.global_variables_initializer())
llprint("ready to train...")
# model overiew
# tf parameter overview
total_parameters = 0 ; print "model overview..."
for variable in tf.trainable_variables():
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
print '\tvariable "{}" has {} parameters' \
.format(variable.name, variable_parameters)
total_parameters += variable_parameters
print "total of {} parameters".format(total_parameters)
# load saved models
global_step = 0
saver = tf.train.Saver(tf.global_variables())
load_was_success = True # yes, I'm being optimistic
try:
save_dir = '/'.join(FLAGS.save_path.split('/')[:-1])
ckpt = tf.train.get_checkpoint_state(save_dir)
load_path = ckpt.model_checkpoint_path
saver.restore(sess, load_path)
except:
print "no saved model to load."
load_was_success = False
else:
print "loaded model: {}".format(load_path)
saver = tf.train.Saver(tf.global_variables())
global_step = int(load_path.split('-')[-1]) + 1
# train, baby, train
loss_history = []
for i in xrange(global_step, FLAGS.iterations + 1):
llprint("\rIteration {}/{}".format(i, FLAGS.iterations))
X, y, text = next_batch(FLAGS, dl)
tsteps = FLAGS.ascii_steps + FLAGS.stroke_steps
fetch = [loss, grad_op]
feed = {dnc.X: X, dnc.y: y, dnc.tsteps: tsteps}
step_loss, _ = sess.run(fetch, feed_dict=feed)
loss_history.append(step_loss)
global_step = i
if i % 100 == 0:
llprint("\n\tloss: {:03.4f}\n".format(np.mean(loss_history)))
loss_history = []
if i % FLAGS.save_every == 0 and i is not 0:
llprint("\n\tSAVING MODEL\n")
saver.save(sess, FLAGS.save_path, global_step=global_step)

203
handwriting/utils.py Normal file
View File

@ -0,0 +1,203 @@
import numpy as np
import math
import random
import os
import cPickle as pickle
import xml.etree.ElementTree as ET
from utils import *
class Dataloader():
def __init__(self, FLAGS, logger, limit = 500):
self.data_dir = FLAGS.data_dir
self.alphabet = FLAGS.alphabet
self.batch_size = FLAGS.batch_size
self.stroke_steps = FLAGS.stroke_steps
self.data_scale = FLAGS.data_scale # scale data down by this factor
self.ascii_steps = FLAGS.ascii_steps
self.logger = logger
self.limit = limit # removes large noisy gaps in the data
data_file = os.path.join(self.data_dir, "strokes_training_data.cpkl")
stroke_dir = self.data_dir + "/lineStrokes"
ascii_dir = self.data_dir + "/ascii"
if not (os.path.exists(data_file)) :
self.logger.write("\tcreating training data cpkl file from raw source")
self.preprocess(stroke_dir, ascii_dir, data_file)
self.load_preprocessed(data_file)
self.reset_batch_pointer()
def preprocess(self, stroke_dir, ascii_dir, data_file):
# create data file from raw xml files from iam handwriting source.
self.logger.write("\tparsing dataset...")
# build the list of xml files
filelist = []
# Set the directory you want to start from
rootDir = stroke_dir
for dirName, subdirList, fileList in os.walk(rootDir):
for fname in fileList:
filelist.append(dirName+"/"+fname)
# function to read each individual xml file
def getStrokes(filename):
tree = ET.parse(filename)
root = tree.getroot()
result = []
x_offset = 1e20
y_offset = 1e20
y_height = 0
for i in range(1, 4):
x_offset = min(x_offset, float(root[0][i].attrib['x']))
y_offset = min(y_offset, float(root[0][i].attrib['y']))
y_height = max(y_height, float(root[0][i].attrib['y']))
y_height -= y_offset
x_offset -= 100
y_offset -= 100
for stroke in root[1].findall('Stroke'):
points = []
for point in stroke.findall('Point'):
points.append([float(point.attrib['x'])-x_offset,float(point.attrib['y'])-y_offset])
result.append(points)
return result
# function to read each individual xml file
def getAscii(filename, line_number):
with open(filename, "r") as f:
s = f.read()
s = s[s.find("CSR"):]
if len(s.split("\n")) > line_number+2:
s = s.split("\n")[line_number+2]
return s
else:
return ""
# converts a list of arrays into a 2d numpy int16 array
def convert_stroke_to_array(stroke):
n_point = 0
for i in range(len(stroke)):
n_point += len(stroke[i])
stroke_data = np.zeros((n_point, 3), dtype=np.int16)
prev_x = 0
prev_y = 0
counter = 0
for j in range(len(stroke)):
for k in range(len(stroke[j])):
stroke_data[counter, 0] = int(stroke[j][k][0]) - prev_x
stroke_data[counter, 1] = int(stroke[j][k][1]) - prev_y
prev_x = int(stroke[j][k][0])
prev_y = int(stroke[j][k][1])
stroke_data[counter, 2] = 0
if (k == (len(stroke[j])-1)): # end of stroke
stroke_data[counter, 2] = 1
counter += 1
return stroke_data
# build stroke database of every xml file inside iam database
strokes = []
asciis = []
for i in range(len(filelist)):
if (filelist[i][-3:] == 'xml'):
stroke_file = filelist[i]
# print 'processing '+stroke_file
stroke = convert_stroke_to_array(getStrokes(stroke_file))
ascii_file = stroke_file.replace("lineStrokes","ascii")[:-7] + ".txt"
line_number = stroke_file[-6:-4]
line_number = int(line_number) - 1
ascii = getAscii(ascii_file, line_number)
if len(ascii) > 10:
strokes.append(stroke)
asciis.append(ascii)
else:
self.logger.write("\tline length was too short. line was: " + ascii)
assert(len(strokes)==len(asciis)), "There should be a 1:1 correspondence between stroke data and ascii labels."
f = open(data_file,"wb")
pickle.dump([strokes,asciis], f, protocol=2)
f.close()
self.logger.write("\tfinished parsing dataset. saved {} lines".format(len(strokes)))
def load_preprocessed(self, data_file):
f = open(data_file,"rb")
[self.raw_stroke_data, self.raw_ascii_data] = pickle.load(f)
f.close()
# goes thru the list, and only keeps the text entries that have more than stroke_steps points
self.stroke_data = []
self.ascii_data = []
counter = 0
for i in range(len(self.raw_stroke_data)):
data = self.raw_stroke_data[i]
if len(data) > (self.stroke_steps+2):
# removes large gaps from the data
data = np.minimum(data, self.limit)
data = np.maximum(data, -self.limit)
data = np.array(data,dtype=np.float32)
data[:,0:2] /= self.data_scale
self.stroke_data.append(data)
self.ascii_data.append(self.raw_ascii_data[i])
# minus 1, since we want the ydata to be a shifted version of x data
self.num_batches = int(len(self.stroke_data) / self.batch_size)
self.logger.write("\tloaded dataset:")
self.logger.write("\t\t{} individual data points".format(len(self.stroke_data)))
self.logger.write("\t\t{} batches".format(self.num_batches))
def next_batch(self):
# returns a randomized, stroke_steps-sized portion of the training data
x_batch = []
y_batch = []
ascii_list = []
for i in xrange(self.batch_size):
data = self.stroke_data[self.idx_perm[self.pointer]]
idx = random.randint(0, len(data)-self.stroke_steps-2)
x_batch.append(np.copy(data[:self.stroke_steps]))
y_batch.append(np.copy(data[1:self.stroke_steps+1]))
ascii_list.append(self.ascii_data[self.idx_perm[self.pointer]][:self.ascii_steps])
self.tick_batch_pointer()
one_hots = [to_one_hot(s, self.ascii_steps, self.alphabet) for s in ascii_list]
return x_batch, y_batch, ascii_list, one_hots
def tick_batch_pointer(self):
self.pointer += 1
if (self.pointer >= len(self.stroke_data)):
self.reset_batch_pointer()
def reset_batch_pointer(self):
self.idx_perm = np.random.permutation(len(self.stroke_data))
self.pointer = 0
# utility function for converting input ascii characters into vectors the network can understand.
# index position 0 means "unknown"
def to_one_hot(s, ascii_steps, alphabet):
steplimit=3e3; s = s[:3e3] if len(s) > 3e3 else s # clip super-long strings
seq = [alphabet.find(char) + 1 for char in s]
if len(seq) >= ascii_steps:
seq = seq[:ascii_steps]
else:
seq = seq + [0]*(ascii_steps - len(seq))
one_hot = np.zeros((ascii_steps,len(alphabet)+1))
one_hot[np.arange(ascii_steps),seq] = 1
return one_hot
# abstraction for logging
class Logger():
def __init__(self, FLAGS):
self.logf = '{}train_scribe.txt'.format(FLAGS.log_dir) if FLAGS.train else '{}sample_scribe.txt'.format(FLAGS.log_dir)
with open(self.logf, 'w') as f: f.write("Scribe: Realistic Handriting in Tensorflow\n by Sam Greydanus\n\n\n")
def write(self, s, print_it=True):
if print_it:
print s
with open(self.logf, 'a') as f:
f.write(s + "\n")

BIN
handwriting/utils.pyc Normal file

Binary file not shown.

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@ -1,2 +1,4 @@
model_checkpoint_path: "model.ckpt-7000"
all_model_checkpoint_paths: "model.ckpt-7000"
model_checkpoint_path: "model.ckpt-10000"
all_model_checkpoint_paths: "model.ckpt-8000"
all_model_checkpoint_paths: "model.ckpt-9000"
all_model_checkpoint_paths: "model.ckpt-10000"

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.