dnc-jupyter/repeat-copy/repeat-copy-rnn.ipynb

389 lines
59 KiB
Plaintext
Raw Permalink Normal View History

2017-02-22 05:38:44 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Repeat Copy Task\n",
"### Differentiable Neural Computer (DNC) using a RNN Controller\n",
"\n",
"<a href=\"http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html\"><img src=\"../static/dnc_schema.png\" alt=\"DNC schema\" style=\"width: 700px;\"/></a>\n",
"\n",
"**Sam Greydanus $\\cdot$ February 2017 $\\cdot$ MIT License.**\n",
"\n",
"Represents the state of the art in differentiable memory. Inspired by this [Nature paper](http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html). Some ideas taken from [this Gihub repo](https://github.com/Mostafa-Samir/DNC-tensorflow)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import sys\n",
"sys.path.insert(0, '../dnc')\n",
"\n",
"from dnc import DNC\n",
"from rnn_controller import RNNController\n",
"\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
2017-02-22 06:18:13 +08:00
"collapsed": false
2017-02-22 05:38:44 +08:00
},
"outputs": [],
"source": [
"xydim = 6\n",
"tf.app.flags.DEFINE_integer(\"xlen\", xydim, \"Input dimension\")\n",
"tf.app.flags.DEFINE_integer(\"ylen\", xydim, \"output dimension\")\n",
"tf.app.flags.DEFINE_integer(\"length\", 5, \"Sequence length\")\n",
"tf.app.flags.DEFINE_integer(\"reps\", 3, \"Number of repeats for copy task\")\n",
2017-03-01 11:11:13 +08:00
"tf.app.flags.DEFINE_integer(\"batch_size\", 16, \"Size of batch in minibatch gradient descent\")\n",
2017-02-22 05:38:44 +08:00
"\n",
"tf.app.flags.DEFINE_integer(\"R\", 1, \"Number of DNC read heads\")\n",
"tf.app.flags.DEFINE_integer(\"W\", 10, \"Word length for DNC memory\")\n",
"tf.app.flags.DEFINE_integer(\"N\", 7, \"Number of words the DNC memory can store\")\n",
"\n",
"tf.app.flags.DEFINE_integer(\"print_every\", 100, \"Print training info after this number of train steps\")\n",
2017-03-01 11:11:13 +08:00
"tf.app.flags.DEFINE_integer(\"iterations\", 20000, \"Number of training iterations\")\n",
2017-02-22 05:38:44 +08:00
"tf.app.flags.DEFINE_float(\"lr\", 1e-4, \"Learning rate (alpha) for the model\")\n",
"tf.app.flags.DEFINE_float(\"momentum\", .9, \"RMSProp momentum\")\n",
"tf.app.flags.DEFINE_integer(\"save_every\", 1000, \"Save model after this number of train steps\")\n",
2017-02-22 06:18:13 +08:00
"tf.app.flags.DEFINE_string(\"save_path\", \"rnn_models/model.ckpt\", \"Where to save checkpoints\")\n",
2017-02-22 05:38:44 +08:00
"FLAGS = tf.app.flags.FLAGS"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data functions"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
2017-03-01 11:11:13 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAABwCAYAAAAKXJmJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAADztJREFUeJzt3W2sHNV5wPH/cw2xEwhGtVOwShpISIpsVzTYIaUJL41t\nOSIqtFJFcKkqIBECkn5w1ZYgkEgpL2kjYlqoq6oNOBXh0lRtFFADxgbaFBFKsRNaX0NReAm02BSD\ncCRjO9j39MPspfdt792zd2Z3Z/f/k/bDzs7OPmees3N2dmaeiZQSkiSpvoa6HYAkSZobB3NJkmrO\nwVySpJpzMJckqeYczCVJqjkHc0mSas7BXJKkmnMwlySp5hzMJUmquSOqXHhELALWAi8CB6r8LEmS\n+swC4ERgc0rp9ZlmbGswj4gvAL8PHA88BfxuSunfp5l1LfDNdj5DkiQBcBFw90wzZA/mEfFZ4Bbg\nMuAJYD2wOSI+klLaM2n2FwHOv+vXWHzKogkvbFn/EGs2rJqy/KGVm7LiGd12Sdb82TJr1zeLfzPF\nL5vJcuPPLaV/x8q8FH9u26G8D2hiy/qtrNmwesr0nPizY3/y7az5y+przdpad81y1ay93eprVSot\ntyVtR5oZffLiUpZf1naqauVtB6dv8aVP5vXNqvr+nqf38J3fvg8aY+lM2tkzXw/8VUrpbwEi4nLg\nM8ClwJ9OmvcAwOJTFrHktOMnvDD/2PlTpkH+QfzRaZZRqtwvYZPpC4Al00zPjT//vjhHZs295LS8\nAbGZ+Qunz29e/Lmx/zRr/rL6WrO21l2zXC1o8t3tVl+rUmm5LWk70kzudqTq7VTVytsOTt/i/L5Z\ned+f9TB1Vp+JiCOBFcBDY9NScdu1rcAZudFJkqS5y/0BuBiYB7w6afqrFMfPJUlSh3lpmiRJNZd7\nzHwPcBg4btL044Ddzd60Zf1DzD92/oRpC3/+mMyPrrfl3Q6gw5atW9rtEDpmkNoKsOzCwWnvoOV2\n0LZTvdTiHcMjjAzvnDDt4N6DLb8/azBPKb0dEduAVcC9ABERjed/3ux9azas6ssThHL0TpfpjOXr\nlnU7hI4ZpLYCLBug9g5abgdtO9VLLV6+btmU/rZr+26+vuLOlt7fztnsXwM2NQb1sUvT3gNsamNZ\nkiRpjrIH85TStyJiMXA9xd/rPwTWppReKzs4SZI0u7YqwKWUNgIbS45FkiS1wbPZJUmquXbKuZ4J\n/AFF8ZglwK+nlO6d5U3FowWjh7+UFc9Q3Jw1P62F8Y7ceKrW4mp8xzWH8yoN3RB5lYxy12duPFXK\nDL3n+trQvK/kLX/06qz58+U1uJf62o3z8mK5drTafpzdN4fy+sLgyVujNw5l9s0e0M6e+VEUx8mv\nBLKL6kmSpHK1cwLcA8AD8M5laZIkqYs8Zi5JUs05mEuSVHMO5pIk1Vxb15nn2rJ+K/MXTqzNvmzd\n0oErlShJ0nQ6Wpu9XWs2rB742uySJDXT8drsEXEUcDL/f+HeByPiVOCNlNLLucuTJElz086e+Urg\nEYprzBNwS2P6N4BLS4pLkiS1qJ3rzP8FT5yTJKlnOChLklRzWXvmEXE18BvAKcB+4DHgqpTSszO+\nMaXi0YLcetO5smutZxa5Gx2ttpZ7i6vxHbk1p3Pl1lrPWZ3XVFz/umqjhyqutd5j9w2o2jWH8vpD\nTt/PvmdA5hex6trpuaU4e60Od/Z2rfLa6fVbo7l75mcCtwEfB1YDRwIPRsS7yw5MkiS1JmvPPKV0\n7vjnEXEx8L8Ud1B7tLywJElSq+Z6zPxYiv8X3ighFkmS1Ia2B/PGHdNuBR5NKe2cbX5JklSNuVSA\n2wgsBT5RUiySJKkNbQ3mEXE7cC5wZkpp12zzb1n/EPOPnVSb/cKlLF+3tJ2PlySpr3S8NntjID8f\nODul9FIr71mzYZW12SVJaqKjtdkjYiOwDjgP2BcRxzVe2ptSOpCzLEmSVI7cE+AuB44B/hl4Zdzj\ngnLDkiRJrcq9ztzyr5Ik9RgHZ0mSai73mPnlwBXAiY1JI8D1KaUHZnljy0W5c+tNz8usZ51d/3r0\n6qz5c2u558pdfG7N6RvnvStz/rwayddm1FvPXpXdL4880VBmXf+qa7nn9uUeE5m7Hjm13Kvsx9DG\nPRsy+3LVdf2Hhm7OW35mX8veruXet6Hie1rkbmdviFaX3/oQnbtn/jJwFXAaRQnXh4F7I8JrzCRJ\n6pLcY+b/NGnStRFxBcWNV6wCJ0lSF7RdAS4ihijOYp8P/GtpEUmSpCztFI1ZDnwfWAC8BVyQUvpR\n2YFJkqTWtHM2+zPAqcDpwO3APRHx0VKjkiRJLcveM08pHQKebzz9QUScTnGG+2XN3rNl/VbmL5xU\nm33d0iml6yRJGkw7Go/xWi+sOpe7po0ZAubNNMOaDautzS5JUlPLG4/xdgF/3dK7c68zvwm4H3gJ\neC9wEXAWcEPOciRJUnly98x/FvgGsATYC/wHsDal9EjZgUmSpNbkXmf++aoCkSRJ7bE2uyRJNTen\nE+Ai4kvATcCtKaXfKyWizHrWh3NrDGfWMCblFfXttXrZ2fWsD/80a/7cGsY5qzO7XnZm7Lml3Cuv\nZ525/J6rRZ8ZUHY966EK62X32LocOiJ3O5LZdzKLoWf3zR5z4xGZfSez9ntVteV3bT/EHStbW2bb\ne+YR8TGKy9GeancZkiRp7toazCPiaOAu4PPAm6VGJEmSsrS7Z/4XwH0ppYfLDEaSJOVrpzb7hcAv\nAS3+ky9JkqqUWzTmBOBWYHVKKfPu8JIkqQq5e+YrgPcB2yPeOX9vHnBWRHwRmJ/S1POVrc0uSVJz\nI8MjjNyzc8K0A28ebPn9uYP5VuAXJ03bBDwNfGW6gRyszS5J0kyWrVvGskk7uLu27+aOlXe29P7c\nCnD7gAk/HSJiH/B6SunpnGVJkqRylFEBrsfKLUiSNFjmfAvUlNKnyghEkiS1J5oc5i5n4RGnAds+\nt+2S3jlmntve3Dp9o5nlX7PLNlZb/jVXlauz6lRV/gED1tfqvDor72u5ei2gmpe17rXV2apd23fz\n9RV3AqxIKW2fad6sv9kj4rqIGJ302Dn7OyVJUlXa+Zt9B7AKGPvtcqi8cCRJUq52BvNDKaXXSo9E\nkiS1pZ2z2T8cEf8TEc9FxF0R8f7So5IkSS3LHcwfBy4G1gKXAycB34uIo0qOS5IktSi3aMzmcU93\nRMQTwI+BC4CmZWqmK+e68APH8Onb1uZ8fK3tGN7J8nVLux1Gx4wMj0ypZtSvdgyPDFRp4kFq7yD1\nYxis3EJv5XfH8AgjwxPPJz+4t7pyrhOklPZGxLPAyTPNN10517877+/n8tG1M3LPgA3m9+zsmS9J\n1UaGdw7YBnBw2jtI/RgGK7fQW/ldvm7ZlHU/7tK0Wc2pAlxEHE0xkO+ay3IkSVL7cq8z/2pEnBUR\nH4iIXwG+DbwNDFcSnSRJmlXu3+wnAHcDi4DXgEeBX04pvV52YJIkqTW5J8Cty1z+AoA9T++Z8sLB\nvQfZtX135uJK0KUSm83aO5QZzmg31tkMmq3OA29O396eKrFZ0gc07cu9Vs61pL7WrL39WM61jH7c\nli7VHy2rL/fadq3n8tuicWPngtnmrbo2+28B36zsAyRJ6n8XpZTunmmGqgfzRRTXpL8IHKjsgyRJ\n6j8LgBOBzbMdzq50MJckSdWb06VpkiSp+xzMJUmqOQdzSZJqzsFckqSa68pgHhFfiIgXImJ/RDwe\nER/rRhxVi4jrImJ00mPn7O/sfRFxZkTc27gd7mhEnDfNPNdHxCsR8VZEbImIGWv497LZ2hsRd06T\n6+92K965iIirI+KJiPhJRLwaEd+OiI9MM19f5LeV9vZLfiPi8oh4KiL2Nh6PRcSnJ83TF3mF2dvb\nL3mFLgzmEfFZ4BbgOuCjwFPA5ohY3OlYOmQHcBxwfOPxye6GU5qjgB8CVwJTLomIiKuALwKXAacD\n+yjy/K5OBlmiGdvbcD8
2017-02-22 05:38:44 +08:00
"text/plain": [
2017-03-01 11:11:13 +08:00
"<matplotlib.figure.Figure at 0x10fc9e590>"
2017-02-22 05:38:44 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def get_sequence(length, reps, dim):\n",
" X = [np.concatenate((np.random.randint(2, size=(length,dim)), np.zeros((length + 3,dim)))) for _ in range(reps)]\n",
" X = np.vstack(X) ; X[:,dim-1] = 0\n",
" \n",
" X = np.concatenate((X[-1:,:],X[:-1,:]))\n",
" y = np.concatenate((X[-(length + 2):,:],X[:-(length + 2),:]))\n",
" markers = range(length+1, X.shape[0], 2*length+3)\n",
" X[markers,dim-1] = 1\n",
" return X, y\n",
" \n",
"def next_batch(batch_size, length, reps, dim):\n",
" X_batch = []\n",
" y_batch = []\n",
" for _ in range(batch_size):\n",
" X, y = get_sequence(length, reps, dim)\n",
" X_batch.append(X) ; y_batch.append(y)\n",
" return [X_batch, y_batch]\n",
"\n",
"batch = next_batch(1, FLAGS.length, FLAGS.reps, FLAGS.xlen)\n",
"plt.imshow(batch[0][0].T - batch[1][0].T, interpolation='none')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Helper functions"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def binary_cross_entropy(y_hat, y):\n",
" return tf.reduce_mean(-y*tf.log(y_hat) - (1-y)*tf.log(1-y_hat))\n",
"\n",
"def llprint(message):\n",
" sys.stdout.write(message)\n",
" sys.stdout.flush()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Build graph, initialize everything"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"building graph...\n",
"defining loss...\n",
"computing gradients...\n",
"init variables... \n",
"ready to train..."
]
}
],
"source": [
"sess = tf.InteractiveSession()\n",
"\n",
"llprint(\"building graph...\\n\")\n",
"optimizer = tf.train.RMSPropOptimizer(FLAGS.lr, momentum=FLAGS.momentum)\n",
"dnc = DNC(RNNController, FLAGS)\n",
"\n",
"llprint(\"defining loss...\\n\")\n",
"y_hat, outputs = dnc.get_outputs()\n",
"y_hat = tf.clip_by_value(tf.sigmoid(y_hat), 1e-6, 1. - 1e-6)\n",
"loss = binary_cross_entropy(y_hat, dnc.y)\n",
"\n",
"llprint(\"computing gradients...\\n\")\n",
"gradients = optimizer.compute_gradients(loss)\n",
"for i, (grad, var) in enumerate(gradients):\n",
" if grad is not None:\n",
" gradients[i] = (tf.clip_by_value(grad, -10, 10), var)\n",
" \n",
"grad_op = optimizer.apply_gradients(gradients)\n",
"\n",
"llprint(\"init variables... \\n\")\n",
"sess.run(tf.global_variables_initializer())\n",
"llprint(\"ready to train...\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model overview...\n",
2017-03-01 11:11:13 +08:00
"\tvariable \"dnc_scope/basic_lstm_cell/weights:0\" has 73728 parameters\n",
"\tvariable \"dnc_scope/basic_lstm_cell/biases:0\" has 512 parameters\n",
"\tvariable \"W_z:0\" has 6144 parameters\n",
"\tvariable \"W_v:0\" has 768 parameters\n",
2017-02-22 05:38:44 +08:00
"\tvariable \"W_r:0\" has 60 parameters\n",
2017-03-01 11:11:13 +08:00
"total of 81212 parameters\n"
2017-02-22 05:38:44 +08:00
]
}
],
"source": [
"# tf parameter overview\n",
"total_parameters = 0 ; print \"model overview...\"\n",
"for variable in tf.trainable_variables():\n",
" shape = variable.get_shape()\n",
" variable_parameters = 1\n",
" for dim in shape:\n",
" variable_parameters *= dim.value\n",
" print '\\tvariable \"{}\" has {} parameters' \\\n",
" .format(variable.name, variable_parameters)\n",
" total_parameters += variable_parameters\n",
"print \"total of {} parameters\".format(total_parameters)"
]
},
{
"cell_type": "code",
2017-02-22 06:18:13 +08:00
"execution_count": 7,
2017-02-22 05:38:44 +08:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2017-03-01 11:11:13 +08:00
"loaded model: rnn_models/model.ckpt-6000\n"
2017-02-22 05:38:44 +08:00
]
}
],
"source": [
"global_step = 0\n",
"saver = tf.train.Saver(tf.global_variables())\n",
"load_was_success = True # yes, I'm being optimistic\n",
"try:\n",
" save_dir = '/'.join(FLAGS.save_path.split('/')[:-1])\n",
" ckpt = tf.train.get_checkpoint_state(save_dir)\n",
" load_path = ckpt.model_checkpoint_path\n",
" saver.restore(sess, load_path)\n",
"except:\n",
" print \"no saved model to load.\"\n",
" load_was_success = False\n",
"else:\n",
" print \"loaded model: {}\".format(load_path)\n",
" saver = tf.train.Saver(tf.global_variables())\n",
" global_step = int(load_path.split('-')[-1]) + 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train loop"
]
},
{
"cell_type": "code",
2017-02-22 06:18:13 +08:00
"execution_count": 8,
2017-02-22 05:38:44 +08:00
"metadata": {
"collapsed": false
},
2017-03-01 11:11:13 +08:00
"outputs": [],
2017-02-22 05:38:44 +08:00
"source": [
"loss_history = []\n",
"for i in xrange(global_step, FLAGS.iterations + 1):\n",
" llprint(\"\\rIteration {}/{}\".format(i, FLAGS.iterations))\n",
"\n",
" rlen = np.random.randint(1, FLAGS.length + 1)\n",
" rreps = np.random.randint(1, FLAGS.reps + 1)\n",
" X, y = next_batch(FLAGS.batch_size, rlen, rreps, FLAGS.xlen)\n",
" tsteps = rreps*(2*rlen+3)\n",
"\n",
" fetch = [loss, grad_op]\n",
" feed = {dnc.X: X, dnc.y: y, dnc.tsteps: tsteps}\n",
"\n",
" step_loss, _ = sess.run(fetch, feed_dict=feed)\n",
" loss_history.append(step_loss)\n",
" global_step = i\n",
"\n",
" if i % 100 == 0:\n",
" llprint(\"\\n\\tloss: {:03.4f}\\n\".format(np.mean(loss_history)))\n",
" loss_history = []\n",
" if i % FLAGS.save_every == 0 and i is not 0:\n",
" llprint(\"\\n\\tSAVING MODEL\\n\")\n",
" saver.save(sess, FLAGS.save_path, global_step=global_step)"
]
},
{
"cell_type": "code",
2017-03-01 11:11:13 +08:00
"execution_count": 12,
2017-02-22 05:38:44 +08:00
"metadata": {
"collapsed": false
},
2017-02-22 06:18:13 +08:00
"outputs": [
{
"data": {
2017-03-01 11:11:13 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA94AAAHHCAYAAABJK4BRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzs3Xm8XEWd///XOyEkyC5bcAUFF0hEAoqgKCMojDOK2yAB\nFXBBXEaH+Y0IAyaQGYGRURQZRlEhOGyDoyguBEXhK4iCEGQIgVHZlTUsYQ1L7uf3R50Lnb7d957q\n7tPbfT8fj34k93SdOp86dbqr6yxVigjMzMzMzMzMrBpTeh2AmZmZmZmZ2TBzx9vMzMzMzMysQu54\nm5mZmZmZmVXIHW8zMzMzMzOzCrnjbWZmZmZmZlYhd7zNzMzMzMzMKuSOt5mZmZmZmVmF3PE2MzMz\nMzMzq5A73mZmZmZmZmYVcsfbzMzMzMzMrELueJuZmZmZmZlVyB1v6wpJ+0sakfSiXsfSDkmHSFra\ngXw+JulWSdM6EVenSDpS0kiv4+gXrdR3v9atmZn1XqN2ttO/kfrhN9d47aek1epjc9tpk4E73tYt\nUbz6gqQdJc2XtE7GOmsDhwDHdiCEhcDqwMc6kFcnlaqnVvZfL3W5vhfSn3VrZma916idzf6NNEG7\n1tPfXOO1n5IEnAycIun5NW8txG2nDTl3vK1bvgOsERG39TqQwk7APGC9jHU+DEwFzq5/Q9K+ki4r\nzjA/KWlhzXv/WSy/X9IPJG0UEU8ApwH/2F4xeqaV/ddLHa3v8QxB3ZqZWXe18htpvHat17+5xms/\n/40U317AMZLWB7edNjm4421dEcmTvY6jhlpYZ3/gvEbliIgzImIn4BpgaUTsX/P2CcB5wIsj4p0R\ncW+x/BxgM0m7tBBLr7Wy/ybOVHpOFfnS4fouYZDr1szM6lTYPrX6G6lpu9YHv7n2p3n7+fWIuDgi\n7gc+Seqgj3LbaUPNHW/rivrnjUafcZL0UkkLJT0g6UFJp0iaUbfuaNqXSzpH0nJJyyR9RdL0mnQL\nJd3cYNurPE8laT7wxeLPW4q8V473LJSkzYBXARdOUNRvArMlbVestxHwaeA9EfFwbcKIWAzcD+w5\nQZ5IepGkkyTdIOmxovznSHpxo7KW3K9vkPQ7SY9L+qOkAyeKo1hv3P3XQqyvlHSmpPuBS2re30XS\nlbXxNXk27nlF+e6StELSEkkHlI23SRk3o1x9N5RTt2Zm1h0Zvycmap/GbXdq0pVqZ+t/I9Vs49uS\n/lJs46aibV2tRDvc8BlvSdtKOr8o98OSLpS0Q5N9NOHviCZl2Yxx2s+IuKnm/w9HxLKav9122lBb\nrdcB2KRR/7zR6P/PAW4CDgXmAB8B7gYOa5L25iLt60gd2vVIZ1YbbaPZtr8HvAzYG/gMcF+x/F6a\n26nIY/E4aQDOAI4DPiLpeuBo4OCIWNkk/WLg9RPkCfAaUpnPAv4MbAZ8ArhI0lYRsaJIV2q/SpoF\nXADcQ7pVbRpwZPH3RCbaf7mxfhf4QxGbivi2Bc4H7gA+T/qu+jywrGY9JG0MXA6sJN1ZsAz4a+Db\nktaOiBOA708QbyNl63s8ZevWzMy6I+f3BDRun8q0O7nt7Cq/UyRtCvwOWAf4BvB/wPOB9wLPYeJ2\nbczvIUlbAb8ClpOevX6a9Dz1xZLeGBG/a7CPJvp91ki77afbThteEeGXX5W/gP1IjdSLir/nAyPA\nyXXpvgfcU7dsNO3365afWOQ5q/j7VOCmBtueD6ysW/b/1cZTIv4FRfrnlEi7kNSwfQOYOUHarwOP\nlMhzeoNlry32y74N9tW4+xU4F3gUeH7NspcDT9XvqybxNN1/LcT6Xw3Snwc8DGxSs+wlwJO18QHf\nInXu16tb/0zSWfPpna5v4P2kkytfqVm2OunH2aa5deuXX3755Vd3Xhm/J8Zrn8q2O6XbWcb+Rjqt\nSLftOGUZrx3er/69Ip7HSY+9jS6bWfxeuajBPprw91mTuBq2n247/fIrfKu59VSQOqe1LgE2kLRW\ng7T/Ubfsa6Qz0G+rJrxVbAA8HRGPlUj7TWBtYElE3DVB2geANSa6fSvSoCPAM9NwPJd0JvpB0pno\nVZIzzn6VNAV4K3BuRPylZhv/Rzo735Z2Yy3i2xX4QUTcXZPvTaSr4LXeDfwImCppg9EX8DNg3Qbb\nK6thfUt6Oanz/33gQzVvbQdsDNTWd6m6NTOzrir7e6JRWwrjtzvrAXPaaWcliXSr9XkRcXVWyZrn\nOQV4SxHPrTXx3EU6YfCGut9dOb/P6o1pP912miXueFuv1Y+4+UDx7/oN0v6p7u8bSWdlN+twTO3a\njnQ2fG6JtKODo4w77YekGZIWSLoNeIJ0a9s9pM7lug1WGW+/bgSswdj9Cel2tra0EGv9c/kbjxPf\nM8uUnp9fDziQdHtd7euUmrw66fXAD4BdgCtrlu8EXBYRtfVYqm7NzKzryv6eWKV9KtHuBKndaaed\n3Yh0i/l1E6TLsRHpFvU/NHjvelJ/4IV1y3N+n03EbacZfsbbeq/Zs89lRqFuNA9mI1ObLM9xH7Ca\npDUj4tFmiSTtRTpz++/A8ZK2jojxGs/1gcdqrxI3cSLp1rHjgd+Sbg0L4L9pfAKtnf3artxYH29x\nO6N5nU66La+R/20x74b1HRGnAEh6J+nRhlE7kZ6dq1W2bs3MrLea/X6ob5/Ktjud+N3Ra63+jhjT\nfrrtNEvc8bZBsiVwa83fW5AawdEz0g/QeD7LzRosyz2TekPx7+bAkkYJJL2J9Ez3CZLWIw1eciBp\n4JNmNiedbZ7Ie4CFEXFIzfam09o82veSfkxs2eC9V5TMY7z9126s9wArSPVbrzbme0nPgU+NiF9O\nkGfH6ltpztHtSIPcjNoR+FJdHmXr1szMumui3xPNlGp3ilu7W21n7wUeAmZNkC6nXbsXeIz0jHm9\nV5Ku9t+ekd94GrafbjvNfKu5DQ6R5nus9WlSw7Oo+PtGYN1iJNG0UhoZ9J0N8hu9ilm2M/ibIobt\nGwaXRgvdNYrRTCPiQdJAJO+vnaKkgTnAZSW2v5Kxn9dP08JZ9YgYIT1j9k5JLxhdLumVpGfSyhhv\n/7UVaxHfhUV8M2vi2wLYoy7d94D3SNq6Ph9JG5aMt5Hx6vulwIMRcXuxnZeQbqH/XV26snVrZmbd\nM97vifpxRFZRtt1pp50tbrv+AfB2SeONU1K6XSvi+RmwZ92UZZuQHou7JCIemSifkpq1n247bdLz\nFW8bJJtL+iGpo70TsC9wekRcW7x/NvBvwA8knQCsCRxEep6qvvG6itQwHC3pbNLooedFRMPbniPi\nZklLgN1Io5Y/Q9JrgcMi4l11q32riPF9wHfq81Sa6/u5pAZ2Ij8GPiDpIWAp6SzxrqTnp1sxn9SJ\nvVTSSaRpTj5FOjv9qhLrj7f/OhHrkaQfJ5dJ+k/Sd9Uni/i2qUl3KOmZscslfbPY3nNJZ9XfDIx2\nvjtW36SrHdMkqfiB9HHgioh4ajRBZt2amVl3Nfs90fCOtjpl25122tl/Jg2G9itJJ5OuAD+PdLX4\n9RHxEJntGnAEqU37dRHPStJdeasDhzRZJ9s47afbTpv0fMXbBkWQOrBPAMeQ5sw8gTSvZEoQcT/p\n6vajpA74B0gN5I/HZBZxJakRehXpeaMzSYOPjOcU0hno6ZCeVZL0K+DXwBslPXMWW9LrSKOkBnCC\npO8Vo3vX+jvg1oi4uET5P03qvO9Den58E1Kj9ggtDEBSnKx4K+m27qNIc5fOo2RjN8H++0y7sUbE\nYtIPlvtJU5N8iNQZ/wXpNvTRdPeQpio7BXgXaZ+Pzsd6SE26tuu7Jq//Ix17X5Z0KGmE24vq1s2p\nWzMz654Jf0+Mu3L5dqfldjYi7gB2IM0jvg/wVdJ0XL8k3TKe3a5FxFJgZ+Ba0m+jz5Nurd+lyKuT\nxrSfbjvNQKsOJGjWfyTNJzVWGxWd617FsQ7pdvZDIuLUidJPkNfqwC3A0RFxYgfCmxQknQtsFRGN\nnlPr9LbG1LekacCXga9
2017-02-22 06:18:13 +08:00
"text/plain": [
2017-03-01 11:11:13 +08:00
"<matplotlib.figure.Figure at 0x1217958d0>"
2017-02-22 06:18:13 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
2017-02-22 05:38:44 +08:00
"source": [
"X, y = next_batch(FLAGS.batch_size, FLAGS.length, FLAGS.reps, FLAGS.xlen)\n",
"tsteps = FLAGS.reps*(2*FLAGS.length+3)\n",
"\n",
"feed = {dnc.X: X, dnc.y: y, dnc.tsteps: tsteps}\n",
"fetch = [outputs['y_hat'], outputs['w_w'], outputs['w_r'], outputs['f'], outputs['g_a']]\n",
"[_y_hat, _w_w, _w_r, _f, _g_a] = sess.run(fetch, feed)\n",
2017-03-01 11:11:13 +08:00
"_y_hat = np.clip(_y_hat, 1e-6, 1. - 1e-6)\n",
2017-02-22 05:38:44 +08:00
"_y = y[0] ; _X = X[0]\n",
"\n",
"fig, ((ax1,ax2),(ax3,ax5),(ax4,ax6),) = plt.subplots(nrows=3, ncols=2)\n",
"plt.rcParams['savefig.facecolor'] = \"0.8\"\n",
"fs = 12 # font size\n",
"fig.set_figwidth(10)\n",
"fig.set_figheight(5)\n",
"\n",
"ax1.imshow(_X.T - _y.T, interpolation='none') ; ax1.set_title('input ($X$) and target ($y$)')\n",
"ax2.imshow(_y_hat[0,:,:].T, interpolation='none') ; ax2.set_title('prediction ($\\hat y$)')\n",
"\n",
"ax3.imshow(_w_w[0,:,:].T, interpolation='none') ; ax3.set_title('write weighting ($w_w$)')\n",
"ax4.imshow(_w_r[0,:,:,0].T, interpolation='none') ; ax4.set_title('read weighting ($w_r$)')\n",
"\n",
"ax5.imshow(_f[0,:,:].T, interpolation='none') ; ax5.set_title('free gate ($f$)') ; ax5.set_aspect(3)\n",
"ax6.imshow(_g_a[0,:,:].T, interpolation='none') ; ax6.set_title('allocation gate ($g_a$)') ; ax6.set_aspect(3)\n",
"\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 1
}