Merge pull request #17 from ixaxaar/scale_interface

Scale interface vectors, dynamic memory pass
This commit is contained in:
Russi Chatterjee 2017-12-01 00:47:24 +05:30 committed by GitHub
commit 11cc2107d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 246 additions and 39 deletions

View File

@ -118,9 +118,9 @@ The copy task, as descibed in the original paper, is included in the repo.
From the project root: From the project root:
```bash ```bash
python ./tasks/copy_task.py -cuda 0 -optim rmsprop -batch_size 32 -mem_slot 64 # (original implementation) python ./tasks/copy_task.py -cuda 0 -optim rmsprop -batch_size 32 -mem_slot 64 # (like original implementation)
python ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -nlayer 1 -nhlayer 2 -mem_slot 32 -batch_size 32 -optim adam # (faster convergence) python3 ./tasks/copy_task.py -cuda 0 -lr 0.001 -rnn_type lstm -nlayer 1 -nhlayer 2 -dropout 0 -mem_slot 32 -batch_size 1000 -optim adam -sequence_max_length 8 # (faster convergence)
``` ```
For the full set of options, see: For the full set of options, see:
@ -150,7 +150,9 @@ The visdom dashboard shows memory as a heatmap for batch 0 every `-summarize_fre
## General noteworthy stuff ## General noteworthy stuff
1. DNCs converge with Adam and RMSProp learning rules, SGD generally causes them to diverge. 1. DNCs converge faster with Adam and RMSProp learning rules, SGD generally converges extremely slowly.
The copy task, for example, takes 25k iterations on SGD with lr 1 compared to 3.5k for adam with lr 0.01.
2. `nan`s in the gradients are common, try with different batch sizes
Repos referred to for creation of this repo: Repos referred to for creation of this repo:

View File

@ -74,13 +74,13 @@ class DNC(nn.Module):
for layer in range(self.num_layers): for layer in range(self.num_layers):
if self.rnn_type.lower() == 'rnn': if self.rnn_type.lower() == 'rnn':
self.rnns.append(nn.RNN((self.nn_input_size if layer == 0 else self.nn_output_size), self.output_size, self.rnns.append(nn.RNN((self.nn_input_size if layer == 0 else self.nn_output_size), self.output_size,
bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout)) bias=self.bias, nonlinearity=self.nonlinearity, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
elif self.rnn_type.lower() == 'gru': elif self.rnn_type.lower() == 'gru':
self.rnns.append(nn.GRU((self.nn_input_size if layer == 0 else self.nn_output_size), self.rnns.append(nn.GRU((self.nn_input_size if layer == 0 else self.nn_output_size),
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout)) self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
if self.rnn_type.lower() == 'lstm': if self.rnn_type.lower() == 'lstm':
self.rnns.append(nn.LSTM((self.nn_input_size if layer == 0 else self.nn_output_size), self.rnns.append(nn.LSTM((self.nn_input_size if layer == 0 else self.nn_output_size),
self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout)) self.output_size, bias=self.bias, batch_first=True, dropout=self.dropout, num_layers=self.num_hidden_layers))
setattr(self, self.rnn_type.lower() + '_layer_' + str(layer), self.rnns[layer]) setattr(self, self.rnn_type.lower() + '_layer_' + str(layer), self.rnns[layer])
# memories for each layer # memories for each layer
@ -191,7 +191,7 @@ class DNC(nn.Module):
else: else:
read_vectors = None read_vectors = None
return output, read_vectors, (chx, mhx) return output, (chx, mhx, read_vectors)
def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True): def forward(self, input, hx=(None, None, None), reset_experience=False, pass_through_memory=True):
# handle packed data # handle packed data
@ -229,7 +229,7 @@ class DNC(nn.Module):
chx = controller_hidden[layer] chx = controller_hidden[layer]
m = mem_hidden if self.share_memory else mem_hidden[layer] m = mem_hidden if self.share_memory else mem_hidden[layer]
# pass through controller # pass through controller
outs[time], read_vectors, (chx, m) = \ outs[time], (chx, m, read_vectors) = \
self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory) self._layer_forward(inputs[time], layer, (chx, m), pass_through_memory)
# debug memory # debug memory
@ -246,6 +246,8 @@ class DNC(nn.Module):
if read_vectors is not None: if read_vectors is not None:
# the controller output + read vectors go into next layer # the controller output + read vectors go into next layer
outs[time] = T.cat([outs[time], read_vectors], 1) outs[time] = T.cat([outs[time], read_vectors], 1)
else:
outs[time] = T.cat([outs[time], last_read], 1)
inputs[time] = outs[time] inputs[time] = outs[time]
if self.debug: if self.debug:

View File

@ -50,11 +50,11 @@ class Memory(nn.Module):
if hidden is None: if hidden is None:
return { return {
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.gpu_id), 'memory': cuda(T.zeros(b, m, w).fill_(0), gpu_id=self.gpu_id),
'link_matrix': cuda(T.zeros(b, 1, m, m), gpu_id=self.gpu_id), 'link_matrix': cuda(T.zeros(b, 1, m, m), gpu_id=self.gpu_id),
'precedence': cuda(T.zeros(b, 1, m), gpu_id=self.gpu_id), 'precedence': cuda(T.zeros(b, 1, m), gpu_id=self.gpu_id),
'read_weights': cuda(T.zeros(b, r, m).fill_(δ), gpu_id=self.gpu_id), 'read_weights': cuda(T.zeros(b, r, m).fill_(0), gpu_id=self.gpu_id),
'write_weights': cuda(T.zeros(b, 1, m).fill_(δ), gpu_id=self.gpu_id), 'write_weights': cuda(T.zeros(b, 1, m).fill_(0), gpu_id=self.gpu_id),
'usage_vector': cuda(T.zeros(b, m), gpu_id=self.gpu_id) 'usage_vector': cuda(T.zeros(b, m), gpu_id=self.gpu_id)
} }
else: else:
@ -66,11 +66,11 @@ class Memory(nn.Module):
hidden['usage_vector'] = hidden['usage_vector'].clone() hidden['usage_vector'] = hidden['usage_vector'].clone()
if erase: if erase:
hidden['memory'].data.fill_(δ) hidden['memory'].data.fill_(0)
hidden['link_matrix'].data.zero_() hidden['link_matrix'].data.zero_()
hidden['precedence'].data.zero_() hidden['precedence'].data.zero_()
hidden['read_weights'].data.fill_(δ) hidden['read_weights'].data.fill_(0)
hidden['write_weights'].data.fill_(δ) hidden['write_weights'].data.fill_(0)
hidden['usage_vector'].data.zero_() hidden['usage_vector'].data.zero_()
return hidden return hidden
@ -116,7 +116,7 @@ class Memory(nn.Module):
new_link_matrix = write_weights_i * precedence new_link_matrix = write_weights_i * precedence
link_matrix = prev_scale * link_matrix + new_link_matrix link_matrix = prev_scale * link_matrix + new_link_matrix
# elaborate trick to delete diag elems # trick to delete diag elems
return self.I.expand_as(link_matrix) * link_matrix return self.I.expand_as(link_matrix) * link_matrix
def update_precedence(self, precedence, write_weights): def update_precedence(self, precedence, write_weights):
@ -139,7 +139,6 @@ class Memory(nn.Module):
hidden['usage_vector'], hidden['usage_vector'],
allocation_gate * write_gate allocation_gate * write_gate
) )
# print((alloc).data.cpu().numpy())
# get write weightings # get write weightings
hidden['write_weights'] = self.write_weighting( hidden['write_weights'] = self.write_weighting(
@ -170,8 +169,7 @@ class Memory(nn.Module):
def content_weightings(self, memory, keys, strengths): def content_weightings(self, memory, keys, strengths):
d = θ(memory, keys) d = θ(memory, keys)
strengths = F.softplus(strengths).unsqueeze(2) return σ(d * strengths.unsqueeze(2), 2)
return σ(d * strengths, 2)
def directional_weightings(self, link_matrix, read_weights): def directional_weightings(self, link_matrix, read_weights):
rw = read_weights.unsqueeze(1) rw = read_weights.unsqueeze(1)
@ -215,17 +213,17 @@ class Memory(nn.Module):
if self.independent_linears: if self.independent_linears:
# r read keys (b * r * w) # r read keys (b * r * w)
read_keys = self.read_keys_transform(ξ).view(b, r, w) read_keys = F.tanh(self.read_keys_transform(ξ).view(b, r, w))
# r read strengths (b * r) # r read strengths (b * r)
read_strengths = self.read_strengths_transform(ξ).view(b, r) read_strengths = F.softplus(self.read_strengths_transform(ξ).view(b, r))
# write key (b * 1 * w) # write key (b * 1 * w)
write_key = self.write_key_transform(ξ).view(b, 1, w) write_key = F.tanh(self.write_key_transform(ξ).view(b, 1, w))
# write strength (b * 1) # write strength (b * 1)
write_strength = self.write_strength_transform(ξ).view(b, 1) write_strength = F.softplus(self.write_strength_transform(ξ).view(b, 1))
# erase vector (b * 1 * w) # erase vector (b * 1 * w)
erase_vector = F.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w)) erase_vector = F.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w))
# write vector (b * 1 * w) # write vector (b * 1 * w)
write_vector = self.write_vector_transform(ξ).view(b, 1, w) write_vector = F.tanh(self.write_vector_transform(ξ).view(b, 1, w))
# r free gates (b * r) # r free gates (b * r)
free_gates = F.sigmoid(self.free_gates_transform(ξ).view(b, r)) free_gates = F.sigmoid(self.free_gates_transform(ξ).view(b, r))
# allocation gate (b * 1) # allocation gate (b * 1)
@ -237,17 +235,17 @@ class Memory(nn.Module):
else: else:
ξ = self.interface_weights(ξ) ξ = self.interface_weights(ξ)
# r read keys (b * w * r) # r read keys (b * w * r)
read_keys = ξ[:, :r * w].contiguous().view(b, r, w) read_keys = F.tanh(ξ[:, :r * w].contiguous().view(b, r, w))
# r read strengths (b * r) # r read strengths (b * r)
read_strengths = ξ[:, r * w:r * w + r].contiguous().view(b, r) read_strengths = F.softplus(ξ[:, r * w:r * w + r].contiguous().view(b, r))
# write key (b * w * 1) # write key (b * w * 1)
write_key = ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w) write_key = F.tanh(ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w))
# write strength (b * 1) # write strength (b * 1)
write_strength = ξ[:, r * w + r + w].contiguous().view(b, 1) write_strength = F.softplus(ξ[:, r * w + r + w].contiguous().view(b, 1))
# erase vector (b * w) # erase vector (b * w)
erase_vector = F.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w)) erase_vector = F.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w))
# write vector (b * w) # write vector (b * w)
write_vector = ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w) write_vector = F.tanh(ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w))
# r free gates (b * r) # r free gates (b * r)
free_gates = F.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r)) free_gates = F.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r))
# allocation gate (b * 1) # allocation gate (b * 1)

View File

@ -37,8 +37,8 @@ parser.add_argument('-clip', type=float, default=50, help='gradient clipping')
parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size') parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size')
parser.add_argument('-mem_size', type=int, default=16, help='memory dimension') parser.add_argument('-mem_size', type=int, default=16, help='memory dimension')
parser.add_argument('-mem_slot', type=int, default=10, help='number of memory slots') parser.add_argument('-mem_slot', type=int, default=16, help='number of memory slots')
parser.add_argument('-read_heads', type=int, default=1, help='number of read heads') parser.add_argument('-read_heads', type=int, default=4, help='number of read heads')
parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length') parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length')
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU') parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
@ -121,7 +121,8 @@ if __name__ == '__main__':
read_heads=read_heads, read_heads=read_heads,
gpu_id=args.cuda, gpu_id=args.cuda,
debug=True, debug=True,
batch_first=True batch_first=True,
independent_linears=True
) )
print(rnn) print(rnn)
@ -131,9 +132,20 @@ if __name__ == '__main__':
last_save_losses = [] last_save_losses = []
if args.optim == 'adam': if args.optim == 'adam':
optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) optimizer = optim.Adam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
if args.optim == 'sparseadam':
optimizer = optim.SparseAdam(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
if args.optim == 'adamax':
optimizer = optim.Adamax(rnn.parameters(), lr=args.lr, eps=1e-9, betas=[0.9, 0.98]) # 0.0001
elif args.optim == 'rmsprop': elif args.optim == 'rmsprop':
optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, eps=1e-10) optimizer = optim.RMSprop(rnn.parameters(), lr=args.lr, eps=1e-10) # 0.0001
elif args.optim == 'sgd':
optimizer = optim.SGD(rnn.parameters(), lr=args.lr) # 0.01
elif args.optim == 'adagrad':
optimizer = optim.Adagrad(rnn.parameters(), lr=args.lr)
elif args.optim == 'adadelta':
optimizer = optim.Adadelta(rnn.parameters(), lr=args.lr)
for epoch in range(iterations + 1): for epoch in range(iterations + 1):
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations)) llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
@ -183,13 +195,13 @@ if __name__ == '__main__':
) )
viz.heatmap( viz.heatmap(
v['link_matrix'], v['link_matrix'][-1].reshape(args.mem_slot, args.mem_slot),
opts=dict( opts=dict(
xtickstep=10, xtickstep=10,
ytickstep=2, ytickstep=2,
title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss), title='Link Matrix, t: ' + str(epoch) + ', loss: ' + str(loss),
ylabel='layer * time', ylabel='mem_slot',
xlabel='mem_slot * mem_slot' xlabel='mem_slot'
) )
) )

View File

@ -17,6 +17,8 @@ import math
import time import time
sys.path.insert(0, '.') sys.path.insert(0, '.')
import functools
from dnc import DNC from dnc import DNC
from test_utils import generate_data, criterion from test_utils import generate_data, criterion
@ -128,6 +130,69 @@ def test_rnn_n():
optimizer.step() optimizer.step()
assert target_output.size() == T.Size([27, 10, 100]) assert target_output.size() == T.Size([27, 10, 100])
assert chx[1].size() == T.Size([1,10,100]) assert chx[1].size() == T.Size([num_hidden_layers,10,100])
assert mhx['memory'].size() == T.Size([10,12,17]) assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10, 51]) assert rv.size() == T.Size([10, 51])
def test_rnn_no_memory_pass():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'gru'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 12
cell_size = 17
read_heads = 3
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13
rnn = DNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
(chx, mhx, rv) = (None, None, None)
outputs = []
for x in range(6):
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
output = output.transpose(0, 1)
outputs.append(output)
output = functools.reduce(lambda x,y: x + y, outputs)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[0].size() == T.Size([num_hidden_layers,10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv == None

View File

@ -15,6 +15,7 @@ import sys
import os import os
import math import math
import time import time
import functools
sys.path.insert(0, '.') sys.path.insert(0, '.')
from dnc import DNC from dnc import DNC
@ -128,6 +129,68 @@ def test_rnn_n():
optimizer.step() optimizer.step()
assert target_output.size() == T.Size([27, 10, 100]) assert target_output.size() == T.Size([27, 10, 100])
assert chx[0][0].size() == T.Size([1,10,100]) assert chx[0][0].size() == T.Size([num_hidden_layers,10,100])
assert mhx['memory'].size() == T.Size([10,12,17]) assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10, 51]) assert rv.size() == T.Size([10, 51])
def test_rnn_no_memory_pass():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'lstm'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 12
cell_size = 17
read_heads = 3
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13
rnn = DNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
(chx, mhx, rv) = (None, None, None)
outputs = []
for x in range(6):
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
output = output.transpose(0, 1)
outputs.append(output)
output = functools.reduce(lambda x,y: x + y, outputs)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[0][0].size() == T.Size([num_hidden_layers,10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv == None

View File

@ -17,6 +17,8 @@ import math
import time import time
sys.path.insert(0, '.') sys.path.insert(0, '.')
import functools
from dnc import DNC from dnc import DNC
from test_utils import generate_data, criterion from test_utils import generate_data, criterion
@ -128,6 +130,69 @@ def test_rnn_n():
optimizer.step() optimizer.step()
assert target_output.size() == T.Size([27, 10, 100]) assert target_output.size() == T.Size([27, 10, 100])
assert chx[1].size() == T.Size([1,10,100]) assert chx[1].size() == T.Size([num_hidden_layers,10,100])
assert mhx['memory'].size() == T.Size([10,12,17]) assert mhx['memory'].size() == T.Size([10,12,17])
assert rv.size() == T.Size([10, 51]) assert rv.size() == T.Size([10, 51])
def test_rnn_no_memory_pass():
T.manual_seed(1111)
input_size = 100
hidden_size = 100
rnn_type = 'rnn'
num_layers = 3
num_hidden_layers = 5
dropout = 0.2
nr_cells = 12
cell_size = 17
read_heads = 3
gpu_id = -1
debug = True
lr = 0.001
sequence_max_length = 10
batch_size = 10
cuda = gpu_id
clip = 20
length = 13
rnn = DNC(
input_size=input_size,
hidden_size=hidden_size,
rnn_type=rnn_type,
num_layers=num_layers,
num_hidden_layers=num_hidden_layers,
dropout=dropout,
nr_cells=nr_cells,
cell_size=cell_size,
read_heads=read_heads,
gpu_id=gpu_id,
debug=debug
)
optimizer = optim.Adam(rnn.parameters(), lr=lr)
optimizer.zero_grad()
input_data, target_output = generate_data(batch_size, length, input_size, cuda)
target_output = target_output.transpose(0, 1).contiguous()
(chx, mhx, rv) = (None, None, None)
outputs = []
for x in range(6):
output, (chx, mhx, rv), v = rnn(input_data, (chx, mhx, rv), pass_through_memory=False)
output = output.transpose(0, 1)
outputs.append(output)
output = functools.reduce(lambda x,y: x + y, outputs)
loss = criterion((output), target_output)
loss.backward()
T.nn.utils.clip_grad_norm(rnn.parameters(), clip)
optimizer.step()
assert target_output.size() == T.Size([27, 10, 100])
assert chx[1].size() == T.Size([num_hidden_layers,10,100])
assert mhx['memory'].size() == T.Size([10,12,17])
assert rv == None