Merge pull request #28 from jbinas/patch-1

memory.py: fix indexing for read_modes transform
This commit is contained in:
Russi Chatterjee 2018-04-23 11:56:23 +05:30 committed by GitHub
commit 2e24452dfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -254,7 +254,7 @@ class Memory(nn.Module):
# write gate (b * 1) # write gate (b * 1)
write_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1) write_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1)
# read modes (b * 3*r) # read modes (b * 3*r)
read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 2: r * w + 5 * r + 3 * w + 2].contiguous().view(b, r, 3), 1) read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 3: r * w + 5 * r + 3 * w + 3].contiguous().view(b, r, 3), 1)
hidden = self.write(write_key, write_vector, erase_vector, free_gates, hidden = self.write(write_key, write_vector, erase_vector, free_gates,
read_strengths, write_strength, write_gate, allocation_gate, hidden) read_strengths, write_strength, write_gate, allocation_gate, hidden)