commit
c48d3d9ba4
@ -232,7 +232,7 @@ class Memory(nn.Module):
|
|||||||
# write gate (b * 1)
|
# write gate (b * 1)
|
||||||
write_gate = T.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
write_gate = T.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
||||||
# read modes (b * r * 3)
|
# read modes (b * r * 3)
|
||||||
read_modes = σ(self.read_modes_transform(ξ).view(b, r, 3), 1)
|
read_modes = σ(self.read_modes_transform(ξ).view(b, r, 3), -1)
|
||||||
else:
|
else:
|
||||||
ξ = self.interface_weights(ξ)
|
ξ = self.interface_weights(ξ)
|
||||||
# r read keys (b * w * r)
|
# r read keys (b * w * r)
|
||||||
@ -254,7 +254,7 @@ class Memory(nn.Module):
|
|||||||
# write gate (b * 1)
|
# write gate (b * 1)
|
||||||
write_gate = T.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1)
|
write_gate = T.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1)
|
||||||
# read modes (b * 3*r)
|
# read modes (b * 3*r)
|
||||||
read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 3: r * w + 5 * r + 3 * w + 3].contiguous().view(b, r, 3), 1)
|
read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 3: r * w + 5 * r + 3 * w + 3].contiguous().view(b, r, 3), -1)
|
||||||
|
|
||||||
hidden = self.write(write_key, write_vector, erase_vector, free_gates,
|
hidden = self.write(write_key, write_vector, erase_vector, free_gates,
|
||||||
read_strengths, write_strength, write_gate, allocation_gate, hidden)
|
read_strengths, write_strength, write_gate, allocation_gate, hidden)
|
||||||
|
4
setup.py
4
setup.py
@ -22,7 +22,7 @@ with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
|
|||||||
setup(
|
setup(
|
||||||
name='dnc',
|
name='dnc',
|
||||||
|
|
||||||
version='1.0.1',
|
version='1.0.2',
|
||||||
description='Differentiable Neural Computer, for Pytorch',
|
description='Differentiable Neural Computer, for Pytorch',
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ setup(
|
|||||||
|
|
||||||
packages=find_packages(exclude=['contrib', 'docs', 'tests', 'tasks', 'scripts']),
|
packages=find_packages(exclude=['contrib', 'docs', 'tests', 'tasks', 'scripts']),
|
||||||
|
|
||||||
install_requires=['torch', 'numpy', 'pyflann3'],
|
install_requires=['torch', 'numpy', 'flann'],
|
||||||
|
|
||||||
extras_require={
|
extras_require={
|
||||||
'dev': ['check-manifest'],
|
'dev': ['check-manifest'],
|
||||||
|
Loading…
Reference in New Issue
Block a user