Make corrections according to paper

This commit is contained in:
ixaxaar 2017-12-11 20:09:19 +05:30
parent 7edf687759
commit 2d66be3013
2 changed files with 39 additions and 30 deletions

View File

@ -46,6 +46,7 @@ class SparseMemory(nn.Module):
m = self.mem_size
w = self.cell_size
r = self.read_heads
# The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell)
self.c = (r * self.K) + (self.KL * 2) + 1
if self.independent_linears:
@ -102,8 +103,8 @@ class SparseMemory(nn.Module):
# warning can be a huge chunk of contiguous memory
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
'link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
'rev_link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
'link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
'rev_link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
'precedence': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
@ -161,10 +162,12 @@ class SparseMemory(nn.Module):
return hidden
def update_link_matrices(self, link_matrix, rev_link_matrix, write_weights, precedence):
link_matrix = (1 - write_weights).unsqueeze(2) * link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
def update_link_matrices(self, link_matrix, rev_link_matrix, write_weights, precedence, temporal_read_positions):
temporal_read_precedence = precedence.gather(1, temporal_read_positions)
link_matrix = (1 - write_weights).unsqueeze(2) * link_matrix + write_weights.unsqueeze(2) * temporal_read_precedence.unsqueeze(1)
rev_link_matrix = (1 - write_weights).unsqueeze(1) * rev_link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
temporal_write_weights = write_weights.gather(1, temporal_read_positions)
rev_link_matrix = (1 - temporal_write_weights).unsqueeze(1) * rev_link_matrix + (temporal_write_weights.unsqueeze(1) * precedence.unsqueeze(2))
return link_matrix, rev_link_matrix
@ -201,13 +204,20 @@ class SparseMemory(nn.Module):
# update link_matrix and precedence
(b, c) = write_weights.size()
precedence = hidden['precedence'].gather(1, hidden['read_positions'])
# update link matrix
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
hidden['link_matrix'], hidden['rev_link_matrix'] = \
self.update_link_matrices(hidden['link_matrix'], hidden['rev_link_matrix'], write_weights, precedence)
precedence = self.update_precedence(precedence, write_weights)
self.update_link_matrices(
hidden['link_matrix'],
hidden['rev_link_matrix'],
hidden['write_weights'],
hidden['precedence'],
temporal_read_positions
)
hidden['precedence'].scatter_(1, hidden['read_positions'], precedence)
# update precedence vector
hidden['precedence'] = self.update_precedence(hidden['precedence'], hidden['read_weights'])
return hidden
@ -232,10 +242,9 @@ class SparseMemory(nn.Module):
return usage, I
def directional_weightings(self, link_matrix, rev_link_matrix, read_weights):
f = T.bmm(link_matrix, read_weights.unsqueeze(2)).squeeze()
b = T.bmm(read_weights.unsqueeze(1), rev_link_matrix).squeeze()
def directional_weightings(self, link_matrix, rev_link_matrix, temporal_read_weights):
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
return f, b
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage, forward, backward, prev_read_positions):
@ -258,14 +267,12 @@ class SparseMemory(nn.Module):
# TODO: this results in duplicate reads when the content based positions and temporal ones are same
(b, m, w) = memory.size()
# get the top KL entries
max_length = int(least_used_mem[0, 0].data.cpu().numpy())
_, fp = T.topk(forward, self.KL, largest=True)
_, bp = T.topk(backward, self.KL, largest=True)
# get read positions for those entries
fpos = prev_read_positions.gather(1, fp)
bpos = prev_read_positions.gather(1, bp)
# append forward and backward read positions, might lead to duplicates
read_positions = T.cat([read_positions, fpos, bpos], 1)
read_positions = T.cat([read_positions, fp, bp], 1)
read_positions = T.cat([read_positions, least_used_mem], 1)
visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, self.c, w))
@ -278,7 +285,8 @@ class SparseMemory(nn.Module):
def read(self, read_query, hidden):
# get forward and backward weights
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
forward, backward = self.directional_weightings(hidden['link_matrix'], hidden['rev_link_matrix'], read_weights)
# sparse read

View File

@ -147,7 +147,7 @@ if __name__ == '__main__':
temporal_reads=args.temporal_reads,
read_heads=args.read_heads,
gpu_id=args.cuda,
debug=True,
debug=False,
batch_first=True,
independent_linears=False
)
@ -209,7 +209,7 @@ if __name__ == '__main__':
last_save_losses.append(loss_value)
if summarize and rnn.debug:
if summarize:
loss = np.mean(last_save_losses)
# print(input_data)
# print("1111111111111111111111111111111111111111111111")
@ -229,16 +229,17 @@ if __name__ == '__main__':
# print(F.relu6(output))
last_save_losses = []
viz.heatmap(
v['memory'],
opts=dict(
xtickstep=10,
ytickstep=2,
title='Memory, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time',
xlabel='mem_slot * mem_size'
)
)
if args.memory_type == 'dnc':
viz.heatmap(
v['memory'],
opts=dict(
xtickstep=10,
ytickstep=2,
title='Memory, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time',
xlabel='mem_slot * mem_size'
)
)
if args.memory_type == 'dnc':
viz.heatmap(