This commit is contained in:
Russi Chatterjee 2019-07-25 11:27:54 +05:30
parent 47140303e9
commit 79dc405f37
3 changed files with 6 additions and 4 deletions

View File

@ -99,7 +99,9 @@ def generate_data(length, size):
sums = np.array(sums)
sums = sums.reshape(1, 1, 1)
return cudavec(x_seq_list, gpu_id=args.cuda).float(), cudavec(sums, gpu_id=args.cuda).float(), sums_text
return cudavec(x_seq_list.astype(np.float32), gpu_id=args.cuda).float(), \
cudavec(sums.astype(np.float32), gpu_id=args.cuda).float(), \
sums_text
def cross_entropy(prediction, target):
@ -221,7 +223,7 @@ if __name__ == '__main__':
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
optimizer.step()
loss_value = loss.data[0]
loss_value = loss.item()
# detach memory from graph
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }

View File

@ -227,7 +227,7 @@ if __name__ == '__main__':
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
optimizer.step()
loss_value = loss.data[0]
loss_value = loss.item()
# detach memory from graph
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }

View File

@ -214,7 +214,7 @@ if __name__ == '__main__':
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
optimizer.step()
loss_value = loss.data[0]
loss_value = loss.item()
summarize = (epoch % summarize_freq == 0)
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)