Make FAISS work properly, fall back to flann when not available, fixes #23

This commit is contained in:
ixaxaar 2017-12-20 02:08:34 +05:30
parent 78ac06a332
commit 2c359e9a86
5 changed files with 106 additions and 65 deletions

View File

@ -51,6 +51,12 @@ pip install -r ./requirements.txt
pip install -e .
```
For using fully GPU based SDNCs or SAMs, install FAISS:
```bash
conda install faiss-gpu -c pytorch
```
`pytest` is required to run the test
## Architecure
@ -465,6 +471,15 @@ make -j 4
sudo make install
```
FAISS can be installed using:
```bash
conda install faiss-gpu -c pytorch
```
FAISS is much faster, has a GPU implementation and is interoperable with pytorch tensors.
We try to use FAISS by default, in absence of which we fall back to FLANN.
2. An alternative to FLANN is [FAISS](https://github.com/facebookresearch/faiss), which is much faster and interoperable with torch cuda tensors (but is difficult to distribute, see [dnc/faiss_index.py](dnc/faiss_index.py)).
3. `nan`s in the gradients are common, try with different batch sizes

View File

@ -1,11 +1,11 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from faiss import faiss
import faiss
from faiss.faiss import cast_integer_to_float_ptr as cast_float
from faiss.faiss import cast_integer_to_int_ptr as cast_int
from faiss.faiss import cast_integer_to_long_ptr as cast_long
from faiss import cast_integer_to_float_ptr as cast_float
from faiss import cast_integer_to_int_ptr as cast_int
from faiss import cast_integer_to_long_ptr as cast_long
from .util import *
@ -21,16 +21,16 @@ class FAISSIndex(object):
self.num_lists = num_lists
self.gpu_id = gpu_id
res = res if res else faiss.StandardGpuResources()
res.setTempMemoryFraction(0.01)
# BEWARE: if this variable gets deallocated, FAISS crashes
self.res = res if res else faiss.StandardGpuResources()
self.res.setTempMemoryFraction(0.01)
if self.gpu_id != -1:
res.initializeForDevice(self.gpu_id)
self.res.initializeForDevice(self.gpu_id)
nr_samples = self.nr_cells * 100 * self.cell_size
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size) * 10
# train = T.randn(self.nr_cells * 100, self.cell_size)
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size)
self.index = faiss.GpuIndexIVFFlat(res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_L2)
self.index.setNumProbes(self.probes)
self.train(train)
@ -48,7 +48,7 @@ class FAISSIndex(object):
self.index.reset()
T.cuda.synchronize()
def add(self, other, positions=None, last=-1):
def add(self, other, positions=None, last=None):
other = ensure_gpu(other, self.gpu_id)
T.cuda.synchronize()
@ -57,7 +57,7 @@ class FAISSIndex(object):
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
else:
other = other[:last, :]
other = other[:last, :] if last is not None else other
self.index.add_c(other.size(0), cast_float(ptr(other)))
T.cuda.synchronize()

View File

