diff --git a/dnc/util.py b/dnc/util.py index e16e0be..fc246d7 100644 --- a/dnc/util.py +++ b/dnc/util.py @@ -56,29 +56,23 @@ def cudalong(x, grad=False, gpu_id=-1): return t -def θ(a, b, dimA=2, dimB=2, normBy=2): - """Batchwise Cosine distance +def θ(a, b, normBy=2): + """Batchwise Cosine similarity - Cosine distance + Cosine similarity Arguments: a {Tensor} -- A 3D Tensor (b * m * w) b {Tensor} -- A 3D Tensor (b * r * w) - Keyword Arguments: - dimA {number} -- exponent value of the norm for `a` (default: {2}) - dimB {number} -- exponent value of the norm for `b` (default: {1}) - Returns: - Tensor -- Batchwise cosine distance (b * r * m) + Tensor -- Batchwise cosine similarity (b * r * m) """ - a_norm = T.norm(a, normBy, dimA, keepdim=True).expand_as(a) + δ - b_norm = T.norm(b, normBy, dimB, keepdim=True).expand_as(b) + δ - - x = T.bmm(a, b.transpose(1, 2)).transpose(1, 2) / ( - T.bmm(a_norm, b_norm.transpose(1, 2)).transpose(1, 2) + δ) - # apply_dict(locals()) - return x + dot = T.bmm(a, b.transpose(1,2)) + a_norm = T.norm(a, normBy, dim=2).unsqueeze(2) + b_norm = T.norm(b, normBy, dim=2).unsqueeze(1) + cos = dot / (a_norm * b_norm + δ) + return cos.transpose(1,2).contiguous() def σ(input, axis=1):