dnc_pytorch/dnc/access.py

354 lines
14 KiB
Python
Raw Permalink Normal View History

2022-08-23 20:58:43 +08:00
from dataclasses import dataclass
from typing import Dict, Tuple, Optional
import torch
from torch import Tensor
from .addressing import (
CosineWeights,
Freeness,
TemporalLinkage,
TemporalLinkageState,
)
@dataclass
class AccessState:
memory: Tensor # [batch_size, memory_size, word_size]
read_weights: Tensor # [batch_size, num_reads, memory_size]
write_weights: Tensor # [batch_size, num_writes, memory_size]
linkage: TemporalLinkageState
# link: [batch_size, num_writes, memory_size, memory_size]
# precedence_weights: [batch_size, num_writes, memory_size]
usage: Tensor # [batch_size, memory_size]
def _erase_and_write(memory, address, reset_weights, values):
"""Module to erase and write in the external memory.
Erase operation:
M_t'(i) = M_{t-1}(i) * (1 - w_t(i) * e_t)
Add operation:
M_t(i) = M_t'(i) + w_t(i) * a_t
where e are the reset_weights, w the write weights and a the values.
Args:
memory: 3-D tensor of shape `[batch_size, memory_size, word_size]`.
address: 3-D tensor `[batch_size, num_writes, memory_size]`.
reset_weights: 3-D tensor `[batch_size, num_writes, word_size]`.
values: 3-D tensor `[batch_size, num_writes, word_size]`.
Returns:
3-D tensor of shape `[batch_size, num_writes, word_size]`.
"""
weighted_resets = address.unsqueeze(3) * reset_weights.unsqueeze(2)
reset_gate = torch.prod(1 - weighted_resets, dim=1)
return memory * reset_gate + torch.matmul(address.transpose(-1, -2), values)
class Reshape(torch.nn.Module):
def __init__(self, dims):
super().__init__()
self.dims = dims
def forward(self, input):
return input.reshape(self.dims)
class MemoryAccess(torch.nn.Module):
"""Access module of the Differentiable Neural Computer.
This memory module supports multiple read and write heads. It makes use of:
* `addressing.TemporalLinkage` to track the temporal ordering of writes in
memory for each write head.
* `addressing.FreenessAllocator` for keeping track of memory usage, where
usage increase when a memory location is written to, and decreases when
memory is read from that the controller says can be freed.
Write-address selection is done by an interpolation between content-based
lookup and using unused memory.
Read-address selection is done by an interpolation of content-based lookup
and following the link graph in the forward or backwards read direction.
"""
def __init__(self, memory_size=128, word_size=20, num_reads=1, num_writes=1):
"""Creates a MemoryAccess module.
Args:
memory_size: The number of memory slots (N in the DNC paper).
word_size: The width of each memory slot (W in the DNC paper)
num_reads: The number of read heads (R in the DNC paper).
num_writes: The number of write heads (fixed at 1 in the paper).
name: The name of the module.
"""
super().__init__()
self._memory_size = memory_size
self._word_size = word_size
self._num_reads = num_reads
self._num_writes = num_writes
self._write_content_weights_mod = CosineWeights(num_writes, word_size)
self._read_content_weights_mod = CosineWeights(num_reads, word_size)
self._linkage = TemporalLinkage(memory_size, num_writes)
self._freeness = Freeness(memory_size)
def _linear(first_dim, second_dim, pre_act=None, post_act=None):
"""Returns a linear transformation of `inputs`, followed by a reshape."""
mods = []
mods.append(torch.nn.LazyLinear(first_dim * second_dim))
if pre_act is not None:
mods.append(pre_act)
mods.append(Reshape([-1, first_dim, second_dim]))
if post_act is not None:
mods.append(post_act)
return torch.nn.Sequential(*mods)
self._write_vectors = _linear(num_writes, word_size)
self._erase_vectors = _linear(num_writes, word_size, pre_act=torch.nn.Sigmoid())
self._free_gate = torch.nn.Sequential(
torch.nn.LazyLinear(num_reads),
torch.nn.Sigmoid(),
)
self._alloc_gate = torch.nn.Sequential(
torch.nn.LazyLinear(num_writes),
torch.nn.Sigmoid(),
)
self._write_gate = torch.nn.Sequential(
torch.nn.LazyLinear(num_writes),
torch.nn.Sigmoid(),
)
num_read_modes = 1 + 2 * num_writes
self._read_mode = _linear(
num_reads, num_read_modes, post_act=torch.nn.Softmax(dim=-1)
)
self._write_keys = _linear(num_writes, word_size)
self._write_strengths = torch.nn.LazyLinear(num_writes)
self._read_keys = _linear(num_reads, word_size)
self._read_strengths = torch.nn.LazyLinear(num_reads)
def _read_inputs(self, inputs: Tensor) -> Dict[str, Tensor]:
"""Applies transformations to `inputs` to get control for this module."""
# v_t^i - The vectors to write to memory, for each write head `i`.
write_vectors = self._write_vectors(inputs)
# e_t^i - Amount to erase the memory by before writing, for each write head.
erase_vectors = self._erase_vectors(inputs)
# f_t^j - Amount that the memory at the locations read from at the previous
# time step can be declared unused, for each read head `j`.
free_gate = self._free_gate(inputs)
# g_t^{a, i} - Interpolation between writing to unallocated memory and
# content-based lookup, for each write head `i`. Note: `a` is simply used to
# identify this gate with allocation vs writing (as defined below).
allocation_gate = self._alloc_gate(inputs)
# g_t^{w, i} - Overall gating of write amount for each write head.
write_gate = self._write_gate(inputs)
# \pi_t^j - Mixing between "backwards" and "forwards" positions (for
# each write head), and content-based lookup, for each read head.
read_mode = self._read_mode(inputs)
# Parameters for the (read / write) "weights by content matching" modules.
write_keys = self._write_keys(inputs)
write_strengths = self._write_strengths(inputs)
read_keys = self._read_keys(inputs)
read_strengths = self._read_strengths(inputs)
result = {
"read_content_keys": read_keys,
"read_content_strengths": read_strengths,
"write_content_keys": write_keys,
"write_content_strengths": write_strengths,
"write_vectors": write_vectors,
"erase_vectors": erase_vectors,
"free_gate": free_gate,
"allocation_gate": allocation_gate,
"write_gate": write_gate,
"read_mode": read_mode,
}
return result
def _write_weights(
self,
inputs: Tensor,
memory: Tensor,
usage: Tensor,
) -> Tensor:
"""Calculates the memory locations to write to.
This uses a combination of content-based lookup and finding an unused
location in memory, for each write head.
Args:
inputs: Collection of inputs to the access module, including controls for
how to chose memory writing, such as the content to look-up and the
weighting between content-based and allocation-based addressing.
memory: A tensor of shape `[batch_size, memory_size, word_size]`
containing the current memory contents.
usage: Current memory usage, which is a tensor of shape `[batch_size,
memory_size]`, used for allocation-based addressing.
Returns:
tensor of shape `[batch_size, num_writes, memory_size]` indicating where
to write to (if anywhere) for each write head.
"""
# c_t^{w, i} - The content-based weights for each write head.
write_content_weights = self._write_content_weights_mod(
memory, inputs["write_content_keys"], inputs["write_content_strengths"]
)
# a_t^i - The allocation weights for each write head.
write_allocation_weights = self._freeness.write_allocation_weights(
usage=usage,
write_gates=(inputs["allocation_gate"] * inputs["write_gate"]),
num_writes=self._num_writes,
)
# Expands gates over memory locations.
allocation_gate = inputs["allocation_gate"].unsqueeze(-1)
write_gate = inputs["write_gate"].unsqueeze(-1)
# w_t^{w, i} - The write weightings for each write head.
return write_gate * (
allocation_gate * write_allocation_weights
+ (1 - allocation_gate) * write_content_weights
)
def _read_weights(
self,
inputs: Tensor,
memory: Tensor,
prev_read_weights: Tensor,
link: Tensor,
) -> Tensor:
"""Calculates read weights for each read head.
The read weights are a combination of following the link graphs in the
forward or backward directions from the previous read position, and doing
content-based lookup. The interpolation between these different modes is
done by `inputs['read_mode']`.
Args:
inputs: Controls for this access module. This contains the content-based
keys to lookup, and the weightings for the different read modes.
memory: A tensor of shape `[batch_size, memory_size, word_size]`
containing the current memory contents to do content-based lookup.
prev_read_weights: A tensor of shape `[batch_size, num_reads,
memory_size]` containing the previous read locations.
link: A tensor of shape `[batch_size, num_writes, memory_size,
memory_size]` containing the temporal write transition graphs.
Returns:
A tensor of shape `[batch_size, num_reads, memory_size]` containing the
read weights for each read head.
"""
# c_t^{r, i} - The content weightings for each read head.
content_weights = self._read_content_weights_mod(
memory, inputs["read_content_keys"], inputs["read_content_strengths"]
)
# Calculates f_t^i and b_t^i.
forward_weights = self._linkage.directional_read_weights(
link, prev_read_weights, is_forward=True
)
backward_weights = self._linkage.directional_read_weights(
link, prev_read_weights, is_forward=False
)
m = self._num_writes
backward_mode = inputs["read_mode"][:, :, :m, None]
forward_mode = inputs["read_mode"][:, :, m : 2 * m, None]
content_mode = inputs["read_mode"][:, :, None, 2 * m]
read_weights = (
content_mode * content_weights
+ (forward_mode * forward_weights).sum(dim=2)
+ (backward_mode * backward_weights).sum(dim=2)
)
return read_weights
def forward(
self,
inputs: Tensor,
prev_state: Optional[AccessState] = None,
) -> Tuple[Tensor, AccessState]:
"""Connects the MemoryAccess module into the graph.
Args:
inputs: tensor of shape `[batch_size, input_size]`. This is used to
control this access module.
prev_state: Instance of `AccessState` containing the previous state.
Returns:
A tuple `(output, next_state)`, where `output` is a tensor of shape
`[batch_size, num_reads, word_size]`, and `next_state` is the new
`AccessState` named tuple at the current time t.
"""
if inputs.ndim != 2:
raise ValueError("Expected `inputs` to be 2D. Found: {inputs.ndim}.")
if prev_state is None:
B, device = inputs.shape[0], inputs.device
prev_state = AccessState(
memory=torch.zeros(
(B, self._memory_size, self._word_size), device=device
),
read_weights=torch.zeros(
(B, self._num_reads, self._memory_size), device=device
),
write_weights=torch.zeros(
(B, self._num_writes, self._memory_size), device=device
),
linkage=None,
usage=torch.zeros((B, self._memory_size), device=device),
)
inputs = self._read_inputs(inputs)
# Update usage using inputs['free_gate'] and previous read & write weights.
usage = self._freeness(
write_weights=prev_state.write_weights,
free_gate=inputs["free_gate"],
read_weights=prev_state.read_weights,
prev_usage=prev_state.usage,
)
# Write to memory.
write_weights = self._write_weights(inputs, prev_state.memory, usage)
memory = _erase_and_write(
prev_state.memory,
address=write_weights,
reset_weights=inputs["erase_vectors"],
values=inputs["write_vectors"],
)
linkage_state = self._linkage(write_weights, prev_state.linkage)
# Read from memory.
read_weights = self._read_weights(
inputs,
memory=memory,
prev_read_weights=prev_state.read_weights,
link=linkage_state.link,
)
read_words = torch.matmul(read_weights, memory)
return (
read_words,
AccessState(
memory=memory,
read_weights=read_weights,
write_weights=write_weights,
linkage=linkage_state,
usage=usage,
),
)