various tweaks, influence distinct write positions, condition read weights with usage
This commit is contained in:
parent
63d49afe40
commit
e588f6b398
66
README.md
66
README.md
@ -17,66 +17,12 @@ For using sparse DNCs, additional libraries are required:
|
||||
|
||||
SDNCs require an additional library: [facebookresearch/faiss](https://github.com/facebookresearch/faiss).
|
||||
A compiled version of the library with intel SSE + CUDA 8 support ships with this library.
|
||||
If that does not work, one might need to manually compile faiss, as detailed below:
|
||||
|
||||
#### Installing FAISS
|
||||
|
||||
Needs `libopenblas.so` in `/usr/lib/`.
|
||||
|
||||
This has been tested on Arch Linux. Other distributions might have different libopenblas path or cuda root dir or numpy include files dir.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/facebookresearch/faiss.git
|
||||
cd faiss
|
||||
cp ./example_makefiles/makefile.inc.Linux ./makefile.inc
|
||||
# change libopenblas path
|
||||
sed -i "s/lib64\/libopenblas\.so\.0/lib\/libopenblas\.so/g" ./makefile.inc
|
||||
# add option for nvcc to work properly with g++ > 5
|
||||
sed -i "s/std c++11 \-lineinfo/std c++11 \-lineinfo \-Xcompiler \-D__CORRECT_ISO_CPP11_MATH_H_PROTO/g" ./makefile.inc
|
||||
# change CUDA ROOT
|
||||
sed -i "s/CUDAROOT=\/usr\/local\/cuda-8.0\//CUDAROOT=\/opt\/cuda\//g" ./makefile.inc
|
||||
# change numpy include files (for v3.6)
|
||||
sed -i "s/PYTHONCFLAGS=\-I\/usr\/include\/python2.7\/ \-I\/usr\/lib64\/python2.7\/site\-packages\/numpy\/core\/include\//PYTHONCFLAGS=\-I\/usr\/include\/python3.6m\/ \-I\/usr\/lib\/python3.6\/site\-packages\/numpy\/core\/include/g"
|
||||
|
||||
# build
|
||||
make
|
||||
cd gpu
|
||||
make
|
||||
cd ..
|
||||
make py
|
||||
cd gpu
|
||||
make py
|
||||
cd ..
|
||||
|
||||
mkdir /tmp/faiss
|
||||
find -name "*.so" -exec cp {} /tmp/faiss \;
|
||||
find -name "*.a" -exec cp {} /tmp/faiss \;
|
||||
find -name "*.py" -exec cp {} /tmp/faiss \;
|
||||
mv /tmp/faiss .
|
||||
cd faiss
|
||||
|
||||
# convert to python3
|
||||
2to3 -w ./*.py
|
||||
rm -rf *.bak
|
||||
|
||||
# Fix relative imports
|
||||
for i in *.py; do
|
||||
filename=`echo $i | cut -d "." -f 1`
|
||||
echo $filename
|
||||
find -name "*.py" -exec sed -i "s/import $filename/import \.$filename/g" {} \;
|
||||
find -name "*.py" -exec sed -i "s/from $filename import/from \.$filename import/g" {} \;
|
||||
done
|
||||
|
||||
cd ..
|
||||
|
||||
git clone https://github.com/ixaxaar/pytorch-dnc
|
||||
mv faiss pytorch-dnc
|
||||
cd pytorch-dnc
|
||||
sudo pip install -e .
|
||||
```
|
||||
|
||||
If that does not work, one might need to install from source, as detailed below:
|
||||
|
||||
#### Installing from source
|
||||
|
||||
A script for building and installing this lib from source can be found at [scripts/install.sh](./scripts/install.sh).
|
||||
Tested on `ubuntu 16.04`, `Arch / Manjaro` and `Fedora 27`.
|
||||
|
||||
## Architecure
|
||||
|
||||
@ -113,8 +59,8 @@ Following are the forward pass parameters:
|
||||
| --- | --- | --- |
|
||||
| input | - | The input vector `(B*T*X)` or `(T*B*X)` |
|
||||
| hidden | `(None,None,None)` | Hidden states `(controller hidden, memory hidden, read vectors)` |
|
||||
| reset_experience | `False` | Whether to reset memory (This is a parameter for the forward pass |
|
||||
| pass_through_memory | `True` | Whether to pass through memory (This is a parameter for the forward pass |
|
||||
| reset_experience | `False` | Whether to reset memory |
|
||||
| pass_through_memory | `True` | Whether to pass through memory |
|
||||
|
||||
|
||||
### Example usage:
|
||||
|
@ -11,12 +11,13 @@ from .util import *
|
||||
|
||||
class Index(object):
|
||||
|
||||
def __init__(self, cell_size=20, nr_cells=1024, K=4, probes=32, res=None, train=None, gpu_id=-1):
|
||||
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_lists=30, probes=32, res=None, train=None, gpu_id=-1):
|
||||
super(Index, self).__init__()
|
||||
self.cell_size = cell_size
|
||||
self.nr_cells = nr_cells
|
||||
self.probes = probes
|
||||
self.K = K
|
||||
self.num_lists = num_lists
|
||||
self.gpu_id = gpu_id
|
||||
self.res = res if res else faiss.StandardGpuResources()
|
||||
self.res.setTempMemoryFraction(0.01)
|
||||
@ -25,9 +26,10 @@ class Index(object):
|
||||
|
||||
nr_samples = self.nr_cells * 100 * self.cell_size
|
||||
train = train if train is not None else T.arange(-nr_samples, nr_samples, 2).view(self.nr_cells * 100, self.cell_size) / nr_samples
|
||||
# train = T.randn(self.nr_cells * 100, self.cell_size)
|
||||
|
||||
self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.K, faiss.METRIC_INNER_PRODUCT)
|
||||
self.index.setNumProbes(self.probes)
|
||||
self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
|
||||
self.index.setNumProbes(self.num_lists)
|
||||
self.train(train)
|
||||
|
||||
def cuda(self, gpu_id):
|
||||
@ -51,7 +53,7 @@ class Index(object):
|
||||
if positions is not None:
|
||||
positions = ensure_gpu(positions, self.gpu_id)
|
||||
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
|
||||
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions)))
|
||||
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
|
||||
else:
|
||||
self.index.add_c(other.size(0), cast_float(ptr(other)))
|
||||
T.cuda.synchronize()
|
||||
@ -77,4 +79,4 @@ class Index(object):
|
||||
cast_long(ptr(labels))
|
||||
)
|
||||
T.cuda.synchronize()
|
||||
return (distances, labels)
|
||||
return (distances, (labels-1))
|
||||
|
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
|
@ -174,14 +174,15 @@ class SDNC(nn.Module):
|
||||
input, chx = self.rnns[layer](input.unsqueeze(1), chx)
|
||||
input = input.squeeze(1)
|
||||
|
||||
# the interface vector
|
||||
ξ = input
|
||||
# clip the controller output
|
||||
if self.clip != 0:
|
||||
output = T.clamp(input, -self.clip, self.clip)
|
||||
else:
|
||||
output = input
|
||||
|
||||
# the interface vector
|
||||
ξ = output
|
||||
|
||||
# pass through memory
|
||||
if pass_through_memory:
|
||||
if self.share_memory:
|
||||
|
@ -6,6 +6,7 @@ import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
from .indexes import Index
|
||||
from .util import *
|
||||
@ -21,7 +22,7 @@ class SparseMemory(nn.Module):
|
||||
cell_size=32,
|
||||
independent_linears=True,
|
||||
sparse_reads=4,
|
||||
num_kdtrees=4,
|
||||
num_lists=None,
|
||||
index_checks=32,
|
||||
rebuild_indexes_after=10,
|
||||
gpu_id=-1,
|
||||
@ -36,7 +37,7 @@ class SparseMemory(nn.Module):
|
||||
self.input_size = input_size
|
||||
self.independent_linears = independent_linears
|
||||
self.K = sparse_reads if self.mem_size > sparse_reads else self.mem_size
|
||||
self.num_kdtrees = num_kdtrees
|
||||
self.num_lists = num_lists if num_lists is not None else int(self.mem_size / 100)
|
||||
self.index_checks = index_checks
|
||||
# self.rebuild_indexes_after = rebuild_indexes_after
|
||||
|
||||
@ -69,7 +70,7 @@ class SparseMemory(nn.Module):
|
||||
# create new indexes
|
||||
hidden['indexes'] = \
|
||||
[Index(cell_size=self.cell_size,
|
||||
nr_cells=self.mem_size, K=self.K,
|
||||
nr_cells=self.mem_size, K=self.K, num_lists=self.num_lists,
|
||||
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||
|
||||
# add existing memory into indexes
|
||||
@ -94,9 +95,9 @@ class SparseMemory(nn.Module):
|
||||
'read_weights': cuda(T.zeros(b, 1, r).fill_(δ), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, 1, r).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||
'last_used_mem': cuda(T.zeros(b, 1), gpu_id=self.gpu_id).long(),
|
||||
'usage': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
|
||||
'read_positions': cuda(T.zeros(b, 1, r).fill_(0), gpu_id=self.gpu_id).long()
|
||||
'last_used_mem': cuda(T.zeros(b, 1).fill_(δ), gpu_id=self.gpu_id).long(),
|
||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_positions': cuda(T.arange(0, r).expand(b, 1, r), gpu_id=self.gpu_id).long()
|
||||
}
|
||||
hidden = self.rebuild_indexes(hidden, erase=True)
|
||||
else:
|
||||
@ -115,8 +116,8 @@ class SparseMemory(nn.Module):
|
||||
hidden['write_weights'].data.fill_(δ)
|
||||
hidden['read_vectors'].data.fill_(δ)
|
||||
hidden['last_used_mem'].data.fill_(0)
|
||||
hidden['usage'].data.fill_(0)
|
||||
hidden['read_positions'].data.fill_(0)
|
||||
hidden['usage'].data.fill_(δ)
|
||||
hidden['read_positions'] = cuda(T.arange(0, r).expand(b, 1, r), gpu_id=self.gpu_id).long()
|
||||
return hidden
|
||||
|
||||
def write_into_sparse_memory(self, hidden):
|
||||
@ -175,7 +176,7 @@ class SparseMemory(nn.Module):
|
||||
|
||||
return usage, I
|
||||
|
||||
def read_from_sparse_memory(self, memory, indexes, keys, last_used_mem):
|
||||
def read_from_sparse_memory(self, memory, indexes, keys, last_used_mem, usage):
|
||||
b = keys.size(0)
|
||||
read_positions = []
|
||||
read_weights = []
|
||||
@ -186,17 +187,26 @@ class SparseMemory(nn.Module):
|
||||
read_weights.append(distances)
|
||||
read_positions.append(T.clamp(positions, 0, self.mem_size - 1))
|
||||
|
||||
# add least used mem to read positions
|
||||
read_positions = T.stack(read_positions, 0)
|
||||
|
||||
# TODO: explore possibility of reading co-locations and such
|
||||
# if read_collocations:
|
||||
# read the previous and the next memory locations
|
||||
# read_positions = T.cat([read_positions, read_positions-1, read_positions+1], -1)
|
||||
|
||||
read_positions = var(read_positions)
|
||||
read_positions = T.cat([read_positions, last_used_mem.unsqueeze(1)], 2)
|
||||
|
||||
# add weight of 0 for least used mem block
|
||||
read_weights = T.stack(read_weights, 0)
|
||||
new_block = read_weights.new(b, 1, 1)
|
||||
new_block.fill_(0)
|
||||
new_block.fill_(δ)
|
||||
read_weights = T.cat([read_weights, new_block], 2)
|
||||
read_weights = F.softmax(var(read_weights))
|
||||
|
||||
# add least used mem to read positions
|
||||
read_positions = T.stack(read_positions, 0)
|
||||
read_positions = var(read_positions)
|
||||
read_positions = T.cat([read_positions, last_used_mem.unsqueeze(1)], 2)
|
||||
read_weights = var(read_weights)
|
||||
# condition read weights by their usages
|
||||
relevant_usages = usage.gather(1, read_positions.squeeze())
|
||||
read_weights = (read_weights.squeeze(1) * relevant_usages).unsqueeze(1)
|
||||
|
||||
(b, m, w) = memory.size()
|
||||
read_vectors = memory.gather(1, read_positions.squeeze().unsqueeze(2).expand(b, self.K+1, w))
|
||||
@ -206,7 +216,13 @@ class SparseMemory(nn.Module):
|
||||
def read(self, read_query, hidden):
|
||||
# sparse read
|
||||
read_vectors, positions, read_weights = \
|
||||
self.read_from_sparse_memory(hidden['memory'], hidden['indexes'], read_query, hidden['last_used_mem'])
|
||||
self.read_from_sparse_memory(
|
||||
hidden['memory'],
|
||||
hidden['indexes'],
|
||||
read_query,
|
||||
hidden['last_used_mem'],
|
||||
hidden['usage']
|
||||
)
|
||||
hidden['read_positions'] = positions
|
||||
hidden['read_weights'] = read_weights
|
||||
hidden['read_vectors'] = read_vectors
|
||||
|
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
@ -99,7 +99,7 @@ def σ(input, axis=1):
|
||||
def register_nan_checks(model):
|
||||
def check_grad(module, grad_input, grad_output):
|
||||
# print(module) you can add this to see that the hook is called
|
||||
print('hook called for ' + str(type(module)))
|
||||
# print('hook called for ' + str(type(module)))
|
||||
if any(np.all(np.isnan(gi.data.cpu().numpy())) for gi in grad_input if gi is not None):
|
||||
print('NaN gradient in grad_input ' + type(module).__name__)
|
||||
|
||||
|
@ -1,156 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
install_openblas() {
|
||||
# Get and build OpenBlas (Torch is much better with a decent Blas)
|
||||
cd /tmp/
|
||||
rm -rf OpenBLAS
|
||||
git clone https://github.com/xianyi/OpenBLAS.git
|
||||
cd OpenBLAS
|
||||
if [ $(getconf _NPROCESSORS_ONLN) == 1 ]; then
|
||||
make NO_AFFINITY=1 USE_OPENMP=0 USE_THREAD=0
|
||||
else
|
||||
make NO_AFFINITY=1 USE_OPENMP=1
|
||||
fi
|
||||
RET=$?;
|
||||
if [ $RET -ne 0 ]; then
|
||||
echo "Error. OpenBLAS could not be compiled";
|
||||
exit $RET;
|
||||
fi
|
||||
sudo make install
|
||||
RET=$?;
|
||||
if [ $RET -ne 0 ]; then
|
||||
echo "Error. OpenBLAS could not be installed";
|
||||
exit $RET;
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
# pre-requisites
|
||||
if [[ -r /usr/bin/pacman ]]; then
|
||||
pacman -Syy
|
||||
pacman -S --noconfirm git wget python-pip
|
||||
|
||||
elif [[ -r /usr/bin/apt-get ]]; then
|
||||
apt-get update
|
||||
apt-get install -y git wget python3-examples python3-pip
|
||||
|
||||
elif [[ -r /usr/bin/yum ]]; then
|
||||
# cause install-deps supports only fedora v21 and v22
|
||||
yum install -y wget cmake curl readline-devel ncurses-devel \
|
||||
gcc-c++ gcc-gfortran git gnuplot unzip \
|
||||
nodejs npm libjpeg-turbo-devel libpng-devel \
|
||||
ImageMagick GraphicsMagick-devel fftw-devel \
|
||||
sox-devel sox qt-devel qtwebkit-devel \
|
||||
python-ipython czmq czmq-devel python3-tools findutils which
|
||||
|
||||
install_openblas
|
||||
|
||||
else
|
||||
echo "Does not support your distribution, top kek"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
wget https://raw.githubusercontent.com/torch/ezinstall/master/install-deps -O ./install-deps
|
||||
chmod +x ./install-deps
|
||||
./install-deps
|
||||
|
||||
pip3 install --upgrade pip
|
||||
|
||||
# clone repo
|
||||
git clone https://github.com/facebookresearch/faiss.git /faiss
|
||||
cd /faiss
|
||||
|
||||
git checkout e652a6648f52d20d95426d1814737e0e8601f348
|
||||
|
||||
# use example makefile
|
||||
cp ./example_makefiles/makefile.inc.Linux ./makefile.inc
|
||||
|
||||
# change stuff in example makefile, for ubuntu 16.04
|
||||
# arch
|
||||
if [[ -r /usr/bin/pacman ]]; then
|
||||
echo "Arch Linux found"
|
||||
pacman -S --noconfirm swig
|
||||
|
||||
# openblas dir
|
||||
sed -i "s/lib64\/libopenblas\.so\.0/lib\/libopenblas\.so/g" ./makefile.inc
|
||||
# nvcc compiler options
|
||||
sed -i "s/std c++11 \-lineinfo/std c++11 \-lineinfo \-Xcompiler \-D__CORRECT_ISO_CPP11_MATH_H_PROTO/g" ./makefile.inc
|
||||
# cuda installation root
|
||||
sed -i "s/CUDAROOT=\/usr\/local\/cuda-8.0\//CUDAROOT=\/usr\/local\/cuda/g" ./makefile.inc
|
||||
# python include directories
|
||||
sed -i "s/PYTHONCFLAGS=\-I\/usr\/include\/python2.7\/ \-I\/usr\/lib64\/python2.7\/site\-packages\/numpy\/core\/include\//PYTHONCFLAGS=\-I \/usr\/include\/python3\.6m \-I \/usr\/lib\/python3\.6\/site\-packages\/numpy\/core\/include/g" ./makefile.inc
|
||||
|
||||
# ubuntu
|
||||
elif [[ -r /usr/bin/apt-get ]]; then
|
||||
echo "Ubuntu found"
|
||||
apt-get -qq update
|
||||
apt-get install -y swig
|
||||
|
||||
# openblas dir
|
||||
sed -i "s/lib64\/libopenblas\.so\.0/lib\/libopenblas\.so/g" ./makefile.inc
|
||||
# nvcc compiler options
|
||||
sed -i "s/std c++11 \-lineinfo/std c++11 \-lineinfo \-Xcompiler \-D__CORRECT_ISO_CPP11_MATH_H_PROTO/g" ./makefile.inc
|
||||
# cuda installation root
|
||||
sed -i "s/CUDAROOT=\/usr\/local\/cuda-8.0\//CUDAROOT=\/usr\/local\/cuda/g" ./makefile.inc
|
||||
# python include directories
|
||||
sed -i "s/PYTHONCFLAGS=\-I\/usr\/include\/python2.7\/ \-I\/usr\/lib64\/python2.7\/site\-packages\/numpy\/core\/include\//PYTHONCFLAGS=\-I \/usr\/include\/python3\.5 \-I \/usr\/local\/lib\/python3\.5\/dist\-packages\/numpy\/core\/include/g" ./makefile.inc
|
||||
|
||||
# fedora
|
||||
elif [[ -r /usr/bin/yum ]]; then
|
||||
echo "Fedora found"
|
||||
yum install -y swig
|
||||
|
||||
# openblas dir
|
||||
sed -i "s/lib64\/libopenblas\.so\.0/lib\/libopenblas\.so/g" ./makefile.inc
|
||||
cp /tmp/OpenBLAS/libopenblas.so /usr/lib/
|
||||
# nvcc compiler options
|
||||
sed -i "s/std c++11 \-lineinfo/std c++11 \-lineinfo \-Xcompiler \-D__CORRECT_ISO_CPP11_MATH_H_PROTO/g" ./makefile.inc
|
||||
# cuda installation root
|
||||
sed -i "s/CUDAROOT=\/usr\/local\/cuda-8.0\//CUDAROOT=\/usr\/local\/cuda/g" ./makefile.inc
|
||||
# python include directories
|
||||
sed -i "s/PYTHONCFLAGS=\-I\/usr\/include\/python2.7\/ \-I\/usr\/lib64\/python2.7\/site\-packages\/numpy\/core\/include\//PYTHONCFLAGS=\-I \/usr\/include\/python3\.6m \-I \/usr\/local\/lib64\/python3\.6\/site\-packages\/numpy\/core\/include/g" ./makefile.inc
|
||||
|
||||
# fucked
|
||||
else
|
||||
echo "Does not support your distribution, top kek"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
# build
|
||||
cd /faiss
|
||||
make
|
||||
cd gpu
|
||||
make
|
||||
cd ..
|
||||
make py
|
||||
cd gpu
|
||||
make py
|
||||
cd ..
|
||||
|
||||
mkdir /tmp/faiss
|
||||
find -name "*.so" -exec cp {} /tmp/faiss \;
|
||||
find -name "*.a" -exec cp {} /tmp/faiss \;
|
||||
find -name "*.py" -exec cp {} /tmp/faiss \;
|
||||
mv /tmp/faiss .
|
||||
cd faiss
|
||||
|
||||
# convert to python3
|
||||
2to3 -w ./*.py
|
||||
rm -rf *.bak
|
||||
|
||||
# Fix relative imports
|
||||
for i in *.py; do
|
||||
filename=`echo $i | cut -d "." -f 1`
|
||||
echo $filename
|
||||
find -name "*.py" -exec sed -i "s/import $filename/import \.$filename/g" {} \;
|
||||
find -name "*.py" -exec sed -i "s/from $filename import/from \.$filename import/g" {} \;
|
||||
done
|
||||
|
||||
cd ..
|
||||
|
||||
git clone https://github.com/ixaxaar/pytorch-dnc
|
||||
rm -rf pytorch-dnc/faiss
|
||||
mv faiss pytorch-dnc
|
||||
cd pytorch-dnc
|
||||
pip3 install -e .
|
||||
|
Loading…
Reference in New Issue
Block a user