commit
4115e69155
@ -64,6 +64,7 @@ class SparseMemory(nn.Module):
|
|||||||
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||||
self.δ = 0.005 # minimum usage
|
self.δ = 0.005 # minimum usage
|
||||||
self.timestep = 0
|
self.timestep = 0
|
||||||
|
self.mem_limit_reached = False
|
||||||
|
|
||||||
def rebuild_indexes(self, hidden, erase=False):
|
def rebuild_indexes(self, hidden, erase=False):
|
||||||
b = hidden['memory'].size(0)
|
b = hidden['memory'].size(0)
|
||||||
@ -95,6 +96,7 @@ class SparseMemory(nn.Module):
|
|||||||
i.add(hidden['memory'][n], last=pos[n][-1])
|
i.add(hidden['memory'][n], last=pos[n][-1])
|
||||||
else:
|
else:
|
||||||
self.timestep = 0
|
self.timestep = 0
|
||||||
|
self.mem_limit_reached = False
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
@ -114,7 +116,7 @@ class SparseMemory(nn.Module):
|
|||||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
|
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
|
||||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'usage': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
|
||||||
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||||
}
|
}
|
||||||
hidden = self.rebuild_indexes(hidden, erase=True)
|
hidden = self.rebuild_indexes(hidden, erase=True)
|
||||||
@ -135,10 +137,10 @@ class SparseMemory(nn.Module):
|
|||||||
hidden['read_weights'].data.fill_(δ)
|
hidden['read_weights'].data.fill_(δ)
|
||||||
hidden['write_weights'].data.fill_(δ)
|
hidden['write_weights'].data.fill_(δ)
|
||||||
hidden['read_vectors'].data.fill_(δ)
|
hidden['read_vectors'].data.fill_(δ)
|
||||||
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
|
hidden['least_used_mem'].data.fill_(c + 1)
|
||||||
hidden['usage'].data.fill_(δ)
|
hidden['usage'].data.fill_(0)
|
||||||
hidden['read_positions'] = cuda(
|
hidden['read_positions'] = cuda(
|
||||||
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
@ -155,17 +157,18 @@ class SparseMemory(nn.Module):
|
|||||||
for batch in range(b):
|
for batch in range(b):
|
||||||
# update indexes
|
# update indexes
|
||||||
hidden['indexes'][batch].reset()
|
hidden['indexes'][batch].reset()
|
||||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None))
|
||||||
|
|
||||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
||||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
|
self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
|
||||||
1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
||||||
|
|
||||||
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
||||||
|
# encourage read and write in the first timestep
|
||||||
|
if self.timestep == 1: read_weights = read_weights + 1
|
||||||
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
|
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
|
||||||
|
|
||||||
hidden['usage'], I = self.update_usage(
|
hidden['usage'], I = self.update_usage(
|
||||||
@ -192,6 +195,9 @@ class SparseMemory(nn.Module):
|
|||||||
(1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
(1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||||
hidden = self.write_into_sparse_memory(hidden)
|
hidden = self.write_into_sparse_memory(hidden)
|
||||||
|
|
||||||
|
# update least used memory cell
|
||||||
|
hidden['least_used_mem'] = T.topk(hidden['usage'], 1, dim=-1, largest=False)[1]
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
||||||
@ -233,7 +239,7 @@ class SparseMemory(nn.Module):
|
|||||||
# temporal reads
|
# temporal reads
|
||||||
(b, m, w) = memory.size()
|
(b, m, w) = memory.size()
|
||||||
# get the top KL entries
|
# get the top KL entries
|
||||||
max_length = int(least_used_mem[0, 0].data.cpu().numpy())
|
max_length = int(least_used_mem[0, 0].data.cpu().numpy()) if not self.mem_limit_reached else (m-1)
|
||||||
|
|
||||||
# differentiable ops
|
# differentiable ops
|
||||||
# append forward and backward read positions, might lead to duplicates
|
# append forward and backward read positions, might lead to duplicates
|
||||||
|
@ -67,6 +67,7 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||||
self.δ = 0.005 # minimum usage
|
self.δ = 0.005 # minimum usage
|
||||||
self.timestep = 0
|
self.timestep = 0
|
||||||
|
self.mem_limit_reached = False
|
||||||
|
|
||||||
def rebuild_indexes(self, hidden, erase=False):
|
def rebuild_indexes(self, hidden, erase=False):
|
||||||
b = hidden['memory'].size(0)
|
b = hidden['memory'].size(0)
|
||||||
@ -98,6 +99,7 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
i.add(hidden['memory'][n], last=pos[n][-1])
|
i.add(hidden['memory'][n], last=pos[n][-1])
|
||||||
else:
|
else:
|
||||||
self.timestep = 0
|
self.timestep = 0
|
||||||
|
self.mem_limit_reached = False
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
@ -120,7 +122,7 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
|
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
|
||||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'usage': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
|
||||||
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||||
}
|
}
|
||||||
hidden = self.rebuild_indexes(hidden, erase=True)
|
hidden = self.rebuild_indexes(hidden, erase=True)
|
||||||
@ -148,7 +150,7 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
hidden['write_weights'].data.fill_(δ)
|
hidden['write_weights'].data.fill_(δ)
|
||||||
hidden['read_vectors'].data.fill_(δ)
|
hidden['read_vectors'].data.fill_(δ)
|
||||||
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
|
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
|
||||||
hidden['usage'].data.fill_(δ)
|
hidden['usage'].data.fill_(0)
|
||||||
hidden['read_positions'] = cuda(
|
hidden['read_positions'] = cuda(
|
||||||
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
||||||
|
|
||||||
@ -167,11 +169,10 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
for batch in range(b):
|
for batch in range(b):
|
||||||
# update indexes
|
# update indexes
|
||||||
hidden['indexes'][batch].reset()
|
hidden['indexes'][batch].reset()
|
||||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None))
|
||||||
|
|
||||||
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
|
||||||
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
|
self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
|
||||||
1) if mem_limit_reached else hidden['least_used_mem'] + 1
|
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
@ -202,6 +203,8 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
||||||
|
|
||||||
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
||||||
|
# encourage read and write in the first timestep
|
||||||
|
if self.timestep == 1: read_weights = read_weights + 1
|
||||||
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
|
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
|
||||||
|
|
||||||
hidden['usage'], I = self.update_usage(
|
hidden['usage'], I = self.update_usage(
|
||||||
@ -246,6 +249,9 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
||||||
hidden['precedence'] = self.update_precedence(hidden['precedence'], read_weights)
|
hidden['precedence'] = self.update_precedence(hidden['precedence'], read_weights)
|
||||||
|
|
||||||
|
# update least used memory cell
|
||||||
|
hidden['least_used_mem'] = T.topk(hidden['usage'], 1, dim=-1, largest=False)[1]
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
||||||
@ -292,7 +298,7 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
# temporal reads
|
# temporal reads
|
||||||
(b, m, w) = memory.size()
|
(b, m, w) = memory.size()
|
||||||
# get the top KL entries
|
# get the top KL entries
|
||||||
max_length = int(least_used_mem[0, 0].data.cpu().numpy())
|
max_length = int(least_used_mem[0, 0].data.cpu().numpy()) if not self.mem_limit_reached else (m-1)
|
||||||
|
|
||||||
_, fp = T.topk(forward, self.KL, largest=True)
|
_, fp = T.topk(forward, self.KL, largest=True)
|
||||||
_, bp = T.topk(backward, self.KL, largest=True)
|
_, bp = T.topk(backward, self.KL, largest=True)
|
||||||
|
2
setup.py
2
setup.py
@ -22,7 +22,7 @@ with open(path.join(here, 'README.md'), encoding='utf-8') as f:
|
|||||||
setup(
|
setup(
|
||||||
name='dnc',
|
name='dnc',
|
||||||
|
|
||||||
version='0.0.7',
|
version='0.0.8',
|
||||||
|
|
||||||
description='Differentiable Neural Computer, for Pytorch',
|
description='Differentiable Neural Computer, for Pytorch',
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
|
Loading…
Reference in New Issue
Block a user