678 lines
27 KiB
Python
678 lines
27 KiB
Python
# Copyright 2017 Robert Csordas. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# ==============================================================================
|
|
|
|
import torch
|
|
import torch.utils.data
|
|
import torch.nn.functional as F
|
|
import torch.nn.init as init
|
|
import functools
|
|
import math
|
|
|
|
|
|
def oneplus(t):
|
|
return F.softplus(t, 1, 20) + 1.0
|
|
|
|
def get_next_tensor_part(src, dims, prev_pos=0):
|
|
if not isinstance(dims, list):
|
|
dims=[dims]
|
|
n = functools.reduce(lambda x, y: x * y, dims)
|
|
data = src.narrow(-1, prev_pos, n)
|
|
return data.contiguous().view(list(data.size())[:-1] + dims) if len(dims)>1 else data, prev_pos + n
|
|
|
|
def split_tensor(src, shapes):
|
|
pos = 0
|
|
res = []
|
|
for s in shapes:
|
|
d, pos = get_next_tensor_part(src, s, pos)
|
|
res.append(d)
|
|
return res
|
|
|
|
def dict_get(dict,name):
|
|
return dict.get(name) if dict is not None else None
|
|
|
|
|
|
def dict_append(dict, name, val):
|
|
if dict is not None:
|
|
l = dict.get(name)
|
|
if not l:
|
|
l = []
|
|
dict[name] = l
|
|
l.append(val)
|
|
|
|
|
|
def init_debug(debug, initial):
|
|
if debug is not None and not debug:
|
|
debug.update(initial)
|
|
|
|
def merge_debug_tensors(d, dim):
|
|
if d is not None:
|
|
for k, v in d.items():
|
|
if isinstance(v, dict):
|
|
merge_debug_tensors(v, dim)
|
|
elif isinstance(v, list):
|
|
d[k] = torch.stack(v, dim)
|
|
|
|
|
|
def linear_reset(module, gain=1.0):
|
|
assert isinstance(module, torch.nn.Linear)
|
|
init.xavier_uniform_(module.weight, gain=gain)
|
|
s = module.weight.size(1)
|
|
if module.bias is not None:
|
|
module.bias.data.zero_()
|
|
|
|
_EPS = 1e-6
|
|
|
|
class AllocationManager(torch.nn.Module):
|
|
def __init__(self):
|
|
super(AllocationManager, self).__init__()
|
|
self.usages = None
|
|
self.zero_usages = None
|
|
self.debug_sequ_init = False
|
|
self.one = None
|
|
|
|
def _init_sequence(self, prev_read_distributions):
|
|
# prev_read_distributions size is [batch, n_heads, cell count]
|
|
s = prev_read_distributions.size()
|
|
if self.zero_usages is None or list(self.zero_usages.size())!=[s[0],s[-1]]:
|
|
self.zero_usages = torch.zeros(s[0], s[-1], device = prev_read_distributions.device)
|
|
if self.debug_sequ_init:
|
|
self.zero_usages += torch.arange(0, s[-1]).unsqueeze(0) * 1e-10
|
|
|
|
self.usages = self.zero_usages
|
|
|
|
def _init_consts(self, device):
|
|
if self.one is None:
|
|
self.one = torch.ones(1, device=device)
|
|
|
|
def new_sequence(self):
|
|
self.usages = None
|
|
|
|
def update_usages(self, prev_write_distribution, prev_read_distributions, free_gates):
|
|
# Read distributions shape: [batch, n_heads, cell count]
|
|
# Free gates shape: [batch, n_heads]
|
|
|
|
self._init_consts(prev_read_distributions.device)
|
|
phi = torch.addcmul(self.one, -1, free_gates.unsqueeze(-1), prev_read_distributions).prod(-2)
|
|
# Phi is the free tensor, sized [batch, cell count]
|
|
|
|
# If memory usage counter if doesn't exists
|
|
if self.usages is None:
|
|
self._init_sequence(prev_read_distributions)
|
|
# in first timestep nothing is written or read yet, so we don't need any further processing
|
|
else:
|
|
self.usages = torch.addcmul(self.usages, 1, prev_write_distribution.detach(), (1 - self.usages)) * phi
|
|
|
|
return phi
|
|
|
|
def forward(self, prev_write_distribution, prev_read_distributions, free_gates):
|
|
phi = self.update_usages(prev_write_distribution, prev_read_distributions, free_gates)
|
|
sorted_usage, free_list = (self.usages*(1.0-_EPS)+_EPS).sort(-1)
|
|
|
|
u_prod = sorted_usage.cumprod(-1)
|
|
one_minus_usage = 1.0 - sorted_usage
|
|
sorted_scores = torch.cat([one_minus_usage[..., 0:1], one_minus_usage[..., 1:] * u_prod[..., :-1]], dim=-1)
|
|
|
|
return sorted_scores.clone().scatter_(-1, free_list, sorted_scores), phi
|
|
|
|
|
|
class ContentAddressGenerator(torch.nn.Module):
|
|
def __init__(self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
|
|
super(ContentAddressGenerator, self).__init__()
|
|
self.disable_content_norm = disable_content_norm
|
|
self.mask_min = mask_min
|
|
self.disable_key_masking = disable_key_masking
|
|
|
|
def forward(self, memory, keys, betas, mask=None):
|
|
# Memory shape [batch, cell count, word length]
|
|
# Key shape [batch, n heads*, word length]
|
|
# Betas shape [batch, n heads]
|
|
if mask is not None and self.mask_min != 0:
|
|
mask = mask * (1.0-self.mask_min) + self.mask_min
|
|
|
|
single_head = keys.dim() == 2
|
|
if single_head:
|
|
# Single head
|
|
keys = keys.unsqueeze(1)
|
|
if mask is not None:
|
|
mask = mask.unsqueeze(1)
|
|
|
|
memory = memory.unsqueeze(1)
|
|
keys = keys.unsqueeze(-2)
|
|
|
|
if mask is not None:
|
|
mask = mask.unsqueeze(-2)
|
|
memory = memory * mask
|
|
if not self.disable_key_masking:
|
|
keys = keys * mask
|
|
|
|
# Shape [batch, n heads, cell count]
|
|
norm = keys.norm(dim=-1)
|
|
if not self.disable_content_norm:
|
|
norm = norm * memory.norm(dim=-1)
|
|
|
|
scores = (memory * keys).sum(-1) / (norm + _EPS)
|
|
scores *= betas.unsqueeze(-1)
|
|
|
|
res = F.softmax(scores, scores.dim()-1)
|
|
return res.squeeze(1) if single_head else res
|
|
|
|
|
|
class WriteHead(torch.nn.Module):
|
|
@staticmethod
|
|
def create_write_archive(write_dist, erase_vector, write_vector, phi):
|
|
return dict(write_dist=write_dist, erase_vector=erase_vector, write_vector=write_vector, phi=phi)
|
|
|
|
def __init__(self, dealloc_content=True, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
|
|
super(WriteHead, self).__init__()
|
|
self.write_content_generator = ContentAddressGenerator(disable_content_norm, mask_min=mask_min, disable_key_masking=disable_key_masking)
|
|
self.allocation_manager = AllocationManager()
|
|
self.last_write = None
|
|
self.dealloc_content = dealloc_content
|
|
self.new_sequence()
|
|
|
|
def new_sequence(self):
|
|
self.last_write = None
|
|
self.allocation_manager.new_sequence()
|
|
|
|
@staticmethod
|
|
def mem_update(memory, write_dist, erase_vector, write_vector, phi):
|
|
# In original paper the memory content is NOT deallocated, which makes content based addressing basically
|
|
# unusable when multiple similar steps should be done. The reason for this is that the memory contents are
|
|
# still there, so the lookup will find them, unless an allocation clears it before the next search, which is
|
|
# completely random. So I'm arguing that erase matrix should also take in account the free gates (multiply it
|
|
# with phi)
|
|
write_dist = write_dist.unsqueeze(-1)
|
|
|
|
erase_matrix = 1.0 - write_dist * erase_vector.unsqueeze(-2)
|
|
if phi is not None:
|
|
erase_matrix = erase_matrix * phi.unsqueeze(-1)
|
|
|
|
update_matrix = write_dist * write_vector.unsqueeze(-2)
|
|
return memory * erase_matrix + update_matrix
|
|
|
|
def forward(self, memory, write_content_key, write_beta, erase_vector, write_vector, alloc_gate, write_gate,
|
|
free_gates, prev_read_dist, write_mask=None, debug=None):
|
|
last_w_dist = self.last_write["write_dist"] if self.last_write is not None else None
|
|
|
|
content_dist = self.write_content_generator(memory, write_content_key, write_beta, mask = write_mask)
|
|
alloc_dist, phi = self.allocation_manager(last_w_dist, prev_read_dist, free_gates)
|
|
|
|
# Shape [batch, cell count]
|
|
write_dist = write_gate * (alloc_gate * alloc_dist + (1-alloc_gate)*content_dist)
|
|
self.last_write = WriteHead.create_write_archive(write_dist, erase_vector, write_vector, phi if self.dealloc_content else None)
|
|
|
|
dict_append(debug, "alloc_dist", alloc_dist)
|
|
dict_append(debug, "write_dist", write_dist)
|
|
dict_append(debug, "mem_usages", self.allocation_manager.usages)
|
|
dict_append(debug, "free_gates", free_gates)
|
|
dict_append(debug, "write_betas", write_beta)
|
|
dict_append(debug, "write_gate", write_gate)
|
|
dict_append(debug, "write_vector", write_vector)
|
|
dict_append(debug, "alloc_gate", alloc_gate)
|
|
dict_append(debug, "erase_vector", erase_vector)
|
|
if write_mask is not None:
|
|
dict_append(debug, "write_mask", write_mask)
|
|
|
|
return WriteHead.mem_update(memory, **self.last_write)
|
|
|
|
class RawWriteHead(torch.nn.Module):
|
|
def __init__(self, n_read_heads, word_length, use_mask=False, dealloc_content=True, disable_content_norm=False,
|
|
mask_min=0.0, disable_key_masking=False):
|
|
super(RawWriteHead, self).__init__()
|
|
self.write_head = WriteHead(dealloc_content = dealloc_content, disable_content_norm = disable_content_norm,
|
|
mask_min=mask_min, disable_key_masking=disable_key_masking)
|
|
self.word_length = word_length
|
|
self.n_read_heads = n_read_heads
|
|
self.use_mask = use_mask
|
|
self.input_size = 3*self.word_length + self.n_read_heads + 3 + (self.word_length if use_mask else 0)
|
|
|
|
def new_sequence(self):
|
|
self.write_head.new_sequence()
|
|
|
|
def get_prev_write(self):
|
|
return self.write_head.last_write
|
|
|
|
def forward(self, memory, nn_output, prev_read_dist, debug):
|
|
shapes = [[self.word_length]] * (4 if self.use_mask else 3) + [[self.n_read_heads]] + [[1]] * 3
|
|
tensors = split_tensor(nn_output, shapes)
|
|
|
|
if self.use_mask:
|
|
write_mask = torch.sigmoid(tensors[0])
|
|
tensors=tensors[1:]
|
|
else:
|
|
write_mask = None
|
|
|
|
write_content_key, erase_vector, write_vector, free_gates, write_beta, alloc_gate, write_gate = tensors
|
|
|
|
erase_vector = torch.sigmoid(erase_vector)
|
|
free_gates = torch.sigmoid(free_gates)
|
|
write_beta = oneplus(write_beta)
|
|
alloc_gate = torch.sigmoid(alloc_gate)
|
|
write_gate = torch.sigmoid(write_gate)
|
|
|
|
return self.write_head(memory, write_content_key, write_beta, erase_vector, write_vector,
|
|
alloc_gate, write_gate, free_gates, prev_read_dist, debug=debug, write_mask=write_mask)
|
|
|
|
def get_neural_input_size(self):
|
|
return self.input_size
|
|
|
|
|
|
class TemporalMemoryLinkage(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TemporalMemoryLinkage, self).__init__()
|
|
self.temp_link_mat = None
|
|
self.precedence_weighting = None
|
|
self.diag_mask = None
|
|
|
|
self.initial_temp_link_mat = None
|
|
self.initial_precedence_weighting = None
|
|
self.initial_diag_mask = None
|
|
self.initial_shape = None
|
|
|
|
def new_sequence(self):
|
|
self.temp_link_mat = None
|
|
self.precedence_weighting = None
|
|
self.diag_mask = None
|
|
|
|
def _init_link(self, w_dist):
|
|
s = list(w_dist.size())
|
|
if self.initial_shape is None or s != self.initial_shape:
|
|
self.initial_temp_link_mat = torch.zeros(s[0], s[-1], s[-1]).to(w_dist.device)
|
|
self.initial_precedence_weighting = torch.zeros(s[0], s[-1]).to(w_dist.device)
|
|
self.initial_diag_mask = (1.0 - torch.eye(s[-1]).unsqueeze(0).to(w_dist)).detach()
|
|
|
|
self.temp_link_mat = self.initial_temp_link_mat
|
|
self.precedence_weighting = self.initial_precedence_weighting
|
|
self.diag_mask = self.initial_diag_mask
|
|
|
|
def _update_precedence(self, w_dist):
|
|
# w_dist shape: [ batch, cell count ]
|
|
self.precedence_weighting = (1.0 - w_dist.sum(-1, keepdim=True)) * self.precedence_weighting + w_dist
|
|
|
|
def _update_links(self, w_dist):
|
|
if self.temp_link_mat is None:
|
|
self._init_link(w_dist)
|
|
|
|
wt_i = w_dist.unsqueeze(-1)
|
|
wt_j = w_dist.unsqueeze(-2)
|
|
pt_j = self.precedence_weighting.unsqueeze(-2)
|
|
|
|
self.temp_link_mat = ((1 - wt_i - wt_j) * self.temp_link_mat + wt_i * pt_j) * self.diag_mask
|
|
|
|
def forward(self, w_dist, prev_r_dists, debug = None):
|
|
self._update_links(w_dist)
|
|
self._update_precedence(w_dist)
|
|
|
|
# prev_r_dists shape: [ batch, n heads, cell count ]
|
|
# Emulate matrix-vector multiplication by broadcast and sum. This way we don't need to transpose the matrix
|
|
tlm_multi_head = self.temp_link_mat.unsqueeze(1)
|
|
|
|
forward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-2)).sum(-1)
|
|
backward_dist = (tlm_multi_head * prev_r_dists.unsqueeze(-1)).sum(-2)
|
|
|
|
dict_append(debug, "forward_dists", forward_dist)
|
|
dict_append(debug, "backward_dists", backward_dist)
|
|
dict_append(debug, "precedence_weights", self.precedence_weighting)
|
|
|
|
# output shapes [ batch, n_heads, cell_count ]
|
|
return forward_dist, backward_dist
|
|
|
|
|
|
class ReadHead(torch.nn.Module):
|
|
def __init__(self, disable_content_norm=False, mask_min=0.0, disable_key_masking=False):
|
|
super(ReadHead, self).__init__()
|
|
self.content_addr_generator = ContentAddressGenerator(disable_content_norm=disable_content_norm,
|
|
mask_min=mask_min,
|
|
disable_key_masking=disable_key_masking)
|
|
self.read_dist = None
|
|
self.read_data = None
|
|
self.new_sequence()
|
|
|
|
def new_sequence(self):
|
|
self.read_dist = None
|
|
self.read_data = None
|
|
|
|
def forward(self, memory, read_content_keys, read_betas, forward_dist, backward_dist, gates, read_mask=None, debug=None):
|
|
content_dist = self.content_addr_generator(memory, read_content_keys, read_betas, mask=read_mask)
|
|
|
|
self.read_dist = backward_dist * gates[..., 0:1] + content_dist * gates[...,1:2] + forward_dist * gates[..., 2:]
|
|
|
|
# memory shape: [ batch, cell count, word_length ]
|
|
# read_dist shape: [ batch, n heads, cell count ]
|
|
# result shape: [ batch, n_heads, word_length ]
|
|
self.read_data = (memory.unsqueeze(1) * self.read_dist.unsqueeze(-1)).sum(-2)
|
|
|
|
dict_append(debug, "content_dist", content_dist)
|
|
dict_append(debug, "balance", gates)
|
|
dict_append(debug, "read_dist", self.read_dist)
|
|
dict_append(debug, "read_content_keys", read_content_keys)
|
|
if read_mask is not None:
|
|
dict_append(debug, "read_mask", read_mask)
|
|
dict_append(debug, "read_betas", read_betas.unsqueeze(-2))
|
|
if read_mask is not None:
|
|
dict_append(debug, "read_mask", read_mask)
|
|
|
|
return self.read_data
|
|
|
|
|
|
class RawReadHead(torch.nn.Module):
|
|
def __init__(self, n_heads, word_length, use_mask=False, disable_content_norm=False, mask_min=0.0,
|
|
disable_key_masking=False):
|
|
super(RawReadHead, self).__init__()
|
|
self.read_head = ReadHead(disable_content_norm=disable_content_norm, mask_min=mask_min,
|
|
disable_key_masking=disable_key_masking)
|
|
self.n_heads = n_heads
|
|
self.word_length = word_length
|
|
self.use_mask = use_mask
|
|
self.input_size = self.n_heads * (self.word_length*(2 if use_mask else 1) + 3 + 1)
|
|
|
|
def get_prev_dist(self, memory):
|
|
if self.read_head.read_dist is not None:
|
|
return self.read_head.read_dist
|
|
else:
|
|
m_shape = memory.size()
|
|
return torch.zeros(m_shape[0], self.n_heads, m_shape[1]).to(memory)
|
|
|
|
def get_prev_data(self, memory):
|
|
if self.read_head.read_data is not None:
|
|
return self.read_head.read_data
|
|
else:
|
|
m_shape = memory.size()
|
|
return torch.zeros(m_shape[0], self.n_heads, m_shape[-1]).to(memory)
|
|
|
|
def new_sequence(self):
|
|
self.read_head.new_sequence()
|
|
|
|
def forward(self, memory, nn_output, forward_dist, backward_dist, debug):
|
|
shapes = [[self.n_heads, self.word_length]] * (2 if self.use_mask else 1) + [[self.n_heads], [self.n_heads, 3]]
|
|
tensors = split_tensor(nn_output, shapes)
|
|
|
|
if self.use_mask:
|
|
read_mask = torch.sigmoid(tensors[0])
|
|
tensors = tensors[1:]
|
|
else:
|
|
read_mask = None
|
|
|
|
keys, betas, gates = tensors
|
|
|
|
betas = oneplus(betas)
|
|
gates = F.softmax(gates, gates.dim()-1)
|
|
|
|
return self.read_head(memory, keys, betas, forward_dist, backward_dist, gates, debug=debug, read_mask=read_mask)
|
|
|
|
def get_neural_input_size(self):
|
|
return self.input_size
|
|
|
|
|
|
class DistSharpnessEnhancer(torch.nn.Module):
|
|
def __init__(self, n_heads):
|
|
super(DistSharpnessEnhancer, self).__init__()
|
|
self.n_heads = n_heads if isinstance(n_heads, list) else [n_heads]
|
|
self.n_data = sum(self.n_heads)
|
|
|
|
def forward(self, nn_input, *dists):
|
|
assert len(dists) == len(self.n_heads)
|
|
nn_input = oneplus(nn_input[..., :self.n_data])
|
|
factors = split_tensor(nn_input, self.n_heads)
|
|
|
|
res = []
|
|
for i, d in enumerate(dists):
|
|
s = list(d.size())
|
|
ndim = d.dim()
|
|
f = factors[i]
|
|
if ndim==2:
|
|
assert self.n_heads[i]==1
|
|
elif ndim==3:
|
|
f = f.unsqueeze(-1)
|
|
else:
|
|
assert False
|
|
|
|
d += _EPS
|
|
d = d / d.max(dim=-1, keepdim=True)[0]
|
|
d = d.pow(f)
|
|
d = d / d.sum(dim=-1, keepdim=True)
|
|
res.append(d)
|
|
return res
|
|
|
|
def get_neural_input_size(self):
|
|
return self.n_data
|
|
|
|
|
|
class DNC(torch.nn.Module):
|
|
def __init__(self, input_size, output_size, word_length, cell_count, n_read_heads, controller, batch_first=False, clip_controller=20,
|
|
bias=True, mask=False, dealloc_content=True, link_sharpness_control=True, disable_content_norm=False,
|
|
mask_min=0.0, disable_key_masking=False):
|
|
super(DNC, self).__init__()
|
|
|
|
self.clip_controller = clip_controller
|
|
|
|
self.read_head = RawReadHead(n_read_heads, word_length, use_mask=mask, disable_content_norm=disable_content_norm,
|
|
mask_min=mask_min, disable_key_masking=disable_key_masking)
|
|
self.write_head = RawWriteHead(n_read_heads, word_length, use_mask=mask, dealloc_content=dealloc_content,
|
|
disable_content_norm=disable_content_norm, mask_min=mask_min,
|
|
disable_key_masking=disable_key_masking)
|
|
self.temporal_link = TemporalMemoryLinkage()
|
|
self.sharpness_control = DistSharpnessEnhancer([n_read_heads, n_read_heads]) if link_sharpness_control else None
|
|
|
|
in_size = input_size + n_read_heads * word_length
|
|
control_channels = self.read_head.get_neural_input_size() + self.write_head.get_neural_input_size() +\
|
|
(self.sharpness_control.get_neural_input_size() if self.sharpness_control is not None else 0)
|
|
|
|
self.controller = controller
|
|
controller.init(in_size)
|
|
self.controller_to_controls = torch.nn.Linear(controller.get_output_size(), control_channels, bias=bias)
|
|
self.controller_to_out = torch.nn.Linear(controller.get_output_size(), output_size, bias=bias)
|
|
self.read_to_out = torch.nn.Linear(word_length * n_read_heads, output_size, bias=bias)
|
|
|
|
self.cell_count = cell_count
|
|
self.word_length = word_length
|
|
|
|
self.memory = None
|
|
self.reset_parameters()
|
|
|
|
self.batch_first = batch_first
|
|
self.zero_mem_tensor = None
|
|
|
|
def reset_parameters(self):
|
|
linear_reset(self.controller_to_controls)
|
|
linear_reset(self.controller_to_out)
|
|
linear_reset(self.read_to_out)
|
|
self.controller.reset_parameters()
|
|
|
|
def _step(self, in_data, debug):
|
|
init_debug(debug, {
|
|
"read_head": {},
|
|
"write_head": {},
|
|
"temporal_links": {}
|
|
})
|
|
|
|
# input shape: [ batch, channels ]
|
|
batch_size = in_data.size(0)
|
|
|
|
# run the controller
|
|
prev_read_data = self.read_head.get_prev_data(self.memory).view([batch_size, -1])
|
|
|
|
control_data = self.controller(torch.cat([in_data, prev_read_data], -1))
|
|
|
|
# memory ops
|
|
controls = self.controller_to_controls(control_data).contiguous()
|
|
controls = controls.clamp(-self.clip_controller, self.clip_controller) if self.clip_controller is not None else controls
|
|
|
|
shapes = [[self.write_head.get_neural_input_size()], [self.read_head.get_neural_input_size()]]
|
|
if self.sharpness_control is not None:
|
|
shapes.append(self.sharpness_control.get_neural_input_size())
|
|
|
|
tensors = split_tensor(controls, shapes)
|
|
|
|
write_head_control, read_head_control = tensors[:2]
|
|
tensors = tensors[2:]
|
|
|
|
prev_read_dist = self.read_head.get_prev_dist(self.memory)
|
|
|
|
self.memory = self.write_head(self.memory, write_head_control, prev_read_dist, debug=dict_get(debug,"write_head"))
|
|
|
|
prev_write = self.write_head.get_prev_write()
|
|
forward_dist, backward_dist = self.temporal_link(prev_write["write_dist"] if prev_write is not None else None, prev_read_dist, debug=dict_get(debug, "temporal_links"))
|
|
|
|
if self.sharpness_control is not None:
|
|
forward_dist, backward_dist = self.sharpness_control(tensors[0], forward_dist, backward_dist)
|
|
|
|
read_data = self.read_head(self.memory, read_head_control, forward_dist, backward_dist, debug=dict_get(debug,"read_head"))
|
|
|
|
# output:
|
|
return self.controller_to_out(control_data) + self.read_to_out(read_data.view(batch_size,-1))
|
|
|
|
def _mem_init(self, batch_size, device):
|
|
if self.zero_mem_tensor is None or self.zero_mem_tensor.size(0)!=batch_size:
|
|
self.zero_mem_tensor = torch.zeros(batch_size, self.cell_count, self.word_length).to(device)
|
|
|
|
self.memory = self.zero_mem_tensor
|
|
|
|
def forward(self, in_data, debug=None):
|
|
self.write_head.new_sequence()
|
|
self.read_head.new_sequence()
|
|
self.temporal_link.new_sequence()
|
|
self.controller.new_sequence()
|
|
|
|
self._mem_init(in_data.size(0 if self.batch_first else 1), in_data.device)
|
|
|
|
out_tsteps = []
|
|
|
|
if self.batch_first:
|
|
# input format: batch, time, channels
|
|
for t in range(in_data.size(1)):
|
|
out_tsteps.append(self._step(in_data[:,t], debug))
|
|
else:
|
|
# input format: time, batch, channels
|
|
for t in range(in_data.size(0)):
|
|
out_tsteps.append(self._step(in_data[t], debug))
|
|
|
|
merge_debug_tensors(debug, dim=1 if self.batch_first else 0)
|
|
return torch.stack(out_tsteps, dim=1 if self.batch_first else 0)
|
|
|
|
class LSTMController(torch.nn.Module):
|
|
def __init__(self, layer_sizes, out_from_all_layers=True):
|
|
super(LSTMController, self).__init__()
|
|
self.out_from_all_layers = out_from_all_layers
|
|
self.layer_sizes = layer_sizes
|
|
self.states = None
|
|
self.outputs = None
|
|
|
|
def new_sequence(self):
|
|
self.states = [None] * len(self.layer_sizes)
|
|
self.outputs = [None] * len(self.layer_sizes)
|
|
|
|
def reset_parameters(self):
|
|
def init_layer(l, index):
|
|
size = self.layer_sizes[index]
|
|
# Initialize all matrices to sigmoid, just data input to tanh
|
|
a=math.sqrt(3.0)*self.stdevs[i]
|
|
l.weight.data[0:-size].uniform_(-a,a)
|
|
a*=init.calculate_gain("tanh")
|
|
l.weight.data[-size:].uniform_(-a, a)
|
|
if l.bias is not None:
|
|
l.bias.data[self.layer_sizes[i]:].fill_(0)
|
|
# init forget gate to large number.
|
|
l.bias.data[:self.layer_sizes[i]].fill_(1)
|
|
|
|
# xavier init merged input weights
|
|
for i in range(len(self.layer_sizes)):
|
|
init_layer(self.in_to_all[i], i)
|
|
init_layer(self.out_to_all[i], i)
|
|
if i>0:
|
|
init_layer(self.prev_to_all[i-1], i)
|
|
|
|
def _add_modules(self, name, m_list):
|
|
for i, m in enumerate(m_list):
|
|
self.add_module("%s_%d" % (name,i), m)
|
|
|
|
def init(self, input_size):
|
|
self.layer_sizes = self.layer_sizes
|
|
|
|
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
|
|
# So use xavier init according to this.
|
|
self.input_sizes = [(self.layer_sizes[i - 1] if i>0 else 0) + self.layer_sizes[i] + input_size
|
|
for i in range(len(self.layer_sizes))]
|
|
self.stdevs = [math.sqrt(2.0 / (self.layer_sizes[i] + self.input_sizes[i])) for i in range(len(self.layer_sizes))]
|
|
self.in_to_all= [torch.nn.Linear(input_size, 4*self.layer_sizes[i]) for i in range(len(self.layer_sizes))]
|
|
self.out_to_all = [torch.nn.Linear(self.layer_sizes[i], 4 * self.layer_sizes[i], bias=False) for i in range(len(self.layer_sizes))]
|
|
self.prev_to_all = [torch.nn.Linear(self.layer_sizes[i-1], 4 * self.layer_sizes[i], bias=False) for i in range(1,len(self.layer_sizes))]
|
|
|
|
self._add_modules("in_to_all", self.in_to_all)
|
|
self._add_modules("out_to_all", self.out_to_all)
|
|
self._add_modules("prev_to_all", self.prev_to_all)
|
|
|
|
self.reset_parameters()
|
|
|
|
def get_output_size(self):
|
|
return sum(self.layer_sizes) if self.out_from_all_layers else self.layer_sizes[-1]
|
|
|
|
def forward(self, data):
|
|
for i, size in enumerate(self.layer_sizes):
|
|
d = self.in_to_all[i](data)
|
|
if self.outputs[i] is not None:
|
|
d+=self.out_to_all[i](self.outputs[i])
|
|
if i>0:
|
|
d+=self.prev_to_all[i-1](self.outputs[i-1])
|
|
|
|
input_data = torch.tanh(d[...,-size:])
|
|
forget_gate, input_gate, output_gate = torch.sigmoid(d[...,:-size]).chunk(3,dim=-1)
|
|
|
|
state_update = input_gate * input_data
|
|
|
|
if self.states[i] is not None:
|
|
self.states[i] = self.states[i]*forget_gate + state_update
|
|
else:
|
|
self.states[i] = state_update
|
|
|
|
self.outputs[i] = output_gate * torch.tanh(self.states[i])
|
|
|
|
return torch.cat(self.outputs, -1) if self.out_from_all_layers else self.outputs[-1]
|
|
|
|
|
|
class FeedforwardController(torch.nn.Module):
|
|
def __init__(self, layer_sizes=[]):
|
|
super(FeedforwardController, self).__init__()
|
|
self.layer_sizes = layer_sizes
|
|
|
|
def new_sequence(self):
|
|
pass
|
|
|
|
def reset_parameters(self):
|
|
for module in self.model:
|
|
if isinstance(module, torch.nn.Linear):
|
|
linear_reset(module, gain=init.calculate_gain("relu"))
|
|
|
|
def get_output_size(self):
|
|
return self.layer_sizes[-1]
|
|
|
|
def init(self, input_size):
|
|
self.layer_sizes = self.layer_sizes
|
|
|
|
# Xavier init: input to all gates is layers_sizes[i-1] + layer_sizes[i] + input_size -> layer_size big.
|
|
# So use xavier init according to this.
|
|
self.input_sizes = [input_size] + self.layer_sizes[:-1]
|
|
|
|
layers = []
|
|
for i, size in enumerate(self.layer_sizes):
|
|
layers.append(torch.nn.Linear(self.input_sizes[i], self.layer_sizes[i]))
|
|
layers.append(torch.nn.ReLU())
|
|
self.model = torch.nn.Sequential(*layers)
|
|
self.reset_parameters()
|
|
|
|
def forward(self, data):
|
|
return self.model(data)
|