debug time
This commit is contained in:
parent
4700ecbac3
commit
af1a77ca7f
28
dnc/dnc.py
28
dnc/dnc.py
@ -170,6 +170,8 @@ class DNC(nn.Module):
|
|||||||
outs = [0] * max_length
|
outs = [0] * max_length
|
||||||
read_vectors = [0] * max_length
|
read_vectors = [0] * max_length
|
||||||
|
|
||||||
|
mem_debug = []
|
||||||
|
|
||||||
for time in range(max_length):
|
for time in range(max_length):
|
||||||
# pass through controller
|
# pass through controller
|
||||||
layer_input = input[time]
|
layer_input = input[time]
|
||||||
@ -189,14 +191,21 @@ class DNC(nn.Module):
|
|||||||
# pass through memory
|
# pass through memory
|
||||||
if self.share_memory:
|
if self.share_memory:
|
||||||
read_vecs, mhx = self.memories[0](ξ, mhx)
|
read_vecs, mhx = self.memories[0](ξ, mhx)
|
||||||
|
if self.debug:
|
||||||
|
mem_debug.append(mhx['memory'][0].data.cpu().numpy())
|
||||||
else:
|
else:
|
||||||
read_vecs, mhx = self.memories[layer](ξ, mhx)
|
read_vecs, mhx = self.memories[layer](ξ, mhx)
|
||||||
|
if self.debug:
|
||||||
|
mem_debug.append(mhx['memory'][0].data.cpu().numpy())
|
||||||
read_vectors[time] = read_vecs.view(-1, self.w * self.r)
|
read_vectors[time] = read_vecs.view(-1, self.w * self.r)
|
||||||
|
|
||||||
# get the final output for this time step
|
# get the final output for this time step
|
||||||
outs[time] = self.dropout_layer(self.mem_out(T.cat([out, read_vectors[time]], 1)))
|
outs[time] = self.dropout_layer(self.mem_out(T.cat([out, read_vectors[time]], 1)))
|
||||||
|
|
||||||
return outs, read_vectors, (chx, mhx)
|
if self.debug:
|
||||||
|
return outs, read_vectors, mem_debug, (chx, mhx)
|
||||||
|
else:
|
||||||
|
return outs, read_vectors, (chx, mhx)
|
||||||
|
|
||||||
def forward(self, input, hx=(None, None, None), reset_experience=False):
|
def forward(self, input, hx=(None, None, None), reset_experience=False):
|
||||||
# handle packed data
|
# handle packed data
|
||||||
@ -220,7 +229,7 @@ class DNC(nn.Module):
|
|||||||
outputs = None
|
outputs = None
|
||||||
chxs = []
|
chxs = []
|
||||||
if self.debug:
|
if self.debug:
|
||||||
viz = [mem_hidden['memory'][0]] if self.share_memory else [mem_hidden[0]['memory'][0]]
|
viz = []
|
||||||
|
|
||||||
read_vectors = [last_read] * max_length
|
read_vectors = [last_read] * max_length
|
||||||
# outs = [input[:, x, :] for x in range(max_length)]
|
# outs = [input[:, x, :] for x in range(max_length)]
|
||||||
@ -232,15 +241,14 @@ class DNC(nn.Module):
|
|||||||
|
|
||||||
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
||||||
# pass through controller
|
# pass through controller
|
||||||
outs, _, (chx, m) = self._layer_forward(
|
if self.debug:
|
||||||
outs,
|
outs, _, mem_debug, (chx, m) = self._layer_forward(outs,layer,(chx, m))
|
||||||
layer,
|
else:
|
||||||
(chx, m)
|
outs, _, (chx, m) = self._layer_forward(outs, layer, (chx, m))
|
||||||
)
|
|
||||||
|
|
||||||
# debug memory
|
# debug memory
|
||||||
if self.debug:
|
if self.debug:
|
||||||
viz.append(m['memory'][0])
|
viz.append(mem_debug)
|
||||||
|
|
||||||
# store the memory back (per layer or shared)
|
# store the memory back (per layer or shared)
|
||||||
if self.share_memory:
|
if self.share_memory:
|
||||||
@ -258,7 +266,9 @@ class DNC(nn.Module):
|
|||||||
# outs = [o for o in outs]
|
# outs = [o for o in outs]
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
viz = T.cat(viz, 0).transpose(0, 1)
|
viz = np.array(viz)
|
||||||
|
s = list(viz.shape)
|
||||||
|
viz = viz.reshape(s[0]*s[1], s[2]*s[3])
|
||||||
|
|
||||||
controller_hidden = chxs
|
controller_hidden = chxs
|
||||||
|
|
||||||
|
@ -141,7 +141,7 @@ if __name__ == '__main__':
|
|||||||
target_output = target_output.transpose(0, 1).contiguous()
|
target_output = target_output.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
output, (chx, mhx, rv), v = rnn(input_data, None)
|
output, (chx, mhx, rv), v = rnn(input_data, None)
|
||||||
|
# dncs operate batch first
|
||||||
output = output.transpose(0, 1)
|
output = output.transpose(0, 1)
|
||||||
|
|
||||||
loss = criterion((output), target_output)
|
loss = criterion((output), target_output)
|
||||||
@ -166,13 +166,13 @@ if __name__ == '__main__':
|
|||||||
last_save_losses = []
|
last_save_losses = []
|
||||||
|
|
||||||
viz.heatmap(
|
viz.heatmap(
|
||||||
v.data.cpu().numpy(),
|
v,
|
||||||
opts=dict(
|
opts=dict(
|
||||||
xtickstep=10,
|
xtickstep=10,
|
||||||
ytickstep=2,
|
ytickstep=2,
|
||||||
title='Timestep: ' + str(epoch) + ', loss: ' + str(loss),
|
title='Timestep: ' + str(epoch) + ', loss: ' + str(loss),
|
||||||
xlabel='mem_slot * layer',
|
ylabel='layer * time',
|
||||||
ylabel='mem_size'
|
xlabel='cell_size * mem_size'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user