separate FLANN and FAISS indexes
This commit is contained in:
parent
106d362e17
commit
bdb763c28f
@ -9,26 +9,28 @@ from faiss.faiss import cast_integer_to_long_ptr as cast_long
|
||||
|
||||
from .util import *
|
||||
|
||||
class Index(object):
|
||||
|
||||
class FAISSIndex(object):
|
||||
|
||||
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_lists=32, probes=32, res=None, train=None, gpu_id=-1):
|
||||
super(Index, self).__init__()
|
||||
super(FAISSIndex, self).__init__()
|
||||
self.cell_size = cell_size
|
||||
self.nr_cells = nr_cells
|
||||
self.probes = probes
|
||||
self.K = K
|
||||
self.num_lists = num_lists
|
||||
self.gpu_id = gpu_id
|
||||
self.res = res if res else faiss.StandardGpuResources()
|
||||
self.res.setTempMemoryFraction(0.01)
|
||||
|
||||
res = res if res else faiss.StandardGpuResources()
|
||||
res.setTempMemoryFraction(0.01)
|
||||
if self.gpu_id != -1:
|
||||
self.res.initializeForDevice(self.gpu_id)
|
||||
res.initializeForDevice(self.gpu_id)
|
||||
|
||||
nr_samples = self.nr_cells * 100 * self.cell_size
|
||||
train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size) * 10
|
||||
# train = T.randn(self.nr_cells * 100, self.cell_size)
|
||||
|
||||
self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
|
||||
self.index = faiss.GpuIndexIVFFlat(res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
|
||||
self.index.setNumProbes(self.probes)
|
||||
self.train(train)
|
||||
|
||||
@ -81,3 +83,4 @@ class Index(object):
|
||||
)
|
||||
T.cuda.synchronize()
|
||||
return (distances, (labels-1))
|
||||
|
53
dnc/flann_index.py
Normal file
53
dnc/flann_index.py
Normal file
@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import numpy as np
|
||||
|
||||
from pyflann import *
|
||||
|
||||
from .util import *
|
||||
|
||||
class FLANNIndex(object):
|
||||
|
||||
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_kdtrees=32, probes=32, gpu_id=-1):
|
||||
super(FLANNIndex, self).__init__()
|
||||
self.cell_size = cell_size
|
||||
self.nr_cells = nr_cells
|
||||
self.probes = probes
|
||||
self.K = K
|
||||
self.num_kdtrees = num_kdtrees
|
||||
self.gpu_id = gpu_id
|
||||
|
||||
self.index = FLANN()
|
||||
|
||||
def add(self, other, positions=None, last=-1):
|
||||
if isinstance(other, var):
|
||||
other = other[:last, :].data.cpu().numpy()
|
||||
elif isinstance(other, T.Tensor):
|
||||
other = other[:last, :].cpu().numpy()
|
||||
|
||||
self.index.build_index(other, algorithm='kdtree', trees=self.num_kdtrees, checks=self.probes)
|
||||
|
||||
def search(self, query, k=None):
|
||||
if isinstance(query, var):
|
||||
query = query.data.cpu().numpy()
|
||||
elif isinstance(query, T.Tensor):
|
||||
query = query.cpu().numpy()
|
||||
|
||||
l, d = self.index.nn_index(query, num_neighbors=self.K if k is None else k)
|
||||
|
||||
distances = T.from_numpy(d).float()
|
||||
labels = T.from_numpy(l).long()
|
||||
|
||||
if self.gpu_id != -1: distances = distances.cuda(self.gpu_id)
|
||||
if self.gpu_id != -1: labels = labels.cuda(self.gpu_id)
|
||||
|
||||
return (distances, labels)
|
||||
|
||||
|
||||
def reset(self):
|
||||
self.index.delete_index()
|
||||
|
Loading…
Reference in New Issue
Block a user