Fixes #24: store entire memory after memory limit is reached

This commit is contained in:
ixaxaar 2017-12-21 13:46:38 +05:30
parent 08bd220852
commit 3db618edea
2 changed files with 8 additions and 2 deletions

View File

@ -64,6 +64,7 @@ class SparseMemory(nn.Module):
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
self.δ = 0.005 # minimum usage
self.timestep = 0
self.mem_limit_reached = False
def rebuild_indexes(self, hidden, erase=False):
b = hidden['memory'].size(0)
@ -95,6 +96,7 @@ class SparseMemory(nn.Module):
i.add(hidden['memory'][n], last=pos[n][-1])
else:
self.timestep = 0
self.mem_limit_reached = False
return hidden
@ -155,11 +157,12 @@ class SparseMemory(nn.Module):
for batch in range(b):
# update indexes
hidden['indexes'][batch].reset()
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None))
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
1) if mem_limit_reached else hidden['least_used_mem'] + 1
self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
return hidden

View File

@ -67,6 +67,7 @@ class SparseTemporalMemory(nn.Module):
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
self.δ = 0.005 # minimum usage
self.timestep = 0
self.mem_limit_reached = False
def rebuild_indexes(self, hidden, erase=False):
b = hidden['memory'].size(0)
@ -98,6 +99,7 @@ class SparseTemporalMemory(nn.Module):
i.add(hidden['memory'][n], last=pos[n][-1])
else:
self.timestep = 0
self.mem_limit_reached = False
return hidden
@ -167,11 +169,12 @@ class SparseTemporalMemory(nn.Module):
for batch in range(b):
# update indexes
hidden['indexes'][batch].reset()
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None))
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
1) if mem_limit_reached else hidden['least_used_mem'] + 1
self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
return hidden