add curriculum for copy task

This commit is contained in:
ixaxaar 2017-12-10 12:06:22 +05:30
parent e4eb9a53e6
commit 106d362e17

15
tasks/copy_task.py Normal file → Executable file
View File

@ -31,7 +31,7 @@ parser.add_argument('-input_size', type=int, default=6, help='dimension of input
parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller')
parser.add_argument('-nhid', type=int, default=64, help='number of hidden units of the inner nn')
parser.add_argument('-dropout', type=float, default=0, help='controller dropout')
parser.add_argument('-memory_type', type=str, default='dense', help='dense or sparse memory')
parser.add_argument('-memory_type', type=str, default='dnc', help='dense or sparse memory')
parser.add_argument('-nlayer', type=int, default=1, help='number of layers')
parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers')
@ -46,6 +46,8 @@ parser.add_argument('-read_heads', type=int, default=4, help='number of read hea
parser.add_argument('-sparse_reads', type=int, default=10, help='number of sparse reads per read head')
parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length')
parser.add_argument('-curriculum_increment', type=int, default=0, metavar='N', help='sequence_max_length incrementor per 1K iterations')
parser.add_argument('-curriculum_freq', type=int, default=1000, metavar='N', help='sequence_max_length incrementor per 1K iterations')
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
parser.add_argument('-log-interval', type=int, default=200, metavar='N', help='report interval')
@ -198,6 +200,7 @@ if __name__ == '__main__':
summarize = (epoch % summarize_freq == 0)
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)
increment_curriculum = (epoch != 0) and (epoch % args.curriculum_freq == 0)
# detach memory from graph
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
@ -230,7 +233,7 @@ if __name__ == '__main__':
)
)
if args.memory_type == 'DNC':
if args.memory_type == 'dnc':
viz.heatmap(
v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot),
opts=dict(
@ -253,7 +256,7 @@ if __name__ == '__main__':
)
)
if args.memory_type == 'SDNC':
if args.memory_type == 'sdnc':
viz.heatmap(
v['read_positions'],
opts=dict(
@ -288,7 +291,7 @@ if __name__ == '__main__':
)
viz.heatmap(
v['usage_vector'] if args.memory_type == 'DNC' else v['usage'],
v['usage_vector'] if args.memory_type == 'dnc' else v['usage'],
opts=dict(
xtickstep=10,
ytickstep=2,
@ -298,6 +301,10 @@ if __name__ == '__main__':
)
)
if increment_curriculum:
sequence_max_length = sequence_max_length + args.curriculum_increment
print("Increasing max length to " + str(sequence_max_length))
if take_checkpoint:
llprint("\nSaving Checkpoint ... "),
check_ptr = os.path.join(ckpts_dir, 'step_{}.pth'.format(epoch))