From 2c359e9a86df4eb2e68ed474c1d8639941a05c1d Mon Sep 17 00:00:00 2001 From: ixaxaar Date: Wed, 20 Dec 2017 02:08:34 +0530 Subject: [PATCH] Make FAISS work properly, fall back to flann when not available, fixes #23 --- README.md | 15 +++++++ dnc/faiss_index.py | 24 ++++++------ dnc/sparse_memory.py | 56 +++++++++++++++----------- dnc/sparse_temporal_memory.py | 74 +++++++++++++++++++++-------------- dnc/util.py | 2 +- 5 files changed, 106 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 42884de..a6bd058 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,12 @@ pip install -r ./requirements.txt pip install -e . ``` +For using fully GPU based SDNCs or SAMs, install FAISS: + +```bash +conda install faiss-gpu -c pytorch +``` + `pytest` is required to run the test ## Architecure @@ -465,6 +471,15 @@ make -j 4 sudo make install ``` +FAISS can be installed using: + +```bash +conda install faiss-gpu -c pytorch +``` + +FAISS is much faster, has a GPU implementation and is interoperable with pytorch tensors. +We try to use FAISS by default, in absence of which we fall back to FLANN. + 2. An alternative to FLANN is [FAISS](https://github.com/facebookresearch/faiss), which is much faster and interoperable with torch cuda tensors (but is difficult to distribute, see [dnc/faiss_index.py](dnc/faiss_index.py)). 3. `nan`s in the gradients are common, try with different batch sizes diff --git a/dnc/faiss_index.py b/dnc/faiss_index.py index 00e5b38..a7dc516 100644 --- a/dnc/faiss_index.py +++ b/dnc/faiss_index.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from faiss import faiss +import faiss -from faiss.faiss import cast_integer_to_float_ptr as cast_float -from faiss.faiss import cast_integer_to_int_ptr as cast_int -from faiss.faiss import cast_integer_to_long_ptr as cast_long +from faiss import cast_integer_to_float_ptr as cast_float +from faiss import cast_integer_to_int_ptr as cast_int +from faiss import cast_integer_to_long_ptr as cast_long from .util import * @@ -21,16 +21,16 @@ class FAISSIndex(object): self.num_lists = num_lists self.gpu_id = gpu_id - res = res if res else faiss.StandardGpuResources() - res.setTempMemoryFraction(0.01) + # BEWARE: if this variable gets deallocated, FAISS crashes + self.res = res if res else faiss.StandardGpuResources() + self.res.setTempMemoryFraction(0.01) if self.gpu_id != -1: - res.initializeForDevice(self.gpu_id) + self.res.initializeForDevice(self.gpu_id) nr_samples = self.nr_cells * 100 * self.cell_size - train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size) * 10 - # train = T.randn(self.nr_cells * 100, self.cell_size) + train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size) - self.index = faiss.GpuIndexIVFFlat(res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT) + self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_L2) self.index.setNumProbes(self.probes) self.train(train) @@ -48,7 +48,7 @@ class FAISSIndex(object): self.index.reset() T.cuda.synchronize() - def add(self, other, positions=None, last=-1): + def add(self, other, positions=None, last=None): other = ensure_gpu(other, self.gpu_id) T.cuda.synchronize() @@ -57,7 +57,7 @@ class FAISSIndex(object): assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors" self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1))) else: - other = other[:last, :] + other = other[:last, :] if last is not None else other self.index.add_c(other.size(0), cast_float(ptr(other))) T.cuda.synchronize() diff --git a/dnc/sparse_memory.py b/dnc/sparse_memory.py index f8d5155..d06e975 100644 --- a/dnc/sparse_memory.py +++ b/dnc/sparse_memory.py @@ -8,7 +8,6 @@ import torch.nn.functional as F import numpy as np import math -from .flann_index import FLANNIndex from .util import * import time @@ -44,11 +43,12 @@ class SparseMemory(nn.Module): m = self.mem_size w = self.cell_size r = self.read_heads - # The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell) + # The visible memory size: (K * R read heads, forward and backward + # temporal reads of size KL and least used memory cell) self.c = (r * self.K) + 1 if self.independent_linears: - self.read_query_transform = nn.Linear(self.input_size, w*r) + self.read_query_transform = nn.Linear(self.input_size, w * r) self.write_vector_transform = nn.Linear(self.input_size, w) self.interpolation_gate_transform = nn.Linear(self.input_size, self.c) self.write_gate_transform = nn.Linear(self.input_size, 1) @@ -72,11 +72,20 @@ class SparseMemory(nn.Module): if 'indexes' in hidden: [x.reset() for x in hidden['indexes']] else: - # create new indexes - hidden['indexes'] = \ - [FLANNIndex(cell_size=self.cell_size, - nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists, - probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)] + # create new indexes, try to use FAISS, fall back to FLANN + try: + from .faiss_index import FAISSIndex + hidden['indexes'] = \ + [FAISSIndex(cell_size=self.cell_size, + nr_cells=self.mem_size, K=self.K, num_lists=self.num_lists, + probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)] + except Exception as e: + print("\nFalling back to FLANN (CPU). \nFor using faster, GPU based indexes, install FAISS: `conda install faiss-gpu -c pytorch`") + from .flann_index import FLANNIndex + hidden['indexes'] = \ + [FLANNIndex(cell_size=self.cell_size, + nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists, + probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)] # add existing memory into indexes pos = hidden['read_positions'].squeeze().data.cpu().numpy() @@ -104,7 +113,7 @@ class SparseMemory(nn.Module): 'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id), 'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id), 'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id), - 'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(), + 'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(), 'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id), 'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long() } @@ -126,9 +135,10 @@ class SparseMemory(nn.Module): hidden['read_weights'].data.fill_(δ) hidden['write_weights'].data.fill_(δ) hidden['read_vectors'].data.fill_(δ) - hidden['least_used_mem'].data.fill_(c+1+self.timestep) + hidden['least_used_mem'].data.fill_(c + 1 + self.timestep) hidden['usage'].data.fill_(δ) - hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long() + hidden['read_positions'] = cuda( + T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long() return hidden @@ -147,8 +157,9 @@ class SparseMemory(nn.Module): hidden['indexes'][batch].reset() hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1]) - mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size-1 - hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + 1) if mem_limit_reached else hidden['least_used_mem'] + 1 + mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1 + hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + + 1) if mem_limit_reached else hidden['least_used_mem'] + 1 return hidden @@ -177,7 +188,8 @@ class SparseMemory(nn.Module): erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size()) # write into memory - hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector) + hidden['visible_memory'] = hidden['visible_memory'] * \ + (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector) hidden = self.write_into_sparse_memory(hidden) return hidden @@ -240,11 +252,11 @@ class SparseMemory(nn.Module): # sparse read read_vectors, positions, read_weights, visible_memory = \ self.read_from_sparse_memory( - hidden['memory'], - hidden['indexes'], - read_query, - hidden['least_used_mem'], - hidden['usage'] + hidden['memory'], + hidden['indexes'], + read_query, + hidden['least_used_mem'], + hidden['usage'] ) hidden['read_positions'] = positions @@ -276,11 +288,11 @@ class SparseMemory(nn.Module): else: ξ = self.interface_weights(ξ) # r read keys (b * r * w) - read_query = ξ[:, :r*w].contiguous().view(b, r, w) + read_query = ξ[:, :r * w].contiguous().view(b, r, w) # write key (b * 1 * w) - write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w) + write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w) # write vector (b * 1 * r) - interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c) + interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c) # write gate (b * 1) write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1) diff --git a/dnc/sparse_temporal_memory.py b/dnc/sparse_temporal_memory.py index 1154a40..2ddac8d 100644 --- a/dnc/sparse_temporal_memory.py +++ b/dnc/sparse_temporal_memory.py @@ -46,11 +46,12 @@ class SparseTemporalMemory(nn.Module): m = self.mem_size w = self.cell_size r = self.read_heads - # The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell) + # The visible memory size: (K * R read heads, forward and backward + # temporal reads of size KL and least used memory cell) self.c = (r * self.K) + (self.KL * 2) + 1 if self.independent_linears: - self.read_query_transform = nn.Linear(self.input_size, w*r) + self.read_query_transform = nn.Linear(self.input_size, w * r) self.write_vector_transform = nn.Linear(self.input_size, w) self.interpolation_gate_transform = nn.Linear(self.input_size, self.c) self.write_gate_transform = nn.Linear(self.input_size, 1) @@ -75,10 +76,19 @@ class SparseTemporalMemory(nn.Module): [x.reset() for x in hidden['indexes']] else: # create new indexes - hidden['indexes'] = \ - [FLANNIndex(cell_size=self.cell_size, - nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists, - probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)] + try: + from .faiss_index import FAISSIndex + hidden['indexes'] = \ + [FAISSIndex(cell_size=self.cell_size, + nr_cells=self.mem_size, K=self.K, num_lists=self.num_lists, + probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)] + except Exception as e: + print("\nFalling back to FLANN (CPU). \nFor using faster, GPU based indexes, install FAISS: `conda install faiss-gpu -c pytorch`") + from .flann_index import FLANNIndex + hidden['indexes'] = \ + [FLANNIndex(cell_size=self.cell_size, + nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists, + probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)] # add existing memory into indexes pos = hidden['read_positions'].squeeze().data.cpu().numpy() @@ -103,13 +113,13 @@ class SparseTemporalMemory(nn.Module): # warning can be a huge chunk of contiguous memory 'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id), 'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id), - 'link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id), - 'rev_link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id), - 'precedence': cuda(T.zeros(b, self.KL*2).fill_(δ), gpu_id=self.gpu_id), + 'link_matrix': cuda(T.zeros(b, m, self.KL * 2), gpu_id=self.gpu_id), + 'rev_link_matrix': cuda(T.zeros(b, m, self.KL * 2), gpu_id=self.gpu_id), + 'precedence': cuda(T.zeros(b, self.KL * 2).fill_(δ), gpu_id=self.gpu_id), 'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id), 'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id), 'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id), - 'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(), + 'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(), 'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id), 'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long() } @@ -137,9 +147,10 @@ class SparseTemporalMemory(nn.Module): hidden['read_weights'].data.fill_(δ) hidden['write_weights'].data.fill_(δ) hidden['read_vectors'].data.fill_(δ) - hidden['least_used_mem'].data.fill_(c+1+self.timestep) + hidden['least_used_mem'].data.fill_(c + 1 + self.timestep) hidden['usage'].data.fill_(δ) - hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long() + hidden['read_positions'] = cuda( + T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long() return hidden @@ -158,8 +169,9 @@ class SparseTemporalMemory(nn.Module): hidden['indexes'][batch].reset() hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1]) - mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size-1 - hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + 1) if mem_limit_reached else hidden['least_used_mem'] + 1 + mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1 + hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + + 1) if mem_limit_reached else hidden['least_used_mem'] + 1 return hidden @@ -179,7 +191,8 @@ class SparseTemporalMemory(nn.Module): link_matrix = (1 - write_weights_i) * link_matrix + write_weights_i * precedence_j - rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (temporal_write_weights_j * precedence_dense_i) + rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + \ + (temporal_write_weights_j * precedence_dense_i) return link_matrix * I, rev_link_matrix * I @@ -211,22 +224,23 @@ class SparseTemporalMemory(nn.Module): erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size()) # write into memory - hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector) + hidden['visible_memory'] = hidden['visible_memory'] * \ + (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector) hidden = self.write_into_sparse_memory(hidden) # update link_matrix and precedence (b, c) = write_weights.size() # update link matrix - temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:] + temporal_read_positions = hidden['read_positions'][:, self.read_heads * self.K + 1:] hidden['link_matrix'], hidden['rev_link_matrix'] = \ - self.update_link_matrices( + self.update_link_matrices( hidden['link_matrix'], hidden['rev_link_matrix'], hidden['write_weights'], hidden['precedence'], temporal_read_positions - ) + ) # update precedence vector read_weights = hidden['read_weights'].gather(1, temporal_read_positions) @@ -299,20 +313,20 @@ class SparseTemporalMemory(nn.Module): def read(self, read_query, hidden): # get forward and backward weights - temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:] + temporal_read_positions = hidden['read_positions'][:, self.read_heads * self.K + 1:] read_weights = hidden['read_weights'].gather(1, temporal_read_positions) forward, backward = self.directional_weightings(hidden['link_matrix'], hidden['rev_link_matrix'], read_weights) # sparse read read_vectors, positions, read_weights, visible_memory = \ self.read_from_sparse_memory( - hidden['memory'], - hidden['indexes'], - read_query, - hidden['least_used_mem'], - hidden['usage'], - forward, backward, - hidden['read_positions'] + hidden['memory'], + hidden['indexes'], + read_query, + hidden['least_used_mem'], + hidden['usage'], + forward, backward, + hidden['read_positions'] ) hidden['read_positions'] = positions @@ -344,11 +358,11 @@ class SparseTemporalMemory(nn.Module): else: ξ = self.interface_weights(ξ) # r read keys (b * r * w) - read_query = ξ[:, :r*w].contiguous().view(b, r, w) + read_query = ξ[:, :r * w].contiguous().view(b, r, w) # write key (b * 1 * w) - write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w) + write_vector = ξ[:, r * w: r * w + w].contiguous().view(b, 1, w) # write vector (b * 1 * r) - interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c) + interpolation_gate = F.sigmoid(ξ[:, r * w + w: r * w + w + c]).contiguous().view(b, c) # write gate (b * 1) write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1) diff --git a/dnc/util.py b/dnc/util.py index feaeead..5602ceb 100644 --- a/dnc/util.py +++ b/dnc/util.py @@ -138,7 +138,7 @@ def ptr(tensor): if T.is_tensor(tensor): return tensor.storage().data_ptr() elif hasattr(tensor, 'data'): - return tensor.data.storage().data_ptr() + return tensor.clone().data.storage().data_ptr() else: return tensor