Initial commit, pushed into pypi
This commit is contained in:
parent
397d7eec7f
commit
90365bd955
2
.gitignore
vendored
2
.gitignore
vendored
@ -16,3 +16,5 @@ __pycache__/
|
||||
*.lang
|
||||
*.log
|
||||
.cache/
|
||||
dist/
|
||||
dnc.egg-info/
|
||||
|
69
README.md
Normal file
69
README.md
Normal file
@ -0,0 +1,69 @@
|
||||
# Differentiable Neural Computer, for Pytorch
|
||||
|
||||
This is an implementation of [Differentiable Neural Computers](people.idsia.ch/~rupesh/rnnsymposium2016/slides/graves.pdf), described in the paper [Hybrid computing using a neural network with dynamic external memory, Graves et al.](www.nature.com/articles/nature20101)
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
pip install dnc
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
**Parameters**:
|
||||
|
||||
| Argument | Default | Description |
|
||||
| --- | --- | --- |
|
||||
| input_size | None | Size of the input vectors |
|
||||
| hidden_size | None | Size of hidden units |
|
||||
| rnn_type | 'lstm' | Type of recurrent cells used in the controller |
|
||||
| num_layers | 1 | Number of layers of recurrent units in the controller |
|
||||
| bias | True | Bias |
|
||||
| batch_first | True | Whether data is fed batch first |
|
||||
| dropout | 0 | Dropout between layers in the controller (Not yet implemented) |
|
||||
| bidirectional | False | If the controller is bidirectional (Not yet implemented) |
|
||||
| nr_cells | 5 | Number of memory cells |
|
||||
| read_heads | 2 | Number of read heads |
|
||||
| cell_size | 10 | Size of each memory cell |
|
||||
| nonlinearity | 'tanh' | If using 'rnn' as `rnn_type`, non-linearity of the RNNs |
|
||||
| gpu_id | -1 | ID of the GPU, -1 for CPU |
|
||||
| independent_linears | False | Whether to use independent linear units to derive interface vector |
|
||||
| share_memory | True | Whether to share memory between controller layers |
|
||||
|
||||
|
||||
Example usage:
|
||||
|
||||
```python
|
||||
from dnc import DNC
|
||||
|
||||
rnn = DNC(
|
||||
input_size=64,
|
||||
hidden_size=128,
|
||||
rnn_type='lstm',
|
||||
num_layers=4,
|
||||
nr_cells=100,
|
||||
cell_size=32,
|
||||
read_heads=4,
|
||||
batch_first=True,
|
||||
gpu_id=0
|
||||
)
|
||||
|
||||
(controller_hidden, memory, read_vectors) = (None, None, None)
|
||||
|
||||
output, (controller_hidden, memory, read_vectors) = \
|
||||
rnn(torch.randn(10, 4, 64), (controller_hidden, memory, read_vectors))
|
||||
```
|
||||
|
||||
## Example copy task
|
||||
|
||||
The copy task, as descibed in the original paper, is included in the repo.
|
||||
|
||||
```
|
||||
python ./copy_task.py -cuda 0
|
||||
```
|
||||
|
||||
## General noteworthy stuff
|
||||
|
||||
1. DNCs converge with Adam and RMSProp learning rules, SGD generally causes them to diverge.
|
||||
2. Using a large batch size (> 100, recommended 1000) prevents gradients from becoming `NaN`.
|
||||
|
1
dnc/__init__.py
Normal file
1
dnc/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
#!/usr/bin/env python3
|
166
dnc/copy_task.py
Normal file
166
dnc/copy_task.py
Normal file
@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
import numpy as np
|
||||
import getopt
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
import argparse
|
||||
|
||||
sys.path.insert(0, os.path.join('..', '..'))
|
||||
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
from torch.nn.utils import clip_grad_norm
|
||||
|
||||
from dnc import DNC
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Differentiable Neural Computer')
|
||||
parser.add_argument('-input_size', type=int, default= 6, help='dimension of input feature')
|
||||
parser.add_argument('-nhid', type=int, default=64, help='humber of hidden units of the inner nn')
|
||||
|
||||
parser.add_argument('-nlayer', type=int, default=2, help='number of layers')
|
||||
parser.add_argument('-lr', type=float, default=1e-2, help='initial learning rate')
|
||||
parser.add_argument('-clip', type=float, default=0.5, help='gradient clipping')
|
||||
|
||||
parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='batch size')
|
||||
parser.add_argument('-mem_size', type=int, default=16, help='memory dimension')
|
||||
parser.add_argument('-mem_slot', type=int, default=15, help='number of memory slots')
|
||||
parser.add_argument('-read_heads', type=int, default=1, help='number of read heads')
|
||||
|
||||
parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length')
|
||||
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
|
||||
parser.add_argument('-log-interval', type=int, default=200, metavar='N', help='report interval')
|
||||
|
||||
parser.add_argument('-iterations', type=int, default=100000, metavar='N', help='total number of iteration')
|
||||
parser.add_argument('-summarize_freq', type=int, default=100, metavar='N', help='summarize frequency')
|
||||
parser.add_argument('-check_freq', type=int, default=100, metavar='N', help='check point frequency')
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.cuda != -1:
|
||||
print('Using CUDA.')
|
||||
T.manual_seed(1111)
|
||||
else:
|
||||
print('Using CPU.')
|
||||
|
||||
|
||||
def llprint(message):
|
||||
sys.stdout.write(message)
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def generate_data(batch_size, length, size, cuda=-1):
|
||||
|
||||
input_data = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32)
|
||||
target_output = np.zeros((batch_size, 2 * length + 1, size), dtype=np.float32)
|
||||
|
||||
sequence = np.random.binomial(1, 0.5, (batch_size, length, size - 1))
|
||||
|
||||
input_data[:, :length, :size - 1] = sequence
|
||||
input_data[:, length, -1] = 1 # the end symbol
|
||||
target_output[:, length + 1:, :size - 1] = sequence
|
||||
|
||||
input_data = T.from_numpy(input_data)
|
||||
target_output = T.from_numpy(target_output)
|
||||
if cuda != -1:
|
||||
input_data = input_data.cuda()
|
||||
target_output = target_output.cuda()
|
||||
|
||||
return var(input_data), var(target_output)
|
||||
|
||||
|
||||
def criterion(predictions, targets):
|
||||
return T.mean(
|
||||
-1 * F.logsigmoid(predictions) * (targets) - T.log(1 - F.sigmoid(predictions) + 1e-9) * (1 - targets)
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
dirname = os.path.dirname(__file__)
|
||||
ckpts_dir = os.path.join(dirname, 'checkpoints')
|
||||
if not os.path.isdir(ckpts_dir):
|
||||
os.mkdir(ckpts_dir)
|
||||
|
||||
batch_size = args.batch_size
|
||||
sequence_max_length = args.sequence_max_length
|
||||
iterations = args.iterations
|
||||
summarize_freq = args.summarize_freq
|
||||
check_freq = args.check_freq
|
||||
|
||||
# input_size = output_size = args.input_size
|
||||
mem_slot = args.mem_slot
|
||||
mem_size = args.mem_size
|
||||
read_heads = args.read_heads
|
||||
|
||||
|
||||
# options, _ = getopt.getopt(sys.argv[1:], '', ['iterations='])
|
||||
|
||||
# for opt in options:
|
||||
# if opt[0] == '-iterations':
|
||||
# iterations = int(opt[1])
|
||||
|
||||
rnn = DNC(
|
||||
input_size=args.input_size,
|
||||
hidden_size=args.nhid,
|
||||
rnn_type='lstm',
|
||||
num_layers=args.nlayer,
|
||||
nr_cells=mem_slot,
|
||||
cell_size=mem_size,
|
||||
read_heads=read_heads,
|
||||
gpu_id=args.cuda
|
||||
)
|
||||
|
||||
if args.cuda != -1:
|
||||
rnn = rnn.cuda(args.cuda)
|
||||
|
||||
last_save_losses = []
|
||||
|
||||
optimizer = optim.Adam(rnn.parameters(), lr=args.lr)
|
||||
|
||||
for epoch in range(iterations + 1):
|
||||
llprint("\rIteration {ep}/{tot}".format(ep=epoch, tot=iterations))
|
||||
optimizer.zero_grad()
|
||||
|
||||
random_length = np.random.randint(1, sequence_max_length + 1)
|
||||
|
||||
input_data, target_output = generate_data(batch_size, random_length, args.input_size, args.cuda)
|
||||
# input_data = input_data.transpose(0, 1).contiguous()
|
||||
target_output = target_output.transpose(0, 1).contiguous()
|
||||
|
||||
output, _ = rnn(input_data, None)
|
||||
output = output.transpose(0, 1)
|
||||
|
||||
loss = criterion((output), target_output)
|
||||
# if np.isnan(loss.data.cpu().numpy()):
|
||||
# llprint('\nGot nan loss, contine to jump the backward \n')
|
||||
|
||||
# apply_dict(locals())
|
||||
loss.backward()
|
||||
|
||||
optimizer.step()
|
||||
loss_value = loss.data[0]
|
||||
|
||||
summerize = (epoch % summarize_freq == 0)
|
||||
take_checkpoint = (epoch != 0) and (epoch % check_freq == 0)
|
||||
|
||||
last_save_losses.append(loss_value)
|
||||
|
||||
if summerize:
|
||||
llprint("\n\tAvg. Logistic Loss: %.4f\n" % (np.mean(last_save_losses)))
|
||||
last_save_losses = []
|
||||
|
||||
if take_checkpoint:
|
||||
llprint("\nSaving Checkpoint ... "),
|
||||
check_ptr = os.path.join(ckpts_dir, 'step_{}.pth'.format(epoch))
|
||||
cur_weights = rnn.state_dict()
|
||||
T.save(cur_weights, check_ptr)
|
||||
llprint("Done!\n")
|
255
dnc/dnc.py
Normal file
255
dnc/dnc.py
Normal file
@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import numpy as np
|
||||
|
||||
from torch.nn.utils.rnn import pad_packed_sequence as pad
|
||||
from torch.nn.utils.rnn import pack_padded_sequence as pack
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
|
||||
from util import *
|
||||
from memory import *
|
||||
|
||||
|
||||
class DNC(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
rnn_type='lstm',
|
||||
num_layers=1,
|
||||
bias=True,
|
||||
batch_first=True,
|
||||
dropout=0,
|
||||
bidirectional=False,
|
||||
nr_cells=5,
|
||||
read_heads=2,
|
||||
cell_size=10,
|
||||
nonlinearity='tanh',
|
||||
gpu_id=-1,
|
||||
independent_linears=False,
|
||||
share_memory=True
|
||||
):
|
||||
super(DNC, self).__init__()
|
||||
# todo: separate weights and RNNs for the interface and output vectors
|
||||
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.rnn_type = rnn_type
|
||||
self.num_layers = num_layers
|
||||
self.bias = bias
|
||||
self.batch_first = batch_first
|
||||
self.dropout = dropout
|
||||
self.bidirectional = bidirectional
|
||||
self.nr_cells = nr_cells
|
||||
self.read_heads = read_heads
|
||||
self.cell_size = cell_size
|
||||
self.nonlinearity = nonlinearity
|
||||
self.gpu_id = gpu_id
|
||||
self.independent_linears = independent_linears
|
||||
self.share_memory = share_memory
|
||||
|
||||
self.w = self.cell_size
|
||||
self.r = self.read_heads
|
||||
|
||||
# input size of layer 0
|
||||
self.layer0_input_size = self.r * self.w + self.input_size
|
||||
# input size of subsequent layers
|
||||
self.layern_input_size = self.r * self.w + self.hidden_size
|
||||
|
||||
self.interface_size = (self.w * self.r) + (3 * self.w) + (5 * self.r) + 3
|
||||
self.output_size = self.hidden_size
|
||||
|
||||
self.rnns = []
|
||||
self.memories = []
|
||||
|
||||
for layer in range(self.num_layers):
|
||||
# controllers for each layer
|
||||
if self.rnn_type.lower() == 'rnn':
|
||||
if layer == 0:
|
||||
self.rnns.append(nn.RNNCell(self.layer0_input_size, self.output_size, bias=self.bias, nonlinearity=self.nonlinearity))
|
||||
else:
|
||||
self.rnns.append(nn.RNNCell(self.layern_input_size, self.output_size, bias=self.bias, nonlinearity=self.nonlinearity))
|
||||
elif self.rnn_type.lower() == 'gru':
|
||||
if layer == 0:
|
||||
self.rnns.append(nn.GRUCell(self.layer0_input_size, self.output_size, bias=self.bias))
|
||||
else:
|
||||
self.rnns.append(nn.GRUCell(self.layern_input_size, self.output_size, bias=self.bias))
|
||||
elif self.rnn_type.lower() == 'lstm':
|
||||
# if layer == 0:
|
||||
self.rnns.append(nn.LSTMCell(self.layer0_input_size, self.output_size, bias=self.bias))
|
||||
# else:
|
||||
# self.rnns.append(nn.LSTMCell(self.layern_input_size, self.output_size, bias=self.bias))
|
||||
|
||||
# memories for each layer
|
||||
if not self.share_memory:
|
||||
self.memories.append(
|
||||
Memory(
|
||||
input_size=self.output_size,
|
||||
mem_size=self.nr_cells,
|
||||
cell_size=self.w,
|
||||
read_heads=self.r,
|
||||
gpu_id=self.gpu_id,
|
||||
independent_linears=self.independent_linears
|
||||
)
|
||||
)
|
||||
|
||||
# only one memory shared by all layers
|
||||
if self.share_memory:
|
||||
self.memories.append(
|
||||
Memory(
|
||||
input_size=self.output_size,
|
||||
mem_size=self.nr_cells,
|
||||
cell_size=self.w,
|
||||
read_heads=self.r,
|
||||
gpu_id=self.gpu_id,
|
||||
independent_linears=self.independent_linears
|
||||
)
|
||||
)
|
||||
|
||||
for layer in range(self.num_layers):
|
||||
setattr(self, 'rnn_layer_' + str(layer), self.rnns[layer])
|
||||
if not self.share_memory:
|
||||
setattr(self, 'rnn_layer_memory_' + str(layer), self.memories[layer])
|
||||
if self.share_memory:
|
||||
setattr(self, 'rnn_layer_memory_shared', self.memories[0])
|
||||
|
||||
# final output layer
|
||||
self.output_weights = nn.Linear(self.output_size, self.output_size)
|
||||
self.mem_out = nn.Linear(self.layern_input_size, self.input_size)
|
||||
self.dropout_layer = nn.Dropout(self.dropout)
|
||||
|
||||
if self.gpu_id != -1:
|
||||
[x.cuda(self.gpu_id) for x in self.rnns]
|
||||
[x.cuda(self.gpu_id) for x in self.memories]
|
||||
self.mem_out.cuda(self.gpu_id)
|
||||
|
||||
def _init_hidden(self, hx, batch_size, reset_experience):
|
||||
# create empty hidden states if not provided
|
||||
if hx is None:
|
||||
hx = (None, None, None)
|
||||
(chx, mhx, last_read) = hx
|
||||
|
||||
# initialize hidden state of the controller RNN
|
||||
if chx is None:
|
||||
chx = cuda(T.zeros(self.num_layers, batch_size, self.output_size), gpu_id=self.gpu_id)
|
||||
if self.rnn_type.lower() == 'lstm':
|
||||
chx = (chx, chx)
|
||||
|
||||
# Last read vectors
|
||||
if last_read is None:
|
||||
last_read = cuda(T.zeros(batch_size, self.w * self.r), gpu_id=self.gpu_id)
|
||||
|
||||
# memory states
|
||||
if mhx is None:
|
||||
if self.share_memory:
|
||||
mhx = self.memories[0].reset(batch_size, erase=reset_experience)
|
||||
else:
|
||||
mhx = [m.reset(batch_size, erase=reset_experience) for m in self.memories]
|
||||
else:
|
||||
if self.share_memory:
|
||||
mhx = self.memories[0].reset(batch_size, mhx, erase=reset_experience)
|
||||
else:
|
||||
mhx = [m.reset(batch_size, h, erase=reset_experience) for m, h in zip(self.memories, mhx)]
|
||||
|
||||
return chx, mhx, last_read
|
||||
|
||||
def _layer_forward(self, input, layer, hx=(None, None)):
|
||||
(chx, mhx) = hx
|
||||
max_length = len(input)
|
||||
outs = [0] * max_length
|
||||
read_vectors = [0] * max_length
|
||||
|
||||
for time in range(max_length):
|
||||
# pass through controller
|
||||
# print('input[time]', input[time].size(), self.layer0_input_size, self.layern_input_size)
|
||||
chx = self.rnns[layer](input[time], chx)
|
||||
# the interface vector
|
||||
ξ = chx[0] if self.rnn_type.lower() == 'lstm' else chx
|
||||
# the output
|
||||
out = self.output_weights(chx[0])
|
||||
|
||||
# pass through memory
|
||||
if self.share_memory:
|
||||
read_vecs, mhx = self.memories[0](ξ, mhx)
|
||||
else:
|
||||
read_vecs, mhx = self.memories[layer](ξ, mhx)
|
||||
read_vectors[time] = read_vecs.view(-1, self.w * self.r)
|
||||
|
||||
# get the final output for this time step
|
||||
outs[time] = self.mem_out(T.cat([out, read_vectors[time]], 1))
|
||||
|
||||
return outs, read_vectors, (chx, mhx)
|
||||
|
||||
def forward(self, input, hx=(None, None, None), reset_experience=False):
|
||||
# handle packed data
|
||||
is_packed = type(input) is PackedSequence
|
||||
if is_packed:
|
||||
input, lengths = pad(input)
|
||||
max_length = lengths[0]
|
||||
else:
|
||||
max_length = input.size(1) if self.batch_first else input.size(0)
|
||||
lengths = [input.size(1)] * max_length if self.batch_first else [input.size(0)] * max_length
|
||||
|
||||
batch_size = input.size(0) if self.batch_first else input.size(1)
|
||||
|
||||
# make the data batch-first
|
||||
if not self.batch_first:
|
||||
input = input.transpose(0, 1)
|
||||
|
||||
controller_hidden, mem_hidden, last_read = self._init_hidden(hx, batch_size, reset_experience)
|
||||
|
||||
# batched forward pass per element / word / etc
|
||||
outputs = None
|
||||
chxs = []
|
||||
read_vectors = [last_read] * max_length
|
||||
# outs = [input[:, x, :] for x in range(max_length)]
|
||||
outs = [T.cat([input[:, x, :], last_read], 1) for x in range(max_length)]
|
||||
|
||||
# chx = [x[0] for x in controller_hidden] if self.rnn_type.lower() == 'lstm' else controller_hidden[0]
|
||||
for layer in range(self.num_layers):
|
||||
# this layer's hidden states
|
||||
chx = [x[layer] for x in controller_hidden] if self.rnn_type.lower() == 'lstm' else controller_hidden[layer]
|
||||
|
||||
m = mem_hidden if self.share_memory else mem_hidden[layer]
|
||||
# pass through controller
|
||||
outs, _, (chx, m) = self._layer_forward(
|
||||
outs,
|
||||
layer,
|
||||
(chx, m)
|
||||
)
|
||||
|
||||
# store the memory back (per layer or shared)
|
||||
if self.share_memory:
|
||||
mem_hidden = m
|
||||
else:
|
||||
mem_hidden[layer] = m
|
||||
chxs.append(chx)
|
||||
|
||||
if layer == self.num_layers - 1:
|
||||
# final outputs
|
||||
outputs = T.stack(outs, 1)
|
||||
else:
|
||||
# the controller output + read vectors go into next layer
|
||||
outs = [T.cat([o, r], 1) for o, r in zip(outs, read_vectors)]
|
||||
# outs = [o for o in outs]
|
||||
|
||||
# final hidden values
|
||||
if self.rnn_type.lower() == 'lstm':
|
||||
h = T.stack([x[0] for x in chxs], 0)
|
||||
c = T.stack([x[1] for x in chxs], 0)
|
||||
controller_hidden = (h, c)
|
||||
else:
|
||||
controller_hidden = T.stack(chxs, 0)
|
||||
|
||||
if not self.batch_first:
|
||||
outputs = outputs.transpose(0, 1)
|
||||
if is_packed:
|
||||
outputs = pack(output, lengths)
|
||||
|
||||
# apply_dict(locals())
|
||||
|
||||
return outputs, (controller_hidden, mem_hidden, read_vectors[-1])
|
256
dnc/memory.py
Normal file
256
dnc/memory.py
Normal file
@ -0,0 +1,256 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
from torch.autograd import Variable as var
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
from util import *
|
||||
|
||||
|
||||
class Memory(nn.Module):
|
||||
|
||||
def __init__(self, input_size, mem_size=512, cell_size=32, read_heads=4, gpu_id=-1, independent_linears=True):
|
||||
super(Memory, self).__init__()
|
||||
|
||||
self.mem_size = mem_size
|
||||
self.cell_size = cell_size
|
||||
self.read_heads = read_heads
|
||||
self.gpu_id = gpu_id
|
||||
self.input_size = input_size
|
||||
self.independent_linears = independent_linears
|
||||
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
r = self.read_heads
|
||||
|
||||
if self.independent_linears:
|
||||
self.read_keys_transform = nn.Linear(self.input_size, w * r)
|
||||
self.read_strengths_transform = nn.Linear(self.input_size, r)
|
||||
self.write_key_transform = nn.Linear(self.input_size, w)
|
||||
self.write_strength_transform = nn.Linear(self.input_size, 1)
|
||||
self.erase_vector_transform = nn.Linear(self.input_size, w)
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||
self.free_gates_transform = nn.Linear(self.input_size, r)
|
||||
self.allocation_gate_transform = nn.Linear(self.input_size, 1)
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||
self.read_modes_transform = nn.Linear(self.input_size, 3 * r)
|
||||
else:
|
||||
self.interface_size = (w * r) + (3 * w) + (5 * r) + 3
|
||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
||||
|
||||
self.I = cuda(1 - T.eye(m).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
|
||||
def reset(self, batch_size=1, hidden=None, erase=True):
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
r = self.read_heads
|
||||
b = batch_size
|
||||
|
||||
if hidden is None:
|
||||
return {
|
||||
'memory': cuda(T.zeros(b, m, w).fill_(0), gpu_id=self.gpu_id),
|
||||
'link_matrix': cuda(T.zeros(b, 1, m, m), gpu_id=self.gpu_id),
|
||||
'precedence': cuda(T.zeros(b, 1, m), gpu_id=self.gpu_id),
|
||||
'read_weights': cuda(T.zeros(b, r, m).fill_(0), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, 1, m).fill_(0), gpu_id=self.gpu_id),
|
||||
'usage_vector': cuda(T.zeros(b, m), gpu_id=self.gpu_id)
|
||||
}
|
||||
else:
|
||||
hidden['memory'] = hidden['memory'].clone()
|
||||
hidden['link_matrix'] = hidden['link_matrix'].clone()
|
||||
hidden['precedence'] = hidden['precedence'].clone()
|
||||
hidden['read_weights'] = hidden['read_weights'].clone()
|
||||
hidden['write_weights'] = hidden['write_weights'].clone()
|
||||
hidden['usage_vector'] = hidden['usage_vector'].clone()
|
||||
|
||||
if erase:
|
||||
hidden['memory'].data.fill_(δ)
|
||||
hidden['link_matrix'].data.zero_()
|
||||
hidden['precedence'].data.zero_()
|
||||
hidden['read_weights'].data.fill_(δ)
|
||||
hidden['write_weights'].data.fill_(δ)
|
||||
hidden['usage_vector'].data.zero_()
|
||||
return hidden
|
||||
|
||||
def get_usage_vector(self, usage, free_gates, read_weights, write_weights):
|
||||
# write_weights = write_weights.detach() # detach from the computation graph
|
||||
usage = usage + (1 - usage) * (1 - T.prod(1 - write_weights, 1))
|
||||
ψ = T.prod(1 - free_gates.unsqueeze(2) * read_weights, 1)
|
||||
return usage * ψ
|
||||
|
||||
def allocate(self, usage, write_gate):
|
||||
# ensure values are not too small prior to cumprod.
|
||||
usage = δ + (1 - δ) * usage
|
||||
# free list
|
||||
sorted_usage, φ = T.topk(usage, self.mem_size, dim=1, largest=False)
|
||||
# TODO: these are actually shifted cumprods, tensorflow has exclusive=True
|
||||
# fix once pytorch issue is fixed
|
||||
sorted_allocation_weights = (1 - sorted_usage) * fake_cumprod(sorted_usage, self.gpu_id).squeeze()
|
||||
# construct the reverse sorting index https://stackoverflow.com/questions/2483696/undo-or-reverse-argsort-python
|
||||
_, φ_rev = T.topk(φ, k=self.mem_size, dim=1, largest=False)
|
||||
allocation_weights = sorted_allocation_weights.gather(1, φ.long())
|
||||
|
||||
# update usage after allocating
|
||||
# usage += ((1 - usage) * write_gate * allocation_weights)
|
||||
return allocation_weights.unsqueeze(1), usage
|
||||
|
||||
def write_weighting(self, memory, write_content_weights, allocation_weights, write_gate, allocation_gate):
|
||||
ag = allocation_gate.unsqueeze(-1)
|
||||
wg = write_gate.unsqueeze(-1)
|
||||
|
||||
return wg * (ag * allocation_weights + (1 - ag) * write_content_weights)
|
||||
|
||||
def get_link_matrix(self, link_matrix, write_weights, precedence):
|
||||
precedence = precedence.unsqueeze(2)
|
||||
write_weights_i = write_weights.unsqueeze(3)
|
||||
write_weights_j = write_weights.unsqueeze(2)
|
||||
|
||||
prev_scale = 1 - write_weights_i - write_weights_j
|
||||
new_link_matrix = write_weights_i * precedence
|
||||
|
||||
link_matrix = prev_scale * link_matrix + new_link_matrix
|
||||
# elaborate trick to delete diag elems
|
||||
return self.I.expand_as(link_matrix) * link_matrix
|
||||
|
||||
def update_precedence(self, precedence, write_weights):
|
||||
return (1 - T.sum(write_weights, 2, keepdim=True)) * precedence + write_weights
|
||||
|
||||
def write(self, write_key, write_vector, erase_vector, free_gates, read_strengths, write_strength, write_gate, allocation_gate, hidden):
|
||||
# get current usage
|
||||
hidden['usage_vector'] = self.get_usage_vector(
|
||||
hidden['usage_vector'],
|
||||
free_gates,
|
||||
hidden['read_weights'],
|
||||
hidden['write_weights']
|
||||
)
|
||||
|
||||
# lookup memory with write_key and write_strength
|
||||
write_content_weights = self.content_weightings(hidden['memory'], write_key, write_strength)
|
||||
|
||||
# get memory allocation
|
||||
alloc, _ = self.allocate(
|
||||
hidden['usage_vector'],
|
||||
allocation_gate * write_gate
|
||||
)
|
||||
|
||||
# get write weightings
|
||||
hidden['write_weights'] = self.write_weighting(
|
||||
hidden['memory'],
|
||||
write_content_weights,
|
||||
alloc,
|
||||
write_gate,
|
||||
allocation_gate
|
||||
)
|
||||
|
||||
weighted_resets = hidden['write_weights'].unsqueeze(3) * erase_vector.unsqueeze(2)
|
||||
reset_gate = T.prod(1 - weighted_resets, 1)
|
||||
# Update memory
|
||||
hidden['memory'] = hidden['memory'] * reset_gate
|
||||
|
||||
hidden['memory'] = hidden['memory'] + \
|
||||
T.bmm(hidden['write_weights'].transpose(1, 2), write_vector)
|
||||
|
||||
# update link_matrix
|
||||
hidden['link_matrix'] = self.get_link_matrix(
|
||||
hidden['link_matrix'],
|
||||
hidden['write_weights'],
|
||||
hidden['precedence']
|
||||
)
|
||||
hidden['precedence'] = self.update_precedence(hidden['precedence'], hidden['write_weights'])
|
||||
|
||||
return hidden
|
||||
|
||||
def content_weightings(self, memory, keys, strengths):
|
||||
d = θ(memory, keys)
|
||||
strengths = F.softplus(strengths).unsqueeze(2)
|
||||
return σ(d * strengths, 2)
|
||||
|
||||
def directional_weightings(self, link_matrix, read_weights):
|
||||
rw = read_weights.unsqueeze(1)
|
||||
|
||||
f = T.matmul(link_matrix, rw.transpose(2, 3)).transpose(2, 3)
|
||||
b = T.matmul(rw, link_matrix)
|
||||
return f.transpose(1, 2), b.transpose(1, 2)
|
||||
|
||||
def read_weightings(self, memory, content_weights, link_matrix, read_modes, read_weights):
|
||||
forward_weight, backward_weight = self.directional_weightings(link_matrix, read_weights)
|
||||
|
||||
content_mode = read_modes[:, :, 2].contiguous().unsqueeze(2) * content_weights
|
||||
backward_mode = T.sum(read_modes[:, :, 0:1].contiguous().unsqueeze(3) * backward_weight, 2)
|
||||
forward_mode = T.sum(read_modes[:, :, 1:2].contiguous().unsqueeze(3) * forward_weight, 2)
|
||||
|
||||
return backward_mode + content_mode + forward_mode
|
||||
|
||||
def read_vectors(self, memory, read_weights):
|
||||
return T.bmm(read_weights, memory)
|
||||
|
||||
def read(self, read_keys, read_strengths, read_modes, hidden):
|
||||
content_weights = self.content_weightings(hidden['memory'], read_keys, read_strengths)
|
||||
|
||||
hidden['read_weights'] = self.read_weightings(
|
||||
hidden['memory'],
|
||||
content_weights,
|
||||
hidden['link_matrix'],
|
||||
read_modes,
|
||||
hidden['read_weights']
|
||||
)
|
||||
read_vectors = self.read_vectors(hidden['memory'], hidden['read_weights'])
|
||||
return read_vectors, hidden
|
||||
|
||||
def forward(self, ξ, hidden):
|
||||
|
||||
# ξ = ξ.detach()
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
r = self.read_heads
|
||||
b = ξ.size()[0]
|
||||
|
||||
if self.independent_linears:
|
||||
# r read keys (b * r * w)
|
||||
read_keys = self.read_keys_transform(ξ).view(b, r, w)
|
||||
# r read strengths (b * r)
|
||||
read_strengths = self.read_strengths_transform(ξ).view(b, r)
|
||||
# write key (b * 1 * w)
|
||||
write_key = self.write_key_transform(ξ).view(b, 1, w)
|
||||
# write strength (b * 1)
|
||||
write_strength = self.write_strength_transform(ξ).view(b, 1)
|
||||
# erase vector (b * 1 * w)
|
||||
erase_vector = F.sigmoid(self.erase_vector_transform(ξ).view(b, 1, w))
|
||||
# write vector (b * 1 * w)
|
||||
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
|
||||
# r free gates (b * r)
|
||||
free_gates = F.sigmoid(self.free_gates_transform(ξ).view(b, r))
|
||||
# allocation gate (b * 1)
|
||||
allocation_gate = F.sigmoid(self.allocation_gate_transform(ξ).view(b, 1))
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
||||
# read modes (b * r * 3)
|
||||
read_modes = σ(self.read_modes_transform(ξ).view(b, r, 3), 1)
|
||||
else:
|
||||
ξ = self.interface_weights(ξ)
|
||||
# r read keys (b * w * r)
|
||||
read_keys = ξ[:, :r * w].contiguous().view(b, r, w)
|
||||
# r read strengths (b * r)
|
||||
read_strengths = 1 + F.relu(ξ[:, r * w:r * w + r].contiguous().view(b, r))
|
||||
# write key (b * w * 1)
|
||||
write_key = ξ[:, r * w + r:r * w + r + w].contiguous().view(b, 1, w)
|
||||
# write strength (b * 1)
|
||||
write_strength = 1 + F.relu(ξ[:, r * w + r + w].contiguous()).view(b, 1)
|
||||
# erase vector (b * w)
|
||||
erase_vector = F.sigmoid(ξ[:, r * w + r + w + 1: r * w + r + 2 * w + 1].contiguous().view(b, 1, w))
|
||||
# write vector (b * w)
|
||||
write_vector = ξ[:, r * w + r + 2 * w + 1: r * w + r + 3 * w + 1].contiguous().view(b, 1, w)
|
||||
# r free gates (b * r)
|
||||
free_gates = F.sigmoid(ξ[:, r * w + r + 3 * w + 1: r * w + 2 * r + 3 * w + 1].contiguous().view(b, r))
|
||||
# allocation gate (b * 1)
|
||||
allocation_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 1].contiguous().unsqueeze(1).view(b, 1))
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(ξ[:, r * w + 2 * r + 3 * w + 2].contiguous()).unsqueeze(1).view(b, 1)
|
||||
# read modes (b * 3*r)
|
||||
read_modes = σ(ξ[:, r * w + 2 * r + 3 * w + 2: r * w + 5 * r + 3 * w + 2].contiguous().view(b, r, 3), 1)
|
||||
|
||||
hidden = self.write(write_key, write_vector, erase_vector, free_gates,
|
||||
read_strengths, write_strength, write_gate, allocation_gate, hidden)
|
||||
return self.read(read_keys, read_strengths, read_modes, hidden)
|
154
dnc/util.py
Normal file
154
dnc/util.py
Normal file
@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch.nn as nn
|
||||
import torch as T
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable as var
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
import re
|
||||
import string
|
||||
|
||||
|
||||
def recursiveTrace(obj):
|
||||
print(type(obj))
|
||||
if hasattr(obj, 'grad_fn'):
|
||||
print(obj.grad_fn)
|
||||
recursiveTrace(obj.grad_fn)
|
||||
elif hasattr(obj, 'saved_variables'):
|
||||
print(obj.requires_grad, len(obj.saved_tensors), len(obj.saved_variables))
|
||||
[print(v) for v in obj.saved_variables]
|
||||
[recursiveTrace(v.grad_fn) for v in obj.saved_variables]
|
||||
|
||||
|
||||
def cuda(x, grad=False, gpu_id=-1):
|
||||
if gpu_id == -1:
|
||||
return var(x, requires_grad=grad)
|
||||
else:
|
||||
return var(x.pin_memory(), requires_grad=grad).cuda(gpu_id, async=True)
|
||||
|
||||
|
||||
def cudavec(x, grad=False, gpu_id=-1):
|
||||
if gpu_id == -1:
|
||||
return var(T.from_numpy(x), requires_grad=grad)
|
||||
else:
|
||||
return var(T.from_numpy(x).pin_memory(), requires_grad=grad).cuda(gpu_id, async=True)
|
||||
|
||||
|
||||
def cudalong(x, grad=False, gpu_id=-1):
|
||||
if gpu_id == -1:
|
||||
return var(T.from_numpy(x.astype(np.long)), requires_grad=grad)
|
||||
else:
|
||||
return var(T.from_numpy(x.astype(np.long)).pin_memory(), requires_grad=grad).cuda(gpu_id, async=True)
|
||||
|
||||
|
||||
def fake_cumprod(vb, gpu_id):
|
||||
"""
|
||||
args:
|
||||
vb: [hei x wid]
|
||||
-> NOTE: we are lazy here so now it only supports cumprod along wid
|
||||
"""
|
||||
# real_cumprod = torch.cumprod(vb.data, 1)
|
||||
vb = vb.unsqueeze(0)
|
||||
mul_mask_vb = Variable(torch.zeros(vb.size(2), vb.size(1), vb.size(2))).type_as(vb)
|
||||
|
||||
if gpu_id != -1:
|
||||
mul_mask_vb = mul_mask_vb.cuda(gpu_id)
|
||||
|
||||
for i in range(vb.size(2)):
|
||||
mul_mask_vb[i, :, :i + 1] = 1
|
||||
add_mask_vb = 1 - mul_mask_vb
|
||||
vb = vb.expand_as(mul_mask_vb) * mul_mask_vb + add_mask_vb
|
||||
# vb = torch.prod(vb, 2).transpose(0, 2) # 0.1.12
|
||||
vb = torch.prod(vb, 2, keepdim=True).transpose(0, 2) # 0.2.0
|
||||
# print(real_cumprod - vb.data) # NOTE: checked, ==0
|
||||
return vb
|
||||
|
||||
|
||||
def θ(a, b, dimA=2, dimB=2, normBy=2):
|
||||
"""Batchwise Cosine distance
|
||||
|
||||
Cosine distance
|
||||
|
||||
Arguments:
|
||||
a {Tensor} -- A 3D Tensor (b * m * w)
|
||||
b {Tensor} -- A 3D Tensor (b * r * w)
|
||||
|
||||
Keyword Arguments:
|
||||
dimA {number} -- exponent value of the norm for `a` (default: {2})
|
||||
dimB {number} -- exponent value of the norm for `b` (default: {1})
|
||||
|
||||
Returns:
|
||||
Tensor -- Batchwise cosine distance (b * r * m)
|
||||
"""
|
||||
a_norm = T.norm(a, normBy, dimA, keepdim=True).expand_as(a) + δ
|
||||
b_norm = T.norm(b, normBy, dimB, keepdim=True).expand_as(b) + δ
|
||||
|
||||
x = T.bmm(a, b.transpose(1, 2)).transpose(1, 2) / (
|
||||
T.bmm(a_norm, b_norm.transpose(1, 2)).transpose(1, 2) + δ)
|
||||
# apply_dict(locals())
|
||||
return x
|
||||
|
||||
|
||||
def σ(input, axis=1):
|
||||
"""Softmax on an axis
|
||||
|
||||
Softmax on an axis
|
||||
|
||||
Arguments:
|
||||
input {Tensor} -- input Tensor
|
||||
|
||||
Keyword Arguments:
|
||||
axis {number} -- axis on which to take softmax on (default: {1})
|
||||
|
||||
Returns:
|
||||
Tensor -- Softmax output Tensor
|
||||
"""
|
||||
input_size = input.size()
|
||||
|
||||
trans_input = input.transpose(axis, len(input_size) - 1)
|
||||
trans_size = trans_input.size()
|
||||
|
||||
input_2d = trans_input.contiguous().view(-1, trans_size[-1])
|
||||
soft_max_2d = F.softmax(input_2d)
|
||||
soft_max_nd = soft_max_2d.view(*trans_size)
|
||||
return soft_max_nd.transpose(axis, len(input_size) - 1)
|
||||
|
||||
δ = 1e-6
|
||||
|
||||
|
||||
def register_nan_checks(model):
|
||||
def check_grad(module, grad_input, grad_output):
|
||||
# print(module) you can add this to see that the hook is called
|
||||
print('hook called for ' + str(type(module)))
|
||||
if any(np.all(np.isnan(gi.data.cpu().numpy())) for gi in grad_input if gi is not None):
|
||||
print('NaN gradient in grad_input ' + type(module).__name__)
|
||||
|
||||
model.apply(lambda module: module.register_backward_hook(check_grad))
|
||||
|
||||
|
||||
def apply_dict(dic):
|
||||
for k, v in dic.items():
|
||||
apply_var(v, k)
|
||||
if isinstance(v, nn.Module):
|
||||
key_list = [a for a in dir(v) if not a.startswith('__')]
|
||||
for key in key_list:
|
||||
apply_var(getattr(v, key), key)
|
||||
for pk, pv in v._parameters.items():
|
||||
apply_var(pv, pk)
|
||||
|
||||
|
||||
def apply_var(v, k):
|
||||
if isinstance(v, Variable) and v.requires_grad:
|
||||
v.register_hook(check_nan_gradient(k))
|
||||
|
||||
|
||||
def check_nan_gradient(name=''):
|
||||
def f(tensor):
|
||||
if np.isnan(T.mean(tensor).data.cpu().numpy()):
|
||||
print('\nnan gradient of {} :'.format(name))
|
||||
# print(tensor)
|
||||
# assert 0, 'nan gradient'
|
||||
return tensor
|
||||
return f
|
67
setup.py
Normal file
67
setup.py
Normal file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""A setuptools based setup module.
|
||||
See:
|
||||
https://packaging.python.org/en/latest/distributing.html
|
||||
https://github.com/pypa/sampleproject
|
||||
"""
|
||||
|
||||
# Always prefer setuptools over distutils
|
||||
from setuptools import setup, find_packages
|
||||
# To use a consistent encoding
|
||||
from codecs import open
|
||||
from os import path
|
||||
|
||||
here = path.abspath(path.dirname(__file__))
|
||||
|
||||
# Get the long description from the README file
|
||||
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
|
||||
long_description = f.read()
|
||||
|
||||
setup(
|
||||
name='dnc',
|
||||
|
||||
version='0.0.1',
|
||||
|
||||
description='Differentiable Neural Computer, for Pytorch',
|
||||
long_description=long_description,
|
||||
|
||||
# The project's main homepage.
|
||||
url='https://github.com/pypa/dnc',
|
||||
|
||||
# Author details
|
||||
author='Russi Chatterjee',
|
||||
author_email='root@ixaxaar.in',
|
||||
|
||||
# Choose your license
|
||||
license='MIT',
|
||||
|
||||
# See https://pypi.python.org/pypi?%3Aaction=list_classifiers
|
||||
classifiers=[
|
||||
'Development Status :: 3 - Alpha',
|
||||
|
||||
'Intended Audience :: Science/Research',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
|
||||
'License :: OSI Approved :: MIT License',
|
||||
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.3',
|
||||
'Programming Language :: Python :: 3.4',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
],
|
||||
|
||||
keywords='differentiable neural computer dnc memory network',
|
||||
|
||||
packages=find_packages(exclude=['contrib', 'docs', 'tests']),
|
||||
|
||||
install_requires=['torch', 'numpy'],
|
||||
|
||||
extras_require={
|
||||
'dev': ['check-manifest'],
|
||||
'test': ['coverage'],
|
||||
},
|
||||
|
||||
python_requires='>=3',
|
||||
)
|
Loading…
Reference in New Issue
Block a user