From 092bdb8f9359d510bbeb3c33a411a2c2688e8dae Mon Sep 17 00:00:00 2001 From: Gavin Sellers Date: Fri, 29 Jun 2018 10:34:10 -0500 Subject: [PATCH 1/2] fix parens in example usage --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index ba5bba0..615a521 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ rnn = DNC( (controller_hidden, memory, read_vectors) = (None, None, None) output, (controller_hidden, memory, read_vectors) = \ - rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True)) + rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True) ``` @@ -150,7 +150,7 @@ rnn = DNC( (controller_hidden, memory, read_vectors) = (None, None, None) output, (controller_hidden, memory, read_vectors), debug_memory = \ - rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True)) + rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True) ``` Memory vectors returned by forward pass (`np.ndarray`): @@ -223,7 +223,7 @@ rnn = SDNC( (controller_hidden, memory, read_vectors) = (None, None, None) output, (controller_hidden, memory, read_vectors) = \ - rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True)) + rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True) ``` @@ -253,7 +253,7 @@ rnn = SDNC( (controller_hidden, memory, read_vectors) = (None, None, None) output, (controller_hidden, memory, read_vectors), debug_memory = \ - rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True)) + rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True) ``` Memory vectors returned by forward pass (`np.ndarray`): @@ -327,7 +327,7 @@ rnn = SAM( (controller_hidden, memory, read_vectors) = (None, None, None) output, (controller_hidden, memory, read_vectors) = \ - rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True)) + rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True) ``` @@ -356,7 +356,7 @@ rnn = SAM( (controller_hidden, memory, read_vectors) = (None, None, None) output, (controller_hidden, memory, read_vectors), debug_memory = \ - rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True)) + rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True) ``` Memory vectors returned by forward pass (`np.ndarray`): From cc2c3bcebc47005dce1764684c6da054419cd706 Mon Sep 17 00:00:00 2001 From: Gavin Sellers Date: Sun, 8 Jul 2018 10:50:57 -0500 Subject: [PATCH 2/2] fix gpu usage --- dnc/dnc.py | 1 + dnc/sparse_memory.py | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/dnc/dnc.py b/dnc/dnc.py index 8c22c2f..55e3a85 100644 --- a/dnc/dnc.py +++ b/dnc/dnc.py @@ -120,6 +120,7 @@ class DNC(nn.Module): if self.gpu_id != -1: [x.cuda(self.gpu_id) for x in self.rnns] [x.cuda(self.gpu_id) for x in self.memories] + self.output.cuda() def _init_hidden(self, hx, batch_size, reset_experience): # create empty hidden states if not provided diff --git a/dnc/sparse_memory.py b/dnc/sparse_memory.py index 9f48c8f..a778f0f 100644 --- a/dnc/sparse_memory.py +++ b/dnc/sparse_memory.py @@ -48,23 +48,34 @@ class SparseMemory(nn.Module): self.c = (r * self.K) + 1 if self.independent_linears: - self.read_query_transform = nn.Linear(self.input_size, w * r) - self.write_vector_transform = nn.Linear(self.input_size, w) - self.interpolation_gate_transform = nn.Linear(self.input_size, self.c) - self.write_gate_transform = nn.Linear(self.input_size, 1) + if self.gpu_id != -1: + self.read_query_transform = nn.Linear(self.input_size, w * r).cuda() + self.write_vector_transform = nn.Linear(self.input_size, w).cuda() + self.interpolation_gate_transform = nn.Linear(self.input_size, self.c).cuda() + self.write_gate_transform = nn.Linear(self.input_size, 1).cuda() + else: + self.read_query_transform = nn.Linear(self.input_size, w * r) + self.write_vector_transform = nn.Linear(self.input_size, w) + self.interpolation_gate_transform = nn.Linear(self.input_size, self.c) + self.write_gate_transform = nn.Linear(self.input_size, 1) T.nn.init.orthogonal(self.read_query_transform.weight) T.nn.init.orthogonal(self.write_vector_transform.weight) T.nn.init.orthogonal(self.interpolation_gate_transform.weight) T.nn.init.orthogonal(self.write_gate_transform.weight) else: self.interface_size = (r * w) + w + self.c + 1 - self.interface_weights = nn.Linear(self.input_size, self.interface_size) + if self.gpu_id != -1: + self.interface_weights = nn.Linear(self.input_size, self.interface_size).cuda() + else: + self.interface_weights = nn.Linear(self.input_size, self.interface_size) T.nn.init.orthogonal(self.interface_weights.weight) self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n) self.δ = 0.005 # minimum usage self.timestep = 0 self.mem_limit_reached = False + if self.gpu_id != -1: + self.cuda() def rebuild_indexes(self, hidden, erase=False): b = hidden['memory'].size(0)