Compare commits

...

1 Commits
master ... 45

Author SHA1 Message Date
Russi Chatterjee
2d35eab3a6 fix copy task generalization code 2019-08-16 10:01:11 +05:30

View File

@ -363,18 +363,18 @@ if __name__ == '__main__':
llprint("\nIteration %d/%d" % (i, iterations))
# We test now the learned generalization using sequence_max_length examples
random_length = np.random.randint(2, sequence_max_length * 10 + 1)
input_data, target_output, loss_weights = generate_data(random_length, input_size)
input_data, target_output = generate_data(batch_size, random_length, args.input_size, args.cuda)
if rnn.debug:
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
else:
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
output = output[:, -1, :].sum().data.cpu().numpy()[0]
output = output[:, -1, :].sum().data.cpu().numpy()
target_output = target_output.sum().data.cpu().numpy()
try:
print("\nReal value: ", ' = ' + str(int(target_output[0])))
print("\nReal value: ", ' = ' + str(int(target_output)))
print("Predicted: ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
except Exception as e:
pass