Merge branch 'master' of github.com:ixaxaar/pytorch-dnc

This commit is contained in:
Russi Chatterjee 2019-04-05 11:47:09 +05:30
commit 4178130e8f
3 changed files with 23 additions and 11 deletions

View File

@ -122,7 +122,7 @@ rnn = DNC(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors) = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
@ -150,7 +150,7 @@ rnn = DNC(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors), debug_memory = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
Memory vectors returned by forward pass (`np.ndarray`):
@ -223,7 +223,7 @@ rnn = SDNC(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors) = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
@ -253,7 +253,7 @@ rnn = SDNC(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors), debug_memory = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
Memory vectors returned by forward pass (`np.ndarray`):
@ -327,7 +327,7 @@ rnn = SAM(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors) = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
@ -356,7 +356,7 @@ rnn = SAM(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors), debug_memory = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
Memory vectors returned by forward pass (`np.ndarray`):

View File

@ -120,6 +120,7 @@ class DNC(nn.Module):
if self.gpu_id != -1:
[x.cuda(self.gpu_id) for x in self.rnns]
[x.cuda(self.gpu_id) for x in self.memories]
self.output.cuda()
def _init_hidden(self, hx, batch_size, reset_experience):
# create empty hidden states if not provided

View File

@ -48,6 +48,12 @@ class SparseMemory(nn.Module):
self.c = (r * self.K) + 1
if self.independent_linears:
if self.gpu_id != -1:
self.read_query_transform = nn.Linear(self.input_size, w * r).cuda()
self.write_vector_transform = nn.Linear(self.input_size, w).cuda()
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c).cuda()
self.write_gate_transform = nn.Linear(self.input_size, 1).cuda()
else:
self.read_query_transform = nn.Linear(self.input_size, w * r)
self.write_vector_transform = nn.Linear(self.input_size, w)
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
@ -58,6 +64,9 @@ class SparseMemory(nn.Module):
T.nn.init.orthogonal(self.write_gate_transform.weight)
else:
self.interface_size = (r * w) + w + self.c + 1
if self.gpu_id != -1:
self.interface_weights = nn.Linear(self.input_size, self.interface_size).cuda()
else:
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
T.nn.init.orthogonal(self.interface_weights.weight)
@ -65,6 +74,8 @@ class SparseMemory(nn.Module):
self.δ = 0.005 # minimum usage
self.timestep = 0
self.mem_limit_reached = False
if self.gpu_id != -1:
self.cuda()
def rebuild_indexes(self, hidden, erase=False):
b = hidden['memory'].size(0)