debug time
This commit is contained in:
parent
4700ecbac3
commit
af1a77ca7f
28
dnc/dnc.py
28
dnc/dnc.py
@ -170,6 +170,8 @@ class DNC(nn.Module):
|
||||
outs = [0] * max_length
|
||||
read_vectors = [0] * max_length
|
||||
|
||||
mem_debug = []
|
||||
|
||||
for time in range(max_length):
|
||||
# pass through controller
|
||||
layer_input = input[time]
|
||||
@ -189,14 +191,21 @@ class DNC(nn.Module):
|
||||
# pass through memory
|
||||
if self.share_memory:
|
||||
read_vecs, mhx = self.memories[0](ξ, mhx)
|
||||
if self.debug:
|
||||
mem_debug.append(mhx['memory'][0].data.cpu().numpy())
|
||||
else:
|
||||
read_vecs, mhx = self.memories[layer](ξ, mhx)
|
||||
if self.debug:
|
||||
mem_debug.append(mhx['memory'][0].data.cpu().numpy())
|
||||
read_vectors[time] = read_vecs.view(-1, self.w * self.r)
|
||||
|
||||
# get the final output for this time step
|
||||
outs[time] = self.dropout_layer(self.mem_out(T.cat([out, read_vectors[time]], 1)))
|
||||
|
||||
return outs, read_vectors, (chx, mhx)
|
||||
if self.debug:
|
||||
return outs, read_vectors, mem_debug, (chx, mhx)
|
||||
else:
|
||||
return outs, read_vectors, (chx, mhx)
|
||||
|
||||
def forward(self, input, hx=(None, None, None), reset_experience=False):
|
||||
# handle packed data
|
||||
@ -220,7 +229,7 @@ class DNC(nn.Module):
|
||||
outputs = None
|
||||
chxs = []
|
||||
if self.debug:
|
||||
viz = [mem_hidden['memory'][0]] if self.share_memory else [mem_hidden[0]['memory'][0]]
|
||||
viz = []
|
||||
|
||||
read_vectors = [last_read] * max_length
|
||||
# outs = [input[:, x, :] for x in range(max_length)]
|
||||
@ -232,15 +241,14 @@ class DNC(nn.Module):
|
||||
|
||||
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
||||
# pass through controller
|
||||
outs, _, (chx, m) = self._layer_forward(
|
||||
outs,
|
||||
layer,
|
||||
(chx, m)
|
||||
)
|
||||
if self.debug:
|
||||
outs, _, mem_debug, (chx, m) = self._layer_forward(outs,layer,(chx, m))
|
||||
else:
|
||||
outs, _, (chx, m) = self._layer_forward(outs, layer, (chx, m))
|
||||
|
||||
# debug memory
|
||||
if self.debug:
|
||||
viz.append(m['memory'][0])
|
||||
viz.append(mem_debug)
|
||||
|
||||
# store the memory back (per layer or shared)
|
||||
if self.share_memory:
|
||||
@ -258,7 +266,9 @@ class DNC(nn.Module):
|
||||
# outs = [o for o in outs]
|
||||
|
||||
if self.debug:
|
||||
viz = T.cat(viz, 0).transpose(0, 1)
|
||||
viz = np.array(viz)
|
||||
s = list(viz.shape)
|
||||
viz = viz.reshape(s[0]*s[1], s[2]*s[3])
|
||||
|
||||
controller_hidden = chxs
|
||||
|
||||
|
@ -141,7 +141,7 @@ if __name__ == '__main__':
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
output, (chx, mhx, rv), v = rnn(input_data, None)
|
||||
|
||||
# dncs operate batch first
|
||||
output = output.transpose(0, 1)
|
||||
|
||||
loss = criterion((output), target_output)
|
||||
@ -166,13 +166,13 @@ if __name__ == '__main__':
|
||||
last_save_losses = []
|
||||
|
||||
viz.heatmap(
|
||||
v.data.cpu().numpy(),
|
||||
v,
|
||||
opts=dict(
|
||||
xtickstep=10,
|
||||
ytickstep=2,
|
||||
title='Timestep: ' + str(epoch) + ', loss: ' + str(loss),
|
||||
xlabel='mem_slot * layer',
|
||||
ylabel='mem_size'
|
||||
ylabel='layer * time',
|
||||
xlabel='cell_size * mem_size'
|
||||
)
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user