Reseting memory is cheaper than recreating it

This commit is contained in:
ixaxaar 2017-12-04 21:11:05 +05:30
parent 2026a8939d
commit 64520e1dcf

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import warnings
warnings.filterwarnings('ignore')
@ -172,6 +172,7 @@ if __name__ == '__main__':
optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)
(chx, mhx, rv) = (None, None, None)
for epoch in range(iterations + 1):
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
optimizer.zero_grad()
@ -181,12 +182,10 @@ if __name__ == '__main__':
input_data, target_output = generate_data(batch_size, random_length, args.input_size, args.cuda)
if rnn.debug:
output, (chx, mhx, rv), v = rnn(input_data, None, pass_through_memory=True)
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
else:
output, (chx, mhx, rv) = rnn(input_data, None, pass_through_memory=True)
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
# print(output)
# print("-----------------------------------------------------------------------------------------------------")
loss = criterion((output), target_output)
loss.backward()
@ -198,6 +197,9 @@ if __name__ == '__main__':
summarize = (epoch % summarize_freq == 0)
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)
# detach memory from graph
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
last_save_losses.append(loss_value)
if summarize:
@ -213,74 +215,73 @@ if __name__ == '__main__':
# print(target_output)
# print('2222222222222222222222222222222222222222222222')
# print(F.relu6(output))
llprint("\n\tAvg. Logistic Loss: %.4f\n" % (loss))
last_save_losses = []
viz.heatmap(
v['memory'],
opts=dict(
xtickstep=10,
ytickstep=2,
title='Memory, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time',
xlabel='mem_slot * mem_size'
)
)
# viz.heatmap(
# v['memory'],
# opts=dict(
# xtickstep=10,
# ytickstep=2,
# title='Memory, t: ' + str(epoch) + ', loss: ' + str(loss),
# ylabel='layer * time',
# xlabel='mem_slot * mem_size'
# )
# )
viz.heatmap(
v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot),
opts=dict(
xtickstep=10,
ytickstep=2,
title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='mem_slot',
xlabel='mem_slot'
)
)
# viz.heatmap(
# v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot),
# opts=dict(
# xtickstep=10,
# ytickstep=2,
# title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss),
# ylabel='mem_slot',
# xlabel='mem_slot'
# )
# )
viz.heatmap(
v['precedence'],
opts=dict(
xtickstep=10,
ytickstep=2,
title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time',
xlabel='mem_slot'
)
)
# viz.heatmap(
# v['precedence'],
# opts=dict(
# xtickstep=10,
# ytickstep=2,
# title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss),
# ylabel='layer * time',
# xlabel='mem_slot'
# )
# )
viz.heatmap(
v['read_weights'],
opts=dict(
xtickstep=10,
ytickstep=2,
title='Read Weights, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time',
xlabel='nr_read_heads * mem_slot'
)
)
# viz.heatmap(
# v['read_weights'],
# opts=dict(
# xtickstep=10,
# ytickstep=2,
# title='Read Weights, t: ' + str(epoch) + ', loss: ' + str(loss),
# ylabel='layer * time',
# xlabel='nr_read_heads * mem_slot'
# )
# )
viz.heatmap(
v['write_weights'],
opts=dict(
xtickstep=10,
ytickstep=2,
title='Write Weights, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time',
xlabel='mem_slot'
)
)
# viz.heatmap(
# v['write_weights'],
# opts=dict(
# xtickstep=10,
# ytickstep=2,
# title='Write Weights, t: ' + str(epoch) + ', loss: ' + str(loss),
# ylabel='layer * time',
# xlabel='mem_slot'
# )
# )
viz.heatmap(
v['usage_vector'],
opts=dict(
xtickstep=10,
ytickstep=2,
title='Usage Vector, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time',
xlabel='mem_slot'
)
)
# viz.heatmap(
# v['usage_vector'],
# opts=dict(
# xtickstep=10,
# ytickstep=2,
# title='Usage Vector, t: ' + str(epoch) + ', loss: ' + str(loss),
# ylabel='layer * time',
# xlabel='mem_slot'
# )
# )
if take_checkpoint:
llprint("\nSaving Checkpoint ... "),