add curriculum for copy task
This commit is contained in:
parent
e4eb9a53e6
commit
106d362e17
15
tasks/copy_task.py
Normal file → Executable file
15
tasks/copy_task.py
Normal file → Executable file
@ -31,7 +31,7 @@ parser.add_argument('-input_size', type=int, default=6, help='dimension of input
|
||||
parser.add_argument('-rnn_type', type=str, default='lstm', help='type of recurrent cells to use for the controller')
|
||||
parser.add_argument('-nhid', type=int, default=64, help='number of hidden units of the inner nn')
|
||||
parser.add_argument('-dropout', type=float, default=0, help='controller dropout')
|
||||
parser.add_argument('-memory_type', type=str, default='dense', help='dense or sparse memory')
|
||||
parser.add_argument('-memory_type', type=str, default='dnc', help='dense or sparse memory')
|
||||
|
||||
parser.add_argument('-nlayer', type=int, default=1, help='number of layers')
|
||||
parser.add_argument('-nhlayer', type=int, default=2, help='number of hidden layers')
|
||||
@ -46,6 +46,8 @@ parser.add_argument('-read_heads', type=int, default=4, help='number of read hea
|
||||
parser.add_argument('-sparse_reads', type=int, default=10, help='number of sparse reads per read head')
|
||||
|
||||
parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length')
|
||||
parser.add_argument('-curriculum_increment', type=int, default=0, metavar='N', help='sequence_max_length incrementor per 1K iterations')
|
||||
parser.add_argument('-curriculum_freq', type=int, default=1000, metavar='N', help='sequence_max_length incrementor per 1K iterations')
|
||||
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
|
||||
parser.add_argument('-log-interval', type=int, default=200, metavar='N', help='report interval')
|
||||
|
||||
@ -198,6 +200,7 @@ if __name__ == '__main__':
|
||||
|
||||
summarize = (epoch % summarize_freq == 0)
|
||||
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)
|
||||
increment_curriculum = (epoch != 0) and (epoch % args.curriculum_freq == 0)
|
||||
|
||||
# detach memory from graph
|
||||
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
||||
@ -230,7 +233,7 @@ if __name__ == '__main__':
|
||||
)
|
||||
)
|
||||
|
||||
if args.memory_type == 'DNC':
|
||||
if args.memory_type == 'dnc':
|
||||
viz.heatmap(
|
||||
v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot),
|
||||
opts=dict(
|
||||
@ -253,7 +256,7 @@ if __name__ == '__main__':
|
||||
)
|
||||
)
|
||||
|
||||
if args.memory_type == 'SDNC':
|
||||
if args.memory_type == 'sdnc':
|
||||
viz.heatmap(
|
||||
v['read_positions'],
|
||||
opts=dict(
|
||||
@ -288,7 +291,7 @@ if __name__ == '__main__':
|
||||
)
|
||||
|
||||
viz.heatmap(
|
||||
v['usage_vector'] if args.memory_type == 'DNC' else v['usage'],
|
||||
v['usage_vector'] if args.memory_type == 'dnc' else v['usage'],
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
@ -298,6 +301,10 @@ if __name__ == '__main__':
|
||||
)
|
||||
)
|
||||
|
||||
if increment_curriculum:
|
||||
sequence_max_length = sequence_max_length + args.curriculum_increment
|
||||
print("Increasing max length to " + str(sequence_max_length))
|
||||
|
||||
if take_checkpoint:
|
||||
llprint("\nSaving Checkpoint ... "),
|
||||
check_ptr = os.path.join(ckpts_dir, 'step_{}.pth'.format(epoch))
|
||||
|
Loading…
Reference in New Issue
Block a user