@ -8,7 +8,6 @@ import torch.nn.functional as F
import numpy as np
import math
from .flann_index import FLANNIndex
from .util import *
import time
@ -44,11 +43,12 @@ class SparseMemory(nn.Module):
m = self.mem_size
w = self.cell_size
r = self.read_heads
# The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell)
# The visible memory size: (K * R read heads, forward and backward
# temporal reads of size KL and least used memory cell)
self.c = (r * self.K) + 1
if self.independent_linears:
self.read_query_transform = nn.Linear(self.input_size, w*r)
self.read_query_transform = nn.Linear(self.input_size, w * r)
self.write_vector_transform = nn.Linear(self.input_size, w)
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
self.write_gate_transform = nn.Linear(self.input_size, 1)
@ -72,11 +72,20 @@ class SparseMemory(nn.Module):
if 'indexes' in hidden:
[x.reset() for x in hidden['indexes']]
else:
# create new indexes
hidden['indexes'] = \
[FLANNIndex(cell_size=self.cell_size,
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
# create new indexes, try to use FAISS, fall back to FLANN
try:
from .faiss_index import FAISSIndex
hidden['indexes'] = \
[FAISSIndex(cell_size=self.cell_size,
nr_cells=self.mem_size, K=self.K, num_lists=self.num_lists,
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
except Exception as e:
print("\nFalling back to FLANN (CPU). \nFor using faster, GPU based indexes, install FAISS: `conda install faiss-gpu -c pytorch`")
from .flann_index import FLANNIndex
hidden['indexes'] = \
[FLANNIndex(cell_size=self.cell_size,
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
# add existing memory into indexes
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
@ -104,7 +113,7 @@ class SparseMemory(nn.Module):
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(),
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
}
@ -126,9 +135,10 @@ class SparseMemory(nn.Module):
hidden['read_weights'].data.fill_(δ)
hidden['write_weights'].data.fill_(δ)
hidden['read_vectors'].data.fill_(δ)
hidden['least_used_mem'].data.fill_(c+1+self.timestep)
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
hidden['usage'].data.fill_(δ)
hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
hidden['read_positions'] = cuda(
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
return hidden
@ -147,8 +157,9 @@ class SparseMemory(nn.Module):
hidden['indexes'][batch].reset()
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size-1
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + 1) if mem_limit_reached else hidden['least_used_mem'] + 1
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
1) if mem_limit_reached else hidden['least_used_mem'] + 1
return hidden
@ -177,7 +188,8 @@ class SparseMemory(nn.Module):
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
# write into memory
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
hidden['visible_memory'] = hidden['visible_memory'] * \
(1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
hidden = self.write_into_sparse_memory(hidden)
return hidden
@ -240,11 +252,11 @@ class SparseMemory(nn.Module):
# sparse read
read_vectors, positions, read_weights, visible_memory = \
self.read_from_sparse_memory(
hidden['memory'],
hidden['indexes'],
read_query,
hidden['least_used_mem'],
hidden['usage']
hidden['memory'],
hidden['indexes'],
read_query,
hidden['least_used_mem'],
hidden['usage']
)
hidden['read_positions'] = positions
@ -276,11 +288,11 @@ class SparseMemory(nn.Module):
else:
ξ = self.interface_weights(ξ)
# r read keys (b * r * w)
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
read_query = ξ[:, :r * w].contiguous().view(b, r, w)
# write key (b * 1 * w)
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w)
# write vector (b * 1 * r)
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c)
interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
# write gate (b * 1)
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)

View File

@ -46,11 +46,12 @@ class SparseTemporalMemory(nn.Module):
m = self.mem_size
w = self.cell_size
r = self.read_heads
# The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell)
# The visible memory size: (K * R read heads, forward and backward
# temporal reads of size KL and least used memory cell)
self.c = (r * self.K) + (self.KL * 2) + 1
if self.independent_linears:
self.read_query_transform = nn.Linear(self.input_size, w*r)
self.read_query_transform = nn.Linear(self.input_size, w * r)
self.write_vector_transform = nn.Linear(self.input_size, w)
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
self.write_gate_transform = nn.Linear(self.input_size, 1)
@ -75,10 +76,19 @@ class SparseTemporalMemory(nn.Module):
[x.reset() for x in hidden['indexes']]
else:
# create new indexes
hidden['indexes'] = \
[FLANNIndex(cell_size=self.cell_size,
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
try:
from .faiss_index import FAISSIndex
hidden['indexes'] = \
[FAISSIndex(cell_size=self.cell_size,
nr_cells=self.mem_size, K=self.K, num_lists=self.num_lists,
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
except Exception as e:
print("\nFalling back to FLANN (CPU). \nFor using faster, GPU based indexes, install FAISS: `conda install faiss-gpu -c pytorch`")
from .flann_index import FLANNIndex
hidden['indexes'] = \
[FLANNIndex(cell_size=self.cell_size,
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
# add existing memory into indexes
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
@ -103,13 +113,13 @@ class SparseTemporalMemory(nn.Module):
# warning can be a huge chunk of contiguous memory
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
'link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
'rev_link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
'precedence': cuda(T.zeros(b, self.KL*2).fill_(δ), gpu_id=self.gpu_id),
'link_matrix': cuda(T.zeros(b, m, self.KL * 2), gpu_id=self.gpu_id),
'rev_link_matrix': cuda(T.zeros(b, m, self.KL * 2), gpu_id=self.gpu_id),
'precedence': cuda(T.zeros(b, self.KL * 2).fill_(δ), gpu_id=self.gpu_id),
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(),
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
}
@ -137,9 +147,10 @@ class SparseTemporalMemory(nn.Module):
hidden['read_weights'].data.fill_(δ)
hidden['write_weights'].data.fill_(δ)
hidden['read_vectors'].data.fill_(δ)
hidden['least_used_mem'].data.fill_(c+1+self.timestep)
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
hidden['usage'].data.fill_(δ)
hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
hidden['read_positions'] = cuda(
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
return hidden
@ -158,8 +169,9 @@ class SparseTemporalMemory(nn.Module):
hidden['indexes'][batch].reset()
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size-1
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + 1) if mem_limit_reached else hidden['least_used_mem'] + 1
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
1) if mem_limit_reached else hidden['least_used_mem'] + 1
return hidden
@ -179,7 +191,8 @@ class SparseTemporalMemory(nn.Module):
link_matrix = (1 - write_weights_i) * link_matrix + write_weights_i * precedence_j
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (temporal_write_weights_j * precedence_dense_i)
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + \
(temporal_write_weights_j * precedence_dense_i)
return link_matrix * I, rev_link_matrix * I
@ -211,22 +224,23 @@ class SparseTemporalMemory(nn.Module):
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
# write into memory
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
hidden['visible_memory'] = hidden['visible_memory'] * \
(1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
hidden = self.write_into_sparse_memory(hidden)
# update link_matrix and precedence
(b, c) = write_weights.size()
# update link matrix
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
temporal_read_positions = hidden['read_positions'][:, self.read_heads * self.K + 1:]
hidden['link_matrix'], hidden['rev_link_matrix'] = \
self.update_link_matrices(
self.update_link_matrices(
hidden['link_matrix'],
hidden['rev_link_matrix'],
hidden['write_weights'],
hidden['precedence'],
temporal_read_positions
)
)
# update precedence vector
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
@ -299,20 +313,20 @@ class SparseTemporalMemory(nn.Module):
def read(self, read_query, hidden):
# get forward and backward weights
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
temporal_read_positions = hidden['read_positions'][:, self.read_heads * self.K + 1:]
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
forward, backward = self.directional_weightings(hidden['link_matrix'], hidden['rev_link_matrix'], read_weights)
# sparse read
read_vectors, positions, read_weights, visible_memory = \
self.read_from_sparse_memory(
hidden['memory'],
hidden['indexes'],
read_query,
hidden['least_used_mem'],
hidden['usage'],
forward, backward,
hidden['read_positions']
hidden['memory'],
hidden['indexes'],
read_query,
hidden['least_used_mem'],
hidden['usage'],
forward, backward,
hidden['read_positions']
)
hidden['read_positions'] = positions
@ -344,11 +358,11 @@ class SparseTemporalMemory(nn.Module):
else:
ξ = self.interface_weights(ξ)
# r read keys (b * r * w)
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
read_query = ξ[:, :r * w].contiguous().view(b, r, w)
# write key (b * 1 * w)
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w)
# write vector (b * 1 * r)
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c)
interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
# write gate (b * 1)
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)

View File

@ -138,7 +138,7 @@ def ptr(tensor):
if T.is_tensor(tensor):
return tensor.storage().data_ptr()
elif hasattr(tensor, 'data'):
return tensor.data.storage().data_ptr()
return tensor.clone().data.storage().data_ptr()
else:
return tensor