Merge pull request #16 from ixaxaar/sparse

SDNC
This commit is contained in:
Russi Chatterjee 2017-12-11 00:59:33 +05:30 committed by GitHub
commit f14f50def0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1644 additions and 54 deletions

1
.gitignore vendored
View File

@ -19,3 +19,4 @@ __pycache__/
dist/
dnc.egg-info/
tasks/checkpoints/
faiss/

View File

@ -2,10 +2,16 @@ language: python
python:
- "3.6"
# command to install dependencies
before_install:
- sudo apt-get -qq update
- sudo apt-get install -yqq software-properties-common git
- sudo apt-get install -yqq libopenblas-dev liblapack3 python3-numpy python3-dev swig
- sudo ln -s /usr/lib/libopenblas.so /usr/lib/libopenblas.so.3
install:
- pip install http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl
- pip install numpy
- pip install visdom
- pip install -qqq http://download.pytorch.org/whl/cu75/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl
- pip install -qqq numpy
- pip install -qqq visdom
- pip install -qqq pyflann3
# command to run tests
script:
- pytest
- pytest ./test

138
README.md
View File

@ -1,8 +1,11 @@
# Differentiable Neural Computer, for Pytorch
# Differentiable Neural Computers and Sparse Differentiable Neural Computers, for Pytorch
[![Build Status](https://travis-ci.org/ixaxaar/pytorch-dnc.svg?branch=master)](https://travis-ci.org/ixaxaar/pytorch-dnc) [![PyPI version](https://badge.fury.io/py/dnc.svg)](https://badge.fury.io/py/dnc)
[![Build Status](https://travis-ci.org/ixaxaar/pytorch-dnc.svg?branch=master)](https://travis-ci.org/ixaxaar/pytorch-dnc) [![PyPI version](https://badge.fury.io/py/dnc.svg)](https://badge.fury.io/py/dnc)
This is an implementation of [Differentiable Neural Computers](http://people.idsia.ch/~rupesh/rnnsymposium2016/slides/graves.pdf), described in the paper [Hybrid computing using a neural network with dynamic external memory, Graves et al.](https://www.nature.com/articles/nature20101)
and the Sparse version of the DNC (the SDNC) described in [Scaling Memory-Augmented Neural Networks with Sparse Reads and Writes](http://papers.nips.cc/paper/6298-scaling-memory-augmented-neural-networks-with-sparse-reads-and-writes.pdf).
## Install
@ -10,7 +13,16 @@ This is an implementation of [Differentiable Neural Computers](http://people.ids
pip install dnc
```
To run the tests in the test directory, `pytest` is needed.
### From source
```
git clone https://github.com/ixaxaar/pytorch-dnc
cd pytorch-dnc
pip install -r ./requirements.txt
pip install -e .
```
`pytest` is required to run the test
## Architecure
@ -18,7 +30,11 @@ To run the tests in the test directory, `pytest` is needed.
## Usage
**Parameters**:
### DNC
**Constructor Parameters**:
Following are the constructor parameters:
Following are the constructor parameters:
@ -47,11 +63,11 @@ Following are the forward pass parameters:
| --- | --- | --- |
| input | - | The input vector `(B*T*X)` or `(T*B*X)` |
| hidden | `(None,None,None)` | Hidden states `(controller hidden, memory hidden, read vectors)` |
| reset_experience | `False` | Whether to reset memory (This is a parameter for the forward pass |
| pass_through_memory | `True` | Whether to pass through memory (This is a parameter for the forward pass |
| reset_experience | `False` | Whether to reset memory |
| pass_through_memory | `True` | Whether to pass through memory |
### Example usage:
#### Example usage:
```python
from dnc import DNC
@ -74,7 +90,8 @@ output, (controller_hidden, memory, read_vectors) = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
```
### Debugging:
#### Debugging:
The `debug` option causes the network to return its memory hidden vectors (numpy `ndarray`s) for the first batch each forward step.
These vectors can be analyzed or visualized, using visdom for example.
@ -112,6 +129,107 @@ Memory vectors returned by forward pass (`np.ndarray`):
| `debug_memory['write_weights']` | layer * time | nr_cells
| `debug_memory['usage_vector']` | layer * time | nr_cells
### SDNC
**Constructor Parameters**:
Following are the constructor parameters:
| Argument | Default | Description |
| --- | --- | --- |
| input_size | `None` | Size of the input vectors |
| hidden_size | `None` | Size of hidden units |
| rnn_type | `'lstm'` | Type of recurrent cells used in the controller |
| num_layers | `1` | Number of layers of recurrent units in the controller |
| num_hidden_layers | `2` | Number of hidden layers per layer of the controller |
| bias | `True` | Bias |
| batch_first | `True` | Whether data is fed batch first |
| dropout | `0` | Dropout between layers in the controller |
| bidirectional | `False` | If the controller is bidirectional (Not yet implemented |
| nr_cells | `5000` | Number of memory cells |
| read_heads | `4` | Number of read heads |
| sparse_reads | `10` | Number of sparse memory reads per read head |
| cell_size | `10` | Size of each memory cell |
| nonlinearity | `'tanh'` | If using 'rnn' as `rnn_type`, non-linearity of the RNNs |
| gpu_id | `-1` | ID of the GPU, -1 for CPU |
| independent_linears | `False` | Whether to use independent linear units to derive interface vector |
| share_memory | `True` | Whether to share memory between controller layers |
Following are the forward pass parameters:
| Argument | Default | Description |
| --- | --- | --- |
| input | - | The input vector `(B*T*X)` or `(T*B*X)` |
| hidden | `(None,None,None)` | Hidden states `(controller hidden, memory hidden, read vectors)` |
| reset_experience | `False` | Whether to reset memory |
| pass_through_memory | `True` | Whether to pass through memory |
#### Example usage:
```python
from dnc import SDNC
rnn = SDNC(
input_size=64,
hidden_size=128,
rnn_type='lstm',
num_layers=4,
nr_cells=100,
cell_size=32,
read_heads=4,
sparse_reads=4,
batch_first=True,
gpu_id=0
)
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors) = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
```
#### Debugging:
The `debug` option causes the network to return its memory hidden vectors (numpy `ndarray`s) for the first batch each forward step.
These vectors can be analyzed or visualized, using visdom for example.
```python
from dnc import SDNC
rnn = SDNC(
input_size=64,
hidden_size=128,
rnn_type='lstm',
num_layers=4,
nr_cells=100,
cell_size=32,
read_heads=4,
batch_first=True,
sparse_reads=4,
gpu_id=0,
debug=True
)
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors), debug_memory = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
```
Memory vectors returned by forward pass (`np.ndarray`):
| Key | Y axis (dimensions) | X axis (dimensions) |
| --- | --- | --- |
| `debug_memory['memory']` | layer * time | nr_cells * cell_size
| `debug_memory['visible_memory']` | layer * time | sparse_reads+1 * nr_cells
| `debug_memory['read_positions']` | layer * time | sparse_reads+1
| `debug_memory['read_weights']` | layer * time | read_heads * nr_cells
| `debug_memory['write_weights']` | layer * time | nr_cells
| `debug_memory['usage']` | layer * time | nr_cells
## Example copy task
The copy task, as descibed in the original paper, is included in the repo.
@ -121,6 +239,12 @@ From the project root:
python ./tasks/copy_task.py -cuda 0 -optim rmsprop -batch_size 32 -mem_slot 64 # (like original implementation)
python3 ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 32 -batch_size 1000 -optim adam -sequence_max_length 8 # (faster convergence)
For SDNCs:
python3 -B ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -memory_type sdnc -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 100 -mem_size 10 -read_heads 1 -sparse_reads 10 -batch_size 20 -optim adam -sequence_max_length 10
and for curriculum learning for SDNCs:
python3 -B ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -memory_type sdnc -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 100 -mem_size 10 -read_heads 1 -sparse_reads 4 -batch_size 20 -optim adam -sequence_max_length 4 -curriculum_increment 2 -curriculum_freq 10000
```
For the full set of options, see:

View File

@ -1,3 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from .dnc import DNC
from .sdnc import SDNC

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch as T
@ -12,6 +13,8 @@ from torch.nn.utils.rnn import PackedSequence
from .util import *
from .memory import *
from torch.nn.init import orthogonal, xavier_uniform
class DNC(nn.Module):
@ -113,6 +116,7 @@ class DNC(nn.Module):
# final output layer
self.output = nn.Linear(self.nn_output_size, self.input_size)
orthogonal(self.output.weight)
if self.gpu_id != -1:
[x.cuda(self.gpu_id) for x in self.rnns]
@ -126,7 +130,10 @@ class DNC(nn.Module):
# initialize hidden state of the controller RNN
if chx is None:
chx = [None for x in range(self.num_layers)]
h = cuda(T.zeros(self.num_hidden_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
xavier_uniform(h)
chx = [ (h, h) if self.rnn_type.lower() == 'lstm' else h for x in range(self.num_layers)]
# Last read vectors
if last_read is None:

86
dnc/faiss_index.py Normal file
View File

@ -0,0 +1,86 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from faiss import faiss
from faiss.faiss import cast_integer_to_float_ptr as cast_float
from faiss.faiss import cast_integer_to_int_ptr as cast_int
from faiss.faiss import cast_integer_to_long_ptr as cast_long
from .util import *
class FAISSIndex(object):
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_lists=32, probes=32, res=None, train=None, gpu_id=-1):
super(FAISSIndex, self).__init__()
self.cell_size = cell_size
self.nr_cells = nr_cells
self.probes = probes
self.K = K
self.num_lists = num_lists
self.gpu_id = gpu_id
res = res if res else faiss.StandardGpuResources()
res.setTempMemoryFraction(0.01)
if self.gpu_id != -1:
res.initializeForDevice(self.gpu_id)
nr_samples = self.nr_cells * 100 * self.cell_size
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size) * 10
# train = T.randn(self.nr_cells * 100, self.cell_size)
self.index = faiss.GpuIndexIVFFlat(res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
self.index.setNumProbes(self.probes)
self.train(train)
def cuda(self, gpu_id):
self.gpu_id = gpu_id
def train(self, train):
train = ensure_gpu(train, -1)
T.cuda.synchronize()
self.index.train_c(self.nr_cells, cast_float(ptr(train)))
T.cuda.synchronize()
def reset(self):
T.cuda.synchronize()
self.index.reset()
T.cuda.synchronize()
def add(self, other, positions=None, last=-1):
other = ensure_gpu(other, self.gpu_id)
T.cuda.synchronize()
if positions is not None:
positions = ensure_gpu(positions, self.gpu_id)
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
else:
other = other[:last, :]
self.index.add_c(other.size(0), cast_float(ptr(other)))
T.cuda.synchronize()
def search(self, query, k=None):
query = ensure_gpu(query, self.gpu_id)
k = k if k else self.K
(b,n) = query.size()
distances = T.FloatTensor(b, k)
labels = T.LongTensor(b, k)
if self.gpu_id != -1: distances = distances.cuda(self.gpu_id)
if self.gpu_id != -1: labels = labels.cuda(self.gpu_id)
T.cuda.synchronize()
self.index.search_c(
b,
cast_float(ptr(query)),
k,
cast_float(ptr(distances)),
cast_long(ptr(labels))
)
T.cuda.synchronize()
return (distances, (labels-1))

53
dnc/flann_index.py Normal file
View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch as T
from torch.autograd import Variable as var
import numpy as np
from pyflann import *
from .util import *
class FLANNIndex(object):
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_kdtrees=32, probes=32, gpu_id=-1):
super(FLANNIndex, self).__init__()
self.cell_size = cell_size
self.nr_cells = nr_cells
self.probes = probes
self.K = K
self.num_kdtrees = num_kdtrees
self.gpu_id = gpu_id
self.index = FLANN()
def add(self, other, positions=None, last=-1):
if isinstance(other, var):
other = other[:last, :].data.cpu().numpy()
elif isinstance(other, T.Tensor):
other = other[:last, :].cpu().numpy()
self.index.build_index(other, algorithm='kdtree', trees=self.num_kdtrees, checks=self.probes)
def search(self, query, k=None):
if isinstance(query, var):
query = query.data.cpu().numpy()
elif isinstance(query, T.Tensor):
query = query.cpu().numpy()
l, d = self.index.nn_index(query, num_neighbors=self.K if k is None else k)
distances = T.from_numpy(d).float()
labels = T.from_numpy(l).long()
if self.gpu_id != -1: distances = distances.cuda(self.gpu_id)
if self.gpu_id != -1: labels = labels.cuda(self.gpu_id)
return (distances, labels)
def reset(self):
self.index.delete_index()

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch as T

284
dnc/sdnc.py Normal file
View File

@ -0,0 +1,284 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch as T
from torch.autograd import Variable as var
import numpy as np
from torch.nn.utils.rnn import pad_packed_sequence as pad
from torch.nn.utils.rnn import pack_padded_sequence as pack
from torch.nn.utils.rnn import PackedSequence
from torch.nn.init import orthogonal, xavier_uniform
from .util import *
from .sparse_memory import SparseMemory
class SDNC(nn.Module):
def __init__(
self,
input_size,
hidden_size,
rnn_type='lstm',
num_layers=1,
num_hidden_layers=2,
bias=True,
batch_first=True,
dropout=0,
bidirectional=False,
nr_cells=5000,
sparse_reads=10,
read_heads=4,
cell_size=10,
nonlinearity='tanh',
gpu_id=-1,
independent_linears=False,
share_memory=True,
debug=False,
clip=20
):
super(SDNC, self).__init__()
# todo: separate weights and RNNs for the interface and output vectors
self.input_size = input_size
self.hidden_size = hidden_size
self.rnn_type = rnn_type
self.num_layers = num_layers
self.num_hidden_layers = num_hidden_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = dropout
self.bidirectional = bidirectional
self.nr_cells = nr_cells
self.sparse_reads = sparse_reads
self.read_heads = read_heads
self.cell_size = cell_size
self.nonlinearity = nonlinearity
self.gpu_id = gpu_id
self.independent_linears = independent_linears
self.share_memory = share_memory
self.debug = debug
self.clip = clip
self.w = self.cell_size
self.r = self.read_heads
self.read_vectors_size = self.r * self.w
self.output_size = self.hidden_size
self.nn_input_size = self.input_size + self.read_vectors_size
self.nn_output_size = self.output_size + self.read_vectors_size
self.rnns = []
self.memories = []
for layer in range(self.num_layers):
if self.rnn_type.lower() == 'rnn':
self.rnns.append(nn.RNN((self.nn_input_size if layer == 0 else self.nn_output_size), self.output_size,
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
elif self.rnn_type.lower() == 'gru':
self.rnns.append(nn.GRU((self.nn_input_size if layer == 0 else self.nn_output_size),
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
if self.rnn_type.lower() == 'lstm':
self.rnns.append(nn.LSTM((self.nn_input_size if layer == 0 else self.nn_output_size),
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
setattr(self, self.rnn_type.lower() + '_layer_' + str(layer), self.rnns[layer])
# memories for each layer
if not self.share_memory:
self.memories.append(
SparseMemory(
input_size=self.output_size,
mem_size=self.nr_cells,
cell_size=self.w,
sparse_reads=self.sparse_reads,
read_heads=self.read_heads,
gpu_id=self.gpu_id,
mem_gpu_id=self.gpu_id,
independent_linears=self.independent_linears
)
)
setattr(self, 'rnn_layer_memory_' + str(layer), self.memories[layer])
# only one memory shared by all layers
if self.share_memory:
self.memories.append(
SparseMemory(
input_size=self.output_size,
mem_size=self.nr_cells,
cell_size=self.w,
sparse_reads=self.sparse_reads,
read_heads=self.read_heads,
gpu_id=self.gpu_id,
mem_gpu_id=self.gpu_id,
independent_linears=self.independent_linears
)
)
setattr(self, 'rnn_layer_memory_shared', self.memories[0])
# final output layer
self.output = nn.Linear(self.nn_output_size, self.input_size)
orthogonal(self.output.weight)
if self.gpu_id != -1:
[x.cuda(self.gpu_id) for x in self.rnns]
[x.cuda(self.gpu_id) for x in self.memories]
def _init_hidden(self, hx, batch_size, reset_experience):
# create empty hidden states if not provided
if hx is None:
hx = (None, None, None)
(chx, mhx, last_read) = hx
# initialize hidden state of the controller RNN
if chx is None:
h = cuda(T.zeros(self.num_hidden_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
xavier_uniform(h)
chx = [ (h, h) if self.rnn_type.lower() == 'lstm' else h for x in range(self.num_layers)]
# Last read vectors
if last_read is None:
last_read = cuda(T.zeros(batch_size, self.w * self.r), gpu_id=self.gpu_id)
# memory states
if mhx is None:
if self.share_memory:
mhx = self.memories[0].reset(batch_size, erase=reset_experience)
else:
mhx = [m.reset(batch_size, erase=reset_experience) for m in self.memories]
else:
if self.share_memory:
mhx = self.memories[0].reset(batch_size, mhx, erase=reset_experience)
else:
mhx = [m.reset(batch_size, h, erase=reset_experience) for m, h in zip(self.memories, mhx)]
return chx, mhx, last_read
def _debug(self, mhx, debug_obj):
if not debug_obj:
debug_obj = {
'memory': [],
'visible_memory': [],
'read_weights': [],
'write_weights': [],
'read_vectors': [],
'least_used_mem': [],
'usage': [],
'read_positions': []
}
debug_obj['memory'].append(mhx['memory'][0].data.cpu().numpy())
debug_obj['visible_memory'].append(mhx['visible_memory'][0].data.cpu().numpy())
debug_obj['read_weights'].append(mhx['read_weights'][0].unsqueeze(0).data.cpu().numpy())
debug_obj['write_weights'].append(mhx['write_weights'][0].unsqueeze(0).data.cpu().numpy())
debug_obj['read_vectors'].append(mhx['read_vectors'][0].data.cpu().numpy())
debug_obj['least_used_mem'].append(mhx['least_used_mem'][0].unsqueeze(0).data.cpu().numpy())
debug_obj['usage'].append(mhx['usage'][0].unsqueeze(0).data.cpu().numpy())
debug_obj['read_positions'].append(mhx['read_positions'][0].unsqueeze(0).data.cpu().numpy())
return debug_obj
def _layer_forward(self, input, layer, hx=(None, None), pass_through_memory=True):
(chx, mhx) = hx
# pass through the controller layer
input, chx = self.rnns[layer](input.unsqueeze(1), chx)
input = input.squeeze(1)
# clip the controller output
if self.clip != 0:
output = T.clamp(input, -self.clip, self.clip)
else:
output = input
# the interface vector
ξ = output
# pass through memory
if pass_through_memory:
if self.share_memory:
read_vecs, mhx = self.memories[0](ξ, mhx)
else:
read_vecs, mhx = self.memories[layer](ξ, mhx)
# the read vectors
read_vectors = read_vecs.view(-1, self.w * self.r)
else:
read_vectors = None
return output, (chx, mhx, read_vectors)
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
# handle packed data
is_packed = type(input) is PackedSequence
if is_packed:
input, lengths = pad(input)
max_length = lengths[0]
else:
max_length = input.size(1) if self.batch_first else input.size(0)
lengths = [input.size(1)] * max_length if self.batch_first else [input.size(0)] * max_length
batch_size = input.size(0) if self.batch_first else input.size(1)
if not self.batch_first:
input = input.transpose(0, 1)
# make the data time-first
controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience)
# concat input with last read (or padding) vectors
inputs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
# batched forward pass per element / word / etc
if self.debug:
viz = None
outs = [None] * max_length
read_vectors = None
# pass through time
for time in range(max_length):
# pass thorugh layers
for layer in range(self.num_layers):
# this layer's hidden states
chx = controller_hidden[layer]
m = mem_hidden if self.share_memory else mem_hidden[layer]
# pass through controller
outs[time], (chx, m, read_vectors) = \
self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory)
# debug memory
if self.debug:
viz = self._debug(m, viz)
# store the memory back (per layer or shared)
if self.share_memory:
mem_hidden = m
else:
mem_hidden[layer] = m
controller_hidden[layer] = chx
if read_vectors is not None:
# the controller output + read vectors go into next layer
outs[time] = T.cat([outs[time], read_vectors], 1)
else:
outs[time] = T.cat([outs[time], last_read], 1)
inputs[time] = outs[time]
if self.debug:
viz = {k: np.array(v) for k, v in viz.items()}
viz = {k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k, v in viz.items()}
# pass through final output layer
inputs = [self.output(i) for i in inputs]
outputs = T.stack(inputs, 1 if self.batch_first else 0)
if is_packed:
outputs = pack(output, lengths)
if self.debug:
return outputs, (controller_hidden, mem_hidden, read_vectors), viz
else:
return outputs, (controller_hidden, mem_hidden, read_vectors)

278
dnc/sparse_memory.py Normal file
View File

@ -0,0 +1,278 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
import numpy as np
import math
from .flann_index import FLANNIndex
from .util import *
import time
class SparseMemory(nn.Module):
def __init__(
self,
input_size,
mem_size=512,
cell_size=32,
independent_linears=True,
read_heads=4,
sparse_reads=10,
num_lists=None,
index_checks=32,
gpu_id=-1,
mem_gpu_id=-1
):
super(SparseMemory, self).__init__()
self.mem_size = mem_size
self.cell_size = cell_size
self.gpu_id = gpu_id
self.mem_gpu_id = mem_gpu_id
self.input_size = input_size
self.independent_linears = independent_linears
self.K = sparse_reads if self.mem_size > sparse_reads else self.mem_size
self.read_heads = read_heads
self.num_lists = num_lists if num_lists is not None else int(self.mem_size / 100)
self.index_checks = index_checks
m = self.mem_size
w = self.cell_size
r = self.read_heads
c = r * self.K + 1
if self.independent_linears:
self.read_query_transform = nn.Linear(self.input_size, w*r)
self.write_vector_transform = nn.Linear(self.input_size, w)
self.interpolation_gate_transform = nn.Linear(self.input_size, c)
self.write_gate_transform = nn.Linear(self.input_size, 1)
T.nn.init.orthogonal(self.read_query_transform.weight)
T.nn.init.orthogonal(self.write_vector_transform.weight)
T.nn.init.orthogonal(self.interpolation_gate_transform.weight)
T.nn.init.orthogonal(self.write_gate_transform.weight)
else:
self.interface_size = (r * w) + w + c + 1
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
T.nn.init.orthogonal(self.interface_weights.weight)
self.I = cuda(1 - T.eye(m).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
self.δ = 0.005 # minimum usage
self.timestep = 0
def rebuild_indexes(self, hidden, erase=False):
b = hidden['memory'].size(0)
# if indexes already exist, we reset them
if 'indexes' in hidden:
[x.reset() for x in hidden['indexes']]
else:
# create new indexes
hidden['indexes'] = \
[FLANNIndex(cell_size=self.cell_size,
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
# add existing memory into indexes
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
if not erase:
for n, i in enumerate(hidden['indexes']):
i.reset()
i.add(hidden['memory'][n], last=pos[n][-1])
else:
self.timestep = 0
return hidden
def reset(self, batch_size=1, hidden=None, erase=True):
m = self.mem_size
w = self.cell_size
b = batch_size
r = self.read_heads
c = r * self.K + 1
if hidden is None:
hidden = {
# warning can be a huge chunk of contiguous memory
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(),
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
}
hidden = self.rebuild_indexes(hidden, erase=True)
else:
hidden['memory'] = hidden['memory'].clone()
hidden['visible_memory'] = hidden['visible_memory'].clone()
hidden['read_weights'] = hidden['read_weights'].clone()
hidden['write_weights'] = hidden['write_weights'].clone()
hidden['read_vectors'] = hidden['read_vectors'].clone()
hidden['least_used_mem'] = hidden['least_used_mem'].clone()
hidden['usage'] = hidden['usage'].clone()
hidden['read_positions'] = hidden['read_positions'].clone()
hidden = self.rebuild_indexes(hidden, erase)
if erase:
hidden['memory'].data.fill_(δ)
hidden['visible_memory'].data.fill_(δ)
hidden['read_weights'].data.fill_(δ)
hidden['write_weights'].data.fill_(δ)
hidden['read_vectors'].data.fill_(δ)
hidden['least_used_mem'].data.fill_(c+1+self.timestep)
hidden['usage'].data.fill_(δ)
hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
return hidden
def write_into_sparse_memory(self, hidden):
visible_memory = hidden['visible_memory']
positions = hidden['read_positions'].squeeze()
(b, m, w) = hidden['memory'].size()
# update memory
hidden['memory'].scatter_(1, positions.unsqueeze(2).expand(b, self.read_heads*self.K+1, w), visible_memory)
# non-differentiable operations
pos = positions.data.cpu().numpy()
for batch in range(b):
# update indexes
hidden['indexes'][batch].reset()
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
hidden['least_used_mem'] = hidden['least_used_mem'] + 1 if self.timestep < self.mem_size else hidden['least_used_mem'] * 0
return hidden
def write(self, interpolation_gate, write_vector, write_gate, hidden):
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
hidden['usage'], I = self.update_usage(
hidden['read_positions'],
read_weights,
write_weights,
hidden['usage']
)
# either we write to previous read locations
x = interpolation_gate * read_weights
# or to a new location
y = (1 - interpolation_gate) * I
write_weights = write_gate * (x + y)
# store the write weights
hidden['write_weights'].scatter_(1, hidden['read_positions'], write_weights)
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
hidden = self.write_into_sparse_memory(hidden)
return hidden
def update_usage(self, read_positions, read_weights, write_weights, usage):
(b, _) = read_positions.size()
# usage is timesteps since a non-negligible memory access
# todo store write weights of all mem and gather from that
u = (read_weights + write_weights > self.δ).float()
# usage before write
relevant_usages = usage.gather(1, read_positions)
# indicator of words with minimal memory usage
minusage = T.min(relevant_usages, -1, keepdim=True)[0]
minusage = minusage.expand(relevant_usages.size())
I = (relevant_usages == minusage).float()
# usage after write
relevant_usages = (self.timestep - relevant_usages) * u + relevant_usages * (1 - u)
usage.scatter_(1, read_positions, relevant_usages)
return usage, I
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage):
b = keys.size(0)
read_positions = []
# we search for k cells per read head
for batch in range(b):
distances, positions = indexes[batch].search(keys[batch])
read_positions.append(T.clamp(positions, 0, self.mem_size - 1))
read_positions = T.stack(read_positions, 0)
# add least used mem to read positions
# TODO: explore possibility of reading co-locations or ranges and such
(b, r, k) = read_positions.size()
read_positions = var(read_positions)
read_positions = T.cat([read_positions.view(b, -1), least_used_mem], 1)
# differentiable ops
(b, m, w) = memory.size()
visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, r*k+1, w))
read_weights = σ(θ(visible_memory, keys), 2)
read_vectors = T.bmm(read_weights, visible_memory)
read_weights = T.prod(read_weights, 1)
return read_vectors, read_positions, read_weights, visible_memory
def read(self, read_query, hidden):
# sparse read
read_vectors, positions, read_weights, visible_memory = \
self.read_from_sparse_memory(
hidden['memory'],
hidden['indexes'],
read_query,
hidden['least_used_mem'],
hidden['usage']
)
hidden['read_positions'] = positions
hidden['read_weights'] = hidden['read_weights'].scatter_(1, positions, read_weights)
hidden['read_vectors'] = read_vectors
hidden['visible_memory'] = visible_memory
return hidden['read_vectors'], hidden
def forward(self, ξ, hidden):
t = time.time()
# ξ = ξ.detach()
m = self.mem_size
w = self.cell_size
r = self.read_heads
c = r * self.K + 1
b = ξ.size()[0]
if self.independent_linears:
# r read keys (b * r * w)
read_query = self.read_query_transform(ξ).view(b, r, w)
# write key (b * 1 * w)
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
# write vector (b * 1 * r)
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
# write gate (b * 1)
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
else:
ξ = self.interface_weights(ξ)
# r read keys (b * r * w)
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
# write key (b * 1 * w)
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
# write vector (b * 1 * r)
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c)
# write gate (b * 1)
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
self.timestep += 1
hidden = self.write(interpolation_gate, write_vector, write_gate, hidden)
return self.read(read_query, hidden)

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch as T
@ -98,7 +99,7 @@ def σ(input, axis=1):
def register_nan_checks(model):
def check_grad(module, grad_input, grad_output):
# print(module) you can add this to see that the hook is called
print('hook called for ' + str(type(module)))
# print('hook called for ' + str(type(module)))
if any(np.all(np.isnan(gi.data.cpu().numpy())) for gi in grad_input if gi is not None):
print('NaN gradient in grad_input ' + type(module).__name__)
@ -129,3 +130,31 @@ def check_nan_gradient(name=''):
# assert 0, 'nan gradient'
return tensor
return f
def ptr(tensor):
if T.is_tensor(tensor):
return tensor.storage().data_ptr()
elif hasattr(tensor, 'data'):
return tensor.data.storage().data_ptr()
else:
return tensor
# TODO: EWW change this shit
def ensure_gpu(tensor, gpu_id):
if "cuda" in str(type(tensor)) and gpu_id != -1:
return tensor.cuda(gpu_id)
elif "cuda" in str(type(tensor)):
return tensor.cpu()
elif "Tensor" in str(type(tensor)) and gpu_id != -1:
return tensor.cuda(gpu_id)
elif "Tensor" in str(type(tensor)):
return tensor
elif type(tensor) is np.ndarray:
return cudavec(tensor, gpu_id=gpu_id).data
else:
return tensor
def print_gradient(x, name):
s = "Gradient of " + name + " ----------------------------------"
x.register_hook(lambda y: print(s, y.squeeze()))

4
requirements.txt Normal file
View File

@ -0,0 +1,4 @@
pyflann3>=1.8.4.1
torch>=0.2.0.post1
numpy>=1.13.3
pytest>=3.0.0

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""A setuptools based setup module.
See:
@ -54,9 +55,9 @@ setup(
keywords='differentiable neural computer dnc memory network',
packages=find_packages(exclude=['contrib', 'docs', 'tests', 'tasks']),
packages=find_packages(exclude=['contrib', 'docs', 'tests', 'tasks', 'scripts']),
install_requires=['torch', 'numpy'],
install_requires=['torch', 'numpy', 'pyflann3'],
extras_require={
'dev': ['check-manifest'],

139
tasks/copy_task.py Normal file → Executable file
View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import warnings
warnings.filterwarnings('ignore')
@ -22,12 +23,15 @@ import torch.optim as optim
from torch.nn.utils import clip_grad_norm
from dnc.dnc import DNC
from dnc.sdnc import SDNC
from dnc.util import *
parser = argparse.ArgumentParser(description='PyTorch Differentiable Neural Computer')
parser.add_argument('-input_size', type=int, default=6, help='dimension of input feature')
parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller')
parser.add_argument('-nhid', type=int, default=64, help='number of hidden units of the inner nn')
parser.add_argument('-dropout', type=float, default=0, help='controller dropout')
parser.add_argument('-memory_type', type=str, default='dnc', help='dense or sparse memory')
parser.add_argument('-nlayer', type=int, default=1, help='number of layers')
parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers')
@ -36,11 +40,14 @@ parser.add_argument('-optim', type=str, default='adam', help='learning rule, sup
parser.add_argument('-clip', type=float, default=50, help='gradient clipping')
parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size')
parser.add_argument('-mem_size', type=int, default=16, help='memory dimension')
parser.add_argument('-mem_size', type=int, default=20, help='memory dimension')
parser.add_argument('-mem_slot', type=int, default=16, help='number of memory slots')
parser.add_argument('-read_heads', type=int, default=4, help='number of read heads')
parser.add_argument('-sparse_reads', type=int, default=10, help='number of sparse reads per read head')
parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length')
parser.add_argument('-curriculum_increment', type=int, default=0, metavar='N', help='sequence_max_length incrementor per 1K iterations')
parser.add_argument('-curriculum_freq', type=int, default=1000, metavar='N', help='sequence_max_length incrementor per 1K iterations')
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
parser.add_argument('-log-interval', type=int, default=200, metavar='N', help='report interval')
@ -109,22 +116,44 @@ if __name__ == '__main__':
mem_size = args.mem_size
read_heads = args.read_heads
rnn = DNC(
input_size=args.input_size,
hidden_size=args.nhid,
rnn_type=args.rnn_type,
num_layers=args.nlayer,
num_hidden_layers=args.nhlayer,
dropout=args.dropout,
nr_cells=mem_slot,
cell_size=mem_size,
read_heads=read_heads,
gpu_id=args.cuda,
debug=True,
batch_first=True,
independent_linears=True
)
if args.memory_type == 'dnc':
rnn = DNC(
input_size=args.input_size,
hidden_size=args.nhid,
rnn_type=args.rnn_type,
num_layers=args.nlayer,
num_hidden_layers=args.nhlayer,
dropout=args.dropout,
nr_cells=mem_slot,
cell_size=mem_size,
read_heads=read_heads,
gpu_id=args.cuda,
debug=True,
batch_first=True,
independent_linears=True
)
elif args.memory_type == 'sdnc':
rnn = SDNC(
input_size=args.input_size,
hidden_size=args.nhid,
rnn_type=args.rnn_type,
num_layers=args.nlayer,
num_hidden_layers=args.nhlayer,
dropout=args.dropout,
nr_cells=mem_slot,
cell_size=mem_size,
sparse_reads=args.sparse_reads,
read_heads=args.read_heads,
gpu_id=args.cuda,
debug=True,
batch_first=True,
independent_linears=False
)
else:
raise Exception('Not recognized type of memory')
print(rnn)
# register_nan_checks(rnn)
if args.cuda != -1:
rnn = rnn.cuda(args.cuda)
@ -147,6 +176,7 @@ if __name__ == '__main__':
optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)
(chx, mhx, rv) = (None, None, None)
for epoch in range(iterations + 1):
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
optimizer.zero_grad()
@ -156,9 +186,9 @@ if __name__ == '__main__':
input_data, target_output = generate_data(batch_size, random_length, args.input_size, args.cuda)
if rnn.debug:
output, (chx, mhx, rv), v = rnn(input_data, None, pass_through_memory=True)
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
else:
output, (chx, mhx, rv) = rnn(input_data, None, pass_through_memory=True)
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
loss = criterion((output), target_output)
@ -170,6 +200,10 @@ if __name__ == '__main__':
summarize = (epoch % summarize_freq == 0)
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)
increment_curriculum = (epoch != 0) and (epoch % args.curriculum_freq == 0)
# detach memory from graph
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
last_save_losses.append(loss_value)
@ -181,6 +215,16 @@ if __name__ == '__main__':
# print('2222222222222222222222222222222222222222222222')
# print(F.relu6(output))
llprint("\n\tAvg. Logistic Loss: %.4f\n" % (loss))
if np.isnan(loss):
raise Exception('nan Loss')
if summarize and rnn.debug:
loss = np.mean(last_save_losses)
# print(input_data)
# print("1111111111111111111111111111111111111111111111")
# print(target_output)
# print('2222222222222222222222222222222222222222222222')
# print(F.relu6(output))
last_save_losses = []
viz.heatmap(
@ -194,27 +238,40 @@ if __name__ == '__main__':
)
)
viz.heatmap(
v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot),
opts=dict(
xtickstep=10,
ytickstep=2,
title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='mem_slot',
xlabel='mem_slot'
)
)
if args.memory_type == 'dnc':
viz.heatmap(
v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot),
opts=dict(
xtickstep=10,
ytickstep=2,
title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='mem_slot',
xlabel='mem_slot'
)
)
viz.heatmap(
v['precedence'],
opts=dict(
xtickstep=10,
ytickstep=2,
title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time',
xlabel='mem_slot'
)
)
viz.heatmap(
v['precedence'],
opts=dict(
xtickstep=10,
ytickstep=2,
title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time',
xlabel='mem_slot'
)
)
if args.memory_type == 'sdnc':
viz.heatmap(
v['read_positions'],
opts=dict(
xtickstep=10,
ytickstep=2,
title='Read Positions, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time',
xlabel='mem_slot'
)
)
viz.heatmap(
v['read_weights'],
@ -239,7 +296,7 @@ if __name__ == '__main__':
)
viz.heatmap(
v['usage_vector'],
v['usage_vector'] if args.memory_type == 'dnc' else v['usage'],
opts=dict(
xtickstep=10,
ytickstep=2,
@ -249,6 +306,10 @@ if __name__ == '__main__':
)
)
if increment_curriculum:
sequence_max_length = sequence_max_length + args.curriculum_increment
print("Increasing max length to " + str(sequence_max_length))
if take_checkpoint:
llprint("\nSaving Checkpoint ... "),
check_ptr = os.path.join(ckpts_dir, 'step_{}.pth'.format(epoch))

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import pytest
import numpy as np

46
test/test_indexes.py Normal file
View File

@ -0,0 +1,46 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import pytest
import numpy as np
import torch.nn as nn
import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm
import torch.optim as optim
import numpy as np
import sys
import os
import math
import time
import functools
sys.path.insert(0, '.')
from pyflann import *
from dnc.flann_index import FLANNIndex
def test_indexes():
n = 30
cell_size=20
nr_cells=1024
K=10
probes=32
d = T.ones(n, cell_size)
q = T.ones(1, cell_size)
for gpu_id in (-1, -1):
i = FLANNIndex(cell_size=cell_size, nr_cells=nr_cells, K=K, probes=probes, gpu_id=gpu_id)
d = d if gpu_id == -1 else d.cuda(gpu_id)
i.add(d)
dist, labels = i.search(q*7)
assert dist.size() == T.Size([1,K])
assert labels.size() == T.Size([1, K])

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import pytest
import numpy as np

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import pytest
import numpy as np

201
test/test_sdnc_gru.py Normal file
View File

@ -0,0 +1,201 @@
# #!/usr/bin/env python3
# # -*- coding: utf-8 -*-
import pytest
import numpy as np
import torch.nn as nn
import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm
import torch.optim as optim
import numpy as np
import sys
import os
import math
import time
import functools
sys.path.insert(0, '.')
from dnc import SDNC
from test_utils import generate_data, criterion
def test_rnn_1():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'gru'
num_layers = 1
num_hidden_layers = 1
dropout = 0
nr_cells = 100
cell_size = 10
read_heads = 1
sparse_reads = 2
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 10
length = 10
rnn = SDNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
sparse_reads=sparse_reads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
output, (chx, mhx, rv), v = rnn(input_data, None)
output = output.transpose(0, 1)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([21, 10, 100])
assert chx[0][0].size() == T.Size([10,100])
# assert mhx['memory'].size() == T.Size([10,1,1])
assert rv.size() == T.Size([10, 10])
def test_rnn_n():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'gru'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 200
cell_size = 17
read_heads = 2
sparse_reads = 4
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13
rnn = SDNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
sparse_reads=sparse_reads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
output, (chx, mhx, rv), v = rnn(input_data, None)
output = output.transpose(0, 1)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[0].size() == T.Size([num_hidden_layers,10,100])
# assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10, 34])
def test_rnn_no_memory_pass():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'gru'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 5000
cell_size = 17
sparse_reads = 3
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13
rnn = SDNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
sparse_reads=sparse_reads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
(chx, mhx, rv) = (None, None, None)
outputs = []
for x in range(6):
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
output = output.transpose(0, 1)
outputs.append(output)
output = functools.reduce(lambda x,y: x + y, outputs)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[0].size() == T.Size([num_hidden_layers,10,100])
# assert mhx['memory'].size() == T.Size([10,12,17])
assert rv == None

201
test/test_sdnc_lstm.py Normal file
View File

@ -0,0 +1,201 @@
# #!/usr/bin/env python3
# # -*- coding: utf-8 -*-
import pytest
import numpy as np
import torch.nn as nn
import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm
import torch.optim as optim
import numpy as np
import sys
import os
import math
import time
import functools
sys.path.insert(0, '.')
from dnc import SDNC
from test_utils import generate_data, criterion
def test_rnn_1():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'lstm'
num_layers = 1
num_hidden_layers = 1
dropout = 0
nr_cells = 100
cell_size = 10
read_heads = 1
sparse_reads = 2
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 10
length = 10
rnn = SDNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
sparse_reads=sparse_reads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
output, (chx, mhx, rv), v = rnn(input_data, None)
output = output.transpose(0, 1)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([21, 10, 100])
assert chx[0][0][0].size() == T.Size([10,100])
# assert mhx['memory'].size() == T.Size([10,1,1])
assert rv.size() == T.Size([10, 10])
def test_rnn_n():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'lstm'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 200
cell_size = 17
read_heads = 2
sparse_reads = 4
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13
rnn = SDNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
sparse_reads=sparse_reads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
output, (chx, mhx, rv), v = rnn(input_data, None)
output = output.transpose(0, 1)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[0][0].size() == T.Size([num_hidden_layers,10,100])
# assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10, 34])
def test_rnn_no_memory_pass():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'lstm'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 5000
cell_size = 17
sparse_reads = 3
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13
rnn = SDNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
sparse_reads=sparse_reads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
(chx, mhx, rv) = (None, None, None)
outputs = []
for x in range(6):
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
output = output.transpose(0, 1)
outputs.append(output)
output = functools.reduce(lambda x,y: x + y, outputs)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[0][0].size() == T.Size([num_hidden_layers,10,100])
# assert mhx['memory'].size() == T.Size([10,12,17])
assert rv == None

201
test/test_sdnc_rnn.py Normal file
View File

@ -0,0 +1,201 @@
# #!/usr/bin/env python3
# # -*- coding: utf-8 -*-
import pytest
import numpy as np
import torch.nn as nn
import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm
import torch.optim as optim
import numpy as np
import sys
import os
import math
import time
import functools
sys.path.insert(0, '.')
from dnc import SDNC
from test_utils import generate_data, criterion
def test_rnn_1():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'rnn'
num_layers = 1
num_hidden_layers = 1
dropout = 0
nr_cells = 100
cell_size = 10
read_heads = 1
sparse_reads = 2
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 10
length = 10
rnn = SDNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
sparse_reads=sparse_reads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
output, (chx, mhx, rv), v = rnn(input_data, None)
output = output.transpose(0, 1)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([21, 10, 100])
assert chx[0][0].size() == T.Size([10,100])
# assert mhx['memory'].size() == T.Size([10,1,1])
assert rv.size() == T.Size([10, 10])
def test_rnn_n():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'rnn'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 200
cell_size = 17
read_heads = 2
sparse_reads = 4
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13
rnn = SDNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
sparse_reads=sparse_reads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
output, (chx, mhx, rv), v = rnn(input_data, None)
output = output.transpose(0, 1)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[0].size() == T.Size([num_hidden_layers,10,100])
# assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10, 34])
def test_rnn_no_memory_pass():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'rnn'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 5000
cell_size = 17
sparse_reads = 3
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13
rnn = SDNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
sparse_reads=sparse_reads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
(chx, mhx, rv) = (None, None, None)
outputs = []
for x in range(6):
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
output = output.transpose(0, 1)
outputs.append(output)
output = functools.reduce(lambda x,y: x + y, outputs)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[0].size() == T.Size([num_hidden_layers,10,100])
# assert mhx['memory'].size() == T.Size([10,12,17])
assert rv == None

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch as T