erase before write

This commit is contained in:
ixaxaar 2017-12-10 12:05:13 +05:30
parent 146916d6ff
commit 9734e9014e

View File

@ -171,7 +171,9 @@ class SparseMemory(nn.Module):
# store the write weights
hidden['write_weights'].scatter_(1, hidden['read_positions'], write_weights)
hidden['visible_memory'] = hidden['visible_memory'] + T.bmm(write_weights.unsqueeze(2), write_vector)
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
hidden = self.write_into_sparse_memory(hidden)
return hidden