Stack DNCs properly

This commit is contained in:
ixaxaar 2017-11-09 00:33:24 +05:30
parent 5158e85f5d
commit f7dc1b5aab
4 changed files with 72 additions and 84 deletions

View File

@ -123,8 +123,8 @@ class DNC(nn.Module):
setattr(self, 'rnn_layer_memory_shared', self.memories[0])
# final output layer
self.output_weights = nn.Linear(self.output_size, self.output_size)
self.mem_out = nn.Linear(self.nn_output_size, self.input_size)
self.output_weights = nn.Linear(self.nn_input_size, self.input_size)
self.mem_out = nn.Linear(self.hidden_size, self.input_size)
self.dropout_layer = nn.Dropout(self.dropout)
if self.gpu_id != -1:
@ -165,18 +165,26 @@ class DNC(nn.Module):
return chx, mhx, last_read
def _debug(self, mhx, debug_obj):
debug_obj['memory'].append(mhx['memory'][0].data.cpu().numpy())
debug_obj['link_matrix'].append(mhx['link_matrix'][0][0].data.cpu().numpy())
debug_obj['precedence'].append(mhx['precedence'][0].data.cpu().numpy())
debug_obj['read_weights'].append(mhx['read_weights'][0].data.cpu().numpy())
debug_obj['write_weights'].append(mhx['write_weights'][0].data.cpu().numpy())
debug_obj['usage_vector'].append(mhx['usage_vector'][0].unsqueeze(0).data.cpu().numpy())
if not debug_obj:
debug_obj = {
'memory': [],
'link_matrix': [],
'precedence': [],
'read_weights': [],
'write_weights': [],
'usage_vector': [],
}
debug_obj['memory'].append(mhx['memory'][0].data.cpu().numpy())
debug_obj['link_matrix'].append(mhx['link_matrix'][0][0].data.cpu().numpy())
debug_obj['precedence'].append(mhx['precedence'][0].data.cpu().numpy())
debug_obj['read_weights'].append(mhx['read_weights'][0].data.cpu().numpy())
debug_obj['write_weights'].append(mhx['write_weights'][0].data.cpu().numpy())
debug_obj['usage_vector'].append(mhx['usage_vector'][0].unsqueeze(0).data.cpu().numpy())
return debug_obj
def _layer_forward(self, input, layer, hx=(None, None)):
(chx, mhx) = hx
max_length = len(input)
outs = [0] * max_length
read_vectors = [0] * max_length
if self.debug:
mem_debug = {
@ -188,40 +196,32 @@ class DNC(nn.Module):
"usage_vector": [],
}
for time in range(max_length):
# pass through controller
layer_input = input[time]
hchx = []
layer_input = input
hchx = []
for hlayer in range(self.num_hidden_layers):
h = self.rnns[layer][hlayer](layer_input, chx[hlayer])
layer_input = h[0] if self.rnn_type.lower() == 'lstm' else h
hchx.append(h)
chx = hchx
for hlayer in range(self.num_hidden_layers):
h = self.rnns[layer][hlayer](layer_input, chx[hlayer])
layer_input = h[0] if self.rnn_type.lower() == 'lstm' else h
hchx.append(h)
chx = hchx
# the interface vector
ξ = layer_input
# the output
output = self.dropout_layer(self.mem_out(layer_input))
# the interface vector
ξ = layer_input
# the output
out = self.output_weights(layer_input)
# pass through memory
if self.share_memory:
read_vecs, mhx = self.memories[0](ξ, mhx)
else:
read_vecs, mhx = self.memories[layer](ξ, mhx)
# pass through memory
if self.share_memory:
read_vecs, mhx = self.memories[0](ξ, mhx)
if self.debug:
self._debug(mhx, mem_debug)
else:
read_vecs, mhx = self.memories[layer](ξ, mhx)
if self.debug:
self._debug(mhx, mem_debug)
read_vectors[time] = read_vecs.view(-1, self.w * self.r)
# get the final output for this time step
outs[time] = self.dropout_layer(self.mem_out(T.cat([out, read_vectors[time]], 1)))
# the read vectors
read_vectors = read_vecs.view(-1, self.w * self.r)
if self.debug:
return outs, read_vectors, mem_debug, (chx, mhx)
return output, read_vectors, mem_debug, (chx, mhx)
else:
return outs, read_vectors, (chx, mhx)
return output, read_vectors, (chx, mhx)
def forward(self, input, hx=(None, None, None), reset_experience=False):
# handle packed data
@ -242,58 +242,46 @@ class DNC(nn.Module):
controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience)
# batched forward pass per element / word / etc
outputs = None
chxs = []
if self.debug:
viz = {}
viz = None
read_vectors = [last_read] * max_length
# read_vectors = [last_read] * max_length
# outs = [input[:, x, :] for x in range(max_length)]
outs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
inputs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
outs = [None] * max_length
for layer in range(self.num_layers):
# this layer's hidden states
chx = controller_hidden[layer]
m = mem_hidden if self.share_memory else mem_hidden[layer]
# pass through controller
if self.debug:
outs, _, mem_debug, (chx, m) = self._layer_forward(outs,layer,(chx, m))
else:
outs, _, (chx, m) = self._layer_forward(outs, layer, (chx, m))
# debug memory
if self.debug:
if viz == {}:
viz = mem_debug
for time in range(max_length):
for layer in range(self.num_layers):
# this layer's hidden states
chx = controller_hidden[layer]
m = mem_hidden if self.share_memory else mem_hidden[layer]
# pass through controller
if self.debug:
outs[time], read_vectors, mem_debug, (chx, m) = self._layer_forward(inputs[time],layer,(chx, m))
else:
viz["memory"] += mem_debug["memory"]
viz["link_matrix"] += mem_debug["link_matrix"]
viz["precedence"] += mem_debug["precedence"]
viz["read_weights"] += mem_debug["read_weights"]
viz["write_weights"] += mem_debug["write_weights"]
viz["usage_vector"] += mem_debug["usage_vector"]
outs[time], read_vectors, (chx, m) = self._layer_forward(inputs[time],layer,(chx, m))
# store the memory back (per layer or shared)
if self.share_memory:
mem_hidden = m
else:
mem_hidden[layer] = m
chxs.append(chx)
# debug memory
if self.debug:
viz = self._debug(m, viz)
# store the memory back (per layer or shared)
if self.share_memory:
mem_hidden = m
else:
mem_hidden[layer] = m
controller_hidden[layer] = chx
if layer == self.num_layers - 1:
# final outputs
outputs = T.stack(outs, 1)
else:
# the controller output + read vectors go into next layer
outs = [T.cat([o, r], 1) for o, r in zip(outs, read_vectors)]
# outs = [o for o in outs]
outs[time] = (T.cat([outs[time], read_vectors], 1))
inputs[time] = outs[time]
if self.debug:
viz = { k: np.array(v) for k,v in viz.items() }
viz = { k: v.reshape(v.shape[0], v.shape[1] * v.shape[2]) for k,v in viz.items() }
controller_hidden = chxs
inputs = [ self.output_weights(i) for i in inputs ]
outputs = T.stack(inputs, 1)
if not self.batch_first:
outputs = outputs.transpose(0, 1)

