Merge pull request #17 from ixaxaar/scale_interface
Scale interface vectors, dynamic memory pass
This commit is contained in:
commit
11cc2107d7
@ -118,9 +118,9 @@ The copy task, as descibed in the original paper, is included in the repo.
|
|||||||
|
|
||||||
From the project root:
|
From the project root:
|
||||||
```bash
|
```bash
|
||||||
python ./tasks/copy_task.py -cuda 0 -optim rmsprop -batch_size 32 -mem_slot 64 # (original implementation)
|
python ./tasks/copy_task.py -cuda 0 -optim rmsprop -batch_size 32 -mem_slot 64 # (like original implementation)
|
||||||
|
|
||||||
python ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -nlayer 1 -nhlayer 2 -mem_slot 32 -batch_size 32 -optim adam # (faster convergence)
|
python3 ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 32 -batch_size 1000 -optim adam -sequence_max_length 8 # (faster convergence)
|
||||||
```
|
```
|
||||||
|
|
||||||
For the full set of options, see:
|
For the full set of options, see:
|
||||||
@ -150,7 +150,9 @@ The visdom dashboard shows memory as a heatmap for batch 0 every `-summarize_fre
|
|||||||
|
|
||||||
## General noteworthy stuff
|
## General noteworthy stuff
|
||||||
|
|
||||||
1. DNCs converge with Adam and RMSProp learning rules, SGD generally causes them to diverge.
|
1. DNCs converge faster with Adam and RMSProp learning rules, SGD generally converges extremely slowly.
|
||||||
|
The copy task, for example, takes 25k iterations on SGD with lr 1 compared to 3.5k for adam with lr 0.01.
|
||||||
|
2. `nan`s in the gradients are common, try with different batch sizes
|
||||||
|
|
||||||
Repos referred to for creation of this repo:
|
Repos referred to for creation of this repo:
|
||||||
|
|
||||||
|
12
dnc/dnc.py
12
dnc/dnc.py
@ -74,13 +74,13 @@ class DNC(nn.Module):
|
|||||||
for layer in range(self.num_layers):
|
for layer in range(self.num_layers):
|
||||||
if self.rnn_type.lower() == 'rnn':
|
if self.rnn_type.lower() == 'rnn':
|
||||||
self.rnns.append(nn.RNN((self.nn_input_size if layer == 0 else self.nn_output_size), self.output_size,
|
self.rnns.append(nn.RNN((self.nn_input_size if layer == 0 else self.nn_output_size), self.output_size,
|
||||||
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout))
|
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
||||||
elif self.rnn_type.lower() == 'gru':
|
elif self.rnn_type.lower() == 'gru':
|
||||||
self.rnns.append(nn.GRU((self.nn_input_size if layer == 0 else self.nn_output_size),
|
self.rnns.append(nn.GRU((self.nn_input_size if layer == 0 else self.nn_output_size),
|
||||||
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout))
|
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
||||||
if self.rnn_type.lower() == 'lstm':
|
if self.rnn_type.lower() == 'lstm':
|
||||||
self.rnns.append(nn.LSTM((self.nn_input_size if layer == 0 else self.nn_output_size),
|
self.rnns.append(nn.LSTM((self.nn_input_size if layer == 0 else self.nn_output_size),
|
||||||
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout))
|
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
|
||||||
setattr(self, self.rnn_type.lower() + '_layer_' + str(layer), self.rnns[layer])
|
setattr(self, self.rnn_type.lower() + '_layer_' + str(layer), self.rnns[layer])
|
||||||
|
|
||||||
# memories for each layer
|
# memories for each layer
|
||||||
@ -191,7 +191,7 @@ class DNC(nn.Module):
|
|||||||
else:
|
else:
|
||||||
read_vectors = None
|
read_vectors = None
|
||||||
|
|
||||||
return output, read_vectors, (chx, mhx)
|
return output, (chx, mhx, read_vectors)
|
||||||
|
|
||||||
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
|
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
|
||||||
# handle packed data
|
# handle packed data
|
||||||
@ -229,7 +229,7 @@ class DNC(nn.Module):
|
|||||||
chx = controller_hidden[layer]
|
chx = controller_hidden[layer]
|
||||||
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
||||||
# pass through controller
|
# pass through controller
|
||||||
outs[time], read_vectors, (chx, m) = \
|
outs[time], (chx, m, read_vectors) = \
|
||||||
self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory)
|
self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory)
|
||||||
|
|
||||||
# debug memory
|
# debug memory
|
||||||
@ -246,6 +246,8 @@ class DNC(nn.Module):
|
|||||||
if read_vectors is not None:
|
if read_vectors is not None:
|
||||||
# the controller output + read vectors go into next layer
|
# the controller output + read vectors go into next layer
|
||||||
outs[time] = T.cat([outs[time], read_vectors], 1)
|
outs[time] = T.cat([outs[time], read_vectors], 1)
|
||||||
|
else:
|
||||||
|
outs[time] = T.cat([outs[time], last_read], 1)
|
||||||
inputs[time] = outs[time]
|
inputs[time] = outs[time]
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
|
@ -50,11 +50,11 @@ class Memory(nn.Module):
|
|||||||
|
|
||||||
if hidden is None:
|
if hidden is None:
|
||||||
return {
|
return {
|
||||||
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.gpu_id),
|
'memory': cuda(T.zeros(b, m, w).fill_(0), gpu_id=self.gpu_id),
|
||||||
'link_matrix': cuda(T.zeros(b, 1, m, m), gpu_id=self.gpu_id),
|
'link_matrix': cuda(T.zeros(b, 1, m, m), gpu_id=self.gpu_id),
|
||||||
'precedence': cuda(T.zeros(b, 1, m), gpu_id=self.gpu_id),
|
'precedence': cuda(T.zeros(b, 1, m), gpu_id=self.gpu_id),
|
||||||
'read_weights': cuda(T.zeros(b, r, m).fill_(δ), gpu_id=self.gpu_id),
|
'read_weights': cuda(T.zeros(b, r, m).fill_(0), gpu_id=self.gpu_id),
|
||||||
'write_weights': cuda(T.zeros(b, 1, m).fill_(δ), gpu_id=self.gpu_id),
|
'write_weights': cuda(T.zeros(b, 1, m).fill_(0), gpu_id=self.gpu_id),
|
||||||
'usage_vector': cuda(T.zeros(b, m), gpu_id=self.gpu_id)
|
'usage_vector': cuda(T.zeros(b, m), gpu_id=self.gpu_id)
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
@ -66,11 +66,11 @@ class Memory(nn.Module):
|
|||||||
hidden['usage_vector'] = hidden['usage_vector'].clone()
|
hidden['usage_vector'] = hidden['usage_vector'].clone()
|
||||||
|
|
||||||
if erase:
|
if erase:
|
||||||
hidden['memory'].data.fill_(δ)
|
hidden['memory'].data.fill_(0)
|
||||||
hidden['link_matrix'].data.zero_()
|
hidden['link_matrix'].data.zero_()
|
||||||
hidden['precedence'].data.zero_()
|
hidden['precedence'].data.zero_()
|
||||||
hidden['read_weights'].data.fill_(δ)
|
hidden['read_weights'].data.fill_(0)
|
||||||
hidden['write_weights'].data.fill_(δ)
|
hidden['write_weights'].data.fill_(0)
|
||||||
hidden['usage_vector'].data.zero_()
|
hidden['usage_vector'].data.zero_()
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
@ -116,7 +116,7 @@ class Memory(nn.Module):
|
|||||||
new_link_matrix = write_weights_i * precedence
|
new_link_matrix = write_weights_i * precedence
|
||||||
|
|
||||||
link_matrix = prev_scale * link_matrix + new_link_matrix
|
link_matrix = prev_scale * link_matrix + new_link_matrix
|
||||||
# elaborate trick to delete diag elems
|
# trick to delete diag elems
|
||||||
return self.I.expand_as(link_matrix) * link_matrix
|
return self.I.expand_as(link_matrix) * link_matrix
|
||||||
|
|
||||||
def update_precedence(self, precedence, write_weights):
|
def update_precedence(self, precedence, write_weights):
|
||||||
@ -139,7 +139,6 @@ class Memory(nn.Module):
|
|||||||
hidden['usage_vector'],
|
hidden['usage_vector'],
|
||||||
allocation_gate * write_gate
|
allocation_gate * write_gate
|
||||||
)
|
)
|
||||||
# print((alloc).data.cpu().numpy())
|
|
||||||
|
|
||||||
# get write weightings
|
# get write weightings
|
||||||
hidden['write_weights'] = self.write_weighting(
|
hidden['write_weights'] = self.write_weighting(
|
||||||
@ -170,8 +169,7 @@ class Memory(nn.Module):
|
|||||||
|
|
||||||
def content_weightings(self, memory, keys, strengths):
|
def content_weightings(self, memory, keys, strengths):
|
||||||
d = θ(memory, keys)
|
d = θ(memory, keys)
|
||||||
strengths = F.softplus(strengths).unsqueeze(2)
|
return σ(d * strengths.unsqueeze(2), 2)
|
||||||
return σ(d * strengths, 2)
|
|
||||||
|
|
||||||
def directional_weightings(self, link_matrix, read_weights):
|
def directional_weightings(self, link_matrix, read_weights):
|
||||||
rw = read_weights.unsqueeze(1)
|
rw = read_weights.unsqueeze(1)
|
||||||
@ -215,17 +213,17 @@ class Memory(nn.Module):
|
|||||||
|
|
||||||
if self.independent_linears:
|
if self.independent_linears:
|
||||||
# r read keys (b * r * w)
|
# r read keys (b * r * w)
|
||||||
read_keys = self.read_keys_transform(ξ).view(b, r, w)
|
read_keys = F.tanh(self.read_keys_transform(ξ).view(b, r, w))
|
||||||
# r read strengths (b * r)
|
# r read strengths (b * r)
|
||||||
read_strengths = self.read_strengths_transform(ξ).view(b, r)
|
read_strengths = F.softplus(self.read_strengths_transform(ξ).view(b, r))
|
||||||
# write key (b * 1 * w)
|
# write key (b * 1 * w)
|
||||||
write_key = self.write_key_transform(ξ).view(b, 1, w)
|
write_key = F.tanh(self.write_key_transform(ξ).view(b, 1, w))
|
||||||
# write strength (b * 1)
|
# write strength (b * 1)
|
||||||
write_strength = self.write_strength_transform(ξ).view(b, 1)
|
write_strength = F.softplus(self.write_strength_transform(ξ).view(b, 1))
|
||||||
# erase vector (b * 1 * w)
|
# erase vector (b * 1 * w)
|
||||||
erase_vector = F.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w))
|
erase_vector = F.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w))
|
||||||
# write vector (b * 1 * w)
|
# write vector (b * 1 * w)
|
||||||
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
|
write_vector = F.tanh(self.write_vector_transform(ξ).view(b, 1, w))
|
||||||
# r free gates (b * r)
|
# r free gates (b * r)
|
||||||
free_gates = F.sigmoid(self.free_gates_transform(ξ).view(b, r))
|
free_gates = F.sigmoid(self.free_gates_transform(ξ).view(b, r))
|
||||||
# allocation gate (b * 1)
|
# allocation gate (b * 1)
|
||||||
@ -237,17 +235,17 @@ class Memory(nn.Module):
|
|||||||
else:
|
else:
|
||||||
ξ = self.interface_weights(ξ)
|
ξ = self.interface_weights(ξ)
|
||||||
# r read keys (b * w * r)
|
# r read keys (b * w * r)
|
||||||
read_keys = ξ[:, :r * w].contiguous().view(b, r, w)
|
read_keys = F.tanh(ξ[:, :r * w].contiguous().view(b, r, w))
|
||||||
# r read strengths (b * r)
|
# r read strengths (b * r)
|
||||||
read_strengths = ξ[:, r * w:r * w + r].contiguous().view(b, r)
|
read_strengths = F.softplus(ξ[:, r * w:r * w + r].contiguous().view(b, r))
|
||||||
# write key (b * w * 1)
|
# write key (b * w * 1)
|
||||||
write_key = ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w)
|
write_key = F.tanh(ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w))
|
||||||
# write strength (b * 1)
|
# write strength (b * 1)
|
||||||
write_strength = ξ[:, r * w + r + w].contiguous().view(b, 1)
|
write_strength = F.softplus(ξ[:, r * w + r + w].contiguous().view(b, 1))
|
||||||
# erase vector (b * w)
|
# erase vector (b * w)
|
||||||
erase_vector = F.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w))
|
erase_vector = F.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w))
|
||||||
# write vector (b * w)
|
# write vector (b * w)
|
||||||
write_vector = ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w)
|
write_vector = F.tanh(ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w))
|
||||||
# r free gates (b * r)
|
# r free gates (b * r)
|
||||||
free_gates = F.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r))
|
free_gates = F.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r))
|
||||||
# allocation gate (b * 1)
|
# allocation gate (b * 1)
|
||||||
|
@ -37,8 +37,8 @@ parser.add_argument('-clip', type=float, default=50, help='gradient clipping')
|
|||||||
|
|
||||||
parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size')
|
parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size')
|
||||||
parser.add_argument('-mem_size', type=int, default=16, help='memory dimension')
|
parser.add_argument('-mem_size', type=int, default=16, help='memory dimension')
|
||||||
parser.add_argument('-mem_slot', type=int, default=10, help='number of memory slots')
|
parser.add_argument('-mem_slot', type=int, default=16, help='number of memory slots')
|
||||||
parser.add_argument('-read_heads', type=int, default=1, help='number of read heads')
|
parser.add_argument('-read_heads', type=int, default=4, help='number of read heads')
|
||||||
|
|
||||||
parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length')
|
parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length')
|
||||||
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
|
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
|
||||||
@ -121,7 +121,8 @@ if __name__ == '__main__':
|
|||||||
read_heads=read_heads,
|
read_heads=read_heads,
|
||||||
gpu_id=args.cuda,
|
gpu_id=args.cuda,
|
||||||
debug=True,
|
debug=True,
|
||||||
batch_first=True
|
batch_first=True,
|
||||||
|
independent_linears=True
|
||||||
)
|
)
|
||||||
print(rnn)
|
print(rnn)
|
||||||
|
|
||||||
@ -131,9 +132,20 @@ if __name__ == '__main__':
|
|||||||
last_save_losses = []
|
last_save_losses = []
|
||||||
|
|
||||||
if args.optim == 'adam':
|
if args.optim == 'adam':
|
||||||
optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98])
|
optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
||||||
|
if args.optim == 'sparseadam':
|
||||||
|
optimizer = optim.SparseAdam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
||||||
|
if args.optim == 'adamax':
|
||||||
|
optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
|
||||||
elif args.optim == 'rmsprop':
|
elif args.optim == 'rmsprop':
|
||||||
optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, eps=1e-10)
|
optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, eps=1e-10) # 0.0001
|
||||||
|
elif args.optim == 'sgd':
|
||||||
|
optimizer = optim.SGD(rnn.parameters(), lr=args.lr) # 0.01
|
||||||
|
elif args.optim == 'adagrad':
|
||||||
|
optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr)
|
||||||
|
elif args.optim == 'adadelta':
|
||||||
|
optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)
|
||||||
|
|
||||||
|
|
||||||
for epoch in range(iterations + 1):
|
for epoch in range(iterations + 1):
|
||||||
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
|
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
|
||||||
@ -183,13 +195,13 @@ if __name__ == '__main__':
|
|||||||
)
|
)
|
||||||
|
|
||||||
viz.heatmap(
|
viz.heatmap(
|
||||||
v['link_matrix'],
|
v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot),
|
||||||
opts=dict(
|
opts=dict(
|
||||||
xtickstep=10,
|
xtickstep=10,
|
||||||
ytickstep=2,
|
ytickstep=2,
|
||||||
title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss),
|
title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss),
|
||||||
ylabel='layer * time',
|
ylabel='mem_slot',
|
||||||
xlabel='mem_slot * mem_slot'
|
xlabel='mem_slot'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,6 +17,8 @@ import math
|
|||||||
import time
|
import time
|
||||||
sys.path.insert(0, '.')
|
sys.path.insert(0, '.')
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
from dnc import DNC
|
from dnc import DNC
|
||||||
from test_utils import generate_data, criterion
|
from test_utils import generate_data, criterion
|
||||||
|
|
||||||
@ -128,6 +130,69 @@ def test_rnn_n():
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
assert target_output.size() == T.Size([27, 10, 100])
|
assert target_output.size() == T.Size([27, 10, 100])
|
||||||
assert chx[1].size() == T.Size([1,10,100])
|
assert chx[1].size() == T.Size([num_hidden_layers,10,100])
|
||||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||||
assert rv.size() == T.Size([10, 51])
|
assert rv.size() == T.Size([10, 51])
|
||||||
|
|
||||||
|
|
||||||
|
def test_rnn_no_memory_pass():
|
||||||
|
T.manual_seed(1111)
|
||||||
|
|
||||||
|
input_size = 100
|
||||||
|
hidden_size = 100
|
||||||
|
rnn_type = 'gru'
|
||||||
|
num_layers = 3
|
||||||
|
num_hidden_layers = 5
|
||||||
|
dropout = 0.2
|
||||||
|
nr_cells = 12
|
||||||
|
cell_size = 17
|
||||||
|
read_heads = 3
|
||||||
|
gpu_id = -1
|
||||||
|
debug = True
|
||||||
|
lr = 0.001
|
||||||
|
sequence_max_length = 10
|
||||||
|
batch_size = 10
|
||||||
|
cuda = gpu_id
|
||||||
|
clip = 20
|
||||||
|
length = 13
|
||||||
|
|
||||||
|
rnn = DNC(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
rnn_type=rnn_type,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_hidden_layers=num_hidden_layers,
|
||||||
|
dropout=dropout,
|
||||||
|
nr_cells=nr_cells,
|
||||||
|
cell_size=cell_size,
|
||||||
|
read_heads=read_heads,
|
||||||
|
gpu_id=gpu_id,
|
||||||
|
debug=debug
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||||
|
target_output = target_output.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
(chx, mhx, rv) = (None, None, None)
|
||||||
|
outputs = []
|
||||||
|
for x in range(6):
|
||||||
|
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
|
||||||
|
output = output.transpose(0, 1)
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
output = functools.reduce(lambda x,y: x + y, outputs)
|
||||||
|
loss = criterion((output), target_output)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
assert target_output.size() == T.Size([27, 10, 100])
|
||||||
|
assert chx[0].size() == T.Size([num_hidden_layers,10,100])
|
||||||
|
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||||
|
assert rv == None
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ import sys
|
|||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
import functools
|
||||||
sys.path.insert(0, '.')
|
sys.path.insert(0, '.')
|
||||||
|
|
||||||
from dnc import DNC
|
from dnc import DNC
|
||||||
@ -128,6 +129,68 @@ def test_rnn_n():
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
assert target_output.size() == T.Size([27, 10, 100])
|
assert target_output.size() == T.Size([27, 10, 100])
|
||||||
assert chx[0][0].size() == T.Size([1,10,100])
|
assert chx[0][0].size() == T.Size([num_hidden_layers,10,100])
|
||||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||||
assert rv.size() == T.Size([10, 51])
|
assert rv.size() == T.Size([10, 51])
|
||||||
|
|
||||||
|
|
||||||
|
def test_rnn_no_memory_pass():
|
||||||
|
T.manual_seed(1111)
|
||||||
|
|
||||||
|
input_size = 100
|
||||||
|
hidden_size = 100
|
||||||
|
rnn_type = 'lstm'
|
||||||
|
num_layers = 3
|
||||||
|
num_hidden_layers = 5
|
||||||
|
dropout = 0.2
|
||||||
|
nr_cells = 12
|
||||||
|
cell_size = 17
|
||||||
|
read_heads = 3
|
||||||
|
gpu_id = -1
|
||||||
|
debug = True
|
||||||
|
lr = 0.001
|
||||||
|
sequence_max_length = 10
|
||||||
|
batch_size = 10
|
||||||
|
cuda = gpu_id
|
||||||
|
clip = 20
|
||||||
|
length = 13
|
||||||
|
|
||||||
|
rnn = DNC(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
rnn_type=rnn_type,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_hidden_layers=num_hidden_layers,
|
||||||
|
dropout=dropout,
|
||||||
|
nr_cells=nr_cells,
|
||||||
|
cell_size=cell_size,
|
||||||
|
read_heads=read_heads,
|
||||||
|
gpu_id=gpu_id,
|
||||||
|
debug=debug
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||||
|
target_output = target_output.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
(chx, mhx, rv) = (None, None, None)
|
||||||
|
outputs = []
|
||||||
|
for x in range(6):
|
||||||
|
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
|
||||||
|
output = output.transpose(0, 1)
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
output = functools.reduce(lambda x,y: x + y, outputs)
|
||||||
|
loss = criterion((output), target_output)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
assert target_output.size() == T.Size([27, 10, 100])
|
||||||
|
assert chx[0][0].size() == T.Size([num_hidden_layers,10,100])
|
||||||
|
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||||
|
assert rv == None
|
||||||
|
|
||||||
|
@ -17,6 +17,8 @@ import math
|
|||||||
import time
|
import time
|
||||||
sys.path.insert(0, '.')
|
sys.path.insert(0, '.')
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
from dnc import DNC
|
from dnc import DNC
|
||||||
from test_utils import generate_data, criterion
|
from test_utils import generate_data, criterion
|
||||||
|
|
||||||
@ -128,6 +130,69 @@ def test_rnn_n():
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
assert target_output.size() == T.Size([27, 10, 100])
|
assert target_output.size() == T.Size([27, 10, 100])
|
||||||
assert chx[1].size() == T.Size([1,10,100])
|
assert chx[1].size() == T.Size([num_hidden_layers,10,100])
|
||||||
assert mhx['memory'].size() == T.Size([10,12,17])
|
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||||
assert rv.size() == T.Size([10, 51])
|
assert rv.size() == T.Size([10, 51])
|
||||||
|
|
||||||
|
|
||||||
|
def test_rnn_no_memory_pass():
|
||||||
|
T.manual_seed(1111)
|
||||||
|
|
||||||
|
input_size = 100
|
||||||
|
hidden_size = 100
|
||||||
|
rnn_type = 'rnn'
|
||||||
|
num_layers = 3
|
||||||
|
num_hidden_layers = 5
|
||||||
|
dropout = 0.2
|
||||||
|
nr_cells = 12
|
||||||
|
cell_size = 17
|
||||||
|
read_heads = 3
|
||||||
|
gpu_id = -1
|
||||||
|
debug = True
|
||||||
|
lr = 0.001
|
||||||
|
sequence_max_length = 10
|
||||||
|
batch_size = 10
|
||||||
|
cuda = gpu_id
|
||||||
|
clip = 20
|
||||||
|
length = 13
|
||||||
|
|
||||||
|
rnn = DNC(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
rnn_type=rnn_type,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_hidden_layers=num_hidden_layers,
|
||||||
|
dropout=dropout,
|
||||||
|
nr_cells=nr_cells,
|
||||||
|
cell_size=cell_size,
|
||||||
|
read_heads=read_heads,
|
||||||
|
gpu_id=gpu_id,
|
||||||
|
debug=debug
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = optim.Adam(rnn.parameters(), lr=lr)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
|
||||||
|
target_output = target_output.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
(chx, mhx, rv) = (None, None, None)
|
||||||
|
outputs = []
|
||||||
|
for x in range(6):
|
||||||
|
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
|
||||||
|
output = output.transpose(0, 1)
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
output = functools.reduce(lambda x,y: x + y, outputs)
|
||||||
|
loss = criterion((output), target_output)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
assert target_output.size() == T.Size([27, 10, 100])
|
||||||
|
assert chx[1].size() == T.Size([num_hidden_layers,10,100])
|
||||||
|
assert mhx['memory'].size() == T.Size([10,12,17])
|
||||||
|
assert rv == None
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user