dnc-with-demon/Utils/universal.py
2022-11-05 14:59:40 -07:00

264 lines
6.0 KiB
Python

# Copyright 2017 Robert Csordas. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
import torch
import torch.nn.functional as F
import numpy as np
float32 = [np.float32, torch.float32]
float64 = [np.float64, torch.float64]
uint8 = [np.uint8, torch.uint8]
_all_types = [float32, float64, uint8]
_dtype_numpy_map = {v[0]().dtype.name:v for v in _all_types}
_dtype_pytorch_map = {v[1]:v for v in _all_types}
def dtype(t):
if torch.is_tensor(t):
return _dtype_pytorch_map[t.dtype]
else:
return _dtype_numpy_map[t.dtype.name]
def cast(t, type):
if torch.is_tensor(t):
return t.type(type[1])
else:
return t.astype(type[0])
def to_numpy(t):
if torch.is_tensor(t):
return t.detach().cpu().numpy()
else:
return t
def to_list(t):
t = to_numpy(t)
if isinstance(t, np.ndarray):
t = t.tolist()
return t
def is_tensor(t):
return torch.is_tensor(t) or isinstance(t, np.ndarray)
def first_batch(t):
if is_tensor(t):
return t[0]
else:
return t
def ndim(t):
if torch.is_tensor(t):
return t.dim()
else:
return t.ndim
def shape(t):
return list(t.shape)
def transpose(t, axis):
if torch.is_tensor(t):
return t.permute(axis)
else:
return np.transpose(t, axis)
def apply_recursive(d, fn, filter=None):
if isinstance(d, list):
return [apply_recursive(da, fn) for da in d]
elif isinstance(d, tuple):
return tuple(apply_recursive(list(d), fn))
elif isinstance(d, dict):
return {k: apply_recursive(v, fn) for k, v in d.items()}
else:
if filter is None or filter(d):
return fn(d)
else:
return d
def apply_to_tensors(d, fn):
return apply_recursive(d, fn, torch.is_tensor)
def recursive_decorator(apply_this_fn):
def decorator(func):
def wrapped_funct(*args, **kwargs):
args = apply_recursive(args, apply_this_fn)
kwargs = apply_recursive(kwargs, apply_this_fn)
return func(*args, **kwargs)
return wrapped_funct
return decorator
untensor = recursive_decorator(to_numpy)
unnumpy = recursive_decorator(to_list)
def unbatch(only_if_dim_equal=None):
if only_if_dim_equal is not None and not isinstance(only_if_dim_equal, list):
only_if_dim_equal = [only_if_dim_equal]
def get_first_batch(t):
if is_tensor(t) and (only_if_dim_equal is None or ndim(t) in only_if_dim_equal):
return t[0]
else:
return t
return recursive_decorator(get_first_batch)
def sigmoid(t):
if torch.is_tensor(t):
return torch.sigmoid(t)
else:
return 1.0 / (1.0 + np.exp(-t))
def argmax(t, dim):
if torch.is_tensor(t):
_, res = t.max(dim)
else:
res = np.argmax(t, axis=dim)
return res
def flip(t, axis):
if torch.is_tensor(t):
return t.flip(axis)
else:
return np.flip(t, axis)
def transpose(t, axes):
if torch.is_tensor(t):
return t.permute(*axes)
else:
return np.transpose(t, axes)
def split_n(t, axis):
if torch.is_tensor(t):
return t.split(1, dim=axis)
else:
return np.split(t, t.shape[axis], axis=axis)
def cat(array_of_tensors, axis):
if torch.is_tensor(array_of_tensors[0]):
return torch.cat(array_of_tensors, axis)
else:
return np.concatenate(array_of_tensors, axis)
def clamp(t, min=None, max=None):
if torch.is_tensor(t):
return t.clamp(min, max)
else:
if min is not None:
t = np.maximum(t, min)
if max is not None:
t = np.minimum(t, max)
return t
def power(t, p):
if torch.is_tensor(t) or torch.is_tensor(p):
return torch.pow(t, p)
else:
return np.power(t, p)
def random_normal_as(a, mean, std, seed=None):
if torch.is_tensor(a):
return torch.randn_like(a) * std + mean
else:
if seed is None:
seed = np.random
return seed.normal(loc=mean, scale=std, size=shape(a))
def pad(t, pad):
assert ndim(t) == 4
if torch.is_tensor(t):
return F.pad(t, pad)
else:
assert np.pad(t, ([0,0], [0,0], pad[0:2], pad[2:]))
def dx(img):
lsh = img[:, :, :, 2:]
orig = img[:, :, :, :-2]
return pad(0.5 * (lsh - orig), (1, 1, 0, 0))
def dy(img):
ush = img[:, :, 2:, :]
orig = img[:, :, :-2, :]
return pad(0.5 * (ush - orig), (0, 0, 1, 1))
def reshape(t, shape):
if torch.is_tensor(t):
return t.view(*shape)
else:
return t.reshape(*shape)
def broadcast_to_beginning(t, target):
if torch.is_tensor(t):
nd_target = target.dim()
t_shape = list(t.shape)
return t.view(*t_shape, *([1]*(nd_target-len(t_shape))))
else:
nd_target = target.ndim
t_shape = list(t.shape)
return t.reshape(*t_shape, *([1] * (nd_target - len(t_shape))))
def lin_combine(d1,w1, d2,w2, bcast_begin=False):
if isinstance(d1, (list, tuple)):
assert len(d1) == len(d2)
res = [lin_combine(d1[i], w1, d2[i], w2) for i in range(len(d1))]
if isinstance(d1, tuple):
res = tuple(d1)
elif isinstance(d1, dict):
res = {k: lin_combine(v, w1, d2[k], w2) for k, v in d1.items()}
else:
if bcast_begin:
w1 = broadcast_to_beginning(w1, d1)
w2 = broadcast_to_beginning(w2, d2)
res = d1 * w1 + d2 * w2
return res