fix porting bugs :D
This commit is contained in:
commit
f528a4c120
12
README.md
12
README.md
@ -122,7 +122,7 @@ rnn = DNC(
|
|||||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||||
|
|
||||||
output, (controller_hidden, memory, read_vectors) = \
|
output, (controller_hidden, memory, read_vectors) = \
|
||||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ rnn = DNC(
|
|||||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||||
|
|
||||||
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
||||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
Memory vectors returned by forward pass (`np.ndarray`):
|
Memory vectors returned by forward pass (`np.ndarray`):
|
||||||
@ -223,7 +223,7 @@ rnn = SDNC(
|
|||||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||||
|
|
||||||
output, (controller_hidden, memory, read_vectors) = \
|
output, (controller_hidden, memory, read_vectors) = \
|
||||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@ -253,7 +253,7 @@ rnn = SDNC(
|
|||||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||||
|
|
||||||
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
||||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
Memory vectors returned by forward pass (`np.ndarray`):
|
Memory vectors returned by forward pass (`np.ndarray`):
|
||||||
@ -327,7 +327,7 @@ rnn = SAM(
|
|||||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||||
|
|
||||||
output, (controller_hidden, memory, read_vectors) = \
|
output, (controller_hidden, memory, read_vectors) = \
|
||||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@ -356,7 +356,7 @@ rnn = SAM(
|
|||||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||||
|
|
||||||
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
output, (controller_hidden, memory, read_vectors), debug_memory = \
|
||||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors, reset_experience=True))
|
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors), reset_experience=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
Memory vectors returned by forward pass (`np.ndarray`):
|
Memory vectors returned by forward pass (`np.ndarray`):
|
||||||
|
@ -120,6 +120,7 @@ class DNC(nn.Module):
|
|||||||
if self.gpu_id != -1:
|
if self.gpu_id != -1:
|
||||||
[x.cuda(self.gpu_id) for x in self.rnns]
|
[x.cuda(self.gpu_id) for x in self.rnns]
|
||||||
[x.cuda(self.gpu_id) for x in self.memories]
|
[x.cuda(self.gpu_id) for x in self.memories]
|
||||||
|
self.output.cuda()
|
||||||
|
|
||||||
def _init_hidden(self, hx, batch_size, reset_experience):
|
def _init_hidden(self, hx, batch_size, reset_experience):
|
||||||
# create empty hidden states if not provided
|
# create empty hidden states if not provided
|
||||||
|
@ -48,23 +48,34 @@ class SparseMemory(nn.Module):
|
|||||||
self.c = (r * self.K) + 1
|
self.c = (r * self.K) + 1
|
||||||
|
|
||||||
if self.independent_linears:
|
if self.independent_linears:
|
||||||
self.read_query_transform = nn.Linear(self.input_size, w * r)
|
if self.gpu_id != -1:
|
||||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
self.read_query_transform = nn.Linear(self.input_size, w * r).cuda()
|
||||||
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
self.write_vector_transform = nn.Linear(self.input_size, w).cuda()
|
||||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c).cuda()
|
||||||
|
self.write_gate_transform = nn.Linear(self.input_size, 1).cuda()
|
||||||
|
else:
|
||||||
|
self.read_query_transform = nn.Linear(self.input_size, w * r)
|
||||||
|
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||||
|
self.interpolation_gate_transform = nn.Linear(self.input_size, self.c)
|
||||||
|
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||||
T.nn.init.orthogonal_(self.read_query_transform.weight)
|
T.nn.init.orthogonal_(self.read_query_transform.weight)
|
||||||
T.nn.init.orthogonal_(self.write_vector_transform.weight)
|
T.nn.init.orthogonal_(self.write_vector_transform.weight)
|
||||||
T.nn.init.orthogonal_(self.interpolation_gate_transform.weight)
|
T.nn.init.orthogonal_(self.interpolation_gate_transform.weight)
|
||||||
T.nn.init.orthogonal_(self.write_gate_transform.weight)
|
T.nn.init.orthogonal_(self.write_gate_transform.weight)
|
||||||
else:
|
else:
|
||||||
self.interface_size = (r * w) + w + self.c + 1
|
self.interface_size = (r * w) + w + self.c + 1
|
||||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
if self.gpu_id != -1:
|
||||||
|
self.interface_weights = nn.Linear(self.input_size, self.interface_size).cuda()
|
||||||
|
else:
|
||||||
|
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
||||||
T.nn.init.orthogonal_(self.interface_weights.weight)
|
T.nn.init.orthogonal_(self.interface_weights.weight)
|
||||||
|
|
||||||
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
self.I = cuda(1 - T.eye(self.c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||||
self.δ = 0.005 # minimum usage
|
self.δ = 0.005 # minimum usage
|
||||||
self.timestep = 0
|
self.timestep = 0
|
||||||
self.mem_limit_reached = False
|
self.mem_limit_reached = False
|
||||||
|
if self.gpu_id != -1:
|
||||||
|
self.cuda()
|
||||||
|
|
||||||
def rebuild_indexes(self, hidden, erase=False):
|
def rebuild_indexes(self, hidden, erase=False):
|
||||||
b = hidden['memory'].size(0)
|
b = hidden['memory'].size(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user