supports only one read key
This commit is contained in:
parent
00b561e4da
commit
14f0f67b2d
@ -109,21 +109,15 @@ class SparseMemory(nn.Module):
|
|||||||
read_weights = []
|
read_weights = []
|
||||||
|
|
||||||
for batch in range(keys.shape[0]):
|
for batch in range(keys.shape[0]):
|
||||||
d = []; rv = []; p = []
|
|
||||||
|
|
||||||
# search nearest neighbor for each key
|
positions, distances = dict[batch].nn_index(keys[batch, 0, :], num_neighbors=self.K)
|
||||||
for key in range(keys.shape[1]):
|
distances = distances / max(distances)
|
||||||
positions, distances = dict[batch].nn_index(keys[batch, key, :], num_neighbors=self.K)
|
positions = positions[0] if self.K > 1 else positions
|
||||||
distances = distances / max(distances)
|
read_vector = [sparse[batch, p] for p in list(positions)]
|
||||||
positions = positions[0] if self.K > 1 else positions
|
|
||||||
read_vector = [sparse[batch, p] for p in list(positions)]
|
|
||||||
|
|
||||||
d.append(distances)
|
read_weights.append(distances)
|
||||||
rv.append(read_vector)
|
read_vectors.append(read_vector)
|
||||||
p.append(positions)
|
read_positions.append(positions)
|
||||||
read_weights.append(d)
|
|
||||||
read_vectors.append(rv)
|
|
||||||
read_positions.append(p)
|
|
||||||
|
|
||||||
read_vectors = cudavec(np.array(read_vectors), gpu_id=self.gpu_id)
|
read_vectors = cudavec(np.array(read_vectors), gpu_id=self.gpu_id)
|
||||||
read_weights = cudavec(np.array(read_weights), gpu_id=self.gpu_id)
|
read_weights = cudavec(np.array(read_weights), gpu_id=self.gpu_id)
|
||||||
|
Loading…
Reference in New Issue
Block a user