supports only one read key

This commit is contained in:
ixaxaar 2017-11-27 16:45:49 +05:30
parent 00b561e4da
commit 14f0f67b2d

View File

@ -109,21 +109,15 @@ class SparseMemory(nn.Module):
read_weights = [] read_weights = []
for batch in range(keys.shape[0]): for batch in range(keys.shape[0]):
d = []; rv = []; p = []
# search nearest neighbor for each key positions, distances = dict[batch].nn_index(keys[batch, 0, :], num_neighbors=self.K)
for key in range(keys.shape[1]): distances = distances / max(distances)
positions, distances = dict[batch].nn_index(keys[batch, key, :], num_neighbors=self.K) positions = positions[0] if self.K > 1 else positions
distances = distances / max(distances) read_vector = [sparse[batch, p] for p in list(positions)]
positions = positions[0] if self.K > 1 else positions
read_vector = [sparse[batch, p] for p in list(positions)]
d.append(distances) read_weights.append(distances)
rv.append(read_vector) read_vectors.append(read_vector)
p.append(positions) read_positions.append(positions)
read_weights.append(d)
read_vectors.append(rv)
read_positions.append(p)
read_vectors = cudavec(np.array(read_vectors), gpu_id=self.gpu_id) read_vectors = cudavec(np.array(read_vectors), gpu_id=self.gpu_id)
read_weights = cudavec(np.array(read_weights), gpu_id=self.gpu_id) read_weights = cudavec(np.array(read_weights), gpu_id=self.gpu_id)