dnc-jupyter/free-recall/free-recall-rnn.ipynb

709 lines
107 KiB
Plaintext
Raw Normal View History

2017-03-01 22:29:54 +08:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Repeat Copy Task\n",
"### Differentiable Neural Computer (DNC) using a RNN Controller\n",
"\n",
"<a href=\"https://goo.gl/6eiJFc\"><img src=\"../static/dnc_schema.png\" alt=\"DNC schema\" style=\"width: 700px;\"/></a>\n",
"\n",
"**Sam Greydanus $\\cdot$ February 2017 $\\cdot$ MIT License.**\n",
"\n",
"Represents the state of the art in differentiable memory. Inspired by this [Nature paper](https://goo.gl/6eiJFc). Some ideas taken from [this Gihub repo](https://github.com/Mostafa-Samir/DNC-tensorflow)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Brain analogy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" However, there are interesting parallels between the memory mechanisms of a DNC and the functional capabilities of the mammalian hippocampus. DNC memory modification is fast and can be one-shot, resembling the associative long-term potentiation of hippocampal CA3 and CA1 synapses. The hippocampal dentate gyrus, a region known to support neurogenesis, has been proposed to increase representational sparsity, thereby enhancing memory capacity: usage- based memory allocation and sparse weightings may provide similar facilities in our model. Human 'free recall' experiments demonstrate the increased probability of item recall in the same order as first presented—a hippocampus-dependent phenomenon accounted for by the temporal context model, bearing some similarity to the formation of temporal links."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import sys\n",
"sys.path.insert(0, '../dnc')\n",
"\n",
"from dnc import DNC\n",
"from rnn_controller import RNNController\n",
"\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Hyperparameters"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"xydim = 6\n",
"tf.app.flags.DEFINE_integer(\"xlen\", xydim, \"Input dimension\")\n",
"tf.app.flags.DEFINE_integer(\"ylen\", xydim, \"output dimension\")\n",
"tf.app.flags.DEFINE_integer(\"length\", 5, \"Sequence length\")\n",
"tf.app.flags.DEFINE_integer(\"batch_size\", 3, \"Size of batch in minibatch gradient descent\")\n",
"\n",
"tf.app.flags.DEFINE_integer(\"R\", 1, \"Number of DNC read heads\")\n",
"tf.app.flags.DEFINE_integer(\"W\", 10, \"Word length for DNC memory\")\n",
"tf.app.flags.DEFINE_integer(\"N\", 7, \"Number of words the DNC memory can store\")\n",
"\n",
"tf.app.flags.DEFINE_integer(\"print_every\", 100, \"Print training info after this number of train steps\")\n",
"tf.app.flags.DEFINE_integer(\"iterations\", 30000, \"Number of training iterations\")\n",
"tf.app.flags.DEFINE_float(\"lr\", 1e-4, \"Learning rate (alpha) for the model\")\n",
"tf.app.flags.DEFINE_float(\"momentum\", .9, \"RMSProp momentum\")\n",
"tf.app.flags.DEFINE_integer(\"save_every\", 1000, \"Save model after this number of train steps\")\n",
"tf.app.flags.DEFINE_string(\"save_path\", \"rnn_models/model.ckpt\", \"Where to save checkpoints\")\n",
"FLAGS = tf.app.flags.FLAGS"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data functions"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
2017-03-02 07:11:23 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfMAAAD+CAYAAAAu5uwhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAEZ9JREFUeJzt3X+snXVhx/H359buojghEUe76FSCOntr0BZ1zgBzhbGQ\niVuWIHcu8xcaRDfTLVEYGDeVyWagDpVkcfJDkRLc4oRsii24KTokUsH0glvYUFBaFBg1gq2l97s/\nzmm999Lenufcc/s83/b9Ss4f9+lznvPJt+c8n/N9znnOk1IKkiSpXmNtB5AkSQtjmUuSVDnLXJKk\nylnmkiRVzjKXJKlylrkkSZWzzCVJqpxlLklS5Z6ymBtP8kzgVOB7wPbFfCxJkg4yhwHPA24spTw8\n34qLWub0ivyzi/wYkiQdzN4AXDPfCotd5t8DeN3Vr+WoFx+1oA1tWLuRU9adPIpMByXHZ98cm/k5\nPvvm2MzP8ZnfQsfnobsf4gt/fAP0u3Q+i13m2wGOevFRLF+1bEEbGj9ifMHbOJg5Pvvm2MzP8dk3\nx2Z+js/8Rjg++/2Y2i/ASZJUOctckqTKWeaSJFWumjKfmFzRdoROc3z2zbGZn+Ozb47N/Byf+R3I\n8UkpZfE2nqwCbn/r7W/2SxKSJDWwZdNWPrX6CoDVpZRN861bzcxckiTtnWUuSVLlLHNJkipnmUuS\nVLmhyjzJO5Pcm+RnSW5N8vJRB5MkSYNpXOZJXg9cDLwfeBlwJ3BjkoX9+LokSRrKMDPztcA/lFI+\nXUr5LnA28DjwlpEmkyRJA2lU5kmWAquBm3YvK70T1TcCrxptNEmSNIimM/OjgCXAg3OWPwj4qzCS\nJLXAb7NLklS5ptczfwjYBRw9Z/nRwNZ93WnD2o2MHzE+a9nE5ApWTk40fHhJkg4+m9dPMbX+rlnL\ndmzbMfD9G/82e5JbgW+WUt7d/zvAfcClpZSPzFnX32aXJGkITX6bvenMHOAS4MoktwO30ft2+9OA\nK4fYliRJWqDGZV5Kua5/TvkH6B1evwM4tZTy41GHkyRJ+zfMzJxSymXAZSPOIkmShuC32SVJqpxl\nLklS5SxzSZIqZ5lLklQ5y1ySpMpZ5pIkVc4ylySpcpa5JEmVs8wlSaqcZS5JUuUsc0mSKmeZS5JU\nOctckqTKDXXVtMZK6d06YGzJRW1HmGV6+ry2I3RXR54zu/ncmV/H/ru4cMnStiPsccH0zrYjdFrn\nnjtjXXnuDF7RzswlSaqcZS5JUuUsc0mSKmeZS5JUOctckqTKWeaSJFXOMpckqXKWuSRJlbPMJUmq\nnGUuSVLlLHNJkirXuMyTnJDk+iQ/TDKd5PTFCCZJkgYzzMz8cOAO4BygYz+PL0nSoafxVdNKKV8C\nvgSQJCNPJEmSGvEzc0mSKmeZS5JUucaH2YexYe1NjB85PmvZxJkrWDm54kA8vCRJHbe5f5tp+8D3\nPiBlfsq6NSxftexAPJQkSRVa2b/NtAX45ED39jC7JEmVazwzT3I4cCyw+5vsxyQ5DniklHL/KMNJ\nkqT9G+Yw+/HAV+idY16Ai/vLrwLeMqJckiRpQMOcZ/4feHhekqTOsJQlSaqcZS5JUuUsc0mSKmeZ\nS5JUOctckqTKWeaSJFXOMpckqXKWuSRJlbPMJUmqnGUuSVLlLHNJkipnmUuSVLlhrppWtZS2E8w2\nNvbhtiPMMr3r3LYj7DG25KK2I8zSpbHRAEr2v84B8qGxpW1HmOX8XTvbjjDLhUs6Nj7T3RifLZue\n4PLjB1vXmbkkSZWzzCVJqpxlLklS5SxzSZIqZ5lLklQ5y1ySpMpZ5pIkVc4ylySpcpa5JEmVs8wl\nSaqcZS5JUuUalXmS85LcluQnSR5M8vkkL1yscJIkaf+azsxPAD4GvBI4GVgKfDnJU0cdTJIkDabR\nVdNKKafN/DvJm4AfAauBW0YXS5IkDWqhn5kfCRTgkRFkkSRJQxi6zJME+ChwSynlrtFFkiRJTTQ6\nzD7HZcAK4NUjyiJJkoYwVJkn+ThwGnBCKWXL/tbfsPYmxo8cn7Vs4swVrJxcMczDS5J0UJlaP8XU\ntbMPcm9/dMfA929c5v0ifx1wUinlvkHuc8q6NSxftazpQ0mSdEiYmJxgYnJi1rItm7Zy+fFXDHT/\nRmWe5DJgEjgdeCzJ0f1/2lZK2d5kW5IkaTSafgHubOAZwL8DD8y4nTHaWJIkaVBNzzP3518lSeoY\ny1mSpMpZ5pIkVc4ylySpcpa5JEmVs8wlSaqcZS5JUuUsc0mSKmeZS5JUOctckqTKWeaSJFXOMpck\nqXKWuSRJlbPMJUmqXKOrpg0t6d06YNeuc9uOMMvYkovajtBZKW0nmK1r/1fT0+e1HWGWjrzE9zh/\n18/bjrDHhUuWth2h20q3njwXjnXl/2vwinZmLklS5SxzSZIqZ5lLklQ5y1ySpMpZ5pIkVc4ylySp\ncpa5JEmVs8wlSaqcZS5JUuUsc0mSKmeZS5JUuUZlnuTsJHcm2da/fSPJ7y5WOEmStH9NZ+b3A+8F\nVgGrgZuB65OsGHUwSZI0mEZXTSul/OucRRckeQfwSuCukaWSJEkDG/oSqEnGgDOAceBrI0skSZIa\naVzmSVYC/wkcBjwOnFFKuWfUwSRJ0mCG+Tb7d4HjgFcAHweuTfKykaaSJEkDazwzL6U8Afxv/89v\nJ3kF8A7g7fu6z4a1Gxk/YnzWsonJFaycnGj68JIkHYQ2928zbR/43kN/Zj7DGLBkvhVOWXcyy1ct\nG8FDSZJ0MFrZv820BfjkQPduVOZJ/gb4InAf8MvAG4ATgQ812Y4kSRqdpjPzXwGuApYD24DvAKeW\nUr4y6mCSJGkwTc8zP2uxgkiSpOH42+ySJFXOMpckqXKWuSRJlbPMJUmqnGUuSVLlLHNJkipnmUuS\nVDnLXJKkylnmkiRVzjKXJKlylrkkSZWzzCVJqpxlLklS5ZpeAnU4pfRuHTC25KK2I3Rb0naCPXZN\nn9t2BDXQkZf4HhcuWdp2hM7q0MscgPOnf952hE7asukJLj9+sHWdmUuSVDnLXJKkylnmkiRVzjKX\nJKlylrkkSZWzzCVJqpxlLklS5SxzSZIqZ5lLklQ5y1ySpMpZ5pIkVW5BZZ7k3CTTSS4ZVSBJktTM\n0GWe5OXA24E7RxdHkiQ1NVSZJ3k6cDVwFvDoSBNJkqRGhp2ZfwK4oZRy8yjDSJKk5hpfzzzJmcBL\ngQGvsipJkhZTozJP8mzgo8DJpZSdixNJkiQ10XRmvhp4FrApSfrLlgAnJnkXMF5KKXPvtGHtTYwf\nOT5r2cSZK1g5uWKIyJIkHVym1k8xde1ds5Ztf3THwPdvWuYbgZfMWXYlcDdw0d6KHOCUdWtYvmpZ\nw4eSJOnQMDE5wcTkxKxlWzZt5fLjrxjo/o3KvJTyGDDrrUOSx4CHSyl3N9mWJEkajVH8AtxeZ+OS\nJOnAaPxt9rlKKb89iiCSJGk4/ja7JEmVs8wlSaqcZS5JUuUsc0mSKmeZS5JUOctckqTKWeaSJFXO\nMpckqXKWuSRJlbPMJUmqnGUuSVLlLHNJkipnmUuSVLkFXzVtIEnv1gHT0+e2HUGD6shzZo/Srav9\njo19uO0Is0xPn9d2hFnOn97ZdgQNqGsv9a5oMi7OzCVJqpxlLklS5SxzSZIqZ5lLklQ5y1ySpMpZ\n5pIkVc4ylySpcpa5JEmVs8wlSaqcZS5JUuUsc0mSKteozJO8P8n0nNtdixVOkiTt3zAXWtkMrAF2\n/wT8E6OLI0mSmhqmzJ8opfx45EkkSdJQhvnM/AVJfpjkf5JcneQ5I08lSZIG1rTMbwXeBJwKnA08\nH/hqksNHnEuSJA2o0WH2UsqNM/7cnOQ24PvAGcAVowwmSZIGM8xn5nuUUrYl+W/g2PnW27B2I+NH\njM9aNjG5gpWTEwt5eEmSDgqb108xtX72yWE7tu0Y+P4LKvMkT6dX5J+eb71T1p3M8lXLFvJQkiQd\ntFZOTjxpgrtl01Y+tXq
2017-03-01 22:29:54 +08:00
"text/plain": [
2017-03-02 07:11:23 +08:00
"<matplotlib.figure.Figure at 0x10ad67690>"
2017-03-01 22:29:54 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def get_sequence(length, dim):\n",
" X = np.concatenate((np.random.randint(2, size=(length,dim)), np.zeros((length + 3,dim))))\n",
" X = np.vstack(X) ; X[:,dim-1] = 0\n",
" \n",
" X = np.concatenate((X[-1:,:],X[:-1,:]))\n",
" y = np.concatenate((X[-(length + 2):,:],X[:-(length + 2),:]))\n",
" markers = range(length+1, X.shape[0], 2*length+3)\n",
" X[markers,dim-1] = 1\n",
" return X, y\n",
" \n",
"def next_batch(batch_size, length, dim):\n",
" X_batch = []\n",
" y_batch = []\n",
" for _ in range(batch_size):\n",
" X, y = get_sequence(length, dim)\n",
" X_batch.append(X) ; y_batch.append(y)\n",
" return [X_batch, y_batch]\n",
"\n",
"batch = next_batch(1, FLAGS.length, FLAGS.xlen)\n",
"plt.imshow(batch[0][0].T - batch[1][0].T, interpolation='none')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Helper functions"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def binary_cross_entropy(y, y_hat):\n",
" return tf.reduce_mean(-y*tf.log(y_hat) - (1-y)*tf.log(1-y_hat))\n",
"\n",
"def llprint(message):\n",
" sys.stdout.write(message)\n",
" sys.stdout.flush()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def free_recall_loss(y, y_hat, tsteps): \n",
" # sorry this dimension stuff is uuuuugly but we have to because it's batched\n",
" y = tf.expand_dims(y, [1])\n",
" y_hat = tf.expand_dims(y_hat, [1])\n",
" \n",
" y_hat = tf.tile(y_hat,[1,tsteps,1,1])\n",
" y_hat = tf.transpose(y_hat, [0,2,1,3])\n",
" \n",
" y_minus = -y*tf.log(y_hat) - (1-y)*tf.log(1-y_hat) # binary cross entropy loss\n",
" y_minus = tf.reduce_sum(y_minus, axis=-1)\n",
" y_minus = tf.reduce_min(y_minus, axis=1)\n",
" \n",
" return tf.reduce_sum(y_minus) / (FLAGS.batch_size*FLAGS.length)"
]
},
{
"cell_type": "code",
2017-03-02 07:11:23 +08:00
"execution_count": 8,
2017-03-01 22:29:54 +08:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2017-03-02 07:11:23 +08:00
"permute by: [0 4 3 1 2]\n",
"guessed permutation: [0 4 3 1 2]\n"
2017-03-01 22:29:54 +08:00
]
2017-03-02 01:49:41 +08:00
},
{
"data": {
2017-03-02 07:11:23 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAGMAAAB1CAYAAABatF8TAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAB65JREFUeJztnWuoXcUVx3//xOC7LXh9gVVbrIIoUeOzmgSjEPFDBBFf\nhaKiEh/9EBAfIApBpFAMFdNAY6lXESv9YMBSzNUY+5AY0l41eqlWfMRnvDURrxAjJnH5YfbJPZ6c\ns8+ZObN3hu75webeM2fPnnX22vNae80amRmZNJixtwXITJOVkRBZGQmRlZEQWRkJkZWREFkZCZGV\nkRD7VHlxSYcAC4FNwNdVlpUw+wHHAmNmtrX0TDOr7ACuBiwfGHB1v/sVVDMk3QLcBhwBbAR+ZWb/\n6nLqJvfnUmCky9ergYs8S4+Zp8wUNIar1N/nRh72KmUL8JT7d1OZhBDQTEm6AngAuBHYACwBxiQd\nb2ZbOk4vmqYR4MguV9uvR3oZMfOUKaN7nrKS+0jWt5kO6cCXAL83s8fM7E1gMfAVcF3AtTJteClD\n0ixgDvB8K81c57AGOCeuaM3Dt2aMADOByY70SVz/kRmCSoe206zGtajtnFQcvqSb54fAnzrSfMbz\nvsrYAuwCDu9IPxz4tHe2i/DvdHtxck15/JVxcZe0zcDKAfN7NVNmtgMYBy5opUlS8Xmdz7UyexLS\nTC0DRiWNMz20PQAYjShXI/FWhpn9WdIIsBTXPL0KLDSzz2IL1zSCOnAzWwGsiCxL48lW24SoaWjb\nspUNxj0srU6UNpZyTy3lDIp3zZA0V9LTkj6W9K2kRVUI1kRCmqkDcZ32zfg87pm+hIymVuOm1K05\nRiYSuQNPiKyMhKhpNDVGPENhurwOTHSkVWkoDGQh8QyF6XIye5okKzMUZqol5B34gcBxQGsk9VNJ\ns4HPzezDmMI1jZBm6nTgBaan1Q8U6Y+S34MPRcg84+/k5q0S8k1NiJpGU2K6i+lPXQa8ugySg+Lr\nqnOXpA2SvpQ0KWmVpOOrEq5p+DZTc4GHgLOAC4FZwLOS9o8tWBPxaqbM7HsOEJKuAf6Hc2x7MZ5Y\nzWTYDvxHuOHt5xFkaTzByijM578FXjSz/8QTqbkMM5paAZwInNv/1F4ehSHOZemyVwyFkpbjHOjm\nmtnm/jliehSmy7CGwhDb1HLgEmC+mX3gmz/TGy9lSFoBXAUsArZJavncTplZU9fsRcO3A18M/AD4\nG/BJ23F5XLGaie88I9uyKiTf3ISoxVB4AyuTHEvVY5DcDCUrZNvxNRQulrRR0lRxrJPkuw440wPf\nZupD4A7gNJw9ai3wtKQTYwvWRHw78L92JN0t6SacFTebRIYkuM+QNAM3pN0X+Gc0iRpMyAz8JOAl\nnLHpK+ByM3s7tmBNJKRmvAnMxq20vQx4UtJ8M3ulV4Zm+BOCMxOGmwpDvEN2Au8WH1+RdCZwEy6W\nSFea4U8I3R+xioa2JdeYGeE6jcfXUHg/8AzwAXAw8AtgHnBffNGah28zdRjOc/BIYAp4Dbfs+IXY\ngjUR33nG9VUJksmGwqSoyaPQj/8nj8LBx1JD1gxJdxbLj5cNc52MYxhXnTNwc4uN8cRpNkHKkHQQ\n8DhwPfBFVIkaTGjN+B3wFzNbG1OYphNiKLwSOAW3gikTEd8Z+FE4l84Li6hsA9EUQ+FwZkL/mjEH\nOBR4uS1UxUxgnqRbgX2ty45aTTEUDmcm9FfGGvb0YBwF3gB+3U0RmcHxNYdso+P1qqRtwFYzeyOm\nYE0khjkk14ZIDG0OMbMFMQTJZENhUqjKPlfSacA43IDfeKqumGJ1tLC7x1NzzOzlsjN9PQrvLQyD\n7Uf2l4pESJ8xgQu33Xp8d8YTp9mEKGNnju5cDSEd+M+KMKrvSHpc0o+jS9VQfJWxHrgGZ+FYDPwE\n+EcRgyozJL4z8LG2jxOSNgDv43xuH+mds5epsP3voLyO/5LlkDwT+Mv2DM5ppp0KPQrbMbMpSW/h\nIrOV0MtU+CT+P3iCsBtbhzKmgCs70mryKCze+B1XlJgZEt95xm8kzZN0jKSfA6uAHey51VAmAN9m\n6ijgCeAQ4DNcJJ2z++5ZmhkI3w78Ks/rF71258aWLb6mewtXZg7plaeMXnnKzCGxytn92ztHMF3E\nyRvtJrPRbtWGwrwFtccW1JUqI+NHfp+REFkZCZGVkRBZGQmRlZEQe0UZkm6R9J6k7ZLWF8sLys73\n2o4uJDJ1jCA1w65XqV0Zkq7AbfNwL3Aqbn3HWLFfbC98t6MLiUw9VJCaKOtVqpyB95iVrwcebPss\n4CPg9gHzfwss8ixzpMh3nme+rcC1A5x3EPBfYAFub5FlIfem1pohaRbuqXu+lVb4564BzqmwaK/I\n1JJmFEsfBg1SE2W9St0LLEdwXuuTHemTwAlVFOgTmTokSE3M9SpJrnaNjEdkar8gNaHrVXpSc38x\nC/cyalFH+iiwKnafASzHvaM/OlDe54CVJd9fAuwCvil+145Cvlaaku0ziqdnHOcEB+xuRi4A1sUs\nqy0y9fkWHpm6X5Ca1nqVU3A1ajbwb9zi09lFfzgwe6OZWgaMShoHNgBLgANwtaMrvtvRhUSmDglS\nY7HXq9TZTLVV75tx7zi24zrM0/ucP5/p6t9+/LGkKes8dxfwy5Iy/oCLo7Ud+BR4FlgQ8NvWEji0\nze8zEiLbphIiKyMhsjISIisjIbIyEiIrIyGyMhIiKyMhsjISIisjIbIyEuI7aIsTsxBe7NYAAAAA\nSUVORK5CYII=\n",
2017-03-02 01:49:41 +08:00
"text/plain": [
2017-03-02 07:11:23 +08:00
"<matplotlib.figure.Figure at 0x119267310>"
2017-03-02 01:49:41 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2017-03-02 07:11:23 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAGMAAAB1CAYAAABatF8TAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAB5BJREFUeJztnWusXFUVx3//YgMCPhKuIAmKGsSEYAqCD4S2oZCU+KHE\nxABiYtBQUkA/NDE+EgOxIYTE0ECsTWyNXg1R4gebYNReLOUhqU21YvVG0CgioOVKS7wmpcS2LD7s\nM73TuTNzZ+/Z53Q3Z/2Sk8nsOfvsNbNmv9Zee22ZGU4ZLDreAjhzuDIKwpVREK6MgnBlFIQroyBc\nGQXhyiiIN9X5cElnACuB54DX6iyrYE4B3gNMmdn+oXeaWW0XcCNgfmHAjQv9Xkk1Q9LtwJeAdwJ7\ngC+a2W/73PpcePkkMNHn4ylCxTmW1WweWHb/HLCZW4ZIvBW4ZsjndebZB/wUjv4Wg4lWhqTrgXuB\nW4BdwFpgStL5Zrav5/aqaZoAzu7ztFP6pve7c3iO9FwN5lmwmU7pwNcC3zGzH5rZM8Aa4FXg8wnP\ncrqIUoakxcAlwCOdNAudwzbgsryitY/YmjEBnATM9KTPEPoPZwxqHdrOMUVoT7u5sLriiM+Rmisl\nz9uAH/ekjT6ij1XGPuAIcFZP+lnAS4OzrSS+M+xPmjI+2FCeT/RJ2wtsGil3VDNlZoeA3cBVnTRJ\nqt7viHmWM5+UZmo9MClpN3ND21OByYxytZJoZZjZTyRNAOsIzdMfgJVm9nJu4dpGUgduZhuBjZll\naT1utS2Ihoa2TRDvcnQH62qQ41j2whBL27FE1wxJSyU9JOlfkl6XtCr2GU5/Upqp0wid9m2k/B2d\ngaSMprYSbMWdOYaTCe/AC8KVURAnnKGwZKarq5uYhf+GlJHPUFgy/f5etQ5tnfpIWQM/DTgP6Iyk\n3idpCfCKmb2QU7i2kdJMXQo8ypwLyr1V+g/wdfCxSJlnPI43b7XgP2pBFGkoTJnWpxj91nFHQkmx\njD6einXV+ZqkXZL+J2lG0hZJ56eI6MwntplaCnwL+ChwNbAYeFjSm3ML1kaimikzO8b9QdJNwH8I\njm1P5hOrnYzbgb+dMLx9JYMsrSdZGZX5/D7gSTP7cz6R2ss4o6mNwAXA5Qvf2g5D4bimwtT9GRsI\n7nNLzWzvwjnaYSgc11SYYpvaAFwLLDez52PzO4OJUoakjcCngVXAAUkdn9tZM2vrnr1sxHbga4C3\nAo8B/+66rssrVjuJnWe4LatG/MctiEYMhavZHDWW+kYjBrwT3KNQ0hpJeyTNVtcOSbH7c50BxDZT\nLwBfAT5EsEdtBx6SdEFuwdpIbAf+856kr0u6lWDFdZPImCT3GZIWEYa0JwO/ziZRi0mZgV8I/IZg\nbHoVuM7M/pZbsDaSUjOeAZYQ9tl+CnhQ0nIze2pQBjcTjkaKd8hh4Nnq7VOSPgLcCoMjqbiZcDRy\nTPoWEaImOGMSayi8G/gl8DzwFuAzwDLgrvyitY/YZupMgufg2cAs8EfCtuNHcwvWRmLnGTfXJYjj\nhsKiaMRQKOK8BJsw4MEJ7lHYi6SvVtuP14/zHCcwjqvOhwlziz35xGk3ScqQdDrwAHAz8N+sErWY\n1JrxbeBnZrY9pzBtJ8VQeANwEWEHk5OR2Bn4OQSXzqurqGwjsZX+hsKUwHNl06xH4SXAO4Dfd4Wq\nOAlYJukLwMnW50Sta2iHobBpj8JtzP9DTwJPA/f0U4QzOrHmkAP0LK9KOgDsN7OncwrWRnKYQ7w2\nZGJsc4iZrcghiOOGwqJoxFC4idWUOZ5qInbZ6GXEehTeWRkGuy/3l8pESs2YJoTb7qj8cD5x2k2K\nMg57dOd6SOnA31+FUf27pAckvSu7VC0lVhk7gZsIrlBrgPcCT1QxqJwxiZ2BT3W9nZa0C/gnwef2\n+4NzDvIp7H4dlemG8vyJeFPmLwhOM93UvPW4g5nNSvorITLbEAb5FD5IucqYJl4Zs4T9p93UdJhJ\nL9WK33lVic6YxM4zvilpmaRzJX0c2AIcYv5BQ04Csc3UOcCPgDOAlwmRdD624JmlzkjEduC9DeJC\nVL1278GWHV4jvoXLmWeYqSJXOUe/e+8IZj5+0G45B+2qzsU5P4IaiDiCulZlOHH4ekZBuDIKwpVR\nEK6MgnBlFMRxUYak2yX9Q9JBSTur7QXD7o86ji4lMnWOIDXj7ldpXBmSricc83AncDFhf8dUdV7s\nIGKPo0uJTD1WkJos+1XqnIEPmJXvBO7vei/gReDLI+Z/HVgVWeZEle+KyHz7gc+NcN/pwF+AFYSz\nRdan/DaN1gxJiwn/ukc6aZV/7jbgshqLjopMLWlRtfVh1CA1WfarNH1kwwTBa32mJ30G+EAdBcZE\npk4JUpNzv0qR52dkJiIydVyQmtT9KgNpuL9YTFiMWtWTPglsyd1nABsIa/TvTpT3V8CmIZ9fCxwB\n/l99r0OVfJ00FdtnVP+e3QQnOOBoM3IVsCNnWV2Rqa+09MjUCwWp6exXuYhQo5YAvyNsPl1S9Ycj\nczyaqfXApKTdwC5gLXAqoXb0JfY4upTI1ClBaiz3fpUmm6mu6n0bYY3jIKHDvHSB+5czV/27r+8N\nacp67z0CfHZIGd8lxNE6CLwEPAysSPhu20kc2vp6RkG4baogXBkF4cooCFdGQbgyCsKVURCujIJw\nZRSEK6MgXBkF4cooiDcAgcEJ8jLx8x4AAAAASUVORK5CYII=\n",
2017-03-02 01:49:41 +08:00
"text/plain": [
2017-03-02 07:11:23 +08:00
"<matplotlib.figure.Figure at 0x119367110>"
2017-03-02 01:49:41 +08:00
]
},
"metadata": {},
"output_type": "display_data"
2017-03-01 22:29:54 +08:00
}
],
"source": [
2017-03-02 01:49:41 +08:00
"def remove_duplicates(k, l):\n",
" i = 0 ; removed_duplicate = False\n",
" while i < len(k) and not removed_duplicate: \n",
" if k[-i] == k[-i-1]:\n",
" ri = np.random.randint(2)\n",
" k = np.delete(k,-i-ri) ; l = np.delete(l,-i-ri) ; removed_duplicate = True\n",
" k,l = remove_duplicates(k,l)\n",
" i+=1\n",
" return k,l\n",
"\n",
"def guess_recall_order(real_y, pred_y, tsteps):\n",
" # sorry this is uuuuugly but we have to because it's batched\n",
" real_y = np.tile(real_y, [1,1,1,1]) ; real_y = np.transpose(real_y, (1,0,2,3))\n",
" pred_y = np.tile(pred_y,[1,1,1,1]) ; pred_y = np.transpose(pred_y, (1,0,2,3))\n",
" \n",
" pred_y = np.tile(pred_y,[1,tsteps,1,1])\n",
" pred_y = np.transpose(pred_y, (0,2,1,3))\n",
" \n",
" real_y = real_y[0,:,:,:] ; pred_y = pred_y[0,:,:,:]\n",
" y_minus = .5*(real_y - pred_y)**2\n",
" y_minus = np.sum(y_minus, axis=-1)\n",
" y_mins = np.amin(y_minus, axis=1)\n",
" \n",
" k, l = np.where(y_minus == np.tile(y_mins,[tsteps,1]).T)\n",
" k, l = remove_duplicates(k, l)\n",
" return l\n",
"\n",
"# test\n",
2017-03-01 22:29:54 +08:00
"X, real_y = next_batch(FLAGS.batch_size, FLAGS.length, FLAGS.xlen)\n",
"real_y = np.stack(real_y)[:,-FLAGS.length:,:]\n",
"\n",
2017-03-02 07:11:23 +08:00
"p = np.random.permutation(real_y[0].shape[0]) ; print 'permute by: ', p\n",
2017-03-02 01:49:41 +08:00
"pred_y = [y_i[p,:] for y_i in real_y]\n",
2017-03-01 22:29:54 +08:00
"pred_y = np.stack(pred_y)\n",
"\n",
2017-03-02 01:49:41 +08:00
"plt.figure(0, figsize=[1,1]) ; plt.imshow(real_y[0,:,:].T, interpolation='none')\n",
"plt.figure(1, figsize=[1,1]) ; plt.imshow(pred_y[0,:,:].T, interpolation='none')\n",
2017-03-01 22:29:54 +08:00
"\n",
2017-03-02 01:49:41 +08:00
"print \"guessed permutation: \", guess_recall_order(real_y, pred_y, FLAGS.length)"
2017-03-01 22:29:54 +08:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Build graph, initialize everything"
]
},
{
"cell_type": "code",
2017-03-02 07:11:23 +08:00
"execution_count": 9,
2017-03-01 22:29:54 +08:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"building graph...\n",
"defining loss...\n",
"computing gradients...\n",
"init variables... \n",
"ready to train..."
]
}
],
"source": [
"sess = tf.InteractiveSession()\n",
"\n",
"llprint(\"building graph...\\n\")\n",
"optimizer = tf.train.RMSPropOptimizer(FLAGS.lr, momentum=FLAGS.momentum)\n",
"dnc = DNC(RNNController, FLAGS)\n",
"\n",
"llprint(\"defining loss...\\n\")\n",
"y_hat, outputs = dnc.get_outputs()\n",
"y_hat = tf.clip_by_value(tf.sigmoid(y_hat), 1e-6, 1. - 1e-6) # avoid infinity\n",
"rlen = (dnc.tsteps-3)/2\n",
"loss = free_recall_loss(dnc.y[:,-rlen:,:], y_hat[:,-rlen:,:], tsteps=rlen)\n",
"\n",
"llprint(\"computing gradients...\\n\")\n",
"gradients = optimizer.compute_gradients(loss)\n",
"for i, (grad, var) in enumerate(gradients):\n",
" if grad is not None:\n",
" gradients[i] = (tf.clip_by_value(grad, -10, 10), var)\n",
" \n",
"grad_op = optimizer.apply_gradients(gradients)\n",
"\n",
"llprint(\"init variables... \\n\")\n",
"sess.run(tf.global_variables_initializer())\n",
"llprint(\"ready to train...\")"
]
},
{
"cell_type": "code",
2017-03-02 07:11:23 +08:00
"execution_count": 10,
2017-03-01 22:29:54 +08:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model overview...\n",
"\tvariable \"dnc_scope/basic_lstm_cell/weights:0\" has 73728 parameters\n",
"\tvariable \"dnc_scope/basic_lstm_cell/biases:0\" has 512 parameters\n",
"\tvariable \"W_z:0\" has 6144 parameters\n",
"\tvariable \"W_v:0\" has 768 parameters\n",
"\tvariable \"W_r:0\" has 60 parameters\n",
"total of 81212 parameters\n"
]
}
],
"source": [
"# tf parameter overview\n",
"total_parameters = 0 ; print \"model overview...\"\n",
"for variable in tf.trainable_variables():\n",
" shape = variable.get_shape()\n",
" variable_parameters = 1\n",
" for dim in shape:\n",
" variable_parameters *= dim.value\n",
" print '\\tvariable \"{}\" has {} parameters' \\\n",
" .format(variable.name, variable_parameters)\n",
" total_parameters += variable_parameters\n",
"print \"total of {} parameters\".format(total_parameters)"
]
},
{
"cell_type": "code",
2017-03-02 07:11:23 +08:00
"execution_count": 11,
2017-03-01 22:29:54 +08:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loaded model: rnn_models/model.ckpt-30000\n"
]
}
],
"source": [
"global_step = 0\n",
"saver = tf.train.Saver(tf.global_variables())\n",
"load_was_success = True # yes, I'm being optimistic\n",
"try:\n",
" save_dir = '/'.join(FLAGS.save_path.split('/')[:-1])\n",
" ckpt = tf.train.get_checkpoint_state(save_dir)\n",
" load_path = ckpt.model_checkpoint_path\n",
" saver.restore(sess, load_path)\n",
"except:\n",
" print \"no saved model to load.\"\n",
" load_was_success = False\n",
"else:\n",
" print \"loaded model: {}\".format(load_path)\n",
" saver = tf.train.Saver(tf.global_variables())\n",
" global_step = int(load_path.split('-')[-1]) + 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train loop"
]
},
{
"cell_type": "code",
2017-03-02 01:49:41 +08:00
"execution_count": 10,
2017-03-01 22:29:54 +08:00
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"loss_history = []\n",
"for i in xrange(global_step, FLAGS.iterations + 1):\n",
" llprint(\"\\rIteration {}/{}\".format(i, FLAGS.iterations))\n",
"\n",
" rlen = np.random.randint(1, FLAGS.length + 1)\n",
" X, y = next_batch(FLAGS.batch_size, rlen, FLAGS.xlen)\n",
" tsteps = 2*rlen+3\n",
"\n",
" fetch = [loss, grad_op]\n",
" feed = {dnc.X: X, dnc.y: y, dnc.tsteps: tsteps}\n",
"\n",
" step_loss, _ = sess.run(fetch, feed_dict=feed)\n",
" loss_history.append(step_loss)\n",
" global_step = i\n",
"\n",
" if i % 100 == 0:\n",
" llprint(\"\\n\\tloss: {:03.4f}\\n\".format(np.mean(loss_history)))\n",
" loss_history = []\n",
" if i % FLAGS.save_every == 0 and i is not 0:\n",
" llprint(\"\\n\\tSAVING MODEL\\n\")\n",
" saver.save(sess, FLAGS.save_path, global_step=global_step)"
]
},
{
"cell_type": "code",
2017-03-02 07:11:23 +08:00
"execution_count": 12,
2017-03-01 22:29:54 +08:00
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
2017-03-02 07:11:23 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1AAAAHpCAYAAACMQd2lAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzs3Xm4JFV5+PHvOzMwbMKgyOKKBFxgEBkQRQGNoBB3NFFG\nk7gTXPMjC4IoKIlCNAoiEncHFyAYNzQRBJcIoiKMomxRBNxYhmVmkIGBYeb9/VF1oadvL9XL7ep7\n7/fzPPUwXV3nnLerL33OqTp1TmQmkiRJkqTu5tQdgCRJkiRNF3agJEmSJKkiO1CSJEmSVJEdKEmS\nJEmqyA6UJEmSJFVkB0qSJEmSKrIDJUmSJEkV2YGSJEmSpIrsQEmSJElSRXagJEmSJKkiO1CSJEmS\nVNG06EBFxKsjYl1EPKruWAYREUdExJVDyOfvIuK3EbHBMOIaloh4d0SsqzuOcdHP9z2u360kqbVW\ndd+w2y3j0A7qVKdFxLzm2KzPNJNNiw4UkOU2FiJi74g4NiI27yHNg4AjgBOGEMISYEPg74aQ1zBV\n+p76OX91GvH3vYTx/G4lSa21qvt6brd0qWtqbQd1qtMiIoBPAJ+JiIc3vLUE6zPNUNOlA/U5YOPM\n/F3dgZSeBhwDLOghzeuAucCZzW9ExCsj4qLy6tK9EbGk4b3/KPffHhFfi4iHZuY9wGnAPwz2MWrT\nz/mr01C/705mwHcrSeqv3dKprqm7HdSpTvs3ivheBhwfEVuC9ZlmtmnRgcrCvXXH0SD6SPNq4OxW\nnyMzv5iZTwMuA67MzFc3vH0ycDbw6Mx8cWbeUu4/C9g+Ip7ZRyx16+f8dc80YpOpyJchf98VTOfv\nVpKmhSmsM/ptt7Sta8agHfRq2tdpH8vM72fm7cCbKTpaE6zPNCNNiw5U89jfifHGEfFnEbEkIpZH\nxIqI+ExEbNSUduLYx0XEWRGxMiJujYiTImJ+w3FLIuK6FmWvN7Y5Io4F3l++vL7Me22ncckRsT3w\nROD8Lh/1k8CuEbFHme6hwNuAl2bmnxoPzMylwO3Ai7rkSUQ8KiJOjYirI+Ku8vOfFRGPbvVZK57X\nfSLipxFxd0T8OiIO7RZHma7j+esj1idExOkRcTtwQcP7z4yISxrjazNO/WHl57spIlZHxOUR8Zqq\n8bb5jNtT7ftuqZfvVpJmsx7q+G51Rse6oOG4SnVfc7uloYxPR8QfyzKuLeu7eRXqxpbPQEXE7hHx\nrfJz/ykizo+Ip7Q5R13r9jafZXs61GmZeW3Dv/+Umbc2vLY+04w0r+4AKmoe+zvx77OAa4EjgUXA\n64GbgaPaHHtdeexTKTomCyiuqrQqo13ZXwYeCxwC/D1wW7n/Ftp7WpnH0g7HAHwR+ADw+oi4Cngf\ncHhmrm1z/FLg6V3yBHgyxWc+A/gDsD3wJuB7EbFzZq4uj6t0XiNiIXAusIxiuMEGwLvL1910O3+9\nxvol4FdlbFHGtzvwLeAG4F0Uf+fvAm5tSEdEbA38BFhLcafvVuAvgE9HxIMy82TgK13ibaXq991J\n1e9WkmazXup4aF1nVKkLeq371ms7RMR2wE+BzYGPA/8HPBz4S2ATutc1k9ooEbEz8ANgJcWzSfdR\nPG/0/YjYLzN/2uIcdWsztTJonWZ9ppknM8d+A15F8cP2qPL1scA64BNNx30ZWNa0b+LYrzTtP6XM\nc2H5+rPAtS3KPhZY27TvHxvjqRD/ceXxm1Q4dgnFj+HHgW27HPsx4M4Kec5vsW+v8ry8ssW56nhe\nga8Cq4CHN+x7HLCm+Vy1iaft+esj1s+3OP5s4E/ANg37dgDubYwP+BRFJ21BU/rTKa6YzR/29w38\nNUUn+aSGfRtSVOjb9frdurm5uc3mrYc6vlOdUbUuqFz3Mbndclp53O4dPkunuvFVze+V8dxNMcR/\nYt+2ZRviey3OUdc2U5u4WtZp1mdus3mbFkP42kiKTkajC4CHRMRmLY79aNO+j1BcfXru1IS3nocA\n92XmXRWO/STwIODyzLypy7HLgY273YLP4kFO4P6pRh9McRVqBcVVqPUOp8N5jYg5wHOAr2bmHxvK\n+D+KK3MDGTTWMr79ga9l5s0N+V5LcVeq0UuAbwBzI+IhExvwbWCLFuVV1fL7jojHUXTivgK8tuGt\nPYCtgcbvu9J3K0mqXMe3qt+gc12wAFg0SN0XEUExhO3szPxZT5+sfZ5zgGeX8fy2IZ6bKDp++zS1\nhXppMzWbVKdZn2m2m84dKIDm2WiWl//dssWx1zS9/g3FFZnthxzToPaguBK2uMKxEw+cdpzaNCI2\niojjIuJ3wD0UwxOWUXQStmiRpNN5fSiwMZPPJxRDEgbSR6zNz61t3SG++/dF8XzZAuBQiiESjdtn\nGvIapqcDXwOeCVzSsP9pwEWZ2fg9VvpuJUlA9Tp+vTqjQl2QFHXBIHXfQymG7l3R5bhePJRi6N+v\nWrx3FUX77pFN+3tpM3VjfaZZbbo8A9VOu2eDqsya1mrNhlbmttnfi9uAeRGxaWauandQRLyM4qrN\nvwMnRsQumdnpB3dL4K7GuzZtnEJx+/9E4McUt/cT+E9ad6IHOa+D6jXWu/ssZyKvL1AMrWjlF33m\n3fL7zszPAETEiymGjE54GsU49kZVv1tJ0mTt6vTmOqNqXTCMtkDd+q3bJ9Vp1mea7aZ7B6oXOwG/\nbXi9I8UP58TVqOW0Xnth+xb7er2KcnX538cAl7c6ICKeQfHM08kRsYDigdBDKR4mbecxFFeaunkp\nsCQzj2gobz79rcN0C0UFtFOL9x5fMY9O52/QWJcBqym+32aNMd9C8ZzU3Mz8bpc8h/Z9R7E+xh4U\nDw5P2Bv4YFMeVb9bSVL3Or6dSnVBOWSu37rvFuAOYGGX43qpa24B7qJ4BqvZEyjuvv2+h/w6aVmn\nWZ9pNpvuQ/iqCoq1CRq9jeLH6pzy9W+ALcpZdopExaw5L26R38RdhaqN+h+VMezZMrhiJp39s5zp\nJzNXUDzc+deN07C2sAi4qEL5a5n8Xb+NPq6oZeY6ivHeL46IR0zsj4gnUIwPr6LT+Rso1jK+88v4\ntm2Ib0fgoKbjvgy8NCJ2ac4nIraqGG8rnb7vPwNWZObvy3J2oBia+NOm46p+t5I023Wq45uffV1P\n1bpgkLqvHM72NeAFEdHp2drKdU0Zz7eBFzVNlb4NxSMAF2Tmnd3yqahdnWZ9pllrNt2BekxEfJ2i\nw/Q04JXAFzLzl+X7Z1Kspv21iDgZ2BQ4jGJsc/MP3qUUPybvi4gzKWbWOTszWw4ny8zrIuJy4ACK\nWfbuFxF7AUdl5sFNyT5VxvhyihW+aUq3B/Bgih/lbr4J/E1E3AFcSXGFaH+K54v6cSxFZ+TCiDiV\nYirXt1BcmXpihfSdzt8wYn03RYV2UUT8B8Xf+ZvL+HZrOO5IivHbP4mIT5blPZjiitqzgIlO1NC+\nb4ornRtERJSV6huBizNzzcQBPX63kqT2dXzLUR9NqtYFg9R976CY9OEHEfEJijsyD6O4e/P0zLyD\nHusa4J0U9cwPy3jWUoxc2RA4ok2annWo06zPNHvVPQ1glY3W05ivBR7c6bimYx9HsQbCCorG+EnA\nhk3p9wcuo7hNfyXFVZxJ05iXx76D4oHMNc1ltvkM/4/ieZ6J6VBfTDFOeA3F+OLnNBz7VIox12vL\neL/c4rOeAFxX8fxtTtEhu7mM4b8phiFcC3y6xbmqcl73AS4uz9WvgTe0O1dtYmp5/iiuXvUda8P7\nz6R4sHUivtdTTLe6qum4rSjW/bieYujfHymu6r12mN9303v/QvGM15EUdz6P6fe7dXNzc5vNW9U6\nvkKdUbUuqFT3tak3H0HxvNBNFMPvfg18GJjXcEy7unFSfuX+3YD/KeubPwHnAXu1OUdd6/YO57ll\nnWZ95jZbt8ic2ZOiRLG69zHAQzPz9hrj2Jzix+WIzPxst+O75LUhxY/8+zLzlCGENytExFeBnTOz\n1ZjxYZc16fuOiA2ADwE
2017-03-01 22:29:54 +08:00
"text/plain": [
2017-03-02 07:11:23 +08:00
"<matplotlib.figure.Figure at 0x11a05afd0>"
2017-03-01 22:29:54 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X, y = next_batch(FLAGS.batch_size, FLAGS.length, FLAGS.xlen)\n",
"tsteps = 2*FLAGS.length+3\n",
"\n",
"feed = {dnc.X: X, dnc.y: y, dnc.tsteps: tsteps}\n",
"fetch = [outputs['y_hat'], outputs['w_w'], outputs['w_r'], outputs['f'], outputs['g_a']]\n",
"[_y_hat, _w_w, _w_r, _f, _g_a] = sess.run(fetch, feed)\n",
"_y_hat = np.clip(_y_hat, 1e-6, 1-1e-6)\n",
"_y = y[0] ; _X = X[0]\n",
"\n",
"fig, ((ax1,ax2),(ax3,ax5),(ax4,ax6),) = plt.subplots(nrows=3, ncols=2)\n",
"plt.rcParams['savefig.facecolor'] = \"0.8\"\n",
"fs = 12 # font size\n",
"fig.set_figwidth(10)\n",
"fig.set_figheight(5)\n",
"\n",
"ax1.imshow(_X.T - _y.T, interpolation='none') ; ax1.set_title('input ($X$) and target ($y$)')\n",
"ax2.imshow(_y_hat[0,-FLAGS.length:,:].T, interpolation='none') ; ax2.set_title('prediction ($\\hat y$)')\n",
"\n",
"ax3.imshow(_w_w[0,:,:].T, interpolation='none') ; ax3.set_title('write weighting ($w_w$)')\n",
"ax4.imshow(_w_r[0,:,:,0].T, interpolation='none') ; ax4.set_title('read weighting ($w_r$)')\n",
"\n",
"ax5.imshow(_f[0,:,:].T, interpolation='none') ; ax5.set_title('free gate ($f$)') ; ax5.set_aspect(3)\n",
"ax6.imshow(_g_a[0,:,:].T, interpolation='none') ; ax6.set_title('allocation gate ($g_a$)') ; ax6.set_aspect(3)\n",
"\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
2017-03-02 07:11:23 +08:00
"execution_count": 26,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"FLAGS.length = 12"
]
},
{
"cell_type": "code",
"execution_count": 27,
2017-03-01 22:29:54 +08:00
"metadata": {
2017-03-02 01:49:41 +08:00
"collapsed": false
2017-03-01 22:29:54 +08:00
},
2017-03-02 01:49:41 +08:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"starting free recall trials: \n",
"\ttrial #0\n",
"\ttrial #100\n",
"\ttrial #200\n",
"\ttrial #300\n",
"\ttrial #400\n",
"\ttrial #500\n",
"\ttrial #600\n",
"\ttrial #700\n",
"\ttrial #800\n",
2017-03-02 07:11:23 +08:00
"\ttrial #900\n",
"\ttrial #1000\n",
"\ttrial #1100\n",
"\ttrial #1200\n",
"\ttrial #1300\n",
"\ttrial #1400\n",
"\ttrial #1500\n",
"\ttrial #1600\n",
"\ttrial #1700\n",
"\ttrial #1800\n",
"\ttrial #1900\n",
"\ttrial #2000\n",
"\ttrial #2100\n",
"\ttrial #2200\n",
"\ttrial #2300\n",
"\ttrial #2400\n",
"\ttrial #2500\n",
"\ttrial #2600\n",
"\ttrial #2700\n",
"\ttrial #2800\n",
"\ttrial #2900\n",
"\ttrial #3000\n",
"\ttrial #3100\n",
"\ttrial #3200\n",
"\ttrial #3300\n",
"\ttrial #3400\n",
"\ttrial #3500\n",
"\ttrial #3600\n",
"\ttrial #3700\n",
"\ttrial #3800\n",
"\ttrial #3900\n",
"\ttrial #4000\n",
"\ttrial #4100\n",
"\ttrial #4200\n",
"\ttrial #4300\n",
"\ttrial #4400\n",
"\ttrial #4500\n",
"\ttrial #4600\n",
"\ttrial #4700\n",
"\ttrial #4800\n",
"\ttrial #4900\n"
2017-03-02 01:49:41 +08:00
]
}
],
"source": [
"recall_orders = []\n",
2017-03-02 07:11:23 +08:00
"trials = 5000 ; print \"starting free recall trials: \"\n",
2017-03-02 01:49:41 +08:00
"for i in range(trials):\n",
" X, y = next_batch(FLAGS.batch_size, FLAGS.length, FLAGS.xlen)\n",
" tsteps = 2*FLAGS.length+3\n",
"\n",
" feed = {dnc.X: X, dnc.tsteps: tsteps}\n",
" _y_hat = outputs['y_hat'].eval(feed)\n",
" _y_hat = np.clip(_y_hat, 1e-6, 1-1e-6)\n",
" _y = y[0] ; _X = X[0]\n",
" \n",
" real_y = np.tile(_y[-FLAGS.length:,:], [1,1,1])\n",
" pred_y = _y_hat[0:1,-FLAGS.length:,:]\n",
" order = guess_recall_order(real_y, pred_y, FLAGS.length)\n",
" \n",
" recall_orders.append(order)\n",
" if (i%100 == 0): print(\"\\ttrial #{}\".format(i))\n",
"recall_orders = np.stack(recall_orders)"
]
2017-03-01 22:29:54 +08:00
},
{
"cell_type": "code",
2017-03-02 07:11:23 +08:00
"execution_count": 29,
2017-03-01 22:29:54 +08:00
"metadata": {
2017-03-02 01:49:41 +08:00
"collapsed": false
2017-03-01 22:29:54 +08:00
},
2017-03-02 01:49:41 +08:00
"outputs": [
{
"data": {
2017-03-02 07:11:23 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAArsAAAHUCAYAAAA6KeCQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3XmYXGWd9//3Jwm7EEGQyIiCO25ooo77Mqi4/NRxGTUu\nOOjzOO4anVHncRTUcRdwRXFFR42DjjrgqCjuGy6JwKggyuaObIYdQvL9/XFOQ6WoTndVV6c7p9+v\n6zpXqk7dp863Th+aT991n/ukqpAkSZK6aNFcFyBJkiTNFsOuJEmSOsuwK0mSpM4y7EqSJKmzDLuS\nJEnqLMOuJEmSOsuwK0mSpM4y7EqSJKmzDLuSJEnqLMOupHkpya2SfDXJX5NsSPLoua5pPkryj0k2\nJrlZz7pvJfnGNLffKcm5SVbOXpXTquPm7ec4aJrtNyZ5zWzXNZuSHJjkkiQ3mutapC4z7EoLSJJn\ntCFhYrkiyR+SfCXJC5PcYMA2h7Rt/5Rk+wGvn53k2AHrt0uyKsmJbWC9Ismvkrw7ya2nUe7HgTsA\n/w94OvDTET7yQlDt0r9uul4CXAx8emwVjW6TupM8PMkhm2k7b+93n+TeSb6X5LL2v513Jtmpt01V\nHQ/8BvjXualSWhiWzHUBkra4Al4NnA1sAywDHgi8A3hpkkdX1f8O2O7GwHOBIwa83ybanqrjgbsC\nXwQ+CVwK3BZ4MvB/gesF557ttwfuCby+qo6c/kfTMJIsAV4EHFZVcxocq+qcJDsA63tWPwJ4HvDa\nAZvsAFyzJWobVpK7ACcAvwRWATcF/gW4FfDIvuZHAW9LckhVXbZFC5UWCMOutDB9parW9jx/S5IH\nAv8D/HeS/arqqr5tTgL+JcmRA17r9zFgf+DxVfWF3heSvBp4wxTb37j9d90U7UiyY1VdPlW72ZRk\n+6q6ci5rGNGjgN2Bz8x1IQBVdXXfqgzRdj55I3Ah8ICJAJvkHOADSR5cVSf0tP0v4N3APwBHb+lC\npYXAYQySAKiqbwGvB24OPK3/ZeB1NL3Az93c+yS5B02P3If6g267n/VV9fLNbH8ITa9zAW9vh1Cc\n2b52aPt8vySfSnIh8N2ebW+b5LNJLmiHTfwkyaMG7GNpknck+W2SK5P8OsnLk0warnq2PTvJsUke\n2r7/FcCze15/WpKfJrm8rWN1kpsOeJ+/TfKlJBcmuTTJyUle1PP6nZJ8NMkZ7Wf5U5IPJ9ltqhqH\n8Bjg7Ko6q6+2o9uxpPsmOb6t7w/tHyr9n2PHJIf1HMvTkrxsQLuHJPlukova9z4tyRt6Xt9kzG6S\nj9L06k6Mz92YZENP++uN2U1y1yRfTrKu3ccJSf62r83EUJ57Jzk8yV/az/e5jGHsbJKdgQcD/9HX\nU/tx4DLgib3tq+o84BSan4WkWWDPrqRe/0HTK/VQ4MN9r30X+Abw8iTv20zv7qNpguonRqzhv4CL\naIZVfAr4Es0QCLhuyMRngNNpxjoGIMkdgO8BvwfexHXB4gtJHldV/9222wH4DnAT4P3A74B7t9ss\nA146RX0F3K6t7SjgA8Cv2vd+Fc0fBZ8GPgjsQTNM4NtJ7lpVF7ftHgIcB/yx/Zx/Bvaj+Yr7Xe1+\nHgLsC3ykff0OwD8BtwfuNUWN03VvYO2A9UXTGfIV4Ic0X8E/DHhtksVVdWhP2+OABwAfAk4GDqT5\nWn6vqnoZQJLbt+1OohlCcxXNV/r33kxt7wf2ogmOT2Uzvbw9+/gOzbcBb6YZ4vBPwLeS3L+qftK3\nybtpel8PBfahGW7wHuDaC/XaMbaTDrfpsX7iZwvcieb/rWt6G1TV+iQn0Qzt6bcGw640e6rKxcVl\ngSzAM4ANwPLNtLkI+GnP80PabXYD7gdsBF7c8/pZwLE9z/+rbb/LDOq8ebufl/atP6Rd/x8DtjkB\n+BmwpG/994DTep7/G80FWbfoa/dG4Grgb6ao7az28z24b/3NaMabvqJv/e3b931l+3wRcCZwBrDz\nZvaz3YB1T2r3fZ8BP9Ob9az7JvCNKT7H4na7tw547aPta0f0rT8OuALYrX3+mPbn8cq+dsfQhM19\n2+cvbt9v12n8zA/qWfduYMMk7TcCr+l5/vm2tpv3rFtGE36/2Xe8NtIM5el9v8Pan9POtelx2DiN\n5Rs92zy+/2fU89p/An8YsP6V7Ta7j/rfjIuLy+SLwxgk9bsU2HnQC1X1XZog9fIk202y/S7tv5fM\nQm3Q9Doe1bsiya7Ag2h6fJcmudHEAnwVuHWSm7TNn0DTS72ur93XaXrk7j+NGs6qTcddQhNyAnym\n733/Avy6rQ9gOU1P4juqatJjVD0952lmtrgR8KN2H8unUeNUdmvf66LNtHlv3/P3ANvR9LZCM1zl\nGppQ2uswmlD/8Pb5X9t/HzudoSLDSrKIpif881V1zsT6qvozTQ/8fbPpTCNF0yPf67s0fwDcvGfd\nW2g+61RL77CNHdp/B33zcWXP670mfga7D/6EkmbCYQyS+t0AOHczrx8KfBt4DvDOAa9PfJ27c8/j\ncTur7/mtaILb64F/H9C+aC56+xNwa5qvms/bTLth9z9RwyKaqaQGve/EBVW3aJ//YnM7aAP8oTS9\nub01FbB0GjVO12ThcyNND3Sv09t/92n/vRnwx7r+LAKntv9OBMf/BJ5FM7TjzUm+DnwO+GxVjWMW\niD2AHXvq669lEbB3T13QDF/pNRE4d51YUVWnAacNWcsV7b+D/hjcvuf1XhM/g3k7lZq0NTPsSrpW\nkr+hCVKDAhvQ9O4m+RZN7+5RA5pMhIM7Ad8fe5GN/sAw8S3V22mmPBvkNz1tv0bTazco6A0KTFPt\nf+J9N9KMbd044PVLB6zbnM/QTL/2VpqxsJe2+zie8VxcfCFNuNp1qoYzVc1MFfdP8iCacckPownx\nX0/y0DEF3mFtmGT9tedEkl0Y3BPb7+qqmgjLf2rf4yYD2t2EZpx2v4mfwfnT2JekIRl2JfU6iCYA\nfWWKdofSDGf4pwGvHUdz4djTmL2w22+iB3J9VU1157AzgBtU1TfHXMMZNCHn7Kqa9I+FnnZ3pLng\n73qS3BD4O+DVVdU7Y8GtxlVsVW1IcgbNRXCDLKLphe79LLdt/53o2T4HOCDJTn29u/v1vN67z2/S\nnDf/nORfaXrhH8Qkx4Hp93SeB1zeU1+v/Wj++OjvyZ2Od9KM8Z3Kt2h+XgA/pxnacTfgsxMNkmwD\n3IWml7vfvsD5VXXBCDVKmoJjdiUBkOTvaC7eOpNmnOOkquo7NEMZXkHf1epVdSJNWP4/Sa53hXmS\nbZO8bVx1t/s8jyZw/FOSZQP22TsW8hjgXkkeOqDd0iSLRyzjczShauAdv3qmDFtLExZfkmSy4QgT\nvY79v6NXMd6vun9IE8om84IBz6/munD6JZpOk/52q2iOxZfh2iEZ/U6mCf2Tjf2GZkaNiR7WSVXV\nRpqx2Y/JprdN3pNmdoXvVtWwPeswwpjdamZlOAF4Wja9Y9pBwE4051+/FTQ/C0mzwJ5daeEJ8Igk\n+9H8DtiTplfqITQh7NE1vQn7X0vTSzfIQTRft/9Xki/SXPx1Gc142SfTXCX/LzP5EAM8n+Yio/9N\n8kGa0L4nzTRdf8N1Uz69jWZ6tC8mOZpm2qedgDsDj6MZj3rhsDuvqjOT/BvwxiT7Al+guUjvFsDf\n01xUd3hVVZLnAscCJ7Xzyf6JZjqz21fVw6vqkiTfoRkqsi3wB5rp4PZhiim4hvTfNKHsVgN6o68C\nHtYeox/RXIz2cOANPT2Qx9GcA29oP/PE1GOPopnJYaIH+DVJ7k9z05JzaH4uzwV+SzNbxmTW0Hze\ndyc5nmZmhkE9o9D8ofZg4PtJjqT5g+HZwLZA/7zOkx3DTdaPOGYX4FU032p8J8kHaMYLvxQ4vqq+\ntskOkz1ozr3+i/wkjYlhV1p4iutuv3o1TbD7X5r5YI8ecLHR4Dep+naSb9PMXlB9r52f5N40NwV4\nEs3X1dvSfJX8RZq5Zad
2017-03-02 01:49:41 +08:00
"text/plain": [
2017-03-02 07:11:23 +08:00
"<matplotlib.figure.Figure at 0x11bef3d50>"
2017-03-02 01:49:41 +08:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
2017-03-02 07:11:23 +08:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAArsAAAHUCAYAAAA6KeCQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAIABJREFUeJzt3Xe8nGWd///XO6FJMaIo0RXFtooNTdTFXhCw/AS7Rl1c\ndNcVK9Fdy9d1saxdwIpiA101LroWZFUU7AVLIuiu2GgWFJASEBJaPr8/7vvAZJhzcmbOnJyT+7ye\nj8f9ODP3fd0zn7lzn+Sda677ulNVSJIkSV20aK4LkCRJkmaLYVeSJEmdZdiVJElSZxl2JUmS1FmG\nXUmSJHWWYVeSJEmdZdiVJElSZxl2JUmS1FmGXUmSJHWWYVfSnEty+yRfTXJxkmuS7D/XNc1HSf4h\nyYYkt+pZ980kX5/m/jskOTfJitmrclp13Lr9HAdOs/2GJP8+23XNF0lWJfmvua5D6grDrtRRSZ7Z\nhoSJZV2SPyb5SpIXJtlxwD6Htm3/lGS7AdvPSnLcgPXbJlmZ5OQ2sK5L8qsk705yh2mU+zHgLsD/\nA/4e+MkIH3khqHbpXzddhwCXAJ8aW0Wj26juJI9McugUbeflve2T7JPkw0l+nuTqJGdM0u6OSd6a\n5KdJLklyTpLjkywf0PwtwBOS3G12q5cWhq3mugBJs6qAVwNnAVsDS4GHAO8AXpJk/6r6+YD9bgYc\nDBwx4PU2kuQmwAnAPYHjgU8AfwXuCDwV+CfgesG5Z//tgL2A11fVkdP/aBpGkq2AFwGHVdWcBseq\nOjvJDYCrelY/Cnge8NoBu9wAuHpz1DaCpwFPBtYAf5yi3T8CzwL+G3gvsAT4Z+DkJPtV1bW981V1\nSpKfAC8F/mGW6pYWDMOu1H1fqao1Pc/fkuQhwP8AX0iyR1Vd0bfPKcC/JjlywLZ+HwX2BJ5QVZ/v\n3ZDk1cAbNrH/zdqfazfRjiTbV9Xlm2o3m5JsV1Xr57KGET0G2AX49FwXAlBVV/atyhBt55NXAv9Y\nVdck+SLNNxSDfBI4tPf8TXI0cBrwGqB/KMqxwGuSPG+uz3lpS+cwBmkBqqpvAq8Hbg08o38z8Dqa\nXuCDp3qdJPeh6ZH7UH/Qbd/nqqp62RT7H0rT61zA29shFGe0217TPt8jySeTXAh8p2ffOyb5TJIL\n2mETP07ymAHvsSTJO5L8Lsn6JL9J8rIkk4arnn3PSnJckn3b118HPKdn+zOS/CTJ5W0dq5LccsDr\n/F2SLyW5MMlfk5ya5EU92++W5Ogkp7ef5U/tV+M33lSNQzgAOKuqzuyr7Zgklya5TZIT2vr+2P5H\npf9zbJ/ksJ5j+cskLx3Qbp8k30lyUfvav0zyhp7tG43ZbUPf89rHE8Nurulpf70xu0numeTLSda2\n73Fikr/razMxlOd+SQ5Pcl77+T7bfiMxY1X156q6ZhrtftofWqtq4pzeY8AuXwN2BPYZR53SQmbP\nrrRw/SfwRmBf4MN9275D09P0siTvm6J3d3+aoPrxEWv4b+AimmEVnwS+RDMEAq4bMvFp4Nc0PWgB\nSHIX4LvAH4A3AZfRfJX8+SSPr6ovtO1uAHwbuDnwfuD3wP3afZYCL9lEfQXcqa3tKOADwK/a134V\nzX8KPgV8ELgpzTCBbyW5Z1Vd0rbbB/gicE77Of9ME24eDbyrfZ99gNsAH2m334XmK+47A/fdRI3T\ndT+ar9oHfcZFwFeAHwD/CjwCeG2SxVX1mp62XwQeDHwIOBXYD3hbkltU1UsBkty5bXcKzRCaK4Db\nt+8/mfcDtwAeDjydKXp5e97j2zTfBryZZojDPwPfTPKgqvpx3y7vBi6k6UHdHVgJvAe49kK9JDsw\nxXCbHldN/NmOyVLgLwPW/wJYB9wf+MIY309aeKrKxcWlgwvwTOAaYNkUbS4CftLz/NB2nxsDDwQ2\nAC/u2X4mcFzP8/9u299wBnXeun2fl/StP7Rd/58D9jkR+CmwVd/67wK/7Hn+bzQXZN22r90bgSuB\nv9lEbWe2n+/hfetvRTPe9OV96+/cvu4r2ueLgDOA04GdpnifbQese0r73vcf8Gd6q5513wC+vonP\nsbjd760Dth3dbjuib/0XacLWjdvnB7R/Hq/oa3csTdi8Tfv8xe3r7TyNP/MDe9a9G7hmkvYbgH/v\nef65trZb96xbShN+v9F3vDbQDOXpfb3D2j+nnWrj47BhGsukx7o9ZmcMce4/sD1Wh06y/ZfA8aP+\nbrm4uDSLwxikhe2vwE6DNlTVd2iC1MuSbDvJ/jdsf146C7VB0+t4VO+KJDsDD6Xp8V2S5CYTC/BV\n4A5Jbt42fyJNL/XavnYn0Xyz9aBp1HBmVZ3Yt+4JNL2Pn+573fOA37T1ASyj6Ul8R1VNeoyqp+c8\nzcwWNwF+2L7HsmnUuCk3bl/roinavLfv+XuAbWl6W6EZrnI1TSjtdRhNqH9k+/zi9ufjpjNUZFhJ\nFtH0hH+uqs6eWF9Vf6bpgX9ANp5ppGh65Ht9h+Y/ALfuWfcWms+6qeV6wzZG/Bw3bes9HXjbJM0u\nohlnLWkGHMYgLWw7AudOsf01wLeA5wLvHLB94uvcnXoej9uZfc9vTxPcXg/8x4D2RXPR25+AOwB3\nA86fot2w7z9RwyLgt5O87sQFVbdtn//fVG/QBvjX0PTm9tZUNFftj8tk4XMDTQ90r1+3P3dvf94K\nOKeqLutrd1r7cyI4/hfwbJqhHW9OchLwWeAzVTWOWSBuCmzfU19/LYuA3Xrqgmb4Sq+J0L/zxIqq\n+iVNT+qsS7I9zQWiOwD71uQXoIV5OuWatCUx7EoLVJK/oQlSgwIb0PTuJvkmTe/uUQOaTISDuwHf\nG3uRjXV9zye+kXo7zZRng/y2p+3XaHrtBgW9QYFpU+8/8bobaMa2bhiw/a8D1k3l0zTTr72VZizs\nX9v3OIHxXEh8IU1o2nlTDWeqmpkqHpTkoTTjkh9BE+JPSrLvmALvsCa7gOzacyLJDWmmONuUK6tq\nqh7yKSXZmmYYxl1pgu5pUzTfmemdo5KmYNiVFq4DaQLQVzbR7jU0wxn+ecC2L9JcOPYMZi/s9pvo\ngbyqeuYmncTpwI5V9Y0x13A6TVA6q6om/c9CT7u7cv2ppQBIciPgYcCrq6p3xoLbj6vYaqbFOp3m\nIrhBFtH0Qvd+lju2Pyd6ts8G9k6yQ1/v7h4923vf8xs0582/JHklTS/8Q5nkODD9Hszzgct76uu1\nB81/Pvp7cqfjnTRjfDflmzR/XkNrh3X8J81xeFJVfXeKtotpeqi9OE2aIcfsSgtQkofRXLx1Bs24\nwUlV1bdphjK8nL6r1avqZJqw/I9JDhjwPtskmWw84kiq6nyawPHPSZYOeM/eMY7HAvdNsu+Adkva\nQDGKz9KEqoF3/OqZMmwNTVg8JMlkwxEmeh37/z5eyXi/wv4BcK8ptr9gwPMruS6cfommg6S/3Uqa\nY/FluHZIRr9TaUL/ZGO/oZlRY6KHdVJVtYFmbPYB2fi2ybvSzK7wnaoatmcdNs+Y3fcATwIOrnbG\nkCncmeb3bXP9J1LqLHt2pW4L8Kgke9D8vu9K0yu1D00I27+mN2H/a2l66QY5kObr9v9OcjzNxV+X\n0YyXfSrNVfL/OpMPMcDzaS4y+nmSD9KE9l1ppun6G5q7uUFz4c/+wPFJjgFW04yTvDvweJrxqBcO\n++ZVdUaSfwPemOQ2wOdpLtK7LfBYmovqDq+qSnIwcBxwSjuf7J9opjO7c1U9sqouTfJtmqEi29Dc\nhWvftrZxXuD1BeAZSW4/oDf6CuAR7TH6Ic3FaI8E3lBVF7RtvkhzDryh/cwTU489hmYmh4ke4H9P\n8iCaMaln0/y5HAz8jma2jMmspvm8705yAs3MDP81Sdt/owme30tyJM1/GJ4DbAP0z+s82THcaP2o\nY3bT3NJ3//bp7WkumnxV+/zUqjq+bXcIzXH4PrA+ydP7XuqzVdU7ZGZfmt+j/osjJQ1rrqeDcHFx\nmZ2F66apmljW0QSpr9C
2017-03-02 01:49:41 +08:00
"text/plain": [
2017-03-02 07:11:23 +08:00
"<matplotlib.figure.Figure at 0x11df76e50>"
2017-03-02 01:49:41 +08:00
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def plot_recall(probs, pos):\n",
" # print \"Is normalized: \", np.abs(1-np.sum(position_probs)) < 1e-6\n",
2017-03-02 07:11:23 +08:00
" plt.figure(pos, figsize=[8,5])\n",
" plt.axis((0,4,0.0,1.0))\n",
2017-03-02 01:49:41 +08:00
" plt.title(\"DNC free recall (position={})\".format(pos))\n",
" plt.xlabel(\"Serial position\") ; plt.ylabel(\"Recall probability\")\n",
" plt.plot(range(len(probs)), probs) ; plt.show()\n",
"\n",
2017-03-02 07:11:23 +08:00
"for pos in range(1):\n",
2017-03-02 01:49:41 +08:00
" position_probs = [np.sum(i == recall_orders[:,pos])/float(trials) for i in range(FLAGS.length)]\n",
" plot_recall(position_probs, pos=pos)\n",
" \n",
"p0_probs = np.asarray([np.sum(i == recall_orders[:,0])/float(trials) for i in range(FLAGS.length)])\n",
"p1_probs = np.asarray([np.sum(i == recall_orders[:,1])/float(trials) for i in range(FLAGS.length)])\n",
"plot_recall(.5*(p0_probs + p1_probs), pos=12)"
]
2017-03-01 22:29:54 +08:00
},
{
"cell_type": "code",
2017-03-02 01:49:41 +08:00
"execution_count": null,
2017-03-01 22:29:54 +08:00
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
2017-03-02 01:49:41 +08:00
"# # X, y = next_batch(FLAGS.batch_size, FLAGS.length, FLAGS.xlen)\n",
"# # y = np.stack(y)\n",
"# # y = y[:,FLAGS.length + 3:,:]\n",
"# # # p = np.random.permutation(FLAGS.length)\n",
"# # y_shuffle = y #y[:,p,:] ; print 'permute by: ', p\n",
"# # # y_shuffle[:,-1,-1] = 4\n",
"\n",
"# # plt.figure(0, figsize=[1,1])\n",
"# # plt.imshow(y[0,:,:].T, interpolation='none')\n",
"# # plt.figure(1, figsize=[1,1])\n",
"# # plt.imshow(y_shuffle[0,:,:].T, interpolation='none')\n",
"# # plt.show()\n",
"\n",
"# def guess_recall_order(real_y, pred_y, FLAGS):\n",
2017-03-01 22:29:54 +08:00
"# # sorry this is uuuuugly but we have to because it's batched\n",
"# real_y = np.tile(real_y, [1,1,1,1]) ; real_y = np.transpose(real_y, (1,0,2,3))\n",
"# pred_y = np.tile(pred_y,[1,1,1,1]) ; pred_y = np.transpose(pred_y, (1,0,2,3))\n",
" \n",
2017-03-02 01:49:41 +08:00
"# pred_y = np.tile(pred_y,[1,FLAGS.length,1,1])\n",
2017-03-01 22:29:54 +08:00
"# pred_y = np.transpose(pred_y, (0,2,1,3))\n",
" \n",
2017-03-02 01:49:41 +08:00
"# # real_y = real_y[0,:,:,:] ; pred_y = pred_y[0,:,:,:]\n",
2017-03-01 22:29:54 +08:00
"# y_minus = .5*(real_y - pred_y)**2\n",
"# y_minus = np.sum(y_minus, axis=-1)\n",
2017-03-02 01:49:41 +08:00
"# y_mins = np.amin(y_minus, axis=1)\n",
2017-03-01 22:29:54 +08:00
" \n",
2017-03-02 01:49:41 +08:00
"# k, l = np.where(y_minus == y_mins)\n",
"# i = 0\n",
"# while i < len(k): \n",
"# if k[-i] == k[-i-1]:\n",
"# k = np.delete(k,-i) ; l = np.delete(l,-i)\n",
"# i+=1\n",
"# return l\n",
2017-03-01 22:29:54 +08:00
"\n",
"# X, real_y = next_batch(FLAGS.batch_size, FLAGS.length, FLAGS.xlen)\n",
"# real_y = np.stack(real_y)[:,-FLAGS.length:,:]\n",
"\n",
2017-03-02 01:49:41 +08:00
"# p = np.random.permutation(y_i.shape[0]) ; print 'permute by: ', p\n",
"# pred_y = [y_i[p,:] for y_i in real_y]\n",
2017-03-01 22:29:54 +08:00
"# pred_y = np.stack(pred_y)\n",
"# # pred_y[:,-2:,-1] = 4\n",
"\n",
2017-03-02 01:49:41 +08:00
"# plt.figure(0, figsize=[1,1])\n",
"# plt.imshow(real_y[0,:,:].T, interpolation='none')\n",
"# plt.figure(1, figsize=[1,1])\n",
"# plt.imshow(pred_y[0,:,:].T, interpolation='none')\n",
"# plt.show()\n",
"\n",
2017-03-01 22:29:54 +08:00
"# np_loss(real_y, pred_y, FLAGS)"
]
2017-03-02 01:49:41 +08:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def scatter(idx, vals, target):\n",
" \"\"\"target[idx] += vals, but allowing for repeats in idx\"\"\"\n",
" np.add.at(target, idx.ravel(), vals.ravel())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
2017-03-01 22:29:54 +08:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
2017-03-02 07:11:23 +08:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.10"
2017-03-01 22:29:54 +08:00
}
},
"nbformat": 4,
"nbformat_minor": 1
}