verify gradneits flowing

This commit is contained in:
ixaxaar 2017-12-09 15:43:11 +05:30
parent c9eb3a5ca7
commit c0b9b04129
2 changed files with 9 additions and 9 deletions

View File

@ -141,11 +141,11 @@ class SparseMemory(nn.Module):
# non-differentiable operations
pos = positions.data.cpu().numpy()
for b in range(positions.size(0)):
for batch in range(b):
# update indexes
hidden['indexes'][b].reset()
hidden['indexes'][b].add(hidden['memory'][b], last=pos[b][-1])
hidden['last_used_mem'][b] = (int(pos[b][-1]) + 1) if (pos[b][-1] + 1) < self.mem_size else 0
hidden['indexes'][batch].reset()
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
hidden['last_used_mem'][batch] = (int(pos[batch][-1]) + 1) if (pos[batch][-1] + 1) < self.mem_size else 0
return hidden
@ -185,7 +185,7 @@ class SparseMemory(nn.Module):
relevant_usages = usage.gather(1, read_positions)
# indicator of words with minimal memory usage
minusage = T.min(relevant_usages, -1)[0].unsqueeze(1)
minusage = T.min(relevant_usages, -1, keepdim=True)[0]
minusage = minusage.expand(relevant_usages.size())
I = (relevant_usages == minusage).float()
@ -217,7 +217,6 @@ class SparseMemory(nn.Module):
visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, r*k+1, w))
read_weights = F.softmax(θ(visible_memory, keys), dim=2)
# visible_memory.register_hook(lambda x: print("========================================", x.squeeze()))
read_vectors = T.bmm(read_weights, visible_memory)
read_weights = T.prod(read_weights, 1)
@ -264,9 +263,6 @@ class SparseMemory(nn.Module):
ξ = self.interface_weights(ξ)
# r read keys (b * r * w)
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
print("999999999999999999999999")
print(read_query.squeeze())
# read_query.register_hook(lambda x: print("========================================", x.squeeze()))
# write key (b * 1 * w)
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
# write vector (b * 1 * r)

View File

@ -154,3 +154,7 @@ def ensure_gpu(tensor, gpu_id):
else:
return tensor
def print_gradient(x, name):
s = "Gradient of " + name + " ----------------------------------"
x.register_hook(lambda y: print(s, y.squeeze()))