commit
08e13761dd
147
README.md
147
README.md
@ -1,28 +1,36 @@
|
||||
# Differentiable Neural Computers and Sparse Differentiable Neural Computers, for Pytorch
|
||||
# Differentiable Neural Computers and family, for Pytorch
|
||||
|
||||
Includes:
|
||||
1. Differentiable Neural Computers (DNC)
|
||||
2. Sparse Access Memory (SAM)
|
||||
3. Sparse Differentiable Neural Computers (SDNC)
|
||||
|
||||
<!-- START doctoc generated TOC please keep comment here to allow auto update -->
|
||||
<!-- DON'T EDIT THIS SECTION, INSTEAD RE-RUN doctoc TO UPDATE -->
|
||||
|
||||
- [Differentiable Neural Computers and Sparse Differentiable Neural Computers, for Pytorch](#differentiable-neural-computers-and-sparse-differentiable-neural-computers-for-pytorch)
|
||||
- [Install](#install)
|
||||
- [From source](#from-source)
|
||||
- [Architecure](#architecure)
|
||||
- [Usage](#usage)
|
||||
- [DNC](#dnc)
|
||||
- [Example usage:](#example-usage)
|
||||
- [Debugging:](#debugging)
|
||||
- [SDNC](#sdnc)
|
||||
- [Example usage:](#example-usage-1)
|
||||
- [Debugging:](#debugging-1)
|
||||
- [Example copy task](#example-copy-task)
|
||||
- [General noteworthy stuff](#general-noteworthy-stuff)
|
||||
|
||||
- [Install](#install)
|
||||
- [From source](#from-source)
|
||||
- [Architecure](#architecure)
|
||||
- [Usage](#usage)
|
||||
- [DNC](#dnc)
|
||||
- [Example usage](#example-usage)
|
||||
- [Debugging](#debugging)
|
||||
- [SDNC](#sdnc)
|
||||
- [Example usage](#example-usage-1)
|
||||
- [Debugging](#debugging-1)
|
||||
- [SAM](#sam)
|
||||
- [Example usage](#example-usage-2)
|
||||
- [Debugging](#debugging-2)
|
||||
- [Example copy task](#example-copy-task)
|
||||
- [General noteworthy stuff](#general-noteworthy-stuff)
|
||||
|
||||
<!-- END doctoc generated TOC please keep comment here to allow auto update -->
|
||||
|
||||
[![Build Status](https://travis-ci.org/ixaxaar/pytorch-dnc.svg?branch=master)](https://travis-ci.org/ixaxaar/pytorch-dnc) [![PyPI version](https://badge.fury.io/py/dnc.svg)](https://badge.fury.io/py/dnc)
|
||||
|
||||
This is an implementation of [Differentiable Neural Computers](http://people.idsia.ch/~rupesh/rnnsymposium2016/slides/graves.pdf), described in the paper [Hybrid computing using a neural network with dynamic external memory, Graves et al.](https://www.nature.com/articles/nature20101)
|
||||
and the Sparse version of the DNC (the SDNC) described in [Scaling Memory-Augmented Neural Networks with Sparse Reads and Writes](http://papers.nips.cc/paper/6298-scaling-memory-augmented-neural-networks-with-sparse-reads-and-writes.pdf).
|
||||
and Sparse DNCs (SDNCs) and Sparse Access Memory (SAM) described in [Scaling Memory-Augmented Neural Networks with Sparse Reads and Writes](http://papers.nips.cc/paper/6298-scaling-memory-augmented-neural-networks-with-sparse-reads-and-writes.pdf).
|
||||
|
||||
## Install
|
||||
|
||||
@ -84,7 +92,7 @@ Following are the forward pass parameters:
|
||||
| pass_through_memory | `True` | Whether to pass through memory |
|
||||
|
||||
|
||||
#### Example usage:
|
||||
#### Example usage
|
||||
|
||||
```python
|
||||
from dnc import DNC
|
||||
@ -108,7 +116,7 @@ output, (controller_hidden, memory, read_vectors) = \
|
||||
```
|
||||
|
||||
|
||||
#### Debugging:
|
||||
#### Debugging
|
||||
|
||||
The `debug` option causes the network to return its memory hidden vectors (numpy `ndarray`s) for the first batch each forward step.
|
||||
These vectors can be analyzed or visualized, using visdom for example.
|
||||
@ -184,7 +192,7 @@ Following are the forward pass parameters:
|
||||
| pass_through_memory | `True` | Whether to pass through memory |
|
||||
|
||||
|
||||
#### Example usage:
|
||||
#### Example usage
|
||||
|
||||
```python
|
||||
from dnc import SDNC
|
||||
@ -209,7 +217,7 @@ output, (controller_hidden, memory, read_vectors) = \
|
||||
```
|
||||
|
||||
|
||||
#### Debugging:
|
||||
#### Debugging
|
||||
|
||||
The `debug` option causes the network to return its memory hidden vectors (numpy `ndarray`s) for the first batch each forward step.
|
||||
These vectors can be analyzed or visualized, using visdom for example.
|
||||
@ -252,6 +260,107 @@ Memory vectors returned by forward pass (`np.ndarray`):
|
||||
| `debug_memory['write_weights']` | layer * time | nr_cells
|
||||
| `debug_memory['usage']` | layer * time | nr_cells
|
||||
|
||||
### SAM
|
||||
|
||||
**Constructor Parameters**:
|
||||
|
||||
Following are the constructor parameters:
|
||||
|
||||
| Argument | Default | Description |
|
||||
| --- | --- | --- |
|
||||
| input_size | `None` | Size of the input vectors |
|
||||
| hidden_size | `None` | Size of hidden units |
|
||||
| rnn_type | `'lstm'` | Type of recurrent cells used in the controller |
|
||||
| num_layers | `1` | Number of layers of recurrent units in the controller |
|
||||
| num_hidden_layers | `2` | Number of hidden layers per layer of the controller |
|
||||
| bias | `True` | Bias |
|
||||
| batch_first | `True` | Whether data is fed batch first |
|
||||
| dropout | `0` | Dropout between layers in the controller |
|
||||
| bidirectional | `False` | If the controller is bidirectional (Not yet implemented |
|
||||
| nr_cells | `5000` | Number of memory cells |
|
||||
| read_heads | `4` | Number of read heads |
|
||||
| sparse_reads | `4` | Number of sparse memory reads per read head |
|
||||
| cell_size | `10` | Size of each memory cell |
|
||||
| nonlinearity | `'tanh'` | If using 'rnn' as `rnn_type`, non-linearity of the RNNs |
|
||||
| gpu_id | `-1` | ID of the GPU, -1 for CPU |
|
||||
| independent_linears | `False` | Whether to use independent linear units to derive interface vector |
|
||||
| share_memory | `True` | Whether to share memory between controller layers |
|
||||
|
||||
Following are the forward pass parameters:
|
||||
|
||||
| Argument | Default | Description |
|
||||
| --- | --- | --- |
|
||||
| input | - | The input vector `(B*T*X)` or `(T*B*X)` |
|
||||
| hidden | `(None,None,None)` | Hidden states `(controller hidden, memory hidden, read vectors)` |
|
||||
| reset_experience | `False` | Whether to reset memory |
|
||||
| pass_through_memory | `True` | Whether to pass through memory |
|
||||
|
||||
|
||||
#### Example usage
|
||||
|
||||
```python
|
||||
from dnc import SAM
|
||||
|
||||
rnn = SAM(
|
||||
input_size=64,
|
||||
hidden_size=128,
|
||||
rnn_type='lstm',
|
||||
num_layers=4,
|
||||
nr_cells=100,
|
||||
cell_size=32,
|
||||
read_heads=4,
|
||||
sparse_reads=4,
|
||||
batch_first=True,
|
||||
gpu_id=0
|
||||
)
|
||||
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors) = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
```
|
||||
|
||||
|
||||
#### Debugging
|
||||
|
||||
The `debug` option causes the network to return its memory hidden vectors (numpy `ndarray`s) for the first batch each forward step.
|
||||
These vectors can be analyzed or visualized, using visdom for example.
|
||||
|
||||
```python
|
||||
from dnc import SAM
|
||||
|
||||
rnn = SAM(
|
||||
input_size=64,
|
||||
hidden_size=128,
|
||||
rnn_type='lstm',
|
||||
num_layers=4,
|
||||
nr_cells=100,
|
||||
cell_size=32,
|
||||
read_heads=4,
|
||||
batch_first=True,
|
||||
sparse_reads=4,
|
||||
gpu_id=0,
|
||||
debug=True
|
||||
)
|
||||
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
```
|
||||
|
||||
Memory vectors returned by forward pass (`np.ndarray`):
|
||||
|
||||
| Key | Y axis (dimensions) | X axis (dimensions) |
|
||||
| --- | --- | --- |
|
||||
| `debug_memory['memory']` | layer * time | nr_cells * cell_size
|
||||
| `debug_memory['visible_memory']` | layer * time | sparse_reads+2*temporal_reads+1 * nr_cells
|
||||
| `debug_memory['read_positions']` | layer * time | sparse_reads+2*temporal_reads+1
|
||||
| `debug_memory['read_weights']` | layer * time | read_heads * nr_cells
|
||||
| `debug_memory['write_weights']` | layer * time | nr_cells
|
||||
| `debug_memory['usage']` | layer * time | nr_cells
|
||||
|
||||
|
||||
## Example copy task
|
||||
|
||||
The copy task, as descibed in the original paper, is included in the repo.
|
||||
|
@ -3,5 +3,7 @@
|
||||
|
||||
from .dnc import DNC
|
||||
from .sdnc import SDNC
|
||||
from .dnc import Memory
|
||||
from .sdnc import SparseMemory
|
||||
from .sam import SAM
|
||||
from .memory import Memory
|
||||
from .sparse_memory import SparseMemory
|
||||
from .sparse_temporal_memory import SparseTemporalMemory
|
||||
|
@ -65,7 +65,6 @@ class DNC(nn.Module):
|
||||
self.r = self.read_heads
|
||||
|
||||
self.read_vectors_size = self.r * self.w
|
||||
self.interface_size = self.read_vectors_size + (3 * self.w) + (5 * self.r) + 3
|
||||
self.output_size = self.hidden_size
|
||||
|
||||
self.nn_input_size = self.input_size + self.read_vectors_size
|
||||
|
125
dnc/sam.py
Normal file
125
dnc/sam.py
Normal file
@ -0,0 +1,125 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import numpy as np
|
||||
|
||||
from torch.nn.utils.rnn import pad_packed_sequence as pad
|
||||
from torch.nn.utils.rnn import pack_padded_sequence as pack
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.nn.init import orthogonal, xavier_uniform
|
||||
|
||||
from .util import *
|
||||
from .sparse_memory import SparseMemory
|
||||
|
||||
from .dnc import DNC
|
||||
|
||||
|
||||
class SAM(DNC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
rnn_type='lstm',
|
||||
num_layers=1,
|
||||
num_hidden_layers=2,
|
||||
bias=True,
|
||||
batch_first=True,
|
||||
dropout=0,
|
||||
bidirectional=False,
|
||||
nr_cells=5000,
|
||||
sparse_reads=4,
|
||||
read_heads=4,
|
||||
cell_size=10,
|
||||
nonlinearity='tanh',
|
||||
gpu_id=-1,
|
||||
independent_linears=False,
|
||||
share_memory=True,
|
||||
debug=False,
|
||||
clip=20
|
||||
):
|
||||
|
||||
super(SAM, self).__init__(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
rnn_type=rnn_type,
|
||||
num_layers=num_layers,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
bias=bias,
|
||||
batch_first=batch_first,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
nr_cells=nr_cells,
|
||||
read_heads=read_heads,
|
||||
cell_size=cell_size,
|
||||
nonlinearity=nonlinearity,
|
||||
gpu_id=gpu_id,
|
||||
independent_linears=independent_linears,
|
||||
share_memory=share_memory,
|
||||
debug=debug,
|
||||
clip=clip
|
||||
)
|
||||
self.sparse_reads = sparse_reads
|
||||
|
||||
# override SDNC memories with SAM
|
||||
self.memories = []
|
||||
|
||||
for layer in range(self.num_layers):
|
||||
# memories for each layer
|
||||
if not self.share_memory:
|
||||
self.memories.append(
|
||||
SparseMemory(
|
||||
input_size=self.output_size,
|
||||
mem_size=self.nr_cells,
|
||||
cell_size=self.w,
|
||||
sparse_reads=self.sparse_reads,
|
||||
read_heads=self.read_heads,
|
||||
gpu_id=self.gpu_id,
|
||||
mem_gpu_id=self.gpu_id,
|
||||
independent_linears=self.independent_linears
|
||||
)
|
||||
)
|
||||
setattr(self, 'rnn_layer_memory_' + str(layer), self.memories[layer])
|
||||
|
||||
# only one memory shared by all layers
|
||||
if self.share_memory:
|
||||
self.memories.append(
|
||||
SparseMemory(
|
||||
input_size=self.output_size,
|
||||
mem_size=self.nr_cells,
|
||||
cell_size=self.w,
|
||||
sparse_reads=self.sparse_reads,
|
||||
read_heads=self.read_heads,
|
||||
gpu_id=self.gpu_id,
|
||||
mem_gpu_id=self.gpu_id,
|
||||
independent_linears=self.independent_linears
|
||||
)
|
||||
)
|
||||
setattr(self, 'rnn_layer_memory_shared', self.memories[0])
|
||||
|
||||
def _debug(self, mhx, debug_obj):
|
||||
if not debug_obj:
|
||||
debug_obj = {
|
||||
'memory': [],
|
||||
'visible_memory': [],
|
||||
'read_weights': [],
|
||||
'write_weights': [],
|
||||
'read_vectors': [],
|
||||
'least_used_mem': [],
|
||||
'usage': [],
|
||||
'read_positions': []
|
||||
}
|
||||
|
||||
debug_obj['memory'].append(mhx['memory'][0].data.cpu().numpy())
|
||||
debug_obj['visible_memory'].append(mhx['visible_memory'][0].data.cpu().numpy())
|
||||
debug_obj['read_weights'].append(mhx['read_weights'][0].unsqueeze(0).data.cpu().numpy())
|
||||
debug_obj['write_weights'].append(mhx['write_weights'][0].unsqueeze(0).data.cpu().numpy())
|
||||
debug_obj['read_vectors'].append(mhx['read_vectors'][0].data.cpu().numpy())
|
||||
debug_obj['least_used_mem'].append(mhx['least_used_mem'][0].unsqueeze(0).data.cpu().numpy())
|
||||
debug_obj['usage'].append(mhx['usage'][0].unsqueeze(0).data.cpu().numpy())
|
||||
debug_obj['read_positions'].append(mhx['read_positions'][0].unsqueeze(0).data.cpu().numpy())
|
||||
|
||||
return debug_obj
|
210
dnc/sdnc.py
210
dnc/sdnc.py
@ -12,10 +12,11 @@ from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.nn.init import orthogonal, xavier_uniform
|
||||
|
||||
from .util import *
|
||||
from .sparse_memory import SparseMemory
|
||||
from .sparse_temporal_memory import SparseTemporalMemory
|
||||
from .dnc import DNC
|
||||
|
||||
|
||||
class SDNC(nn.Module):
|
||||
class SDNC(DNC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -40,58 +41,37 @@ class SDNC(nn.Module):
|
||||
debug=False,
|
||||
clip=20
|
||||
):
|
||||
super(SDNC, self).__init__()
|
||||
# todo: separate weights and RNNs for the interface and output vectors
|
||||
super(SDNC, self).__init__(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
rnn_type=rnn_type,
|
||||
num_layers=num_layers,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
bias=bias,
|
||||
batch_first=batch_first,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional,
|
||||
nr_cells=nr_cells,
|
||||
read_heads=read_heads,
|
||||
cell_size=cell_size,
|
||||
nonlinearity=nonlinearity,
|
||||
gpu_id=gpu_id,
|
||||
independent_linears=independent_linears,
|
||||
share_memory=share_memory,
|
||||
debug=debug,
|
||||
clip=clip
|
||||
)
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.rnn_type = rnn_type
|
||||
self.num_layers = num_layers
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.bias = bias
|
||||
self.batch_first = batch_first
|
||||
self.dropout = dropout
|
||||
self.bidirectional = bidirectional
|
||||
self.nr_cells = nr_cells
|
||||
self.sparse_reads = sparse_reads
|
||||
self.temporal_reads = temporal_reads
|
||||
self.read_heads = read_heads
|
||||
self.cell_size = cell_size
|
||||
self.nonlinearity = nonlinearity
|
||||
self.gpu_id = gpu_id
|
||||
self.independent_linears = independent_linears
|
||||
self.share_memory = share_memory
|
||||
self.debug = debug
|
||||
self.clip = clip
|
||||
|
||||
self.w = self.cell_size
|
||||
self.r = self.read_heads
|
||||
|
||||
self.read_vectors_size = self.r * self.w
|
||||
self.output_size = self.hidden_size
|
||||
|
||||
self.nn_input_size = self.input_size + self.read_vectors_size
|
||||
self.nn_output_size = self.output_size + self.read_vectors_size
|
||||
|
||||
self.rnns = []
|
||||
self.memories = []
|
||||
|
||||
for layer in range(self.num_layers):
|
||||
if self.rnn_type.lower() == 'rnn':
|
||||
self.rnns.append(nn.RNN((self.nn_input_size if layer == 0 else self.nn_output_size), self.output_size,
|
||||
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
||||
elif self.rnn_type.lower() == 'gru':
|
||||
self.rnns.append(nn.GRU((self.nn_input_size if layer == 0 else self.nn_output_size),
|
||||
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
||||
if self.rnn_type.lower() == 'lstm':
|
||||
self.rnns.append(nn.LSTM((self.nn_input_size if layer == 0 else self.nn_output_size),
|
||||
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
||||
setattr(self, self.rnn_type.lower() + '_layer_' + str(layer), self.rnns[layer])
|
||||
|
||||
# memories for each layer
|
||||
if not self.share_memory:
|
||||
self.memories.append(
|
||||
SparseMemory(
|
||||
SparseTemporalMemory(
|
||||
input_size=self.output_size,
|
||||
mem_size=self.nr_cells,
|
||||
cell_size=self.w,
|
||||
@ -108,7 +88,7 @@ class SDNC(nn.Module):
|
||||
# only one memory shared by all layers
|
||||
if self.share_memory:
|
||||
self.memories.append(
|
||||
SparseMemory(
|
||||
SparseTemporalMemory(
|
||||
input_size=self.output_size,
|
||||
mem_size=self.nr_cells,
|
||||
cell_size=self.w,
|
||||
@ -122,45 +102,6 @@ class SDNC(nn.Module):
|
||||
)
|
||||
setattr(self, 'rnn_layer_memory_shared', self.memories[0])
|
||||
|
||||
# final output layer
|
||||
self.output = nn.Linear(self.nn_output_size, self.input_size)
|
||||
orthogonal(self.output.weight)
|
||||
|
||||
if self.gpu_id != -1:
|
||||
[x.cuda(self.gpu_id) for x in self.rnns]
|
||||
[x.cuda(self.gpu_id) for x in self.memories]
|
||||
|
||||
def _init_hidden(self, hx, batch_size, reset_experience):
|
||||
# create empty hidden states if not provided
|
||||
if hx is None:
|
||||
hx = (None, None, None)
|
||||
(chx, mhx, last_read) = hx
|
||||
|
||||
# initialize hidden state of the controller RNN
|
||||
if chx is None:
|
||||
h = cuda(T.zeros(self.num_hidden_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
|
||||
xavier_uniform(h)
|
||||
|
||||
chx = [ (h, h) if self.rnn_type.lower() == 'lstm' else h for x in range(self.num_layers)]
|
||||
|
||||
# Last read vectors
|
||||
if last_read is None:
|
||||
last_read = cuda(T.zeros(batch_size, self.w * self.r), gpu_id=self.gpu_id)
|
||||
|
||||
# memory states
|
||||
if mhx is None:
|
||||
if self.share_memory:
|
||||
mhx = self.memories[0].reset(batch_size, erase=reset_experience)
|
||||
else:
|
||||
mhx = [m.reset(batch_size, erase=reset_experience) for m in self.memories]
|
||||
else:
|
||||
if self.share_memory:
|
||||
mhx = self.memories[0].reset(batch_size, mhx, erase=reset_experience)
|
||||
else:
|
||||
mhx = [m.reset(batch_size, h, erase=reset_experience) for m, h in zip(self.memories, mhx)]
|
||||
|
||||
return chx, mhx, last_read
|
||||
|
||||
def _debug(self, mhx, debug_obj):
|
||||
if not debug_obj:
|
||||
debug_obj = {
|
||||
@ -191,104 +132,3 @@ class SDNC(nn.Module):
|
||||
|
||||
return debug_obj
|
||||
|
||||
def _layer_forward(self, input, layer, hx=(None, None), pass_through_memory=True):
|
||||
(chx, mhx) = hx
|
||||
|
||||
# pass through the controller layer
|
||||
input, chx = self.rnns[layer](input.unsqueeze(1), chx)
|
||||
input = input.squeeze(1)
|
||||
|
||||
# clip the controller output
|
||||
if self.clip != 0:
|
||||
output = T.clamp(input, -self.clip, self.clip)
|
||||
else:
|
||||
output = input
|
||||
|
||||
# the interface vector
|
||||
ξ = output
|
||||
|
||||
# pass through memory
|
||||
if pass_through_memory:
|
||||
if self.share_memory:
|
||||
read_vecs, mhx = self.memories[0](ξ, mhx)
|
||||
else:
|
||||
read_vecs, mhx = self.memories[layer](ξ, mhx)
|
||||
# the read vectors
|
||||
read_vectors = read_vecs.view(-1, self.w * self.r)
|
||||
else:
|
||||
read_vectors = None
|
||||
|
||||
return output, (chx, mhx, read_vectors)
|
||||
|
||||
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
|
||||
# handle packed data
|
||||
is_packed = type(input) is PackedSequence
|
||||
if is_packed:
|
||||
input, lengths = pad(input)
|
||||
max_length = lengths[0]
|
||||
else:
|
||||
max_length = input.size(1) if self.batch_first else input.size(0)
|
||||
lengths = [input.size(1)] * max_length if self.batch_first else [input.size(0)] * max_length
|
||||
|
||||
batch_size = input.size(0) if self.batch_first else input.size(1)
|
||||
|
||||
if not self.batch_first:
|
||||
input = input.transpose(0, 1)
|
||||
# make the data time-first
|
||||
|
||||
controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience)
|
||||
|
||||
# concat input with last read (or padding) vectors
|
||||
inputs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
|
||||
|
||||
# batched forward pass per element / word / etc
|
||||
if self.debug:
|
||||
viz = None
|
||||
|
||||
outs = [None] * max_length
|
||||
read_vectors = None
|
||||
|
||||
# pass through time
|
||||
for time in range(max_length):
|
||||
# pass thorugh layers
|
||||
for layer in range(self.num_layers):
|
||||
# this layer's hidden states
|
||||
chx = controller_hidden[layer]
|
||||
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
||||
# pass through controller
|
||||
outs[time], (chx, m, read_vectors) = \
|
||||
self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory)
|
||||
|
||||
# debug memory
|
||||
if self.debug:
|
||||
viz = self._debug(m, viz)
|
||||
|
||||
# store the memory back (per layer or shared)
|
||||
if self.share_memory:
|
||||
mem_hidden = m
|
||||
else:
|
||||
mem_hidden[layer] = m
|
||||
controller_hidden[layer] = chx
|
||||
|
||||
if read_vectors is not None:
|
||||
# the controller output + read vectors go into next layer
|
||||
outs[time] = T.cat([outs[time], read_vectors], 1)
|
||||
else:
|
||||
outs[time] = T.cat([outs[time], last_read], 1)
|
||||
inputs[time] = outs[time]
|
||||
|
||||
if self.debug:
|
||||
viz = {k: np.array(v) for k, v in viz.items()}
|
||||
viz = {k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k, v in viz.items()}
|
||||
|
||||
# pass through final output layer
|
||||
inputs = [self.output(i) for i in inputs]
|
||||
outputs = T.stack(inputs, 1 if self.batch_first else 0)
|
||||
|
||||
if is_packed:
|
||||
outputs = pack(output, lengths)
|
||||
|
||||
if self.debug:
|
||||
return outputs, (controller_hidden, mem_hidden, read_vectors), viz
|
||||
else:
|
||||
return outputs, (controller_hidden, mem_hidden, read_vectors)
|
||||
|
@ -23,7 +23,6 @@ class SparseMemory(nn.Module):
|
||||
independent_linears=True,
|
||||
read_heads=4,
|
||||
sparse_reads=4,
|
||||
temporal_reads=4,
|
||||
num_lists=None,
|
||||
index_checks=32,
|
||||
gpu_id=-1,
|
||||
@ -38,7 +37,6 @@ class SparseMemory(nn.Module):
|
||||
self.input_size = input_size
|
||||
self.independent_linears = independent_linears
|
||||
self.K = sparse_reads if self.mem_size > sparse_reads else self.mem_size
|
||||
self.KL = temporal_reads if self.mem_size > temporal_reads else self.mem_size
|
||||
self.read_heads = read_heads
|
||||
self.num_lists = num_lists if num_lists is not None else int(self.mem_size / 100)
|
||||
self.index_checks = index_checks
|
||||
@ -47,7 +45,7 @@ class SparseMemory(nn.Module):
|
||||
w = self.cell_size
|
||||
r = self.read_heads
|
||||
# The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell)
|
||||
self.c = (r * self.K) + (self.KL * 2) + 1
|
||||
self.c = (r * self.K) + 1
|
||||
|
||||
if self.independent_linears:
|
||||
self.read_query_transform = nn.Linear(self.input_size, w*r)
|
||||
@ -103,9 +101,6 @@ class SparseMemory(nn.Module):
|
||||
# warning can be a huge chunk of contiguous memory
|
||||
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
|
||||
'rev_link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
|
||||
'precedence': cuda(T.zeros(b, self.KL*2).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||
@ -117,9 +112,6 @@ class SparseMemory(nn.Module):
|
||||
else:
|
||||
hidden['memory'] = hidden['memory'].clone()
|
||||
hidden['visible_memory'] = hidden['visible_memory'].clone()
|
||||
hidden['link_matrix'] = hidden['link_matrix'].clone()
|
||||
hidden['rev_link_matrix'] = hidden['link_matrix'].clone()
|
||||
hidden['precedence'] = hidden['precedence'].clone()
|
||||
hidden['read_weights'] = hidden['read_weights'].clone()
|
||||
hidden['write_weights'] = hidden['write_weights'].clone()
|
||||
hidden['read_vectors'] = hidden['read_vectors'].clone()
|
||||
@ -131,9 +123,6 @@ class SparseMemory(nn.Module):
|
||||
if erase:
|
||||
hidden['memory'].data.fill_(δ)
|
||||
hidden['visible_memory'].data.fill_(δ)
|
||||
hidden['link_matrix'].data.zero_()
|
||||
hidden['rev_link_matrix'].data.zero_()
|
||||
hidden['precedence'].data.zero_()
|
||||
hidden['read_weights'].data.fill_(δ)
|
||||
hidden['write_weights'].data.fill_(δ)
|
||||
hidden['read_vectors'].data.fill_(δ)
|
||||
@ -163,32 +152,6 @@ class SparseMemory(nn.Module):
|
||||
|
||||
return hidden
|
||||
|
||||
def update_link_matrices(self, link_matrix, rev_link_matrix, write_weights, precedence, temporal_read_positions):
|
||||
write_weights_i = write_weights.unsqueeze(2)
|
||||
# write_weights_j = write_weights.unsqueeze(1)
|
||||
|
||||
# precedence_i = precedence.unsqueeze(2)
|
||||
precedence_j = precedence.unsqueeze(1)
|
||||
|
||||
(b, m, k) = link_matrix.size()
|
||||
I = cuda(T.eye(m, k).unsqueeze(0).expand((b, m, k)), gpu_id=self.gpu_id)
|
||||
|
||||
# since only KL*2 entries are kept non-zero sparse, create the dense version from the sparse one
|
||||
precedence_dense = cuda(T.zeros(b, m), gpu_id=self.gpu_id)
|
||||
precedence_dense.scatter_(1, temporal_read_positions, precedence)
|
||||
precedence_dense_i = precedence_dense.unsqueeze(2)
|
||||
|
||||
temporal_write_weights_j = write_weights.gather(1, temporal_read_positions).unsqueeze(1)
|
||||
|
||||
link_matrix = (1 - write_weights_i) * link_matrix + write_weights_i * precedence_j
|
||||
|
||||
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (temporal_write_weights_j * precedence_dense_i)
|
||||
|
||||
return link_matrix.squeeze() * I, rev_link_matrix.squeeze() * I
|
||||
|
||||
def update_precedence(self, precedence, write_weights):
|
||||
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
|
||||
|
||||
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
||||
|
||||
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
||||
@ -217,24 +180,6 @@ class SparseMemory(nn.Module):
|
||||
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||
hidden = self.write_into_sparse_memory(hidden)
|
||||
|
||||
# update link_matrix and precedence
|
||||
(b, c) = write_weights.size()
|
||||
|
||||
# update link matrix
|
||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
|
||||
hidden['link_matrix'], hidden['rev_link_matrix'] = \
|
||||
self.update_link_matrices(
|
||||
hidden['link_matrix'],
|
||||
hidden['rev_link_matrix'],
|
||||
hidden['write_weights'],
|
||||
hidden['precedence'],
|
||||
temporal_read_positions
|
||||
)
|
||||
|
||||
# update precedence vector
|
||||
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
||||
hidden['precedence'] = self.update_precedence(hidden['precedence'], read_weights)
|
||||
|
||||
return hidden
|
||||
|
||||
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
||||
@ -257,12 +202,7 @@ class SparseMemory(nn.Module):
|
||||
|
||||
return usage, I
|
||||
|
||||
def directional_weightings(self, link_matrix, rev_link_matrix, temporal_read_weights):
|
||||
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
||||
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
||||
return f, b
|
||||
|
||||
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage, forward, backward, prev_read_positions):
|
||||
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage):
|
||||
b = keys.size(0)
|
||||
read_positions = []
|
||||
|
||||
@ -283,12 +223,8 @@ class SparseMemory(nn.Module):
|
||||
# get the top KL entries
|
||||
max_length = int(least_used_mem[0, 0].data.cpu().numpy())
|
||||
|
||||
_, fp = T.topk(forward, self.KL, largest=True)
|
||||
_, bp = T.topk(backward, self.KL, largest=True)
|
||||
|
||||
# differentiable ops
|
||||
# append forward and backward read positions, might lead to duplicates
|
||||
read_positions = T.cat([read_positions, fp, bp], 1)
|
||||
read_positions = T.cat([read_positions, least_used_mem], 1)
|
||||
read_positions = T.clamp(read_positions, 0, max_length)
|
||||
|
||||
@ -301,11 +237,6 @@ class SparseMemory(nn.Module):
|
||||
return read_vectors, read_positions, read_weights, visible_memory
|
||||
|
||||
def read(self, read_query, hidden):
|
||||
# get forward and backward weights
|
||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
|
||||
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
||||
forward, backward = self.directional_weightings(hidden['link_matrix'], hidden['rev_link_matrix'], read_weights)
|
||||
|
||||
# sparse read
|
||||
read_vectors, positions, read_weights, visible_memory = \
|
||||
self.read_from_sparse_memory(
|
||||
@ -313,9 +244,7 @@ class SparseMemory(nn.Module):
|
||||
hidden['indexes'],
|
||||
read_query,
|
||||
hidden['least_used_mem'],
|
||||
hidden['usage'],
|
||||
forward, backward,
|
||||
hidden['read_positions']
|
||||
hidden['usage']
|
||||
)
|
||||
|
||||
hidden['read_positions'] = positions
|
||||
|
357
dnc/sparse_temporal_memory.py
Normal file
357
dnc/sparse_temporal_memory.py
Normal file
@ -0,0 +1,357 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
from .flann_index import FLANNIndex
|
||||
from .util import *
|
||||
import time
|
||||
|
||||
|
||||
class SparseTemporalMemory(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
mem_size=512,
|
||||
cell_size=32,
|
||||
independent_linears=True,
|
||||
read_heads=4,
|
||||
sparse_reads=4,
|
||||
temporal_reads=4,
|
||||
num_lists=None,
|
||||
index_checks=32,
|
||||
gpu_id=-1,
|
||||
mem_gpu_id=-1
|
||||
):
|
||||
super(SparseTemporalMemory, self).__init__()
|
||||
|
||||
self.mem_size = mem_size
|
||||
self.cell_size = cell_size
|
||||
self.gpu_id = gpu_id
|
||||
self.mem_gpu_id = mem_gpu_id
|
||||
self.input_size = input_size
|
||||
self.independent_linears = independent_linears
|
||||
self.K = sparse_reads if self.mem_size > sparse_reads else self.mem_size
|
||||
self.KL = temporal_reads if self.mem_size > temporal_reads else self.mem_size
|
||||
self.read_heads = read_heads
|
||||
self.num_lists = num_lists if num_lists is not None else int(self.mem_size / 100)
|
||||
self.index_checks = index_checks
|
||||
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
r = self.read_heads
|
||||
# The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell)
|
||||
self.c = (r * self.K) + (self.KL * 2) + 1
|
||||
|
||||
if self.independent_linears:
|
||||
self.read_query_transform = nn.Linear(self.input_size, w*r)
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||
T.nn.init.orthogonal(self.read_query_transform.weight)
|
||||
T.nn.init.orthogonal(self.write_vector_transform.weight)
|
||||
T.nn.init.orthogonal(self.interpolation_gate_transform.weight)
|
||||
T.nn.init.orthogonal(self.write_gate_transform.weight)
|
||||
else:
|
||||
self.interface_size = (r * w) + w + self.c + 1
|
||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
||||
T.nn.init.orthogonal(self.interface_weights.weight)
|
||||
|
||||
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
self.δ = 0.005 # minimum usage
|
||||
self.timestep = 0
|
||||
|
||||
def rebuild_indexes(self, hidden, erase=False):
|
||||
b = hidden['memory'].size(0)
|
||||
|
||||
# if indexes already exist, we reset them
|
||||
if 'indexes' in hidden:
|
||||
[x.reset() for x in hidden['indexes']]
|
||||
else:
|
||||
# create new indexes
|
||||
hidden['indexes'] = \
|
||||
[FLANNIndex(cell_size=self.cell_size,
|
||||
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
|
||||
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||
|
||||
# add existing memory into indexes
|
||||
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
|
||||
if not erase:
|
||||
for n, i in enumerate(hidden['indexes']):
|
||||
i.reset()
|
||||
i.add(hidden['memory'][n], last=pos[n][-1])
|
||||
else:
|
||||
self.timestep = 0
|
||||
|
||||
return hidden
|
||||
|
||||
def reset(self, batch_size=1, hidden=None, erase=True):
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
b = batch_size
|
||||
r = self.read_heads
|
||||
c = self.c
|
||||
|
||||
if hidden is None:
|
||||
hidden = {
|
||||
# warning can be a huge chunk of contiguous memory
|
||||
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
|
||||
'rev_link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
|
||||
'precedence': cuda(T.zeros(b, self.KL*2).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||
'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(),
|
||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
}
|
||||
hidden = self.rebuild_indexes(hidden, erase=True)
|
||||
else:
|
||||
hidden['memory'] = hidden['memory'].clone()
|
||||
hidden['visible_memory'] = hidden['visible_memory'].clone()
|
||||
hidden['link_matrix'] = hidden['link_matrix'].clone()
|
||||
hidden['rev_link_matrix'] = hidden['link_matrix'].clone()
|
||||
hidden['precedence'] = hidden['precedence'].clone()
|
||||
hidden['read_weights'] = hidden['read_weights'].clone()
|
||||
hidden['write_weights'] = hidden['write_weights'].clone()
|
||||
hidden['read_vectors'] = hidden['read_vectors'].clone()
|
||||
hidden['least_used_mem'] = hidden['least_used_mem'].clone()
|
||||
hidden['usage'] = hidden['usage'].clone()
|
||||
hidden['read_positions'] = hidden['read_positions'].clone()
|
||||
hidden = self.rebuild_indexes(hidden, erase)
|
||||
|
||||
if erase:
|
||||
hidden['memory'].data.fill_(δ)
|
||||
hidden['visible_memory'].data.fill_(δ)
|
||||
hidden['link_matrix'].data.zero_()
|
||||
hidden['rev_link_matrix'].data.zero_()
|
||||
hidden['precedence'].data.zero_()
|
||||
hidden['read_weights'].data.fill_(δ)
|
||||
hidden['write_weights'].data.fill_(δ)
|
||||
hidden['read_vectors'].data.fill_(δ)
|
||||
hidden['least_used_mem'].data.fill_(c+1+self.timestep)
|
||||
hidden['usage'].data.fill_(δ)
|
||||
hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
|
||||
return hidden
|
||||
|
||||
def write_into_sparse_memory(self, hidden):
|
||||
visible_memory = hidden['visible_memory']
|
||||
positions = hidden['read_positions'].squeeze()
|
||||
|
||||
(b, m, w) = hidden['memory'].size()
|
||||
# update memory
|
||||
hidden['memory'].scatter_(1, positions.unsqueeze(2).expand(b, self.c, w), visible_memory)
|
||||
|
||||
# non-differentiable operations
|
||||
pos = positions.data.cpu().numpy()
|
||||
for batch in range(b):
|
||||
# update indexes
|
||||
hidden['indexes'][batch].reset()
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
||||
|
||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size-1
|
||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + 1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
||||
|
||||
return hidden
|
||||
|
||||
def update_link_matrices(self, link_matrix, rev_link_matrix, write_weights, precedence, temporal_read_positions):
|
||||
write_weights_i = write_weights.unsqueeze(2)
|
||||
precedence_j = precedence.unsqueeze(1)
|
||||
|
||||
(b, m, k) = link_matrix.size()
|
||||
I = cuda(T.eye(m, k).unsqueeze(0).expand((b, m, k)), gpu_id=self.gpu_id)
|
||||
|
||||
# since only KL*2 entries are kept non-zero sparse, create the dense version from the sparse one
|
||||
precedence_dense = cuda(T.zeros(b, m), gpu_id=self.gpu_id)
|
||||
precedence_dense.scatter_(1, temporal_read_positions, precedence)
|
||||
precedence_dense_i = precedence_dense.unsqueeze(2)
|
||||
|
||||
temporal_write_weights_j = write_weights.gather(1, temporal_read_positions).unsqueeze(1)
|
||||
|
||||
link_matrix = (1 - write_weights_i) * link_matrix + write_weights_i * precedence_j
|
||||
|
||||
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (temporal_write_weights_j * precedence_dense_i)
|
||||
|
||||
return link_matrix.squeeze() * I, rev_link_matrix.squeeze() * I
|
||||
|
||||
def update_precedence(self, precedence, write_weights):
|
||||
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
|
||||
|
||||
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
||||
|
||||
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
||||
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
|
||||
|
||||
hidden['usage'], I = self.update_usage(
|
||||
hidden['read_positions'],
|
||||
read_weights,
|
||||
write_weights,
|
||||
hidden['usage']
|
||||
)
|
||||
|
||||
# either we write to previous read locations
|
||||
x = interpolation_gate * read_weights
|
||||
# or to a new location
|
||||
y = (1 - interpolation_gate) * I
|
||||
write_weights = write_gate * (x + y)
|
||||
|
||||
# store the write weights
|
||||
hidden['write_weights'].scatter_(1, hidden['read_positions'], write_weights)
|
||||
|
||||
# erase matrix
|
||||
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
|
||||
|
||||
# write into memory
|
||||
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||
hidden = self.write_into_sparse_memory(hidden)
|
||||
|
||||
# update link_matrix and precedence
|
||||
(b, c) = write_weights.size()
|
||||
|
||||
# update link matrix
|
||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
|
||||
hidden['link_matrix'], hidden['rev_link_matrix'] = \
|
||||
self.update_link_matrices(
|
||||
hidden['link_matrix'],
|
||||
hidden['rev_link_matrix'],
|
||||
hidden['write_weights'],
|
||||
hidden['precedence'],
|
||||
temporal_read_positions
|
||||
)
|
||||
|
||||
# update precedence vector
|
||||
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
||||
hidden['precedence'] = self.update_precedence(hidden['precedence'], read_weights)
|
||||
|
||||
return hidden
|
||||
|
||||
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
||||
(b, _) = read_positions.size()
|
||||
# usage is timesteps since a non-negligible memory access
|
||||
u = (read_weights + write_weights > self.δ).float()
|
||||
|
||||
# usage before write
|
||||
relevant_usages = usage.gather(1, read_positions)
|
||||
|
||||
# indicator of words with minimal memory usage
|
||||
minusage = T.min(relevant_usages, -1, keepdim=True)[0]
|
||||
minusage = minusage.expand(relevant_usages.size())
|
||||
I = (relevant_usages == minusage).float()
|
||||
|
||||
# usage after write
|
||||
relevant_usages = (self.timestep - relevant_usages) * u + relevant_usages * (1 - u)
|
||||
|
||||
usage.scatter_(1, read_positions, relevant_usages)
|
||||
|
||||
return usage, I
|
||||
|
||||
def directional_weightings(self, link_matrix, rev_link_matrix, temporal_read_weights):
|
||||
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
||||
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
||||
return f, b
|
||||
|
||||
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage, forward, backward, prev_read_positions):
|
||||
b = keys.size(0)
|
||||
read_positions = []
|
||||
|
||||
# we search for k cells per read head
|
||||
for batch in range(b):
|
||||
distances, positions = indexes[batch].search(keys[batch])
|
||||
read_positions.append(positions)
|
||||
read_positions = T.stack(read_positions, 0)
|
||||
|
||||
# add least used mem to read positions
|
||||
# TODO: explore possibility of reading co-locations or ranges and such
|
||||
(b, r, k) = read_positions.size()
|
||||
read_positions = var(read_positions).squeeze(1).view(b, -1)
|
||||
|
||||
# no gradient here
|
||||
# temporal reads
|
||||
(b, m, w) = memory.size()
|
||||
# get the top KL entries
|
||||
max_length = int(least_used_mem[0, 0].data.cpu().numpy())
|
||||
|
||||
_, fp = T.topk(forward, self.KL, largest=True)
|
||||
_, bp = T.topk(backward, self.KL, largest=True)
|
||||
|
||||
# differentiable ops
|
||||
# append forward and backward read positions, might lead to duplicates
|
||||
read_positions = T.cat([read_positions, fp, bp], 1)
|
||||
read_positions = T.cat([read_positions, least_used_mem], 1)
|
||||
read_positions = T.clamp(read_positions, 0, max_length)
|
||||
|
||||
visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, self.c, w))
|
||||
|
||||
read_weights = σ(θ(visible_memory, keys), 2)
|
||||
read_vectors = T.bmm(read_weights, visible_memory)
|
||||
read_weights = T.prod(read_weights, 1)
|
||||
|
||||
return read_vectors, read_positions, read_weights, visible_memory
|
||||
|
||||
def read(self, read_query, hidden):
|
||||
# get forward and backward weights
|
||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
|
||||
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
||||
forward, backward = self.directional_weightings(hidden['link_matrix'], hidden['rev_link_matrix'], read_weights)
|
||||
|
||||
# sparse read
|
||||
read_vectors, positions, read_weights, visible_memory = \
|
||||
self.read_from_sparse_memory(
|
||||
hidden['memory'],
|
||||
hidden['indexes'],
|
||||
read_query,
|
||||
hidden['least_used_mem'],
|
||||
hidden['usage'],
|
||||
forward, backward,
|
||||
hidden['read_positions']
|
||||
)
|
||||
|
||||
hidden['read_positions'] = positions
|
||||
hidden['read_weights'] = hidden['read_weights'].scatter_(1, positions, read_weights)
|
||||
hidden['read_vectors'] = read_vectors
|
||||
hidden['visible_memory'] = visible_memory
|
||||
|
||||
return hidden['read_vectors'], hidden
|
||||
|
||||
def forward(self, ξ, hidden):
|
||||
t = time.time()
|
||||
|
||||
# ξ = ξ.detach()
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
r = self.read_heads
|
||||
c = self.c
|
||||
b = ξ.size()[0]
|
||||
|
||||
if self.independent_linears:
|
||||
# r read keys (b * r * w)
|
||||
read_query = self.read_query_transform(ξ).view(b, r, w)
|
||||
# write key (b * 1 * w)
|
||||
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
||||
else:
|
||||
ξ = self.interface_weights(ξ)
|
||||
# r read keys (b * r * w)
|
||||
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
|
||||
# write key (b * 1 * w)
|
||||
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
||||
|
||||
self.timestep += 1
|
||||
hidden = self.write(interpolation_gate, write_vector, write_gate, hidden)
|
||||
return self.read(read_query, hidden)
|
@ -89,7 +89,10 @@ def σ(input, axis=1):
|
||||
trans_size = trans_input.size()
|
||||
|
||||
input_2d = trans_input.contiguous().view(-1, trans_size[-1])
|
||||
soft_max_2d = F.softmax(input_2d)
|
||||
if '0.3' in T.__version__:
|
||||
soft_max_2d = F.softmax(input_2d, -1)
|
||||
else:
|
||||
soft_max_2d = F.softmax(input_2d)
|
||||
soft_max_nd = soft_max_2d.view(*trans_size)
|
||||
return soft_max_nd.transpose(axis, len(input_size) - 1)
|
||||
|
||||
|
@ -24,6 +24,7 @@ from torch.nn.utils import clip_grad_norm
|
||||
|
||||
from dnc.dnc import DNC
|
||||
from dnc.sdnc import SDNC
|
||||
from dnc.sam import SAM
|
||||
from dnc.util import *
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Differentiable Neural Computer')
|
||||
@ -31,7 +32,7 @@ parser.add_argument('-input_size', type=int, default=6, help='dimension of input
|
||||
parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller')
|
||||
parser.add_argument('-nhid', type=int, default=64, help='number of hidden units of the inner nn')
|
||||
parser.add_argument('-dropout', type=float, default=0, help='controller dropout')
|
||||
parser.add_argument('-memory_type', type=str, default='dnc', help='dense or sparse memory')
|
||||
parser.add_argument('-memory_type', type=str, default='dnc', help='dense or sparse memory: dnc | sdnc | sam')
|
||||
|
||||
parser.add_argument('-nlayer', type=int, default=1, help='number of layers')
|
||||
parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers')
|
||||
@ -55,6 +56,7 @@ parser.add_argument('-log-interval', type=int, default=200, metavar='N', help='r
|
||||
parser.add_argument('-iterations', type=int, default=100000, metavar='N', help='total number of iteration')
|
||||
parser.add_argument('-summarize_freq', type=int, default=100, metavar='N', help='summarize frequency')
|
||||
parser.add_argument('-check_freq', type=int, default=100, metavar='N', help='check point frequency')
|
||||
parser.add_argument('-visdom', action='store_true', help='plot memory content on visdom per -summarize_freq steps')
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
@ -129,7 +131,7 @@ if __name__ == '__main__':
|
||||
cell_size=mem_size,
|
||||
read_heads=read_heads,
|
||||
gpu_id=args.cuda,
|
||||
debug=True,
|
||||
debug=args.visdom,
|
||||
batch_first=True,
|
||||
independent_linears=True
|
||||
)
|
||||
@ -147,7 +149,24 @@ if __name__ == '__main__':
|
||||
temporal_reads=args.temporal_reads,
|
||||
read_heads=args.read_heads,
|
||||
gpu_id=args.cuda,
|
||||
debug=False,
|
||||
debug=args.visdom,
|
||||
batch_first=True,
|
||||
independent_linears=False
|
||||
)
|
||||
elif args.memory_type == 'sam':
|
||||
rnn = SAM(
|
||||
input_size=args.input_size,
|
||||
hidden_size=args.nhid,
|
||||
rnn_type=args.rnn_type,
|
||||
num_layers=args.nlayer,
|
||||
num_hidden_layers=args.nhlayer,
|
||||
dropout=args.dropout,
|
||||
nr_cells=mem_slot,
|
||||
cell_size=mem_size,
|
||||
sparse_reads=args.sparse_reads,
|
||||
read_heads=args.read_heads,
|
||||
gpu_id=args.cuda,
|
||||
debug=args.visdom,
|
||||
batch_first=True,
|
||||
independent_linears=False
|
||||
)
|
||||
@ -252,7 +271,7 @@ if __name__ == '__main__':
|
||||
xlabel='mem_slot'
|
||||
)
|
||||
)
|
||||
else:
|
||||
elif args.memory_type == 'sdnc':
|
||||
viz.heatmap(
|
||||
v['link_matrix'][-1].reshape(args.mem_slot, -1),
|
||||
opts=dict(
|
||||
@ -275,16 +294,17 @@ if __name__ == '__main__':
|
||||
)
|
||||
)
|
||||
|
||||
viz.heatmap(
|
||||
v['precedence'],
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
ylabel='layer * time',
|
||||
xlabel='mem_slot'
|
||||
)
|
||||
)
|
||||
elif args.memory_type == 'sdnc' or args.memory_type == 'dnc':
|
||||
viz.heatmap(
|
||||
v['precedence'],
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
ylabel='layer * time',
|
||||
xlabel='mem_slot'
|
||||
)
|
||||
)
|
||||
|
||||
if args.memory_type == 'sdnc':
|
||||
viz.heatmap(
|
||||
|
201
test/test_sam_gru.py
Normal file
201
test/test_sam_gru.py
Normal file
@ -0,0 +1,201 @@
|
||||
# #!/usr/bin/env python3
|
||||
# # -*- coding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
import functools
|
||||
sys.path.insert(0, '.')
|
||||
|
||||
from dnc import SAM
|
||||
from test_utils import generate_data, criterion
|
||||
|
||||
|
||||
def test_rnn_1():
|
||||
T.manual_seed(1111)
|
||||
|
||||
input_size = 100
|
||||
hidden_size = 100
|
||||
rnn_type = 'gru'
|
||||
num_layers = 1
|
||||
num_hidden_layers = 1
|
||||
dropout = 0
|
||||
nr_cells = 100
|
||||
cell_size = 10
|
||||
read_heads = 1
|
||||
sparse_reads = 2
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
sequence_max_length = 10
|
||||
batch_size = 10
|
||||
cuda = gpu_id
|
||||
clip = 10
|
||||
length = 10
|
||||
|
||||
rnn = SAM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
rnn_type=rnn_type,
|
||||
num_layers=num_layers,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
dropout=dropout,
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
output, (chx, mhx, rv), v = rnn(input_data, None)
|
||||
output = output.transpose(0, 1)
|
||||
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
assert chx[0][0].size() == T.Size([10,100])
|
||||
# assert mhx['memory'].size() == T.Size([10,1,1])
|
||||
assert rv.size() == T.Size([10, 10])
|
||||
|
||||
|
||||
def test_rnn_n():
|
||||
T.manual_seed(1111)
|
||||
|
||||
input_size = 100
|
||||
hidden_size = 100
|
||||
rnn_type = 'gru'
|
||||
num_layers = 3
|
||||
num_hidden_layers = 5
|
||||
dropout = 0.2
|
||||
nr_cells = 200
|
||||
cell_size = 17
|
||||
read_heads = 2
|
||||
sparse_reads = 4
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
sequence_max_length = 10
|
||||
batch_size = 10
|
||||
cuda = gpu_id
|
||||
clip = 20
|
||||
length = 13
|
||||
|
||||
rnn = SAM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
rnn_type=rnn_type,
|
||||
num_layers=num_layers,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
dropout=dropout,
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
output, (chx, mhx, rv), v = rnn(input_data, None)
|
||||
output = output.transpose(0, 1)
|
||||
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[0].size() == T.Size([num_hidden_layers,10,100])
|
||||
# assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([10, 34])
|
||||
|
||||
|
||||
def test_rnn_no_memory_pass():
|
||||
T.manual_seed(1111)
|
||||
|
||||
input_size = 100
|
||||
hidden_size = 100
|
||||
rnn_type = 'gru'
|
||||
num_layers = 3
|
||||
num_hidden_layers = 5
|
||||
dropout = 0.2
|
||||
nr_cells = 5000
|
||||
cell_size = 17
|
||||
sparse_reads = 3
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
sequence_max_length = 10
|
||||
batch_size = 10
|
||||
cuda = gpu_id
|
||||
clip = 20
|
||||
length = 13
|
||||
|
||||
rnn = SAM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
rnn_type=rnn_type,
|
||||
num_layers=num_layers,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
dropout=dropout,
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
sparse_reads=sparse_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
(chx, mhx, rv) = (None, None, None)
|
||||
outputs = []
|
||||
for x in range(6):
|
||||
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
|
||||
output = output.transpose(0, 1)
|
||||
outputs.append(output)
|
||||
|
||||
output = functools.reduce(lambda x,y: x + y, outputs)
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[0].size() == T.Size([num_hidden_layers,10,100])
|
||||
# assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv == None
|
||||
|
201
test/test_sam_lstm.py
Normal file
201
test/test_sam_lstm.py
Normal file
@ -0,0 +1,201 @@
|
||||
# #!/usr/bin/env python3
|
||||
# # -*- coding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
import functools
|
||||
sys.path.insert(0, '.')
|
||||
|
||||
from dnc import SAM
|
||||
from test_utils import generate_data, criterion
|
||||
|
||||
|
||||
def test_rnn_1():
|
||||
T.manual_seed(1111)
|
||||
|
||||
input_size = 100
|
||||
hidden_size = 100
|
||||
rnn_type = 'lstm'
|
||||
num_layers = 1
|
||||
num_hidden_layers = 1
|
||||
dropout = 0
|
||||
nr_cells = 100
|
||||
cell_size = 10
|
||||
read_heads = 1
|
||||
sparse_reads = 2
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
sequence_max_length = 10
|
||||
batch_size = 10
|
||||
cuda = gpu_id
|
||||
clip = 10
|
||||
length = 10
|
||||
|
||||
rnn = SAM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
rnn_type=rnn_type,
|
||||
num_layers=num_layers,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
dropout=dropout,
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
output, (chx, mhx, rv), v = rnn(input_data, None)
|
||||
output = output.transpose(0, 1)
|
||||
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
assert chx[0][0][0].size() == T.Size([10,100])
|
||||
# assert mhx['memory'].size() == T.Size([10,1,1])
|
||||
assert rv.size() == T.Size([10, 10])
|
||||
|
||||
|
||||
def test_rnn_n():
|
||||
T.manual_seed(1111)
|
||||
|
||||
input_size = 100
|
||||
hidden_size = 100
|
||||
rnn_type = 'lstm'
|
||||
num_layers = 3
|
||||
num_hidden_layers = 5
|
||||
dropout = 0.2
|
||||
nr_cells = 200
|
||||
cell_size = 17
|
||||
read_heads = 2
|
||||
sparse_reads = 4
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
sequence_max_length = 10
|
||||
batch_size = 10
|
||||
cuda = gpu_id
|
||||
clip = 20
|
||||
length = 13
|
||||
|
||||
rnn = SAM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
rnn_type=rnn_type,
|
||||
num_layers=num_layers,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
dropout=dropout,
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
output, (chx, mhx, rv), v = rnn(input_data, None)
|
||||
output = output.transpose(0, 1)
|
||||
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[0][0].size() == T.Size([num_hidden_layers,10,100])
|
||||
# assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([10, 34])
|
||||
|
||||
|
||||
def test_rnn_no_memory_pass():
|
||||
T.manual_seed(1111)
|
||||
|
||||
input_size = 100
|
||||
hidden_size = 100
|
||||
rnn_type = 'lstm'
|
||||
num_layers = 3
|
||||
num_hidden_layers = 5
|
||||
dropout = 0.2
|
||||
nr_cells = 5000
|
||||
cell_size = 17
|
||||
sparse_reads = 3
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
sequence_max_length = 10
|
||||
batch_size = 10
|
||||
cuda = gpu_id
|
||||
clip = 20
|
||||
length = 13
|
||||
|
||||
rnn = SAM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
rnn_type=rnn_type,
|
||||
num_layers=num_layers,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
dropout=dropout,
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
sparse_reads=sparse_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
(chx, mhx, rv) = (None, None, None)
|
||||
outputs = []
|
||||
for x in range(6):
|
||||
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
|
||||
output = output.transpose(0, 1)
|
||||
outputs.append(output)
|
||||
|
||||
output = functools.reduce(lambda x,y: x + y, outputs)
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[0][0].size() == T.Size([num_hidden_layers,10,100])
|
||||
# assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv == None
|
||||
|
201
test/test_sam_rnn.py
Normal file
201
test/test_sam_rnn.py
Normal file
@ -0,0 +1,201 @@
|
||||
# #!/usr/bin/env python3
|
||||
# # -*- coding: utf-8 -*-
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
import functools
|
||||
sys.path.insert(0, '.')
|
||||
|
||||
from dnc import SAM
|
||||
from test_utils import generate_data, criterion
|
||||
|
||||
|
||||
def test_rnn_1():
|
||||
T.manual_seed(1111)
|
||||
|
||||
input_size = 100
|
||||
hidden_size = 100
|
||||
rnn_type = 'rnn'
|
||||
num_layers = 1
|
||||
num_hidden_layers = 1
|
||||
dropout = 0
|
||||
nr_cells = 100
|
||||
cell_size = 10
|
||||
read_heads = 1
|
||||
sparse_reads = 2
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
sequence_max_length = 10
|
||||
batch_size = 10
|
||||
cuda = gpu_id
|
||||
clip = 10
|
||||
length = 10
|
||||
|
||||
rnn = SAM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
rnn_type=rnn_type,
|
||||
num_layers=num_layers,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
dropout=dropout,
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
output, (chx, mhx, rv), v = rnn(input_data, None)
|
||||
output = output.transpose(0, 1)
|
||||
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
assert chx[0][0].size() == T.Size([10,100])
|
||||
# assert mhx['memory'].size() == T.Size([10,1,1])
|
||||
assert rv.size() == T.Size([10, 10])
|
||||
|
||||
|
||||
def test_rnn_n():
|
||||
T.manual_seed(1111)
|
||||
|
||||
input_size = 100
|
||||
hidden_size = 100
|
||||
rnn_type = 'rnn'
|
||||
num_layers = 3
|
||||
num_hidden_layers = 5
|
||||
dropout = 0.2
|
||||
nr_cells = 200
|
||||
cell_size = 17
|
||||
read_heads = 2
|
||||
sparse_reads = 4
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
sequence_max_length = 10
|
||||
batch_size = 10
|
||||
cuda = gpu_id
|
||||
clip = 20
|
||||
length = 13
|
||||
|
||||
rnn = SAM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
rnn_type=rnn_type,
|
||||
num_layers=num_layers,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
dropout=dropout,
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
output, (chx, mhx, rv), v = rnn(input_data, None)
|
||||
output = output.transpose(0, 1)
|
||||
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[0].size() == T.Size([num_hidden_layers,10,100])
|
||||
# assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([10, 34])
|
||||
|
||||
|
||||
def test_rnn_no_memory_pass():
|
||||
T.manual_seed(1111)
|
||||
|
||||
input_size = 100
|
||||
hidden_size = 100
|
||||
rnn_type = 'rnn'
|
||||
num_layers = 3
|
||||
num_hidden_layers = 5
|
||||
dropout = 0.2
|
||||
nr_cells = 5000
|
||||
cell_size = 17
|
||||
sparse_reads = 3
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
sequence_max_length = 10
|
||||
batch_size = 10
|
||||
cuda = gpu_id
|
||||
clip = 20
|
||||
length = 13
|
||||
|
||||
rnn = SAM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
rnn_type=rnn_type,
|
||||
num_layers=num_layers,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
dropout=dropout,
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
sparse_reads=sparse_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||
optimizer.zero_grad()
|
||||
|
||||
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
(chx, mhx, rv) = (None, None, None)
|
||||
outputs = []
|
||||
for x in range(6):
|
||||
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
|
||||
output = output.transpose(0, 1)
|
||||
outputs.append(output)
|
||||
|
||||
output = functools.reduce(lambda x,y: x + y, outputs)
|
||||
loss = criterion((output), target_output)
|
||||
loss.backward()
|
||||
|
||||
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[0].size() == T.Size([num_hidden_layers,10,100])
|
||||
# assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv == None
|
||||
|
@ -36,6 +36,7 @@ def test_rnn_1():
|
||||
cell_size = 10
|
||||
read_heads = 1
|
||||
sparse_reads = 2
|
||||
temporal_reads = 1
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
@ -56,6 +57,7 @@ def test_rnn_1():
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
temporal_reads=temporal_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
@ -94,6 +96,7 @@ def test_rnn_n():
|
||||
cell_size = 17
|
||||
read_heads = 2
|
||||
sparse_reads = 4
|
||||
temporal_reads = 3
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
@ -114,6 +117,7 @@ def test_rnn_n():
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
temporal_reads=temporal_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
@ -151,6 +155,7 @@ def test_rnn_no_memory_pass():
|
||||
nr_cells = 5000
|
||||
cell_size = 17
|
||||
sparse_reads = 3
|
||||
temporal_reads = 4
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
@ -170,6 +175,7 @@ def test_rnn_no_memory_pass():
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
sparse_reads=sparse_reads,
|
||||
temporal_reads=temporal_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
@ -36,6 +36,7 @@ def test_rnn_1():
|
||||
cell_size = 10
|
||||
read_heads = 1
|
||||
sparse_reads = 2
|
||||
temporal_reads = 1
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
@ -56,6 +57,7 @@ def test_rnn_1():
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
temporal_reads=temporal_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
@ -94,6 +96,7 @@ def test_rnn_n():
|
||||
cell_size = 17
|
||||
read_heads = 2
|
||||
sparse_reads = 4
|
||||
temporal_reads = 3
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
@ -114,6 +117,7 @@ def test_rnn_n():
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
temporal_reads=temporal_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
@ -151,6 +155,7 @@ def test_rnn_no_memory_pass():
|
||||
nr_cells = 5000
|
||||
cell_size = 17
|
||||
sparse_reads = 3
|
||||
temporal_reads = 4
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
@ -170,6 +175,7 @@ def test_rnn_no_memory_pass():
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
sparse_reads=sparse_reads,
|
||||
temporal_reads=temporal_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
@ -36,6 +36,7 @@ def test_rnn_1():
|
||||
cell_size = 10
|
||||
read_heads = 1
|
||||
sparse_reads = 2
|
||||
temporal_reads = 1
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
@ -56,6 +57,7 @@ def test_rnn_1():
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
temporal_reads=temporal_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
@ -94,6 +96,7 @@ def test_rnn_n():
|
||||
cell_size = 17
|
||||
read_heads = 2
|
||||
sparse_reads = 4
|
||||
temporal_reads = 3
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
@ -114,6 +117,7 @@ def test_rnn_n():
|
||||
cell_size=cell_size,
|
||||
read_heads=read_heads,
|
||||
sparse_reads=sparse_reads,
|
||||
temporal_reads=temporal_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
@ -151,6 +155,7 @@ def test_rnn_no_memory_pass():
|
||||
nr_cells = 5000
|
||||
cell_size = 17
|
||||
sparse_reads = 3
|
||||
temporal_reads = 4
|
||||
gpu_id = -1
|
||||
debug = True
|
||||
lr = 0.001
|
||||
@ -170,6 +175,7 @@ def test_rnn_no_memory_pass():
|
||||
nr_cells=nr_cells,
|
||||
cell_size=cell_size,
|
||||
sparse_reads=sparse_reads,
|
||||
temporal_reads=temporal_reads,
|
||||
gpu_id=gpu_id,
|
||||
debug=debug
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user