Add num_hidden_layers to nn layers

This commit is contained in:
ixaxaar 2017-11-21 17:28:27 +05:30
parent cc0308ba68
commit e3b4513730

View File

@ -74,13 +74,13 @@ class DNC(nn.Module):
for layer in range(self.num_layers):
if self.rnn_type.lower() == 'rnn':
self.rnns.append(nn.RNN((self.nn_input_size if layer == 0 else self.nn_output_size), self.output_size,
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout))
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
elif self.rnn_type.lower() == 'gru':
self.rnns.append(nn.GRU((self.nn_input_size if layer == 0 else self.nn_output_size),
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout))
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
if self.rnn_type.lower() == 'lstm':
self.rnns.append(nn.LSTM((self.nn_input_size if layer == 0 else self.nn_output_size),
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout))
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
setattr(self, self.rnn_type.lower() + '_layer_' + str(layer), self.rnns[layer])
# memories for each layer
@ -191,7 +191,7 @@ class DNC(nn.Module):
else:
read_vectors = None
return output, read_vectors, (chx, mhx)
return output, (chx, mhx, read_vectors)
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
# handle packed data
@ -229,7 +229,7 @@ class DNC(nn.Module):
chx = controller_hidden[layer]
m = mem_hidden if self.share_memory else mem_hidden[layer]
# pass through controller
outs[time], read_vectors, (chx, m) = \
outs[time], (chx, m, read_vectors) = \
self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory)
# debug memory
@ -246,6 +246,8 @@ class DNC(nn.Module):
if read_vectors is not None:
# the controller output + read vectors go into next layer
outs[time] = T.cat([outs[time], read_vectors], 1)
else:
outs[time] = T.cat([outs[time], last_read], 1)
inputs[time] = outs[time]
if self.debug: