separate FLANN and FAISS indexes

This commit is contained in:
ixaxaar 2017-12-11 00:20:43 +05:30
parent 106d362e17
commit bdb763c28f
2 changed files with 62 additions and 6 deletions

View File

@ -9,26 +9,28 @@ from faiss.faiss import cast_integer_to_long_ptr as cast_long
from .util import *
class Index(object):
class FAISSIndex(object):
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_lists=32, probes=32, res=None, train=None, gpu_id=-1):
super(Index, self).__init__()
super(FAISSIndex, self).__init__()
self.cell_size = cell_size
self.nr_cells = nr_cells
self.probes = probes
self.K = K
self.num_lists = num_lists
self.gpu_id = gpu_id
self.res = res if res else faiss.StandardGpuResources()
self.res.setTempMemoryFraction(0.01)
res = res if res else faiss.StandardGpuResources()
res.setTempMemoryFraction(0.01)
if self.gpu_id != -1:
self.res.initializeForDevice(self.gpu_id)
res.initializeForDevice(self.gpu_id)
nr_samples = self.nr_cells * 100 * self.cell_size
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size) * 10
# train = T.randn(self.nr_cells * 100, self.cell_size)
self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
self.index = faiss.GpuIndexIVFFlat(res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
self.index.setNumProbes(self.probes)
self.train(train)
@ -81,3 +83,4 @@ class Index(object):
)
T.cuda.synchronize()
return (distances, (labels-1))

53
dnc/flann_index.py Normal file
View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch as T
from torch.autograd import Variable as var
import numpy as np
from pyflann import *
from .util import *
class FLANNIndex(object):
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_kdtrees=32, probes=32, gpu_id=-1):
super(FLANNIndex, self).__init__()
self.cell_size = cell_size
self.nr_cells = nr_cells
self.probes = probes
self.K = K
self.num_kdtrees = num_kdtrees
self.gpu_id = gpu_id
self.index = FLANN()
def add(self, other, positions=None, last=-1):
if isinstance(other, var):
other = other[:last, :].data.cpu().numpy()
elif isinstance(other, T.Tensor):
other = other[:last, :].cpu().numpy()
self.index.build_index(other, algorithm='kdtree', trees=self.num_kdtrees, checks=self.probes)
def search(self, query, k=None):
if isinstance(query, var):
query = query.data.cpu().numpy()
elif isinstance(query, T.Tensor):
query = query.cpu().numpy()
l, d = self.index.nn_index(query, num_neighbors=self.K if k is None else k)
distances = T.from_numpy(d).float()
labels = T.from_numpy(l).long()
if self.gpu_id != -1: distances = distances.cuda(self.gpu_id)
if self.gpu_id != -1: labels = labels.cuda(self.gpu_id)
return (distances, labels)
def reset(self):
self.index.delete_index()