debug time

This commit is contained in:
ixaxaar 2017-11-01 15:04:30 +05:30
parent 4700ecbac3
commit af1a77ca7f
2 changed files with 23 additions and 13 deletions

View File

@ -170,6 +170,8 @@ class DNC(nn.Module):
outs = [0] * max_length
read_vectors = [0] * max_length
mem_debug = []
for time in range(max_length):
# pass through controller
layer_input = input[time]
@ -189,14 +191,21 @@ class DNC(nn.Module):
# pass through memory
if self.share_memory:
read_vecs, mhx = self.memories[0](ξ, mhx)
if self.debug:
mem_debug.append(mhx['memory'][0].data.cpu().numpy())
else:
read_vecs, mhx = self.memories[layer](ξ, mhx)
if self.debug:
mem_debug.append(mhx['memory'][0].data.cpu().numpy())
read_vectors[time] = read_vecs.view(-1, self.w * self.r)
# get the final output for this time step
outs[time] = self.dropout_layer(self.mem_out(T.cat([out, read_vectors[time]], 1)))
return outs, read_vectors, (chx, mhx)
if self.debug:
return outs, read_vectors, mem_debug, (chx, mhx)
else:
return outs, read_vectors, (chx, mhx)
def forward(self, input, hx=(None, None, None), reset_experience=False):
# handle packed data
@ -220,7 +229,7 @@ class DNC(nn.Module):
outputs = None
chxs = []
if self.debug:
viz = [mem_hidden['memory'][0]] if self.share_memory else [mem_hidden[0]['memory'][0]]
viz = []
read_vectors = [last_read] * max_length
# outs = [input[:, x, :] for x in range(max_length)]
@ -232,15 +241,14 @@ class DNC(nn.Module):
m = mem_hidden if self.share_memory else mem_hidden[layer]
# pass through controller
outs, _, (chx, m) = self._layer_forward(
outs,
layer,
(chx, m)
)
if self.debug:
outs, _, mem_debug, (chx, m) = self._layer_forward(outs,layer,(chx, m))
else:
outs, _, (chx, m) = self._layer_forward(outs, layer, (chx, m))
# debug memory
if self.debug:
viz.append(m['memory'][0])
viz.append(mem_debug)
# store the memory back (per layer or shared)
if self.share_memory:
@ -258,7 +266,9 @@ class DNC(nn.Module):
# outs = [o for o in outs]
if self.debug:
viz = T.cat(viz, 0).transpose(0, 1)
viz = np.array(viz)
s = list(viz.shape)
viz = viz.reshape(s[0]*s[1], s[2]*s[3])
controller_hidden = chxs

View File

@ -141,7 +141,7 @@ if __name__ == '__main__':
target_output = target_output.transpose(0, 1).contiguous()
output, (chx, mhx, rv), v = rnn(input_data, None)
# dncs operate batch first
output = output.transpose(0, 1)
loss = criterion((output), target_output)
@ -166,13 +166,13 @@ if __name__ == '__main__':
last_save_losses = []
viz.heatmap(
v.data.cpu().numpy(),
v,
opts=dict(
xtickstep=10,
ytickstep=2,
title='Timestep: ' + str(epoch) + ', loss: ' + str(loss),
xlabel='mem_slot * layer',
ylabel='mem_size'
ylabel='layer * time',
xlabel='cell_size * mem_size'
)
)