Add num_hidden_layers to nn layers
This commit is contained in:
parent
cc0308ba68
commit
e3b4513730
12
dnc/dnc.py
12
dnc/dnc.py
@ -74,13 +74,13 @@ class DNC(nn.Module):
|
||||
for layer in range(self.num_layers):
|
||||
if self.rnn_type.lower() == 'rnn':
|
||||
self.rnns.append(nn.RNN((self.nn_input_size if layer == 0 else self.nn_output_size), self.output_size,
|
||||
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout))
|
||||
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
||||
elif self.rnn_type.lower() == 'gru':
|
||||
self.rnns.append(nn.GRU((self.nn_input_size if layer == 0 else self.nn_output_size),
|
||||
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout))
|
||||
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
||||
if self.rnn_type.lower() == 'lstm':
|
||||
self.rnns.append(nn.LSTM((self.nn_input_size if layer == 0 else self.nn_output_size),
|
||||
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout))
|
||||
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
||||
setattr(self, self.rnn_type.lower() + '_layer_' + str(layer), self.rnns[layer])
|
||||
|
||||
# memories for each layer
|
||||
@ -191,7 +191,7 @@ class DNC(nn.Module):
|
||||
else:
|
||||
read_vectors = None
|
||||
|
||||
return output, read_vectors, (chx, mhx)
|
||||
return output, (chx, mhx, read_vectors)
|
||||
|
||||
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
|
||||
# handle packed data
|
||||
@ -229,7 +229,7 @@ class DNC(nn.Module):
|
||||
chx = controller_hidden[layer]
|
||||
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
||||
# pass through controller
|
||||
outs[time], read_vectors, (chx, m) = \
|
||||
outs[time], (chx, m, read_vectors) = \
|
||||
self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory)
|
||||
|
||||
# debug memory
|
||||
@ -246,6 +246,8 @@ class DNC(nn.Module):
|
||||
if read_vectors is not None:
|
||||
# the controller output + read vectors go into next layer
|
||||
outs[time] = T.cat([outs[time], read_vectors], 1)
|
||||
else:
|
||||
outs[time] = T.cat([outs[time], last_read], 1)
|
||||
inputs[time] = outs[time]
|
||||
|
||||
if self.debug:
|
||||
|
Loading…
Reference in New Issue
Block a user