Merge pull request #27 from ixaxaar/bugfix

Bugfixes
This commit is contained in:
Russi Chatterjee 2017-12-31 11:43:55 +05:30 committed by GitHub
commit 4115e69155
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 15 deletions

View File

@ -64,6 +64,7 @@ class SparseMemory(nn.Module):
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
self.δ = 0.005 # minimum usage
self.timestep = 0
self.mem_limit_reached = False
def rebuild_indexes(self, hidden, erase=False):
b = hidden['memory'].size(0)
@ -95,6 +96,7 @@ class SparseMemory(nn.Module):
i.add(hidden['memory'][n], last=pos[n][-1])
else:
self.timestep = 0
self.mem_limit_reached = False
return hidden
@ -114,7 +116,7 @@ class SparseMemory(nn.Module):
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'usage': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
}
hidden = self.rebuild_indexes(hidden, erase=True)
@ -135,10 +137,10 @@ class SparseMemory(nn.Module):
hidden['read_weights'].data.fill_(δ)
hidden['write_weights'].data.fill_(δ)
hidden['read_vectors'].data.fill_(δ)
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
hidden['usage'].data.fill_(δ)
hidden['least_used_mem'].data.fill_(c + 1)
hidden['usage'].data.fill_(0)
hidden['read_positions'] = cuda(
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
return hidden
@ -155,17 +157,18 @@ class SparseMemory(nn.Module):
for batch in range(b):
# update indexes
hidden['indexes'][batch].reset()
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None))
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
1) if mem_limit_reached else hidden['least_used_mem'] + 1
self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
return hidden
def write(self, interpolation_gate, write_vector, write_gate, hidden):
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
# encourage read and write in the first timestep
if self.timestep == 1: read_weights = read_weights + 1
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
hidden['usage'], I = self.update_usage(
@ -192,6 +195,9 @@ class SparseMemory(nn.Module):
(1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
hidden = self.write_into_sparse_memory(hidden)
# update least used memory cell
hidden['least_used_mem'] = T.topk(hidden['usage'], 1, dim=-1, largest=False)[1]
return hidden
def update_usage(self, read_positions, read_weights, write_weights, usage):
@ -233,7 +239,7 @@ class SparseMemory(nn.Module):
# temporal reads
(b, m, w) = memory.size()
# get the top KL entries
max_length = int(least_used_mem[0, 0].data.cpu().numpy())
max_length = int(least_used_mem[0, 0].data.cpu().numpy()) if not self.mem_limit_reached else (m-1)
# differentiable ops
# append forward and backward read positions, might lead to duplicates

View File

@ -67,6 +67,7 @@ class SparseTemporalMemory(nn.Module):
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
self.δ = 0.005 # minimum usage
self.timestep = 0
self.mem_limit_reached = False
def rebuild_indexes(self, hidden, erase=False):
b = hidden['memory'].size(0)
@ -98,6 +99,7 @@ class SparseTemporalMemory(nn.Module):
i.add(hidden['memory'][n], last=pos[n][-1])
else:
self.timestep = 0
self.mem_limit_reached = False
return hidden
@ -120,7 +122,7 @@ class SparseTemporalMemory(nn.Module):
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
'least_used_mem': cuda(T.zeros(b, 1).fill_(c + 1), gpu_id=self.gpu_id).long(),
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'usage': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
}
hidden = self.rebuild_indexes(hidden, erase=True)
@ -148,7 +150,7 @@ class SparseTemporalMemory(nn.Module):
hidden['write_weights'].data.fill_(δ)
hidden['read_vectors'].data.fill_(δ)
hidden['least_used_mem'].data.fill_(c + 1 + self.timestep)
hidden['usage'].data.fill_(δ)
hidden['usage'].data.fill_(0)
hidden['read_positions'] = cuda(
T.arange(self.timestep, c + self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
@ -167,11 +169,10 @@ class SparseTemporalMemory(nn.Module):
for batch in range(b):
# update indexes
hidden['indexes'][batch].reset()
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
hidden['indexes'][batch].add(hidden['memory'][batch], last=(pos[batch][-1] if not self.mem_limit_reached else None))
mem_limit_reached = hidden['least_used_mem'][0].data.cpu().numpy()[0] >= self.mem_size - 1
hidden['least_used_mem'] = (hidden['least_used_mem'] * 0 + self.c +
1) if mem_limit_reached else hidden['least_used_mem'] + 1
self.mem_limit_reached = mem_limit_reached or self.mem_limit_reached
return hidden
@ -202,6 +203,8 @@ class SparseTemporalMemory(nn.Module):
def write(self, interpolation_gate, write_vector, write_gate, hidden):
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
# encourage read and write in the first timestep
if self.timestep == 1: read_weights = read_weights + 1
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
hidden['usage'], I = self.update_usage(
@ -246,6 +249,9 @@ class SparseTemporalMemory(nn.Module):
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
hidden['precedence'] = self.update_precedence(hidden['precedence'], read_weights)
# update least used memory cell
hidden['least_used_mem'] = T.topk(hidden['usage'], 1, dim=-1, largest=False)[1]
return hidden
def update_usage(self, read_positions, read_weights, write_weights, usage):
@ -292,7 +298,7 @@ class SparseTemporalMemory(nn.Module):
# temporal reads
(b, m, w) = memory.size()
# get the top KL entries
max_length = int(least_used_mem[0, 0].data.cpu().numpy())
max_length = int(least_used_mem[0, 0].data.cpu().numpy()) if not self.mem_limit_reached else (m-1)
_, fp = T.topk(forward, self.KL, largest=True)
_, bp = T.topk(backward, self.KL, largest=True)

View File

@ -22,7 +22,7 @@ with open(path.join(here, 'README.md'), encoding='utf-8') as f:
setup(
name='dnc',
version='0.0.7',
version='0.0.8',
description='Differentiable Neural Computer, for Pytorch',
long_description=long_description,