fix porting bugs :D

This commit is contained in:
Russi Chatterjee 2019-04-05 11:52:18 +05:30
commit f528a4c120
3 changed files with 23 additions and 11 deletions

View File

@ -122,7 +122,7 @@ rnn = DNC(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors) = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
@ -150,7 +150,7 @@ rnn = DNC(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors), debug_memory = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
Memory vectors returned by forward pass (`np.ndarray`):
@ -223,7 +223,7 @@ rnn = SDNC(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors) = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
@ -253,7 +253,7 @@ rnn = SDNC(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors), debug_memory = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
Memory vectors returned by forward pass (`np.ndarray`):
@ -327,7 +327,7 @@ rnn = SAM(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors) = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
@ -356,7 +356,7 @@ rnn = SAM(
(controller_hidden, memory, read_vectors) = (None, None, None)
output, (controller_hidden, memory, read_vectors), debug_memory = \
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
```
Memory vectors returned by forward pass (`np.ndarray`):

View File

@ -120,6 +120,7 @@ class DNC(nn.Module):
if self.gpu_id != -1:
[x.cuda(self.gpu_id) for x in self.rnns]
[x.cuda(self.gpu_id) for x in self.memories]
self.output.cuda()
def _init_hidden(self, hx, batch_size, reset_experience):
# create empty hidden states if not provided

View File

@ -48,23 +48,34 @@ class SparseMemory(nn.Module):
self.c = (r * self.K) + 1
if self.independent_linears:
self.read_query_transform = nn.Linear(self.input_size, w * r)
self.write_vector_transform = nn.Linear(self.input_size, w)
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
self.write_gate_transform = nn.Linear(self.input_size, 1)
if self.gpu_id != -1:
self.read_query_transform = nn.Linear(self.input_size, w * r).cuda()
self.write_vector_transform = nn.Linear(self.input_size, w).cuda()
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c).cuda()
self.write_gate_transform = nn.Linear(self.input_size, 1).cuda()
else:
self.read_query_transform = nn.Linear(self.input_size, w * r)
self.write_vector_transform = nn.Linear(self.input_size, w)
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
self.write_gate_transform = nn.Linear(self.input_size, 1)
T.nn.init.orthogonal_(self.read_query_transform.weight)
T.nn.init.orthogonal_(self.write_vector_transform.weight)
T.nn.init.orthogonal_(self.interpolation_gate_transform.weight)
T.nn.init.orthogonal_(self.write_gate_transform.weight)
else:
self.interface_size = (r * w) + w + self.c + 1
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
if self.gpu_id != -1:
self.interface_weights = nn.Linear(self.input_size, self.interface_size).cuda()
else:
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
T.nn.init.orthogonal_(self.interface_weights.weight)
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
self.δ = 0.005 # minimum usage
self.timestep = 0
self.mem_limit_reached = False
if self.gpu_id != -1:
self.cuda()
def rebuild_indexes(self, hidden, erase=False):
b = hidden['memory'].size(0)