yeah it's lit
This commit is contained in:
parent
eec207bfcb
commit
879b556732
@ -77,8 +77,9 @@ class Controller():
|
||||
Returns: int
|
||||
controller_dim: the output dimension of the controller
|
||||
'''
|
||||
test_chi = tf.zeros([self.batch_size, self.chi_dim])
|
||||
nn_output, nn_state = self.nn_step(test_chi, state=self.get_state())
|
||||
with tf.variable_scope("dnc_scope") as scope:
|
||||
test_chi = tf.zeros([self.batch_size, self.chi_dim])
|
||||
nn_output, nn_state = self.nn_step(test_chi, state=self.get_state())
|
||||
return nn_output.get_shape().as_list()[-1]
|
||||
|
||||
def prepare_interface(self, zeta_hat):
|
||||
|
Binary file not shown.
53
dnc/dnc.py
53
dnc/dnc.py
@ -10,7 +10,7 @@ from memory import Memory
|
||||
import os
|
||||
|
||||
class DNC:
|
||||
def __init__(self, make_controller, FLAGS):
|
||||
def __init__(self, make_controller, FLAGS, input_steps=None):
|
||||
'''
|
||||
Builds a TensorFlow graph for the Differentiable Neural Computer. Uses TensorArrays and a while loop for efficiency
|
||||
Parameters:
|
||||
@ -40,33 +40,61 @@ class DNC:
|
||||
self.X = tf.placeholder(tf.float32, [batch_size, None, xlen], name='X')
|
||||
self.y = tf.placeholder(tf.float32, [batch_size, None, ylen], name='y')
|
||||
self.tsteps = tf.placeholder(tf.int32, name='tsteps')
|
||||
self.input_steps = input_steps if input_steps is not None else self.tsteps
|
||||
|
||||
self.X_tensor_array = self.unstack_time_dim(self.X)
|
||||
|
||||
# initialize states
|
||||
nn_state = self.controller.get_state()
|
||||
dnc_state = self.memory.zero_state()
|
||||
self.nn_state = self.controller.get_state()
|
||||
self.dnc_state = self.memory.zero_state()
|
||||
|
||||
# values for which we want a history
|
||||
self.hist_keys = ['y_hat', 'f', 'g_a', 'g_w', 'w_r', 'w_w', 'u']
|
||||
dnc_hist = [tf.TensorArray(tf.float32, self.tsteps) for _ in range(len(self.hist_keys))]
|
||||
dnc_hist = [tf.TensorArray(tf.float32, self.tsteps, clear_after_read=False) for _ in range(len(self.hist_keys))]
|
||||
|
||||
# loop through time
|
||||
with tf.variable_scope("while_loop") as scope:
|
||||
with tf.variable_scope("dnc_scope", reuse=True) as scope:
|
||||
time = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
output = tf.while_loop(
|
||||
cond=lambda time, *_: time < self.tsteps,
|
||||
body=self.step,
|
||||
loop_vars=(time, nn_state, dnc_state, dnc_hist),
|
||||
loop_vars=(time, self.nn_state, self.dnc_state, dnc_hist),
|
||||
)
|
||||
(_, next_nn_state, next_dnc_state, dnc_hist) = output
|
||||
(_, self.next_nn_state, self.next_dnc_state, dnc_hist) = output
|
||||
|
||||
# write down the history
|
||||
controller_dependencies = [self.controller.update_state(next_nn_state)]
|
||||
controller_dependencies = [self.controller.update_state(self.next_nn_state)]
|
||||
with tf.control_dependencies(controller_dependencies):
|
||||
self.dnc_hist = {self.hist_keys[i]: self.stack_time_dim(v) for i, v in enumerate(dnc_hist)} # convert to dict
|
||||
|
||||
def step2(self, time, nn_state, dnc_state, dnc_hist):
|
||||
|
||||
# map from tuple to dict for readability
|
||||
dnc_state = {self.memory.state_keys[i]: v for i, v in enumerate(dnc_state)}
|
||||
dnc_hist = {self.hist_keys[i]: v for i, v in enumerate(dnc_hist)}
|
||||
|
||||
# one full pass!
|
||||
X_t = self.X_tensor_array.read(time)
|
||||
v, zeta, next_nn_state = self.controller.step(X_t, dnc_state['r'], nn_state)
|
||||
next_dnc_state = self.memory.step(zeta, dnc_state)
|
||||
y_hat = self.controller.next_y_hat(v, next_dnc_state['r'])
|
||||
|
||||
dnc_hist['y_hat'] = dnc_hist['y_hat'].write(time, y_hat)
|
||||
dnc_hist['f'] = dnc_hist['f'].write(time, zeta['f'])
|
||||
dnc_hist['g_a'] = dnc_hist['g_a'].write(time, zeta['g_a'])
|
||||
dnc_hist['g_w'] = dnc_hist['g_w'].write(time, zeta['g_w'])
|
||||
dnc_hist['w_r'] = dnc_hist['w_r'].write(time, next_dnc_state['w_r'])
|
||||
dnc_hist['w_w'] = dnc_hist['w_w'].write(time, next_dnc_state['w_w'])
|
||||
dnc_hist['u'] = dnc_hist['u'].write(time, next_dnc_state['u'])
|
||||
|
||||
# map from dict to tuple for tf.while_loop :/
|
||||
next_dnc_state = [next_dnc_state[k] for k in self.memory.state_keys]
|
||||
dnc_hist = [dnc_hist[k] for k in self.hist_keys]
|
||||
|
||||
time += 1
|
||||
return time, next_nn_state, next_dnc_state, dnc_hist
|
||||
|
||||
def step(self, time, nn_state, dnc_state, dnc_hist):
|
||||
'''
|
||||
Performs the feedforward step of the DNC in order to get the DNC output
|
||||
@ -88,8 +116,15 @@ class DNC:
|
||||
dnc_state = {self.memory.state_keys[i]: v for i, v in enumerate(dnc_state)}
|
||||
dnc_hist = {self.hist_keys[i]: v for i, v in enumerate(dnc_hist)}
|
||||
|
||||
def use_prev_output():
|
||||
y_prev = tf.concat((dnc_hist['y_hat'].read(time-1), tf.zeros([self.batch_size, self.xlen - self.ylen])), axis=1)
|
||||
return tf.reshape(y_prev, (self.batch_size, self.xlen))
|
||||
|
||||
def use_input_array():
|
||||
return self.X_tensor_array.read(time)
|
||||
|
||||
# one full pass!
|
||||
X_t = self.X_tensor_array.read(time)
|
||||
X_t = tf.cond(time < self.input_steps, use_input_array, use_prev_output)
|
||||
v, zeta, next_nn_state = self.controller.step(X_t, dnc_state['r'], nn_state)
|
||||
next_dnc_state = self.memory.step(zeta, dnc_state)
|
||||
y_hat = self.controller.next_y_hat(v, next_dnc_state['r'])
|
||||
|
BIN
dnc/dnc.pyc
BIN
dnc/dnc.pyc
Binary file not shown.
6
handwriting/.ipynb_checkpoints/Untitled-checkpoint.ipynb
Normal file
6
handwriting/.ipynb_checkpoints/Untitled-checkpoint.ipynb
Normal file
@ -0,0 +1,6 @@
|
||||
{
|
||||
"cells": [],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
6
handwriting/.ipynb_checkpoints/debug-checkpoint.ipynb
Normal file
6
handwriting/.ipynb_checkpoints/debug-checkpoint.ipynb
Normal file
@ -0,0 +1,6 @@
|
||||
{
|
||||
"cells": [],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
453
handwriting/.ipynb_checkpoints/handwriting-rnn-checkpoint.ipynb
Normal file
453
handwriting/.ipynb_checkpoints/handwriting-rnn-checkpoint.ipynb
Normal file
File diff suppressed because one or more lines are too long
610
handwriting/.ipynb_checkpoints/visualization-checkpoint.ipynb
Executable file
610
handwriting/.ipynb_checkpoints/visualization-checkpoint.ipynb
Executable file
File diff suppressed because one or more lines are too long
BIN
handwriting/data/strokes_training_data.cpkl
Normal file
BIN
handwriting/data/strokes_training_data.cpkl
Normal file
Binary file not shown.
453
handwriting/handwriting-rnn.ipynb
Normal file
453
handwriting/handwriting-rnn.ipynb
Normal file
File diff suppressed because one or more lines are too long
7
handwriting/logs/train_scribe.txt
Normal file
7
handwriting/logs/train_scribe.txt
Normal file
@ -0,0 +1,7 @@
|
||||
Scribe: Realistic Handriting in Tensorflow
|
||||
by Sam Greydanus
|
||||
|
||||
|
||||
loaded dataset:
|
||||
11895 individual data points
|
||||
5947 batches
|
BIN
handwriting/nn_controller.pyc
Normal file
BIN
handwriting/nn_controller.pyc
Normal file
Binary file not shown.
26
handwriting/rnn_controller.py
Executable file
26
handwriting/rnn_controller.py
Executable file
@ -0,0 +1,26 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from controller import Controller
|
||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import LSTMStateTuple
|
||||
|
||||
"""
|
||||
A 1-Layer recurrent neural network (LSTM) with 64 hidden nodes
|
||||
"""
|
||||
|
||||
class RNNController(Controller):
|
||||
|
||||
def init_controller_params(self):
|
||||
self.rnn_dim = 150
|
||||
self.lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(self.rnn_dim)
|
||||
self.state = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||||
self.output = tf.Variable(tf.zeros([self.batch_size, self.rnn_dim]), trainable=False)
|
||||
|
||||
def nn_step(self, X, state):
|
||||
X = tf.convert_to_tensor(X)
|
||||
return self.lstm_cell(X, state)
|
||||
|
||||
def update_state(self, update):
|
||||
return tf.group(self.output.assign(update[0]), self.state.assign(update[1]))
|
||||
|
||||
def get_state(self):
|
||||
return LSTMStateTuple(self.output, self.state)
|
BIN
handwriting/rnn_controller.pyc
Normal file
BIN
handwriting/rnn_controller.pyc
Normal file
Binary file not shown.
142
handwriting/train.py
Normal file
142
handwriting/train.py
Normal file
@ -0,0 +1,142 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import sys
|
||||
sys.path.insert(0, '../dnc')
|
||||
|
||||
from dnc import DNC
|
||||
from utils import *
|
||||
from rnn_controller import RNNController
|
||||
|
||||
# hyperparameters
|
||||
alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
stroke_steps = 175
|
||||
tf.app.flags.DEFINE_integer("xlen", len(alphabet) + 4, "Input dimension")
|
||||
tf.app.flags.DEFINE_integer("ylen", 3, "output dimension")
|
||||
tf.app.flags.DEFINE_integer("stroke_steps", stroke_steps, "Number of time steps for stroke data")
|
||||
tf.app.flags.DEFINE_integer("ascii_steps", stroke_steps/25, "Sequence length")
|
||||
tf.app.flags.DEFINE_integer("data_scale", 50, "How to scale stroke data")
|
||||
tf.app.flags.DEFINE_integer("batch_size", 2, "Size of batch in minibatch gradient descent")
|
||||
|
||||
tf.app.flags.DEFINE_integer("R", 1, "Number of DNC read heads")
|
||||
tf.app.flags.DEFINE_integer("W", 16, "Word length for DNC memory")
|
||||
tf.app.flags.DEFINE_integer("N", 10, "Number of words the DNC memory can store")
|
||||
|
||||
tf.app.flags.DEFINE_integer("train", True, "Train or sample???")
|
||||
tf.app.flags.DEFINE_integer("print_every", 100, "Print training info after this number of train steps")
|
||||
tf.app.flags.DEFINE_integer("iterations", 5000000, "Number of training iterations")
|
||||
tf.app.flags.DEFINE_float("lr", 1e-4, "Learning rate (alpha) for the model")
|
||||
tf.app.flags.DEFINE_float("momentum", .9, "RMSProp momentum")
|
||||
tf.app.flags.DEFINE_integer("save_every", 1000, "Save model after this number of train steps")
|
||||
tf.app.flags.DEFINE_string("save_path", "models/model.ckpt", "Where to save checkpoints")
|
||||
tf.app.flags.DEFINE_string("data_dir", "data/", "Where to save checkpoints")
|
||||
tf.app.flags.DEFINE_string("log_dir", "logs/", "Where to find data")
|
||||
tf.app.flags.DEFINE_string("alphabet", alphabet, "Viable characters")
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
|
||||
|
||||
# modify dataloader
|
||||
def next_batch(FLAGS, dl):
|
||||
X_batch = []
|
||||
y_batch = []
|
||||
text = []
|
||||
_X_batch, _y_batch, _text, _one_hots = dl.next_batch()
|
||||
for i in range(FLAGS.batch_size):
|
||||
ascii_part = np.concatenate((_one_hots[i], np.zeros((FLAGS.ascii_steps, 3))), axis=1)
|
||||
X_stroke_part = np.concatenate((np.zeros((FLAGS.stroke_steps, len(FLAGS.alphabet)+1)), _X_batch[i]), axis=1)
|
||||
|
||||
X = np.concatenate((ascii_part, X_stroke_part), axis=0)
|
||||
y = np.concatenate((np.zeros((FLAGS.ascii_steps, 3)), _y_batch[i]), axis=0)
|
||||
X_batch.append(X) ; y_batch.append(y) ; text.append(_text)
|
||||
return [X_batch, y_batch, text]
|
||||
|
||||
logger = Logger(FLAGS)
|
||||
dl = Dataloader(FLAGS, logger, limit = 500)
|
||||
|
||||
|
||||
# helper funcs
|
||||
def binary_cross_entropy(y_hat, y):
|
||||
return tf.reduce_mean(-y*tf.log(y_hat) - (1-y)*tf.log(1-y_hat))
|
||||
|
||||
def llprint(message):
|
||||
sys.stdout.write(message)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
# build graph
|
||||
sess = tf.InteractiveSession()
|
||||
|
||||
llprint("building graph...\n")
|
||||
optimizer = tf.train.RMSPropOptimizer(FLAGS.lr, momentum=FLAGS.momentum)
|
||||
dnc = DNC(RNNController, FLAGS, input_steps=FLAGS.ascii_steps)
|
||||
|
||||
llprint("defining loss...\n")
|
||||
y_hat, outputs = dnc.get_outputs()
|
||||
# TODO: fix this loss: l2 on [:,:,:2] and then do binary cross entropy on <EOS> tags
|
||||
loss = tf.nn.l2_loss(dnc.y - y_hat)*100./(FLAGS.batch_size*(FLAGS.ascii_steps+FLAGS.stroke_steps))
|
||||
|
||||
llprint("computing gradients...\n")
|
||||
gradients = optimizer.compute_gradients(loss)
|
||||
for i, (grad, var) in enumerate(gradients):
|
||||
if grad is not None:
|
||||
gradients[i] = (tf.clip_by_value(grad, -10, 10), var)
|
||||
|
||||
grad_op = optimizer.apply_gradients(gradients)
|
||||
|
||||
llprint("init variables... \n")
|
||||
sess.run(tf.global_variables_initializer())
|
||||
llprint("ready to train...")
|
||||
|
||||
|
||||
# model overiew
|
||||
# tf parameter overview
|
||||
total_parameters = 0 ; print "model overview..."
|
||||
for variable in tf.trainable_variables():
|
||||
shape = variable.get_shape()
|
||||
variable_parameters = 1
|
||||
for dim in shape:
|
||||
variable_parameters *= dim.value
|
||||
print '\tvariable "{}" has {} parameters' \
|
||||
.format(variable.name, variable_parameters)
|
||||
total_parameters += variable_parameters
|
||||
print "total of {} parameters".format(total_parameters)
|
||||
|
||||
|
||||
# load saved models
|
||||
global_step = 0
|
||||
saver = tf.train.Saver(tf.global_variables())
|
||||
load_was_success = True # yes, I'm being optimistic
|
||||
try:
|
||||
save_dir = '/'.join(FLAGS.save_path.split('/')[:-1])
|
||||
ckpt = tf.train.get_checkpoint_state(save_dir)
|
||||
load_path = ckpt.model_checkpoint_path
|
||||
saver.restore(sess, load_path)
|
||||
except:
|
||||
print "no saved model to load."
|
||||
load_was_success = False
|
||||
else:
|
||||
print "loaded model: {}".format(load_path)
|
||||
saver = tf.train.Saver(tf.global_variables())
|
||||
global_step = int(load_path.split('-')[-1]) + 1
|
||||
|
||||
|
||||
# train, baby, train
|
||||
loss_history = []
|
||||
for i in xrange(global_step, FLAGS.iterations + 1):
|
||||
llprint("\rIteration {}/{}".format(i, FLAGS.iterations))
|
||||
|
||||
X, y, text = next_batch(FLAGS, dl)
|
||||
tsteps = FLAGS.ascii_steps + FLAGS.stroke_steps
|
||||
|
||||
fetch = [loss, grad_op]
|
||||
feed = {dnc.X: X, dnc.y: y, dnc.tsteps: tsteps}
|
||||
|
||||
step_loss, _ = sess.run(fetch, feed_dict=feed)
|
||||
loss_history.append(step_loss)
|
||||
global_step = i
|
||||
|
||||
if i % 100 == 0:
|
||||
llprint("\n\tloss: {:03.4f}\n".format(np.mean(loss_history)))
|
||||
loss_history = []
|
||||
if i % FLAGS.save_every == 0 and i is not 0:
|
||||
llprint("\n\tSAVING MODEL\n")
|
||||
saver.save(sess, FLAGS.save_path, global_step=global_step)
|
203
handwriting/utils.py
Normal file
203
handwriting/utils.py
Normal file
@ -0,0 +1,203 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import random
|
||||
import os
|
||||
import cPickle as pickle
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from utils import *
|
||||
|
||||
class Dataloader():
|
||||
def __init__(self, FLAGS, logger, limit = 500):
|
||||
self.data_dir = FLAGS.data_dir
|
||||
self.alphabet = FLAGS.alphabet
|
||||
self.batch_size = FLAGS.batch_size
|
||||
self.stroke_steps = FLAGS.stroke_steps
|
||||
self.data_scale = FLAGS.data_scale # scale data down by this factor
|
||||
self.ascii_steps = FLAGS.ascii_steps
|
||||
self.logger = logger
|
||||
self.limit = limit # removes large noisy gaps in the data
|
||||
|
||||
data_file = os.path.join(self.data_dir, "strokes_training_data.cpkl")
|
||||
stroke_dir = self.data_dir + "/lineStrokes"
|
||||
ascii_dir = self.data_dir + "/ascii"
|
||||
|
||||
if not (os.path.exists(data_file)) :
|
||||
self.logger.write("\tcreating training data cpkl file from raw source")
|
||||
self.preprocess(stroke_dir, ascii_dir, data_file)
|
||||
|
||||
self.load_preprocessed(data_file)
|
||||
self.reset_batch_pointer()
|
||||
|
||||
def preprocess(self, stroke_dir, ascii_dir, data_file):
|
||||
# create data file from raw xml files from iam handwriting source.
|
||||
self.logger.write("\tparsing dataset...")
|
||||
|
||||
# build the list of xml files
|
||||
filelist = []
|
||||
# Set the directory you want to start from
|
||||
rootDir = stroke_dir
|
||||
for dirName, subdirList, fileList in os.walk(rootDir):
|
||||
for fname in fileList:
|
||||
filelist.append(dirName+"/"+fname)
|
||||
|
||||
# function to read each individual xml file
|
||||
def getStrokes(filename):
|
||||
tree = ET.parse(filename)
|
||||
root = tree.getroot()
|
||||
|
||||
result = []
|
||||
|
||||
x_offset = 1e20
|
||||
y_offset = 1e20
|
||||
y_height = 0
|
||||
for i in range(1, 4):
|
||||
x_offset = min(x_offset, float(root[0][i].attrib['x']))
|
||||
y_offset = min(y_offset, float(root[0][i].attrib['y']))
|
||||
y_height = max(y_height, float(root[0][i].attrib['y']))
|
||||
y_height -= y_offset
|
||||
x_offset -= 100
|
||||
y_offset -= 100
|
||||
|
||||
for stroke in root[1].findall('Stroke'):
|
||||
points = []
|
||||
for point in stroke.findall('Point'):
|
||||
points.append([float(point.attrib['x'])-x_offset,float(point.attrib['y'])-y_offset])
|
||||
result.append(points)
|
||||
return result
|
||||
|
||||
# function to read each individual xml file
|
||||
def getAscii(filename, line_number):
|
||||
with open(filename, "r") as f:
|
||||
s = f.read()
|
||||
s = s[s.find("CSR"):]
|
||||
if len(s.split("\n")) > line_number+2:
|
||||
s = s.split("\n")[line_number+2]
|
||||
return s
|
||||
else:
|
||||
return ""
|
||||
|
||||
# converts a list of arrays into a 2d numpy int16 array
|
||||
def convert_stroke_to_array(stroke):
|
||||
n_point = 0
|
||||
for i in range(len(stroke)):
|
||||
n_point += len(stroke[i])
|
||||
stroke_data = np.zeros((n_point, 3), dtype=np.int16)
|
||||
|
||||
prev_x = 0
|
||||
prev_y = 0
|
||||
counter = 0
|
||||
|
||||
for j in range(len(stroke)):
|
||||
for k in range(len(stroke[j])):
|
||||
stroke_data[counter, 0] = int(stroke[j][k][0]) - prev_x
|
||||
stroke_data[counter, 1] = int(stroke[j][k][1]) - prev_y
|
||||
prev_x = int(stroke[j][k][0])
|
||||
prev_y = int(stroke[j][k][1])
|
||||
stroke_data[counter, 2] = 0
|
||||
if (k == (len(stroke[j])-1)): # end of stroke
|
||||
stroke_data[counter, 2] = 1
|
||||
counter += 1
|
||||
return stroke_data
|
||||
|
||||
# build stroke database of every xml file inside iam database
|
||||
strokes = []
|
||||
asciis = []
|
||||
for i in range(len(filelist)):
|
||||
if (filelist[i][-3:] == 'xml'):
|
||||
stroke_file = filelist[i]
|
||||
# print 'processing '+stroke_file
|
||||
stroke = convert_stroke_to_array(getStrokes(stroke_file))
|
||||
|
||||
ascii_file = stroke_file.replace("lineStrokes","ascii")[:-7] + ".txt"
|
||||
line_number = stroke_file[-6:-4]
|
||||
line_number = int(line_number) - 1
|
||||
ascii = getAscii(ascii_file, line_number)
|
||||
if len(ascii) > 10:
|
||||
strokes.append(stroke)
|
||||
asciis.append(ascii)
|
||||
else:
|
||||
self.logger.write("\tline length was too short. line was: " + ascii)
|
||||
|
||||
assert(len(strokes)==len(asciis)), "There should be a 1:1 correspondence between stroke data and ascii labels."
|
||||
f = open(data_file,"wb")
|
||||
pickle.dump([strokes,asciis], f, protocol=2)
|
||||
f.close()
|
||||
self.logger.write("\tfinished parsing dataset. saved {} lines".format(len(strokes)))
|
||||
|
||||
|
||||
def load_preprocessed(self, data_file):
|
||||
f = open(data_file,"rb")
|
||||
[self.raw_stroke_data, self.raw_ascii_data] = pickle.load(f)
|
||||
f.close()
|
||||
|
||||
# goes thru the list, and only keeps the text entries that have more than stroke_steps points
|
||||
self.stroke_data = []
|
||||
self.ascii_data = []
|
||||
counter = 0
|
||||
|
||||
for i in range(len(self.raw_stroke_data)):
|
||||
data = self.raw_stroke_data[i]
|
||||
if len(data) > (self.stroke_steps+2):
|
||||
# removes large gaps from the data
|
||||
data = np.minimum(data, self.limit)
|
||||
data = np.maximum(data, -self.limit)
|
||||
data = np.array(data,dtype=np.float32)
|
||||
data[:,0:2] /= self.data_scale
|
||||
|
||||
self.stroke_data.append(data)
|
||||
self.ascii_data.append(self.raw_ascii_data[i])
|
||||
|
||||
# minus 1, since we want the ydata to be a shifted version of x data
|
||||
self.num_batches = int(len(self.stroke_data) / self.batch_size)
|
||||
self.logger.write("\tloaded dataset:")
|
||||
self.logger.write("\t\t{} individual data points".format(len(self.stroke_data)))
|
||||
self.logger.write("\t\t{} batches".format(self.num_batches))
|
||||
|
||||
def next_batch(self):
|
||||
# returns a randomized, stroke_steps-sized portion of the training data
|
||||
x_batch = []
|
||||
y_batch = []
|
||||
ascii_list = []
|
||||
for i in xrange(self.batch_size):
|
||||
data = self.stroke_data[self.idx_perm[self.pointer]]
|
||||
idx = random.randint(0, len(data)-self.stroke_steps-2)
|
||||
x_batch.append(np.copy(data[:self.stroke_steps]))
|
||||
y_batch.append(np.copy(data[1:self.stroke_steps+1]))
|
||||
ascii_list.append(self.ascii_data[self.idx_perm[self.pointer]][:self.ascii_steps])
|
||||
self.tick_batch_pointer()
|
||||
one_hots = [to_one_hot(s, self.ascii_steps, self.alphabet) for s in ascii_list]
|
||||
return x_batch, y_batch, ascii_list, one_hots
|
||||
|
||||
def tick_batch_pointer(self):
|
||||
self.pointer += 1
|
||||
if (self.pointer >= len(self.stroke_data)):
|
||||
self.reset_batch_pointer()
|
||||
def reset_batch_pointer(self):
|
||||
self.idx_perm = np.random.permutation(len(self.stroke_data))
|
||||
self.pointer = 0
|
||||
|
||||
# utility function for converting input ascii characters into vectors the network can understand.
|
||||
# index position 0 means "unknown"
|
||||
def to_one_hot(s, ascii_steps, alphabet):
|
||||
steplimit=3e3; s = s[:3e3] if len(s) > 3e3 else s # clip super-long strings
|
||||
seq = [alphabet.find(char) + 1 for char in s]
|
||||
if len(seq) >= ascii_steps:
|
||||
seq = seq[:ascii_steps]
|
||||
else:
|
||||
seq = seq + [0]*(ascii_steps - len(seq))
|
||||
one_hot = np.zeros((ascii_steps,len(alphabet)+1))
|
||||
one_hot[np.arange(ascii_steps),seq] = 1
|
||||
return one_hot
|
||||
|
||||
# abstraction for logging
|
||||
class Logger():
|
||||
def __init__(self, FLAGS):
|
||||
self.logf = '{}train_scribe.txt'.format(FLAGS.log_dir) if FLAGS.train else '{}sample_scribe.txt'.format(FLAGS.log_dir)
|
||||
with open(self.logf, 'w') as f: f.write("Scribe: Realistic Handriting in Tensorflow\n by Sam Greydanus\n\n\n")
|
||||
|
||||
def write(self, s, print_it=True):
|
||||
if print_it:
|
||||
print s
|
||||
with open(self.logf, 'a') as f:
|
||||
f.write(s + "\n")
|
BIN
handwriting/utils.pyc
Normal file
BIN
handwriting/utils.pyc
Normal file
Binary file not shown.
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Binary file not shown.
@ -1,2 +1,4 @@
|
||||
model_checkpoint_path: "model.ckpt-7000"
|
||||
all_model_checkpoint_paths: "model.ckpt-7000"
|
||||
model_checkpoint_path: "model.ckpt-10000"
|
||||
all_model_checkpoint_paths: "model.ckpt-8000"
|
||||
all_model_checkpoint_paths: "model.ckpt-9000"
|
||||
all_model_checkpoint_paths: "model.ckpt-10000"
|
||||
|
BIN
repeat-copy/rnn_models/model.ckpt-10000.data-00000-of-00001
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-10000.data-00000-of-00001
Normal file
Binary file not shown.
BIN
repeat-copy/rnn_models/model.ckpt-10000.index
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-10000.index
Normal file
Binary file not shown.
BIN
repeat-copy/rnn_models/model.ckpt-10000.meta
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-10000.meta
Normal file
Binary file not shown.
BIN
repeat-copy/rnn_models/model.ckpt-8000.data-00000-of-00001
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-8000.data-00000-of-00001
Normal file
Binary file not shown.
BIN
repeat-copy/rnn_models/model.ckpt-8000.index
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-8000.index
Normal file
Binary file not shown.
BIN
repeat-copy/rnn_models/model.ckpt-8000.meta
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-8000.meta
Normal file
Binary file not shown.
BIN
repeat-copy/rnn_models/model.ckpt-9000.data-00000-of-00001
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-9000.data-00000-of-00001
Normal file
Binary file not shown.
BIN
repeat-copy/rnn_models/model.ckpt-9000.index
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-9000.index
Normal file
Binary file not shown.
BIN
repeat-copy/rnn_models/model.ckpt-9000.meta
Normal file
BIN
repeat-copy/rnn_models/model.ckpt-9000.meta
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user