dnc_pytorch/tests/util.py
2022-08-23 21:58:43 +09:00

10 lines
223 B
Python

import torch
import torch.nn.functional as F
def one_hot(length, index, dtype=None):
val = F.one_hot(torch.tensor(index), num_classes=length)
if dtype is not None:
val = val.to(dtype=dtype)
return val