Fixes #24: store entire memory after memory limit is reached
This commit is contained in:
parent
08bd220852
commit
3db618edea
@ -64,6 +64,7 @@ class SparseMemory(nn.Module):
|
||||
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
self.δ = 0.005 # minimum usage
|
||||
self.timestep = 0
|
||||
self.mem_limit_reached = False
|
||||
|
||||
def rebuild_indexes(self, hidden, erase=False):
|
||||
b = hidden['memory'].size(0)
|
||||
@ -95,6 +96,7 @@ class SparseMemory(nn.Module):
|
||||
i.add(hidden['memory'][n], last=pos[n][-1])
|
||||
else:
|
||||
self.timestep = 0
|
||||
self.mem_limit_reached = False
|
||||
|
||||
return hidden
|
||||
|
||||
@ -155,11 +157,12 @@ class SparseMemory(nn.Module):
|
||||
for batch in range(b):
|
||||
# update indexes
|
||||
hidden['indexes'][batch].reset()
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None))
|
||||
|
||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
|
||||
1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
||||
self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
|
||||
|
||||
return hidden
|
||||
|
||||
|
@ -67,6 +67,7 @@ class SparseTemporalMemory(nn.Module):
|
||||
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
self.δ = 0.005 # minimum usage
|
||||
self.timestep = 0
|
||||
self.mem_limit_reached = False
|
||||
|
||||
def rebuild_indexes(self, hidden, erase=False):
|
||||
b = hidden['memory'].size(0)
|
||||
@ -98,6 +99,7 @@ class SparseTemporalMemory(nn.Module):
|
||||
i.add(hidden['memory'][n], last=pos[n][-1])
|
||||
else:
|
||||
self.timestep = 0
|
||||
self.mem_limit_reached = False
|
||||
|
||||
return hidden
|
||||
|
||||
@ -167,11 +169,12 @@ class SparseTemporalMemory(nn.Module):
|
||||
for batch in range(b):
|
||||
# update indexes
|
||||
hidden['indexes'][batch].reset()
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None))
|
||||
|
||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
|
||||
1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
||||
self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
|
||||
|
||||
return hidden
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user