Make FAISS work properly, fall back to flann when not available, fixes #23
This commit is contained in:
parent
78ac06a332
commit
2c359e9a86
15
README.md
15
README.md
@ -51,6 +51,12 @@ pip install -r ./requirements.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For using fully GPU based SDNCs or SAMs, install FAISS:
|
||||
|
||||
```bash
|
||||
conda install faiss-gpu -c pytorch
|
||||
```
|
||||
|
||||
`pytest` is required to run the test
|
||||
|
||||
## Architecure
|
||||
@ -465,6 +471,15 @@ make -j 4
|
||||
sudo make install
|
||||
```
|
||||
|
||||
FAISS can be installed using:
|
||||
|
||||
```bash
|
||||
conda install faiss-gpu -c pytorch
|
||||
```
|
||||
|
||||
FAISS is much faster, has a GPU implementation and is interoperable with pytorch tensors.
|
||||
We try to use FAISS by default, in absence of which we fall back to FLANN.
|
||||
|
||||
2. An alternative to FLANN is [FAISS](https://github.com/facebookresearch/faiss), which is much faster and interoperable with torch cuda tensors (but is difficult to distribute, see [dnc/faiss_index.py](dnc/faiss_index.py)).
|
||||
3. `nan`s in the gradients are common, try with different batch sizes
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from faiss import faiss
|
||||
import faiss
|
||||
|
||||
from faiss.faiss import cast_integer_to_float_ptr as cast_float
|
||||
from faiss.faiss import cast_integer_to_int_ptr as cast_int
|
||||
from faiss.faiss import cast_integer_to_long_ptr as cast_long
|
||||
from faiss import cast_integer_to_float_ptr as cast_float
|
||||
from faiss import cast_integer_to_int_ptr as cast_int
|
||||
from faiss import cast_integer_to_long_ptr as cast_long
|
||||
|
||||
from .util import *
|
||||
|
||||
@ -21,16 +21,16 @@ class FAISSIndex(object):
|
||||
self.num_lists = num_lists
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
res = res if res else faiss.StandardGpuResources()
|
||||
res.setTempMemoryFraction(0.01)
|
||||
# BEWARE: if this variable gets deallocated, FAISS crashes
|
||||
self.res = res if res else faiss.StandardGpuResources()
|
||||
self.res.setTempMemoryFraction(0.01)
|
||||
if self.gpu_id != -1:
|
||||
res.initializeForDevice(self.gpu_id)
|
||||
self.res.initializeForDevice(self.gpu_id)
|
||||
|
||||
nr_samples = self.nr_cells * 100 * self.cell_size
|
||||
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size) * 10
|
||||
# train = T.randn(self.nr_cells * 100, self.cell_size)
|
||||
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size)
|
||||
|
||||
self.index = faiss.GpuIndexIVFFlat(res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
|
||||
self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_L2)
|
||||
self.index.setNumProbes(self.probes)
|
||||
self.train(train)
|
||||
|
||||
@ -48,7 +48,7 @@ class FAISSIndex(object):
|
||||
self.index.reset()
|
||||
T.cuda.synchronize()
|
||||
|
||||
def add(self, other, positions=None, last=-1):
|
||||
def add(self, other, positions=None, last=None):
|
||||
other = ensure_gpu(other, self.gpu_id)
|
||||
|
||||
T.cuda.synchronize()
|
||||
@ -57,7 +57,7 @@ class FAISSIndex(object):
|
||||
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
|
||||
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
|
||||
else:
|
||||
other = other[:last, :]
|
||||
other = other[:last, :] if last is not None else other
|
||||
self.index.add_c(other.size(0), cast_float(ptr(other)))
|
||||
T.cuda.synchronize()
|
||||
|
||||
|
@ -8,7 +8,6 @@ import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
from .flann_index import FLANNIndex
|
||||
from .util import *
|
||||
import time
|
||||
|
||||
@ -44,11 +43,12 @@ class SparseMemory(nn.Module):
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
r = self.read_heads
|
||||
# The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell)
|
||||
# The visible memory size: (K * R read heads, forward and backward
|
||||
# temporal reads of size KL and least used memory cell)
|
||||
self.c = (r * self.K) + 1
|
||||
|
||||
if self.independent_linears:
|
||||
self.read_query_transform = nn.Linear(self.input_size, w*r)
|
||||
self.read_query_transform = nn.Linear(self.input_size, w * r)
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||
@ -72,11 +72,20 @@ class SparseMemory(nn.Module):
|
||||
if 'indexes' in hidden:
|
||||
[x.reset() for x in hidden['indexes']]
|
||||
else:
|
||||
# create new indexes
|
||||
hidden['indexes'] = \
|
||||
[FLANNIndex(cell_size=self.cell_size,
|
||||
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
|
||||
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||
# create new indexes, try to use FAISS, fall back to FLANN
|
||||
try:
|
||||
from .faiss_index import FAISSIndex
|
||||
hidden['indexes'] = \
|
||||
[FAISSIndex(cell_size=self.cell_size,
|
||||
nr_cells=self.mem_size, K=self.K, num_lists=self.num_lists,
|
||||
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||
except Exception as e:
|
||||
print("\nFalling back to FLANN (CPU). \nFor using faster, GPU based indexes, install FAISS: `conda install faiss-gpu -c pytorch`")
|
||||
from .flann_index import FLANNIndex
|
||||
hidden['indexes'] = \
|
||||
[FLANNIndex(cell_size=self.cell_size,
|
||||
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
|
||||
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||
|
||||
# add existing memory into indexes
|
||||
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
|
||||
@ -104,7 +113,7 @@ class SparseMemory(nn.Module):
|
||||
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||
'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(),
|
||||
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
|
||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
}
|
||||
@ -126,9 +135,10 @@ class SparseMemory(nn.Module):
|
||||
hidden['read_weights'].data.fill_(δ)
|
||||
hidden['write_weights'].data.fill_(δ)
|
||||
hidden['read_vectors'].data.fill_(δ)
|
||||
hidden['least_used_mem'].data.fill_(c+1+self.timestep)
|
||||
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
|
||||
hidden['usage'].data.fill_(δ)
|
||||
hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
hidden['read_positions'] = cuda(
|
||||
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
|
||||
return hidden
|
||||
|
||||
@ -147,8 +157,9 @@ class SparseMemory(nn.Module):
|
||||
hidden['indexes'][batch].reset()
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
||||
|
||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size-1
|
||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + 1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
|
||||
1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
||||
|
||||
return hidden
|
||||
|
||||
@ -177,7 +188,8 @@ class SparseMemory(nn.Module):
|
||||
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
|
||||
|
||||
# write into memory
|
||||
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||
hidden['visible_memory'] = hidden['visible_memory'] * \
|
||||
(1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||
hidden = self.write_into_sparse_memory(hidden)
|
||||
|
||||
return hidden
|
||||
@ -240,11 +252,11 @@ class SparseMemory(nn.Module):
|
||||
# sparse read
|
||||
read_vectors, positions, read_weights, visible_memory = \
|
||||
self.read_from_sparse_memory(
|
||||
hidden['memory'],
|
||||
hidden['indexes'],
|
||||
read_query,
|
||||
hidden['least_used_mem'],
|
||||
hidden['usage']
|
||||
hidden['memory'],
|
||||
hidden['indexes'],
|
||||
read_query,
|
||||
hidden['least_used_mem'],
|
||||
hidden['usage']
|
||||
)
|
||||
|
||||
hidden['read_positions'] = positions
|
||||
@ -276,11 +288,11 @@ class SparseMemory(nn.Module):
|
||||
else:
|
||||
ξ = self.interface_weights(ξ)
|
||||
# r read keys (b * r * w)
|
||||
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
|
||||
read_query = ξ[:, :r * w].contiguous().view(b, r, w)
|
||||
# write key (b * 1 * w)
|
||||
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
|
||||
write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c)
|
||||
interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
||||
|
||||
|
@ -46,11 +46,12 @@ class SparseTemporalMemory(nn.Module):
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
r = self.read_heads
|
||||
# The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell)
|
||||
# The visible memory size: (K * R read heads, forward and backward
|
||||
# temporal reads of size KL and least used memory cell)
|
||||
self.c = (r * self.K) + (self.KL * 2) + 1
|
||||
|
||||
if self.independent_linears:
|
||||
self.read_query_transform = nn.Linear(self.input_size, w*r)
|
||||
self.read_query_transform = nn.Linear(self.input_size, w * r)
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||
@ -75,10 +76,19 @@ class SparseTemporalMemory(nn.Module):
|
||||
[x.reset() for x in hidden['indexes']]
|
||||
else:
|
||||
# create new indexes
|
||||
hidden['indexes'] = \
|
||||
[FLANNIndex(cell_size=self.cell_size,
|
||||
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
|
||||
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||
try:
|
||||
from .faiss_index import FAISSIndex
|
||||
hidden['indexes'] = \
|
||||
[FAISSIndex(cell_size=self.cell_size,
|
||||
nr_cells=self.mem_size, K=self.K, num_lists=self.num_lists,
|
||||
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||
except Exception as e:
|
||||
print("\nFalling back to FLANN (CPU). \nFor using faster, GPU based indexes, install FAISS: `conda install faiss-gpu -c pytorch`")
|
||||
from .flann_index import FLANNIndex
|
||||
hidden['indexes'] = \
|
||||
[FLANNIndex(cell_size=self.cell_size,
|
||||
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
|
||||
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||
|
||||
# add existing memory into indexes
|
||||
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
|
||||
@ -103,13 +113,13 @@ class SparseTemporalMemory(nn.Module):
|
||||
# warning can be a huge chunk of contiguous memory
|
||||
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
|
||||
'rev_link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
|
||||
'precedence': cuda(T.zeros(b, self.KL*2).fill_(δ), gpu_id=self.gpu_id),
|
||||
'link_matrix': cuda(T.zeros(b, m, self.KL * 2), gpu_id=self.gpu_id),
|
||||
'rev_link_matrix': cuda(T.zeros(b, m, self.KL * 2), gpu_id=self.gpu_id),
|
||||
'precedence': cuda(T.zeros(b, self.KL * 2).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||
'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(),
|
||||
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
|
||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
}
|
||||
@ -137,9 +147,10 @@ class SparseTemporalMemory(nn.Module):
|
||||
hidden['read_weights'].data.fill_(δ)
|
||||
hidden['write_weights'].data.fill_(δ)
|
||||
hidden['read_vectors'].data.fill_(δ)
|
||||
hidden['least_used_mem'].data.fill_(c+1+self.timestep)
|
||||
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
|
||||
hidden['usage'].data.fill_(δ)
|
||||
hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
hidden['read_positions'] = cuda(
|
||||
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
|
||||
return hidden
|
||||
|
||||
@ -158,8 +169,9 @@ class SparseTemporalMemory(nn.Module):
|
||||
hidden['indexes'][batch].reset()
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
||||
|
||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size-1
|
||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + 1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
|
||||
1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
||||
|
||||
return hidden
|
||||
|
||||
@ -179,7 +191,8 @@ class SparseTemporalMemory(nn.Module):
|
||||
|
||||
link_matrix = (1 - write_weights_i) * link_matrix + write_weights_i * precedence_j
|
||||
|
||||
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (temporal_write_weights_j * precedence_dense_i)
|
||||
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + \
|
||||
(temporal_write_weights_j * precedence_dense_i)
|
||||
|
||||
return link_matrix * I, rev_link_matrix * I
|
||||
|
||||
@ -211,22 +224,23 @@ class SparseTemporalMemory(nn.Module):
|
||||
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
|
||||
|
||||
# write into memory
|
||||
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||
hidden['visible_memory'] = hidden['visible_memory'] * \
|
||||
(1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||
hidden = self.write_into_sparse_memory(hidden)
|
||||
|
||||
# update link_matrix and precedence
|
||||
(b, c) = write_weights.size()
|
||||
|
||||
# update link matrix
|
||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
|
||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads * self.K + 1:]
|
||||
hidden['link_matrix'], hidden['rev_link_matrix'] = \
|
||||
self.update_link_matrices(
|
||||
self.update_link_matrices(
|
||||
hidden['link_matrix'],
|
||||
hidden['rev_link_matrix'],
|
||||
hidden['write_weights'],
|
||||
hidden['precedence'],
|
||||
temporal_read_positions
|
||||
)
|
||||
)
|
||||
|
||||
# update precedence vector
|
||||
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
||||
@ -299,20 +313,20 @@ class SparseTemporalMemory(nn.Module):
|
||||
|
||||
def read(self, read_query, hidden):
|
||||
# get forward and backward weights
|
||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
|
||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads * self.K + 1:]
|
||||
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
||||
forward, backward = self.directional_weightings(hidden['link_matrix'], hidden['rev_link_matrix'], read_weights)
|
||||
|
||||
# sparse read
|
||||
read_vectors, positions, read_weights, visible_memory = \
|
||||
self.read_from_sparse_memory(
|
||||
hidden['memory'],
|
||||
hidden['indexes'],
|
||||
read_query,
|
||||
hidden['least_used_mem'],
|
||||
hidden['usage'],
|
||||
forward, backward,
|
||||
hidden['read_positions']
|
||||
hidden['memory'],
|
||||
hidden['indexes'],
|
||||
read_query,
|
||||
hidden['least_used_mem'],
|
||||
hidden['usage'],
|
||||
forward, backward,
|
||||
hidden['read_positions']
|
||||
)
|
||||
|
||||
hidden['read_positions'] = positions
|
||||
@ -344,11 +358,11 @@ class SparseTemporalMemory(nn.Module):
|
||||
else:
|
||||
ξ = self.interface_weights(ξ)
|
||||
# r read keys (b * r * w)
|
||||
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
|
||||
read_query = ξ[:, :r * w].contiguous().view(b, r, w)
|
||||
# write key (b * 1 * w)
|
||||
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
|
||||
write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c)
|
||||
interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
||||
|
||||
|
@ -138,7 +138,7 @@ def ptr(tensor):
|
||||
if T.is_tensor(tensor):
|
||||
return tensor.storage().data_ptr()
|
||||
elif hasattr(tensor, 'data'):
|
||||
return tensor.data.storage().data_ptr()
|
||||
return tensor.clone().data.storage().data_ptr()
|
||||
else:
|
||||
return tensor
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user