47 lines
872 B
Python
47 lines
872 B
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import pytest
|
|
import numpy as np
|
|
|
|
import torch.nn as nn
|
|
import torch as T
|
|
from torch.autograd import Variable as var
|
|
import torch.nn.functional as F
|
|
from torch.nn.utils import clip_grad_norm_
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
|
|
import sys
|
|
import os
|
|
import math
|
|
import time
|
|
import functools
|
|
sys.path.insert(0, '.')
|
|
|
|
from pyflann import *
|
|
|
|
from dnc.flann_index import FLANNIndex
|
|
|
|
def test_indexes():
|
|
|
|
n = 30
|
|
cell_size=20
|
|
nr_cells=1024
|
|
K=10
|
|
probes=32
|
|
d = T.ones(n, cell_size)
|
|
q = T.ones(1, cell_size)
|
|
|
|
for gpu_id in (-1, -1):
|
|
i = FLANNIndex(cell_size=cell_size, nr_cells=nr_cells, K=K, probes=probes, gpu_id=gpu_id)
|
|
d = d if gpu_id == -1 else d.cuda(gpu_id)
|
|
|
|
i.add(d)
|
|
|
|
dist, labels = i.search(q*7)
|
|
|
|
assert dist.size() == T.Size([1,K])
|
|
assert labels.size() == T.Size([1, K])
|
|
|