54 lines
1.4 KiB
Python
54 lines
1.4 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import torch.nn as nn
|
|
import torch as T
|
|
from torch.autograd import Variable as var
|
|
import numpy as np
|
|
|
|
from pyflann import *
|
|
|
|
from .util import *
|
|
|
|
class FLANNIndex(object):
|
|
|
|
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_kdtrees=32, probes=32, gpu_id=-1):
|
|
super(FLANNIndex, self).__init__()
|
|
self.cell_size = cell_size
|
|
self.nr_cells = nr_cells
|
|
self.probes = probes
|
|
self.K = K
|
|
self.num_kdtrees = num_kdtrees
|
|
self.gpu_id = gpu_id
|
|
|
|
self.index = FLANN()
|
|
|
|
def add(self, other, positions=None, last=-1):
|
|
if isinstance(other, var):
|
|
other = other[:last, :].data.cpu().numpy()
|
|
elif isinstance(other, T.Tensor):
|
|
other = other[:last, :].cpu().numpy()
|
|
|
|
self.index.build_index(other, algorithm='kdtree', trees=self.num_kdtrees, checks=self.probes)
|
|
|
|
def search(self, query, k=None):
|
|
if isinstance(query, var):
|
|
query = query.data.cpu().numpy()
|
|
elif isinstance(query, T.Tensor):
|
|
query = query.cpu().numpy()
|
|
|
|
l, d = self.index.nn_index(query, num_neighbors=self.K if k is None else k, cores=4)
|
|
|
|
distances = T.from_numpy(d).float()
|
|
labels = T.from_numpy(l).long()
|
|
|
|
if self.gpu_id != -1: distances = distances.cuda(self.gpu_id)
|
|
if self.gpu_id != -1: labels = labels.cuda(self.gpu_id)
|
|
|
|
return (distances, labels)
|
|
|
|
|
|
def reset(self):
|
|
self.index.delete_index()
|
|
|