fix bug in function \theta for batchwise cosine similarity

This commit is contained in:
rfeinman 2020-11-23 08:22:24 -05:00
parent 00bfa63bc5
commit a660434d21

View File

@ -56,29 +56,23 @@ def cudalong(x, grad=False, gpu_id=-1):
return t return t
def θ(a, b, dimA=2, dimB=2, normBy=2): def θ(a, b, normBy=2):
"""Batchwise Cosine distance """Batchwise Cosine similarity
Cosine distance Cosine similarity
Arguments: Arguments:
a {Tensor} -- A 3D Tensor (b * m * w) a {Tensor} -- A 3D Tensor (b * m * w)
b {Tensor} -- A 3D Tensor (b * r * w) b {Tensor} -- A 3D Tensor (b * r * w)
Keyword Arguments:
dimA {number} -- exponent value of the norm for `a` (default: {2})
dimB {number} -- exponent value of the norm for `b` (default: {1})
Returns: Returns:
Tensor -- Batchwise cosine distance (b * r * m) Tensor -- Batchwise cosine similarity (b * r * m)
""" """
a_norm = T.norm(a, normBy, dimA, keepdim=True).expand_as(a) + δ dot = T.bmm(a, b.transpose(1,2))
b_norm = T.norm(b, normBy, dimB, keepdim=True).expand_as(b) + δ a_norm = T.norm(a, normBy, dim=2).unsqueeze(2)
b_norm = T.norm(b, normBy, dim=2).unsqueeze(1)
x = T.bmm(a, b.transpose(1, 2)).transpose(1, 2) / ( cos = dot / (a_norm * b_norm + δ)
T.bmm(a_norm, b_norm.transpose(1, 2)).transpose(1, 2) + δ) return cos.transpose(1,2).contiguous()
# apply_dict(locals())
return x
def σ(input, axis=1): def σ(input, axis=1):