diff --git a/tasks/copy_task.py b/tasks/copy_task.py old mode 100644 new mode 100755 index 15cced1..9a93479 --- a/tasks/copy_task.py +++ b/tasks/copy_task.py @@ -31,7 +31,7 @@ parser.add_argument('-input_size', type=int, default=6, help='dimension of input parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller') parser.add_argument('-nhid', type=int, default=64, help='number of hidden units of the inner nn') parser.add_argument('-dropout', type=float, default=0, help='controller dropout') -parser.add_argument('-memory_type', type=str, default='dense', help='dense or sparse memory') +parser.add_argument('-memory_type', type=str, default='dnc', help='dense or sparse memory') parser.add_argument('-nlayer', type=int, default=1, help='number of layers') parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers') @@ -46,6 +46,8 @@ parser.add_argument('-read_heads', type=int, default=4, help='number of read hea parser.add_argument('-sparse_reads', type=int, default=10, help='number of sparse reads per read head') parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length') +parser.add_argument('-curriculum_increment', type=int, default=0, metavar='N', help='sequence_max_length incrementor per 1K iterations') +parser.add_argument('-curriculum_freq', type=int, default=1000, metavar='N', help='sequence_max_length incrementor per 1K iterations') parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU') parser.add_argument('-log-interval', type=int, default=200, metavar='N', help='report interval') @@ -198,6 +200,7 @@ if __name__ == '__main__': summarize = (epoch % summarize_freq == 0) take_checkpoint = (epoch != 0) and (epoch % check_freq == 0) + increment_curriculum = (epoch != 0) and (epoch % args.curriculum_freq == 0) # detach memory from graph mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() } @@ -230,7 +233,7 @@ if __name__ == '__main__': ) ) - if args.memory_type == 'DNC': + if args.memory_type == 'dnc': viz.heatmap( v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot), opts=dict( @@ -253,7 +256,7 @@ if __name__ == '__main__': ) ) - if args.memory_type == 'SDNC': + if args.memory_type == 'sdnc': viz.heatmap( v['read_positions'], opts=dict( @@ -288,7 +291,7 @@ if __name__ == '__main__': ) viz.heatmap( - v['usage_vector'] if args.memory_type == 'DNC' else v['usage'], + v['usage_vector'] if args.memory_type == 'dnc' else v['usage'], opts=dict( xtickstep=10, ytickstep=2, @@ -298,6 +301,10 @@ if __name__ == '__main__': ) ) + if increment_curriculum: + sequence_max_length = sequence_max_length + args.curriculum_increment + print("Increasing max length to " + str(sequence_max_length)) + if take_checkpoint: llprint("\nSaving Checkpoint ... "), check_ptr = os.path.join(ckpts_dir, 'step_{}.pth'.format(epoch))