Clamp read positions after concating them, add to readme

This commit is contained in:
ixaxaar 2017-12-15 18:28:11 +05:30
parent 116432d2c5
commit e686d240ee
2 changed files with 20 additions and 3 deletions

View File

@ -296,9 +296,25 @@ The visdom dashboard shows memory as a heatmap for batch 0 every `-summarize_fre
## General noteworthy stuff
1. DNCs converge faster with Adam and RMSProp learning rules, SGD generally converges extremely slowly.
1. SDNCs use the [FLANN approximate nearest library](https://www.cs.ubc.ca/research/flann/), with its python binding [pyflann3](https://github.com/primetang/pyflann).
FLANN can be installed either from pip (automatically as a dependency), or from source (e.g. for multithreading via OpenMP):
```bash
# install openmp first: e.g. `sudo pacman -S openmp` for Arch.
git clone git://github.com/mariusmuja/flann.git
cd flann
mkdir build
cd build
cmake ..
make -j 4
sudo make install
```
2. An alternative to FLANN is [FAISS](https://github.com/facebookresearch/faiss), which is much faster and interoperable with torch cuda tensors (but is difficult to distribute, see [dnc/faiss_index.py](dnc/faiss_index.py)).
3. DNCs converge faster with Adam and RMSProp learning rules, SGD generally converges extremely slowly.
The copy task, for example, takes 25k iterations on SGD with lr 1 compared to 3.5k for adam with lr 0.01.
2. `nan`s in the gradients are common, try with different batch sizes
4. `nan`s in the gradients are common, try with different batch sizes
Repos referred to for creation of this repo:

View File

@ -255,7 +255,7 @@ class SparseMemory(nn.Module):
# we search for k cells per read head
for batch in range(b):
distances, positions = indexes[batch].search(keys[batch])
read_positions.append(T.clamp(positions, 0, self.mem_size - 1))
read_positions.append(positions)
read_positions = T.stack(read_positions, 0)
# add least used mem to read positions
@ -275,6 +275,7 @@ class SparseMemory(nn.Module):
# append forward and backward read positions, might lead to duplicates
read_positions = T.cat([read_positions, fp, bp], 1)
read_positions = T.cat([read_positions, least_used_mem], 1)
read_positions = T.clamp(read_positions, 0, max_length)
visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, self.c, w))