debug time

This commit is contained in:
ixaxaar 2017-11-01 15:04:30 +05:30
parent 4700ecbac3
commit af1a77ca7f
2 changed files with 23 additions and 13 deletions

View File

@ -170,6 +170,8 @@ class DNC(nn.Module):
outs = [0] * max_length outs = [0] * max_length
read_vectors = [0] * max_length read_vectors = [0] * max_length
mem_debug = []
for time in range(max_length): for time in range(max_length):
# pass through controller # pass through controller
layer_input = input[time] layer_input = input[time]
@ -189,14 +191,21 @@ class DNC(nn.Module):
# pass through memory # pass through memory
if self.share_memory: if self.share_memory:
read_vecs, mhx = self.memories[0](ξ, mhx) read_vecs, mhx = self.memories[0](ξ, mhx)
if self.debug:
mem_debug.append(mhx['memory'][0].data.cpu().numpy())
else: else:
read_vecs, mhx = self.memories[layer](ξ, mhx) read_vecs, mhx = self.memories[layer](ξ, mhx)
if self.debug:
mem_debug.append(mhx['memory'][0].data.cpu().numpy())
read_vectors[time] = read_vecs.view(-1, self.w * self.r) read_vectors[time] = read_vecs.view(-1, self.w * self.r)
# get the final output for this time step # get the final output for this time step
outs[time] = self.dropout_layer(self.mem_out(T.cat([out, read_vectors[time]], 1))) outs[time] = self.dropout_layer(self.mem_out(T.cat([out, read_vectors[time]], 1)))
return outs, read_vectors, (chx, mhx) if self.debug:
return outs, read_vectors, mem_debug, (chx, mhx)
else:
return outs, read_vectors, (chx, mhx)
def forward(self, input, hx=(None, None, None), reset_experience=False): def forward(self, input, hx=(None, None, None), reset_experience=False):
# handle packed data # handle packed data
@ -220,7 +229,7 @@ class DNC(nn.Module):
outputs = None outputs = None
chxs = [] chxs = []
if self.debug: if self.debug:
viz = [mem_hidden['memory'][0]] if self.share_memory else [mem_hidden[0]['memory'][0]] viz = []
read_vectors = [last_read] * max_length read_vectors = [last_read] * max_length
# outs = [input[:, x, :] for x in range(max_length)] # outs = [input[:, x, :] for x in range(max_length)]
@ -232,15 +241,14 @@ class DNC(nn.Module):
m = mem_hidden if self.share_memory else mem_hidden[layer] m = mem_hidden if self.share_memory else mem_hidden[layer]
# pass through controller # pass through controller
outs, _, (chx, m) = self._layer_forward( if self.debug:
outs, outs, _, mem_debug, (chx, m) = self._layer_forward(outs,layer,(chx, m))
layer, else:
(chx, m) outs, _, (chx, m) = self._layer_forward(outs, layer, (chx, m))
)
# debug memory # debug memory
if self.debug: if self.debug:
viz.append(m['memory'][0]) viz.append(mem_debug)
# store the memory back (per layer or shared) # store the memory back (per layer or shared)
if self.share_memory: if self.share_memory:
@ -258,7 +266,9 @@ class DNC(nn.Module):
# outs = [o for o in outs] # outs = [o for o in outs]
if self.debug: if self.debug:
viz = T.cat(viz, 0).transpose(0, 1) viz = np.array(viz)
s = list(viz.shape)
viz = viz.reshape(s[0]*s[1], s[2]*s[3])
controller_hidden = chxs controller_hidden = chxs

View File

@ -141,7 +141,7 @@ if __name__ == '__main__':
target_output = target_output.transpose(0, 1).contiguous() target_output = target_output.transpose(0, 1).contiguous()
output, (chx, mhx, rv), v = rnn(input_data, None) output, (chx, mhx, rv), v = rnn(input_data, None)
# dncs operate batch first
output = output.transpose(0, 1) output = output.transpose(0, 1)
loss = criterion((output), target_output) loss = criterion((output), target_output)
@ -166,13 +166,13 @@ if __name__ == '__main__':
last_save_losses = [] last_save_losses = []
viz.heatmap( viz.heatmap(
v.data.cpu().numpy(), v,
opts=dict( opts=dict(
xtickstep=10, xtickstep=10,
ytickstep=2, ytickstep=2,
title='Timestep: ' + str(epoch) + ', loss: ' + str(loss), title='Timestep: ' + str(epoch) + ', loss: ' + str(loss),
xlabel='mem_slot * layer', ylabel='layer * time',
ylabel='mem_size' xlabel='cell_size * mem_size'
) )
) )