fix #45
This commit is contained in:
parent
47140303e9
commit
79dc405f37
@ -99,7 +99,9 @@ def generate_data(length, size):
|
|||||||
sums = np.array(sums)
|
sums = np.array(sums)
|
||||||
sums = sums.reshape(1, 1, 1)
|
sums = sums.reshape(1, 1, 1)
|
||||||
|
|
||||||
return cudavec(x_seq_list, gpu_id=args.cuda).float(), cudavec(sums, gpu_id=args.cuda).float(), sums_text
|
return cudavec(x_seq_list.astype(np.float32), gpu_id=args.cuda).float(), \
|
||||||
|
cudavec(sums.astype(np.float32), gpu_id=args.cuda).float(), \
|
||||||
|
sums_text
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy(prediction, target):
|
def cross_entropy(prediction, target):
|
||||||
@ -221,7 +223,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
loss_value = loss.data[0]
|
loss_value = loss.item()
|
||||||
|
|
||||||
# detach memory from graph
|
# detach memory from graph
|
||||||
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
||||||
|
@ -227,7 +227,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
loss_value = loss.data[0]
|
loss_value = loss.item()
|
||||||
|
|
||||||
# detach memory from graph
|
# detach memory from graph
|
||||||
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
mhx = { k : (v.detach() if isinstance(v, var) else v) for k, v in mhx.items() }
|
||||||
|
@ -214,7 +214,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
T.nn.utils.clip_grad_norm_(rnn.parameters(), args.clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
loss_value = loss.data[0]
|
loss_value = loss.item()
|
||||||
|
|
||||||
summarize = (epoch % summarize_freq == 0)
|
summarize = (epoch % summarize_freq == 0)
|
||||||
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)
|
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)
|
||||||
|
Loading…
Reference in New Issue
Block a user