Stack DNCs properly
This commit is contained in:
parent
5158e85f5d
commit
f7dc1b5aab
144
dnc/dnc.py
144
dnc/dnc.py
@ -123,8 +123,8 @@ class DNC(nn.Module):
|
||||
setattr(self, 'rnn_layer_memory_shared', self.memories[0])
|
||||
|
||||
# final output layer
|
||||
self.output_weights = nn.Linear(self.output_size, self.output_size)
|
||||
self.mem_out = nn.Linear(self.nn_output_size, self.input_size)
|
||||
self.output_weights = nn.Linear(self.nn_input_size, self.input_size)
|
||||
self.mem_out = nn.Linear(self.hidden_size, self.input_size)
|
||||
self.dropout_layer = nn.Dropout(self.dropout)
|
||||
|
||||
if self.gpu_id != -1:
|
||||
@ -165,18 +165,26 @@ class DNC(nn.Module):
|
||||
return chx, mhx, last_read
|
||||
|
||||
def _debug(self, mhx, debug_obj):
|
||||
debug_obj['memory'].append(mhx['memory'][0].data.cpu().numpy())
|
||||
debug_obj['link_matrix'].append(mhx['link_matrix'][0][0].data.cpu().numpy())
|
||||
debug_obj['precedence'].append(mhx['precedence'][0].data.cpu().numpy())
|
||||
debug_obj['read_weights'].append(mhx['read_weights'][0].data.cpu().numpy())
|
||||
debug_obj['write_weights'].append(mhx['write_weights'][0].data.cpu().numpy())
|
||||
debug_obj['usage_vector'].append(mhx['usage_vector'][0].unsqueeze(0).data.cpu().numpy())
|
||||
if not debug_obj:
|
||||
debug_obj = {
|
||||
'memory': [],
|
||||
'link_matrix': [],
|
||||
'precedence': [],
|
||||
'read_weights': [],
|
||||
'write_weights': [],
|
||||
'usage_vector': [],
|
||||
}
|
||||
|
||||
debug_obj['memory'].append(mhx['memory'][0].data.cpu().numpy())
|
||||
debug_obj['link_matrix'].append(mhx['link_matrix'][0][0].data.cpu().numpy())
|
||||
debug_obj['precedence'].append(mhx['precedence'][0].data.cpu().numpy())
|
||||
debug_obj['read_weights'].append(mhx['read_weights'][0].data.cpu().numpy())
|
||||
debug_obj['write_weights'].append(mhx['write_weights'][0].data.cpu().numpy())
|
||||
debug_obj['usage_vector'].append(mhx['usage_vector'][0].unsqueeze(0).data.cpu().numpy())
|
||||
return debug_obj
|
||||
|
||||
def _layer_forward(self, input, layer, hx=(None, None)):
|
||||
(chx, mhx) = hx
|
||||
max_length = len(input)
|
||||
outs = [0] * max_length
|
||||
read_vectors = [0] * max_length
|
||||
|
||||
if self.debug:
|
||||
mem_debug = {
|
||||
@ -188,40 +196,32 @@ class DNC(nn.Module):
|
||||
"usage_vector": [],
|
||||
}
|
||||
|
||||
for time in range(max_length):
|
||||
# pass through controller
|
||||
layer_input = input[time]
|
||||
hchx = []
|
||||
layer_input = input
|
||||
hchx = []
|
||||
for hlayer in range(self.num_hidden_layers):
|
||||
h = self.rnns[layer][hlayer](layer_input, chx[hlayer])
|
||||
layer_input = h[0] if self.rnn_type.lower() == 'lstm' else h
|
||||
hchx.append(h)
|
||||
chx = hchx
|
||||
|
||||
for hlayer in range(self.num_hidden_layers):
|
||||
h = self.rnns[layer][hlayer](layer_input, chx[hlayer])
|
||||
layer_input = h[0] if self.rnn_type.lower() == 'lstm' else h
|
||||
hchx.append(h)
|
||||
chx = hchx
|
||||
# the interface vector
|
||||
ξ = layer_input
|
||||
# the output
|
||||
output = self.dropout_layer(self.mem_out(layer_input))
|
||||
|
||||
# the interface vector
|
||||
ξ = layer_input
|
||||
# the output
|
||||
out = self.output_weights(layer_input)
|
||||
# pass through memory
|
||||
if self.share_memory:
|
||||
read_vecs, mhx = self.memories[0](ξ, mhx)
|
||||
else:
|
||||
read_vecs, mhx = self.memories[layer](ξ, mhx)
|
||||
|
||||
# pass through memory
|
||||
if self.share_memory:
|
||||
read_vecs, mhx = self.memories[0](ξ, mhx)
|
||||
if self.debug:
|
||||
self._debug(mhx, mem_debug)
|
||||
else:
|
||||
read_vecs, mhx = self.memories[layer](ξ, mhx)
|
||||
if self.debug:
|
||||
self._debug(mhx, mem_debug)
|
||||
read_vectors[time] = read_vecs.view(-1, self.w * self.r)
|
||||
|
||||
# get the final output for this time step
|
||||
outs[time] = self.dropout_layer(self.mem_out(T.cat([out, read_vectors[time]], 1)))
|
||||
# the read vectors
|
||||
read_vectors = read_vecs.view(-1, self.w * self.r)
|
||||
|
||||
if self.debug:
|
||||
return outs, read_vectors, mem_debug, (chx, mhx)
|
||||
return output, read_vectors, mem_debug, (chx, mhx)
|
||||
else:
|
||||
return outs, read_vectors, (chx, mhx)
|
||||
return output, read_vectors, (chx, mhx)
|
||||
|
||||
def forward(self, input, hx=(None, None, None), reset_experience=False):
|
||||
# handle packed data
|
||||
@ -242,58 +242,46 @@ class DNC(nn.Module):
|
||||
controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience)
|
||||
|
||||
# batched forward pass per element / word / etc
|
||||
outputs = None
|
||||
chxs = []
|
||||
if self.debug:
|
||||
viz = {}
|
||||
viz = None
|
||||
|
||||
read_vectors = [last_read] * max_length
|
||||
# read_vectors = [last_read] * max_length
|
||||
# outs = [input[:, x, :] for x in range(max_length)]
|
||||
outs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
|
||||
inputs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
|
||||
outs = [None] * max_length
|
||||
|
||||
for layer in range(self.num_layers):
|
||||
# this layer's hidden states
|
||||
chx = controller_hidden[layer]
|
||||
|
||||
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
||||
# pass through controller
|
||||
if self.debug:
|
||||
outs, _, mem_debug, (chx, m) = self._layer_forward(outs,layer,(chx, m))
|
||||
else:
|
||||
outs, _, (chx, m) = self._layer_forward(outs, layer, (chx, m))
|
||||
|
||||
# debug memory
|
||||
if self.debug:
|
||||
if viz == {}:
|
||||
viz = mem_debug
|
||||
for time in range(max_length):
|
||||
for layer in range(self.num_layers):
|
||||
# this layer's hidden states
|
||||
chx = controller_hidden[layer]
|
||||
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
||||
# pass through controller
|
||||
if self.debug:
|
||||
outs[time], read_vectors, mem_debug, (chx, m) = self._layer_forward(inputs[time],layer,(chx, m))
|
||||
else:
|
||||
viz["memory"] += mem_debug["memory"]
|
||||
viz["link_matrix"] += mem_debug["link_matrix"]
|
||||
viz["precedence"] += mem_debug["precedence"]
|
||||
viz["read_weights"] += mem_debug["read_weights"]
|
||||
viz["write_weights"] += mem_debug["write_weights"]
|
||||
viz["usage_vector"] += mem_debug["usage_vector"]
|
||||
outs[time], read_vectors, (chx, m) = self._layer_forward(inputs[time],layer,(chx, m))
|
||||
|
||||
# store the memory back (per layer or shared)
|
||||
if self.share_memory:
|
||||
mem_hidden = m
|
||||
else:
|
||||
mem_hidden[layer] = m
|
||||
chxs.append(chx)
|
||||
# debug memory
|
||||
if self.debug:
|
||||
viz = self._debug(m, viz)
|
||||
|
||||
# store the memory back (per layer or shared)
|
||||
if self.share_memory:
|
||||
mem_hidden = m
|
||||
else:
|
||||
mem_hidden[layer] = m
|
||||
controller_hidden[layer] = chx
|
||||
|
||||
if layer == self.num_layers - 1:
|
||||
# final outputs
|
||||
outputs = T.stack(outs, 1)
|
||||
else:
|
||||
# the controller output + read vectors go into next layer
|
||||
outs = [T.cat([o, r], 1) for o, r in zip(outs, read_vectors)]
|
||||
# outs = [o for o in outs]
|
||||
outs[time] = (T.cat([outs[time], read_vectors], 1))
|
||||
inputs[time] = outs[time]
|
||||
|
||||
if self.debug:
|
||||
viz = { k: np.array(v) for k,v in viz.items() }
|
||||
viz = { k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k,v in viz.items() }
|
||||
|
||||
controller_hidden = chxs
|
||||
inputs = [ self.output_weights(i) for i in inputs ]
|
||||
outputs = T.stack(inputs, 1)
|
||||
|
||||
if not self.batch_first:
|
||||
outputs = outputs.transpose(0, 1)
|
||||
|
@ -74,7 +74,7 @@ def test_rnn_1():
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
assert chx[0][0].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,1,1])
|
||||
assert rv.size() == T.Size([10,1])
|
||||
assert rv.size() == T.Size([1])
|
||||
|
||||
|
||||
def test_rnn_n():
|
||||
@ -130,4 +130,4 @@ def test_rnn_n():
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[1][2].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([10,51])
|
||||
assert rv.size() == T.Size([51])
|
||||
|
@ -74,7 +74,7 @@ def test_rnn_1():
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
assert chx[0][0][0].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,1,1])
|
||||
assert rv.size() == T.Size([10,1])
|
||||
assert rv.size() == T.Size([1])
|
||||
|
||||
|
||||
def test_rnn_n():
|
||||
@ -130,4 +130,4 @@ def test_rnn_n():
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[0][0][0].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([10,51])
|
||||
assert rv.size() == T.Size([51])
|
||||
|
@ -74,7 +74,7 @@ def test_rnn_1():
|
||||
assert target_output.size() == T.Size([21, 10, 100])
|
||||
assert chx[0][0].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,1,1])
|
||||
assert rv.size() == T.Size([10,1])
|
||||
assert rv.size() == T.Size([1])
|
||||
|
||||
|
||||
def test_rnn_n():
|
||||
@ -130,4 +130,4 @@ def test_rnn_n():
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[1][2].size() == T.Size([10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([10,51])
|
||||
assert rv.size() == T.Size([51])
|
||||
|
Loading…
Reference in New Issue
Block a user