2017-03-01 22:29:54 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Repeat Copy Task\n",
"### Differentiable Neural Computer (DNC) using a RNN Controller\n",
"\n",
"<a href=\"https://goo.gl/6eiJFc\"><img src=\"../static/dnc_schema.png\" alt=\"DNC schema\" style=\"width: 700px;\"/></a>\n",
"\n",
"**Sam Greydanus $\\cdot$ February 2017 $\\cdot$ MIT License.**\n",
"\n",
"Represents the state of the art in differentiable memory. Inspired by this [Nature paper](https://goo.gl/6eiJFc). Some ideas taken from [this Gihub repo](https://github.com/Mostafa-Samir/DNC-tensorflow)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Brain analogy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" However, there are interesting parallels between the memory mechanisms of a DNC and the functional capabilities of the mammalian hippocampus. DNC memory modification is fast and can be one-shot, resembling the associative long-term potentiation of hippocampal CA3 and CA1 synapses. The hippocampal dentate gyrus, a region known to support neurogenesis, has been proposed to increase representational sparsity, thereby enhancing memory capacity: usage- based memory allocation and sparse weightings may provide similar facilities in our model. Human 'free recall' experiments demonstrate the increased probability of item recall in the same order as first presented—a hippocampus-dependent phenomenon accounted for by the temporal context model, bearing some similarity to the formation of temporal links."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import sys\n",
"sys.path.insert(0, '../dnc')\n",
"\n",
"from dnc import DNC\n",
"from rnn_controller import RNNController\n",
"\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"xydim = 6\n",
"tf.app.flags.DEFINE_integer(\"xlen\", xydim, \"Input dimension\")\n",
"tf.app.flags.DEFINE_integer(\"ylen\", xydim, \"output dimension\")\n",
"tf.app.flags.DEFINE_integer(\"length\", 5, \"Sequence length\")\n",
"tf.app.flags.DEFINE_integer(\"batch_size\", 3, \"Size of batch in minibatch gradient descent\")\n",
"\n",
"tf.app.flags.DEFINE_integer(\"R\", 1, \"Number of DNC read heads\")\n",
"tf.app.flags.DEFINE_integer(\"W\", 10, \"Word length for DNC memory\")\n",
"tf.app.flags.DEFINE_integer(\"N\", 7, \"Number of words the DNC memory can store\")\n",
"\n",
"tf.app.flags.DEFINE_integer(\"print_every\", 100, \"Print training info after this number of train steps\")\n",
"tf.app.flags.DEFINE_integer(\"iterations\", 30000, \"Number of training iterations\")\n",
"tf.app.flags.DEFINE_float(\"lr\", 1e-4, \"Learning rate (alpha) for the model\")\n",
"tf.app.flags.DEFINE_float(\"momentum\", .9, \"RMSProp momentum\")\n",
"tf.app.flags.DEFINE_integer(\"save_every\", 1000, \"Save model after this number of train steps\")\n",
"tf.app.flags.DEFINE_string(\"save_path\", \"rnn_models/model.ckpt\", \"Where to save checkpoints\")\n",
"FLAGS = tf.app.flags.FLAGS"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data functions"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
2017-03-02 01:49:41 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAD+CAYAAAAu5uwhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAEaRJREFUeJzt3X+sXnVhx/H359Z6qzghsY520akEdfaWoKXqnAHmKMOQ\niVuWIHcm88fQILqZbonSgXHTMdkM1PmDZG5BUASC29wwm2IrbooOiVTQtriFDRWloIWBsVgs3O/+\neB7qvZf29nmee+4953v7fiXPH/f0nPN8cvqc83m+53nOc1JKQZIk1Wus7QCSJGl+LHNJkipnmUuS\nVDnLXJKkylnmkiRVzjKXJKlylrkkSZWzzCVJqtwTFnLlSZ4GnAZ8B9i7kM8lSdISswJ4NnB9KeW+\nuWZc0DKnV+SfXODnkCRpKXstcNVcMyx0mX8H4NVXvoqVL1g5rxVt2biVUzdvaCLTktTY9unQz/su\nW395I+v5HPDKBtbTnS3TM3XLGxpZj/vWwS3VbdPUbt7U9rls/fIG0nTRfI8+u4F/gn6XzmWhy3wv\nwMoXrGT1ulXzWtH4kePzXsdS1tj26VKZN7SeFcDqBtbTnS3TM9XQ/uC+dXBLdds0tZuvOKqp7fPE\nBtbRRU0dfQ79MbVfgJMkqXKWuSRJlbPMJUmqXDVlPjG5pu0Ineb2Obi1bQfoOF87B+e2mdvEWW6f\nuS3e0aeaMl87OdF2hE5z+xzccW0H6DhfOwfntpnbhNvnEBbv6FNNmUuSpAOzzCVJqpxlLklS5Sxz\nSZIqN1KZJ3lrkjuT/DTJTUle3HQwSZI0mKHLPMlrgIuBdwMvAm4Drk8yvx9flyRJIxllZL4R+NtS\nysdLKd8GzgEeAt7YaDJJkjSQoco8yXLgBOALj00rpRRgK/CyZqNJkqRBDDsyX0nvZlb3zpp+L7D0\nbi0kSVIF/Da7JEmVG/Z+5ruBR4GjZ00/GrjnYAtt2biV8SPHZ0ybmFzjTyVKkgTAt4Dts6Yd8jbm\n+w1V5qWUfUluAU4BrgNIkv7fHzzYcqdu3tDQDewlSVqKjuPxv+W+C/joQEsPOzIHuAS4vF/qN9P7\ndvuTgctHWJckSZqnocu8lHJt/5ry99A7vX4rcFop5UdNh5MkSYc2ysicUsqlwKUNZ5EkSSPw2+yS\nJFXOMpckqXKWuSRJlbPMJUmqnGUuSVLlLHNJkipnmUuSVDnLXJKkylnmkiRVzjKXJKlylrkkSZWz\nzCVJqpxlLklS5Ua6a5qWsKTtBPs9OnVe2xGkJalDuzkA50/9rO0InbRr2z4uWz/YvI7MJUmqnGUu\nSVLlLHNJkipnmUuSVDnLXJKkylnmkiRVzjKXJKlylrkkSZWzzCVJqpxlLklS5SxzSZIqN3SZJzkx\nyXVJfpBkKskZCxFMkiQNZpSR+RHArcC5QGk2jiRJGtbQd00rpXwO+BxA0rV770iSdPjxM3NJkipn\nmUuSVLmhT7OPYsvGrYwfOT5j2sTkGtZOTizG00uS1Gk7rt7Bjmt2zpi294GHB15+Ucr81M0bWL1u\n1WI8lSRJ1ZmYnGBi1gB317Z7uGz9xwZa3tPskiRVbuiReZIjgGOBx77JfkyS44H7Syl3NRlOkiQd\n2iin2dcDX6R3jXkBLu5PvwJ4Y0O5JEnSgEa5zvw/8PS8JEmdYSlLklQ5y1ySpMpZ5pIkVc4ylySp\ncpa5JEmVs8wlSaqcZS5JUuUsc0mSKmeZS5JUOctckqTKWeaSJFXOMpckqXKj3DVNS1kpbSfYb2zs\norYjzJRDz7KYpqY2tR1BlerQbg7AhWPL247QUYNXtCNzSZIqZ5lLklQ5y1ySpMpZ5pIkVc4ylySp\ncpa5JEmVs8wlSaqcZS5JUuUsc0mSKmeZS5JUOctckqTKDVXmSTYluTnJj5Pcm+TTSZ63UOEkSdKh\nDTsyPxH4EPBSYAOwHPh8kic1HUySJA1mqLumlVJOn/53ktcDPwROAG5sLpYkSRrUfD8zPwoowP0N\nZJEkSSMYucyTBPgAcGMpZWdzkSRJ0jCGOs0+y6XAGuDlDWWRJEkjGKnMk3wYOB04sZSy61Dzb9m4\nlfEjx2dMm5hcw9rJiVGeXpKkJWZ7/zHd3oGXHrrM+0X+auDkUsr3Blnm1M0bWL1u1bBPJUnSYWJt\n/zHdLuDvBlp6qDJPcikwCZwB7ElydP+fHiylDP4WQpIkNWbYL8CdAzwV+Hfg7mmPM5uNJUmSBjXs\ndeb+/KskSR1jOUuSVDnLXJKkylnmkiRVzjKXJKlylrkkSZWzzCVJqpxlLklS5SxzSZIqZ5lLklQ5\ny1ySpMpZ5pIkVc4ylySpcpa5JEmVG+quaVoApbSdYIaxsYvajrBf2g4wS8f+q1SRrr12Lhxb3naE\nWbq2t3fF4NvFkbkkSZWzzCVJqpxlLklS5SxzSZIqZ5lLklQ5y1ySpMpZ5pIkVc4ylySpcpa5JEmV\ns8wlSaqcZS5JUuWGKvMk5yS5LcmD/cdXk7xyocJJkqRDG3ZkfhfwTmAdcAJwA3BdkjVNB5MkSYMZ\n6q5ppZR/nTXpgiRvAV4K7GwslSRJGtjIt0BNMgacCYwDX24skSRJGsrQZZ5kLfCfwArgIeDMUsod\nTQeTJEmDGeXb7N8GjgdeAnwYuCbJixpNJUmSBjb0yLyU8gjwv/0/v5HkJcBbgDcfbJktG7cyfuT4\njGkTk2tYOzkx7NNLkrQEfQvYPmva3oGXHvkz82nGgGVzzXDq5g2sXreqgaeSJGkpOq7/mG4X8NGB\nlh6qzJP8JfBZ4HvALwCvBU4C/mKY9UiSpOYMOzL/ReAKYDXwIPBN4LRSyhebDiZJkgYz7HXmZy9U\nEEmSNBp/m12SpMpZ5pIkVc4ylySpcpa5JEmVs8wlSaqcZS5JUuUsc0mSKmeZS5JUOctckqTKWeaS\nJFXOMpckqXKWuSRJlbPMJUmq3LC3QB1NKb1HB4wtu6jtCDNMTW1qO8IMU1PntR3h57rxktmva6+d\nrunILr7fhcuWtx1hvwum9rUdYYbzO5ana/v6hcue2HaEoTkylySpcpa5JEmVs8wlSaqcZS5JUuUs\nc0mSKmeZS5JUOctckqTKWeaSJFXOMpckqXKWuSRJlbPMJUmq3LzKPMl5SaaSXNJUIEmSNJyRyzzJ\ni4E3A7c1F0eSJA1rpDJP8hTgSuBs4IFGE0mSpKGMOjL/CPCZUsoNTYaRJEnDG/p+5knOAl4IrG8+\njiRJGtZQZZ7kGcAHgA2llI7d3V6SpMPTsCPzE4CnA9uSpD9tGXBSkrcB46WUMnuhLRu/wPhR4zOm\nTZy1hrWTa0aILEnSUvMtYPusaXsHXnrYMt8KHDdr2uXA7cBFBypygFM3n8LqdauGfCpJkg4Xx/H4\net0FfHSgpYcq81LKHmDn9GlJ9gD3lVJuH2ZdkiSpGU38AtwBR+OSJGlxDP1t9tlKKb/RRBBJkjQa\nf5tdkqTKWeaSJFXOMpckqXKWuSRJlbPMJUmqnGUuSVLlLHNJkipnmUuSVDnLXJKkylnmkiRVzjKX\nJKlylrkkSZWzzCVJqty875o2kKT36ICpR89rO0K3deT/CYAORYHuvXbGxt7XdoQZpqY2tR1hhvMf\n3dd2hM7q0m4OdG5fP3/qZ21HAGDXtn1ctn6weR2ZS5JUOctckqTKWeaSJFXOMpckqXKWuSRJlbPM\nJUmqnGUuSVLlLHNJkipnmUuSVDnLXJKkylnmkiRVbqgyT/LuJFOzHjsXKpwkSTq0UW60sh04hZ//\nNP4jzcWRJEnDGqXMHyml/KjxJJIkaSSjfGb+3CQ/SPI/Sa5M8szGU0mSpIENW+Y3Aa8HTgPOAZ4D\nfCnJEQ3nkiRJAxrqNHsp5fppf25PcjPwXeBM4GNNBpMkSYMZ5TPz/UopDyb5b+DYuebbsnEr40eO\nz5g2MbmGtZMT83l6SZKWhB1X72DHNTMvDtv7wMMDLz+vMk/yFHpF/vG55jt18wZWr1s1n6eSJGnJ\nmpicYGLWAHfXtnu4bP1
2017-03-01 22:29:54 +08:00
"text/plain": [
2017-03-02 01:49:41 +08:00
"<matplotlib.figure.Figure at 0x103d98190>"
2017-03-01 22:29:54 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def get_sequence(length, dim):\n",
" X = np.concatenate((np.random.randint(2, size=(length,dim)), np.zeros((length + 3,dim))))\n",
" X = np.vstack(X) ; X[:,dim-1] = 0\n",
" \n",
" X = np.concatenate((X[-1:,:],X[:-1,:]))\n",
" y = np.concatenate((X[-(length + 2):,:],X[:-(length + 2),:]))\n",
" markers = range(length+1, X.shape[0], 2*length+3)\n",
" X[markers,dim-1] = 1\n",
" return X, y\n",
" \n",
"def next_batch(batch_size, length, dim):\n",
" X_batch = []\n",
" y_batch = []\n",
" for _ in range(batch_size):\n",
" X, y = get_sequence(length, dim)\n",
" X_batch.append(X) ; y_batch.append(y)\n",
" return [X_batch, y_batch]\n",
"\n",
"batch = next_batch(1, FLAGS.length, FLAGS.xlen)\n",
"plt.imshow(batch[0][0].T - batch[1][0].T, interpolation='none')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Helper functions"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def binary_cross_entropy(y, y_hat):\n",
" return tf.reduce_mean(-y*tf.log(y_hat) - (1-y)*tf.log(1-y_hat))\n",
"\n",
"def llprint(message):\n",
" sys.stdout.write(message)\n",
" sys.stdout.flush()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def free_recall_loss(y, y_hat, tsteps): \n",
" # sorry this dimension stuff is uuuuugly but we have to because it's batched\n",
" y = tf.expand_dims(y, [1])\n",
" y_hat = tf.expand_dims(y_hat, [1])\n",
" \n",
" y_hat = tf.tile(y_hat,[1,tsteps,1,1])\n",
" y_hat = tf.transpose(y_hat, [0,2,1,3])\n",
" \n",
" y_minus = -y*tf.log(y_hat) - (1-y)*tf.log(1-y_hat) # binary cross entropy loss\n",
" y_minus = tf.reduce_sum(y_minus, axis=-1)\n",
" y_minus = tf.reduce_min(y_minus, axis=1)\n",
" \n",
" return tf.reduce_sum(y_minus) / (FLAGS.batch_size*FLAGS.length)"
]
},
{
"cell_type": "code",
2017-03-02 01:49:41 +08:00
"execution_count": 77,
2017-03-01 22:29:54 +08:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2017-03-02 01:49:41 +08:00
"permute by: [4 2 1 0 3]\n",
"guessed permutation: [4 2 1 0 3]\n"
2017-03-01 22:29:54 +08:00
]
2017-03-02 01:49:41 +08:00
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAGMAAAB1CAYAAABatF8TAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAB7ZJREFUeJztnWuoXcUVx3//aND6LHjrA2x9YBVEiRprfTQJRiHih4hQ\nfELRYiU++iEgbQVJMBQRxGBpDFSlvYqo+MGAIubGGOuDGFKvNhqqLX1EbRuvJtJbiJEmcfXD7JN7\nPNnn3DOzZ+8M7vnB5t4zZ2bPOmedea29Zo3MjEwazNjfAmSmyMpIiKyMhMjKSIisjITIykiIrIyE\nyMpIiAPrvLmko4AFwBbgizrrSpiDgROBMTPbPjCnmdV2AdcBli8MuG667yuoZUi6DbgDOBbYBPzU\nzP5QknWL+3MlMFLy9hiu4XyVm3m4b92rgctK0gcZdcprgYf5iXcpX9m2Ac+4f7cMqAwI6KYkXQ3c\nD9wMbAQWA2OSTjWzbT3Zi65pBDiu5G4Hl6aX5RxcYrAy+pUJqSlEtoJpu+mQAXwx8Bsze8zM3gcW\nAZ8DPw64V6YLL2VImgnMBl7qpJkbHNYCF8QVrX34towR4ABgoid9Ajd+ZCpQ69R2ijFcj9rNGcXl\nh3+JsDIhpY4EnuxJ85nP+ypjG7AHOKYn/Rjg4/7FFjB46BueMwPKNKWMy0vStgIPDVneq5sys13A\nOHBJJ02Sitfrfe6V2ZeQbmo5MCppnKmp7SHAaES5Wom3MszsaUkjwDJc9/RHYIGZfRpbuLYRNICb\n2UpgZWRZWk+22iaE6nTVkXQOMO4sJz6zqabchxRQxle2reDsWbPN7K1BOb1bhqQ5kp6V9C9JX0pa\n6HuPTDkh3dShuEH7Vpr7CbeCkNnUapy1uLPGyEQiD+AJkZWREA0ZCldTbigMsTSlzObi6mZ4U2FD\nyriMWIbCtCmzRO+d2k5L7qYSIuQZ+KHAKUytmE6WNAv4zMw+iilc2wjpps4FXmbKBeX+Iv1R8nPw\nSoSsM14hd2+1kL/UhGhoNtXp0YZjCcu8a1jGUu8yYfgaHYbP7+uqc6ekjZL+K2lC0ipJp3pKl+mD\nbzc1B/g18H3gUmAmsEbSN2IL1ka8uikz+4oDhKQbgE9wjm2vxxOrnVQdwL+JGww+iyBL6wlWRmE+\nfwB43cz+FE+k9lJlNrUSOB24aPqs8TwK0+ZdGjcUSlqBc6CbY2Zbpy8Rz6Mwbc5kX0v08D6FIbap\nFcAVwDwz+9C3fKY/XsqQtBK4FlgI7JDU8bmdNLO27tmLhu8Avgg4Avg98O+u66q4YrUT33VGtmXV\nSP5yE6IhQ2Gq+Lt9+Roxh3/o6m8oXCRpk6TJ4lovqWwncCYA327qI+DnwDk4e9Q64FlJp8cWrI34\nDuDP9yTdJekWnBU3m0QqEjxmSJqBm9IeBLwWTaIWE7ICPwN4A2ds+hy4ysz+GluwNhLSMt4HZuF2\n2v4QeErSPDN7u3+RdhgKq/kThnmH7Ab+Xrx8W9J5wC24HTF9aIehsJo/YZxF3wxc1IRMRXwNhfcA\nLwAfAocD1wNzgV/GF619+HZTR+M8B48DJoF3cNuOX44tWBvxXWfcVJcgmWwoTIokDYUhGwWXcHd0\nOcqoz5+wYsuQ9Iti+/HyKvfJOKq46nwPt7bYFE+cdhOkDEmHAY8DNwH/iSpRiwltGQ8Cz5nZupjC\ntJ0QQ+E1wFm4HUyZiPiuwI/HuXReWkRlG5J2GAqr+RP6t4zZwLeAt7pCVRwAzJV0O3CQlYbpaYeh\nsJo/ob8y1pbUNwq8B9xbrojMsPiaQ3bQ83hV0g5gu5m9F1OwNhLDHJJbQyQqm0PMbH4MQTLZUJgU\nDRkKhY/J7G6W1CdK49QUyEXS0sIw2H1lf6lIhLSMzbhw252f+u544rSbEGXsztGd6yFkAP9uEUb1\nb5Iel/Tt6FK1FF9lbABuwNk3FgEnAa8WMagyFfFdgY91vdwsaSPwAc7n9nf9S/aLUQj+cQo3429g\nbKrMCzinmW4ailFoZpOS/oKLzDaAfjEKn+TrpYxJ4JqetIZiFBZP/E4pasxUxHedcZ+kuZJOkHQh\nsArYxb5HDWUC8O2mjgeeAI4CPsVF0jl/2jNLM0PhO4Bf63n/YtTuPdiywxeU93CDDMH9ygxif5bZ\n+9l7ZzD7kg/aTeeg3boPM8lHUHscQV2rMjJ+5OcZCZGVkRBZGQmRlZEQWRkJsV+UIek2Sf+QtFPS\nhmJ7waD8XsfRhUSmjhGkpup+lcaVIelq3DEPS4Gzcfs7xorzYvvhexxdSGTqSkFqouxXqXMF3mdV\nvgH4VddrAf8EfjZk+S+BhZ51jhTlfuBZbjtw4xD5DgP+DMzHnS2yPOS7abRlSJqJ+9W91Ekr/HPX\nAhfUWLVXZGpJM4qtD8MGqYmyX6XpDZYjOK/1iZ70CeC0Oir0iUwdEqQm5n6VJHe7RsYjMrVfkJrw\n/Sp9aHi8mIl7GLWwJ30UWBV7zABW4J7RfydQ3heBhwa8fwWwB/hf8bl2FfJ10pTsmFH8esZxTnDA\n3m7kEmB9zLq6IlNfbOGRqacLUtPZr3IWrkXNAt7EbT6dVYyHQ7M/uqnlwKikcWAjsBg4BNc6SvE9\nji4kMnVIkBqLvV+lyW6qq3nfinvGsRM3YJ47Tf55TDX/7uu3A7qy3rx7gB8NqOMRXBytncDHwBpg\nfsBnW0fg1DY/z0iIbJtKiKyMhMjKSIisjITIykiIrIyEyMpIiKyMhMjKSIisjITIykiI/wOKuxTN\n0VYkUQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x10eb2b510>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAGMAAAB1CAYAAABatF8TAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAB65JREFUeJztnWuoXcUVx3//aPCVquCtVlCr4gNESWpsq7ZJaBQifogg\nYnxAsaIhvj4ERC2UBEIpBTFYGgON0t6KqPjBgCLmxpj4Iobo1aZefFEf9RWvJsFbyIPmsfph9sk9\nnpx97pk5s88duucHm8uZu2fvdfY6M2tm7TVrZGZk0mDKZAuQGScrIyGyMhIiKyMhsjISIisjIbIy\nEiIrIyEOr/Likk4A5gGfAnuqvFfCHAmcDgyZ2faOZ5pZZQdwA2D5wIAbJnpeQS1D0h3A3cCPgC3A\nXWb2RptTPwW4Ghho8881wBWe9y6r08mpM4Rrnq08zMKAO3WiXZ1twNNQPItOeCtD0gLgAWAhsBlY\nDAxJOsfMtrWcvgecIk5uc60jS8o7UVankzLK79Pp7jGlA7ropkMM+GLgL2b2qJm9DywCdgE3B1wr\n04SXMiRNBWYCLzbKzBmHdcAlcUWrH74tYwA4DBhtKR/F2Y9MD1Q6tG2wBtebNnN+cfjSrzphtY4D\nnmgp635E76uMbcB+4KSW8pOAr8sqXYG/KSzjgoA6YcoIudOVbcq2Aqu6qu3VTZnZXmAYuKxRJknF\n540+18ocSkg3tRwYlDTM+ND2aGAwoly1xFsZZvaUpAFgGa57+gcwz8y+jS1c3Qgy4Ga2ElgZWZba\nk722CdGXoe0qbiXeeOr/F++WIWmWpGckfSnpgKT5VQhWR0K6qWNwRvt2OvvnMp6EjKbW4CbVjTlG\nJhLZgCdEVkZC9GU05d61xXIVpsw7wEhLWXWOwkDmUY+h7QUc6mCsyFGYqZaQd+DHAGcBjZHUmZKm\nAzvM7POYwtWNkG7qImAD4yEoDxTlfye/B++JkHnGy+TurRLyQ02IPo2mxLiJ6QZ/L8sSlnnXWcZS\n7zpV4huq81tJmyX9R9KopNWSzqlKuLrh203NAv4M/By4HJgKrJV0VGzB6ohXN2Vm3wt/kHQT8A0u\nsO21eGLVk14N+PG4Dn5HBFlqT7AyCvf5g8BrZvZuPJHqSy+jqZXAecAvJj61LKYwJFAsZSbBUShp\nBS58bpaZbZ24RsyYwpTpzVEY4ptaAVwFzDGzz3zrZ8rxUoaklcD1wHxgp6RGzO2YmdV1zV40fA34\nIuBY4CXgq6bj2rhi1RPfeUb2ZVVIfrgJ0RdH4a2s8hpLLWNJZbKkjK+jcJGkLZLGimOjJN/1uZkS\nfLupz4F7gQtx/qj1wDOSzostWB3xNeDPtRT9TtJtOC9udon0SLDNkDQFN6Q9Ang1mkQ1JmQGfj7w\nOs7ZtAu41sz+FVuwOhLSMt4HpuPW2V4DPClpjpm9XVahLvGEfXcUmtk+4OPi49uSfgbcBuVZUeoS\nT5hCROEUXNaETI/4Ogr/ADwPfAb8ALgRmA38Pr5o9cO3mzoRFzl4MjAG/BO37HhDbMHqiO8845aq\nBMlkR2FS9MVR6BtPGBIdGEb1kYtbgYe7PLenliHpvmL58fJerpNx9BKq81Pc3GJLPHHqTZAyJE0D\nHgNuAb6LKlGNCW0ZDwHPmtn6mMLUnRBH4XXADNwKpkxEfGfgp+BCOi8vsrJ1RV3iCUfoxU3o3zJm\nAj8E3mpKVXEYMFvSncAR1mZHrbrEE7bzRPsMbX2VsY5Df9CDwHvAH9spItM9vu6QnbS8XpW0E9hu\nZu/FFKyOxHCH5NYQiZ7dIWY2N4YgmewoTApVaXMlXQgMk2yOwpDcZb7P6+B4aqaZvdXpTN+IwqWF\nY7D5yPFSkQixGSO4dNuNn9W+eOLUmxBl7MvZnashxICfXaRR/UjSY5JOjS5VTfFVxibgJlwo1CLg\nDOCVIgdVpkd8Z+BDTR9HJG0G/o2Luf1bec2ymMLmv90y0qc67+DvynweFzTTTJ9yFJrZmKQPcZnZ\nOlAWU/gk6SpjBH9ljAHXtZR17yrs9R34NJwiulgLnpkI33nG/ZJmS/qxpEuB1cBeDt1oKBOAbzd1\nCvA4cALwLS6TzsUT7lma6QpfA3695/ULq926sWWDPfj3cDHrdHKHlNXp5A5pV+fgd28dwbS5dN5o\nN5mNdqt2FOYtqD22oK5UGRk/8vuMhMjKSIisjITIykiIrIyEmBRlSLpD0ieSdkvaVCwv6HS+13Z0\nIZmpYySp6XW9St+VIWkBbpuHpcBPcOs7hor9Ysvw3Y4uJDN1T0lqoqxXqXIGXjIr3wT8qemzgC+A\ne7qsfwCY73nPgaLeLz3rbQd+08V504APgLm4vUWWhzybvrYMSVNxv7oXG2VFfO464JIKb+2VmVrS\nlGLpQ7dJaqKsV+nTlg0HGcBFrY+2lI8C51ZxQ5/M1CFJamKuV+m3MiYDj8zUfklqQterlNJnezEV\n9zJqfkv5ILA6ts0AVuDe0Z8WKO8LwKoO/78K2A/8t/heewv5GmVK1mYUv55hXBAccLAbuQzYGPNe\nTZmpf2XhmaknSlLTWK8yA9eipgNv4hafTi/sYddMRje1HBiUNAxsBhYDR+NaR1t8t6MLyUwdkqTG\nYq9X6Wc31dS8b8e949iNM5gXTXD+HMabf/Px1w5dWeu5+4Ffd7jHI7g8WruBr4G1wNyA77aewKFt\nfp+RENk3lRBZGQmRlZEQWRkJkZWREFkZCZGVkRBZGQmRlZEQWRkJkZWREP8DwWURTEO1cuMAAAAA\nSUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x116a82c50>"
]
},
"metadata": {},
"output_type": "display_data"
2017-03-01 22:29:54 +08:00
}
],
"source": [
2017-03-02 01:49:41 +08:00
"def remove_duplicates(k, l):\n",
" i = 0 ; removed_duplicate = False\n",
" while i < len(k) and not removed_duplicate: \n",
" if k[-i] == k[-i-1]:\n",
" ri = np.random.randint(2)\n",
" k = np.delete(k,-i-ri) ; l = np.delete(l,-i-ri) ; removed_duplicate = True\n",
" k,l = remove_duplicates(k,l)\n",
" i+=1\n",
" return k,l\n",
"\n",
"def guess_recall_order(real_y, pred_y, tsteps):\n",
" # sorry this is uuuuugly but we have to because it's batched\n",
" real_y = np.tile(real_y, [1,1,1,1]) ; real_y = np.transpose(real_y, (1,0,2,3))\n",
" pred_y = np.tile(pred_y,[1,1,1,1]) ; pred_y = np.transpose(pred_y, (1,0,2,3))\n",
" \n",
" pred_y = np.tile(pred_y,[1,tsteps,1,1])\n",
" pred_y = np.transpose(pred_y, (0,2,1,3))\n",
" \n",
" real_y = real_y[0,:,:,:] ; pred_y = pred_y[0,:,:,:]\n",
" y_minus = .5*(real_y - pred_y)**2\n",
" y_minus = np.sum(y_minus, axis=-1)\n",
" y_mins = np.amin(y_minus, axis=1)\n",
" \n",
" k, l = np.where(y_minus == np.tile(y_mins,[tsteps,1]).T)\n",
" k, l = remove_duplicates(k, l)\n",
" return l\n",
"\n",
"# test\n",
2017-03-01 22:29:54 +08:00
"X, real_y = next_batch(FLAGS.batch_size, FLAGS.length, FLAGS.xlen)\n",
"real_y = np.stack(real_y)[:,-FLAGS.length:,:]\n",
"\n",
2017-03-02 01:49:41 +08:00
"p = np.random.permutation(y_i.shape[0]) ; print 'permute by: ', p\n",
"pred_y = [y_i[p,:] for y_i in real_y]\n",
2017-03-01 22:29:54 +08:00
"pred_y = np.stack(pred_y)\n",
"\n",
2017-03-02 01:49:41 +08:00
"plt.figure(0, figsize=[1,1]) ; plt.imshow(real_y[0,:,:].T, interpolation='none')\n",
"plt.figure(1, figsize=[1,1]) ; plt.imshow(pred_y[0,:,:].T, interpolation='none')\n",
2017-03-01 22:29:54 +08:00
"\n",
2017-03-02 01:49:41 +08:00
"print \"guessed permutation: \", guess_recall_order(real_y, pred_y, FLAGS.length)"
2017-03-01 22:29:54 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Build graph, initialize everything"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"building graph...\n",
"defining loss...\n",
"computing gradients...\n",
"init variables... \n",
"ready to train..."
]
}
],
"source": [
"sess = tf.InteractiveSession()\n",
"\n",
"llprint(\"building graph...\\n\")\n",
"optimizer = tf.train.RMSPropOptimizer(FLAGS.lr, momentum=FLAGS.momentum)\n",
"dnc = DNC(RNNController, FLAGS)\n",
"\n",
"llprint(\"defining loss...\\n\")\n",
"y_hat, outputs = dnc.get_outputs()\n",
"y_hat = tf.clip_by_value(tf.sigmoid(y_hat), 1e-6, 1. - 1e-6) # avoid infinity\n",
"rlen = (dnc.tsteps-3)/2\n",
"loss = free_recall_loss(dnc.y[:,-rlen:,:], y_hat[:,-rlen:,:], tsteps=rlen)\n",
"\n",
"llprint(\"computing gradients...\\n\")\n",
"gradients = optimizer.compute_gradients(loss)\n",
"for i, (grad, var) in enumerate(gradients):\n",
" if grad is not None:\n",
" gradients[i] = (tf.clip_by_value(grad, -10, 10), var)\n",
" \n",
"grad_op = optimizer.apply_gradients(gradients)\n",
"\n",
"llprint(\"init variables... \\n\")\n",
"sess.run(tf.global_variables_initializer())\n",
"llprint(\"ready to train...\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model overview...\n",
"\tvariable \"dnc_scope/basic_lstm_cell/weights:0\" has 73728 parameters\n",
"\tvariable \"dnc_scope/basic_lstm_cell/biases:0\" has 512 parameters\n",
"\tvariable \"W_z:0\" has 6144 parameters\n",
"\tvariable \"W_v:0\" has 768 parameters\n",
"\tvariable \"W_r:0\" has 60 parameters\n",
"total of 81212 parameters\n"
]
}
],
"source": [
"# tf parameter overview\n",
"total_parameters = 0 ; print \"model overview...\"\n",
"for variable in tf.trainable_variables():\n",
" shape = variable.get_shape()\n",
" variable_parameters = 1\n",
" for dim in shape:\n",
" variable_parameters *= dim.value\n",
" print '\\tvariable \"{}\" has {} parameters' \\\n",
" .format(variable.name, variable_parameters)\n",
" total_parameters += variable_parameters\n",
"print \"total of {} parameters\".format(total_parameters)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loaded model: rnn_models/model.ckpt-30000\n"
]
}
],
"source": [
"global_step = 0\n",
"saver = tf.train.Saver(tf.global_variables())\n",
"load_was_success = True # yes, I'm being optimistic\n",
"try:\n",
" save_dir = '/'.join(FLAGS.save_path.split('/')[:-1])\n",
" ckpt = tf.train.get_checkpoint_state(save_dir)\n",
" load_path = ckpt.model_checkpoint_path\n",
" saver.restore(sess, load_path)\n",
"except:\n",
" print \"no saved model to load.\"\n",
" load_was_success = False\n",
"else:\n",
" print \"loaded model: {}\".format(load_path)\n",
" saver = tf.train.Saver(tf.global_variables())\n",
" global_step = int(load_path.split('-')[-1]) + 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train loop"
]
},
{
"cell_type": "code",
2017-03-02 01:49:41 +08:00
"execution_count": 10,
2017-03-01 22:29:54 +08:00
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"loss_history = []\n",
"for i in xrange(global_step, FLAGS.iterations + 1):\n",
" llprint(\"\\rIteration {}/{}\".format(i, FLAGS.iterations))\n",
"\n",
" rlen = np.random.randint(1, FLAGS.length + 1)\n",
" X, y = next_batch(FLAGS.batch_size, rlen, FLAGS.xlen)\n",
" tsteps = 2*rlen+3\n",
"\n",
" fetch = [loss, grad_op]\n",
" feed = {dnc.X: X, dnc.y: y, dnc.tsteps: tsteps}\n",
"\n",
" step_loss, _ = sess.run(fetch, feed_dict=feed)\n",
" loss_history.append(step_loss)\n",
" global_step = i\n",
"\n",
" if i % 100 == 0:\n",
" llprint(\"\\n\\tloss: {:03.4f}\\n\".format(np.mean(loss_history)))\n",
" loss_history = []\n",
" if i % FLAGS.save_every == 0 and i is not 0:\n",
" llprint(\"\\n\\tSAVING MODEL\\n\")\n",
" saver.save(sess, FLAGS.save_path, global_step=global_step)"
]
},
{
"cell_type": "code",
2017-03-02 01:49:41 +08:00
"execution_count": 130,
2017-03-01 22:29:54 +08:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
2017-03-02 01:49:41 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1AAAAHpCAYAAACMQd2lAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzs3Xm8JFV58PHfwzDMACoDgixuaEAUZkQGREFAIyjGJK55\nldEkbkhAjYYsCC6gJFETRRQV1+jgArwTVzQKgssriIowig5LXAA3hGGZAWSRYeZ5/6i60NO3+97q\nru6uvvf+vp9PfZiurnPO09WXPudUnTonMhNJkiRJ0vQ2aToASZIkSZop7EBJkiRJUkV2oCRJkiSp\nIjtQkiRJklSRHShJkiRJqsgOlCRJkiRVZAdKkiRJkiqyAyVJkiRJFdmBkiRJkqSK7EBJkiRJUkV2\noCRJkiSpohnRgYqIl0bEhoh4WNOx1BERx0TE5QPI5+8i4lcRMX8QcQ1KRLwlIjY0Hce46Of7Htfv\nVpLUWae6b9DtlnFoB01Vp0XEpu2xWZ9pNpsRHSggy20sRMR+EXFCRDyghzT3B44B3jGAEJYDmwF/\nN4C8BqnS99TP+WvSiL/v5YzndytJ6qxT3ddzu2WauqbRdtBUdVpEBPAR4OMR8eCWt5ZjfaZZaqZ0\noD4JbJ6Zv246kNL+wPHAoh7SvAKYB5zZ/kZEvDgiLiyvLt0dEctb3vtguf/miPhiRGyXmX8ETgP+\nsd7HaEw/569JA/2+pzILvltJUn/tlqnqmqbbQVPVaf9BEd8LgLdHxNZgfabZbUZ0oLJwd9NxtIg+\n0rwUOKvT58jMz2Tm/sClwOWZ+dKWt08BzgIenpnPycwbyv0rgJ0j4il9xNK0fs7f9JlGbDGMfBnw\n913BTP5uJWlGGGKd0W+7pWtdMwbtoJfSvU77UGZ+OzNvBl5N0dGaYH2mWWlGdKDax/5OjDeOiD+J\niOURsSYi1kbExyNiYVvaiWN3i4gVEXFLRNwYEe+JiAUtxy2PiKs7lL3R2OaIOAH4z/LlNWXe66ca\nlxwROwOPBc6b5qN+FFgSEXuX6bYDXgs8PzNvaz0wM1cCNwPPniZPIuJhEXFqRFwZEXeUn39FRDy8\n02eteF4PiIgfRsSdEfHziDhiujjKdFOevz5ifUxEnB4RNwPnt7z/lIi4uDW+LuPUdyo/33URcVdE\nrIqIl1WNt8tn3Jlq33dHvXy3kjSX9VDHT1dnTFkXtBxXqe5rb7e0lPFfEfG7soyryvpu0wp1Y8dn\noCJir4j4Wvm5b4uI8yLiCV3O0bR1e5fPsjNT1GmZeVXLv2/LzBtbXlufaVbatOkAKmof+zvx7xXA\nVcCxwFLgcOB64Lgux15dHvtEio7JIoqrKp3K6Fb254BHAYcBrwNuKvffQHf7l3msnOIYgM8A7wQO\nj4grgLcBR2fm+i7HrwSeNE2eAI+n+MxnAL8FdgZeBXwrInbPzLvK4yqd14hYDJwDrKYYbjAfeEv5\nejrTnb9eY/1v4GdlbFHGtxfwNeBa4M0Uf+dvBm5sSUdEPAj4AbCe4k7fjcCfAf8VEffPzFOAz08T\nbydVv++pVP1uJWku66WOh851RpW6oNe6b6O2Q0TsCPwQeADwYeB/gQcDfwVswfR1zaQ2SkTsDnwH\nuIXi2aR7KJ43+nZEHJSZP+xwjqZrM3VSt06zPtPsk5ljvwEvofhhe1j5+gRgA/CRtuM+B6xu2zdx\n7Ofb9r+/zHNx+foTwFUdyj4BWN+2759a46kQ/4nl8VtUOHY5xY/hh4Edpjn2Q8AfKuS5oMO+fcvz\n8uIO52rK8wp8AbgdeHDLvt2Ade3nqks8Xc9fH7F+qsPxZwG3Adu37HskcHdrfMDHKDppi9rSn05x\nxWzBoL9v4K8pOsnvadm3GUWFvmOv362bm5vbXN56qOOnqjOq1gWV6z4mt1tOK4/ba4rPMlXd+JL2\n98p47qQY4j+xb4eyDfGtDudo2jZTl7g61mnWZ25zeZsRQ/i6SIpORqvzgQdGxP06HPuBtn3vo7j6\n9MzhhLeRBwL3ZOYdFY79KHB/YFVmXjfNsWuAzae7BZ/Fg5zAvVONbkNxFWotxVWojQ5nivMaEZsA\nTwe+kJm/aynjfymuzNVSN9YyvoOBL2bm9S35XkVxV6rV84AvA/Mi4oETG/B1YKsO5VXV8fuOiN0o\nOnGfB17e8tbewIOA1u+70ncrSapcx3eq32DqumARsLRO3RcRQTGE7azM/FFPn6x7npsATyvj+VVL\nPNdRdPwOaGsL9dJmajepTrM+01w3kztQAO2z0awp/7t1h2N/0fb6lxRXZHYecEx17U1xJWxZhWMn\nHjidcmrTiFgYESdGxK+BP1IMT1hN0UnYqkOSqc7rdsDmTD6fUAxJqKWPWNufW3vQFPHduy+K58sW\nAUdQDJFo3T7ektcgPQn4IvAU4OKW/fsDF2Zm6/dY6buVJAHV6/iN6owKdUFS1AV16r7tKIbuXTbN\ncb3YjmLo3886vHcFRfvuoW37e2kzTcf6THPaTHkGqptuzwZVmTWt05oNnczrsr8XNwGbRsSWmXl7\nt4Mi4gUUV23eBZwcEXtk5lQ/uFsDd7Teteni/RS3/08Gvk9xez+B/0vnTnSd81pXr7He2Wc5E3l9\nmmJoRSc/6TPvjt93Zn4cICKeQzFkdML+FOPYW1X9biVJk3Wr09vrjKp1wSDaAk3rt26fVKdZn2mu\nm+kdqF7sCvyq5fUuFD+cE1ej1tB57YWdO+zr9SrKleV/HwGs6nRARDyZ4pmnUyJiEcUDoUdQPEza\nzSMorjRN5/nA8sw8pqW8BfS3DtMNFBXQrh3ee3TFPKY6f3VjXQ3cRfH9tmuN+QaK56TmZeY3p8lz\nYN93FOtj7E3x4PCE/YCT2vKo+t1Kkqav47upVBeUQ+b6rftuAG4FFk9zXC91zQ3AHRTPYLV7DMXd\nt9/0kN9UOtZp1meay2b6EL6qgmJtglavpfixOrt8/Utgq3KWnSJRMWvOczrkN3FXoWqj/ntlDPt0\nDK6YSefgLGf6ycy1FA93/nXrNKwdLAUurFD+eiZ/16+ljytqmbmBYrz3cyLiIRP7I+IxFOPDq5jq\n/NWKtYzvvDK+HVri2wV4RttxnwOeHxF7tOcTEdtWjLeTqb7vPwHWZuZvynIeSTE08Ydtx1X9biVp\nrpuqjm9/9nUjVeuCOnVfOZzti8BfRsRUz9ZWrmvKeL4OPLttqvTtKR4BOD8z/zBdPhV1q9OszzRn\nzaU7UI+IiC9RdJj2B14MfDozf1q+fybFatpfjIhTgC2BIynGNrf/4F1C8WPytog4k2JmnbMys+Nw\nssy8OiJWAYdQzLJ3r4jYFzguM5/bluxjZYwvpFjhm7Z0ewPbUPwoT+crwN9ExK3A5RRXiA6meL6o\nHydQdEYuiIhTKaZyfQ3FlanHVkg/1fkbRKxvoajQLoyID1L8nb+6jG/PluOOpRi//YOI+GhZ3jYU\nV9SeCkx0ogb2fVNc6ZwfEVFWqkcBF2XmuokDevxuJUnd6/iOoz7aVK0L6tR9b6CY9OE7EfERijsy\nO1HcvXlSZt5Kj3UN8CaKeua7ZTzrKUaubAYc0yVNz6ao06zPNHc1PQ1glY3O05ivB7aZ6ri2Y3ej\nWANhLUVj/D3AZm3pDwYupbhNfznFVZxJ05iXx76B4oHMde1ldvkM/0DxPM/EdKjPoRgnvI5ifPHT\nW459IsWY6/VlvJ/r8FnfAVxd8fw9gKJDdn0Zw/9QDEO4CvivDueqynk9ALioPFc/B17Z7Vx1ianj\n+aO4etV3rC3vP4XiwdaJ+A6nmG719rbjtqVY9+MaiqF/v6O4qvfyQX7fbe/9K8UzXsdS3Pk8vt/v\n1s3NzW0ub1Xr+Ap1RtW6oFLd16XefAjF80LXUQy/+znwXmDTlmO61Y2T8iv37wl8taxvbgPOBfbt\nco6mrdunOM8d6zTrM7e5ukXm7J4UJYrVvY8HtsvMmxuM4wEUPy7HZOYnpjt+mrw2o/iRf1tmvn8A\n4c0JEfEFYPfM7DRmfNB
2017-03-01 22:29:54 +08:00
"text/plain": [
2017-03-02 01:49:41 +08:00
"<matplotlib.figure.Figure at 0x1180be390>"
2017-03-01 22:29:54 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X, y = next_batch(FLAGS.batch_size, FLAGS.length, FLAGS.xlen)\n",
"tsteps = 2*FLAGS.length+3\n",
"\n",
"feed = {dnc.X: X, dnc.y: y, dnc.tsteps: tsteps}\n",
"fetch = [outputs['y_hat'], outputs['w_w'], outputs['w_r'], outputs['f'], outputs['g_a']]\n",
"[_y_hat, _w_w, _w_r, _f, _g_a] = sess.run(fetch, feed)\n",
"_y_hat = np.clip(_y_hat, 1e-6, 1-1e-6)\n",
"_y = y[0] ; _X = X[0]\n",
"\n",
"fig, ((ax1,ax2),(ax3,ax5),(ax4,ax6),) = plt.subplots(nrows=3, ncols=2)\n",
"plt.rcParams['savefig.facecolor'] = \"0.8\"\n",
"fs = 12 # font size\n",
"fig.set_figwidth(10)\n",
"fig.set_figheight(5)\n",
"\n",
"ax1.imshow(_X.T - _y.T, interpolation='none') ; ax1.set_title('input ($X$) and target ($y$)')\n",
"ax2.imshow(_y_hat[0,-FLAGS.length:,:].T, interpolation='none') ; ax2.set_title('prediction ($\\hat y$)')\n",
"\n",
"ax3.imshow(_w_w[0,:,:].T, interpolation='none') ; ax3.set_title('write weighting ($w_w$)')\n",
"ax4.imshow(_w_r[0,:,:,0].T, interpolation='none') ; ax4.set_title('read weighting ($w_r$)')\n",
"\n",
"ax5.imshow(_f[0,:,:].T, interpolation='none') ; ax5.set_title('free gate ($f$)') ; ax5.set_aspect(3)\n",
"ax6.imshow(_g_a[0,:,:].T, interpolation='none') ; ax6.set_title('allocation gate ($g_a$)') ; ax6.set_aspect(3)\n",
"\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
2017-03-02 01:49:41 +08:00
"execution_count": 123,
2017-03-01 22:29:54 +08:00
"metadata": {
2017-03-02 01:49:41 +08:00
"collapsed": false
2017-03-01 22:29:54 +08:00
},
2017-03-02 01:49:41 +08:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"starting free recall trials: \n",
"\ttrial #0\n",
"\ttrial #100\n",
"\ttrial #200\n",
"\ttrial #300\n",
"\ttrial #400\n",
"\ttrial #500\n",
"\ttrial #600\n",
"\ttrial #700\n",
"\ttrial #800\n",
"\ttrial #900\n"
]
}
],
"source": [
"recall_orders = []\n",
"trials = 1000 ; print \"starting free recall trials: \"\n",
"for i in range(trials):\n",
" X, y = next_batch(FLAGS.batch_size, FLAGS.length, FLAGS.xlen)\n",
" tsteps = 2*FLAGS.length+3\n",
"\n",
" feed = {dnc.X: X, dnc.tsteps: tsteps}\n",
" _y_hat = outputs['y_hat'].eval(feed)\n",
" _y_hat = np.clip(_y_hat, 1e-6, 1-1e-6)\n",
" _y = y[0] ; _X = X[0]\n",
" \n",
" real_y = np.tile(_y[-FLAGS.length:,:], [1,1,1])\n",
" pred_y = _y_hat[0:1,-FLAGS.length:,:]\n",
" order = guess_recall_order(real_y, pred_y, FLAGS.length)\n",
" \n",
" recall_orders.append(order)\n",
" if (i%100 == 0): print(\"\\ttrial #{}\".format(i))\n",
"recall_orders = np.stack(recall_orders)"
]
2017-03-01 22:29:54 +08:00
},
{
"cell_type": "code",
2017-03-02 01:49:41 +08:00
"execution_count": 131,
2017-03-01 22:29:54 +08:00
"metadata": {
2017-03-02 01:49:41 +08:00
"collapsed": false
2017-03-01 22:29:54 +08:00
},
2017-03-02 01:49:41 +08:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAADFCAYAAABD7BVnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJztnXe4XFXV/z/fIDWQUE0oQmiGIBAIijTpEPD3AlJUQAQR\nRIqCAQQRMfRiIVIEQRBQEEREDBZCC70ICeXllSYlIB2BJEAgbf3+WHu4505m7p1z7rR77/o8z3nu\nnH322WfNmTtnzV577e+WmREEQRAEtTKg1QYEQRAEvYtwHEEQBEEuwnEEQRAEuQjHEQRBEOQiHEcQ\nBEGQi3AcQRAEQS7CcQRBEAS5CMcRBEEQ5CIcRxAEQZCLcBxB3ZG0mqSbJL0raY6knVptUzsi6RuS\n5kpaMVN2u6Tbajx/oKTXJe3ZOCtrsmOl9D72qbH+XEk/brRdjUTSaEnTJS3ValtaQTiOXoKkfdMX\nrrTNkPSypBslfVfSohXOGZvqvippoQrHX5A0vkL5gpLGSLo/PfxnSHpK0rmSVq/B3N8CnwF+CHwd\neKjAW+4PWNrKy2rle8A04Oq6WVScTnZL2kHS2C7qtq3WkaSNJd0t6f303Tlb0sBsHTObAPwbOLY1\nVraWT7TagCAXBhwPvADMDwwFtgB+ARwhaScz+98K530SOBgYV6G9TqRfUBOA9YC/AlcC7wHDgT2A\nbwHzOKHM+QsBGwInm9n5tb+1IA+SPgEcBvzcWiw4Z2ZTJC0MzMoUfxE4BDixwikLA7ObYVteJK0L\n3AL8CxgDrAB8H1gN+H9l1S8EfipprJm931RDW0w4jt7HjWY2ObN/pqQtgL8Bf5E0wsw+KjvnEeD7\nks6vcKycy4GRwG5mdn32gKTjgVO7Of+T6e/UbuohaREz+6C7eo1E0kJm9mErbSjIjsDSwB9bbQiA\nmc0sK1KOuu3EacDbwOYlZyBpCnCRpG3M7JZM3T8B5wJfBi5rtqGtJEJVfQAzux04GVgJ2Lv8MHAS\n3js5uKt2JG2A/1K8uNxppOvMMrOjuzh/LN4bMuBnKUz2XDp2QtofIen3kt4G7sqcO1zStZL+m0Jj\nD0rascI1Bkv6haQXJX0o6RlJR0uq+qDKnPuCpPGStkvtzwAOzBzfW9JDkj5IdlwlaYUK7Xxe0t8l\nvS3pPUmPSjosc3xtSZdKeja9l1clXSJpye5szMHOwAtm9nyZbZel2PvKkiYk+15OTr/8fSwi6eeZ\ne/mkpCMr1NtW0l2S3kltPynp1MzxTmMcki7Fexul8Yy5kuZk6s8zxiFpPUn/kDQ1XeMWSZ8vq1MK\n124s6SxJb6T3d53qMNYgaTFgG+B3ZT2I3wLvA1/J1jezN4HH8M+iXxE9jr7D7/BfS9sBl5Qduwu4\nDTha0gVd9Dp2wh/6VxS04U/AO3jo7PfA3/EwF3SExf4IPI3HhgUg6TPA3cB/gNPp+JJeL2lXM/tL\nqrcwcCewLPAr4CVg43TOUOCIbuwzYI1k24XARcBTqe3jcAd7NfBrYBk8FHSHpPXMbFqqty1wA/BK\nep+vASPwMMY56TrbAisDv0nHPwN8G1gT2KgbG2tlY2ByhXLDfxDeCNyHh1m2B06UNJ+ZnZCpewOw\nOXAx8CgwGg+9LGdmRwJIWjPVewQPk36Eh2027sK2XwHL4Q/hr9FF7yNzjTvxXuoZeBjr28DtkjYz\nswfLTjkX7xWcAAzDQ0rnAR8nCaQxiaoh1QyzSp8tsDb+TJyUrWBmsyQ9godvy5lEP3QcmFlsvWAD\n9gXmAKO6qPMO8FBmf2w6Z0ngC8Bc4PDM8eeB8Zn9P6X6g3pg50rpOkeUlY9N5b+rcM4twMPAJ8rK\n7waezOz/CB8MXqWs3mnATGD5bmx7Pr2/bcrKV8Tj88eUla+Z2v1B2h8APAc8CyzWxXUWrFD21XTt\nTSp8pitmyiYCt3XzPuZL5/2kwrFL07FxZeU3ADOAJdP+zunz+EFZvWvwB/fKaf/w1N4SNXzm+2TK\nzgXmVKk/F/hxZv/PybaVMmVDcUcysex+zcXDtdn2fp4+p8Ws832YW8N2W+ac3co/o8yxPwAvVyj/\nQTpn6aLfmd64Raiqb/EesFilA2Z2F/5QOlrSglXOH5T+Tm+AbeC/hi/MFkhaAtgS74kMlrRUaQNu\nAlaXtGyqvjvee5paVu9W/JfiZjXY8Lx1jlODPzAE/LGs3TeAZ5J9AKPwX7i/MLOq98gyPTp5htpS\nwAPpGqNqsLE7lkxtvdNFnV+W7Z8HLIj3AsBDkrPxB3yWn+MOcoe0/276u0st4cC8SBqA99D+bGZT\nSuVm9hreM9xUnTMGDe8pZrkLd6YrZcrOxN9rd1s2NLdw+lupR/5h5niW0mewdOV32DeJUFXfYlHg\n9S6OnwDcARwEnF3heKnLvljmdb15vmx/NfwheDJwSoX6hg+4vwqsjocT3uyiXt7rl2wYgKdXVmq3\nNJi7Str/v64ukJzhCXgvI2uTAYNrsLFWqj3I5+I9oyxPp7/D0t8VgVds3mygJ9Lf0kP4D8D+ePju\nDEm3AtcB11r6yd1DlgEWydhXbssA4FMZu8BDlFlKD+8lSgVm9iTwZE5bZqS/lX5YLZQ5nqX0GbRt\nenEjCMfRR5C0PP5QqvTwA7zXIel2vNdxYYUqpS/a2sA9dTfSKf/ylXq9P8PTgCvx70zdm/Ffk5Ue\nmpUePt1dv9TuXHwsYG6F4+9VKOuKP+IpyT/Bxw7eS9eYQH0SUt7GH1RLdFexp5hnnG0maUt8HGd7\n3CHeKmm7OjmPvMypUv7x/4SkQVTuIZQz08xKjufV1MayFeoti49rlVP6DN6q4Vp9hnAcfYd98IfJ\njd3UOwEPWX27wrEb8EHrvWmc4yin9Mt4lpl1N2P6WWBRM5tYZxuexR8YL5hZVcebqbcWnmwwD5IW\nB7YCjjezbObRavUy1szmSHoWH4CvxAC8d5R9L8PT31KPawqwtaSBZb2OEZnj2WtOxP9vjpJ0LN47\n3JIq94Haf4G/CXyQsS/LCNyRl/cwauFsfEykO27HPy+Ax/Hw3WeBa0sVJM0PrIv3vspZGXjLzP5b\nwMZeS4xx9AEkbYUPHD+Hx4WrYmZ34uGqYyjLOjGz+3HHc4CkeTJFJC0g6af1sjtd8038y/ttSUMr\nXDMbO74G2EjSdhXqDZY0X0EzrsMfUBVnOmfSaCfjD97vSaoWcir9Gi7/bo2hvuGM+/AHXDW+U2F/\nJh0P+r/jPxzL643B78U/4OOwWzmP4g602lgZeGZc6Zd/VcxsLj6WtbM6S68MwbOk7jKzvD0+KDDG\nYZ5ddQuwtzrPFN8HGIj//5WzPv5Z9Cuix9G7EPBFSSPwz24I/mtpW/yBtpPVNrnqRPzXYyX2wUMq\nf5L0V3zg+X18fGEPPNvl+z15ExU4FB/g/F9Jv8Yd4BA8dXV5OtIgf4qnDP9V0mV4KuRAYB1gVzx+\n/3bei5vZc5J+BJwmaWXgejxBYBXgS/iA/llmZpIOBsYDj6T5Cq/iKb5rmtkOZjZd0p14OHAB4GU8\nRXoY3aSl5uQv+ANutQq9pI+A7dM9egAfCN8BODXzy/gG/H/g1PSeS+m4O+IZWaWeyY8lbYZPMJ2C\nfy4HAy/iWW/VmIS/33MlTcAzrCr9Ygf/0bMNcI+k83HneyCwAFA+b6jaPexUXnCMA+A4vLd9p6SL\n8PGVI4AJZnZzpwtKy+D/e+UJBn2fVqd1xVbbRkfqZmmbgT+UbsQfvAMrnPNxOm6FYxPTsb9UOLYg\n/svzfjwlcgY+fnAusGo3dq6U2h1Tqy3p+DA8hfJlPIPlRfzh+KWyeovgYZKnkl2v407ne8B83dj2\nXKX3mzn+Jbw3Ni1t/4eHPFYrq7dRuu/vpnoPAwdnji+Lhzr+izuyq/AH7hw8hFX+mZan495aw//D\n/HjW1w/Lyi9NNg1LNk7HY/PHV2hjEXxs6aV0z5+s8LltgffIXkr3+yV8ztCqmTqlzzybjjuAjnku\ns8mk5pbfh1Q2Eu8FTU0
"text/plain": [
"<matplotlib.figure.Figure at 0x117b997d0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAADFCAYAAABD7BVnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJztnXu4XOP1xz/fEJeEhEQk7vcSdwkpirhF0LorgkbVnV8R\nSqlq0BKXiroWbSXRajSlUZQkEteWNHEScQ1F4hIRQuROImf9/lh7ZM6YOWdmzpyZOeesz/PsZ2a/\n+93vXvudmb3mXet915KZEQRBEAT50qbSAgRBEATNi1AcQRAEQUGE4giCIAgKIhRHEARBUBChOIIg\nCIKCCMURBEEQFEQojiAIgqAgQnEEQRAEBRGKIwiCICiIUBxByZG0uaQxkr6QtEzSoZWWqRqR9GNJ\ntZI2TCt7WtKTeZ7fXtIsSf2aTsq85NgouY/+edavlfSrpparKZHUV9J8SZ0rLUslCMXRTJB0UvKD\nS22LJc2QNErSTyWtluWcgUndmZJWyXJ8uqSHs5SvLGmApPHJw3+xpDcl3SppizzEvRfYBvgF8CPg\nxSJuuTVgyZZZli/nA/OA+0smUfHUkVvSQZIG1lO3KmMdSeoj6U+SXpH0taR3s9Uzs9HA28Cl5ZWw\nOlix0gIEBWHA5cB0oC3QDdgb+B1wgaRDzeyVLOetDZwF3JSlvTok/6BGAzsBjwL3AQuALYHjgNOA\nbymhtPNXAXYFfm1md+R/a0EhSFoROBe40SoccM7M3pO0KrA0rfhg4GzgyiynrAp8XQ7ZiuB44Bhg\nEjCjgbp3ATdIGmhmC5tcsioiRhzNj1Fm9lczG2Zm15nZQcB+uHL4p6SVs5zzEnBRjmOZDAN2AI4y\ns8PM7FYzG2JmlwBbALc0cP7ayevchi4kqV0e8jQp2UZizYRDgLWAv1daEAAzW5KhwNRA3doyiFUM\nlwIdzGxP4OUG6j6I/4n6YZNLVWWE4mgBmNnTwK+BjYATMw8DV+Gjk7Pqa0dSL/yf4h/N7KEs11lq\nZhfXc/5AfDRkwG8TM9m7ybErkv3ukv4q6XPgubRzt5T0gKTPEtPYREmHZLlGR0m/k/S+pC8l/U/S\nxZJyPqjSzp0u6WFJByTtLwZOTzt+oqQXJS1K5Bguaf0s7XxX0mOSPpe0QNIUSeemHd9O0hBJ7yT3\nMjMxf3RqSMYCOAyYbmbTMmQbmtjeN5E0OpFvhqTLs9xHO0k3pvXlVEkXZqnXR9JzkuYkbU+VdHXa\n8To+DklD8NFGyp9RK2lZWv1v+Tgk7STpcUlzk2uMlfTdjDopc+3ukgZL+iS5v3+oRL4GM/vYzJY1\nXBPM7FNcuRxWims3J8JU1XL4M3ANcADwp4xjzwFPAhdL+r2ZfZWjjUPxh/5fipThQWAObjr7K/AY\nbuaC5WaxvwNv4f/sBCBpG+DfwIfAIGAhbi54SNKRZvbPpN6qwLPAOsCdwAfA7sk53YALGpDPgK0S\n2e4C7gbeTNq+DFew9wN/ALrgpqBnJO1kZvOSen2AR4CPkvv8GOgOfJ/lo7E+wCbAPcnxbYAzgK2B\n3RqQMV92x80p2e6xDTAKeAG4CDgQuFLSCmZ2RVrdR4DewB+BKUBf3PSyrpldCCBp66TeS7iZ9Ctg\n8+T6ubgTWBfYHziBekYfadd4Fh+lXoubsc4Anpa0l5lNzDjlVuBz4ApgY2AAcBvwzSQBSe2px6Sa\nxtLUZ1skNbRCxYGZxdYMNuAkYBnQo546c4AX0/YHJud0AvYEaoHz0o5PAx5O238wqd+hEXJulFzn\ngozygUn5n7OcMxaYDKyYUf5vYGra/i9xZ/CmGfWuAZYA6zUg27Tk/vbPKN8Qt8//PKN866TdS5L9\nNsC7wDvA6vVcZ+UsZccm1/5els90w7Syp4AnG7iPFZLzrs9ybEhy7KaM8keAxUCnZP+w5PO4JKPe\nCPzBvUmyf17S3pp5fOb908puBZblqF8L/Cptf2Qi20ZpZd1wRfJURn/V4uba9PZuTD6n1a1uP9Tm\nseXs66TP3m3gs7gk6Z+1iv3NNMctTFUtiwXA6tkOmNlz+EPp4np8HR2S1/lNIBv4v+G70gskrQns\ng49EOkrqnNqAMcAWktZJqh+Nj57mZtQbh4+e98pDhmlmNjaj7Cj8X/HfM9r9BPhfIh9AD/wf7u/M\nLGcfWdqITj5DrTPw3+QaPfKQsSE6JW3NqafO7Rn7twEr46MAcJPk1/gDPp0bcQV5ULL/RfJ6RD7m\nwEKR1AYfoY00s/dS5Wb2MT4y3EN1ZwwaPlJM5zlcmW6UVnYdfq8Nbd8yzRVI6jNYq5HtNCvCVNWy\nWA2YVc/xK4BngDOBm7McTw3ZV097X2qmZexvjj8Efw38Jkt9wx3uM3Hn/HbAp/XUK/T6KRna4NMr\ns7W7JHm/abL/Wn0XSJThFfgoI10mAzrmIWO+5HqQ1+Ijo3TeSl43Tl43BD6yb88GeiN5TT2E/wac\ngpvvrpU0DvgH8IAlf7kbSRegXZp8mbK0ATZIkwvcRJlO6uG9ZqrAzKYCU0sgX0OkPoOqnF7cVITi\naCFIWg9/KGV7+AE+6pD0ND7quCtLldQPbTvgPyUX0lmcsZ8a9f4WnwacjbfT6j6B/5vM9tDM9vBp\n6PqpdmtxX0C22T4LspTVx9/xKcnX476DBck1RlOaCSmf4w+qNRuq2FjM7EtgL0n74H6cA3GFOE7S\nASVSHoWSy3n9zXdCUgd82m9DLDGz+kZuDZH6DGY3oo1mRyiOlkN//GEyqoF6V+AmqzOyHHsEd1qf\nSNMpjkxS/4yXmllDK6bfAVYzs6dKLMM7+ENnupnlVLxp9bbFJxt8C0lrAPsCl5tZ+syjzUslrJkt\nk/QO7oDPRht8dJR+L1smr6kR13vAfpLaZ4w6uqcdT7/mU/j35meSLsVHh/uQox/I/x/4p8CiNPnS\n6Y4r8swRRj7cjPtEGuJp/PMqlk2A2Wb2WSPaaHaEj6MFIGlf3HH8Lm4XzomZPYubq35OxqwTMxuP\nK55TJX1rpoiklSTdUCq5k2t+iv94z5DULcs1023HI4DdJB2QpV5HSSsUKcY/8AdU1pXOadNoJ+EP\n3vMl5TI5pf4NZ/62BlBac8YLwM71HP+/LPtLWP6gfwz/45hZbwDeF4/DN2a3TKbgCrS+dUELk/M7\n1FMH8/UcY4DDVDf0Sld8ltRzZlboiA/K5+PoiX8WrYoYcTQvBBwsqTv+2XXF/y31wR9oh5rZknrO\nT3El/u8xG/1xk8qDkh7FHc8Lcf/Ccfhsl4sacxNZOAd3cL4i6Q+4AuyKT11dD1/FDnADPmX4UUlD\n8amQ7YHtgSNx+/3nhV7czN6V9EvgGkmbAA/hEwQ2BQ7HHfqDzcwknQU8DLyUrFeYiU/x3drMDjKz\n+ZKexc2BK+Grjw9IZCulc/mfwImSNs8ySvoKODDpo//ijvCDgKvT/hk/gn8Hrk7uOTUd9xB8RlZq\nZPIrSXsB/8JHIV3x9UDv47PeclGD3++tkkbjM6z+lqPuL/GH+H8k3YEr39OBlYDMdUO5+rBOebE+\nDknb4d8xcN9Xx2SqNsAUM3s0rW4X/LuXOcGg5VPpaV2x5bexfOpmaluMP5RG4Q/e9lnO+WY6bpZj\nTyXH/pnl2Mr4P8/x+JTIxbj/4FZgswbk3Chpd0C+siTHN8anUM4AvsQfTP8EDs+o1w43k7yZyDUL\nVzrnAys0INu72e437fjh+GhsXrK9hps8Ns+ot1vS718k9SYDZ6UdXwd4APgMV2TD8QfuMtyElfmZ\nZk7HHZfH96EtPuvrFxnlQxKZNk5knI+vObk8SxvtcN/SB0mfT83yue2Nj8g+SPr7A3zN0GZpdVKf\nefp03DYsX+fyNWlTczP7ISnbAR8FzU1kfgLoleM30COjvHdSvlcT/M7St3sy6p6ZyPqt315L35R0\nQBAEzYxklHQyrtgsKRuCh4up10QUNB5Jk/B1ID+rtCzlpmp8HJLOkTRNHqJhvKRd6ql7hDxs9ydJ\niILnM+3eSZ2JSZiEBZI
"text/plain": [
"<matplotlib.figure.Figure at 0x117091290>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAADFCAYAAABD7BVnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJztnXm4XdP5xz/fmIdQc2pWUWKKoaUIqqZqixpqKI2qVrXG\nlAZVYqipSIi5tCjFzxRCETWP0UoMLTElxDzHEIJI3t8f7zrJvifn3HvOvme6976f59nPPXvttdd+\nzz737Pestd71fWVmBEEQBEGl9Gq2AUEQBEHXIhxHEARBUBXhOIIgCIKqCMcRBEEQVEU4jiAIgqAq\nwnEEQRAEVRGOIwiCIKiKcBxBEARBVYTjCIIgCKoiHEdQcyT1lXSHpA8lTZO0XbNtakUk/VzSdEnL\nZsrulXR3hefPJ+ltSbvXz8qK7FguvY+BFdafLumYettVTyRtLekTSYs025ZmEI6jiyBpr/SFK2xT\nJL0u6XZJB0qav8Q5Q1LdNyXNXeL4y5JGliifS9IgSaPTw3+KpOcknS1ppQrM/TuwGvAH4GfAYzne\nck/A0lZcVimHAB8DV9fMovy0sVvSNpKGtFO35bSOJM0jaX9JoyS9IeljSWMl7SepzbPSzEYBLwJH\nNsfa5hKOo2thwB+BPYH9gOGp7Ezgv5LWKHPe4sBvyrTXhvQL6iHgdOBt4Gjgt8AIYFvgv+0ZmBzU\nd4CLzew8M7vSzN7o+K0F1SBpduAg4CJrsuCcmU0E5gEuzxT/ACjXq5gHOLHeduXgG/h3CuAM4FBg\nAnAe8NcS9S8Efi1pvsaY1zrM3mwDgqq53czGZvZPlfRd4J/ATZL6mdkXRec8Afxe0nkljhVzGdAf\n2MnMbswekHQ0HX/hF09/P+qgHpLmNbPPOqpXTyTNbWafN9OGnGwLLApc22xDAMzsy6IiVVG3VXgL\nWN3MxmXKLpL0V+Dnkk4wswmZY9cDZwM/AS5tnJnNJ3oc3QAzuxc4AVgO7420OQwcD/ShdK9jBpLW\nw38pXlzsNNJ1pprZ4HbOHwK8nK55ehomm5COHZv2+0m6UtIHwAOZc1eWdJ2k99PQ2H8kbVviGgtK\nOlPSK5I+l/SCpMGSyj6oMue+LGmkpK1S+1OAfTPH95T0mKTPkh1XSVq6RDvrS7pV0geSJkt6UtJB\nmeNrSLpE0vj0Xt6U9FdJC3dkYxVsD7xsZi8V2XZpGntfIQ25TE5DmkeXeB/zSjojcy+flXRoiXpb\nSnpA0qTU9rOSTswcbzPHIekSvJdamM+YLmlapv4scxyS1pZ0m6SP0jXulLR+UZ3CcO2GkoZKeie9\nvxtUg7kGM3u/yGkUGJH+9iuq/y7wFP5Z9Ciix9F9uBw4CdiKWbvVDwB3A4Mlnd9Or2M7/KF/RU4b\nrgcm4UNnVwK3ApPTscJwyrXA8/jYsAAkrQY8CLwGnAx8CuwC3ChpRzO7KdWbB7gf+DpwAfAqsGE6\npw/wuw7sM2CVZNuFwF+A51LbR+EO9mrgImAxfCjoPklrm9nHqd6WwM3AG+l9voU/UH7IzGGOLYEV\ngL+l46sBvwZWBTbowMZK2RAYW6Lc8B+EtwOPAL8Hvg8cJ2k2Mzs2U/dmYFPgYuBJYGvgNElLmtmh\nAJJWTfWewIctvwD6puuX4wJgSWALYA/a6X1krnE/3ks9BfgKv1/3StrEzP5TdMrZwAfAscDywCDg\nHGBGkEAaPpplXq8EUwufbTt8Pf19r8SxMfRAx4GZxdYFNmAvYBqwTjt1JgGPZfaHpHMWBjYGpgMH\nZ46/BIzM7F+f6i/QCTuXS9f5XVH5kFR+eYlz7gQeB2YvKn8QeDaz/0d8MvgbRfVOAr4ElurAtpfS\n+9uiqHxZYCpweFH5qqndI9J+L3zMezzQu53rzFWibNd07Y1KfKbLZsruAe7u4H3Mls77c4ljl6Rj\nw4rKbwamAAun/e3T53FEUb1r8Af3Cmn/4NTeQhV85gMzZWcD08rUnw4ck9kfkWxbLlPWB3ck9xTd\nr+n4cG22vTPS59Tb2t6H6RVsHd3rOYCngReAXiWOH5Huz6J5vzNdcYuhqu7FZKB3qQNm9gD+UBos\naa4y5y+Q/n5SB9vAfw1fmC2QtBCwGd4TWVDSIoUNuANYSVLhF9/OeO/po6J6d+G9500qsOElM7uz\nqGwn/FfxtUXtvoM/MDZL9dbBf+GeaWZl75FlenTyCLVFgEfTNdapwMaOWDi1NamdOucW7Z8DzIX3\nAsCHJL/CH/BZzsAd5DZp/8P0d4dKhgOrRR6ttCUwwnySHQAzewvvGQ5Q24hBw3uKWR7AnelymbJT\n8ffa0TbL0FwR5+K91APMbHqJ44XPYNEO2ulWxFBV92J+PBKqHMcC9+ERWWeVOF7osvfOvK41LxXt\n98UfgicAfypR3/AJ9zeBlYA1gHfbqVft9Qs29MLDK0u1W5jM/Ubaf7q9CyRneCzey8jaZMCCFdhY\nKeUe5NPxnlGW59Pf5dPfZYE3zOzTonqFMf7CQ/j/gH3w4btTJN0F3ABcZ+kndydZDJg3Y1+xLb2A\nZTJ2gQ9RZik8vBcqFJjZs8CznTFM0u+BXwJHmYfflqxWuGRnrtXVCMfRTZC0FP5QKvXwA7zXIele\nvNdxYYkqhS/aGnhIbj2YUrRf6PWeDpT7cr6Yqfsv/NdkqYdmqYdPR9cvtDsdnwso9atycomy9rgW\nD0n+Mz53MDldYxS1CUj5AH9QLdRRxc5iHnG2iaTN8Hmc7+MO8S5JW9XIeVTLtDLlM/4nJC2Ah/12\nxJdmNkvPTdLP8fmW88zs5HbOL3wGpeY/ui3hOLoPA/GHye0d1DsWH7L6dYljN+OT1ntSP8dRTOGX\n8VQz62jF9HhgfjO7p8Y2jMcfOi+bWVnHm6m3Oh5sMAuSvgZ8DzjazLKRR31rZayZTZM0Hp+AL0Uv\nvHeUfS8rp7+FHtdEYHNJ8xX1OvpljmeveQ/+f3OYpCPx3uFmlLkPVP4L/F3gs4x9Wfrhjry4h1EJ\nZ+FzIh1xL/55zUDS9ngP6zozO6CD81cA3jOz93PY2GWJOY5ugKTv4RPHE/Bx4bKY2f34cNXhFEWd\nmNlo3PH8Mn15iq8zp6TTamV3uua7+Jf315L6lLhmduz4GmADSVuVqLegpNlymnED/oAqudI5E0Y7\nFn/wHiKp3JBT4ddw8XdrELUdzngE+FY7x4sfeAfgQ26FB/2t+A/H4nqD8HtxG8wYdivmSdyBlpsr\nA4+MK/zyL0uaN7gD2F5tpVeWwKOkHjCzant8kHOOQ9ImwFX4/2RxaHsp1sU/ix5F9Di6FgJ+IKkf\n/tktgf9a2hJ/oG1nlS2uOg7/9ViKgfiQyvWSbsEnnj/F5xd2w6Ndft+ZN1GC/fEJzv9Kugh3gEvg\noatLAWuneqfhIcO3SLoUD4WcD1gT2BEfv/+g2oub2QRJfwROkrQCcCMeIPAN4Mf4hP5QMzNJvwFG\nAk+k9Qpv4pOnq5rZNmb2iaT78eHAOYHX8RDp5ekgLLVKbgL2lNS3RC/pC+D76R49ik+EbwOcmPll\nfDP+P3Bies+FcNxt8YisQs/kmPQw/SfeC1kCXw/0Ch71Vo4x+Ps9W9IoPMLq/8rU/SP+EH9I0nm4\n890XmBMoXjdU7h62Kc8zx5Ec10jccd4A7FIUD/CUmf03U38x/H+vOMCg+9PssK7YKtuYGbpZ2Kbg\nD6Xb8QfvfCXOmRGOW+LYPenYTSWOzYX/8hyNh0ROwecPzgZW7MDO5VK7gyq1JR1fHg+hfB34HH8w\n3QT8uKjevPgwyXPJrrdxp3MIMFsHtk0o9X4zx3+M98Y+TtvT+JBH36J6G6T7/mGq9zjwm8zxrwPX\nAe/jjuwq/IE7DR/CKv5Mi8Nx76rg/2EOPOrrD0XllySblk82foKvOTm6RBvz4nNLr6Z7/myJz+27\n+EP01XS/X8XXDK2YqVP4zLPhuL2Yuc7lKzKhucX3IZX1x3tBHyWb/wWsV+Y7sE5R+aapfJNOfscK\n7ZTbjimqv1+ydZbvXnf
"text/plain": [
"<matplotlib.figure.Figure at 0x1171a4350>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAADFCAYAAABD7BVnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJztnXmYVNXxv98PiCAixBVM3HfchcS4b2hw+Yp7BGNU3I1G\nxT1xwd3ghohicMMlBIIraiIYFESjaATRnxExCYtGVMBlRAVBqN8fdVuapnumu6enu2em3uc5z/Q9\n99xzq09P3+pTp06VzIwgCIIgyJcWlRYgCIIgaFyE4giCIAgKIhRHEARBUBChOIIgCIKCCMURBEEQ\nFEQojiAIgqAgQnEEQRAEBRGKIwiCICiIUBxBEARBQYTiCEqOpE0kPSfpS0mLJfWotEzViKQTJC2R\ntF5a3ThJL+R5/cqSPpXUq+GkzEuO9ZP3cVye7ZdIuqKh5WpIJHWXNE/S6pWWpRKE4mgkSDo++cKl\nynxJH0kaJem3ktpluaZv0vZjSW2ynJ8h6aks9a0l9ZE0IXn4z5c0VdJASZvmIe5DwFbA74FfA28U\n8ZabA5aUzLp8ORf4ChheMomKZxm5JR0gqW8tbasy1pGk30l6VdLs5P/+v5IGpyt3ADMbDfwH+F1l\nJK0sK1RagKAgDLgcmAG0AjoBewG3AedJ6mFm/y/LdWsBZwD9s/S3DMkvqNHADsAzwFDga2BzoCdw\nCrCcEkq7vg2wE3CNmQ3K/60FhSBpBeBs4BarcMA5M5spaSVgUVr1gcBvgKuyXLIS8H05ZCuCrsCb\nwDBgHtAZOBU4WNLWZvZ5WtvBwE2S+prZN+UXtXKE4mh8jDKzSWnH/STtBfwVGCmps5l9l3HNZOBC\nSYOynMvkQWA74AgzezL9hKTLgevquH6t5G9NHe2Q1NbMvq2rXUMiqY2ZLaikDEVyMLAG8EilBQEw\ns4UZVSqgbdVgZkdm1kmagI9zD+CBtFOPAQOBozLqmzxhqmoCmNk44BpgfeDYzNPA1fjs5Iza+pG0\nI/5L8d5MpZHcZ5GZXVTL9X3x2ZABNydmsmnJuSuT486S/izpc+CltGs3l/SopM8SE8E/JR2c5R4d\nJN0m6QNJCyT9W9JFknI+qNKunSHpKUm/SPqfj/+aTJ0/VtIbkr5N5BgmaZ0s/fxc0t8kfS7pa0lv\nSTo77fw2koYkZo75ianwPkmr1SVjARwCzDCz6RmyPZDY3jeUNDqR76NE6We+j7aSbkkby/cknZ+l\n3X6SXpL0RdL3e5KuSzu/zBqHpCH4bCO1nrFE0uK09sutcUjaQdKzkmqSe4yR9POMNilz7S6Sbk3M\nSV9LelwNu9YwE1eEy8ySzGwO8Db+WTQrYsbRdHgYuB74BXBfxrmXgBeAiyTdVcusowf+0P9TkTI8\nBnyBm87+DPwNN3PBUrPYI8D7uG1YAJK2Al4G/gfcAHwD/BJ4UtLhZjYyabcSMB5YG/gj8CGwS3JN\nJ+C8OuQzYItEtsHA3cDUpO9LcQU7HLgHWBM3Bb0oaQcz+ypptx/wNDAreZ+f4OaMg4Dbk/vsB2wI\n3J+c3wo4DdgS2LkOGfNlF2BSlnrDfxCOAl4FLgT2B66S1NLMrkxr+zSwJ3Av8BbQHTe9/NjMzgeQ\ntGXSbjJuJv0O2CS5fy7+CPwY2Bf4FbXMPtLuMR6fpf4Bf0CfBoyTtIeZ/TPjkoHA58CVwAZAH+AO\n4AcnAUkrU4tJNY1Fqc82Q6bV8efjZvj/11Tg8SzXT6QZKg7MLEojKMDxwGKgSy1tvgDeSDvum1yz\nGrA7sAQ4J+38dOCptOPHkvbt6yHn+sl9zsuo75vUP5zlmjG4XXmFjPqXgffSji/DF4M3ymh3PbAQ\n+Ekdsk1P3t++GfXr4fb5izPqt0z6vSQ5bgFMA/4LrFLLfVpnqTs6ufeuWT7T9dLqxgIv1PE+WibX\n3Zjl3JDkXP+M+qeB+cBqyfEhyedxSUa7EfiDe8Pk+Jykv1Xz+MyPS6sbCCzO0X4JcEXa8ROJbOun\n1XXCFcnYjPFagptr0/u7JfmcVrFlx2FJHmW5sQY6ZrT5J9Axx3u5JBmfNYr9zjTGEqaqpsXXwCrZ\nTpjZS/hD6SJJrXNc3z75O68BZAP/NTw4vULSqsDe+Eykg6TVUwV4DthU0tpJ8yPx2VNNRrvn8V+H\ne+Qhw3QzG5NRdwT+q/iRjH5nA/9O5APogv/Cvc3Mco6Rpc3o5B5qqwOvJffokoeMdbFa0tcXtbS5\nM+P4DqA1PgsAN0l+jz/g07kFV5AHJMdfJn8Py8ccWCiSWuAztCfMbGaq3sw+wWeGu2lZj0HDZ4rp\nvIQr0/XT6vrh77WuspxpDp/N7Av8Hz7L2gh4Vlk8F1n6GaxR13ttSoSpqmnRDvi0lvNXAi8CpwMD\nspxPTdlXSXtdaqZnHG+CPwSvAa7N0t7wBfePgU2BbYA5tbQr9P4pGVrg7pXZ+k0t5m6UHP+rthsk\nyvBKfJaRLpMBHfKQMV9yPciX4DOjdN5P/m6Q/F0PmGXLewNNSf6mHsJ/AU7CzXd/kPQ8brJ51JKf\n3PVkTaBtmnyZsrQA1k2TC9xEmU7q4b1qqsLM3gPeK0YgM1uEm3YB/ibfV/MP4Le42Sqd1GdQle7F\nDUUojiaCpJ/gD6VsDz/AZx2SxuGzjsFZmqS+aNvgX5SGYH7GcWrWezPuBpyN/6S1/Tv+azLbQzPb\nw6eu+6f6XYKvBSzJcv7rLHW18QjuknwjvnbwdXKP0ZTGIeVz/EG1al0N64u5x9kekvbG13H2xxXi\n85J+USLlUSiLc9T/8D8hqT3u9lsXC82stpkbZvaqpI+Bn2c5nfoM5uZxryZDKI6mw3H4w2RUHe2u\nxE1Wp2U59zS+aH0sDac4Mkn9Ml5kZnXtmP4v0M7MxpZYhv/iD50ZZpZT8aa125qlv0iXQdKPgH2A\ny80s3fNok1IJa2aLJf0XX4DPRgt8dpT+XjZP/qZmXDOBbpJWzph1dE47n37Psfj/zQWSfofPDvcm\nxziQ/y/wOcC3afKl0xlX5JkzjHwYgK+J1MU4/POqizZk/1GxITDXzD7LX7TGT6xxNAEk7YMvHE/D\n7cI5MbPxuLnqYjK8TsxsAq54Tpa0nKeIpBUl3VQquZN7zsG/vKdJ6pTlnum24xHAzpJ+kaVdB0kt\nixTjcfyhkHWnc5ob7ST8wXuupFwmp9Sv4czvVh9Ka854FfhpLefPynK8kDQTDP7DMbNdH3wsnoUf\nzG6ZvIUr0FxrZeCecalf/jkxsyX4WtYhWjb0SkfcS+olMyt0xgdFrHEk7snLzVIkHYHPLMZnuU9X\n/LNoVsSMo3Eh4EBJnfHPriP+a2k//IHWw/LbXHUV/usxG8fhJpXHJD2DLzx/g68v9MS9XS6sz5vI\nwpn4Auf/k3QPrgA74q6rP8F3sQPchLsMPyPpAdwVcmVgW+Bw3H6fvrM3L8xsmqTLgOslbQg8iTsI\nbAQcii/o32pmJukM4ClgcrJf4WPcxXdLMzvAzOZJGo+bA1cEPsJdpDegDrfUAhkJHCtpkyyzpO+A\n/ZMxeg1fCD8AuC7tl/HT+P/Adcl7TrnjHox7ZKVmJldI2gPfYDoT/1zOAD7Avd5yMRF/vwMljcY9\nrP6So+1l+EP8H5IG4cr3VGBFIHPfUK4xXKa+yDWOTYExkv6SXLsE+BnuUvw2yzt2rIn/72U6GDR9\nKu3WFSW/wlLXzVSZjz+URuEP3pWzXPODO26Wc2OTcyOznGuN//KcgLtEzsfXDwYCG9ch5/pJv33y\nlSU5vwHuQvkRsAB/MI0EDs1o1xY3k0xN5PoUVzrnAi3rkG1atvebdv5QfDb2VVL+hZs8Nslot3My\n7l8m7d4Ezkg7vzbwKPAZrsiG4Q/cxbgJK/MzzXTHfT6P/4dWuNfX7zPqhyQybZDIOA/fc3J5lj7a\n4mtLHyZj/l6Wz20vfEb2YTLeH+J7hjZOa5P6zNPdcVuwdJ/L96S55maOQ1K3HT4Lqklk/juwY47v\nQJeM+j2T+j3q+R1bHbg
"text/plain": [
"<matplotlib.figure.Figure at 0x11681aa10>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAADFCAYAAABD7BVnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJztnXe4FdXVh98fiAVUokZBo4ItgA3BqFiJPSY2LFEs5DN2\nTVTsFbBrNBKF2A1qNBqNipgYNVhixQLYscSKvaMiCnLX98eaI8Ph3HvPmXvavXe9zzPPPbNnz541\n+9wza/Zee/+2zIwgCIIgKJYOtTYgCIIgaF2E4wiCIAhKIhxHEARBUBLhOIIgCIKSCMcRBEEQlEQ4\njiAIgqAkwnEEQRAEJRGOIwiCICiJcBxBEARBSYTjCMqOpJUl3SPpC0mzJW1fa5vqEUn/J6lB0vKp\ntAck3Vfk+V0kfShpcOWsLMqOHsl9DCkyf4OkYZW2q5JI2lrSV5KWqLUttSAcRytB0m+SH1xumyHp\nXUl3Sfq9pIULnDM8yfu+pAULHH9T0rgC6QtIGippQvLwnyHpZUmjJK1ShLnXAqsBJwJ7A09luOX2\ngCVbflqxHAF8CdxYNouyM5fdkraRNLyJvHWvdSSpq6SPkt/QTuljZnY38D/ghNpYV1vmq7UBQUkY\ncArwJtAJ6A78HPgTcKSk7c3suQLnLQUcDIwsUN5cJG9QdwP9gH8C1wNfA72A3YH9gXmcUOr8BYEB\nwOlmdnHxtxaUgqT5gMOAP1qNBefM7C1JCwGzUsm/BA4BTi1wykLA99WwrYWcjv+vN1a/lwHnSRpu\nZtOrZ1btiRZH6+MuM/ubmV1jZuea2TbA5rhzuF3SAgXOeRo4ppFj+VwD9AV2NrMdzGyUmY0xs+OB\nVYCLmjl/qeTvtOYuJKlzEfZUlEItsVbCdsCPgZtrbQiAmc3Mc2BqJm9DFczKjKTVgYOAc5vIdgvu\nWHatilF1RDiONoCZPYC/HfUA9so/DJyGt04ObqocSevib4pXmtnYAteZZWbHNnH+cLw1ZMD5SRP/\n9eTYiGS/j6S/SfoMeCh1bi9J/5D0adI19qSk7Qpco6ukP0l6W9K3kl6VdKykRh9UqXPflDRO0lZJ\n+TOAA1LH95L0lKRvEjtukLRsgXLWk3SnpM8kfS3pGUmHpY6vIWmMpNeSe3lf0lWSFm/OxhLYAXjT\nzN7Is+3qpO99BUl3J/a9K+mUAvfRWdIfU3X5kqSjCuTbUtJDkj5Pyn5J0pmp43PFOCSNwVsbuXhG\ng6TZqfzzxDgk9ZP0b0nTkmuMl7ReXp5cd+0Gki5IupG+lnSryh9ruBB3DA/TiBM0s4+BZ/Hvol0R\nXVVth78CZwFbAVflHXsIuA84VtIlZvZdI2Vsjz/0r8towy3A53jX2d+AO/FuLpjT3L8ZeAXvGxaA\npNXwH+g7wNnAdODXwFhJO5nZ7Um+hYAHgaWBS4GpwAbJOd2BI5uxz4DeiW2XAZcDLydln4Q72BuB\nK4Al8a6g/0rqZ2ZfJvm2BO4A3kvu8wOgD/Ar5rTGtgRWAP6SHF8NOBBYFVi/GRuLZQNgUiP32AG4\nC3gMOAb4BXCqpI5mNiKV9w5gIHAl8AywNd71soyZHQUgadUk39N4N+l3wMrJ9RvjUmAZYAtgT5po\nfaSu8SDeSj0H78Y6EHhA0iZm9mTeKaOAz4ARQE9gKDAa+GGQgKQuNNGlmmJW7rtNnbsr3t3aG1ix\nmfMn0g4dB2YWWyvYgN8As4H+TeT5HHgqtT88OWdxYGOgATg8dfwNYFxq/5Yk/6ItsLNHcp0j89KH\nJ+l/LXDOeGAyMF9e+sPAS6n9k/Fg8Ip5+c4CZgI/aca2N5L72yIvfXm8f/64vPRVk3KPT/Y7AK8D\nrwGLNHGdBQqk7ZZce8MC3+nyqbT7gfuauY+OyXl/KHBsTHJsZF76HcAMYPFkf4fk+zg+L99N+IN7\nhWT/8KS8xYr4zoek0kYBsxvJ3wAMS+3fltjWI5XWHXck9+fVVwPeXZsu74/J97SIzV0PDUVs9+WV\ntSDeaj492R+Y5NupkXs5PqmfH2f9zbTGLbqq2hZfA4sUOmBmD+EPpWObiHUsmvz9qgK2gb8NX5ZO\nkLQYsCneEukqaYncBtwDrCJp6ST7LnjraVpevnvx1vMmRdjwhpmNz0vbGX8rvjmv3I+AVxP7APrj\nb7h/MrNG68hSLTr5CLUlgMeTa/QvwsbmWDwp6/Mm8vw5b380sADeCgDvkvwef8Cn+SPuILdJ9r9I\n/g4qpjuwVCR1wFtot5nZW7l0M/sAbxlupLlHDBreUkzzEO5Me6TSzsXvtbktv2vuBPx/6ewibyH3\nHfy4yPxtguiqalssDHzYxPERwH/xoN+FBY7nmuyLpD6Xmzfy9lfGH4KnA2cUyG94wP19PDi/BvBx\nE/lKvX7Ohg748MpC5c5MPq+Y7L/Q1AUSZzgCb2WkbTKgaxE2FktjD/IGvGWU5pXkb8/k7/LAezbv\naKApyd/cQ/jvwL549905ku4FbgX+YckrdwtZEuicsi/flg7Acim7wLso0+Qe3ovlEszsJeClUgyR\n1BM4GjjYzL4p9rTcJUu5VmsnHEcbQdJP8IdSoYcf4K0OSQ/grY7LCmTJ/dDWAB4pu5HOjLz9XKv3\nfHwYcCH+l8r7H/xtstBDs9DDp7nr58ptwGMBhUb7fF0grSluxvvI/4DHDr5OrnE35RmQ8hn+oFqs\nuYwtxcy+BTaRtCkex/kF7hDvlbRVmZxHqcxuJP2H/wlJi+LDfptjppnlHM9peJztQUk5x5lr7S6Z\npL2dd8+57+CToixvI4TjaDsMwR8mdzWTbwTeZXVggWN34E31vaic48gn92Y8y8yamzH9GrCwmd1f\nZhtewx86b5pZo443lW91fLDBPEj6EbAZcIqZpUcerVwuY81stqTX8AB8ITrgraP0vfRK/uZaXG8B\nm0vqktfq6JM6nr7m/fj/zdGSTsBbh5vSSD1Q/Bv4x8A3KfvS9MEdeX4LoxguxGMizfEA/n2Bt2xW\nZt7WmgGXMMdZp1vjKwCfmNmnGWxstUSMow0gaTM8cPw63i/cKGb2IN5ddRx5o07MbALuePaTNM9I\nEUnzSzqvXHYn1/wY//EeKKl7gWum+45vAtaXtFWBfF0ldcxoxq34A6rgTOfUMNpJ+IP3CEmNdTnl\n3obzf1tDKW93xmPAz5o4/rsC+zOZ86C/E39xzM83FK+Lf8MP3W75PIM70KbmBU1Pzl+0iTyYz+e4\nB9hBc0uvdMNHST1kZqW2+CBbjOMkYBCwY2o7OVXeoNx9pVgb/y7aFdHiaF0I+KWkPvh31w1/W9oS\nf6Btb2Yzmzg/x6n422MhhuBdKrdI+iceeJ6Oxxd2x0e7HNOSmyjAoXiA8zlJV+AOsBs+dPUn+Cx2\ngPPwIcP/lHQ1PhSyC7AmsBPef/9ZqRc3s9clnQycJWkFYCw+QGBF/OFxGXCBmZmkg4FxwNPJfIX3\n8WGbq5rZNmb2laQH8e7A+YF38SHSPWlmWGqJ3A7sJWnlAq2k74BfJHX0OB4I3wY4M/VmfAf+P3Bm\ncs+54bjb4SOyci2TYZI2Af6Ft0K64fOB3sZHvTXGRPx+R0m6Gx9h9fdG8p6MP8QfkXQx7nwPAOYH\n8ucNNVaHc6VniXGY2aPzFCpNS8p+0szG5R1bEv/fyx9g0Pap9bCu2IrbmDN0M7fNwB9Kd+EP3i4F\nzvlhOG6BY/cnx24vcGwB/M1zAj4kcgYePxgFrNSMnT2ScocWa0tyvCc+hPJd4Fv8wXQ7sGNevs54\nN8nLiV0f4k7nCKBjM7a9Xuh+U8d3xFtjXybbC3iXx8p5+dZP6v2LJN9kPKCaO7408A/gU9yR3YA/\ncGfjXVj532n+cNx7i/h/6ISP+joxL31MYlPPxMav8DknpxQoozMeW5qa1PlLBb63n+MtsqlJfU/F\n5wytlMqT+87Tw3E7MGeey/ekhubm10OS1hdvBU1LbP4PsG4jv4H+eekDk/RNKvC7y5U9z3BcfJDJ\nVxT47bX1TUkFBEHQykh
"text/plain": [
"<matplotlib.figure.Figure at 0x117185450>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAADFCAYAAACYV79FAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJztnXm4XdP5xz/fRIQgGqkKNYUYYhatWWKKoEVLEMTUomiL\nUNNPiaHmVkrQRtXUEI2aYqhozFOqboiqeUjMQ8xTyPD+/nj3kX2Pc+49Z98z3Xvfz/OsJ2fvvfZa\n7173ZL9nrfUOMjOCIAiCoJJ0qbcAQRAEQccjlEsQBEFQcUK5BEEQBBUnlEsQBEFQcUK5BEEQBBUn\nlEsQBEFQcUK5BEEQBBUnlEsQBEFQcUK5BEEQBBUnlEtQdST1k3SnpI8kzZG0Y71lakQk7SdprqRl\nU+fulXR3ifcvJOkdSXtUT8qS5FgueY59Sqw/V9JJ1ZarUZA0TtLf6y1HtQnl0k6RtG/ynzJXvpT0\nhqQ7JP1a0sIF7hmZ1H1L0gIFrk+TNKHA+e6SRkianCiILyU9J2m0pJVKEPcqYHXg/4C9gccyPHJn\nwJKSf65UjgA+Aa6tmETZaSa3pO0kjWyhbkPGoZI0WNJfJf1X0mxJLxept4qkcyQ9LukTSW9KulXS\negWqnw3sImnN6kpfX+artwBBmzDgRGAa0A3oA2wO/BE4UtKOZvbfAvd9DzgEGFWgvWZI6g1MBNYF\nbgWuBj4DVgGGAQcC31JUqfsXADYETjOzi0t/tKAcJM0HHAb8weocMNDMpktaEJiVOr09cChwSoFb\nFgRm10K2DOwJ7AZMAd5ood4BwM+A64GLgEWBXwCTJQ0xs29mn2b2hKTHgKOA/aokd90J5dL+ucPM\npqSOz5a0OXAbcLOk/mb2Vd49TwBHS7q4wLV8rgTWBnYxs5vSFySdCJzeyv3fS/79uJV6SOphZl+0\nVq+aSFrAzGbWU4aM7AB8F7iu3oIAmNnXeadURt1G4njgADObI+kWfAZeiGuAkenvr6TLgWeAk4H8\npc3xwMmSDq33d75axLJYB8TM7gVOA5YDhudfBk7FZzmHtNSOpPXxX5yX5iuWpJ9ZZnZMC/ePxGdV\nBvw+WZJ7Obl2cnLcX9I1kj4AHkjdu4qkf0h6P1mG+4+kHQr0saikP0p6VdJMSS9IOkZS0ZdZ6t5p\nkiZI2iZp/0vgoNT14ZIek/RFIsc4SUsXaGcDSbdL+kDSZ5KmSjosdX1NSZdLeil5lreSpZbFWpOx\nDHYCppnZK3myXSHpU0l9JU1M5Hsj+WGQ/xw9JP0hNZbPSjqqQL3Bkh6Q9GHS9rOSTk9db7bnkrxk\nD00+55Zx56Tqf2vPRdK6kv4p6eOkj0mSNsirk1sa3ljSeZLeTZ7vhmTG3WbM7G0zm1NCvcfzlYSZ\n5b7T/Qvc8i9gYWBwJeRsRGLm0nH5G3AGsA3w17xrD+C/pI6R9KcWZi874ophbEYZrgc+xJfprgFu\nx5fUYN4S3HXA8/gvRAFIWh14EHgdOBP4HF+auEnSzmZ2c1JvQeB+YEngz8BrwMbJPX2AI1uRz4BV\nE9nGAJcAzyVtn4Ar4WuBvwCL48tO90la18w+SeoNBm4B3kye8238ZfIj4IKkn8FAX+Cy5Prq+JLJ\nasBGrchYKhvjSzeFnrELcAfwCHA0sC1wiqSuZnZyqu4twCDgUmAqMAQ4V9JSZnYUgKTVknpP4Euy\nXwH9kv6L8WdgKWBrYC9amMWk+rgfn+2ehS+Z/QK4V9JAM/tP3i2jgQ/wGcLywAjgQuAbwwZJC9HC\n8m2KWbm/bYXoA8wocP5p4EtgE+DmCvbXOJhZlHZYgH2BOcCAFup8CDyWOh6Z3LMYsBkwFzg8df0V\nYELq+Pqkfs82yLlc0s+ReedHJuf/VuCeScDjwHx55x8Enk0d/xbfwF4hr94ZwNfA91uR7ZXk+bbO\nO78svl9wbN751ZJ2j0uOuwAvAy8Bi7TQT/cC53ZP+t6kwN902dS5e4C7W3mOrsl95xS4dnlybVTe\n+Vvwl9tiyfFOyd/juLx64/GXe9/k+PCkvV4l/M33SZ0bDcwpUn8ucFLq+MZEtuVS5/rgyuaevPGa\niy8Np9v7Q/J3WsSaj8PcEkrRsU7G7OUyvvubJWM1ssj1Z4Fbs/7favQSy2Idm8+ARQpdMLMH8BfX\nMZK6F7m/Z/Lvp1WQDfxX9Zj0CUm9gC3wGc2iknrnCnAnsJKkJZPqQ/FZ2Md59e7CZ+UDS5DhFTOb\nlHduF/zX9XV57b4LvJDIBzAA/6X8RzMrOkaWmhnKLe96A/9O+hhQgoytsVjS1oct1Lko7/hCoDs+\nmwBf/pyNK4E0f8CV6HbJ8UfJvz8tZemxXCR1wWd6N5rZ9Nx5M3sbn2FuquaWkIbPONM8gCvc5VLn\nzsaftbXyrWXAjM+xeCLvS8C5Rap9iO+TdUhiWaxjszDwTgvXTwbuAw4Gzi9wPbc8sEjqc6V5Je+4\nH/6iPA34XYH6hhsJvAWsBKwJvNdCvXL7z8nQBXixSLu5DegVkuP/tdRBojBPxmcraZkMtyqqFMVe\n9nPxGVaa55N/l0/+XRZ408w+z6v3TPJv7kX9d+Dn+FLhWZLuAm4A/mHJz/E2sjjQIyVfvixdgGVS\ncoEvh6bJKdleuRNm9iw+U6g6knrgBjULAdtY8Q170aAm2JUglEsHRdL38RdXoRck4LMXSffis5cx\nBark/jOuCTxUcSGdL/OOc7Pp3+Mm0IV4MVX3X/iv0kIv1kIvqNb6z7U7F9+bmFvg+mcFzrXEdbg5\n9jn4XsZnSR8TqYxRzQf4S6pXaxXbirkl3UBJW+D7StviSvMuSdtUSMGUS7EN92++E5J64ibPrfG1\nmbU0A2wRSd3wZb01cMXyTAvVe1Had7RdEsql47IP/sK5o5V6J+PLY78ocO0WfKN9ONVTLvnkfmHP\nspRvQBFeAhY2s3sqLMNL+ItpmpkVVc6pemvwbVNTACR9B9gSONHM0hZV/SolrLmZ7Eu40UAhuuCz\nrPSzrJL8m5u5TQe2krRQ3uylf+p6us978O/NbyQdj88yt6DIOFD6L/T3gC9S8qXpjyv7/JlKKZyP\n79G0xr3436tskmXCv+HjsKuZPdhC3a74DKxjbuYTpsgdEklb4pvdL+PrvkUxs/vxpbFjybOmMbPJ\nuHI6QNJOBfqZX1Kx9eRMmNl7+H/wX0jqU6DP9Br1eGAjSdsUqLdo8h84CzfgL7GCHuUpE+Ip+Mv5\nCEnFlrdyv6rz/6+NoLJLIo8AP2jh+q8KHH/NPGVwO/5jM7/eCHws/gnfLPHlMxVXssX27sAt/nIz\niKKY2Vx8b20nNQ+DswRu/fWAmZU7c4Ta7LlcCOwKHGKJRWMLrIb/f6vVj7aaEzOX9o2A7SX1x/+W\nS+C/ugbjL70drTQHtVPwX6GF2Adfvrle0q34Zvnn+H7HMNyK5+i2PEQBfolvyv5X0l9wJbkEbrb7\nfTxaAPhG6Y7ArZKuAJrwde61gJ3x/YQPyu3czF6W9FvgDEl9gZtwo4YVgJ/gRgjnmZlJOgSYADyR\n+HO8hZs3r2Zm25nZp5Lux5ce58e9vLdJZKvkhvjNwHBJ/QrMtr4Ctk3G6N/45v12wOlm9n5S5xb8\nO3B68sw5U+QdcEuz3AznJEkD8T2F6fjf5RDgVdyarxhN+POOljQRtxwrFl/rt/iL/iFJF+MK+iBg\nfiDfr6rYGDY7n3XPRR6iJRcLrx9uZHJCcjzVzG5N6h2Bj8PDwExJe+U1dYOZpZdgt8H/H+Ubk3Qc\n6m2uFiVbYZ7Zaq58ib+47sBfzgsVuOcbU+QC1+5Jrt1c4Fp3/BfsZNwc9Et8rXg0sGIrci6XtDui\nVFmS68vj5qNvADPxl9fNwE/y6vXAl2SeS+R6B1dMRwBdW5Ht5ULPm7r+E3xW90lS/ocvr/TLq7dR\nMu4fJfUex3+95q4vCfwDeB9XduPwl/IcfLks/2+ab4p8Vwnfh264Ndv/5Z2/PJFp+UTGT3GfnBML\ntNED3+t6LRnzZwv83Tb
"text/plain": [
"<matplotlib.figure.Figure at 0x117b99b50>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def plot_recall(probs, pos):\n",
" # print \"Is normalized: \", np.abs(1-np.sum(position_probs)) < 1e-6\n",
" plt.figure(pos, figsize=[4,1.5])\n",
"# plt.axis((0,4,0.15,.3))\n",
" plt.title(\"DNC free recall (position={})\".format(pos))\n",
" plt.xlabel(\"Serial position\") ; plt.ylabel(\"Recall probability\")\n",
" plt.plot(range(len(probs)), probs) ; plt.show()\n",
"\n",
"for pos in range(FLAGS.length):\n",
" position_probs = [np.sum(i == recall_orders[:,pos])/float(trials) for i in range(FLAGS.length)]\n",
" plot_recall(position_probs, pos=pos)\n",
" \n",
"p0_probs = np.asarray([np.sum(i == recall_orders[:,0])/float(trials) for i in range(FLAGS.length)])\n",
"p1_probs = np.asarray([np.sum(i == recall_orders[:,1])/float(trials) for i in range(FLAGS.length)])\n",
"plot_recall(.5*(p0_probs + p1_probs), pos=12)"
]
2017-03-01 22:29:54 +08:00
},
{
"cell_type": "code",
2017-03-02 01:49:41 +08:00
"execution_count": null,
2017-03-01 22:29:54 +08:00
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
2017-03-02 01:49:41 +08:00
"# # X, y = next_batch(FLAGS.batch_size, FLAGS.length, FLAGS.xlen)\n",
"# # y = np.stack(y)\n",
"# # y = y[:,FLAGS.length + 3:,:]\n",
"# # # p = np.random.permutation(FLAGS.length)\n",
"# # y_shuffle = y #y[:,p,:] ; print 'permute by: ', p\n",
"# # # y_shuffle[:,-1,-1] = 4\n",
"\n",
"# # plt.figure(0, figsize=[1,1])\n",
"# # plt.imshow(y[0,:,:].T, interpolation='none')\n",
"# # plt.figure(1, figsize=[1,1])\n",
"# # plt.imshow(y_shuffle[0,:,:].T, interpolation='none')\n",
"# # plt.show()\n",
"\n",
"# def guess_recall_order(real_y, pred_y, FLAGS):\n",
2017-03-01 22:29:54 +08:00
"# # sorry this is uuuuugly but we have to because it's batched\n",
"# real_y = np.tile(real_y, [1,1,1,1]) ; real_y = np.transpose(real_y, (1,0,2,3))\n",
"# pred_y = np.tile(pred_y,[1,1,1,1]) ; pred_y = np.transpose(pred_y, (1,0,2,3))\n",
" \n",
2017-03-02 01:49:41 +08:00
"# pred_y = np.tile(pred_y,[1,FLAGS.length,1,1])\n",
2017-03-01 22:29:54 +08:00
"# pred_y = np.transpose(pred_y, (0,2,1,3))\n",
" \n",
2017-03-02 01:49:41 +08:00
"# # real_y = real_y[0,:,:,:] ; pred_y = pred_y[0,:,:,:]\n",
2017-03-01 22:29:54 +08:00
"# y_minus = .5*(real_y - pred_y)**2\n",
"# y_minus = np.sum(y_minus, axis=-1)\n",
2017-03-02 01:49:41 +08:00
"# y_mins = np.amin(y_minus, axis=1)\n",
2017-03-01 22:29:54 +08:00
" \n",
2017-03-02 01:49:41 +08:00
"# k, l = np.where(y_minus == y_mins)\n",
"# i = 0\n",
"# while i < len(k): \n",
"# if k[-i] == k[-i-1]:\n",
"# k = np.delete(k,-i) ; l = np.delete(l,-i)\n",
"# i+=1\n",
"# return l\n",
2017-03-01 22:29:54 +08:00
"\n",
"# X, real_y = next_batch(FLAGS.batch_size, FLAGS.length, FLAGS.xlen)\n",
"# real_y = np.stack(real_y)[:,-FLAGS.length:,:]\n",
"\n",
2017-03-02 01:49:41 +08:00
"# p = np.random.permutation(y_i.shape[0]) ; print 'permute by: ', p\n",
"# pred_y = [y_i[p,:] for y_i in real_y]\n",
2017-03-01 22:29:54 +08:00
"# pred_y = np.stack(pred_y)\n",
"# # pred_y[:,-2:,-1] = 4\n",
"\n",
2017-03-02 01:49:41 +08:00
"# plt.figure(0, figsize=[1,1])\n",
"# plt.imshow(real_y[0,:,:].T, interpolation='none')\n",
"# plt.figure(1, figsize=[1,1])\n",
"# plt.imshow(pred_y[0,:,:].T, interpolation='none')\n",
"# plt.show()\n",
"\n",
2017-03-01 22:29:54 +08:00
"# np_loss(real_y, pred_y, FLAGS)"
]
2017-03-02 01:49:41 +08:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def scatter(idx, vals, target):\n",
" \"\"\"target[idx] += vals, but allowing for repeats in idx\"\"\"\n",
" np.add.at(target, idx.ravel(), vals.ravel())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
2017-03-01 22:29:54 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
}
},
"nbformat": 4,
"nbformat_minor": 1
}