From 146916d6ff45f429443c801a883e44303cb7a65c Mon Sep 17 00:00:00 2001 From: ixaxaar Date: Sat, 9 Dec 2017 18:28:59 +0530 Subject: [PATCH] Get debugging working and initialize properly --- dnc/indexes.py | 2 +- dnc/sdnc.py | 31 +++++----- dnc/sparse_memory.py | 13 +++-- tasks/copy_task.py | 135 ++++++++++++++++++++++++------------------- 4 files changed, 99 insertions(+), 82 deletions(-) diff --git a/dnc/indexes.py b/dnc/indexes.py index 487b6cf..55c3679 100644 --- a/dnc/indexes.py +++ b/dnc/indexes.py @@ -25,7 +25,7 @@ class Index(object): self.res.initializeForDevice(self.gpu_id) nr_samples = self.nr_cells * 100 * self.cell_size - train = train if train is not None else T.arange(-nr_samples, nr_samples, 2).view(self.nr_cells * 100, self.cell_size) / (nr_samples/10) + train = train if train is not None else T.randn(self.nr_cells * 100, self.cell_size) * 10 # train = T.randn(self.nr_cells * 100, self.cell_size) self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT) diff --git a/dnc/sdnc.py b/dnc/sdnc.py index 8ea2b57..ad34f90 100644 --- a/dnc/sdnc.py +++ b/dnc/sdnc.py @@ -158,19 +158,24 @@ class SDNC(nn.Module): if not debug_obj: debug_obj = { 'memory': [], - 'link_matrix': [], - 'precedence': [], + 'visible_memory': [], 'read_weights': [], 'write_weights': [], - 'usage_vector': [], + 'read_vectors': [], + 'last_used_mem': [], + 'usage': [], + 'read_positions': [] } - # debug_obj['memory'].append(mhx['memory'][0].data.cpu().numpy()) - # debug_obj['link_matrix'].append(mhx['link_matrix'][0][0].data.cpu().numpy()) - # debug_obj['precedence'].append(mhx['precedence'][0].data.cpu().numpy()) - # debug_obj['read_weights'].append(mhx['read_weights'][0].data.cpu().numpy()) - # debug_obj['write_weights'].append(mhx['write_weights'][0].data.cpu().numpy()) - # debug_obj['usage_vector'].append(mhx['usage_vector'][0].unsqueeze(0).data.cpu().numpy()) + debug_obj['memory'].append(mhx['memory'][0].data.cpu().numpy()) + debug_obj['visible_memory'].append(mhx['visible_memory'][0].data.cpu().numpy()) + debug_obj['read_weights'].append(mhx['read_weights'][0].unsqueeze(0).data.cpu().numpy()) + debug_obj['write_weights'].append(mhx['write_weights'][0].unsqueeze(0).data.cpu().numpy()) + debug_obj['read_vectors'].append(mhx['read_vectors'][0].data.cpu().numpy()) + debug_obj['last_used_mem'].append(mhx['last_used_mem'][0].unsqueeze(0).data.cpu().numpy()) + debug_obj['usage'].append(mhx['usage'][0].unsqueeze(0).data.cpu().numpy()) + debug_obj['read_positions'].append(mhx['read_positions'][0].unsqueeze(0).data.cpu().numpy()) + return debug_obj def _layer_forward(self, input, layer, hx=(None, None), pass_through_memory=True): @@ -259,16 +264,14 @@ class SDNC(nn.Module): outs[time] = T.cat([outs[time], last_read], 1) inputs[time] = outs[time] - # if self.debug: - # viz = {k: np.array(v) for k, v in viz.items()} - # viz = {k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k, v in viz.items()} + if self.debug: + viz = {k: np.array(v) for k, v in viz.items()} + viz = {k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k, v in viz.items()} # pass through final output layer inputs = [self.output(i) for i in inputs] outputs = T.stack(inputs, 1 if self.batch_first else 0) - # outputs.register_hook(lambda x: print("========================================", x.squeeze())) - if is_packed: outputs = pack(output, lengths) diff --git a/dnc/sparse_memory.py b/dnc/sparse_memory.py index 396e126..a18f6f4 100644 --- a/dnc/sparse_memory.py +++ b/dnc/sparse_memory.py @@ -103,7 +103,7 @@ class SparseMemory(nn.Module): 'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id), 'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id), 'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id), - 'last_used_mem': cuda(T.zeros(b, 1).fill_(δ), gpu_id=self.gpu_id).long(), + 'last_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(), 'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id), 'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long() } @@ -125,9 +125,10 @@ class SparseMemory(nn.Module): hidden['read_weights'].data.fill_(δ) hidden['write_weights'].data.fill_(δ) hidden['read_vectors'].data.fill_(δ) - hidden['last_used_mem'].data.fill_(0) + hidden['last_used_mem'].data.fill_(c+1+self.timestep) hidden['usage'].data.fill_(δ) - hidden['read_positions'] = cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long() + hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long() + return hidden def write_into_sparse_memory(self, hidden): @@ -137,7 +138,6 @@ class SparseMemory(nn.Module): (b, m, w) = hidden['memory'].size() # update memory hidden['memory'].scatter_(1, positions.unsqueeze(2).expand(b, self.read_heads*self.K+1, w), visible_memory) - print(positions) # non-differentiable operations pos = positions.data.cpu().numpy() @@ -145,7 +145,8 @@ class SparseMemory(nn.Module): # update indexes hidden['indexes'][batch].reset() hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1]) - hidden['last_used_mem'][batch] = (int(pos[batch][-1]) + 1) if (pos[batch][-1] + 1) < self.mem_size else 0 + + hidden['last_used_mem'] = hidden['last_used_mem'] + 1 if self.timestep < self.mem_size else hidden['last_used_mem'] * 0 return hidden @@ -216,7 +217,7 @@ class SparseMemory(nn.Module): (b, m, w) = memory.size() visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, r*k+1, w)) - read_weights = F.softmax(θ(visible_memory, keys), dim=2) + read_weights = σ(θ(visible_memory, keys), 2) read_vectors = T.bmm(read_weights, visible_memory) read_weights = T.prod(read_weights, 1) diff --git a/tasks/copy_task.py b/tasks/copy_task.py index 3466312..15cced1 100644 --- a/tasks/copy_task.py +++ b/tasks/copy_task.py @@ -143,7 +143,7 @@ if __name__ == '__main__': sparse_reads=args.sparse_reads, read_heads=args.read_heads, gpu_id=args.cuda, - debug=False, + debug=True, batch_first=True, independent_linears=False ) @@ -219,71 +219,84 @@ if __name__ == '__main__': # print(F.relu6(output)) last_save_losses = [] - # viz.heatmap( - # v['memory'], - # opts=dict( - # xtickstep=10, - # ytickstep=2, - # title='Memory, t: ' + str(epoch) + ', loss: ' + str(loss), - # ylabel='layer * time', - # xlabel='mem_slot * mem_size' - # ) - # ) + viz.heatmap( + v['memory'], + opts=dict( + xtickstep=10, + ytickstep=2, + title='Memory, t: ' + str(epoch) + ', loss: ' + str(loss), + ylabel='layer * time', + xlabel='mem_slot * mem_size' + ) + ) - # viz.heatmap( - # v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot), - # opts=dict( - # xtickstep=10, - # ytickstep=2, - # title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss), - # ylabel='mem_slot', - # xlabel='mem_slot' - # ) - # ) + if args.memory_type == 'DNC': + viz.heatmap( + v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot), + opts=dict( + xtickstep=10, + ytickstep=2, + title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss), + ylabel='mem_slot', + xlabel='mem_slot' + ) + ) - # viz.heatmap( - # v['precedence'], - # opts=dict( - # xtickstep=10, - # ytickstep=2, - # title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss), - # ylabel='layer * time', - # xlabel='mem_slot' - # ) - # ) + viz.heatmap( + v['precedence'], + opts=dict( + xtickstep=10, + ytickstep=2, + title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss), + ylabel='layer * time', + xlabel='mem_slot' + ) + ) - # viz.heatmap( - # v['read_weights'], - # opts=dict( - # xtickstep=10, - # ytickstep=2, - # title='Read Weights, t: ' + str(epoch) + ', loss: ' + str(loss), - # ylabel='layer * time', - # xlabel='nr_read_heads * mem_slot' - # ) - # ) + if args.memory_type == 'SDNC': + viz.heatmap( + v['read_positions'], + opts=dict( + xtickstep=10, + ytickstep=2, + title='Read Positions, t: ' + str(epoch) + ', loss: ' + str(loss), + ylabel='layer * time', + xlabel='mem_slot' + ) + ) - # viz.heatmap( - # v['write_weights'], - # opts=dict( - # xtickstep=10, - # ytickstep=2, - # title='Write Weights, t: ' + str(epoch) + ', loss: ' + str(loss), - # ylabel='layer * time', - # xlabel='mem_slot' - # ) - # ) + viz.heatmap( + v['read_weights'], + opts=dict( + xtickstep=10, + ytickstep=2, + title='Read Weights, t: ' + str(epoch) + ', loss: ' + str(loss), + ylabel='layer * time', + xlabel='nr_read_heads * mem_slot' + ) + ) - # viz.heatmap( - # v['usage_vector'], - # opts=dict( - # xtickstep=10, - # ytickstep=2, - # title='Usage Vector, t: ' + str(epoch) + ', loss: ' + str(loss), - # ylabel='layer * time', - # xlabel='mem_slot' - # ) - # ) + viz.heatmap( + v['write_weights'], + opts=dict( + xtickstep=10, + ytickstep=2, + title='Write Weights, t: ' + str(epoch) + ', loss: ' + str(loss), + ylabel='layer * time', + xlabel='mem_slot' + ) + ) + + viz.heatmap( + v['usage_vector'] if args.memory_type == 'DNC' else v['usage'], + opts=dict( + xtickstep=10, + ytickstep=2, + title='Usage Vector, t: ' + str(epoch) + ', loss: ' + str(loss), + ylabel='layer * time', + xlabel='mem_slot' + ) + ) if take_checkpoint: llprint("\nSaving Checkpoint ... "),