Remove redundant code
This commit is contained in:
parent
64520e1dcf
commit
63d49afe40
@ -61,7 +61,6 @@ class SparseMemory(nn.Module):
|
||||
|
||||
def rebuild_indexes(self, hidden, erase=False):
|
||||
b = hidden['memory'].size(0)
|
||||
t = time.time()
|
||||
|
||||
# if indexes already exist, we reset them
|
||||
if 'indexes' in hidden:
|
||||
@ -108,10 +107,9 @@ class SparseMemory(nn.Module):
|
||||
hidden['last_used_mem'] = hidden['last_used_mem'].clone()
|
||||
hidden['usage'] = hidden['usage'].clone()
|
||||
hidden['read_positions'] = hidden['read_positions'].clone()
|
||||
hidden = self.rebuild_indexes(hidden)
|
||||
hidden = self.rebuild_indexes(hidden, erase)
|
||||
|
||||
if erase:
|
||||
hidden = self.rebuild_indexes(hidden, erase)
|
||||
hidden['memory'].data.fill_(δ)
|
||||
hidden['read_weights'].data.fill_(δ)
|
||||
hidden['write_weights'].data.fill_(δ)
|
||||
@ -185,16 +183,15 @@ class SparseMemory(nn.Module):
|
||||
# non-differentiable operations
|
||||
for batch in range(b):
|
||||
distances, positions = indexes[batch].search(keys[batch])
|
||||
distances = F.softmax(distances)
|
||||
|
||||
read_weights.append(distances)
|
||||
read_positions.append(T.clamp(positions, 0, self.mem_size - 1))
|
||||
|
||||
# add weight of 0 for least used mem block
|
||||
read_weights = T.stack(read_weights, 0)
|
||||
new_block = read_weights.data.new(b, 1, 1)
|
||||
new_block = read_weights.new(b, 1, 1)
|
||||
new_block.fill_(0)
|
||||
read_weights = T.cat([read_weights, new_block], 2)
|
||||
read_weights = F.softmax(var(read_weights))
|
||||
|
||||
# add least used mem to read positions
|
||||
read_positions = T.stack(read_positions, 0)
|
||||
|
Loading…
Reference in New Issue
Block a user