Modify copy task and readme

This commit is contained in:
ixaxaar 2017-12-18 12:38:45 +05:30
parent 264bdfb2f0
commit 60f2026d80
2 changed files with 142 additions and 16 deletions

110
README.md
View File

@ -1,4 +1,9 @@
# Differentiable Neural Computers and Sparse Differentiable Neural Computers, for Pytorch # Differentiable Neural Computers and family, for Pytorch
Includes:
1. Differentiable Neural Computers (DNC)
2. Sparse Access Memory (SAM)
3. Sparse Differentiable Neural Computers (SDNC)
<!-- START doctoc generated TOC please keep comment here to allow auto update --> <!-- START doctoc generated TOC please keep comment here to allow auto update -->
<!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE --> <!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE -->
@ -22,7 +27,7 @@
[![Build Status](https://travis-ci.org/ixaxaar/pytorch-dnc.svg?branch=master)](https://travis-ci.org/ixaxaar/pytorch-dnc) [![PyPI version](https://badge.fury.io/py/dnc.svg)](https://badge.fury.io/py/dnc) [![Build Status](https://travis-ci.org/ixaxaar/pytorch-dnc.svg?branch=master)](https://travis-ci.org/ixaxaar/pytorch-dnc) [![PyPI version](https://badge.fury.io/py/dnc.svg)](https://badge.fury.io/py/dnc)
This is an implementation of [Differentiable Neural Computers](http://people.idsia.ch/~rupesh/rnnsymposium2016/slides/graves.pdf), described in the paper [Hybrid computing using a neural network with dynamic external memory, Graves et al.](https://www.nature.com/articles/nature20101) This is an implementation of [Differentiable Neural Computers](http://people.idsia.ch/~rupesh/rnnsymposium2016/slides/graves.pdf), described in the paper [Hybrid computing using a neural network with dynamic external memory, Graves et al.](https://www.nature.com/articles/nature20101)
and the Sparse version of the DNC (the SDNC) described in [Scaling Memory-Augmented Neural Networks with Sparse Reads and Writes](http://papers.nips.cc/paper/6298-scaling-memory-augmented-neural-networks-with-sparse-reads-and-writes.pdf). and Sparse DNCs (SDNCs) and Sparse Access Memory (SAM) described in [Scaling Memory-Augmented Neural Networks with Sparse Reads and Writes](http://papers.nips.cc/paper/6298-scaling-memory-augmented-neural-networks-with-sparse-reads-and-writes.pdf).
## Install ## Install
@ -252,6 +257,107 @@ Memory vectors returned by forward pass (`np.ndarray`):
| `debug_memory['write_weights']` | layer * time | nr_cells | `debug_memory['write_weights']` | layer * time | nr_cells
| `debug_memory['usage']` | layer * time | nr_cells | `debug_memory['usage']` | layer * time | nr_cells
### SAM
**Constructor Parameters**:
Following are the constructor parameters:
| Argument | Default | Description |
| --- | --- | --- |
| input_size | `None` | Size of the input vectors |
| hidden_size | `None` | Size of hidden units |
| rnn_type | `'lstm'` | Type of recurrent cells used in the controller |
| num_layers | `1` | Number of layers of recurrent units in the controller |
| num_hidden_layers | `2` | Number of hidden layers per layer of the controller |
| bias | `True` | Bias |
| batch_first | `True` | Whether data is fed batch first |
| dropout | `0` | Dropout between layers in the controller |
| bidirectional | `False` | If the controller is bidirectional (Not yet implemented |
| nr_cells | `5000` | Number of memory cells |
| read_heads | `4` | Number of read heads |
| sparse_reads | `4` | Number of sparse memory reads per read head |
| cell_size | `10` | Size of each memory cell |
| nonlinearity | `'tanh'` | If using 'rnn' as `rnn_type`, non-linearity of the RNNs |
| gpu_id | `-1` | ID of the GPU, -1 for CPU |
| independent_linears | `False` | Whether to use independent linear units to derive interface vector |
| share_memory | `True` | Whether to share memory between controller layers |
Following are the forward pass parameters:
| Argument | Default | Description |
| --- | --- | --- |
| input | - | The input vector `(B*T*X)` or `(T*B*X)` |
| hidden | `(None,None,None)` | Hidden states `(controller hidden, memory hidden, read vectors)` |
| reset_experience | `False` | Whether to reset memory |
| pass_through_memory | `True` | Whether to pass through memory |
#### Example usage:
```python
from dnc import SAM
rnn = SAM(
input_size=64,
hidden_size=128,
rnn_type='lstm',
num_layers=4,
nr_cells=100,
cell_size=32,
read_heads=4,
sparse_reads=4,
batch_first=True,
gpu_id=0
)
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors) = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
```
#### Debugging:
The `debug` option causes the network to return its memory hidden vectors (numpy `ndarray`s) for the first batch each forward step.
These vectors can be analyzed or visualized, using visdom for example.
```python
from dnc import SAM
rnn = SAM(
input_size=64,
hidden_size=128,
rnn_type='lstm',
num_layers=4,
nr_cells=100,
cell_size=32,
read_heads=4,
batch_first=True,
sparse_reads=4,
gpu_id=0,
debug=True
)
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors), debug_memory = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
```
Memory vectors returned by forward pass (`np.ndarray`):
| Key | Y axis (dimensions) | X axis (dimensions) |
| --- | --- | --- |
| `debug_memory['memory']` | layer * time | nr_cells * cell_size
| `debug_memory['visible_memory']` | layer * time | sparse_reads+2*temporal_reads+1 * nr_cells
| `debug_memory['read_positions']` | layer * time | sparse_reads+2*temporal_reads+1
| `debug_memory['read_weights']` | layer * time | read_heads * nr_cells
| `debug_memory['write_weights']` | layer * time | nr_cells
| `debug_memory['usage']` | layer * time | nr_cells
## Example copy task ## Example copy task
The copy task, as descibed in the original paper, is included in the repo. The copy task, as descibed in the original paper, is included in the repo.

View File

@ -24,6 +24,7 @@ from torch.nn.utils import clip_grad_norm
from dnc.dnc import DNC from dnc.dnc import DNC
from dnc.sdnc import SDNC from dnc.sdnc import SDNC
from dnc.sam import SAM
from dnc.util import * from dnc.util import *
parser = argparse.ArgumentParser(description='PyTorch Differentiable Neural Computer') parser = argparse.ArgumentParser(description='PyTorch Differentiable Neural Computer')
@ -31,7 +32,7 @@ parser.add_argument('-input_size', type=int, default=6, help='dimension of input
parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller') parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller')
parser.add_argument('-nhid', type=int, default=64, help='number of hidden units of the inner nn') parser.add_argument('-nhid', type=int, default=64, help='number of hidden units of the inner nn')
parser.add_argument('-dropout', type=float, default=0, help='controller dropout') parser.add_argument('-dropout', type=float, default=0, help='controller dropout')
parser.add_argument('-memory_type', type=str, default='dnc', help='dense or sparse memory') parser.add_argument('-memory_type', type=str, default='dnc', help='dense or sparse memory: dnc | sdnc | sam')
parser.add_argument('-nlayer', type=int, default=1, help='number of layers') parser.add_argument('-nlayer', type=int, default=1, help='number of layers')
parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers') parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers')
@ -55,6 +56,7 @@ parser.add_argument('-log-interval', type=int, default=200, metavar='N', help='r
parser.add_argument('-iterations', type=int, default=100000, metavar='N', help='total number of iteration') parser.add_argument('-iterations', type=int, default=100000, metavar='N', help='total number of iteration')
parser.add_argument('-summarize_freq', type=int, default=100, metavar='N', help='summarize frequency') parser.add_argument('-summarize_freq', type=int, default=100, metavar='N', help='summarize frequency')
parser.add_argument('-check_freq', type=int, default=100, metavar='N', help='check point frequency') parser.add_argument('-check_freq', type=int, default=100, metavar='N', help='check point frequency')
parser.add_argument('-visdom', action='store_true', help='plot memory content on visdom per -summarize_freq steps')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
@ -129,7 +131,7 @@ if __name__ == '__main__':
cell_size=mem_size, cell_size=mem_size,
read_heads=read_heads, read_heads=read_heads,
gpu_id=args.cuda, gpu_id=args.cuda,
debug=True, debug=args.visdom,
batch_first=True, batch_first=True,
independent_linears=True independent_linears=True
) )
@ -147,7 +149,24 @@ if __name__ == '__main__':
temporal_reads=args.temporal_reads, temporal_reads=args.temporal_reads,
read_heads=args.read_heads, read_heads=args.read_heads,
gpu_id=args.cuda, gpu_id=args.cuda,
debug=False, debug=args.visdom,
batch_first=True,
independent_linears=False
)
elif args.memory_type == 'sam':
rnn = SAM(
input_size=args.input_size,
hidden_size=args.nhid,
rnn_type=args.rnn_type,
num_layers=args.nlayer,
num_hidden_layers=args.nhlayer,
dropout=args.dropout,
nr_cells=mem_slot,
cell_size=mem_size,
sparse_reads=args.sparse_reads,
read_heads=args.read_heads,
gpu_id=args.cuda,
debug=args.visdom,
batch_first=True, batch_first=True,
independent_linears=False independent_linears=False
) )
@ -252,7 +271,7 @@ if __name__ == '__main__':
xlabel='mem_slot' xlabel='mem_slot'
) )
) )
else: elif args.memory_type == 'sdnc':
viz.heatmap( viz.heatmap(
v['link_matrix'][-1].reshape(args.mem_slot, -1), v['link_matrix'][-1].reshape(args.mem_slot, -1),
opts=dict( opts=dict(
@ -275,16 +294,17 @@ if __name__ == '__main__':
) )
) )
viz.heatmap( elif args.memory_type == 'sdnc' or args.memory_type == 'dnc':
v['precedence'], viz.heatmap(
opts=dict( v['precedence'],
xtickstep=10, opts=dict(
ytickstep=2, xtickstep=10,
title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss), ytickstep=2,
ylabel='layer * time', title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss),
xlabel='mem_slot' ylabel='layer * time',
) xlabel='mem_slot'
) )
)
if args.memory_type == 'sdnc': if args.memory_type == 'sdnc':
viz.heatmap( viz.heatmap(