Reseting memory is cheaper than recreating it
This commit is contained in:
parent
2026a8939d
commit
64520e1dcf
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
@ -172,6 +172,7 @@ if __name__ == '__main__':
|
||||
optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)
|
||||
|
||||
|
||||
(chx, mhx, rv) = (None, None, None)
|
||||
for epoch in range(iterations + 1):
|
||||
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
|
||||
optimizer.zero_grad()
|
||||
@ -181,12 +182,10 @@ if __name__ == '__main__':
|
||||
input_data, target_output = generate_data(batch_size, random_length, args.input_size, args.cuda)
|
||||
|
||||
if rnn.debug:
|
||||
output, (chx, mhx, rv), v = rnn(input_data, None, pass_through_memory=True)
|
||||
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||
else:
|
||||
output, (chx, mhx, rv) = rnn(input_data, None, pass_through_memory=True)
|
||||
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||
|
||||
# print(output)
|
||||
# print("-----------------------------------------------------------------------------------------------------")
|
||||
loss = criterion((output), target_output)
|
||||
|
||||
loss.backward()
|
||||
@ -198,6 +197,9 @@ if __name__ == '__main__':
|
||||
summarize = (epoch % summarize_freq == 0)
|
||||
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)
|
||||
|
||||
# detach memory from graph
|
||||
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
||||
|
||||
last_save_losses.append(loss_value)
|
||||
|
||||
if summarize:
|
||||
@ -213,74 +215,73 @@ if __name__ == '__main__':
|
||||
# print(target_output)
|
||||
# print('2222222222222222222222222222222222222222222222')
|
||||
# print(F.relu6(output))
|
||||
llprint("\n\tAvg. Logistic Loss: %.4f\n" % (loss))
|
||||
last_save_losses = []
|
||||
|
||||
viz.heatmap(
|
||||
v['memory'],
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
title='Memory, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
ylabel='layer * time',
|
||||
xlabel='mem_slot * mem_size'
|
||||
)
|
||||
)
|
||||
# viz.heatmap(
|
||||
# v['memory'],
|
||||
# opts=dict(
|
||||
# xtickstep=10,
|
||||
# ytickstep=2,
|
||||
# title='Memory, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
# ylabel='layer * time',
|
||||
# xlabel='mem_slot * mem_size'
|
||||
# )
|
||||
# )
|
||||
|
||||
viz.heatmap(
|
||||
v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot),
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
ylabel='mem_slot',
|
||||
xlabel='mem_slot'
|
||||
)
|
||||
)
|
||||
# viz.heatmap(
|
||||
# v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot),
|
||||
# opts=dict(
|
||||
# xtickstep=10,
|
||||
# ytickstep=2,
|
||||
# title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
# ylabel='mem_slot',
|
||||
# xlabel='mem_slot'
|
||||
# )
|
||||
# )
|
||||
|
||||
viz.heatmap(
|
||||
v['precedence'],
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
ylabel='layer * time',
|
||||
xlabel='mem_slot'
|
||||
)
|
||||
)
|
||||
# viz.heatmap(
|
||||
# v['precedence'],
|
||||
# opts=dict(
|
||||
# xtickstep=10,
|
||||
# ytickstep=2,
|
||||
# title='Precedence, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
# ylabel='layer * time',
|
||||
# xlabel='mem_slot'
|
||||
# )
|
||||
# )
|
||||
|
||||
viz.heatmap(
|
||||
v['read_weights'],
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
title='Read Weights, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
ylabel='layer * time',
|
||||
xlabel='nr_read_heads * mem_slot'
|
||||
)
|
||||
)
|
||||
# viz.heatmap(
|
||||
# v['read_weights'],
|
||||
# opts=dict(
|
||||
# xtickstep=10,
|
||||
# ytickstep=2,
|
||||
# title='Read Weights, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
# ylabel='layer * time',
|
||||
# xlabel='nr_read_heads * mem_slot'
|
||||
# )
|
||||
# )
|
||||
|
||||
viz.heatmap(
|
||||
v['write_weights'],
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
title='Write Weights, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
ylabel='layer * time',
|
||||
xlabel='mem_slot'
|
||||
)
|
||||
)
|
||||
# viz.heatmap(
|
||||
# v['write_weights'],
|
||||
# opts=dict(
|
||||
# xtickstep=10,
|
||||
# ytickstep=2,
|
||||
# title='Write Weights, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
# ylabel='layer * time',
|
||||
# xlabel='mem_slot'
|
||||
# )
|
||||
# )
|
||||
|
||||
viz.heatmap(
|
||||
v['usage_vector'],
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
title='Usage Vector, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
ylabel='layer * time',
|
||||
xlabel='mem_slot'
|
||||
)
|
||||
)
|
||||
# viz.heatmap(
|
||||
# v['usage_vector'],
|
||||
# opts=dict(
|
||||
# xtickstep=10,
|
||||
# ytickstep=2,
|
||||
# title='Usage Vector, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
# ylabel='layer * time',
|
||||
# xlabel='mem_slot'
|
||||
# )
|
||||
# )
|
||||
|
||||
if take_checkpoint:
|
||||
llprint("\nSaving Checkpoint ... "),
|
||||
|
Loading…
Reference in New Issue
Block a user