Merge branch 'master' of github.com:ixaxaar/pytorch-dnc
This commit is contained in:
commit
4178130e8f
12
README.md
12
README.md
@ -122,7 +122,7 @@ rnn = DNC(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors) = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
|
||||
@ -150,7 +150,7 @@ rnn = DNC(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
Memory vectors returned by forward pass (`np.ndarray`):
|
||||
@ -223,7 +223,7 @@ rnn = SDNC(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors) = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
|
||||
@ -253,7 +253,7 @@ rnn = SDNC(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
Memory vectors returned by forward pass (`np.ndarray`):
|
||||
@ -327,7 +327,7 @@ rnn = SAM(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors) = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
|
||||
@ -356,7 +356,7 @@ rnn = SAM(
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||
```
|
||||
|
||||
Memory vectors returned by forward pass (`np.ndarray`):
|
||||
|
@ -120,6 +120,7 @@ class DNC(nn.Module):
|
||||
if self.gpu_id != -1:
|
||||
[x.cuda(self.gpu_id) for x in self.rnns]
|
||||
[x.cuda(self.gpu_id) for x in self.memories]
|
||||
self.output.cuda()
|
||||
|
||||
def _init_hidden(self, hx, batch_size, reset_experience):
|
||||
# create empty hidden states if not provided
|
||||
|
@ -48,23 +48,34 @@ class SparseMemory(nn.Module):
|
||||
self.c = (r * self.K) + 1
|
||||
|
||||
if self.independent_linears:
|
||||
self.read_query_transform = nn.Linear(self.input_size, w * r)
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||
if self.gpu_id != -1:
|
||||
self.read_query_transform = nn.Linear(self.input_size, w * r).cuda()
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w).cuda()
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c).cuda()
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1).cuda()
|
||||
else:
|
||||
self.read_query_transform = nn.Linear(self.input_size, w * r)
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||
T.nn.init.orthogonal(self.read_query_transform.weight)
|
||||
T.nn.init.orthogonal(self.write_vector_transform.weight)
|
||||
T.nn.init.orthogonal(self.interpolation_gate_transform.weight)
|
||||
T.nn.init.orthogonal(self.write_gate_transform.weight)
|
||||
else:
|
||||
self.interface_size = (r * w) + w + self.c + 1
|
||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
||||
if self.gpu_id != -1:
|
||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size).cuda()
|
||||
else:
|
||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
||||
T.nn.init.orthogonal(self.interface_weights.weight)
|
||||
|
||||
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
self.δ = 0.005 # minimum usage
|
||||
self.timestep = 0
|
||||
self.mem_limit_reached = False
|
||||
if self.gpu_id != -1:
|
||||
self.cuda()
|
||||
|
||||
def rebuild_indexes(self, hidden, erase=False):
|
||||
b = hidden['memory'].size(0)
|
||||
|
Loading…
Reference in New Issue
Block a user