supports only one read key
This commit is contained in:
parent
00b561e4da
commit
14f0f67b2d
@ -109,21 +109,15 @@ class SparseMemory(nn.Module):
|
||||
read_weights = []
|
||||
|
||||
for batch in range(keys.shape[0]):
|
||||
d = []; rv = []; p = []
|
||||
|
||||
# search nearest neighbor for each key
|
||||
for key in range(keys.shape[1]):
|
||||
positions, distances = dict[batch].nn_index(keys[batch, key, :], num_neighbors=self.K)
|
||||
positions, distances = dict[batch].nn_index(keys[batch, 0, :], num_neighbors=self.K)
|
||||
distances = distances / max(distances)
|
||||
positions = positions[0] if self.K > 1 else positions
|
||||
read_vector = [sparse[batch, p] for p in list(positions)]
|
||||
|
||||
d.append(distances)
|
||||
rv.append(read_vector)
|
||||
p.append(positions)
|
||||
read_weights.append(d)
|
||||
read_vectors.append(rv)
|
||||
read_positions.append(p)
|
||||
read_weights.append(distances)
|
||||
read_vectors.append(read_vector)
|
||||
read_positions.append(positions)
|
||||
|
||||
read_vectors = cudavec(np.array(read_vectors), gpu_id=self.gpu_id)
|
||||
read_weights = cudavec(np.array(read_weights), gpu_id=self.gpu_id)
|
||||
|
Loading…
Reference in New Issue
Block a user