diff --git a/tasks/copy_task.py b/tasks/copy_task.py index d75f9f4..6997da7 100644 --- a/tasks/copy_task.py +++ b/tasks/copy_task.py @@ -37,8 +37,8 @@ parser.add_argument('-clip', type=float, default=50, help='gradient clipping') parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size') parser.add_argument('-mem_size', type=int, default=16, help='memory dimension') -parser.add_argument('-mem_slot', type=int, default=10, help='number of memory slots') -parser.add_argument('-read_heads', type=int, default=1, help='number of read heads') +parser.add_argument('-mem_slot', type=int, default=16, help='number of memory slots') +parser.add_argument('-read_heads', type=int, default=4, help='number of read heads') parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length') parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU') @@ -121,7 +121,8 @@ if __name__ == '__main__': read_heads=read_heads, gpu_id=args.cuda, debug=True, - batch_first=True + batch_first=True, + independent_linears=True ) print(rnn) @@ -183,13 +184,13 @@ if __name__ == '__main__': ) viz.heatmap( - v['link_matrix'], + v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot), opts=dict( xtickstep=10, ytickstep=2, title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss), - ylabel='layer * time', - xlabel='mem_slot * mem_slot' + ylabel='mem_slot', + xlabel='mem_slot' ) )