View File

@ -74,7 +74,7 @@ def test_rnn_1():
assert target_output.size() == T.Size([21, 10, 100])
assert chx[0][0].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,1,1])
assert rv.size() == T.Size([10,1])
assert rv.size() == T.Size([1])
def test_rnn_n():
@ -130,4 +130,4 @@ def test_rnn_n():
assert target_output.size() == T.Size([27, 10, 100])
assert chx[1][2].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10,51])
assert rv.size() == T.Size([51])

View File

@ -74,7 +74,7 @@ def test_rnn_1():
assert target_output.size() == T.Size([21, 10, 100])
assert chx[0][0][0].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,1,1])
assert rv.size() == T.Size([10,1])
assert rv.size() == T.Size([1])
def test_rnn_n():
@ -130,4 +130,4 @@ def test_rnn_n():
assert target_output.size() == T.Size([27, 10, 100])
assert chx[0][0][0].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10,51])
assert rv.size() == T.Size([51])

View File

@ -74,7 +74,7 @@ def test_rnn_1():
assert target_output.size() == T.Size([21, 10, 100])
assert chx[0][0].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,1,1])
assert rv.size() == T.Size([10,1])
assert rv.size() == T.Size([1])
def test_rnn_n():
@ -130,4 +130,4 @@ def test_rnn_n():
assert target_output.size() == T.Size([27, 10, 100])
assert chx[1][2].size() == T.Size([10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10,51])
assert rv.size() == T.Size([51])