diff --git a/README.md b/README.md index 1630155..48ba444 100644 --- a/README.md +++ b/README.md @@ -296,9 +296,25 @@ The visdom dashboard shows memory as a heatmap for batch 0 every `-summarize_fre ## General noteworthy stuff -1. DNCs converge faster with Adam and RMSProp learning rules, SGD generally converges extremely slowly. +1. SDNCs use the [FLANN approximate nearest library](https://www.cs.ubc.ca/research/flann/), with its python binding [pyflann3](https://github.com/primetang/pyflann). + +FLANN can be installed either from pip (automatically as a dependency), or from source (e.g. for multithreading via OpenMP): + +```bash +# install openmp first: e.g. `sudo pacman -S openmp` for Arch. +git clone git://github.com/mariusmuja/flann.git +cd flann +mkdir build +cd build +cmake .. +make -j 4 +sudo make install +``` + +2. An alternative to FLANN is [FAISS](https://github.com/facebookresearch/faiss), which is much faster and interoperable with torch cuda tensors (but is difficult to distribute, see [dnc/faiss_index.py](dnc/faiss_index.py)). +3. DNCs converge faster with Adam and RMSProp learning rules, SGD generally converges extremely slowly. The copy task, for example, takes 25k iterations on SGD with lr 1 compared to 3.5k for adam with lr 0.01. -2. `nan`s in the gradients are common, try with different batch sizes +4. `nan`s in the gradients are common, try with different batch sizes Repos referred to for creation of this repo: diff --git a/dnc/sparse_memory.py b/dnc/sparse_memory.py index 5b5b6e0..a9607f9 100644 --- a/dnc/sparse_memory.py +++ b/dnc/sparse_memory.py @@ -255,7 +255,7 @@ class SparseMemory(nn.Module): # we search for k cells per read head for batch in range(b): distances, positions = indexes[batch].search(keys[batch]) - read_positions.append(T.clamp(positions, 0, self.mem_size - 1)) + read_positions.append(positions) read_positions = T.stack(read_positions, 0) # add least used mem to read positions @@ -275,6 +275,7 @@ class SparseMemory(nn.Module): # append forward and backward read positions, might lead to duplicates read_positions = T.cat([read_positions, fp, bp], 1) read_positions = T.cat([read_positions, least_used_mem], 1) + read_positions = T.clamp(read_positions, 0, max_length) visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, self.c, w))