RNNs with CUDNN implementations, make whether to forward pass thorugh memory a controllable, update readme

This commit is contained in:
ixaxaar 2017-11-10 21:29:48 +05:30
parent 67d9722231
commit 51caa2e2ce
5 changed files with 46 additions and 46 deletions

View File

@ -18,25 +18,35 @@ pip install dnc
**Parameters**:
Following are the constructor parameters:
| Argument | Default | Description |
| --- | --- | --- |
| input_size | None | Size of the input vectors |
| hidden_size | None | Size of hidden units |
| rnn_type | 'lstm' | Type of recurrent cells used in the controller |
| num_layers | 1 | Number of layers of recurrent units in the controller |
| num_hidden_layers | 2 | Number of hidden layers per layer of the controller |
| bias | True | Bias |
| batch_first | True | Whether data is fed batch first |
| dropout | 0 | Dropout between layers in the controller |
| bidirectional | False | If the controller is bidirectional (Not yet implemented) |
| nr_cells | 5 | Number of memory cells |
| read_heads | 2 | Number of read heads |
| cell_size | 10 | Size of each memory cell |
| nonlinearity | 'tanh' | If using 'rnn' as `rnn_type`, non-linearity of the RNNs |
| gpu_id | -1 | ID of the GPU, -1 for CPU |
| independent_linears | False | Whether to use independent linear units to derive interface vector |
| share_memory | True | Whether to share memory between controller layers |
| reset_experience | False | Whether to reset memory (This is a parameter for the forward pass) |
| input_size | `None` | Size of the input vectors |
| hidden_size | `None` | Size of hidden units |
| rnn_type | `'lstm'` | Type of recurrent cells used in the controller |
| num_layers | `1` | Number of layers of recurrent units in the controller |
| num_hidden_layers | `2` | Number of hidden layers per layer of the controller |
| bias | `True` | Bias |
| batch_first | `True` | Whether data is fed batch first |
| dropout | `0` | Dropout between layers in the controller |
| bidirectional | `False` | If the controller is bidirectional (Not yet implemented) |
| nr_cells | `5` | Number of memory cells |
| read_heads | `2` | Number of read heads |
| cell_size | `10` | Size of each memory cell |
| nonlinearity | `'tanh'` | If using 'rnn' as `rnn_type`, non-linearity of the RNNs |
| gpu_id | `-1` | ID of the GPU, -1 for CPU |
| independent_linears | `False` | Whether to use independent linear units to derive interface vector |
| share_memory | `True` | Whether to share memory between controller layers |
Following are the forward pass parameters:
| Argument | Default | Description |
| --- | --- | --- |
| input | - | Whether to reset memory (This is a parameter for the forward pass) |
| hidden | `(None,None,None)` | Whether to reset memory (This is a parameter for the forward pass) |
| reset_experience | `False` | Whether to reset memory (This is a parameter for the forward pass) |
| pass_through_memory | `True` | Whether to pass through memory (This is a parameter for the forward pass) |
Example usage:

View File

@ -67,15 +67,13 @@ class DNC(nn.Module):
self.memories = []
for layer in range(self.num_layers):
self.rnns.append([])
for hlayer in range(self.num_hidden_layers):
if self.rnn_type.lower() == 'rnn':
self.rnns[layer].append(nn.RNNCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias, nonlinearity=self.nonlinearity))
elif self.rnn_type.lower() == 'gru':
self.rnns[layer].append(nn.GRUCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias))
elif self.rnn_type.lower() == 'lstm':
self.rnns[layer].append(nn.LSTMCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias))
setattr(self, self.rnn_type.lower()+'_layer_' + str(layer) + '_' + str(hlayer), self.rnns[layer][hlayer])
if self.rnn_type.lower() == 'rnn':
self.rnns.append(nn.RNN((self.input_size if layer == 0 else self.output_size), self.output_size, bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True))
elif self.rnn_type.lower() == 'gru':
self.rnns.append(nn.GRU((self.input_size if layer == 0 else self.output_size), self.output_size, bias=self.bias, batch_first=True))
if self.rnn_type.lower() == 'lstm':
self.rnns.append(nn.LSTM((self.input_size if layer == 0 else self.output_size), self.output_size, bias=self.bias, batch_first=True))
setattr(self, self.rnn_type.lower()+'_layer_' + str(layer), self.rnns[layer])
# memories for each layer
if not self.share_memory:
@ -107,14 +105,12 @@ class DNC(nn.Module):
# final output layer
self.read_vectors_weights = nn.Linear(self.read_vectors_size, self.output_size)
self.mem_out = nn.Linear(self.hidden_size, self.output_size)
self.output = nn.Linear(self.output_size, self.input_size)
self.dropout_layer = nn.Dropout(self.dropout)
if self.gpu_id != -1:
[x.cuda(self.gpu_id) for y in self.rnns for x in y]
[x.cuda(self.gpu_id) for x in self.rnns]
[x.cuda(self.gpu_id) for x in self.memories]
self.mem_out.cuda(self.gpu_id)
def _init_hidden(self, hx, batch_size, reset_experience):
# create empty hidden states if not provided
@ -124,11 +120,7 @@ class DNC(nn.Module):
# initialize hidden state of the controller RNN
if chx is None:
chx = cuda(T.zeros(batch_size, self.output_size), gpu_id=self.gpu_id)
if self.rnn_type.lower() == 'lstm':
chx = [ [ (chx.clone(), chx.clone()) for h in range(self.num_hidden_layers) ] for l in range(self.num_layers) ]
else:
chx = [ [ chx.clone() for h in range(self.num_hidden_layers) ] for l in range(self.num_layers) ]
chx = [ None for x in range(self.num_layers) ]
# Last read vectors
if last_read is None:
@ -170,18 +162,14 @@ class DNC(nn.Module):
def _layer_forward(self, input, layer, hx=(None, None), pass_through_memory=True):
(chx, mhx) = hx
layer_input = input
hchx = []
for hlayer in range(self.num_hidden_layers):
h = self.rnns[layer][hlayer](layer_input, chx[hlayer])
layer_input = h[0] if self.rnn_type.lower() == 'lstm' else h
hchx.append(h)
chx = hchx
# pass through the controller layer
input, chx = self.rnns[layer](input.unsqueeze(1), chx)
input = input.squeeze(1)
# the interface vector
ξ = layer_input
ξ = input
# the output
output = self.dropout_layer(self.mem_out(layer_input))
output = input
# pass through memory
if pass_through_memory:
@ -223,7 +211,9 @@ class DNC(nn.Module):
outs = [None] * max_length
read_vectors = None
# pass through time
for time in range(max_length):
# pass thorugh layers
for layer in range(self.num_layers):
# this layer's hidden states
chx = controller_hidden[layer]

View File

@ -128,6 +128,6 @@ def test_rnn_n():
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[1][2].size() == T.Size([10,100])
assert chx[1].size() == T.Size([1,10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10, 51])

View File

@ -128,6 +128,6 @@ def test_rnn_n():
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[0][0][0].size() == T.Size([10,100])
assert chx[0][0].size() == T.Size([1,10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10, 51])

View File

@ -128,6 +128,6 @@ def test_rnn_n():
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[1][2].size() == T.Size([10,100])
assert chx[1].size() == T.Size([1,10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10, 51])