commit
adbb195e27
54
README.md
54
README.md
@ -22,7 +22,10 @@ Includes:
|
|||||||
- [SAM](#sam)
|
- [SAM](#sam)
|
||||||
- [Example usage](#example-usage-2)
|
- [Example usage](#example-usage-2)
|
||||||
- [Debugging](#debugging-2)
|
- [Debugging](#debugging-2)
|
||||||
- [Example copy task](#example-copy-task)
|
- [Tasks](#tasks)
|
||||||
|
- [Copy task (with curriculum and generalization)](#copy-task-with-curriculum-and-generalization)
|
||||||
|
- [Generalizing Addition task](#generalizing-addition-task)
|
||||||
|
- [Generalizing Argmax task](#generalizing-argmax-task)
|
||||||
- [Code Structure](#code-structure)
|
- [Code Structure](#code-structure)
|
||||||
- [General noteworthy stuff](#general-noteworthy-stuff)
|
- [General noteworthy stuff](#general-noteworthy-stuff)
|
||||||
|
|
||||||
@ -48,6 +51,12 @@ pip install -r ./requirements.txt
|
|||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For using fully GPU based SDNCs or SAMs, install FAISS:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda install faiss-gpu -c pytorch
|
||||||
|
```
|
||||||
|
|
||||||
`pytest` is required to run the test
|
`pytest` is required to run the test
|
||||||
|
|
||||||
## Architecure
|
## Architecure
|
||||||
@ -362,7 +371,9 @@ Memory vectors returned by forward pass (`np.ndarray`):
|
|||||||
| `debug_memory['usage']` | layer * time | nr_cells
|
| `debug_memory['usage']` | layer * time | nr_cells
|
||||||
|
|
||||||
|
|
||||||
## Example copy task
|
## Tasks
|
||||||
|
|
||||||
|
### Copy task (with curriculum and generalization)
|
||||||
|
|
||||||
The copy task, as descibed in the original paper, is included in the repo.
|
The copy task, as descibed in the original paper, is included in the repo.
|
||||||
|
|
||||||
@ -370,13 +381,13 @@ From the project root:
|
|||||||
```bash
|
```bash
|
||||||
python ./tasks/copy_task.py -cuda 0 -optim rmsprop -batch_size 32 -mem_slot 64 # (like original implementation)
|
python ./tasks/copy_task.py -cuda 0 -optim rmsprop -batch_size 32 -mem_slot 64 # (like original implementation)
|
||||||
|
|
||||||
python3 ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 32 -batch_size 1000 -optim adam -sequence_max_length 8 # (faster convergence)
|
python ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 32 -batch_size 1000 -optim adam -sequence_max_length 8 # (faster convergence)
|
||||||
|
|
||||||
For SDNCs:
|
For SDNCs:
|
||||||
python3 -B ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -memory_type sdnc -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 100 -mem_size 10 -read_heads 1 -sparse_reads 10 -batch_size 20 -optim adam -sequence_max_length 10
|
python ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -memory_type sdnc -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 100 -mem_size 10 -read_heads 1 -sparse_reads 10 -batch_size 20 -optim adam -sequence_max_length 10
|
||||||
|
|
||||||
and for curriculum learning for SDNCs:
|
and for curriculum learning for SDNCs:
|
||||||
python3 -B ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -memory_type sdnc -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 100 -mem_size 10 -read_heads 1 -sparse_reads 4 -temporal_reads 4 -batch_size 20 -optim adam -sequence_max_length 4 -curriculum_increment 2 -curriculum_freq 10000
|
python ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -memory_type sdnc -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 100 -mem_size 10 -read_heads 1 -sparse_reads 4 -temporal_reads 4 -batch_size 20 -optim adam -sequence_max_length 4 -curriculum_increment 2 -curriculum_freq 10000
|
||||||
```
|
```
|
||||||
|
|
||||||
For the full set of options, see:
|
For the full set of options, see:
|
||||||
@ -403,6 +414,30 @@ The visdom dashboard shows memory as a heatmap for batch 0 every `-summarize_fre
|
|||||||
|
|
||||||
![Visdom dashboard](./docs/dnc-mem-debug.png)
|
![Visdom dashboard](./docs/dnc-mem-debug.png)
|
||||||
|
|
||||||
|
### Generalizing Addition task
|
||||||
|
|
||||||
|
The adding task is as described in [this github pull request](https://github.com/Mostafa-Samir/DNC-tensorflow/pull/4#issue-199369192).
|
||||||
|
This task
|
||||||
|
- creates one-hot vectors of size `input_size`, each representing a number
|
||||||
|
- feeds a sentence of them to a network
|
||||||
|
- the output of which is added to get the sum of the decoded outputs
|
||||||
|
|
||||||
|
The task first trains the network for sentences of size ~100, and then tests if the network genetalizes for lengths ~1000.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python ./tasks/adding_task.py -cuda 0 -lr 0.0001 -rnn_type lstm -memory_type sam -nlayer 1 -nhlayer 1 -nhid 100 -dropout 0 -mem_slot 1000 -mem_size 32 -read_heads 1 -sparse_reads 4 -batch_size 20 -optim rmsprop -input_size 3 -sequence_max_length 100
|
||||||
|
```
|
||||||
|
|
||||||
|
### Generalizing Argmax task
|
||||||
|
|
||||||
|
The second adding task is similar to the first one, except that the network's output at the last time step is expected to be the argmax of the input.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python ./tasks/argmax_task.py -cuda 0 -lr 0.0001 -rnn_type lstm -memory_type dnc -nlayer 1 -nhlayer 1 -nhid 100 -dropout 0 -mem_slot 100 -mem_size 10 -read_heads 2 -batch_size 1 -optim rmsprop -sequence_max_length 15 -input_size 10 -iterations 10000
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Code Structure
|
## Code Structure
|
||||||
|
|
||||||
1. DNCs:
|
1. DNCs:
|
||||||
@ -436,6 +471,15 @@ make -j 4
|
|||||||
sudo make install
|
sudo make install
|
||||||
```
|
```
|
||||||
|
|
||||||
|
FAISS can be installed using:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
conda install faiss-gpu -c pytorch
|
||||||
|
```
|
||||||
|
|
||||||
|
FAISS is much faster, has a GPU implementation and is interoperable with pytorch tensors.
|
||||||
|
We try to use FAISS by default, in absence of which we fall back to FLANN.
|
||||||
|
|
||||||
2. An alternative to FLANN is [FAISS](https://github.com/facebookresearch/faiss), which is much faster and interoperable with torch cuda tensors (but is difficult to distribute, see [dnc/faiss_index.py](dnc/faiss_index.py)).
|
2. An alternative to FLANN is [FAISS](https://github.com/facebookresearch/faiss), which is much faster and interoperable with torch cuda tensors (but is difficult to distribute, see [dnc/faiss_index.py](dnc/faiss_index.py)).
|
||||||
3. `nan`s in the gradients are common, try with different batch sizes
|
3. `nan`s in the gradients are common, try with different batch sizes
|
||||||
|
|
||||||
|
43
dnc/dnc.py
43
dnc/dnc.py
@ -271,3 +271,46 @@ class DNC(nn.Module):
|
|||||||
return outputs, (controller_hidden, mem_hidden, read_vectors), viz
|
return outputs, (controller_hidden, mem_hidden, read_vectors), viz
|
||||||
else:
|
else:
|
||||||
return outputs, (controller_hidden, mem_hidden, read_vectors)
|
return outputs, (controller_hidden, mem_hidden, read_vectors)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
s = "\n----------------------------------------\n"
|
||||||
|
s += '{name}({input_size}, {hidden_size}'
|
||||||
|
if self.rnn_type != 'lstm':
|
||||||
|
s += ', rnn_type={rnn_type}'
|
||||||
|
if self.num_layers != 1:
|
||||||
|
s += ', num_layers={num_layers}'
|
||||||
|
if self.num_hidden_layers != 2:
|
||||||
|
s += ', num_hidden_layers={num_hidden_layers}'
|
||||||
|
if self.bias != True:
|
||||||
|
s += ', bias={bias}'
|
||||||
|
if self.batch_first != True:
|
||||||
|
s += ', batch_first={batch_first}'
|
||||||
|
if self.dropout != 0:
|
||||||
|
s += ', dropout={dropout}'
|
||||||
|
if self.bidirectional != False:
|
||||||
|
s += ', bidirectional={bidirectional}'
|
||||||
|
if self.nr_cells != 5:
|
||||||
|
s += ', nr_cells={nr_cells}'
|
||||||
|
if self.read_heads != 2:
|
||||||
|
s += ', read_heads={read_heads}'
|
||||||
|
if self.cell_size != 10:
|
||||||
|
s += ', cell_size={cell_size}'
|
||||||
|
if self.nonlinearity != 'tanh':
|
||||||
|
s += ', nonlinearity={nonlinearity}'
|
||||||
|
if self.gpu_id != -1:
|
||||||
|
s += ', gpu_id={gpu_id}'
|
||||||
|
if self.independent_linears != False:
|
||||||
|
s += ', independent_linears={independent_linears}'
|
||||||
|
if self.share_memory != True:
|
||||||
|
s += ', share_memory={share_memory}'
|
||||||
|
if self.debug != False:
|
||||||
|
s += ', debug={debug}'
|
||||||
|
if self.clip != 20:
|
||||||
|
s += ', clip={clip}'
|
||||||
|
|
||||||
|
s += ")\n" + super(DNC, self).__repr__() + \
|
||||||
|
"\n----------------------------------------\n"
|
||||||
|
return s.format(name=self.__class__.__name__, **self.__dict__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
from faiss import faiss
|
import faiss
|
||||||
|
|
||||||
from faiss.faiss import cast_integer_to_float_ptr as cast_float
|
from faiss import cast_integer_to_float_ptr as cast_float
|
||||||
from faiss.faiss import cast_integer_to_int_ptr as cast_int
|
from faiss import cast_integer_to_int_ptr as cast_int
|
||||||
from faiss.faiss import cast_integer_to_long_ptr as cast_long
|
from faiss import cast_integer_to_long_ptr as cast_long
|
||||||
|
|
||||||
from .util import *
|
from .util import *
|
||||||
|
|
||||||
@ -21,16 +21,16 @@ class FAISSIndex(object):
|
|||||||
self.num_lists = num_lists
|
self.num_lists = num_lists
|
||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
|
|
||||||
res = res if res else faiss.StandardGpuResources()
|
# BEWARE: if this variable gets deallocated, FAISS crashes
|
||||||
res.setTempMemoryFraction(0.01)
|
self.res = res if res else faiss.StandardGpuResources()
|
||||||
|
self.res.setTempMemoryFraction(0.01)
|
||||||
if self.gpu_id != -1:
|
if self.gpu_id != -1:
|
||||||
res.initializeForDevice(self.gpu_id)
|
self.res.initializeForDevice(self.gpu_id)
|
||||||
|
|
||||||
nr_samples = self.nr_cells * 100 * self.cell_size
|
nr_samples = self.nr_cells * 100 * self.cell_size
|
||||||
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size) * 10
|
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size)
|
||||||
# train = T.randn(self.nr_cells * 100, self.cell_size)
|
|
||||||
|
|
||||||
self.index = faiss.GpuIndexIVFFlat(res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
|
self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_L2)
|
||||||
self.index.setNumProbes(self.probes)
|
self.index.setNumProbes(self.probes)
|
||||||
self.train(train)
|
self.train(train)
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ class FAISSIndex(object):
|
|||||||
self.index.reset()
|
self.index.reset()
|
||||||
T.cuda.synchronize()
|
T.cuda.synchronize()
|
||||||
|
|
||||||
def add(self, other, positions=None, last=-1):
|
def add(self, other, positions=None, last=None):
|
||||||
other = ensure_gpu(other, self.gpu_id)
|
other = ensure_gpu(other, self.gpu_id)
|
||||||
|
|
||||||
T.cuda.synchronize()
|
T.cuda.synchronize()
|
||||||
@ -57,7 +57,7 @@ class FAISSIndex(object):
|
|||||||
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
|
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
|
||||||
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
|
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
|
||||||
else:
|
else:
|
||||||
other = other[:last, :]
|
other = other[:last, :] if last is not None else other
|
||||||
self.index.add_c(other.size(0), cast_float(ptr(other)))
|
self.index.add_c(other.size(0), cast_float(ptr(other)))
|
||||||
T.cuda.synchronize()
|
T.cuda.synchronize()
|
||||||
|
|
||||||
|
@ -8,7 +8,6 @@ import torch.nn.functional as F
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from .flann_index import FLANNIndex
|
|
||||||
from .util import *
|
from .util import *
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@ -44,11 +43,12 @@ class SparseMemory(nn.Module):
|
|||||||
m = self.mem_size
|
m = self.mem_size
|
||||||
w = self.cell_size
|
w = self.cell_size
|
||||||
r = self.read_heads
|
r = self.read_heads
|
||||||
# The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell)
|
# The visible memory size: (K * R read heads, forward and backward
|
||||||
|
# temporal reads of size KL and least used memory cell)
|
||||||
self.c = (r * self.K) + 1
|
self.c = (r * self.K) + 1
|
||||||
|
|
||||||
if self.independent_linears:
|
if self.independent_linears:
|
||||||
self.read_query_transform = nn.Linear(self.input_size, w*r)
|
self.read_query_transform = nn.Linear(self.input_size, w * r)
|
||||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
||||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||||
@ -72,11 +72,20 @@ class SparseMemory(nn.Module):
|
|||||||
if 'indexes' in hidden:
|
if 'indexes' in hidden:
|
||||||
[x.reset() for x in hidden['indexes']]
|
[x.reset() for x in hidden['indexes']]
|
||||||
else:
|
else:
|
||||||
# create new indexes
|
# create new indexes, try to use FAISS, fall back to FLANN
|
||||||
hidden['indexes'] = \
|
try:
|
||||||
[FLANNIndex(cell_size=self.cell_size,
|
from .faiss_index import FAISSIndex
|
||||||
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
|
hidden['indexes'] = \
|
||||||
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
[FAISSIndex(cell_size=self.cell_size,
|
||||||
|
nr_cells=self.mem_size, K=self.K, num_lists=self.num_lists,
|
||||||
|
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||||
|
except Exception as e:
|
||||||
|
print("\nFalling back to FLANN (CPU). \nFor using faster, GPU based indexes, install FAISS: `conda install faiss-gpu -c pytorch`")
|
||||||
|
from .flann_index import FLANNIndex
|
||||||
|
hidden['indexes'] = \
|
||||||
|
[FLANNIndex(cell_size=self.cell_size,
|
||||||
|
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
|
||||||
|
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||||
|
|
||||||
# add existing memory into indexes
|
# add existing memory into indexes
|
||||||
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
|
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
|
||||||
@ -104,7 +113,7 @@ class SparseMemory(nn.Module):
|
|||||||
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(),
|
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
|
||||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||||
}
|
}
|
||||||
@ -126,15 +135,16 @@ class SparseMemory(nn.Module):
|
|||||||
hidden['read_weights'].data.fill_(δ)
|
hidden['read_weights'].data.fill_(δ)
|
||||||
hidden['write_weights'].data.fill_(δ)
|
hidden['write_weights'].data.fill_(δ)
|
||||||
hidden['read_vectors'].data.fill_(δ)
|
hidden['read_vectors'].data.fill_(δ)
|
||||||
hidden['least_used_mem'].data.fill_(c+1+self.timestep)
|
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
|
||||||
hidden['usage'].data.fill_(δ)
|
hidden['usage'].data.fill_(δ)
|
||||||
hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
hidden['read_positions'] = cuda(
|
||||||
|
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
def write_into_sparse_memory(self, hidden):
|
def write_into_sparse_memory(self, hidden):
|
||||||
visible_memory = hidden['visible_memory']
|
visible_memory = hidden['visible_memory']
|
||||||
positions = hidden['read_positions'].squeeze()
|
positions = hidden['read_positions']
|
||||||
|
|
||||||
(b, m, w) = hidden['memory'].size()
|
(b, m, w) = hidden['memory'].size()
|
||||||
# update memory
|
# update memory
|
||||||
@ -147,8 +157,9 @@ class SparseMemory(nn.Module):
|
|||||||
hidden['indexes'][batch].reset()
|
hidden['indexes'][batch].reset()
|
||||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
||||||
|
|
||||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size-1
|
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
||||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + 1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
|
||||||
|
1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
@ -177,7 +188,8 @@ class SparseMemory(nn.Module):
|
|||||||
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
|
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
|
||||||
|
|
||||||
# write into memory
|
# write into memory
|
||||||
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
hidden['visible_memory'] = hidden['visible_memory'] * \
|
||||||
|
(1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||||
hidden = self.write_into_sparse_memory(hidden)
|
hidden = self.write_into_sparse_memory(hidden)
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
@ -240,11 +252,11 @@ class SparseMemory(nn.Module):
|
|||||||
# sparse read
|
# sparse read
|
||||||
read_vectors, positions, read_weights, visible_memory = \
|
read_vectors, positions, read_weights, visible_memory = \
|
||||||
self.read_from_sparse_memory(
|
self.read_from_sparse_memory(
|
||||||
hidden['memory'],
|
hidden['memory'],
|
||||||
hidden['indexes'],
|
hidden['indexes'],
|
||||||
read_query,
|
read_query,
|
||||||
hidden['least_used_mem'],
|
hidden['least_used_mem'],
|
||||||
hidden['usage']
|
hidden['usage']
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden['read_positions'] = positions
|
hidden['read_positions'] = positions
|
||||||
@ -276,11 +288,11 @@ class SparseMemory(nn.Module):
|
|||||||
else:
|
else:
|
||||||
ξ = self.interface_weights(ξ)
|
ξ = self.interface_weights(ξ)
|
||||||
# r read keys (b * r * w)
|
# r read keys (b * r * w)
|
||||||
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
|
read_query = ξ[:, :r * w].contiguous().view(b, r, w)
|
||||||
# write key (b * 1 * w)
|
# write key (b * 1 * w)
|
||||||
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
|
write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w)
|
||||||
# write vector (b * 1 * r)
|
# write vector (b * 1 * r)
|
||||||
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c)
|
interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
|
||||||
# write gate (b * 1)
|
# write gate (b * 1)
|
||||||
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
||||||
|
|
||||||
|
@ -46,11 +46,12 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
m = self.mem_size
|
m = self.mem_size
|
||||||
w = self.cell_size
|
w = self.cell_size
|
||||||
r = self.read_heads
|
r = self.read_heads
|
||||||
# The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell)
|
# The visible memory size: (K * R read heads, forward and backward
|
||||||
|
# temporal reads of size KL and least used memory cell)
|
||||||
self.c = (r * self.K) + (self.KL * 2) + 1
|
self.c = (r * self.K) + (self.KL * 2) + 1
|
||||||
|
|
||||||
if self.independent_linears:
|
if self.independent_linears:
|
||||||
self.read_query_transform = nn.Linear(self.input_size, w*r)
|
self.read_query_transform = nn.Linear(self.input_size, w * r)
|
||||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
||||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||||
@ -75,10 +76,19 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
[x.reset() for x in hidden['indexes']]
|
[x.reset() for x in hidden['indexes']]
|
||||||
else:
|
else:
|
||||||
# create new indexes
|
# create new indexes
|
||||||
hidden['indexes'] = \
|
try:
|
||||||
[FLANNIndex(cell_size=self.cell_size,
|
from .faiss_index import FAISSIndex
|
||||||
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
|
hidden['indexes'] = \
|
||||||
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
[FAISSIndex(cell_size=self.cell_size,
|
||||||
|
nr_cells=self.mem_size, K=self.K, num_lists=self.num_lists,
|
||||||
|
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||||
|
except Exception as e:
|
||||||
|
print("\nFalling back to FLANN (CPU). \nFor using faster, GPU based indexes, install FAISS: `conda install faiss-gpu -c pytorch`")
|
||||||
|
from .flann_index import FLANNIndex
|
||||||
|
hidden['indexes'] = \
|
||||||
|
[FLANNIndex(cell_size=self.cell_size,
|
||||||
|
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
|
||||||
|
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||||
|
|
||||||
# add existing memory into indexes
|
# add existing memory into indexes
|
||||||
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
|
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
|
||||||
@ -103,13 +113,13 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
# warning can be a huge chunk of contiguous memory
|
# warning can be a huge chunk of contiguous memory
|
||||||
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||||
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||||
'link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
|
'link_matrix': cuda(T.zeros(b, m, self.KL * 2), gpu_id=self.gpu_id),
|
||||||
'rev_link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
|
'rev_link_matrix': cuda(T.zeros(b, m, self.KL * 2), gpu_id=self.gpu_id),
|
||||||
'precedence': cuda(T.zeros(b, self.KL*2).fill_(δ), gpu_id=self.gpu_id),
|
'precedence': cuda(T.zeros(b, self.KL * 2).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(),
|
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
|
||||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||||
}
|
}
|
||||||
@ -137,15 +147,16 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
hidden['read_weights'].data.fill_(δ)
|
hidden['read_weights'].data.fill_(δ)
|
||||||
hidden['write_weights'].data.fill_(δ)
|
hidden['write_weights'].data.fill_(δ)
|
||||||
hidden['read_vectors'].data.fill_(δ)
|
hidden['read_vectors'].data.fill_(δ)
|
||||||
hidden['least_used_mem'].data.fill_(c+1+self.timestep)
|
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
|
||||||
hidden['usage'].data.fill_(δ)
|
hidden['usage'].data.fill_(δ)
|
||||||
hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
hidden['read_positions'] = cuda(
|
||||||
|
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
def write_into_sparse_memory(self, hidden):
|
def write_into_sparse_memory(self, hidden):
|
||||||
visible_memory = hidden['visible_memory']
|
visible_memory = hidden['visible_memory']
|
||||||
positions = hidden['read_positions'].squeeze()
|
positions = hidden['read_positions']
|
||||||
|
|
||||||
(b, m, w) = hidden['memory'].size()
|
(b, m, w) = hidden['memory'].size()
|
||||||
# update memory
|
# update memory
|
||||||
@ -158,8 +169,9 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
hidden['indexes'][batch].reset()
|
hidden['indexes'][batch].reset()
|
||||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
||||||
|
|
||||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size-1
|
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
||||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + 1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
|
||||||
|
1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
@ -179,9 +191,10 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
|
|
||||||
link_matrix = (1 - write_weights_i) * link_matrix + write_weights_i * precedence_j
|
link_matrix = (1 - write_weights_i) * link_matrix + write_weights_i * precedence_j
|
||||||
|
|
||||||
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (temporal_write_weights_j * precedence_dense_i)
|
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + \
|
||||||
|
(temporal_write_weights_j * precedence_dense_i)
|
||||||
|
|
||||||
return link_matrix.squeeze() * I, rev_link_matrix.squeeze() * I
|
return link_matrix * I, rev_link_matrix * I
|
||||||
|
|
||||||
def update_precedence(self, precedence, write_weights):
|
def update_precedence(self, precedence, write_weights):
|
||||||
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
|
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
|
||||||
@ -211,22 +224,23 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
|
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
|
||||||
|
|
||||||
# write into memory
|
# write into memory
|
||||||
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
hidden['visible_memory'] = hidden['visible_memory'] * \
|
||||||
|
(1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||||
hidden = self.write_into_sparse_memory(hidden)
|
hidden = self.write_into_sparse_memory(hidden)
|
||||||
|
|
||||||
# update link_matrix and precedence
|
# update link_matrix and precedence
|
||||||
(b, c) = write_weights.size()
|
(b, c) = write_weights.size()
|
||||||
|
|
||||||
# update link matrix
|
# update link matrix
|
||||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
|
temporal_read_positions = hidden['read_positions'][:, self.read_heads * self.K + 1:]
|
||||||
hidden['link_matrix'], hidden['rev_link_matrix'] = \
|
hidden['link_matrix'], hidden['rev_link_matrix'] = \
|
||||||
self.update_link_matrices(
|
self.update_link_matrices(
|
||||||
hidden['link_matrix'],
|
hidden['link_matrix'],
|
||||||
hidden['rev_link_matrix'],
|
hidden['rev_link_matrix'],
|
||||||
hidden['write_weights'],
|
hidden['write_weights'],
|
||||||
hidden['precedence'],
|
hidden['precedence'],
|
||||||
temporal_read_positions
|
temporal_read_positions
|
||||||
)
|
)
|
||||||
|
|
||||||
# update precedence vector
|
# update precedence vector
|
||||||
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
||||||
@ -255,8 +269,8 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
return usage, I
|
return usage, I
|
||||||
|
|
||||||
def directional_weightings(self, link_matrix, rev_link_matrix, temporal_read_weights):
|
def directional_weightings(self, link_matrix, rev_link_matrix, temporal_read_weights):
|
||||||
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2)
|
||||||
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2)
|
||||||
return f, b
|
return f, b
|
||||||
|
|
||||||
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage, forward, backward, prev_read_positions):
|
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage, forward, backward, prev_read_positions):
|
||||||
@ -299,20 +313,20 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
|
|
||||||
def read(self, read_query, hidden):
|
def read(self, read_query, hidden):
|
||||||
# get forward and backward weights
|
# get forward and backward weights
|
||||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
|
temporal_read_positions = hidden['read_positions'][:, self.read_heads * self.K + 1:]
|
||||||
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
||||||
forward, backward = self.directional_weightings(hidden['link_matrix'], hidden['rev_link_matrix'], read_weights)
|
forward, backward = self.directional_weightings(hidden['link_matrix'], hidden['rev_link_matrix'], read_weights)
|
||||||
|
|
||||||
# sparse read
|
# sparse read
|
||||||
read_vectors, positions, read_weights, visible_memory = \
|
read_vectors, positions, read_weights, visible_memory = \
|
||||||
self.read_from_sparse_memory(
|
self.read_from_sparse_memory(
|
||||||
hidden['memory'],
|
hidden['memory'],
|
||||||
hidden['indexes'],
|
hidden['indexes'],
|
||||||
read_query,
|
read_query,
|
||||||
hidden['least_used_mem'],
|
hidden['least_used_mem'],
|
||||||
hidden['usage'],
|
hidden['usage'],
|
||||||
forward, backward,
|
forward, backward,
|
||||||
hidden['read_positions']
|
hidden['read_positions']
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden['read_positions'] = positions
|
hidden['read_positions'] = positions
|
||||||
@ -344,11 +358,11 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
else:
|
else:
|
||||||
ξ = self.interface_weights(ξ)
|
ξ = self.interface_weights(ξ)
|
||||||
# r read keys (b * r * w)
|
# r read keys (b * r * w)
|
||||||
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
|
read_query = ξ[:, :r * w].contiguous().view(b, r, w)
|
||||||
# write key (b * 1 * w)
|
# write key (b * 1 * w)
|
||||||
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
|
write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w)
|
||||||
# write vector (b * 1 * r)
|
# write vector (b * 1 * r)
|
||||||
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c)
|
interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
|
||||||
# write gate (b * 1)
|
# write gate (b * 1)
|
||||||
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
||||||
|
|
||||||
|
@ -138,7 +138,7 @@ def ptr(tensor):
|
|||||||
if T.is_tensor(tensor):
|
if T.is_tensor(tensor):
|
||||||
return tensor.storage().data_ptr()
|
return tensor.storage().data_ptr()
|
||||||
elif hasattr(tensor, 'data'):
|
elif hasattr(tensor, 'data'):
|
||||||
return tensor.data.storage().data_ptr()
|
return tensor.clone().data.storage().data_ptr()
|
||||||
else:
|
else:
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
267
tasks/adding_task.py
Normal file
267
tasks/adding_task.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import getopt
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
from visdom import Visdom
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join('..', '..'))
|
||||||
|
|
||||||
|
import torch as T
|
||||||
|
from torch.autograd import Variable as var
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
from torch.nn.utils import clip_grad_norm
|
||||||
|
|
||||||
|
from dnc.dnc import DNC
|
||||||
|
from dnc.sdnc import SDNC
|
||||||
|
from dnc.sam import SAM
|
||||||
|
from dnc.util import *
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch Differentiable Neural Computer')
|
||||||
|
parser.add_argument('-input_size', type=int, default=6, help='dimension of input feature')
|
||||||
|
parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller')
|
||||||
|
parser.add_argument('-nhid', type=int, default=64, help='number of hidden units of the inner nn')
|
||||||
|
parser.add_argument('-dropout', type=float, default=0, help='controller dropout')
|
||||||
|
parser.add_argument('-memory_type', type=str, default='dnc', help='dense or sparse memory: dnc | sdnc | sam')
|
||||||
|
|
||||||
|
parser.add_argument('-nlayer', type=int, default=1, help='number of layers')
|
||||||
|
parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers')
|
||||||
|
parser.add_argument('-lr', type=float, default=1e-4, help='initial learning rate')
|
||||||
|
parser.add_argument('-optim', type=str, default='adam', help='learning rule, supports adam|rmsprop')
|
||||||
|
parser.add_argument('-clip', type=float, default=50, help='gradient clipping')
|
||||||
|
|
||||||
|
parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size')
|
||||||
|
parser.add_argument('-mem_size', type=int, default=20, help='memory dimension')
|
||||||
|
parser.add_argument('-mem_slot', type=int, default=16, help='number of memory slots')
|
||||||
|
parser.add_argument('-read_heads', type=int, default=4, help='number of read heads')
|
||||||
|
parser.add_argument('-sparse_reads', type=int, default=10, help='number of sparse reads per read head')
|
||||||
|
parser.add_argument('-temporal_reads', type=int, default=2, help='number of temporal reads')
|
||||||
|
|
||||||
|
parser.add_argument('-sequence_max_length', type=int, default=1000, metavar='N', help='sequence_max_length')
|
||||||
|
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
|
||||||
|
|
||||||
|
parser.add_argument('-iterations', type=int, default=2000, metavar='N', help='total number of iteration')
|
||||||
|
parser.add_argument('-summarize_freq', type=int, default=100, metavar='N', help='summarize frequency')
|
||||||
|
parser.add_argument('-check_freq', type=int, default=100, metavar='N', help='check point frequency')
|
||||||
|
parser.add_argument('-visdom', action='store_true', help='plot memory content on visdom per -summarize_freq steps')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
viz = Visdom()
|
||||||
|
# assert viz.check_connection()
|
||||||
|
|
||||||
|
if args.cuda != -1:
|
||||||
|
print('Using CUDA.')
|
||||||
|
T.manual_seed(1111)
|
||||||
|
else:
|
||||||
|
print('Using CPU.')
|
||||||
|
|
||||||
|
def llprint(message):
|
||||||
|
sys.stdout.write(message)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def onehot(x, n):
|
||||||
|
ret = np.zeros(n).astype(np.float32)
|
||||||
|
ret[x] = 1.0
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def generate_data(length, size):
|
||||||
|
|
||||||
|
content = np.random.randint(0, size - 1, length)
|
||||||
|
|
||||||
|
seqlen = length + 1
|
||||||
|
x_seq_list = [float('nan')] * seqlen
|
||||||
|
sums = 0.0
|
||||||
|
sums_text = ""
|
||||||
|
for i in range(seqlen):
|
||||||
|
if (i < length):
|
||||||
|
x_seq_list[i] = onehot(content[i], size)
|
||||||
|
sums += content[i]
|
||||||
|
sums_text += str(content[i]) + " + "
|
||||||
|
else:
|
||||||
|
x_seq_list[i] = onehot(size - 1, size)
|
||||||
|
|
||||||
|
x_seq_list = np.array(x_seq_list)
|
||||||
|
x_seq_list = x_seq_list.reshape((1,) + x_seq_list.shape)
|
||||||
|
sums = np.array(sums)
|
||||||
|
sums = sums.reshape(1, 1, 1)
|
||||||
|
|
||||||
|
return cudavec(x_seq_list, gpu_id=args.cuda).float(), cudavec(sums, gpu_id=args.cuda).float(), sums_text
|
||||||
|
|
||||||
|
|
||||||
|
def cross_entropy(prediction, target):
|
||||||
|
return (prediction - target) ** 2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
dirname = os.path.dirname(__file__)
|
||||||
|
ckpts_dir = os.path.join(dirname, 'checkpoints')
|
||||||
|
|
||||||
|
input_size = args.input_size
|
||||||
|
memory_type = args.memory_type
|
||||||
|
lr = args.lr
|
||||||
|
clip = args.clip
|
||||||
|
batch_size = args.batch_size
|
||||||
|
sequence_max_length = args.sequence_max_length
|
||||||
|
cuda = args.cuda
|
||||||
|
iterations = args.iterations
|
||||||
|
summarize_freq = args.summarize_freq
|
||||||
|
check_freq = args.check_freq
|
||||||
|
visdom = args.visdom
|
||||||
|
|
||||||
|
from_checkpoint = None
|
||||||
|
|
||||||
|
if args.memory_type == 'dnc':
|
||||||
|
rnn = DNC(
|
||||||
|
input_size=args.input_size,
|
||||||
|
hidden_size=args.nhid,
|
||||||
|
rnn_type=args.rnn_type,
|
||||||
|
num_layers=args.nlayer,
|
||||||
|
num_hidden_layers=args.nhlayer,
|
||||||
|
dropout=args.dropout,
|
||||||
|
nr_cells=args.mem_slot,
|
||||||
|
cell_size=args.mem_size,
|
||||||
|
read_heads=args.read_heads,
|
||||||
|
gpu_id=args.cuda,
|
||||||
|
debug=args.visdom,
|
||||||
|
batch_first=True,
|
||||||
|
independent_linears=True
|
||||||
|
)
|
||||||
|
elif args.memory_type == 'sdnc':
|
||||||
|
rnn = SDNC(
|
||||||
|
input_size=args.input_size,
|
||||||
|
hidden_size=args.nhid,
|
||||||
|
rnn_type=args.rnn_type,
|
||||||
|
num_layers=args.nlayer,
|
||||||
|
num_hidden_layers=args.nhlayer,
|
||||||
|
dropout=args.dropout,
|
||||||
|
nr_cells=args.mem_slot,
|
||||||
|
cell_size=args.mem_size,
|
||||||
|
sparse_reads=args.sparse_reads,
|
||||||
|
temporal_reads=args.temporal_reads,
|
||||||
|
read_heads=args.read_heads,
|
||||||
|
gpu_id=args.cuda,
|
||||||
|
debug=args.visdom,
|
||||||
|
batch_first=True,
|
||||||
|
independent_linears=False
|
||||||
|
)
|
||||||
|
elif args.memory_type == 'sam':
|
||||||
|
rnn = SAM(
|
||||||
|
input_size=args.input_size,
|
||||||
|
hidden_size=args.nhid,
|
||||||
|
rnn_type=args.rnn_type,
|
||||||
|
num_layers=args.nlayer,
|
||||||
|
num_hidden_layers=args.nhlayer,
|
||||||
|
dropout=args.dropout,
|
||||||
|
nr_cells=args.mem_slot,
|
||||||
|
cell_size=args.mem_size,
|
||||||
|
sparse_reads=args.sparse_reads,
|
||||||
|
read_heads=args.read_heads,
|
||||||
|
gpu_id=args.cuda,
|
||||||
|
debug=args.visdom,
|
||||||
|
batch_first=True,
|
||||||
|
independent_linears=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception('Not recognized type of memory')
|
||||||
|
|
||||||
|
if args.cuda != -1:
|
||||||
|
rnn = rnn.cuda(args.cuda)
|
||||||
|
|
||||||
|
print(rnn)
|
||||||
|
|
||||||
|
last_save_losses = []
|
||||||
|
|
||||||
|
if args.optim == 'adam':
|
||||||
|
optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
||||||
|
elif args.optim == 'adamax':
|
||||||
|
optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
||||||
|
elif args.optim == 'rmsprop':
|
||||||
|
optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, momentum=0.9, eps=1e-10) # 0.0001
|
||||||
|
elif args.optim == 'sgd':
|
||||||
|
optimizer = optim.SGD(rnn.parameters(), lr=args.lr) # 0.01
|
||||||
|
elif args.optim == 'adagrad':
|
||||||
|
optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr)
|
||||||
|
elif args.optim == 'adadelta':
|
||||||
|
optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)
|
||||||
|
|
||||||
|
last_100_losses = []
|
||||||
|
|
||||||
|
(chx, mhx, rv) = (None, None, None)
|
||||||
|
for epoch in range(iterations + 1):
|
||||||
|
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
|
||||||
|
optimizer.zero_grad()
|
||||||
|
# We use for training just (sequence_max_length / 10) examples
|
||||||
|
random_length = np.random.randint(2, (sequence_max_length) + 1)
|
||||||
|
input_data, target_output, sums_text = generate_data(random_length, input_size)
|
||||||
|
|
||||||
|
if rnn.debug:
|
||||||
|
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
else:
|
||||||
|
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
|
||||||
|
output = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True)
|
||||||
|
loss = cross_entropy(output, target_output)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
T.nn.utils.clip_grad_norm(rnn.parameters(), args.clip)
|
||||||
|
optimizer.step()
|
||||||
|
loss_value = loss.data[0]
|
||||||
|
|
||||||
|
# detach memory from graph
|
||||||
|
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
||||||
|
|
||||||
|
summarize = (epoch % summarize_freq == 0)
|
||||||
|
take_checkpoint = (epoch != 0) and (epoch % iterations == 0)
|
||||||
|
|
||||||
|
last_100_losses.append(loss_value)
|
||||||
|
|
||||||
|
if summarize:
|
||||||
|
llprint("\rIteration %d/%d" % (epoch, iterations))
|
||||||
|
llprint("\nAvg. Logistic Loss: %.4f\n" % (np.mean(last_100_losses)))
|
||||||
|
output = output.data.cpu().numpy()
|
||||||
|
print("Real value: ", ' = ' + str(int(target_output[0])))
|
||||||
|
print("Predicted: ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
|
||||||
|
last_100_losses = []
|
||||||
|
|
||||||
|
if take_checkpoint:
|
||||||
|
llprint("\nSaving Checkpoint ... "),
|
||||||
|
check_ptr = os.path.join(ckpts_dir, 'step_{}.pth'.format(epoch))
|
||||||
|
cur_weights = rnn.state_dict()
|
||||||
|
T.save(cur_weights, check_ptr)
|
||||||
|
llprint("Done!\n")
|
||||||
|
|
||||||
|
llprint("\nTesting generalization...\n")
|
||||||
|
|
||||||
|
rnn.eval()
|
||||||
|
|
||||||
|
for i in range(int((iterations + 1) / 10)):
|
||||||
|
llprint("\nIteration %d/%d" % (i, iterations))
|
||||||
|
# We test now the learned generalization using sequence_max_length examples
|
||||||
|
random_length = np.random.randint(2, int(sequence_max_length) * 10 + 1)
|
||||||
|
input_data, target_output, sums_text = generate_data(random_length, input_size)
|
||||||
|
|
||||||
|
if rnn.debug:
|
||||||
|
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
else:
|
||||||
|
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
|
||||||
|
output = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True)
|
||||||
|
output = output.data.cpu().numpy()
|
||||||
|
print("\nReal value: ", ' = ' + str(int(target_output[0])))
|
||||||
|
print("Predicted: ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
|
283
tasks/argmax_task.py
Normal file
283
tasks/argmax_task.py
Normal file
@ -0,0 +1,283 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import getopt
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
from visdom import Visdom
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.join('..', '..'))
|
||||||
|
|
||||||
|
import torch as T
|
||||||
|
from torch.autograd import Variable as var
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
|
from torch.nn.utils import clip_grad_norm
|
||||||
|
|
||||||
|
from dnc.dnc import DNC
|
||||||
|
from dnc.sdnc import SDNC
|
||||||
|
from dnc.sam import SAM
|
||||||
|
from dnc.util import *
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch Differentiable Neural Computer')
|
||||||
|
parser.add_argument('-input_size', type=int, default=6, help='dimension of input feature')
|
||||||
|
parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller')
|
||||||
|
parser.add_argument('-nhid', type=int, default=100, help='number of hidden units of the inner nn')
|
||||||
|
parser.add_argument('-dropout', type=float, default=0, help='controller dropout')
|
||||||
|
parser.add_argument('-memory_type', type=str, default='dnc', help='dense or sparse memory: dnc | sdnc | sam')
|
||||||
|
|
||||||
|
parser.add_argument('-nlayer', type=int, default=1, help='number of layers')
|
||||||
|
parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers')
|
||||||
|
parser.add_argument('-lr', type=float, default=1e-4, help='initial learning rate')
|
||||||
|
parser.add_argument('-optim', type=str, default='adam', help='learning rule, supports adam|rmsprop')
|
||||||
|
parser.add_argument('-clip', type=float, default=50, help='gradient clipping')
|
||||||
|
|
||||||
|
parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size')
|
||||||
|
parser.add_argument('-mem_size', type=int, default=20, help='memory dimension')
|
||||||
|
parser.add_argument('-mem_slot', type=int, default=16, help='number of memory slots')
|
||||||
|
parser.add_argument('-read_heads', type=int, default=4, help='number of read heads')
|
||||||
|
parser.add_argument('-sparse_reads', type=int, default=10, help='number of sparse reads per read head')
|
||||||
|
parser.add_argument('-temporal_reads', type=int, default=2, help='number of temporal reads')
|
||||||
|
|
||||||
|
parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length')
|
||||||
|
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
|
||||||
|
|
||||||
|
parser.add_argument('-iterations', type=int, default=2000, metavar='N', help='total number of iteration')
|
||||||
|
parser.add_argument('-summarize_freq', type=int, default=100, metavar='N', help='summarize frequency')
|
||||||
|
parser.add_argument('-check_freq', type=int, default=100, metavar='N', help='check point frequency')
|
||||||
|
parser.add_argument('-visdom', action='store_true', help='plot memory content on visdom per -summarize_freq steps')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
viz = Visdom()
|
||||||
|
# assert viz.check_connection()
|
||||||
|
|
||||||
|
if args.cuda != -1:
|
||||||
|
print('Using CUDA.')
|
||||||
|
T.manual_seed(1111)
|
||||||
|
else:
|
||||||
|
print('Using CPU.')
|
||||||
|
|
||||||
|
def llprint(message):
|
||||||
|
sys.stdout.write(message)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def onehot(x, n):
|
||||||
|
ret = np.zeros(n).astype(np.float32)
|
||||||
|
ret[x] = 1.0
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def generate_data(length, size):
|
||||||
|
|
||||||
|
content = np.random.randint(0, size - 1, length)
|
||||||
|
|
||||||
|
seqlen = length + 1
|
||||||
|
x_seq_list = [float('nan')] * seqlen
|
||||||
|
max_value = 0
|
||||||
|
max_ind = 0
|
||||||
|
for i in range(seqlen):
|
||||||
|
if (i < length):
|
||||||
|
x_seq_list[i] = onehot(content[i], size)
|
||||||
|
if (max_value <= content[i]):
|
||||||
|
max_value = content[i]
|
||||||
|
max_ind = i
|
||||||
|
else:
|
||||||
|
x_seq_list[i] = onehot(size - 1, size)
|
||||||
|
|
||||||
|
x_seq_list = np.array(x_seq_list)
|
||||||
|
x_seq_list = x_seq_list.reshape((1,) + x_seq_list.shape)
|
||||||
|
x_seq_list = np.reshape(x_seq_list, (1, -1, size))
|
||||||
|
|
||||||
|
target_output = np.zeros((1, 1, seqlen), dtype=np.float32)
|
||||||
|
target_output[:, -1, -1] = max_ind
|
||||||
|
target_output = np.reshape(target_output, (1, -1, 1))
|
||||||
|
|
||||||
|
weights_vec = np.zeros((1, 1, seqlen), dtype=np.float32)
|
||||||
|
weights_vec[:, -1, -1] = 1.0
|
||||||
|
weights_vec = np.reshape(weights_vec, (1, -1, 1))
|
||||||
|
|
||||||
|
return cudavec(x_seq_list, gpu_id=args.cuda).float(), \
|
||||||
|
cudavec(target_output, gpu_id=args.cuda).float(), \
|
||||||
|
cudavec(weights_vec, gpu_id=args.cuda)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
dirname = os.path.dirname(__file__)
|
||||||
|
ckpts_dir = os.path.join(dirname, 'checkpoints')
|
||||||
|
|
||||||
|
input_size = args.input_size
|
||||||
|
memory_type = args.memory_type
|
||||||
|
lr = args.lr
|
||||||
|
clip = args.clip
|
||||||
|
batch_size = args.batch_size
|
||||||
|
sequence_max_length = args.sequence_max_length
|
||||||
|
cuda = args.cuda
|
||||||
|
iterations = args.iterations
|
||||||
|
summarize_freq = args.summarize_freq
|
||||||
|
check_freq = args.check_freq
|
||||||
|
visdom = args.visdom
|
||||||
|
|
||||||
|
from_checkpoint = None
|
||||||
|
|
||||||
|
if args.memory_type == 'dnc':
|
||||||
|
rnn = DNC(
|
||||||
|
input_size=args.input_size,
|
||||||
|
hidden_size=args.nhid,
|
||||||
|
rnn_type=args.rnn_type,
|
||||||
|
num_layers=args.nlayer,
|
||||||
|
num_hidden_layers=args.nhlayer,
|
||||||
|
dropout=args.dropout,
|
||||||
|
nr_cells=args.mem_slot,
|
||||||
|
cell_size=args.mem_size,
|
||||||
|
read_heads=args.read_heads,
|
||||||
|
gpu_id=args.cuda,
|
||||||
|
debug=args.visdom,
|
||||||
|
batch_first=True,
|
||||||
|
independent_linears=False
|
||||||
|
)
|
||||||
|
elif args.memory_type == 'sdnc':
|
||||||
|
rnn = SDNC(
|
||||||
|
input_size=args.input_size,
|
||||||
|
hidden_size=args.nhid,
|
||||||
|
rnn_type=args.rnn_type,
|
||||||
|
num_layers=args.nlayer,
|
||||||
|
num_hidden_layers=args.nhlayer,
|
||||||
|
dropout=args.dropout,
|
||||||
|
nr_cells=args.mem_slot,
|
||||||
|
cell_size=args.mem_size,
|
||||||
|
sparse_reads=args.sparse_reads,
|
||||||
|
temporal_reads=args.temporal_reads,
|
||||||
|
read_heads=args.read_heads,
|
||||||
|
gpu_id=args.cuda,
|
||||||
|
debug=args.visdom,
|
||||||
|
batch_first=True,
|
||||||
|
independent_linears=False
|
||||||
|
)
|
||||||
|
elif args.memory_type == 'sam':
|
||||||
|
rnn = SAM(
|
||||||
|
input_size=args.input_size,
|
||||||
|
hidden_size=args.nhid,
|
||||||
|
rnn_type=args.rnn_type,
|
||||||
|
num_layers=args.nlayer,
|
||||||
|
num_hidden_layers=args.nhlayer,
|
||||||
|
dropout=args.dropout,
|
||||||
|
nr_cells=args.mem_slot,
|
||||||
|
cell_size=args.mem_size,
|
||||||
|
sparse_reads=args.sparse_reads,
|
||||||
|
read_heads=args.read_heads,
|
||||||
|
gpu_id=args.cuda,
|
||||||
|
debug=args.visdom,
|
||||||
|
batch_first=True,
|
||||||
|
independent_linears=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception('Not recognized type of memory')
|
||||||
|
|
||||||
|
if args.cuda != -1:
|
||||||
|
rnn = rnn.cuda(args.cuda)
|
||||||
|
|
||||||
|
print(rnn)
|
||||||
|
|
||||||
|
last_save_losses = []
|
||||||
|
|
||||||
|
if args.optim == 'adam':
|
||||||
|
optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
||||||
|
elif args.optim == 'adamax':
|
||||||
|
optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
||||||
|
elif args.optim == 'rmsprop':
|
||||||
|
optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, momentum=0.9, eps=1e-10) # 0.0001
|
||||||
|
elif args.optim == 'sgd':
|
||||||
|
optimizer = optim.SGD(rnn.parameters(), lr=args.lr) # 0.01
|
||||||
|
elif args.optim == 'adagrad':
|
||||||
|
optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr)
|
||||||
|
elif args.optim == 'adadelta':
|
||||||
|
optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)
|
||||||
|
|
||||||
|
last_100_losses = []
|
||||||
|
|
||||||
|
(chx, mhx, rv) = (None, None, None)
|
||||||
|
for epoch in range(iterations + 1):
|
||||||
|
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# We use for training just (sequence_max_length / 10) examples
|
||||||
|
random_length = np.random.randint(2, (sequence_max_length) + 1)
|
||||||
|
input_data, target_output, loss_weights = generate_data(random_length, input_size)
|
||||||
|
|
||||||
|
if rnn.debug:
|
||||||
|
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
else:
|
||||||
|
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
|
||||||
|
loss = T.mean(((loss_weights * output).sum(-1, keepdim=True) - target_output) ** 2)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
T.nn.utils.clip_grad_norm(rnn.parameters(), args.clip)
|
||||||
|
optimizer.step()
|
||||||
|
loss_value = loss.data[0]
|
||||||
|
|
||||||
|
# detach memory from graph
|
||||||
|
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
||||||
|
|
||||||
|
summarize = (epoch % summarize_freq == 0)
|
||||||
|
take_checkpoint = (epoch != 0) and (epoch % iterations == 0)
|
||||||
|
|
||||||
|
last_100_losses.append(loss_value)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if summarize:
|
||||||
|
output = (loss_weights * output).sum().data.cpu().numpy()[0]
|
||||||
|
target_output = target_output.sum().data.cpu().numpy()
|
||||||
|
|
||||||
|
llprint("\rIteration %d/%d" % (epoch, iterations))
|
||||||
|
llprint("\nAvg. Logistic Loss: %.4f\n" % (np.mean(last_100_losses)))
|
||||||
|
print(target_output)
|
||||||
|
print("Real value: ", ' = ' + str(int(target_output[0])))
|
||||||
|
print("Predicted: ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
|
||||||
|
last_100_losses = []
|
||||||
|
|
||||||
|
if take_checkpoint:
|
||||||
|
llprint("\nSaving Checkpoint ... "),
|
||||||
|
check_ptr = os.path.join(ckpts_dir, 'step_{}.pth'.format(epoch))
|
||||||
|
cur_weights = rnn.state_dict()
|
||||||
|
T.save(cur_weights, check_ptr)
|
||||||
|
llprint("Done!\n")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
llprint("\nTesting generalization...\n")
|
||||||
|
|
||||||
|
rnn.eval()
|
||||||
|
|
||||||
|
for i in range(int((iterations + 1) / 10)):
|
||||||
|
llprint("\nIteration %d/%d" % (i, iterations))
|
||||||
|
# We test now the learned generalization using sequence_max_length examples
|
||||||
|
random_length = np.random.randint(2, sequence_max_length * 2 + 1)
|
||||||
|
input_data, target_output, loss_weights = generate_data(random_length, input_size)
|
||||||
|
|
||||||
|
if rnn.debug:
|
||||||
|
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
else:
|
||||||
|
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
|
||||||
|
output = output[:, -1, :].sum().data.cpu().numpy()[0]
|
||||||
|
target_output = target_output.sum().data.cpu().numpy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("\nReal value: ", ' = ' + str(int(target_output[0])))
|
||||||
|
print("Predicted: ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
@ -51,7 +51,6 @@ parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', he
|
|||||||
parser.add_argument('-curriculum_increment', type=int, default=0, metavar='N', help='sequence_max_length incrementor per 1K iterations')
|
parser.add_argument('-curriculum_increment', type=int, default=0, metavar='N', help='sequence_max_length incrementor per 1K iterations')
|
||||||
parser.add_argument('-curriculum_freq', type=int, default=1000, metavar='N', help='sequence_max_length incrementor per 1K iterations')
|
parser.add_argument('-curriculum_freq', type=int, default=1000, metavar='N', help='sequence_max_length incrementor per 1K iterations')
|
||||||
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
|
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
|
||||||
parser.add_argument('-log-interval', type=int, default=200, metavar='N', help='report interval')
|
|
||||||
|
|
||||||
parser.add_argument('-iterations', type=int, default=100000, metavar='N', help='total number of iteration')
|
parser.add_argument('-iterations', type=int, default=100000, metavar='N', help='total number of iteration')
|
||||||
parser.add_argument('-summarize_freq', type=int, default=100, metavar='N', help='summarize frequency')
|
parser.add_argument('-summarize_freq', type=int, default=100, metavar='N', help='summarize frequency')
|
||||||
@ -183,12 +182,10 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
if args.optim == 'adam':
|
if args.optim == 'adam':
|
||||||
optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
||||||
if args.optim == 'sparseadam':
|
elif args.optim == 'adamax':
|
||||||
optimizer = optim.SparseAdam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
|
||||||
if args.optim == 'adamax':
|
|
||||||
optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
||||||
elif args.optim == 'rmsprop':
|
elif args.optim == 'rmsprop':
|
||||||
optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, eps=1e-10) # 0.0001
|
optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, momentum=0.9, eps=1e-10) # 0.0001
|
||||||
elif args.optim == 'sgd':
|
elif args.optim == 'sgd':
|
||||||
optimizer = optim.SGD(rnn.parameters(), lr=args.lr) # 0.01
|
optimizer = optim.SGD(rnn.parameters(), lr=args.lr) # 0.01
|
||||||
elif args.optim == 'adagrad':
|
elif args.optim == 'adagrad':
|
||||||
@ -361,3 +358,24 @@ if __name__ == '__main__':
|
|||||||
cur_weights = rnn.state_dict()
|
cur_weights = rnn.state_dict()
|
||||||
T.save(cur_weights, check_ptr)
|
T.save(cur_weights, check_ptr)
|
||||||
llprint("Done!\n")
|
llprint("Done!\n")
|
||||||
|
|
||||||
|
for i in range(int((iterations + 1) / 10)):
|
||||||
|
llprint("\nIteration %d/%d" % (i, iterations))
|
||||||
|
# We test now the learned generalization using sequence_max_length examples
|
||||||
|
random_length = np.random.randint(2, sequence_max_length * 10 + 1)
|
||||||
|
input_data, target_output, loss_weights = generate_data(random_length, input_size)
|
||||||
|
|
||||||
|
if rnn.debug:
|
||||||
|
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
else:
|
||||||
|
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
|
||||||
|
output = output[:, -1, :].sum().data.cpu().numpy()[0]
|
||||||
|
target_output = target_output.sum().data.cpu().numpy()
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("\nReal value: ", ' = ' + str(int(target_output[0])))
|
||||||
|
print("Predicted: ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user