diff --git a/dnc/sparse_memory.py b/dnc/sparse_memory.py index d06e975..76cb55e 100644 --- a/dnc/sparse_memory.py +++ b/dnc/sparse_memory.py @@ -64,6 +64,7 @@ class SparseMemory(nn.Module): self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n) self.δ = 0.005 # minimum usage self.timestep = 0 + self.mem_limit_reached = False def rebuild_indexes(self, hidden, erase=False): b = hidden['memory'].size(0) @@ -95,6 +96,7 @@ class SparseMemory(nn.Module): i.add(hidden['memory'][n], last=pos[n][-1]) else: self.timestep = 0 + self.mem_limit_reached = False return hidden @@ -155,11 +157,12 @@ class SparseMemory(nn.Module): for batch in range(b): # update indexes hidden['indexes'][batch].reset() - hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1]) + hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None)) mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1 hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + 1) if mem_limit_reached else hidden['least_used_mem'] + 1 + self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached return hidden diff --git a/dnc/sparse_temporal_memory.py b/dnc/sparse_temporal_memory.py index 2ddac8d..91555f0 100644 --- a/dnc/sparse_temporal_memory.py +++ b/dnc/sparse_temporal_memory.py @@ -67,6 +67,7 @@ class SparseTemporalMemory(nn.Module): self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n) self.δ = 0.005 # minimum usage self.timestep = 0 + self.mem_limit_reached = False def rebuild_indexes(self, hidden, erase=False): b = hidden['memory'].size(0) @@ -98,6 +99,7 @@ class SparseTemporalMemory(nn.Module): i.add(hidden['memory'][n], last=pos[n][-1]) else: self.timestep = 0 + self.mem_limit_reached = False return hidden @@ -167,11 +169,12 @@ class SparseTemporalMemory(nn.Module): for batch in range(b): # update indexes hidden['indexes'][batch].reset() - hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1]) + hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None)) mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1 hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c + 1) if mem_limit_reached else hidden['least_used_mem'] + 1 + self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached return hidden