Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
2d35eab3a6 |
@ -363,18 +363,18 @@ if __name__ == '__main__':
|
||||
llprint("\nIteration %d/%d" % (i, iterations))
|
||||
# We test now the learned generalization using sequence_max_length examples
|
||||
random_length = np.random.randint(2, sequence_max_length * 10 + 1)
|
||||
input_data, target_output, loss_weights = generate_data(random_length, input_size)
|
||||
input_data, target_output = generate_data(batch_size, random_length, args.input_size, args.cuda)
|
||||
|
||||
if rnn.debug:
|
||||
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||
else:
|
||||
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||
|
||||
output = output[:, -1, :].sum().data.cpu().numpy()[0]
|
||||
output = output[:, -1, :].sum().data.cpu().numpy()
|
||||
target_output = target_output.sum().data.cpu().numpy()
|
||||
|
||||
try:
|
||||
print("\nReal value: ", ' = ' + str(int(target_output[0])))
|
||||
print("\nReal value: ", ' = ' + str(int(target_output)))
|
||||
print("Predicted: ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user