commit
4115e69155
@ -64,6 +64,7 @@ class SparseMemory(nn.Module):
|
||||
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
self.δ = 0.005 # minimum usage
|
||||
self.timestep = 0
|
||||
self.mem_limit_reached = False
|
||||
|
||||
def rebuild_indexes(self, hidden, erase=False):
|
||||
b = hidden['memory'].size(0)
|
||||
@ -95,6 +96,7 @@ class SparseMemory(nn.Module):
|
||||
i.add(hidden['memory'][n], last=pos[n][-1])
|
||||
else:
|
||||
self.timestep = 0
|
||||
self.mem_limit_reached = False
|
||||
|
||||
return hidden
|
||||
|
||||
@ -114,7 +116,7 @@ class SparseMemory(nn.Module):
|
||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
|
||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'usage': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
|
||||
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
}
|
||||
hidden = self.rebuild_indexes(hidden, erase=True)
|
||||
@ -135,10 +137,10 @@ class SparseMemory(nn.Module):
|
||||
hidden['read_weights'].data.fill_(δ)
|
||||
hidden['write_weights'].data.fill_(δ)
|
||||
hidden['read_vectors'].data.fill_(δ)
|
||||
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
|
||||
hidden['usage'].data.fill_(δ)
|
||||
hidden['least_used_mem'].data.fill_(c + 1)
|
||||
hidden['usage'].data.fill_(0)
|
||||
hidden['read_positions'] = cuda(
|
||||
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
|
||||
return hidden
|
||||
|
||||
@ -155,17 +157,18 @@ class SparseMemory(nn.Module):
|
||||
for batch in range(b):
|
||||
# update indexes
|
||||
hidden['indexes'][batch].reset()
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None))
|
||||
|
||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
|
||||
1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
||||
self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
|
||||
|
||||
return hidden
|
||||
|
||||
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
||||
|
||||
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
||||
# encourage read and write in the first timestep
|
||||
if self.timestep == 1: read_weights = read_weights + 1
|
||||
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
|
||||
|
||||
hidden['usage'], I = self.update_usage(
|
||||
@ -192,6 +195,9 @@ class SparseMemory(nn.Module):
|
||||
(1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||
hidden = self.write_into_sparse_memory(hidden)
|
||||
|
||||
# update least used memory cell
|
||||
hidden['least_used_mem'] = T.topk(hidden['usage'], 1, dim=-1, largest=False)[1]
|
||||
|
||||
return hidden
|
||||
|
||||
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
||||
@ -233,7 +239,7 @@ class SparseMemory(nn.Module):
|
||||
# temporal reads
|
||||
(b, m, w) = memory.size()
|
||||
# get the top KL entries
|
||||
max_length = int(least_used_mem[0, 0].data.cpu().numpy())
|
||||
max_length = int(least_used_mem[0, 0].data.cpu().numpy()) if not self.mem_limit_reached else (m-1)
|
||||
|
||||
# differentiable ops
|
||||
# append forward and backward read positions, might lead to duplicates
|
||||
|
@ -67,6 +67,7 @@ class SparseTemporalMemory(nn.Module):
|
||||
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
self.δ = 0.005 # minimum usage
|
||||
self.timestep = 0
|
||||
self.mem_limit_reached = False
|
||||
|
||||
def rebuild_indexes(self, hidden, erase=False):
|
||||
b = hidden['memory'].size(0)
|
||||
@ -98,6 +99,7 @@ class SparseTemporalMemory(nn.Module):
|
||||
i.add(hidden['memory'][n], last=pos[n][-1])
|
||||
else:
|
||||
self.timestep = 0
|
||||
self.mem_limit_reached = False
|
||||
|
||||
return hidden
|
||||
|
||||
@ -120,7 +122,7 @@ class SparseTemporalMemory(nn.Module):
|
||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
|
||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'usage': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
|
||||
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
}
|
||||
hidden = self.rebuild_indexes(hidden, erase=True)
|
||||
@ -148,7 +150,7 @@ class SparseTemporalMemory(nn.Module):
|
||||
hidden['write_weights'].data.fill_(δ)
|
||||
hidden['read_vectors'].data.fill_(δ)
|
||||
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
|
||||
hidden['usage'].data.fill_(δ)
|
||||
hidden['usage'].data.fill_(0)
|
||||
hidden['read_positions'] = cuda(
|
||||
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
|
||||
@ -167,11 +169,10 @@ class SparseTemporalMemory(nn.Module):
|
||||
for batch in range(b):
|
||||
# update indexes
|
||||
hidden['indexes'][batch].reset()
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None))
|
||||
|
||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
|
||||
1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
||||
self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
|
||||
|
||||
return hidden
|
||||
|
||||
@ -202,6 +203,8 @@ class SparseTemporalMemory(nn.Module):
|
||||
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
||||
|
||||
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
||||
# encourage read and write in the first timestep
|
||||
if self.timestep == 1: read_weights = read_weights + 1
|
||||
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
|
||||
|
||||
hidden['usage'], I = self.update_usage(
|
||||
@ -246,6 +249,9 @@ class SparseTemporalMemory(nn.Module):
|
||||
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
||||
hidden['precedence'] = self.update_precedence(hidden['precedence'], read_weights)
|
||||
|
||||
# update least used memory cell
|
||||
hidden['least_used_mem'] = T.topk(hidden['usage'], 1, dim=-1, largest=False)[1]
|
||||
|
||||
return hidden
|
||||
|
||||
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
||||
@ -292,7 +298,7 @@ class SparseTemporalMemory(nn.Module):
|
||||
# temporal reads
|
||||
(b, m, w) = memory.size()
|
||||
# get the top KL entries
|
||||
max_length = int(least_used_mem[0, 0].data.cpu().numpy())
|
||||
max_length = int(least_used_mem[0, 0].data.cpu().numpy()) if not self.mem_limit_reached else (m-1)
|
||||
|
||||
_, fp = T.topk(forward, self.KL, largest=True)
|
||||
_, bp = T.topk(backward, self.KL, largest=True)
|
||||
|
Loading…
Reference in New Issue
Block a user