Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
2d35eab3a6 |
@ -363,18 +363,18 @@ if __name__ == '__main__':
|
|||||||
llprint("\nIteration %d/%d" % (i, iterations))
|
llprint("\nIteration %d/%d" % (i, iterations))
|
||||||
# We test now the learned generalization using sequence_max_length examples
|
# We test now the learned generalization using sequence_max_length examples
|
||||||
random_length = np.random.randint(2, sequence_max_length * 10 + 1)
|
random_length = np.random.randint(2, sequence_max_length * 10 + 1)
|
||||||
input_data, target_output, loss_weights = generate_data(random_length, input_size)
|
input_data, target_output = generate_data(batch_size, random_length, args.input_size, args.cuda)
|
||||||
|
|
||||||
if rnn.debug:
|
if rnn.debug:
|
||||||
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
output, (chx, mhx, rv), v = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
else:
|
else:
|
||||||
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
|
||||||
output = output[:, -1, :].sum().data.cpu().numpy()[0]
|
output = output[:, -1, :].sum().data.cpu().numpy()
|
||||||
target_output = target_output.sum().data.cpu().numpy()
|
target_output = target_output.sum().data.cpu().numpy()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print("\nReal value: ", ' = ' + str(int(target_output[0])))
|
print("\nReal value: ", ' = ' + str(int(target_output)))
|
||||||
print("Predicted: ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
|
print("Predicted: ", ' = ' + str(int(output // 1)) + " [" + str(output) + "]")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user