verify gradneits flowing
This commit is contained in:
parent
c9eb3a5ca7
commit
c0b9b04129
@ -141,11 +141,11 @@ class SparseMemory(nn.Module):
|
||||
|
||||
# non-differentiable operations
|
||||
pos = positions.data.cpu().numpy()
|
||||
for b in range(positions.size(0)):
|
||||
for batch in range(b):
|
||||
# update indexes
|
||||
hidden['indexes'][b].reset()
|
||||
hidden['indexes'][b].add(hidden['memory'][b], last=pos[b][-1])
|
||||
hidden['last_used_mem'][b] = (int(pos[b][-1]) + 1) if (pos[b][-1] + 1) < self.mem_size else 0
|
||||
hidden['indexes'][batch].reset()
|
||||
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
||||
hidden['last_used_mem'][batch] = (int(pos[batch][-1]) + 1) if (pos[batch][-1] + 1) < self.mem_size else 0
|
||||
|
||||
return hidden
|
||||
|
||||
@ -185,7 +185,7 @@ class SparseMemory(nn.Module):
|
||||
relevant_usages = usage.gather(1, read_positions)
|
||||
|
||||
# indicator of words with minimal memory usage
|
||||
minusage = T.min(relevant_usages, -1)[0].unsqueeze(1)
|
||||
minusage = T.min(relevant_usages, -1, keepdim=True)[0]
|
||||
minusage = minusage.expand(relevant_usages.size())
|
||||
I = (relevant_usages == minusage).float()
|
||||
|
||||
@ -217,7 +217,6 @@ class SparseMemory(nn.Module):
|
||||
visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, r*k+1, w))
|
||||
|
||||
read_weights = F.softmax(θ(visible_memory, keys), dim=2)
|
||||
# visible_memory.register_hook(lambda x: print("========================================", x.squeeze()))
|
||||
read_vectors = T.bmm(read_weights, visible_memory)
|
||||
read_weights = T.prod(read_weights, 1)
|
||||
|
||||
@ -264,9 +263,6 @@ class SparseMemory(nn.Module):
|
||||
ξ = self.interface_weights(ξ)
|
||||
# r read keys (b * r * w)
|
||||
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
|
||||
print("999999999999999999999999")
|
||||
print(read_query.squeeze())
|
||||
# read_query.register_hook(lambda x: print("========================================", x.squeeze()))
|
||||
# write key (b * 1 * w)
|
||||
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
|
@ -154,3 +154,7 @@ def ensure_gpu(tensor, gpu_id):
|
||||
else:
|
||||
return tensor
|
||||
|
||||
|
||||
def print_gradient(x, name):
|
||||
s = "Gradient of " + name + " ----------------------------------"
|
||||
x.register_hook(lambda y: print(s, y.squeeze()))
|
||||
|
Loading…
Reference in New Issue
Block a user