pytorch-dnc/dnc/sparse_memory.py

312 lines
11 KiB
Python
Raw Normal View History

2017-11-24 19:11:19 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
2017-11-24 19:11:19 +08:00
import torch.nn as nn
import torch as T
from torch.autograd import Variable as var
import torch.nn.functional as F
import numpy as np
import math
2017-11-24 19:11:19 +08:00
2017-12-11 02:51:30 +08:00
from .flann_index import FLANNIndex
2017-11-24 19:11:19 +08:00
from .util import *
2017-11-29 18:11:50 +08:00
import time
2017-11-27 16:21:17 +08:00
2017-12-03 19:39:59 +08:00
2017-11-24 19:11:19 +08:00
class SparseMemory(nn.Module):
def __init__(
2017-11-27 16:21:17 +08:00
self,
input_size,
mem_size=512,
cell_size=32,
independent_linears=True,
2017-12-07 19:29:51 +08:00
read_heads=4,
sparse_reads=10,
num_lists=None,
2017-11-27 16:21:17 +08:00
index_checks=32,
gpu_id=-1,
mem_gpu_id=-1
2017-11-27 16:21:17 +08:00
):
super(SparseMemory, self).__init__()
2017-11-24 19:11:19 +08:00
self.mem_size = mem_size
self.cell_size = cell_size
self.gpu_id = gpu_id
self.mem_gpu_id = mem_gpu_id
2017-11-24 19:11:19 +08:00
self.input_size = input_size
self.independent_linears = independent_linears
2017-11-27 16:21:17 +08:00
self.K = sparse_reads if self.mem_size > sparse_reads else self.mem_size
2017-12-07 19:29:51 +08:00
self.read_heads = read_heads
self.num_lists = num_lists if num_lists is not None else int(self.mem_size / 100)
2017-11-27 14:32:41 +08:00
self.index_checks = index_checks
2017-11-24 19:11:19 +08:00
m = self.mem_size
w = self.cell_size
2017-12-07 19:29:51 +08:00
r = self.read_heads
c = r * self.K + 1
2017-11-24 19:11:19 +08:00
if self.independent_linears:
2017-12-07 19:29:51 +08:00
self.read_query_transform = nn.Linear(self.input_size, w*r)
2017-11-24 19:11:19 +08:00
self.write_vector_transform = nn.Linear(self.input_size, w)
2017-12-07 19:29:51 +08:00
self.interpolation_gate_transform = nn.Linear(self.input_size, c)
2017-11-24 19:11:19 +08:00
self.write_gate_transform = nn.Linear(self.input_size, 1)
T.nn.init.orthogonal(self.read_query_transform.weight)
T.nn.init.orthogonal(self.write_vector_transform.weight)
T.nn.init.orthogonal(self.interpolation_gate_transform.weight)
T.nn.init.orthogonal(self.write_gate_transform.weight)
2017-11-24 19:11:19 +08:00
else:
2017-12-07 19:29:51 +08:00
self.interface_size = (r * w) + w + c + 1
2017-11-24 19:11:19 +08:00
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
T.nn.init.orthogonal(self.interface_weights.weight)
2017-11-24 19:11:19 +08:00
2017-12-11 06:12:45 +08:00
self.I = cuda(1 - T.eye(c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
2017-12-03 19:39:59 +08:00
self.δ = 0.005 # minimum usage
self.timestep = 0
2017-11-27 14:32:41 +08:00
def rebuild_indexes(self, hidden, erase=False):
b = hidden['memory'].size(0)
2017-11-27 14:32:41 +08:00
# if indexes already exist, we reset them
if 'indexes' in hidden:
2017-12-03 19:39:59 +08:00
[x.reset() for x in hidden['indexes']]
else:
# create new indexes
hidden['indexes'] = \
2017-12-11 02:51:30 +08:00
[FLANNIndex(cell_size=self.cell_size,
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
2017-12-03 19:39:59 +08:00
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
# add existing memory into indexes
2017-12-06 23:44:16 +08:00
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
if not erase:
2017-12-03 19:39:59 +08:00
for n, i in enumerate(hidden['indexes']):
2017-12-06 23:44:16 +08:00
i.reset()
i.add(hidden['memory'][n], last=pos[n][-1])
2017-12-03 19:39:59 +08:00
else:
self.timestep = 0
2017-11-27 14:32:41 +08:00
return hidden
2017-11-24 19:11:19 +08:00
def reset(self, batch_size=1, hidden=None, erase=True):
m = self.mem_size
w = self.cell_size
b = batch_size
2017-12-07 19:29:51 +08:00
r = self.read_heads
c = r * self.K + 1
2017-11-24 19:11:19 +08:00
if hidden is None:
2017-11-27 16:21:17 +08:00
hidden = {
2017-11-27 14:32:41 +08:00
# warning can be a huge chunk of contiguous memory
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
2017-12-07 19:29:51 +08:00
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
2017-12-11 06:12:45 +08:00
'link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
'rev_link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
'precedence': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
2017-11-28 02:14:21 +08:00
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
2017-12-11 02:51:30 +08:00
'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(),
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
2017-12-07 19:29:51 +08:00
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
2017-11-24 19:11:19 +08:00
}
2017-12-03 19:39:59 +08:00
hidden = self.rebuild_indexes(hidden, erase=True)
2017-11-24 19:11:19 +08:00
else:
hidden['memory'] = hidden['memory'].clone()
2017-12-07 19:29:51 +08:00
hidden['visible_memory'] = hidden['visible_memory'].clone()
2017-12-11 06:12:45 +08:00
hidden['link_matrix'] = hidden['link_matrix'].clone()
hidden['rev_link_matrix'] = hidden['link_matrix'].clone()
hidden['precedence'] = hidden['precedence'].clone()
2017-11-24 19:11:19 +08:00
hidden['read_weights'] = hidden['read_weights'].clone()
hidden['write_weights'] = hidden['write_weights'].clone()
2017-11-28 02:14:21 +08:00
hidden['read_vectors'] = hidden['read_vectors'].clone()
2017-12-11 02:51:30 +08:00
hidden['least_used_mem'] = hidden['least_used_mem'].clone()
2017-12-03 19:39:59 +08:00
hidden['usage'] = hidden['usage'].clone()
hidden['read_positions'] = hidden['read_positions'].clone()
2017-12-04 23:41:30 +08:00
hidden = self.rebuild_indexes(hidden, erase)
2017-11-24 19:11:19 +08:00
if erase:
hidden['memory'].data.fill_(δ)
2017-12-07 19:29:51 +08:00
hidden['visible_memory'].data.fill_(δ)
2017-12-11 06:12:45 +08:00
hidden['link_matrix'].data.zero_()
hidden['rev_link_matrix'].data.zero_()
hidden['precedence'].data.zero_()
2017-11-24 19:11:19 +08:00
hidden['read_weights'].data.fill_(δ)
hidden['write_weights'].data.fill_(δ)
2017-11-28 02:14:21 +08:00
hidden['read_vectors'].data.fill_(δ)
2017-12-11 02:51:30 +08:00
hidden['least_used_mem'].data.fill_(c+1+self.timestep)
hidden['usage'].data.fill_(δ)
hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
2017-11-24 19:11:19 +08:00
return hidden
def write_into_sparse_memory(self, hidden):
2017-12-07 19:29:51 +08:00
visible_memory = hidden['visible_memory']
positions = hidden['read_positions'].squeeze()
(b, m, w) = hidden['memory'].size()
# update memory
hidden['memory'].scatter_(1, positions.unsqueeze(2).expand(b, self.read_heads*self.K+1, w), visible_memory)
# non-differentiable operations
pos = positions.data.cpu().numpy()
2017-12-09 18:13:11 +08:00
for batch in range(b):
# update indexes
2017-12-09 18:13:11 +08:00
hidden['indexes'][batch].reset()
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
2017-12-11 02:51:30 +08:00
hidden['least_used_mem'] = hidden['least_used_mem'] + 1 if self.timestep < self.mem_size else hidden['least_used_mem'] * 0
2017-11-28 02:14:21 +08:00
return hidden
2017-12-11 06:12:45 +08:00
def update_link_matrices(self, link_matrix, rev_link_matrix, write_weights, precedence):
link_matrix = (1 - write_weights).unsqueeze(2) * link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
rev_link_matrix = (1 - write_weights).unsqueeze(1) * rev_link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
return link_matrix, rev_link_matrix
def update_precedence(self, precedence, write_weights):
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
2017-11-28 02:14:21 +08:00
def write(self, interpolation_gate, write_vector, write_gate, hidden):
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
2017-12-03 19:39:59 +08:00
hidden['usage'], I = self.update_usage(
hidden['read_positions'],
read_weights,
write_weights,
2017-12-03 19:39:59 +08:00
hidden['usage']
)
2017-12-07 19:29:51 +08:00
# either we write to previous read locations
x = interpolation_gate * read_weights
2017-12-07 19:29:51 +08:00
# or to a new location
2017-12-03 19:39:59 +08:00
y = (1 - interpolation_gate) * I
write_weights = write_gate * (x + y)
# store the write weights
hidden['write_weights'].scatter_(1, hidden['read_positions'], write_weights)
2017-11-28 02:14:21 +08:00
2017-12-11 06:12:45 +08:00
# erase matrix
2017-12-10 14:35:13 +08:00
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
2017-12-11 06:12:45 +08:00
# write into memory
2017-12-10 14:35:13 +08:00
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
hidden = self.write_into_sparse_memory(hidden)
2017-11-24 19:11:19 +08:00
2017-12-11 06:12:45 +08:00
# update link_matrix and precedence
(b, c) = write_weights.size()
precedence = hidden['precedence'].gather(1, hidden['read_positions'])
hidden['link_matrix'], hidden['rev_link_matrix'] = \
self.update_link_matrices(hidden['link_matrix'], hidden['rev_link_matrix'], write_weights, precedence)
precedence = self.update_precedence(hidden['precedence'], hidden['write_weights'])
hidden['precedence'].scatter_(1, hidden['read_positions'], precedence)
2017-11-24 19:11:19 +08:00
return hidden
2017-12-03 19:39:59 +08:00
def update_usage(self, read_positions, read_weights, write_weights, usage):
2017-12-07 19:29:51 +08:00
(b, _) = read_positions.size()
2017-12-03 19:39:59 +08:00
# usage is timesteps since a non-negligible memory access
2017-12-07 19:29:51 +08:00
# todo store write weights of all mem and gather from that
u = (read_weights + write_weights > self.δ).float()
2017-12-03 19:39:59 +08:00
# usage before write
relevant_usages = usage.gather(1, read_positions)
2017-12-03 19:39:59 +08:00
# indicator of words with minimal memory usage
2017-12-09 18:13:11 +08:00
minusage = T.min(relevant_usages, -1, keepdim=True)[0]
2017-12-03 19:39:59 +08:00
minusage = minusage.expand(relevant_usages.size())
I = (relevant_usages == minusage).float()
2017-12-03 19:39:59 +08:00
# usage after write
2017-12-07 19:29:51 +08:00
relevant_usages = (self.timestep - relevant_usages) * u + relevant_usages * (1 - u)
2017-12-03 19:39:59 +08:00
usage.scatter_(1, read_positions, relevant_usages)
2017-12-03 19:39:59 +08:00
return usage, I
2017-12-11 02:51:30 +08:00
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage):
b = keys.size(0)
2017-11-27 18:28:14 +08:00
read_positions = []
2017-11-24 19:11:19 +08:00
2017-12-07 19:29:51 +08:00
# we search for k cells per read head
for batch in range(b):
2017-11-30 22:37:52 +08:00
distances, positions = indexes[batch].search(keys[batch])
2017-12-03 19:39:59 +08:00
read_positions.append(T.clamp(positions, 0, self.mem_size - 1))
read_positions = T.stack(read_positions, 0)
2017-12-07 19:29:51 +08:00
# add least used mem to read positions
# TODO: explore possibility of reading co-locations or ranges and such
(b, r, k) = read_positions.size()
read_positions = var(read_positions)
2017-12-11 02:51:30 +08:00
read_positions = T.cat([read_positions.view(b, -1), least_used_mem], 1)
# differentiable ops
(b, m, w) = memory.size()
2017-12-07 19:29:51 +08:00
visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, r*k+1, w))
read_weights = σ(θ(visible_memory, keys), 2)
2017-12-07 19:29:51 +08:00
read_vectors = T.bmm(read_weights, visible_memory)
read_weights = T.prod(read_weights, 1)
2017-11-24 19:11:19 +08:00
2017-12-07 19:29:51 +08:00
return read_vectors, read_positions, read_weights, visible_memory
2017-11-24 19:11:19 +08:00
2017-12-11 06:12:45 +08:00
# def
2017-11-28 02:14:21 +08:00
def read(self, read_query, hidden):
2017-11-27 14:32:41 +08:00
# sparse read
2017-12-07 19:29:51 +08:00
read_vectors, positions, read_weights, visible_memory = \
self.read_from_sparse_memory(
hidden['memory'],
hidden['indexes'],
read_query,
2017-12-11 02:51:30 +08:00
hidden['least_used_mem'],
hidden['usage']
)
2017-11-27 14:32:41 +08:00
hidden['read_positions'] = positions
hidden['read_weights'] = hidden['read_weights'].scatter_(1, positions, read_weights)
2017-11-28 02:14:21 +08:00
hidden['read_vectors'] = read_vectors
2017-12-07 19:29:51 +08:00
hidden['visible_memory'] = visible_memory
2017-11-24 19:11:19 +08:00
2017-12-07 19:29:51 +08:00
return hidden['read_vectors'], hidden
2017-11-24 19:11:19 +08:00
def forward(self, ξ, hidden):
2017-11-29 18:11:50 +08:00
t = time.time()
2017-11-24 19:11:19 +08:00
# ξ = ξ.detach()
m = self.mem_size
w = self.cell_size
2017-12-07 19:29:51 +08:00
r = self.read_heads
c = r * self.K + 1
2017-11-24 19:11:19 +08:00
b = ξ.size()[0]
if self.independent_linears:
# r read keys (b * r * w)
2017-12-07 19:29:51 +08:00
read_query = self.read_query_transform(ξ).view(b, r, w)
2017-11-24 19:11:19 +08:00
# write key (b * 1 * w)
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
2017-12-07 19:29:51 +08:00
# write vector (b * 1 * r)
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
2017-11-24 19:11:19 +08:00
# write gate (b * 1)
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
else:
ξ = self.interface_weights(ξ)
2017-12-07 19:29:51 +08:00
# r read keys (b * r * w)
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
# write key (b * 1 * w)
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
# write vector (b * 1 * r)
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c)
2017-11-24 19:11:19 +08:00
# write gate (b * 1)
2017-11-27 16:21:17 +08:00
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
2017-11-24 19:11:19 +08:00
2017-12-03 19:39:59 +08:00
self.timestep += 1
2017-11-28 02:14:21 +08:00
hidden = self.write(interpolation_gate, write_vector, write_gate, hidden)
return self.read(read_query, hidden)