RNNs with CUDNN implementations, make whether to forward pass thorugh memory a controllable, update readme
This commit is contained in:
parent
67d9722231
commit
51caa2e2ce
44
README.md
44
README.md
@ -18,25 +18,35 @@ pip install dnc
|
||||
|
||||
**Parameters**:
|
||||
|
||||
Following are the constructor parameters:
|
||||
|
||||
| Argument | Default | Description |
|
||||
| --- | --- | --- |
|
||||
| input_size | None | Size of the input vectors |
|
||||
| hidden_size | None | Size of hidden units |
|
||||
| rnn_type | 'lstm' | Type of recurrent cells used in the controller |
|
||||
| num_layers | 1 | Number of layers of recurrent units in the controller |
|
||||
| num_hidden_layers | 2 | Number of hidden layers per layer of the controller |
|
||||
| bias | True | Bias |
|
||||
| batch_first | True | Whether data is fed batch first |
|
||||
| dropout | 0 | Dropout between layers in the controller |
|
||||
| bidirectional | False | If the controller is bidirectional (Not yet implemented) |
|
||||
| nr_cells | 5 | Number of memory cells |
|
||||
| read_heads | 2 | Number of read heads |
|
||||
| cell_size | 10 | Size of each memory cell |
|
||||
| nonlinearity | 'tanh' | If using 'rnn' as `rnn_type`, non-linearity of the RNNs |
|
||||
| gpu_id | -1 | ID of the GPU, -1 for CPU |
|
||||
| independent_linears | False | Whether to use independent linear units to derive interface vector |
|
||||
| share_memory | True | Whether to share memory between controller layers |
|
||||
| reset_experience | False | Whether to reset memory (This is a parameter for the forward pass) |
|
||||
| input_size | `None` | Size of the input vectors |
|
||||
| hidden_size | `None` | Size of hidden units |
|
||||
| rnn_type | `'lstm'` | Type of recurrent cells used in the controller |
|
||||
| num_layers | `1` | Number of layers of recurrent units in the controller |
|
||||
| num_hidden_layers | `2` | Number of hidden layers per layer of the controller |
|
||||
| bias | `True` | Bias |
|
||||
| batch_first | `True` | Whether data is fed batch first |
|
||||
| dropout | `0` | Dropout between layers in the controller |
|
||||
| bidirectional | `False` | If the controller is bidirectional (Not yet implemented) |
|
||||
| nr_cells | `5` | Number of memory cells |
|
||||
| read_heads | `2` | Number of read heads |
|
||||
| cell_size | `10` | Size of each memory cell |
|
||||
| nonlinearity | `'tanh'` | If using 'rnn' as `rnn_type`, non-linearity of the RNNs |
|
||||
| gpu_id | `-1` | ID of the GPU, -1 for CPU |
|
||||
| independent_linears | `False` | Whether to use independent linear units to derive interface vector |
|
||||
| share_memory | `True` | Whether to share memory between controller layers |
|
||||
|
||||
Following are the forward pass parameters:
|
||||
|
||||
| Argument | Default | Description |
|
||||
| --- | --- | --- |
|
||||
| input | - | Whether to reset memory (This is a parameter for the forward pass) |
|
||||
| hidden | `(None,None,None)` | Whether to reset memory (This is a parameter for the forward pass) |
|
||||
| reset_experience | `False` | Whether to reset memory (This is a parameter for the forward pass) |
|
||||
| pass_through_memory | `True` | Whether to pass through memory (This is a parameter for the forward pass) |
|
||||
|
||||
|
||||
Example usage:
|
||||
|
42
dnc/dnc.py
42
dnc/dnc.py
@ -67,15 +67,13 @@ class DNC(nn.Module):
|
||||
self.memories = []
|
||||
|
||||
for layer in range(self.num_layers):
|
||||
self.rnns.append([])
|
||||
for hlayer in range(self.num_hidden_layers):
|
||||
if self.rnn_type.lower() == 'rnn':
|
||||
self.rnns[layer].append(nn.RNNCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias, nonlinearity=self.nonlinearity))
|
||||
elif self.rnn_type.lower() == 'gru':
|
||||
self.rnns[layer].append(nn.GRUCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias))
|
||||
elif self.rnn_type.lower() == 'lstm':
|
||||
self.rnns[layer].append(nn.LSTMCell((self.input_size if (hlayer == 0 and layer == 0) else self.output_size), self.output_size, bias=self.bias))
|
||||
setattr(self, self.rnn_type.lower()+'_layer_' + str(layer) + '_' + str(hlayer), self.rnns[layer][hlayer])
|
||||
if self.rnn_type.lower() == 'rnn':
|
||||
self.rnns.append(nn.RNN((self.input_size if layer == 0 else self.output_size), self.output_size, bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True))
|
||||
elif self.rnn_type.lower() == 'gru':
|
||||
self.rnns.append(nn.GRU((self.input_size if layer == 0 else self.output_size), self.output_size, bias=self.bias, batch_first=True))
|
||||
if self.rnn_type.lower() == 'lstm':
|
||||
self.rnns.append(nn.LSTM((self.input_size if layer == 0 else self.output_size), self.output_size, bias=self.bias, batch_first=True))
|
||||
setattr(self, self.rnn_type.lower()+'_layer_' + str(layer), self.rnns[layer])
|
||||
|
||||
# memories for each layer
|
||||
if not self.share_memory:
|
||||
@ -107,14 +105,12 @@ class DNC(nn.Module):
|
||||
|
||||
# final output layer
|
||||
self.read_vectors_weights = nn.Linear(self.read_vectors_size, self.output_size)
|
||||
self.mem_out = nn.Linear(self.hidden_size, self.output_size)
|
||||
self.output = nn.Linear(self.output_size, self.input_size)
|
||||
self.dropout_layer = nn.Dropout(self.dropout)
|
||||
|
||||
if self.gpu_id != -1:
|
||||
[x.cuda(self.gpu_id) for y in self.rnns for x in y]
|
||||
[x.cuda(self.gpu_id) for x in self.rnns]
|
||||
[x.cuda(self.gpu_id) for x in self.memories]
|
||||
self.mem_out.cuda(self.gpu_id)
|
||||
|
||||
def _init_hidden(self, hx, batch_size, reset_experience):
|
||||
# create empty hidden states if not provided
|
||||
@ -124,11 +120,7 @@ class DNC(nn.Module):
|
||||
|
||||
# initialize hidden state of the controller RNN
|
||||
if chx is None:
|
||||
chx = cuda(T.zeros(batch_size, self.output_size), gpu_id=self.gpu_id)
|
||||
if self.rnn_type.lower() == 'lstm':
|
||||
chx = [ [ (chx.clone(), chx.clone()) for h in range(self.num_hidden_layers) ] for l in range(self.num_layers) ]
|
||||
else:
|
||||
chx = [ [ chx.clone() for h in range(self.num_hidden_layers) ] for l in range(self.num_layers) ]
|
||||
chx = [ None for x in range(self.num_layers) ]
|
||||
|
||||
# Last read vectors
|
||||
if last_read is None:
|
||||
@ -170,18 +162,14 @@ class DNC(nn.Module):
|
||||
def _layer_forward(self, input, layer, hx=(None, None), pass_through_memory=True):
|
||||
(chx, mhx) = hx
|
||||
|
||||
layer_input = input
|
||||
hchx = []
|
||||
for hlayer in range(self.num_hidden_layers):
|
||||
h = self.rnns[layer][hlayer](layer_input, chx[hlayer])
|
||||
layer_input = h[0] if self.rnn_type.lower() == 'lstm' else h
|
||||
hchx.append(h)
|
||||
chx = hchx
|
||||
# pass through the controller layer
|
||||
input, chx = self.rnns[layer](input.unsqueeze(1), chx)
|
||||
input = input.squeeze(1)
|
||||
|
||||
# the interface vector
|
||||
ξ = layer_input
|
||||
ξ = input
|
||||
# the output
|
||||
output = self.dropout_layer(self.mem_out(layer_input))
|
||||
output = input
|
||||
|
||||
# pass through memory
|
||||
if pass_through_memory:
|
||||
@ -223,7 +211,9 @@ class DNC(nn.Module):
|
||||
outs = [None] * max_length
|
||||
read_vectors = None
|
||||
|
||||
# pass through time
|
||||
for time in range(max_length):
|
||||
# pass thorugh layers
|
||||
for layer in range(self.num_layers):
|
||||
# this layer's hidden states
|
||||
chx = controller_hidden[layer]
|
||||
|
@ -128,6 +128,6 @@ def test_rnn_n():
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[1][2].size() == T.Size([10,100])
|
||||
assert chx[1].size() == T.Size([1,10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([10, 51])
|
||||
|
@ -128,6 +128,6 @@ def test_rnn_n():
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[0][0][0].size() == T.Size([10,100])
|
||||
assert chx[0][0].size() == T.Size([1,10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([10, 51])
|
||||
|
@ -128,6 +128,6 @@ def test_rnn_n():
|
||||
optimizer.step()
|
||||
|
||||
assert target_output.size() == T.Size([27, 10, 100])
|
||||
assert chx[1][2].size() == T.Size([10,100])
|
||||
assert chx[1].size() == T.Size([1,10,100])
|
||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||
assert rv.size() == T.Size([10, 51])
|
||||
|
Loading…
Reference in New Issue
Block a user