133 lines
3.9 KiB
Python
133 lines
3.9 KiB
Python
|
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
#
|
||
|
# ==============================================================================
|
||
|
|
||
|
import numpy as np
|
||
|
import sys
|
||
|
import traceback
|
||
|
import torch
|
||
|
|
||
|
enableDebug = False
|
||
|
|
||
|
def nan_check(arg, name=None, force=False):
|
||
|
if not enableDebug and not force:
|
||
|
return arg
|
||
|
is_nan = False
|
||
|
curr_nan = False
|
||
|
if isinstance(arg, torch.autograd.Variable):
|
||
|
curr_nan = not np.isfinite(arg.sum().cpu().data.numpy())
|
||
|
elif isinstance(arg, torch.nn.parameter.Parameter):
|
||
|
curr_nan = (not np.isfinite(arg.sum().cpu().data.numpy())) or (not np.isfinite(arg.grad.sum().cpu().data.numpy()))
|
||
|
elif isinstance(arg, float):
|
||
|
curr_nan = not np.isfinite(arg)
|
||
|
elif isinstance(arg, (list, tuple)):
|
||
|
for a in arg:
|
||
|
nan_check(a)
|
||
|
else:
|
||
|
assert False, "Unsupported type %s" % type(arg)
|
||
|
|
||
|
if curr_nan:
|
||
|
if sys.exc_info()[0] is not None:
|
||
|
trace = str(traceback.format_exc())
|
||
|
else:
|
||
|
trace = "".join(traceback.format_stack())
|
||
|
|
||
|
print(arg)
|
||
|
if name is not None:
|
||
|
print("NaN found in %s." % name)
|
||
|
else:
|
||
|
print("NaN found.")
|
||
|
if isinstance(arg, torch.autograd.Variable):
|
||
|
print(" Argument is a torch tensor. Shape: %s" % list(arg.size()))
|
||
|
|
||
|
print(trace)
|
||
|
sys.exit(-1)
|
||
|
|
||
|
return arg
|
||
|
|
||
|
|
||
|
def assert_range(t, min=0.0, max=1.0):
|
||
|
if not enableDebug:
|
||
|
return
|
||
|
|
||
|
if t.min().cpu().data.numpy()<min or t.max().cpu().data.numpy()>max:
|
||
|
print(t)
|
||
|
assert False
|
||
|
|
||
|
|
||
|
def assert_dist(t, use_lower_limit=True):
|
||
|
if not enableDebug:
|
||
|
return
|
||
|
|
||
|
assert_range(t)
|
||
|
|
||
|
if t.sum(-1).max().cpu().data.numpy()>1.001:
|
||
|
print("MAT:", t)
|
||
|
print("SUM:", t.sum(-1))
|
||
|
assert False
|
||
|
|
||
|
if use_lower_limit and t.sum(-1).max().cpu().data.numpy()<0.999:
|
||
|
print(t)
|
||
|
print("SUM:", t.sum(-1))
|
||
|
assert False
|
||
|
|
||
|
|
||
|
def print_stat(name, t):
|
||
|
if not enableDebug:
|
||
|
return
|
||
|
|
||
|
min = t.min().cpu().data.numpy()
|
||
|
max = t.max().cpu().data.numpy()
|
||
|
mean = t.mean().cpu().data.numpy()
|
||
|
|
||
|
print("%s: min: %g, mean: %g, max: %g" % (name, min, mean, max))
|
||
|
|
||
|
|
||
|
def dbg_print(*things):
|
||
|
if not enableDebug:
|
||
|
return
|
||
|
print(*things)
|
||
|
|
||
|
class GradPrinter(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx, a):
|
||
|
return a
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, g):
|
||
|
print("Grad (print_grad): ", g[0])
|
||
|
return g
|
||
|
|
||
|
def print_grad(t):
|
||
|
return GradPrinter.apply(t)
|
||
|
|
||
|
def assert_equal(t1, ref, limit=1e-5, force=True):
|
||
|
if not (enableDebug or force):
|
||
|
return
|
||
|
|
||
|
assert t1.shape==ref.shape, "Tensor shapes differ: got %s, ref %s" % (t1.shape, ref.shape)
|
||
|
norm = ref.abs().sum() / ref.nonzero().sum().float()
|
||
|
threshold = norm * limit
|
||
|
|
||
|
errcnt = ((t1 - ref).abs() > threshold).sum()
|
||
|
if errcnt > 0:
|
||
|
print("Tensors differ. (max difference: %g, norm %f). No of errors: %d of %d" %
|
||
|
((t1 - ref).abs().max().item(), norm, errcnt, t1.numel()))
|
||
|
print("---------------------------------------------Out-----------------------------------------------")
|
||
|
print(t1)
|
||
|
print("---------------------------------------------Ref-----------------------------------------------")
|
||
|
print(ref)
|
||
|
print("-----------------------------------------------------------------------------------------------")
|
||
|
assert False
|