354 lines
14 KiB
Python
354 lines
14 KiB
Python
|
from dataclasses import dataclass
|
||
|
from typing import Dict, Tuple, Optional
|
||
|
|
||
|
import torch
|
||
|
from torch import Tensor
|
||
|
|
||
|
from .addressing import (
|
||
|
CosineWeights,
|
||
|
Freeness,
|
||
|
TemporalLinkage,
|
||
|
TemporalLinkageState,
|
||
|
)
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class AccessState:
|
||
|
memory: Tensor # [batch_size, memory_size, word_size]
|
||
|
read_weights: Tensor # [batch_size, num_reads, memory_size]
|
||
|
write_weights: Tensor # [batch_size, num_writes, memory_size]
|
||
|
linkage: TemporalLinkageState
|
||
|
# link: [batch_size, num_writes, memory_size, memory_size]
|
||
|
# precedence_weights: [batch_size, num_writes, memory_size]
|
||
|
usage: Tensor # [batch_size, memory_size]
|
||
|
|
||
|
|
||
|
def _erase_and_write(memory, address, reset_weights, values):
|
||
|
"""Module to erase and write in the external memory.
|
||
|
|
||
|
Erase operation:
|
||
|
M_t'(i) = M_{t-1}(i) * (1 - w_t(i) * e_t)
|
||
|
|
||
|
Add operation:
|
||
|
M_t(i) = M_t'(i) + w_t(i) * a_t
|
||
|
|
||
|
where e are the reset_weights, w the write weights and a the values.
|
||
|
|
||
|
Args:
|
||
|
memory: 3-D tensor of shape `[batch_size, memory_size, word_size]`.
|
||
|
address: 3-D tensor `[batch_size, num_writes, memory_size]`.
|
||
|
reset_weights: 3-D tensor `[batch_size, num_writes, word_size]`.
|
||
|
values: 3-D tensor `[batch_size, num_writes, word_size]`.
|
||
|
|
||
|
Returns:
|
||
|
3-D tensor of shape `[batch_size, num_writes, word_size]`.
|
||
|
"""
|
||
|
weighted_resets = address.unsqueeze(3) * reset_weights.unsqueeze(2)
|
||
|
reset_gate = torch.prod(1 - weighted_resets, dim=1)
|
||
|
return memory * reset_gate + torch.matmul(address.transpose(-1, -2), values)
|
||
|
|
||
|
|
||
|
class Reshape(torch.nn.Module):
|
||
|
def __init__(self, dims):
|
||
|
super().__init__()
|
||
|
self.dims = dims
|
||
|
|
||
|
def forward(self, input):
|
||
|
return input.reshape(self.dims)
|
||
|
|
||
|
|
||
|
class MemoryAccess(torch.nn.Module):
|
||
|
"""Access module of the Differentiable Neural Computer.
|
||
|
|
||
|
This memory module supports multiple read and write heads. It makes use of:
|
||
|
|
||
|
* `addressing.TemporalLinkage` to track the temporal ordering of writes in
|
||
|
memory for each write head.
|
||
|
* `addressing.FreenessAllocator` for keeping track of memory usage, where
|
||
|
usage increase when a memory location is written to, and decreases when
|
||
|
memory is read from that the controller says can be freed.
|
||
|
|
||
|
Write-address selection is done by an interpolation between content-based
|
||
|
lookup and using unused memory.
|
||
|
|
||
|
Read-address selection is done by an interpolation of content-based lookup
|
||
|
and following the link graph in the forward or backwards read direction.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, memory_size=128, word_size=20, num_reads=1, num_writes=1):
|
||
|
"""Creates a MemoryAccess module.
|
||
|
|
||
|
Args:
|
||
|
memory_size: The number of memory slots (N in the DNC paper).
|
||
|
word_size: The width of each memory slot (W in the DNC paper)
|
||
|
num_reads: The number of read heads (R in the DNC paper).
|
||
|
num_writes: The number of write heads (fixed at 1 in the paper).
|
||
|
name: The name of the module.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self._memory_size = memory_size
|
||
|
self._word_size = word_size
|
||
|
self._num_reads = num_reads
|
||
|
self._num_writes = num_writes
|
||
|
|
||
|
self._write_content_weights_mod = CosineWeights(num_writes, word_size)
|
||
|
self._read_content_weights_mod = CosineWeights(num_reads, word_size)
|
||
|
|
||
|
self._linkage = TemporalLinkage(memory_size, num_writes)
|
||
|
self._freeness = Freeness(memory_size)
|
||
|
|
||
|
def _linear(first_dim, second_dim, pre_act=None, post_act=None):
|
||
|
"""Returns a linear transformation of `inputs`, followed by a reshape."""
|
||
|
mods = []
|
||
|
mods.append(torch.nn.LazyLinear(first_dim * second_dim))
|
||
|
if pre_act is not None:
|
||
|
mods.append(pre_act)
|
||
|
mods.append(Reshape([-1, first_dim, second_dim]))
|
||
|
if post_act is not None:
|
||
|
mods.append(post_act)
|
||
|
return torch.nn.Sequential(*mods)
|
||
|
|
||
|
self._write_vectors = _linear(num_writes, word_size)
|
||
|
self._erase_vectors = _linear(num_writes, word_size, pre_act=torch.nn.Sigmoid())
|
||
|
self._free_gate = torch.nn.Sequential(
|
||
|
torch.nn.LazyLinear(num_reads),
|
||
|
torch.nn.Sigmoid(),
|
||
|
)
|
||
|
self._alloc_gate = torch.nn.Sequential(
|
||
|
torch.nn.LazyLinear(num_writes),
|
||
|
torch.nn.Sigmoid(),
|
||
|
)
|
||
|
self._write_gate = torch.nn.Sequential(
|
||
|
torch.nn.LazyLinear(num_writes),
|
||
|
torch.nn.Sigmoid(),
|
||
|
)
|
||
|
num_read_modes = 1 + 2 * num_writes
|
||
|
self._read_mode = _linear(
|
||
|
num_reads, num_read_modes, post_act=torch.nn.Softmax(dim=-1)
|
||
|
)
|
||
|
self._write_keys = _linear(num_writes, word_size)
|
||
|
self._write_strengths = torch.nn.LazyLinear(num_writes)
|
||
|
|
||
|
self._read_keys = _linear(num_reads, word_size)
|
||
|
self._read_strengths = torch.nn.LazyLinear(num_reads)
|
||
|
|
||
|
def _read_inputs(self, inputs: Tensor) -> Dict[str, Tensor]:
|
||
|
"""Applies transformations to `inputs` to get control for this module."""
|
||
|
# v_t^i - The vectors to write to memory, for each write head `i`.
|
||
|
write_vectors = self._write_vectors(inputs)
|
||
|
|
||
|
# e_t^i - Amount to erase the memory by before writing, for each write head.
|
||
|
erase_vectors = self._erase_vectors(inputs)
|
||
|
|
||
|
# f_t^j - Amount that the memory at the locations read from at the previous
|
||
|
# time step can be declared unused, for each read head `j`.
|
||
|
free_gate = self._free_gate(inputs)
|
||
|
|
||
|
# g_t^{a, i} - Interpolation between writing to unallocated memory and
|
||
|
# content-based lookup, for each write head `i`. Note: `a` is simply used to
|
||
|
# identify this gate with allocation vs writing (as defined below).
|
||
|
allocation_gate = self._alloc_gate(inputs)
|
||
|
|
||
|
# g_t^{w, i} - Overall gating of write amount for each write head.
|
||
|
write_gate = self._write_gate(inputs)
|
||
|
|
||
|
# \pi_t^j - Mixing between "backwards" and "forwards" positions (for
|
||
|
# each write head), and content-based lookup, for each read head.
|
||
|
read_mode = self._read_mode(inputs)
|
||
|
|
||
|
# Parameters for the (read / write) "weights by content matching" modules.
|
||
|
write_keys = self._write_keys(inputs)
|
||
|
write_strengths = self._write_strengths(inputs)
|
||
|
|
||
|
read_keys = self._read_keys(inputs)
|
||
|
read_strengths = self._read_strengths(inputs)
|
||
|
|
||
|
result = {
|
||
|
"read_content_keys": read_keys,
|
||
|
"read_content_strengths": read_strengths,
|
||
|
"write_content_keys": write_keys,
|
||
|
"write_content_strengths": write_strengths,
|
||
|
"write_vectors": write_vectors,
|
||
|
"erase_vectors": erase_vectors,
|
||
|
"free_gate": free_gate,
|
||
|
"allocation_gate": allocation_gate,
|
||
|
"write_gate": write_gate,
|
||
|
"read_mode": read_mode,
|
||
|
}
|
||
|
return result
|
||
|
|
||
|
def _write_weights(
|
||
|
self,
|
||
|
inputs: Tensor,
|
||
|
memory: Tensor,
|
||
|
usage: Tensor,
|
||
|
) -> Tensor:
|
||
|
"""Calculates the memory locations to write to.
|
||
|
|
||
|
This uses a combination of content-based lookup and finding an unused
|
||
|
location in memory, for each write head.
|
||
|
|
||
|
Args:
|
||
|
inputs: Collection of inputs to the access module, including controls for
|
||
|
how to chose memory writing, such as the content to look-up and the
|
||
|
weighting between content-based and allocation-based addressing.
|
||
|
memory: A tensor of shape `[batch_size, memory_size, word_size]`
|
||
|
containing the current memory contents.
|
||
|
usage: Current memory usage, which is a tensor of shape `[batch_size,
|
||
|
memory_size]`, used for allocation-based addressing.
|
||
|
|
||
|
Returns:
|
||
|
tensor of shape `[batch_size, num_writes, memory_size]` indicating where
|
||
|
to write to (if anywhere) for each write head.
|
||
|
"""
|
||
|
# c_t^{w, i} - The content-based weights for each write head.
|
||
|
write_content_weights = self._write_content_weights_mod(
|
||
|
memory, inputs["write_content_keys"], inputs["write_content_strengths"]
|
||
|
)
|
||
|
|
||
|
# a_t^i - The allocation weights for each write head.
|
||
|
write_allocation_weights = self._freeness.write_allocation_weights(
|
||
|
usage=usage,
|
||
|
write_gates=(inputs["allocation_gate"] * inputs["write_gate"]),
|
||
|
num_writes=self._num_writes,
|
||
|
)
|
||
|
|
||
|
# Expands gates over memory locations.
|
||
|
allocation_gate = inputs["allocation_gate"].unsqueeze(-1)
|
||
|
write_gate = inputs["write_gate"].unsqueeze(-1)
|
||
|
|
||
|
# w_t^{w, i} - The write weightings for each write head.
|
||
|
return write_gate * (
|
||
|
allocation_gate * write_allocation_weights
|
||
|
+ (1 - allocation_gate) * write_content_weights
|
||
|
)
|
||
|
|
||
|
def _read_weights(
|
||
|
self,
|
||
|
inputs: Tensor,
|
||
|
memory: Tensor,
|
||
|
prev_read_weights: Tensor,
|
||
|
link: Tensor,
|
||
|
) -> Tensor:
|
||
|
"""Calculates read weights for each read head.
|
||
|
|
||
|
The read weights are a combination of following the link graphs in the
|
||
|
forward or backward directions from the previous read position, and doing
|
||
|
content-based lookup. The interpolation between these different modes is
|
||
|
done by `inputs['read_mode']`.
|
||
|
|
||
|
Args:
|
||
|
inputs: Controls for this access module. This contains the content-based
|
||
|
keys to lookup, and the weightings for the different read modes.
|
||
|
memory: A tensor of shape `[batch_size, memory_size, word_size]`
|
||
|
containing the current memory contents to do content-based lookup.
|
||
|
prev_read_weights: A tensor of shape `[batch_size, num_reads,
|
||
|
memory_size]` containing the previous read locations.
|
||
|
link: A tensor of shape `[batch_size, num_writes, memory_size,
|
||
|
memory_size]` containing the temporal write transition graphs.
|
||
|
|
||
|
Returns:
|
||
|
A tensor of shape `[batch_size, num_reads, memory_size]` containing the
|
||
|
read weights for each read head.
|
||
|
"""
|
||
|
# c_t^{r, i} - The content weightings for each read head.
|
||
|
content_weights = self._read_content_weights_mod(
|
||
|
memory, inputs["read_content_keys"], inputs["read_content_strengths"]
|
||
|
)
|
||
|
|
||
|
# Calculates f_t^i and b_t^i.
|
||
|
forward_weights = self._linkage.directional_read_weights(
|
||
|
link, prev_read_weights, is_forward=True
|
||
|
)
|
||
|
backward_weights = self._linkage.directional_read_weights(
|
||
|
link, prev_read_weights, is_forward=False
|
||
|
)
|
||
|
|
||
|
m = self._num_writes
|
||
|
backward_mode = inputs["read_mode"][:, :, :m, None]
|
||
|
forward_mode = inputs["read_mode"][:, :, m : 2 * m, None]
|
||
|
content_mode = inputs["read_mode"][:, :, None, 2 * m]
|
||
|
|
||
|
read_weights = (
|
||
|
content_mode * content_weights
|
||
|
+ (forward_mode * forward_weights).sum(dim=2)
|
||
|
+ (backward_mode * backward_weights).sum(dim=2)
|
||
|
)
|
||
|
return read_weights
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
inputs: Tensor,
|
||
|
prev_state: Optional[AccessState] = None,
|
||
|
) -> Tuple[Tensor, AccessState]:
|
||
|
"""Connects the MemoryAccess module into the graph.
|
||
|
|
||
|
Args:
|
||
|
inputs: tensor of shape `[batch_size, input_size]`. This is used to
|
||
|
control this access module.
|
||
|
prev_state: Instance of `AccessState` containing the previous state.
|
||
|
|
||
|
Returns:
|
||
|
A tuple `(output, next_state)`, where `output` is a tensor of shape
|
||
|
`[batch_size, num_reads, word_size]`, and `next_state` is the new
|
||
|
`AccessState` named tuple at the current time t.
|
||
|
"""
|
||
|
if inputs.ndim != 2:
|
||
|
raise ValueError("Expected `inputs` to be 2D. Found: {inputs.ndim}.")
|
||
|
if prev_state is None:
|
||
|
B, device = inputs.shape[0], inputs.device
|
||
|
prev_state = AccessState(
|
||
|
memory=torch.zeros(
|
||
|
(B, self._memory_size, self._word_size), device=device
|
||
|
),
|
||
|
read_weights=torch.zeros(
|
||
|
(B, self._num_reads, self._memory_size), device=device
|
||
|
),
|
||
|
write_weights=torch.zeros(
|
||
|
(B, self._num_writes, self._memory_size), device=device
|
||
|
),
|
||
|
linkage=None,
|
||
|
usage=torch.zeros((B, self._memory_size), device=device),
|
||
|
)
|
||
|
|
||
|
inputs = self._read_inputs(inputs)
|
||
|
|
||
|
# Update usage using inputs['free_gate'] and previous read & write weights.
|
||
|
usage = self._freeness(
|
||
|
write_weights=prev_state.write_weights,
|
||
|
free_gate=inputs["free_gate"],
|
||
|
read_weights=prev_state.read_weights,
|
||
|
prev_usage=prev_state.usage,
|
||
|
)
|
||
|
|
||
|
# Write to memory.
|
||
|
write_weights = self._write_weights(inputs, prev_state.memory, usage)
|
||
|
memory = _erase_and_write(
|
||
|
prev_state.memory,
|
||
|
address=write_weights,
|
||
|
reset_weights=inputs["erase_vectors"],
|
||
|
values=inputs["write_vectors"],
|
||
|
)
|
||
|
|
||
|
linkage_state = self._linkage(write_weights, prev_state.linkage)
|
||
|
|
||
|
# Read from memory.
|
||
|
read_weights = self._read_weights(
|
||
|
inputs,
|
||
|
memory=memory,
|
||
|
prev_read_weights=prev_state.read_weights,
|
||
|
link=linkage_state.link,
|
||
|
)
|
||
|
read_words = torch.matmul(read_weights, memory)
|
||
|
|
||
|
return (
|
||
|
read_words,
|
||
|
AccessState(
|
||
|
memory=memory,
|
||
|
read_weights=read_weights,
|
||
|
write_weights=write_weights,
|
||
|
linkage=linkage_state,
|
||
|
usage=usage,
|
||
|
),
|
||
|
)
|