fix #45
This commit is contained in:
parent
47140303e9
commit
79dc405f37
@ -99,7 +99,9 @@ def generate_data(length, size):
|
||||
sums = np.array(sums)
|
||||
sums = sums.reshape(1, 1, 1)
|
||||
|
||||
return cudavec(x_seq_list, gpu_id=args.cuda).float(), cudavec(sums, gpu_id=args.cuda).float(), sums_text
|
||||
return cudavec(x_seq_list.astype(np.float32), gpu_id=args.cuda).float(), \
|
||||
cudavec(sums.astype(np.float32), gpu_id=args.cuda).float(), \
|
||||
sums_text
|
||||
|
||||
|
||||
def cross_entropy(prediction, target):
|
||||
@ -221,7 +223,7 @@ if __name__ == '__main__':
|
||||
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
||||
optimizer.step()
|
||||
loss_value = loss.data[0]
|
||||
loss_value = loss.item()
|
||||
|
||||
# detach memory from graph
|
||||
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
||||
|
@ -227,7 +227,7 @@ if __name__ == '__main__':
|
||||
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
||||
optimizer.step()
|
||||
loss_value = loss.data[0]
|
||||
loss_value = loss.item()
|
||||
|
||||
# detach memory from graph
|
||||
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
||||
|
@ -214,7 +214,7 @@ if __name__ == '__main__':
|
||||
|
||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
||||
optimizer.step()
|
||||
loss_value = loss.data[0]
|
||||
loss_value = loss.item()
|
||||
|
||||
summarize = (epoch % summarize_freq == 0)
|
||||
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)
|
||||
|
Loading…
Reference in New Issue
Block a user