2017-11-24 19:11:19 +08:00
|
|
|
|
#!/usr/bin/env python3
|
2017-12-02 14:32:55 +08:00
|
|
|
|
# -*- coding: utf-8 -*-
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch as T
|
|
|
|
|
from torch.autograd import Variable as var
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
import numpy as np
|
2017-12-06 17:19:52 +08:00
|
|
|
|
import math
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
2017-12-11 02:51:30 +08:00
|
|
|
|
from .flann_index import FLANNIndex
|
2017-11-24 19:11:19 +08:00
|
|
|
|
from .util import *
|
2017-11-29 18:11:50 +08:00
|
|
|
|
import time
|
2017-11-27 16:21:17 +08:00
|
|
|
|
|
2017-12-03 19:39:59 +08:00
|
|
|
|
|
2017-11-24 19:11:19 +08:00
|
|
|
|
class SparseMemory(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
2017-11-27 16:21:17 +08:00
|
|
|
|
self,
|
|
|
|
|
input_size,
|
|
|
|
|
mem_size=512,
|
|
|
|
|
cell_size=32,
|
|
|
|
|
independent_linears=True,
|
2017-12-07 19:29:51 +08:00
|
|
|
|
read_heads=4,
|
|
|
|
|
sparse_reads=10,
|
2017-12-06 17:19:52 +08:00
|
|
|
|
num_lists=None,
|
2017-11-27 16:21:17 +08:00
|
|
|
|
index_checks=32,
|
2017-11-30 17:24:51 +08:00
|
|
|
|
gpu_id=-1,
|
|
|
|
|
mem_gpu_id=-1
|
2017-11-27 16:21:17 +08:00
|
|
|
|
):
|
|
|
|
|
super(SparseMemory, self).__init__()
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
|
|
|
|
self.mem_size = mem_size
|
|
|
|
|
self.cell_size = cell_size
|
|
|
|
|
self.gpu_id = gpu_id
|
2017-11-30 17:24:51 +08:00
|
|
|
|
self.mem_gpu_id = mem_gpu_id
|
2017-11-24 19:11:19 +08:00
|
|
|
|
self.input_size = input_size
|
|
|
|
|
self.independent_linears = independent_linears
|
2017-11-27 16:21:17 +08:00
|
|
|
|
self.K = sparse_reads if self.mem_size > sparse_reads else self.mem_size
|
2017-12-07 19:29:51 +08:00
|
|
|
|
self.read_heads = read_heads
|
2017-12-06 17:19:52 +08:00
|
|
|
|
self.num_lists = num_lists if num_lists is not None else int(self.mem_size / 100)
|
2017-11-27 14:32:41 +08:00
|
|
|
|
self.index_checks = index_checks
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
|
|
|
|
m = self.mem_size
|
|
|
|
|
w = self.cell_size
|
2017-12-07 19:29:51 +08:00
|
|
|
|
r = self.read_heads
|
|
|
|
|
c = r * self.K + 1
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
|
|
|
|
if self.independent_linears:
|
2017-12-07 19:29:51 +08:00
|
|
|
|
self.read_query_transform = nn.Linear(self.input_size, w*r)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
self.write_vector_transform = nn.Linear(self.input_size, w)
|
2017-12-07 19:29:51 +08:00
|
|
|
|
self.interpolation_gate_transform = nn.Linear(self.input_size, c)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
2017-12-09 17:53:29 +08:00
|
|
|
|
T.nn.init.orthogonal(self.read_query_transform.weight)
|
|
|
|
|
T.nn.init.orthogonal(self.write_vector_transform.weight)
|
|
|
|
|
T.nn.init.orthogonal(self.interpolation_gate_transform.weight)
|
|
|
|
|
T.nn.init.orthogonal(self.write_gate_transform.weight)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
else:
|
2017-12-07 19:29:51 +08:00
|
|
|
|
self.interface_size = (r * w) + w + c + 1
|
2017-11-24 19:11:19 +08:00
|
|
|
|
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
2017-12-09 17:53:29 +08:00
|
|
|
|
T.nn.init.orthogonal(self.interface_weights.weight)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
2017-12-11 06:12:45 +08:00
|
|
|
|
self.I = cuda(1 - T.eye(c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
2017-12-03 19:39:59 +08:00
|
|
|
|
self.δ = 0.005 # minimum usage
|
|
|
|
|
self.timestep = 0
|
2017-11-27 14:32:41 +08:00
|
|
|
|
|
2017-12-02 14:32:55 +08:00
|
|
|
|
def rebuild_indexes(self, hidden, erase=False):
|
2017-11-30 17:24:51 +08:00
|
|
|
|
b = hidden['memory'].size(0)
|
2017-11-27 14:32:41 +08:00
|
|
|
|
|
2017-12-02 14:32:55 +08:00
|
|
|
|
# if indexes already exist, we reset them
|
|
|
|
|
if 'indexes' in hidden:
|
2017-12-03 19:39:59 +08:00
|
|
|
|
[x.reset() for x in hidden['indexes']]
|
2017-12-02 14:32:55 +08:00
|
|
|
|
else:
|
|
|
|
|
# create new indexes
|
|
|
|
|
hidden['indexes'] = \
|
2017-12-11 02:51:30 +08:00
|
|
|
|
[FLANNIndex(cell_size=self.cell_size,
|
|
|
|
|
nr_cells=self.mem_size, K=self.K, num_kdtrees=self.num_lists,
|
2017-12-03 19:39:59 +08:00
|
|
|
|
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
2017-12-02 14:32:55 +08:00
|
|
|
|
|
|
|
|
|
# add existing memory into indexes
|
2017-12-06 23:44:16 +08:00
|
|
|
|
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
|
2017-12-02 14:32:55 +08:00
|
|
|
|
if not erase:
|
2017-12-03 19:39:59 +08:00
|
|
|
|
for n, i in enumerate(hidden['indexes']):
|
2017-12-06 23:44:16 +08:00
|
|
|
|
i.reset()
|
|
|
|
|
i.add(hidden['memory'][n], last=pos[n][-1])
|
2017-12-03 19:39:59 +08:00
|
|
|
|
else:
|
|
|
|
|
self.timestep = 0
|
2017-12-02 14:32:55 +08:00
|
|
|
|
|
2017-11-27 14:32:41 +08:00
|
|
|
|
return hidden
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
|
|
|
|
def reset(self, batch_size=1, hidden=None, erase=True):
|
|
|
|
|
m = self.mem_size
|
|
|
|
|
w = self.cell_size
|
|
|
|
|
b = batch_size
|
2017-12-07 19:29:51 +08:00
|
|
|
|
r = self.read_heads
|
|
|
|
|
c = r * self.K + 1
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
|
|
|
|
if hidden is None:
|
2017-11-27 16:21:17 +08:00
|
|
|
|
hidden = {
|
2017-11-27 14:32:41 +08:00
|
|
|
|
# warning can be a huge chunk of contiguous memory
|
2017-11-30 17:24:51 +08:00
|
|
|
|
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
2017-12-07 19:29:51 +08:00
|
|
|
|
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
2017-12-11 06:12:45 +08:00
|
|
|
|
'link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
|
|
|
|
|
'rev_link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
|
|
|
|
|
'precedence': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
|
2017-12-09 17:53:29 +08:00
|
|
|
|
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
|
|
|
|
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
2017-11-28 02:14:21 +08:00
|
|
|
|
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
2017-12-11 02:51:30 +08:00
|
|
|
|
'least_used_mem': cuda(T.zeros(b, 1).fill_(c+1), gpu_id=self.gpu_id).long(),
|
2017-12-06 17:19:52 +08:00
|
|
|
|
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
2017-12-07 19:29:51 +08:00
|
|
|
|
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
2017-11-24 19:11:19 +08:00
|
|
|
|
}
|
2017-12-03 19:39:59 +08:00
|
|
|
|
hidden = self.rebuild_indexes(hidden, erase=True)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
else:
|
2017-11-30 17:24:51 +08:00
|
|
|
|
hidden['memory'] = hidden['memory'].clone()
|
2017-12-07 19:29:51 +08:00
|
|
|
|
hidden['visible_memory'] = hidden['visible_memory'].clone()
|
2017-12-11 06:12:45 +08:00
|
|
|
|
hidden['link_matrix'] = hidden['link_matrix'].clone()
|
|
|
|
|
hidden['rev_link_matrix'] = hidden['link_matrix'].clone()
|
|
|
|
|
hidden['precedence'] = hidden['precedence'].clone()
|
2017-11-24 19:11:19 +08:00
|
|
|
|
hidden['read_weights'] = hidden['read_weights'].clone()
|
|
|
|
|
hidden['write_weights'] = hidden['write_weights'].clone()
|
2017-11-28 02:14:21 +08:00
|
|
|
|
hidden['read_vectors'] = hidden['read_vectors'].clone()
|
2017-12-11 02:51:30 +08:00
|
|
|
|
hidden['least_used_mem'] = hidden['least_used_mem'].clone()
|
2017-12-03 19:39:59 +08:00
|
|
|
|
hidden['usage'] = hidden['usage'].clone()
|
|
|
|
|
hidden['read_positions'] = hidden['read_positions'].clone()
|
2017-12-04 23:41:30 +08:00
|
|
|
|
hidden = self.rebuild_indexes(hidden, erase)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
|
|
|
|
if erase:
|
2017-11-30 17:24:51 +08:00
|
|
|
|
hidden['memory'].data.fill_(δ)
|
2017-12-07 19:29:51 +08:00
|
|
|
|
hidden['visible_memory'].data.fill_(δ)
|
2017-12-11 06:12:45 +08:00
|
|
|
|
hidden['link_matrix'].data.zero_()
|
|
|
|
|
hidden['rev_link_matrix'].data.zero_()
|
|
|
|
|
hidden['precedence'].data.zero_()
|
2017-11-24 19:11:19 +08:00
|
|
|
|
hidden['read_weights'].data.fill_(δ)
|
|
|
|
|
hidden['write_weights'].data.fill_(δ)
|
2017-11-28 02:14:21 +08:00
|
|
|
|
hidden['read_vectors'].data.fill_(δ)
|
2017-12-11 02:51:30 +08:00
|
|
|
|
hidden['least_used_mem'].data.fill_(c+1+self.timestep)
|
2017-12-06 17:19:52 +08:00
|
|
|
|
hidden['usage'].data.fill_(δ)
|
2017-12-09 20:58:59 +08:00
|
|
|
|
hidden['read_positions'] = cuda(T.arange(self.timestep, c+self.timestep).expand(b, c), gpu_id=self.gpu_id).long()
|
|
|
|
|
|
2017-11-24 19:11:19 +08:00
|
|
|
|
return hidden
|
|
|
|
|
|
2017-11-30 17:24:51 +08:00
|
|
|
|
def write_into_sparse_memory(self, hidden):
|
2017-12-07 19:29:51 +08:00
|
|
|
|
visible_memory = hidden['visible_memory']
|
2017-11-30 17:24:51 +08:00
|
|
|
|
positions = hidden['read_positions'].squeeze()
|
|
|
|
|
|
2017-12-04 04:18:03 +08:00
|
|
|
|
(b, m, w) = hidden['memory'].size()
|
|
|
|
|
# update memory
|
2017-12-09 17:53:29 +08:00
|
|
|
|
hidden['memory'].scatter_(1, positions.unsqueeze(2).expand(b, self.read_heads*self.K+1, w), visible_memory)
|
2017-12-04 04:18:03 +08:00
|
|
|
|
|
|
|
|
|
# non-differentiable operations
|
2017-11-30 17:24:51 +08:00
|
|
|
|
pos = positions.data.cpu().numpy()
|
2017-12-09 18:13:11 +08:00
|
|
|
|
for batch in range(b):
|
2017-11-30 17:24:51 +08:00
|
|
|
|
# update indexes
|
2017-12-09 18:13:11 +08:00
|
|
|
|
hidden['indexes'][batch].reset()
|
|
|
|
|
hidden['indexes'][batch].add(hidden['memory'][batch], last=pos[batch][-1])
|
2017-12-09 20:58:59 +08:00
|
|
|
|
|
2017-12-11 02:51:30 +08:00
|
|
|
|
hidden['least_used_mem'] = hidden['least_used_mem'] + 1 if self.timestep < self.mem_size else hidden['least_used_mem'] * 0
|
2017-11-30 17:24:51 +08:00
|
|
|
|
|
2017-11-28 02:14:21 +08:00
|
|
|
|
return hidden
|
|
|
|
|
|
2017-12-11 06:12:45 +08:00
|
|
|
|
def update_link_matrices(self, link_matrix, rev_link_matrix, write_weights, precedence):
|
|
|
|
|
link_matrix = (1 - write_weights).unsqueeze(2) * link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
rev_link_matrix = (1 - write_weights).unsqueeze(1) * rev_link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
return link_matrix, rev_link_matrix
|
|
|
|
|
|
|
|
|
|
def update_precedence(self, precedence, write_weights):
|
|
|
|
|
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
|
|
|
|
|
|
2017-11-28 02:14:21 +08:00
|
|
|
|
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
|
|
|
|
|
2017-12-09 17:53:29 +08:00
|
|
|
|
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
|
|
|
|
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
|
|
|
|
|
|
2017-12-03 19:39:59 +08:00
|
|
|
|
hidden['usage'], I = self.update_usage(
|
|
|
|
|
hidden['read_positions'],
|
2017-12-09 17:53:29 +08:00
|
|
|
|
read_weights,
|
|
|
|
|
write_weights,
|
2017-12-03 19:39:59 +08:00
|
|
|
|
hidden['usage']
|
|
|
|
|
)
|
|
|
|
|
|
2017-12-07 19:29:51 +08:00
|
|
|
|
# either we write to previous read locations
|
2017-12-09 17:53:29 +08:00
|
|
|
|
x = interpolation_gate * read_weights
|
2017-12-07 19:29:51 +08:00
|
|
|
|
# or to a new location
|
2017-12-03 19:39:59 +08:00
|
|
|
|
y = (1 - interpolation_gate) * I
|
2017-12-09 17:53:29 +08:00
|
|
|
|
write_weights = write_gate * (x + y)
|
|
|
|
|
|
|
|
|
|
# store the write weights
|
|
|
|
|
hidden['write_weights'].scatter_(1, hidden['read_positions'], write_weights)
|
2017-11-28 02:14:21 +08:00
|
|
|
|
|
2017-12-11 06:12:45 +08:00
|
|
|
|
# erase matrix
|
2017-12-10 14:35:13 +08:00
|
|
|
|
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
|
|
|
|
|
|
2017-12-11 06:12:45 +08:00
|
|
|
|
# write into memory
|
2017-12-10 14:35:13 +08:00
|
|
|
|
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
2017-12-09 17:53:29 +08:00
|
|
|
|
hidden = self.write_into_sparse_memory(hidden)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
2017-12-11 06:12:45 +08:00
|
|
|
|
# update link_matrix and precedence
|
|
|
|
|
(b, c) = write_weights.size()
|
|
|
|
|
precedence = hidden['precedence'].gather(1, hidden['read_positions'])
|
|
|
|
|
|
|
|
|
|
hidden['link_matrix'], hidden['rev_link_matrix'] = \
|
|
|
|
|
self.update_link_matrices(hidden['link_matrix'], hidden['rev_link_matrix'], write_weights, precedence)
|
|
|
|
|
precedence = self.update_precedence(hidden['precedence'], hidden['write_weights'])
|
|
|
|
|
|
|
|
|
|
hidden['precedence'].scatter_(1, hidden['read_positions'], precedence)
|
|
|
|
|
|
2017-11-24 19:11:19 +08:00
|
|
|
|
return hidden
|
|
|
|
|
|
2017-12-03 19:39:59 +08:00
|
|
|
|
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
2017-12-07 19:29:51 +08:00
|
|
|
|
(b, _) = read_positions.size()
|
2017-12-03 19:39:59 +08:00
|
|
|
|
# usage is timesteps since a non-negligible memory access
|
2017-12-07 19:29:51 +08:00
|
|
|
|
# todo store write weights of all mem and gather from that
|
2017-12-09 17:53:29 +08:00
|
|
|
|
u = (read_weights + write_weights > self.δ).float()
|
2017-12-03 19:39:59 +08:00
|
|
|
|
|
|
|
|
|
# usage before write
|
2017-12-04 04:18:03 +08:00
|
|
|
|
relevant_usages = usage.gather(1, read_positions)
|
2017-12-03 19:39:59 +08:00
|
|
|
|
|
|
|
|
|
# indicator of words with minimal memory usage
|
2017-12-09 18:13:11 +08:00
|
|
|
|
minusage = T.min(relevant_usages, -1, keepdim=True)[0]
|
2017-12-03 19:39:59 +08:00
|
|
|
|
minusage = minusage.expand(relevant_usages.size())
|
2017-12-09 17:53:29 +08:00
|
|
|
|
I = (relevant_usages == minusage).float()
|
2017-12-03 19:39:59 +08:00
|
|
|
|
|
|
|
|
|
# usage after write
|
2017-12-07 19:29:51 +08:00
|
|
|
|
relevant_usages = (self.timestep - relevant_usages) * u + relevant_usages * (1 - u)
|
2017-12-03 19:39:59 +08:00
|
|
|
|
|
2017-12-04 04:18:03 +08:00
|
|
|
|
usage.scatter_(1, read_positions, relevant_usages)
|
2017-12-03 19:39:59 +08:00
|
|
|
|
|
|
|
|
|
return usage, I
|
|
|
|
|
|
2017-12-11 02:51:30 +08:00
|
|
|
|
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage):
|
2017-11-30 17:24:51 +08:00
|
|
|
|
b = keys.size(0)
|
2017-11-27 18:28:14 +08:00
|
|
|
|
read_positions = []
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
2017-12-07 19:29:51 +08:00
|
|
|
|
# we search for k cells per read head
|
2017-11-30 17:24:51 +08:00
|
|
|
|
for batch in range(b):
|
2017-11-30 22:37:52 +08:00
|
|
|
|
distances, positions = indexes[batch].search(keys[batch])
|
2017-12-03 19:39:59 +08:00
|
|
|
|
read_positions.append(T.clamp(positions, 0, self.mem_size - 1))
|
2017-11-30 17:24:51 +08:00
|
|
|
|
read_positions = T.stack(read_positions, 0)
|
2017-12-06 17:19:52 +08:00
|
|
|
|
|
2017-12-07 19:29:51 +08:00
|
|
|
|
# add least used mem to read positions
|
|
|
|
|
# TODO: explore possibility of reading co-locations or ranges and such
|
|
|
|
|
(b, r, k) = read_positions.size()
|
2017-11-30 17:24:51 +08:00
|
|
|
|
read_positions = var(read_positions)
|
2017-12-11 02:51:30 +08:00
|
|
|
|
read_positions = T.cat([read_positions.view(b, -1), least_used_mem], 1)
|
2017-12-06 17:19:52 +08:00
|
|
|
|
|
2017-12-09 17:53:29 +08:00
|
|
|
|
# differentiable ops
|
2017-12-04 04:18:03 +08:00
|
|
|
|
(b, m, w) = memory.size()
|
2017-12-07 19:29:51 +08:00
|
|
|
|
visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, r*k+1, w))
|
|
|
|
|
|
2017-12-09 20:58:59 +08:00
|
|
|
|
read_weights = σ(θ(visible_memory, keys), 2)
|
2017-12-07 19:29:51 +08:00
|
|
|
|
read_vectors = T.bmm(read_weights, visible_memory)
|
2017-12-09 17:53:29 +08:00
|
|
|
|
read_weights = T.prod(read_weights, 1)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
2017-12-07 19:29:51 +08:00
|
|
|
|
return read_vectors, read_positions, read_weights, visible_memory
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
2017-12-11 06:12:45 +08:00
|
|
|
|
# def
|
|
|
|
|
|
2017-11-28 02:14:21 +08:00
|
|
|
|
def read(self, read_query, hidden):
|
2017-11-27 14:32:41 +08:00
|
|
|
|
# sparse read
|
2017-12-07 19:29:51 +08:00
|
|
|
|
read_vectors, positions, read_weights, visible_memory = \
|
2017-12-06 17:19:52 +08:00
|
|
|
|
self.read_from_sparse_memory(
|
|
|
|
|
hidden['memory'],
|
|
|
|
|
hidden['indexes'],
|
|
|
|
|
read_query,
|
2017-12-11 02:51:30 +08:00
|
|
|
|
hidden['least_used_mem'],
|
2017-12-06 17:19:52 +08:00
|
|
|
|
hidden['usage']
|
|
|
|
|
)
|
2017-12-09 17:53:29 +08:00
|
|
|
|
|
2017-11-27 14:32:41 +08:00
|
|
|
|
hidden['read_positions'] = positions
|
2017-12-09 17:53:29 +08:00
|
|
|
|
hidden['read_weights'] = hidden['read_weights'].scatter_(1, positions, read_weights)
|
2017-11-28 02:14:21 +08:00
|
|
|
|
hidden['read_vectors'] = read_vectors
|
2017-12-07 19:29:51 +08:00
|
|
|
|
hidden['visible_memory'] = visible_memory
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
2017-12-07 19:29:51 +08:00
|
|
|
|
return hidden['read_vectors'], hidden
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
|
|
|
|
def forward(self, ξ, hidden):
|
2017-11-29 18:11:50 +08:00
|
|
|
|
t = time.time()
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
|
|
|
|
# ξ = ξ.detach()
|
|
|
|
|
m = self.mem_size
|
|
|
|
|
w = self.cell_size
|
2017-12-07 19:29:51 +08:00
|
|
|
|
r = self.read_heads
|
|
|
|
|
c = r * self.K + 1
|
2017-11-24 19:11:19 +08:00
|
|
|
|
b = ξ.size()[0]
|
|
|
|
|
|
|
|
|
|
if self.independent_linears:
|
|
|
|
|
# r read keys (b * r * w)
|
2017-12-07 19:29:51 +08:00
|
|
|
|
read_query = self.read_query_transform(ξ).view(b, r, w)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
# write key (b * 1 * w)
|
|
|
|
|
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
|
2017-12-07 19:29:51 +08:00
|
|
|
|
# write vector (b * 1 * r)
|
2017-12-09 17:53:29 +08:00
|
|
|
|
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
# write gate (b * 1)
|
|
|
|
|
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
|
|
|
|
else:
|
|
|
|
|
ξ = self.interface_weights(ξ)
|
2017-12-07 19:29:51 +08:00
|
|
|
|
# r read keys (b * r * w)
|
|
|
|
|
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
|
|
|
|
|
# write key (b * 1 * w)
|
|
|
|
|
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
|
|
|
|
|
# write vector (b * 1 * r)
|
2017-12-09 17:53:29 +08:00
|
|
|
|
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
# write gate (b * 1)
|
2017-11-27 16:21:17 +08:00
|
|
|
|
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
2017-11-24 19:11:19 +08:00
|
|
|
|
|
2017-12-03 19:39:59 +08:00
|
|
|
|
self.timestep += 1
|
2017-11-28 02:14:21 +08:00
|
|
|
|
hidden = self.write(interpolation_gate, write_vector, write_gate, hidden)
|
|
|
|
|
return self.read(read_query, hidden)
|