Remove redundant code

This commit is contained in:
ixaxaar 2017-12-04 21:11:30 +05:30
parent 64520e1dcf
commit 63d49afe40

View File

@ -61,7 +61,6 @@ class SparseMemory(nn.Module):
def rebuild_indexes(self, hidden, erase=False):
b = hidden['memory'].size(0)
t = time.time()
# if indexes already exist, we reset them
if 'indexes' in hidden:
@ -108,10 +107,9 @@ class SparseMemory(nn.Module):
hidden['last_used_mem'] = hidden['last_used_mem'].clone()
hidden['usage'] = hidden['usage'].clone()
hidden['read_positions'] = hidden['read_positions'].clone()
hidden = self.rebuild_indexes(hidden)
hidden = self.rebuild_indexes(hidden, erase)
if erase:
hidden = self.rebuild_indexes(hidden, erase)
hidden['memory'].data.fill_(δ)
hidden['read_weights'].data.fill_(δ)
hidden['write_weights'].data.fill_(δ)
@ -185,16 +183,15 @@ class SparseMemory(nn.Module):
# non-differentiable operations
for batch in range(b):
distances, positions = indexes[batch].search(keys[batch])
distances = F.softmax(distances)
read_weights.append(distances)
read_positions.append(T.clamp(positions, 0, self.mem_size - 1))
# add weight of 0 for least used mem block
read_weights = T.stack(read_weights, 0)
new_block = read_weights.data.new(b, 1, 1)
new_block = read_weights.new(b, 1, 1)
new_block.fill_(0)
read_weights = T.cat([read_weights, new_block], 2)
read_weights = F.softmax(var(read_weights))
# add least used mem to read positions
read_positions = T.stack(read_positions, 0)