pytorch-dnc/dnc/faiss_index.py

87 lines
2.5 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import faiss
from faiss import cast_integer_to_float_ptr as cast_float
from faiss import cast_integer_to_int_ptr as cast_int
from faiss import cast_integer_to_long_ptr as cast_long
from .util import *
class FAISSIndex(object):
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_lists=32, probes=32, res=None, train=None, gpu_id=-1):
super(FAISSIndex, self).__init__()
self.cell_size = cell_size
self.nr_cells = nr_cells
self.probes = probes
self.K = K
self.num_lists = num_lists
self.gpu_id = gpu_id
# BEWARE: if this variable gets deallocated, FAISS crashes
self.res = res if res else faiss.StandardGpuResources()
self.res.setTempMemoryFraction(0.01)
if self.gpu_id != -1:
self.res.initializeForDevice(self.gpu_id)
nr_samples = self.nr_cells * 100 * self.cell_size
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size)
self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_L2)
self.index.setNumProbes(self.probes)
self.train(train)
def cuda(self, gpu_id):
self.gpu_id = gpu_id
def train(self, train):
train = ensure_gpu(train, -1)
T.cuda.synchronize()
self.index.train_c(self.nr_cells, cast_float(ptr(train)))
T.cuda.synchronize()
def reset(self):
T.cuda.synchronize()
self.index.reset()
T.cuda.synchronize()
def add(self, other, positions=None, last=None):
other = ensure_gpu(other, self.gpu_id)
T.cuda.synchronize()
if positions is not None:
positions = ensure_gpu(positions, self.gpu_id)
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
else:
other = other[:last, :] if last is not None else other
self.index.add_c(other.size(0), cast_float(ptr(other)))
T.cuda.synchronize()
def search(self, query, k=None):
query = ensure_gpu(query, self.gpu_id)
k = k if k else self.K
(b,n) = query.size()
distances = T.FloatTensor(b, k)
labels = T.LongTensor(b, k)
if self.gpu_id != -1: distances = distances.cuda(self.gpu_id)
if self.gpu_id != -1: labels = labels.cuda(self.gpu_id)
T.cuda.synchronize()
self.index.search_c(
b,
cast_float(ptr(query)),
k,
cast_float(ptr(distances)),
cast_long(ptr(labels))
)
T.cuda.synchronize()
return (distances, (labels-1))