Make corrections according to paper
This commit is contained in:
parent
7edf687759
commit
2d66be3013
@ -46,6 +46,7 @@ class SparseMemory(nn.Module):
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
r = self.read_heads
|
||||
# The visible memory size: (K * R read heads, forward and backward temporal reads of size KL and least used memory cell)
|
||||
self.c = (r * self.K) + (self.KL * 2) + 1
|
||||
|
||||
if self.independent_linears:
|
||||
@ -102,8 +103,8 @@ class SparseMemory(nn.Module):
|
||||
# warning can be a huge chunk of contiguous memory
|
||||
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
|
||||
'rev_link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
|
||||
'link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
|
||||
'rev_link_matrix': cuda(T.zeros(b, m, self.KL*2), gpu_id=self.gpu_id),
|
||||
'precedence': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
|
||||
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
@ -161,10 +162,12 @@ class SparseMemory(nn.Module):
|
||||
|
||||
return hidden
|
||||
|
||||
def update_link_matrices(self, link_matrix, rev_link_matrix, write_weights, precedence):
|
||||
link_matrix = (1 - write_weights).unsqueeze(2) * link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
|
||||
def update_link_matrices(self, link_matrix, rev_link_matrix, write_weights, precedence, temporal_read_positions):
|
||||
temporal_read_precedence = precedence.gather(1, temporal_read_positions)
|
||||
link_matrix = (1 - write_weights).unsqueeze(2) * link_matrix + write_weights.unsqueeze(2) * temporal_read_precedence.unsqueeze(1)
|
||||
|
||||
rev_link_matrix = (1 - write_weights).unsqueeze(1) * rev_link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
|
||||
temporal_write_weights = write_weights.gather(1, temporal_read_positions)
|
||||
rev_link_matrix = (1 - temporal_write_weights).unsqueeze(1) * rev_link_matrix + (temporal_write_weights.unsqueeze(1) * precedence.unsqueeze(2))
|
||||
|
||||
return link_matrix, rev_link_matrix
|
||||
|
||||
@ -201,13 +204,20 @@ class SparseMemory(nn.Module):
|
||||
|
||||
# update link_matrix and precedence
|
||||
(b, c) = write_weights.size()
|
||||
precedence = hidden['precedence'].gather(1, hidden['read_positions'])
|
||||
|
||||
# update link matrix
|
||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
|
||||
hidden['link_matrix'], hidden['rev_link_matrix'] = \
|
||||
self.update_link_matrices(hidden['link_matrix'], hidden['rev_link_matrix'], write_weights, precedence)
|
||||
precedence = self.update_precedence(precedence, write_weights)
|
||||
self.update_link_matrices(
|
||||
hidden['link_matrix'],
|
||||
hidden['rev_link_matrix'],
|
||||
hidden['write_weights'],
|
||||
hidden['precedence'],
|
||||
temporal_read_positions
|
||||
)
|
||||
|
||||
hidden['precedence'].scatter_(1, hidden['read_positions'], precedence)
|
||||
# update precedence vector
|
||||
hidden['precedence'] = self.update_precedence(hidden['precedence'], hidden['read_weights'])
|
||||
|
||||
return hidden
|
||||
|
||||
@ -232,10 +242,9 @@ class SparseMemory(nn.Module):
|
||||
|
||||
return usage, I
|
||||
|
||||
def directional_weightings(self, link_matrix, rev_link_matrix, read_weights):
|
||||
|
||||
f = T.bmm(link_matrix, read_weights.unsqueeze(2)).squeeze()
|
||||
b = T.bmm(read_weights.unsqueeze(1), rev_link_matrix).squeeze()
|
||||
def directional_weightings(self, link_matrix, rev_link_matrix, temporal_read_weights):
|
||||
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
||||
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
||||
return f, b
|
||||
|
||||
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage, forward, backward, prev_read_positions):
|
||||
@ -258,14 +267,12 @@ class SparseMemory(nn.Module):
|
||||
# TODO: this results in duplicate reads when the content based positions and temporal ones are same
|
||||
(b, m, w) = memory.size()
|
||||
# get the top KL entries
|
||||
max_length = int(least_used_mem[0, 0].data.cpu().numpy())
|
||||
_, fp = T.topk(forward, self.KL, largest=True)
|
||||
_, bp = T.topk(backward, self.KL, largest=True)
|
||||
# get read positions for those entries
|
||||
fpos = prev_read_positions.gather(1, fp)
|
||||
bpos = prev_read_positions.gather(1, bp)
|
||||
|
||||
# append forward and backward read positions, might lead to duplicates
|
||||
read_positions = T.cat([read_positions, fpos, bpos], 1)
|
||||
read_positions = T.cat([read_positions, fp, bp], 1)
|
||||
read_positions = T.cat([read_positions, least_used_mem], 1)
|
||||
|
||||
visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, self.c, w))
|
||||
@ -278,7 +285,8 @@ class SparseMemory(nn.Module):
|
||||
|
||||
def read(self, read_query, hidden):
|
||||
# get forward and backward weights
|
||||
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
||||
temporal_read_positions = hidden['read_positions'][:, self.read_heads*self.K+1:]
|
||||
read_weights = hidden['read_weights'].gather(1, temporal_read_positions)
|
||||
forward, backward = self.directional_weightings(hidden['link_matrix'], hidden['rev_link_matrix'], read_weights)
|
||||
|
||||
# sparse read
|
||||
|
@ -147,7 +147,7 @@ if __name__ == '__main__':
|
||||
temporal_reads=args.temporal_reads,
|
||||
read_heads=args.read_heads,
|
||||
gpu_id=args.cuda,
|
||||
debug=True,
|
||||
debug=False,
|
||||
batch_first=True,
|
||||
independent_linears=False
|
||||
)
|
||||
@ -209,7 +209,7 @@ if __name__ == '__main__':
|
||||
|
||||
last_save_losses.append(loss_value)
|
||||
|
||||
if summarize and rnn.debug:
|
||||
if summarize:
|
||||
loss = np.mean(last_save_losses)
|
||||
# print(input_data)
|
||||
# print("1111111111111111111111111111111111111111111111")
|
||||
@ -229,16 +229,17 @@ if __name__ == '__main__':
|
||||
# print(F.relu6(output))
|
||||
last_save_losses = []
|
||||
|
||||
viz.heatmap(
|
||||
v['memory'],
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
title='Memory, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
ylabel='layer * time',
|
||||
xlabel='mem_slot * mem_size'
|
||||
)
|
||||
)
|
||||
if args.memory_type == 'dnc':
|
||||
viz.heatmap(
|
||||
v['memory'],
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
title='Memory, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
ylabel='layer * time',
|
||||
xlabel='mem_slot * mem_size'
|
||||
)
|
||||
)
|
||||
|
||||
if args.memory_type == 'dnc':
|
||||
viz.heatmap(
|
||||
|
Loading…
Reference in New Issue
Block a user