Prepare for API deprecation

This commit is contained in:
ixaxaar 2017-12-18 12:29:02 +05:30
parent 973b51b36a
commit 264bdfb2f0

View File

@ -89,7 +89,10 @@ def σ(input, axis=1):
trans_size = trans_input.size()
input_2d = trans_input.contiguous().view(-1, trans_size[-1])
soft_max_2d = F.softmax(input_2d, -1)
if '0.3' in T.__version__:
soft_max_2d = F.softmax(input_2d, -1)
else:
soft_max_2d = F.softmax(input_2d)
soft_max_nd = soft_max_2d.view(*trans_size)
return soft_max_nd.transpose(axis, len(input_size) - 1)