Merge pull request #44 from ixaxaar/43

Fixes for #43
This commit is contained in:
ixaxaar 2019-05-20 19:22:13 +05:30 committed by GitHub
commit c48d3d9ba4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -232,7 +232,7 @@ class Memory(nn.Module):
# write gate (b * 1) # write gate (b * 1)
write_gate = T.sigmoid(self.write_gate_transform(ξ).view(b, 1)) write_gate = T.sigmoid(self.write_gate_transform(ξ).view(b, 1))
# read modes (b * r * 3) # read modes (b * r * 3)
read_modes = σ(self.read_modes_transform(ξ).view(b, r, 3), 1) read_modes = σ(self.read_modes_transform(ξ).view(b, r, 3), -1)
else: else:
ξ = self.interface_weights(ξ) ξ = self.interface_weights(ξ)
# r read keys (b * w * r) # r read keys (b * w * r)
@ -254,7 +254,7 @@ class Memory(nn.Module):
# write gate (b * 1) # write gate (b * 1)
write_gate = T.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1) write_gate = T.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1)
# read modes (b * 3*r) # read modes (b * 3*r)
read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 3: r * w + 5 * r + 3 * w + 3].contiguous().view(b, r, 3), 1) read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 3: r * w + 5 * r + 3 * w + 3].contiguous().view(b, r, 3), -1)
hidden = self.write(write_key, write_vector, erase_vector, free_gates, hidden = self.write(write_key, write_vector, erase_vector, free_gates,
read_strengths, write_strength, write_gate, allocation_gate, hidden) read_strengths, write_strength, write_gate, allocation_gate, hidden)

View File

@ -22,7 +22,7 @@ with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
setup( setup(
name='dnc', name='dnc',
version='1.0.1', version='1.0.2',
description='Differentiable Neural Computer, for Pytorch', description='Differentiable Neural Computer, for Pytorch',
long_description=long_description, long_description=long_description,
@ -56,7 +56,7 @@ setup(
packages=find_packages(exclude=['contrib', 'docs', 'tests', 'tasks', 'scripts']), packages=find_packages(exclude=['contrib', 'docs', 'tests', 'tasks', 'scripts']),
install_requires=['torch', 'numpy', 'pyflann3'], install_requires=['torch', 'numpy', 'flann'],
extras_require={ extras_require={
'dev': ['check-manifest'], 'dev': ['check-manifest'],