Initial commit

This commit is contained in:
moto 2022-08-23 21:58:43 +09:00
parent 7544f49f69
commit 1a581b8aff
13 changed files with 1994 additions and 2 deletions

View File

@ -1,2 +1,11 @@
# dnc_pytorch
PyTorch Port of DeepMind's Differentiable Neural Computer
# Differential Neural Computer in PyTorch
This is PyTorch port of DeepMind's [Differentiable Neural Computer (DNC)](https://github.com/deepmind/dnc).
The original code was written with TensorFlow v1, which is not straightforward to set up in modern tech stack, such as Apple Scilicon and Python 3.8+.
This code requires only PyTorch >= 1.10.
The code structure and interfaces are kept almost same.
You can run repeat-copy task with `python train.py`.

0
dnc/__init__.py Normal file
View File

353
dnc/access.py Normal file
View File

@ -0,0 +1,353 @@
from dataclasses import dataclass
from typing import Dict, Tuple, Optional
import torch
from torch import Tensor
from .addressing import (
CosineWeights,
Freeness,
TemporalLinkage,
TemporalLinkageState,
)
@dataclass
class AccessState:
memory: Tensor # [batch_size, memory_size, word_size]
read_weights: Tensor # [batch_size, num_reads, memory_size]
write_weights: Tensor # [batch_size, num_writes, memory_size]
linkage: TemporalLinkageState
# link: [batch_size, num_writes, memory_size, memory_size]
# precedence_weights: [batch_size, num_writes, memory_size]
usage: Tensor # [batch_size, memory_size]
def _erase_and_write(memory, address, reset_weights, values):
"""Module to erase and write in the external memory.
Erase operation:
M_t'(i) = M_{t-1}(i) * (1 - w_t(i) * e_t)
Add operation:
M_t(i) = M_t'(i) + w_t(i) * a_t
where e are the reset_weights, w the write weights and a the values.
Args:
memory: 3-D tensor of shape `[batch_size, memory_size, word_size]`.
address: 3-D tensor `[batch_size, num_writes, memory_size]`.
reset_weights: 3-D tensor `[batch_size, num_writes, word_size]`.
values: 3-D tensor `[batch_size, num_writes, word_size]`.
Returns:
3-D tensor of shape `[batch_size, num_writes, word_size]`.
"""
weighted_resets = address.unsqueeze(3) * reset_weights.unsqueeze(2)
reset_gate = torch.prod(1 - weighted_resets, dim=1)
return memory * reset_gate + torch.matmul(address.transpose(-1, -2), values)
class Reshape(torch.nn.Module):
def __init__(self, dims):
super().__init__()
self.dims = dims
def forward(self, input):
return input.reshape(self.dims)
class MemoryAccess(torch.nn.Module):
"""Access module of the Differentiable Neural Computer.
This memory module supports multiple read and write heads. It makes use of:
* `addressing.TemporalLinkage` to track the temporal ordering of writes in
memory for each write head.
* `addressing.FreenessAllocator` for keeping track of memory usage, where
usage increase when a memory location is written to, and decreases when
memory is read from that the controller says can be freed.
Write-address selection is done by an interpolation between content-based
lookup and using unused memory.
Read-address selection is done by an interpolation of content-based lookup
and following the link graph in the forward or backwards read direction.
"""
def __init__(self, memory_size=128, word_size=20, num_reads=1, num_writes=1):
"""Creates a MemoryAccess module.
Args:
memory_size: The number of memory slots (N in the DNC paper).
word_size: The width of each memory slot (W in the DNC paper)
num_reads: The number of read heads (R in the DNC paper).
num_writes: The number of write heads (fixed at 1 in the paper).
name: The name of the module.
"""
super().__init__()
self._memory_size = memory_size
self._word_size = word_size
self._num_reads = num_reads
self._num_writes = num_writes
self._write_content_weights_mod = CosineWeights(num_writes, word_size)
self._read_content_weights_mod = CosineWeights(num_reads, word_size)
self._linkage = TemporalLinkage(memory_size, num_writes)
self._freeness = Freeness(memory_size)
def _linear(first_dim, second_dim, pre_act=None, post_act=None):
"""Returns a linear transformation of `inputs`, followed by a reshape."""
mods = []
mods.append(torch.nn.LazyLinear(first_dim * second_dim))
if pre_act is not None:
mods.append(pre_act)
mods.append(Reshape([-1, first_dim, second_dim]))
if post_act is not None:
mods.append(post_act)
return torch.nn.Sequential(*mods)
self._write_vectors = _linear(num_writes, word_size)
self._erase_vectors = _linear(num_writes, word_size, pre_act=torch.nn.Sigmoid())
self._free_gate = torch.nn.Sequential(
torch.nn.LazyLinear(num_reads),
torch.nn.Sigmoid(),
)
self._alloc_gate = torch.nn.Sequential(
torch.nn.LazyLinear(num_writes),
torch.nn.Sigmoid(),
)
self._write_gate = torch.nn.Sequential(
torch.nn.LazyLinear(num_writes),
torch.nn.Sigmoid(),
)
num_read_modes = 1 + 2 * num_writes
self._read_mode = _linear(
num_reads, num_read_modes, post_act=torch.nn.Softmax(dim=-1)
)
self._write_keys = _linear(num_writes, word_size)
self._write_strengths = torch.nn.LazyLinear(num_writes)
self._read_keys = _linear(num_reads, word_size)
self._read_strengths = torch.nn.LazyLinear(num_reads)
def _read_inputs(self, inputs: Tensor) -> Dict[str, Tensor]:
"""Applies transformations to `inputs` to get control for this module."""
# v_t^i - The vectors to write to memory, for each write head `i`.
write_vectors = self._write_vectors(inputs)
# e_t^i - Amount to erase the memory by before writing, for each write head.
erase_vectors = self._erase_vectors(inputs)
# f_t^j - Amount that the memory at the locations read from at the previous
# time step can be declared unused, for each read head `j`.
free_gate = self._free_gate(inputs)
# g_t^{a, i} - Interpolation between writing to unallocated memory and
# content-based lookup, for each write head `i`. Note: `a` is simply used to
# identify this gate with allocation vs writing (as defined below).
allocation_gate = self._alloc_gate(inputs)
# g_t^{w, i} - Overall gating of write amount for each write head.
write_gate = self._write_gate(inputs)
# \pi_t^j - Mixing between "backwards" and "forwards" positions (for
# each write head), and content-based lookup, for each read head.
read_mode = self._read_mode(inputs)
# Parameters for the (read / write) "weights by content matching" modules.
write_keys = self._write_keys(inputs)
write_strengths = self._write_strengths(inputs)
read_keys = self._read_keys(inputs)
read_strengths = self._read_strengths(inputs)
result = {
"read_content_keys": read_keys,
"read_content_strengths": read_strengths,
"write_content_keys": write_keys,
"write_content_strengths": write_strengths,
"write_vectors": write_vectors,
"erase_vectors": erase_vectors,
"free_gate": free_gate,
"allocation_gate": allocation_gate,
"write_gate": write_gate,
"read_mode": read_mode,
}
return result
def _write_weights(
self,
inputs: Tensor,
memory: Tensor,
usage: Tensor,
) -> Tensor:
"""Calculates the memory locations to write to.
This uses a combination of content-based lookup and finding an unused
location in memory, for each write head.
Args:
inputs: Collection of inputs to the access module, including controls for
how to chose memory writing, such as the content to look-up and the
weighting between content-based and allocation-based addressing.
memory: A tensor of shape `[batch_size, memory_size, word_size]`
containing the current memory contents.
usage: Current memory usage, which is a tensor of shape `[batch_size,
memory_size]`, used for allocation-based addressing.
Returns:
tensor of shape `[batch_size, num_writes, memory_size]` indicating where
to write to (if anywhere) for each write head.
"""
# c_t^{w, i} - The content-based weights for each write head.
write_content_weights = self._write_content_weights_mod(
memory, inputs["write_content_keys"], inputs["write_content_strengths"]
)
# a_t^i - The allocation weights for each write head.
write_allocation_weights = self._freeness.write_allocation_weights(
usage=usage,
write_gates=(inputs["allocation_gate"] * inputs["write_gate"]),
num_writes=self._num_writes,
)
# Expands gates over memory locations.
allocation_gate = inputs["allocation_gate"].unsqueeze(-1)
write_gate = inputs["write_gate"].unsqueeze(-1)
# w_t^{w, i} - The write weightings for each write head.
return write_gate * (
allocation_gate * write_allocation_weights
+ (1 - allocation_gate) * write_content_weights
)
def _read_weights(
self,
inputs: Tensor,
memory: Tensor,
prev_read_weights: Tensor,
link: Tensor,
) -> Tensor:
"""Calculates read weights for each read head.
The read weights are a combination of following the link graphs in the
forward or backward directions from the previous read position, and doing
content-based lookup. The interpolation between these different modes is
done by `inputs['read_mode']`.
Args:
inputs: Controls for this access module. This contains the content-based
keys to lookup, and the weightings for the different read modes.
memory: A tensor of shape `[batch_size, memory_size, word_size]`
containing the current memory contents to do content-based lookup.
prev_read_weights: A tensor of shape `[batch_size, num_reads,
memory_size]` containing the previous read locations.
link: A tensor of shape `[batch_size, num_writes, memory_size,
memory_size]` containing the temporal write transition graphs.
Returns:
A tensor of shape `[batch_size, num_reads, memory_size]` containing the
read weights for each read head.
"""
# c_t^{r, i} - The content weightings for each read head.
content_weights = self._read_content_weights_mod(
memory, inputs["read_content_keys"], inputs["read_content_strengths"]
)
# Calculates f_t^i and b_t^i.
forward_weights = self._linkage.directional_read_weights(
link, prev_read_weights, is_forward=True
)
backward_weights = self._linkage.directional_read_weights(
link, prev_read_weights, is_forward=False
)
m = self._num_writes
backward_mode = inputs["read_mode"][:, :, :m, None]
forward_mode = inputs["read_mode"][:, :, m : 2 * m, None]
content_mode = inputs["read_mode"][:, :, None, 2 * m]
read_weights = (
content_mode * content_weights
+ (forward_mode * forward_weights).sum(dim=2)
+ (backward_mode * backward_weights).sum(dim=2)
)
return read_weights
def forward(
self,
inputs: Tensor,
prev_state: Optional[AccessState] = None,
) -> Tuple[Tensor, AccessState]:
"""Connects the MemoryAccess module into the graph.
Args:
inputs: tensor of shape `[batch_size, input_size]`. This is used to
control this access module.
prev_state: Instance of `AccessState` containing the previous state.
Returns:
A tuple `(output, next_state)`, where `output` is a tensor of shape
`[batch_size, num_reads, word_size]`, and `next_state` is the new
`AccessState` named tuple at the current time t.
"""
if inputs.ndim != 2:
raise ValueError("Expected `inputs` to be 2D. Found: {inputs.ndim}.")
if prev_state is None:
B, device = inputs.shape[0], inputs.device
prev_state = AccessState(
memory=torch.zeros(
(B, self._memory_size, self._word_size), device=device
),
read_weights=torch.zeros(
(B, self._num_reads, self._memory_size), device=device
),
write_weights=torch.zeros(
(B, self._num_writes, self._memory_size), device=device
),
linkage=None,
usage=torch.zeros((B, self._memory_size), device=device),
)
inputs = self._read_inputs(inputs)
# Update usage using inputs['free_gate'] and previous read & write weights.
usage = self._freeness(
write_weights=prev_state.write_weights,
free_gate=inputs["free_gate"],
read_weights=prev_state.read_weights,
prev_usage=prev_state.usage,
)
# Write to memory.
write_weights = self._write_weights(inputs, prev_state.memory, usage)
memory = _erase_and_write(
prev_state.memory,
address=write_weights,
reset_weights=inputs["erase_vectors"],
values=inputs["write_vectors"],
)
linkage_state = self._linkage(write_weights, prev_state.linkage)
# Read from memory.
read_weights = self._read_weights(
inputs,
memory=memory,
prev_read_weights=prev_state.read_weights,
link=linkage_state.link,
)
read_words = torch.matmul(read_weights, memory)
return (
read_words,
AccessState(
memory=memory,
read_weights=read_weights,
write_weights=write_weights,
linkage=linkage_state,
usage=usage,
),
)

368
dnc/addressing.py Normal file
View File

@ -0,0 +1,368 @@
from dataclasses import dataclass
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
@dataclass
class TemporalLinkageState:
link: Tensor # [batch_size, num_writes, memory_size, memory_size]
precedence_weights: Tensor # [batch_size, num_writes, memory_size]
_EPSILON = 1e-6
def _vector_norms(m: Tensor) -> Tensor:
norm = torch.sum(m * m, axis=2, keepdim=True)
return torch.sqrt(norm + _EPSILON)
def weighted_softmax(activations: Tensor, strengths: Tensor, strengths_op=F.softplus):
"""Returns softmax over activations multiplied by positive strengths.
Args:
activations: A tensor of shape `[batch_size, num_heads, memory_size]`, of
activations to be transformed. Softmax is taken over the last dimension.
strengths: A tensor of shape `[batch_size, num_heads]` containing strengths to
multiply by the activations prior to the softmax.
strengths_op: An operation to transform strengths before softmax.
Returns:
A tensor of same shape as `activations` with weighted softmax applied.
"""
transformed_strengths = strengths_op(strengths).unsqueeze(-1)
sharp_activations = activations * transformed_strengths
softmax = F.softmax(sharp_activations, dim=-1)
return softmax
class CosineWeights(torch.nn.Module):
def __init__(self, num_heads, word_size, strength_op=F.softplus):
"""
Args:
num_heads: number of memory heads.
word_size: memory word size.
strength_op: operation to apply to strengths (default softplus).
"""
super().__init__()
self._num_heads = num_heads
self._word_size = word_size
self._strength_op = strength_op
def forward(self, memory: Tensor, keys: Tensor, strengths: Tensor) -> Tensor:
"""
Args:
memory: A 3-D tensor of shape `[batch_size, memory_size, word_size]`.
keys: A 3-D tensor of shape `[batch_size, num_heads, word_size]`.
strengths: A 2-D tensor of shape `[batch_size, num_heads]`.
Returns:
Weights tensor of shape `[batch_size, num_heads, memory_size]`.
"""
dot = torch.matmul(keys, memory.transpose(-1, -2)) # <B, H, M>
memory_norms = _vector_norms(memory) # <B, M, 1>
key_norms = _vector_norms(keys) # <B, H, 1>
norm = torch.matmul(key_norms, memory_norms.transpose(-1, -2)) # <B, H, M>
similarity = dot / (norm + _EPSILON)
return weighted_softmax(similarity, strengths, self._strength_op)
class TemporalLinkage(torch.nn.Module):
def __init__(self, memory_size, num_writes):
super().__init__()
self._memory_size = memory_size
self._num_writes = num_writes
def _link(
self,
prev_link: Tensor,
prev_precedence_weights: Tensor,
write_weights: Tensor,
) -> Tensor:
"""Calculates the new link graphs.
For each write head, the link is a directed graph (represented by a matrix
with entries in range [0, 1]) whose vertices are the memory locations, and
an edge indicates temporal ordering of writes.
Args:
prev_link: A tensor of shape `[batch_size, num_writes, memory_size,
memory_size]` representing the previous link graphs for each write
head.
prev_precedence_weights: A tensor of shape `[batch_size, num_writes,
memory_size]` which is the previous "aggregated" write weights for
each write head.
write_weights: A tensor of shape `[batch_size, num_writes, memory_size]`
containing the new locations in memory written to.
Returns:
A tensor of shape `[batch_size, num_writes, memory_size, memory_size]`
containing the new link graphs for each write head.
"""
batch_size = prev_link.size(0)
write_weights_i = write_weights.unsqueeze(3) # <B, W, M, 1>
write_weights_j = write_weights.unsqueeze(2) # <B, W, 1, M>
prev_precedence_weights_j = prev_precedence_weights.unsqueeze(2) # <B, W, 1, M>
prev_link_scale = 1 - write_weights_i - write_weights_j # <B, W, M, M>
new_link = write_weights_i * prev_precedence_weights_j # <B, W, M, M>
link = prev_link_scale * prev_link + new_link # <B, W, M, M>
# Return the link with the diagonal set to zero, to remove self-looping
# edges.
mask = (
torch.eye(self._memory_size)
.repeat(batch_size, self._num_writes, 1, 1)
.bool()
)
link[mask] = 0
return link
def _precedence_weights(
self,
prev_precedence_weights: Tensor,
write_weights: Tensor,
) -> Tensor:
"""Calculates the new precedence weights given the current write weights.
The precedence weights are the "aggregated write weights" for each write
head, where write weights with sum close to zero will leave the precedence
weights unchanged, but with sum close to one will replace the precedence
weights.
Args:
prev_precedence_weights: A tensor of shape `[batch_size, num_writes,
memory_size]` containing the previous precedence weights.
write_weights: A tensor of shape `[batch_size, num_writes, memory_size]`
containing the new write weights.
Returns:
A tensor of shape `[batch_size, num_writes, memory_size]` containing the
new precedence weights.
"""
write_sum = write_weights.sum(dim=2, keepdim=True)
return (1 - write_sum) * prev_precedence_weights + write_weights
def forward(
self,
write_weights: Tensor,
prev_state: Optional[TemporalLinkageState] = None,
) -> TemporalLinkageState:
"""Calculate the updated linkage state given the write weights.
Args:
write_weights: A tensor of shape `[batch_size, num_writes, memory_size]`
containing the memory addresses of the different write heads.
prev_state: `TemporalLinkageState` tuple containg a tensor `link` of
shape `[batch_size, num_writes, memory_size, memory_size]`, and a
tensor `precedence_weights` of shape `[batch_size, num_writes,
memory_size]` containing the aggregated history of recent writes.
Returns:
A `TemporalLinkageState` tuple `next_state`, which contains the updated
link and precedence weights.
"""
if write_weights.ndim != 3:
raise ValueError(
f"Expected `write_weights` to be 3D. Found: {write_weights.ndim}"
)
if (
write_weights.size(1) != self._num_writes
or write_weights.size(2) != self._memory_size
):
raise ValueError(
"Expected the shape of `write_weights` to be "
f"[batch, {self._num_writes}, {self._memory_size}]. "
f"Found: {write_weights.shape}."
)
if prev_state is None:
B, W, M = write_weights.shape
prev_state = TemporalLinkageState(
link=torch.zeros((B, W, M, M), device=write_weights.device),
precedence_weights=torch.zeros((B, W, M), device=write_weights.device),
)
link = self._link(prev_state.link, prev_state.precedence_weights, write_weights)
precedence_weights = self._precedence_weights(
prev_state.precedence_weights, write_weights
)
return TemporalLinkageState(link=link, precedence_weights=precedence_weights)
def directional_read_weights(
self,
link: Tensor,
prev_read_weights: Tensor,
is_forward: bool,
) -> Tensor:
"""Calculates the forward or the backward read weights.
For each read head (at a given address), there are `num_writes` link graphs
to follow. Thus this function computes a read address for each of the
`num_reads * num_writes` pairs of read and write heads.
Args:
link: tensor of shape `[batch_size, num_writes, memory_size,
memory_size]` representing the link graphs L_t.
prev_read_weights: tensor of shape `[batch_size, num_reads,
memory_size]` containing the previous read weights w_{t-1}^r.
forward: Boolean indicating whether to follow the "future" direction in
the link graph (True) or the "past" direction (False).
Returns:
tensor of shape `[batch_size, num_reads, num_writes, memory_size]`
"""
# <B, W, R, M>
expanded_read_weights = torch.stack(
[prev_read_weights for _ in range(self._num_writes)], dim=1
)
if is_forward:
link = link.transpose(-1, -2)
result = torch.matmul(expanded_read_weights, link) # <B, W, R, M>
return result.permute((0, 2, 1, 3)) # <B, R, W, M>
class Freeness(torch.nn.Module):
def __init__(self, memory_size):
super().__init__()
self._memory_size = memory_size
def _usage_after_write(
self,
prev_usage: Tensor,
write_weights: Tensor,
) -> Tensor:
"""Calcualtes the new usage after writing to memory.
Args:
prev_usage: tensor of shape `[batch_size, memory_size]`.
write_weights: tensor of shape `[batch_size, num_writes, memory_size]`.
Returns:
New usage, a tensor of shape `[batch_size, memory_size]`.
"""
write_weights = 1 - torch.prod(1 - write_weights, 1)
return prev_usage + (1 - prev_usage) * write_weights
def _usage_after_read(
self, prev_usage: Tensor, free_gate: Tensor, read_weights: Tensor
) -> Tensor:
"""Calcualtes the new usage after reading and freeing from memory.
Args:
prev_usage: tensor of shape `[batch_size, memory_size]`.
free_gate: tensor of shape `[batch_size, num_reads]` with entries in the
range [0, 1] indicating the amount that locations read from can be
freed.
read_weights: tensor of shape `[batch_size, num_reads, memory_size]`.
Returns:
New usage, a tensor of shape `[batch_size, memory_size]`.
"""
free_gate = free_gate.unsqueeze(-1)
free_read_weights = free_gate * read_weights
phi = torch.prod(1 - free_read_weights, 1)
return prev_usage * phi
def forward(
self,
write_weights: Tensor,
free_gate: Tensor,
read_weights: Tensor,
prev_usage: Tensor,
) -> Tensor:
"""Calculates the new memory usage u_t.
Memory that was written to in the previous time step will have its usage
increased; memory that was read from and the controller says can be "freed"
will have its usage decreased.
Args:
write_weights: tensor of shape `[batch_size, num_writes,
memory_size]` giving write weights at previous time step.
free_gate: tensor of shape `[batch_size, num_reads]` which indicates
which read heads read memory that can now be freed.
read_weights: tensor of shape `[batch_size, num_reads,
memory_size]` giving read weights at previous time step.
prev_usage: tensor of shape `[batch_size, memory_size]` giving
usage u_{t - 1} at the previous time step, with entries in range
[0, 1].
Returns:
tensor of shape `[batch_size, memory_size]` representing updated memory
usage.
"""
with torch.no_grad():
usage = self._usage_after_write(prev_usage, write_weights)
usage = self._usage_after_read(usage, free_gate, read_weights)
return usage
def _allocation(self, usage: Tensor) -> Tensor:
"""Computes allocation by sorting `usage`.
This corresponds to the value a = a_t[\phi_t[j]] in the paper.
Args:
usage: tensor of shape `[batch_size, memory_size]` indicating current
memory usage. This is equal to u_t in the paper when we only have one
write head, but for multiple write heads, one should update the usage
while iterating through the write heads to take into account the
allocation returned by this function.
Returns:
Tensor of shape `[batch_size, memory_size]` corresponding to allocation.
"""
usage = _EPSILON + (1 - _EPSILON) * usage
nonusage = 1 - usage
sorted_nonusage, indices = torch.topk(nonusage, k=self._memory_size)
sorted_usage = 1 - sorted_nonusage
# emulate tf.cumprod(exclusive=True)
sorted_usage = F.pad(sorted_usage, (1, 0), mode="constant", value=1)
prod_sorted_usage = torch.cumprod(sorted_usage, dim=1)
prod_sorted_usage = prod_sorted_usage[:, :-1]
sorted_allocation = sorted_nonusage * prod_sorted_usage
inverse_indices = torch.argsort(indices)
return torch.gather(sorted_allocation, 1, inverse_indices)
def write_allocation_weights(
self,
usage: Tensor,
write_gates: Tensor,
num_writes: Tensor,
) -> Tensor:
"""Calculates freeness-based locations for writing to.
This finds unused memory by ranking the memory locations by usage, for each
write head. (For more than one write head, we use a "simulated new usage"
which takes into account the fact that the previous write head will increase
the usage in that area of the memory.)
Args:
usage: A tensor of shape `[batch_size, memory_size]` representing
current memory usage.
write_gates: A tensor of shape `[batch_size, num_writes]` with values in
the range [0, 1] indicating how much each write head does writing
based on the address returned here (and hence how much usage
increases).
num_writes: The number of write heads to calculate write weights for.
Returns:
tensor of shape `[batch_size, num_writes, memory_size]` containing the
freeness-based write locations. Note that this isn't scaled by
`write_gate`; this scaling must be applied externally.
"""
write_gates = write_gates.unsqueeze(-1)
allocation_weights = []
for i in range(num_writes):
allocation_weights.append(self._allocation(usage))
usage = usage + (1 - usage) * write_gates[:, i, :] * allocation_weights[i]
return torch.stack(allocation_weights, dim=1)

129
dnc/dnc.py Normal file
View File

@ -0,0 +1,129 @@
from dataclasses import dataclass
from typing import Tuple, Optional
import torch
from torch import Tensor
from .access import MemoryAccess, AccessState
@dataclass
class DNCState:
access_output: Tensor # [batch_size, num_reads, word_size]
access_state: AccessState
# memory: Tensor [batch_size, memory_size, word_size]
# read_weights: Tensor # [batch_size, num_reads, memory_size]
# write_weights: Tensor # [batch_size, num_writes, memory_size]
# linkage: TemporalLinkageState
# link: [batch_size, num_writes, memory_size, memory_size]
# precedence_weights: [batch_size, num_writes, memory_size]
# usage: Tensor # [batch_size, memory_size]
controller_state: Tuple[Tensor, Tensor]
# h_n: [num_layers, batch_size, projection_size]
# c_n: [num_layers, batch_size, hidden_size]
class DNC(torch.nn.Module):
"""DNC core module.
Contains controller and memory access module.
"""
def __init__(
self,
access_config,
controller_config,
output_size,
clip_value=None,
):
"""Initializes the DNC core.
Args:
access_config: dictionary of access module configurations.
controller_config: dictionary of controller (LSTM) module configurations.
output_size: output dimension size of core.
clip_value: clips controller and core output values to between
`[-clip_value, clip_value]` if specified.
Raises:
TypeError: if direct_input_size is not None for any access module other
than KeyValueMemory.
"""
super().__init__()
self._controller = torch.nn.LSTMCell(**controller_config)
self._access = MemoryAccess(**access_config)
self._output = torch.nn.LazyLinear(output_size)
if clip_value is None:
self._clip = lambda x: x
else:
self._clip = lambda x: torch.clamp(x, min=-clip_value, max=clip_value)
def forward(self, inputs: Tensor, prev_state: Optional[DNCState] = None):
"""Connects the DNC core into the graph.
Args:
inputs: Tensor input.
prev_state: A `DNCState` tuple containing the fields `access_output`,
`access_state` and `controller_state`. `access_state` is a 3-D Tensor
of shape `[batch_size, num_reads, word_size]` containing read words.
`access_state` is a tuple of the access module's state, and
`controller_state` is a tuple of controller module's state.
Returns:
A tuple `(output, next_state)` where `output` is a tensor and `next_state`
is a `DNCState` tuple containing the fields `access_output`,
`access_state`, and `controller_state`.
"""
if inputs.ndim != 2:
raise ValueError(f"Expected `inputs` to be 2D: Found {inputs.ndim}.")
if prev_state is None:
B, device = inputs.shape[0], inputs.device
num_reads = self._access._num_reads
word_size = self._access._word_size
prev_state = DNCState(
access_output=torch.zeros((B, num_reads, word_size), device=device),
access_state=None,
controller_state=None,
)
def batch_flatten(x):
return torch.reshape(x, [x.size(0), -1])
controller_input = torch.concat(
[
batch_flatten(inputs), # [batch_size, num_input_feats]
batch_flatten(
prev_state.access_output
), # [batch_size, num_reads*word_size]
],
dim=1,
) # [batch_size, num_input_feats + num_reads * word_size]
controller_state = self._controller(
controller_input, prev_state.controller_state
)
controller_state = tuple(self._clip(t) for t in controller_state)
controller_output = controller_state[0]
access_output, access_state = self._access(
controller_output, prev_state.access_state
)
output = torch.concat(
[
controller_output, # [batch_size, num_ctrl_feats]
batch_flatten(access_output), # [batch_size, num_reads*word_size]
],
dim=1,
)
output = self._output(output)
output = self._clip(output)
return (
output,
DNCState(
access_output=access_output,
access_state=access_state,
controller_state=controller_state,
),
)

383
dnc/repeat_copy.py Normal file
View File

@ -0,0 +1,383 @@
from dataclasses import dataclass
import torch
from torch import Tensor
import torch.nn.functional as F
@dataclass
class DatasetTensors:
observations: Tensor
target: Tensor
mask: Tensor
def masked_sigmoid_cross_entropy(
logits,
target,
mask,
time_average=False,
log_prob_in_bits=False,
):
"""Adds ops to graph which compute the (scalar) NLL of the target sequence.
The logits parametrize independent bernoulli distributions per time-step and
per batch element, and irrelevant time/batch elements are masked out by the
mask tensor.
Args:
logits: `Tensor` of activations for which sigmoid(`logits`) gives the
bernoulli parameter.
target: time-major `Tensor` of target.
mask: time-major `Tensor` to be multiplied elementwise with cost T x B cost
masking out irrelevant time-steps.
time_average: optionally average over the time dimension (sum by default).
log_prob_in_bits: iff True express log-probabilities in bits (default nats).
Returns:
A `Tensor` representing the log-probability of the target.
"""
batch_size = logits.shape[1]
xent = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
loss_time_batch = xent.sum(dim=2)
loss_batch = (loss_time_batch * mask).sum(dim=0)
if time_average:
mask_count = mask.sum(dim=0)
loss_batch /= mask_count + torch.finfo(loss_batch.dtype).eps
loss = loss_batch.sum() / batch_size
if log_prob_in_bits:
loss /= torch.log(2.0)
return loss
def bitstring_readable(data, batch_size, model_output=None, whole_batch=False):
"""Produce a human readable representation of the sequences in data.
Args:
data: data to be visualised
batch_size: size of batch
model_output: optional model output tensor to visualize alongside data.
whole_batch: whether to visualise the whole batch. Only the first sample
will be visualized if False
Returns:
A string used to visualise the data batch
"""
def _readable(datum):
return "+" + " ".join([f"{int(x):d}" if x else "-" for x in datum]) + "+"
obs_batch = data.observations
targ_batch = data.target
batch_strings = []
for batch_index in range(batch_size if whole_batch else 1):
obs = obs_batch[:, batch_index, :]
targ = targ_batch[:, batch_index, :]
obs_channels = range(obs.shape[1])
targ_channels = range(targ.shape[1])
obs_channel_strings = [_readable(obs[:, i]) for i in obs_channels]
targ_channel_strings = [_readable(targ[:, i]) for i in targ_channels]
readable_obs = "Observations:\n" + "\n".join(obs_channel_strings)
readable_targ = "Targets:\n" + "\n".join(targ_channel_strings)
strings = [readable_obs, readable_targ]
if model_output is not None:
output = model_output[:, batch_index, :]
output_strings = [_readable(output[:, i]) for i in targ_channels]
strings.append("Model Output:\n" + "\n".join(output_strings))
batch_strings.append("\n\n".join(strings))
return "\n" + "\n\n\n\n".join(batch_strings)
class RepeatCopy:
"""Sequence data generator for the task of repeating a random binary pattern.
When called, an instance of this class will return a tuple of tensorflow ops
(obs, targ, mask), representing an input sequence, target sequence, and
binary mask. Each of these ops produces tensors whose first two dimensions
represent sequence position and batch index respectively. The value in
mask[t, b] is equal to 1 iff a prediction about targ[t, b, :] should be
penalized and 0 otherwise.
For each realisation from this generator, the observation sequence is
comprised of I.I.D. uniform-random binary vectors (and some flags).
The target sequence is comprised of this binary pattern repeated
some number of times (and some flags). Before explaining in more detail,
let's examine the setup pictorially for a single batch element:
```none
Note: blank space represents 0.
time ------------------------------------------>
+-------------------------------+
mask: |0000000001111111111111111111111|
+-------------------------------+
+-------------------------------+
target: | 1| 'end-marker' channel.
| 101100110110011011001 |
| 010101001010100101010 |
+-------------------------------+
+-------------------------------+
observation: | 1011001 |
| 0101010 |
|1 | 'start-marker' channel
| 3 | 'num-repeats' channel.
+-------------------------------+
```
The length of the random pattern and the number of times it is repeated
in the target are both discrete random variables distributed according to
uniform distributions whose parameters are configured at construction time.
The obs sequence has two extra channels (components in the trailing dimension)
which are used for flags. One channel is marked with a 1 at the first time
step and is otherwise equal to 0. The other extra channel is zero until the
binary pattern to be repeated ends. At this point, it contains an encoding of
the number of times the observation pattern should be repeated. Rather than
simply providing this integer number directly, it is normalised so that
a neural network may have an easier time representing the number of
repetitions internally. To allow a network to be readily evaluated on
instances of this task with greater numbers of repetitions, the range with
respect to which this encoding is normalised is also configurable by the user.
As in the diagram, the target sequence is offset to begin directly after the
observation sequence; both sequences are padded with zeros to accomplish this,
resulting in their lengths being equal. Additional padding is done at the end
so that all sequences in a minibatch represent tensors with the same shape.
"""
def __init__(
self,
num_bits=6,
batch_size=1,
min_length=1,
max_length=1,
min_repeats=1,
max_repeats=2,
norm_max=10,
log_prob_in_bits=False,
time_average_cost=False,
):
"""Creates an instance of RepeatCopy task.
Args:
name: A name for the generator instance (for name scope purposes).
num_bits: The dimensionality of each random binary vector.
batch_size: Minibatch size per realization.
min_length: Lower limit on number of random binary vectors in the
observation pattern.
max_length: Upper limit on number of random binary vectors in the
observation pattern.
min_repeats: Lower limit on number of times the obervation pattern
is repeated in targ.
max_repeats: Upper limit on number of times the observation pattern
is repeated in targ.
norm_max: Upper limit on uniform distribution w.r.t which the encoding
of the number of repetitions presented in the observation sequence
is normalised.
log_prob_in_bits: By default, log probabilities are expressed in units of
nats. If true, express log probabilities in bits.
time_average_cost: If true, the cost at each time step will be
divided by the `true`, sequence length, the number of non-masked time
steps, in each sequence before any subsequent reduction over the time
and batch dimensions.
"""
super().__init__()
self._batch_size = batch_size
self._num_bits = num_bits
self._min_length = min_length
self._max_length = max_length
self._min_repeats = min_repeats
self._max_repeats = max_repeats
self._norm_max = norm_max
self._log_prob_in_bits = log_prob_in_bits
self._time_average_cost = time_average_cost
def _normalise(self, val):
return val / self._norm_max
def _unnormalise(self, val):
return val * self._norm_max
@property
def time_average_cost(self):
return self._time_average_cost
@property
def log_prob_in_bits(self):
return self._log_prob_in_bits
@property
def num_bits(self):
"""The dimensionality of each random binary vector in a pattern."""
return self._num_bits
@property
def target_size(self):
"""The dimensionality of the target tensor."""
return self._num_bits + 1
@property
def batch_size(self):
return self._batch_size
def __call__(self):
"""Implements build method which adds ops to graph."""
# short-hand for private fields.
min_length, max_length = self._min_length, self._max_length
min_reps, max_reps = self._min_repeats, self._max_repeats
num_bits = self.num_bits
batch_size = self.batch_size
# We reserve one dimension for the num-repeats and one for the start-marker.
full_obs_size = num_bits + 2
# We reserve one target dimension for the end-marker.
full_targ_size = num_bits + 1
start_end_flag_idx = full_obs_size - 2
num_repeats_channel_idx = full_obs_size - 1
# Samples each batch index's sequence length and the number of repeats.
sub_seq_length_batch = torch.randint(
low=min_length, high=max_length + 1, size=[batch_size], dtype=torch.int32
)
num_repeats_batch = torch.randint(
low=min_reps, high=max_reps + 1, size=[batch_size], dtype=torch.int32
)
# Pads all the batches to have the same total sequence length.
total_length_batch = sub_seq_length_batch * (num_repeats_batch + 1) + 3
max_length_batch = total_length_batch.max()
residual_length_batch = max_length_batch - total_length_batch
obs_batch_shape = [max_length_batch, batch_size, full_obs_size]
targ_batch_shape = [max_length_batch, batch_size, full_targ_size]
mask_batch_trans_shape = [batch_size, max_length_batch]
obs_tensors = []
targ_tensors = []
mask_tensors = []
# Generates patterns for each batch element independently.
for batch_index in range(batch_size):
sub_seq_len = sub_seq_length_batch[batch_index]
num_reps = num_repeats_batch[batch_index]
# The observation pattern is a sequence of random binary vectors.
obs_pattern_shape = [sub_seq_len, num_bits]
obs_pattern = torch.randint(low=0, high=2, size=obs_pattern_shape).to(
torch.float32
)
# The target pattern is the observation pattern repeated n times.
# Some reshaping is required to accomplish the tiling.
targ_pattern_shape = [sub_seq_len * num_reps, num_bits]
flat_obs_pattern = obs_pattern.reshape([-1])
flat_targ_pattern = torch.tile(flat_obs_pattern, [num_reps])
targ_pattern = flat_targ_pattern.reshape(targ_pattern_shape)
# Expand the obs_pattern to have two extra channels for flags.
# Concatenate start flag and num_reps flag to the sequence.
obs_flag_channel_pad = torch.zeros([sub_seq_len, 2])
obs_start_flag = F.one_hot(
torch.tensor([start_end_flag_idx]), num_classes=full_obs_size
).to(torch.float32)
num_reps_flag = F.one_hot(
torch.tensor([num_repeats_channel_idx]), num_classes=full_obs_size
).to(torch.float32)
num_reps_flag *= self._normalise(num_reps)
# note the concatenation dimensions.
obs = torch.concat([obs_pattern, obs_flag_channel_pad], 1)
obs = torch.concat([obs_start_flag, obs], 0)
obs = torch.concat([obs, num_reps_flag], 0)
# Now do the same for the targ_pattern (it only has one extra channel).
targ_flag_channel_pad = torch.zeros([sub_seq_len * num_reps, 1])
targ_end_flag = F.one_hot(
torch.tensor([start_end_flag_idx]), num_classes=full_targ_size
).to(torch.float32)
targ = torch.concat([targ_pattern, targ_flag_channel_pad], 1)
targ = torch.concat([targ, targ_end_flag], 0)
# Concatenate zeros at end of obs and begining of targ.
# This aligns them s.t. the target begins as soon as the obs ends.
obs_end_pad = torch.zeros([sub_seq_len * num_reps + 1, full_obs_size])
targ_start_pad = torch.zeros([sub_seq_len + 2, full_targ_size])
# The mask is zero during the obs and one during the targ.
mask_off = torch.zeros([sub_seq_len + 2])
mask_on = torch.ones([sub_seq_len * num_reps + 1])
obs = torch.concat([obs, obs_end_pad], 0)
targ = torch.concat([targ_start_pad, targ], 0)
mask = torch.concat([mask_off, mask_on], 0)
obs_tensors.append(obs)
targ_tensors.append(targ)
mask_tensors.append(mask)
# End the loop over batch index.
# Compute how much zero padding is needed to make tensors sequences
# the same length for all batch elements.
residual_obs_pad = [
torch.zeros([residual_length_batch[i], full_obs_size])
for i in range(batch_size)
]
residual_targ_pad = [
torch.zeros([residual_length_batch[i], full_targ_size])
for i in range(batch_size)
]
residual_mask_pad = [
torch.zeros([residual_length_batch[i]]) for i in range(batch_size)
]
# Concatenate the pad to each batch element.
obs_tensors = [
torch.concat([o, p], 0) for o, p in zip(obs_tensors, residual_obs_pad)
]
targ_tensors = [
torch.concat([t, p], 0) for t, p in zip(targ_tensors, residual_targ_pad)
]
mask_tensors = [
torch.concat([m, p], 0) for m, p in zip(mask_tensors, residual_mask_pad)
]
# Concatenate each batch element into a single tensor.
obs = torch.concat(obs_tensors, 1).reshape(obs_batch_shape)
targ = torch.concat(targ_tensors, 1).reshape(targ_batch_shape)
mask = torch.concat(mask_tensors, 0).reshape(mask_batch_trans_shape).T
return DatasetTensors(obs, targ, mask)
def cost(self, logits, targ, mask):
return masked_sigmoid_cross_entropy(
logits,
targ,
mask,
time_average=self.time_average_cost,
log_prob_in_bits=self.log_prob_in_bits,
)
def to_human_readable(self, data, model_output=None, whole_batch=False):
obs = data.observations
unnormalised_num_reps_flag = self._unnormalise(obs[:, :, -1:]).round()
obs = torch.cat([obs[:, :, :-1], unnormalised_num_reps_flag], dim=2)
data = DatasetTensors(
observations=obs,
target=data.target,
mask=data.mask,
)
return bitstring_readable(data, self.batch_size, model_output, whole_batch)

1
requirements.txt Normal file
View File

@ -0,0 +1 @@
torch > 1.9

0
tests/__init__.py Normal file
View File

155
tests/access_test.py Normal file
View File

@ -0,0 +1,155 @@
import torch
from dnc.access import MemoryAccess, AccessState
from dnc.addressing import TemporalLinkageState
from .util import one_hot
BATCH_SIZE = 2
MEMORY_SIZE = 20
WORD_SIZE = 6
NUM_READS = 2
NUM_WRITES = 3
TIME_STEPS = 4
INPUT_SIZE = 10
def test_memory_access_build_and_train():
module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES)
inputs = torch.randn((TIME_STEPS, BATCH_SIZE, INPUT_SIZE))
outputs = []
state = None
for input in inputs:
output, state = module(input, state)
outputs.append(output)
outputs = torch.stack(outputs, dim=-1)
targets = torch.rand((TIME_STEPS, BATCH_SIZE, NUM_READS, WORD_SIZE))
loss = (output - targets).square().mean()
optim = torch.optim.SGD(module.parameters(), lr=1)
optim.zero_grad()
loss.backward()
optim.step()
def test_memory_access_valid_read_mode():
module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES)
inputs = module._read_inputs(torch.randn((BATCH_SIZE, INPUT_SIZE)))
# Check that the read modes for each read head constitute a probability
# distribution.
torch.testing.assert_close(
inputs["read_mode"].sum(2), torch.ones([BATCH_SIZE, NUM_READS])
)
assert torch.all(inputs["read_mode"] >= 0)
def test_memory_access_write_weights():
memory = 10 * (torch.rand((BATCH_SIZE, MEMORY_SIZE, WORD_SIZE)) - 0.5)
usage = torch.rand((BATCH_SIZE, MEMORY_SIZE))
allocation_gate = torch.rand((BATCH_SIZE, NUM_WRITES))
write_gate = torch.rand((BATCH_SIZE, NUM_WRITES))
write_gate = torch.rand((BATCH_SIZE, NUM_WRITES))
write_content_keys = torch.rand((BATCH_SIZE, NUM_WRITES, WORD_SIZE))
write_content_strengths = torch.rand((BATCH_SIZE, NUM_WRITES))
# Check that turning on allocation gate fully brings the write gate to
# the allocation weighting (which we will control by controlling the usage).
usage[:, 3] = 0
allocation_gate[:, 0] = 1
write_gate[:, 0] = 1
inputs = {
"allocation_gate": allocation_gate,
"write_gate": write_gate,
"write_content_keys": write_content_keys,
"write_content_strengths": write_content_strengths,
}
module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES)
weights = module._write_weights(inputs, memory, usage)
# Check the weights sum to their target gating.
torch.testing.assert_close(weights.sum(dim=2), write_gate, atol=5e-2, rtol=0)
# Check that we fully allocated to the third row.
weights_0_0_target = one_hot(MEMORY_SIZE, 3, dtype=torch.float32)
torch.testing.assert_close(weights[0, 0], weights_0_0_target, atol=1e-3, rtol=0)
def test_memory_access_read_weights():
memory = 10 * (torch.rand((BATCH_SIZE, MEMORY_SIZE, WORD_SIZE)) - 0.5)
prev_read_weights = torch.rand((BATCH_SIZE, NUM_READS, MEMORY_SIZE))
prev_read_weights /= prev_read_weights.sum(2, keepdim=True) + 1
link = torch.rand((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE, MEMORY_SIZE))
# Row and column sums should be at most 1:
link /= torch.maximum(link.sum(2, keepdim=True), torch.tensor(1))
link /= torch.maximum(link.sum(3, keepdim=True), torch.tensor(1))
# We query the memory on the third location in memory, and select a large
# strength on the query. Then we select a content-based read-mode.
read_content_keys = torch.rand((BATCH_SIZE, NUM_READS, WORD_SIZE))
read_content_keys[0, 0] = memory[0, 3]
read_content_strengths = torch.full(
(BATCH_SIZE, NUM_READS), 100.0, dtype=torch.float64
)
read_mode = torch.rand((BATCH_SIZE, NUM_READS, 1 + 2 * NUM_WRITES))
read_mode[0, 0, :] = one_hot(1 + 2 * NUM_WRITES, 2 * NUM_WRITES)
inputs = {
"read_content_keys": read_content_keys,
"read_content_strengths": read_content_strengths,
"read_mode": read_mode,
}
module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES)
read_weights = module._read_weights(inputs, memory, prev_read_weights, link)
# read_weights for batch 0, read head 0 should be memory location 3
ref = one_hot(MEMORY_SIZE, 3, dtype=torch.float64)
torch.testing.assert_close(read_weights[0, 0, :], ref, atol=1e-3, rtol=0)
def test_memory_access_gradients():
kwargs = {"dtype": torch.float64, "requires_grad": True}
inputs = torch.randn((BATCH_SIZE, INPUT_SIZE), **kwargs)
memory = torch.zeros((BATCH_SIZE, MEMORY_SIZE, WORD_SIZE), **kwargs)
read_weights = torch.zeros((BATCH_SIZE, NUM_READS, MEMORY_SIZE), **kwargs)
precedence_weights = torch.zeros((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE), **kwargs)
link = torch.zeros((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE, MEMORY_SIZE), **kwargs)
module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES)
class Wrapper(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inputs, memory, read_weights, link, precedence_weights):
write_weights = torch.zeros((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE))
usage = torch.zeros((BATCH_SIZE, MEMORY_SIZE))
state = AccessState(
memory=memory,
read_weights=read_weights,
write_weights=write_weights,
linkage=TemporalLinkageState(
link=link,
precedence_weights=precedence_weights,
),
usage=usage,
)
output, _ = self.module(inputs, state)
return output.sum()
module = Wrapper(module).to(torch.float64)
torch.autograd.gradcheck(
module,
(inputs, memory, read_weights, link, precedence_weights),
)

338
tests/addressing_test.py Normal file
View File

@ -0,0 +1,338 @@
import numpy as np
import torch
import torch.nn.functional as F
from dnc.addressing import weighted_softmax, CosineWeights, TemporalLinkage, Freeness
from .util import one_hot
def _test_weighted_softmax(strength_op):
batch_size, num_heads, memory_size = 5, 3, 7
activations = torch.randn(batch_size, num_heads, memory_size)
weights = torch.ones((batch_size, num_heads))
observed = weighted_softmax(activations, weights, strength_op)
expected = torch.stack(
[
F.softmax(a * strength_op(w).unsqueeze(-1))
for a, w in zip(activations, weights)
],
dim=0,
)
torch.testing.assert_close(observed, expected)
def test_weighted_softmax_identity():
_test_weighted_softmax(lambda x: x)
def test_weighted_softmax_softplus():
_test_weighted_softmax(F.softplus)
def test_cosine_weights_shape():
batch_size, num_heads, memory_size, word_size = 5, 3, 7, 2
module = CosineWeights(num_heads, word_size)
mem = torch.randn([batch_size, memory_size, word_size])
keys = torch.randn([batch_size, num_heads, word_size])
strengths = torch.randn([batch_size, num_heads])
weights = module(mem, keys, strengths)
assert weights.shape == torch.Size([batch_size, num_heads, memory_size])
def test_cosine_weights_values():
batch_size, num_heads, memory_size, word_size = 5, 4, 10, 2
mem = torch.randn((batch_size, memory_size, word_size))
mem[0, 0, 0] = 1
mem[0, 0, 1] = 2
mem[0, 1, 0] = 3
mem[0, 1, 1] = 4
mem[0, 2, 0] = 5
mem[0, 2, 1] = 6
keys = torch.randn((batch_size, num_heads, word_size))
keys[0, 0, 0] = 5
keys[0, 0, 1] = 6
keys[0, 1, 0] = 1
keys[0, 1, 1] = 2
keys[0, 2, 0] = 5
keys[0, 2, 1] = 6
keys[0, 3, 0] = 3
keys[0, 3, 1] = 4
strengths = torch.randn((batch_size, num_heads))
module = CosineWeights(num_heads, word_size)
result = module(mem, keys, strengths)
# Manually checks results.
strengths_softplus = np.log(1 + np.exp(strengths.numpy()))
similarity = np.zeros((memory_size))
for b in range(batch_size):
for h in range(num_heads):
key = keys[b, h]
key_norm = np.linalg.norm(key)
for m in range(memory_size):
row = mem[b, m]
similarity[m] = np.dot(key, row) / (key_norm * np.linalg.norm(row))
similarity = np.exp(similarity * strengths_softplus[b, h])
similarity /= similarity.sum()
ref = torch.from_numpy(similarity).to(dtype=torch.float32)
torch.testing.assert_close(result[b, h], ref, atol=1e-4, rtol=1e-4)
def test_cosine_weights_divide_by_zero():
batch_size, num_heads, memory_size, word_size = 5, 4, 10, 2
module = CosineWeights(num_heads, word_size)
keys = torch.randn([batch_size, num_heads, word_size], requires_grad=True)
strengths = torch.randn([batch_size, num_heads], requires_grad=True)
# First row of memory is non-zero to concentrate attention on this location.
# Remaining rows are all zero.
mem = torch.zeros([batch_size, memory_size, word_size])
mem[:, 0, :] = 1
mem.requires_grad = True
output = module(mem, keys, strengths)
output.sum().backward()
assert torch.all(~output.isnan())
assert torch.all(~mem.grad.isnan())
assert torch.all(~keys.grad.isnan())
assert torch.all(~strengths.grad.isnan())
def test_temporal_linkage():
batch_size, memory_size, num_reads, num_writes = 7, 4, 11, 5
module = TemporalLinkage(memory_size=memory_size, num_writes=num_writes)
state = None
num_steps = 5
for i in range(num_steps):
write_weights = torch.rand([batch_size, num_writes, memory_size])
write_weights /= write_weights.sum(2, keepdim=True) + 1
# Simulate (in final steps) link 0-->1 in head 0 and 3-->2 in head 1
if i == num_steps - 2:
write_weights[0, 0, :] = one_hot(memory_size, 0)
write_weights[0, 1, :] = one_hot(memory_size, 3)
elif i == num_steps - 1:
write_weights[0, 0, :] = one_hot(memory_size, 1)
write_weights[0, 1, :] = one_hot(memory_size, 2)
state = module(write_weights, state)
# link should be bounded in range [0, 1]
assert torch.all(0 <= state.link.min() <= 1)
# link diagonal should be zero
torch.testing.assert_close(
state.link[:, :, range(memory_size), range(memory_size)],
torch.zeros([batch_size, num_writes, memory_size]),
)
# link rows and columns should sum to at most 1
assert torch.all(state.link.sum(2) <= 1)
assert torch.all(state.link.sum(3) <= 1)
# records our transitions in batch 0: head 0: 0->1, and head 1: 3->2
torch.testing.assert_close(
state.link[0, 0, :, 0], one_hot(memory_size, 1, dtype=torch.float32)
)
torch.testing.assert_close(
state.link[0, 1, :, 3], one_hot(memory_size, 2, dtype=torch.float32)
)
# Now test calculation of forward and backward read weights
prev_read_weights = torch.randn((batch_size, num_reads, memory_size))
prev_read_weights[0, 5, :] = one_hot(memory_size, 0) # read 5, posn 0
prev_read_weights[0, 6, :] = one_hot(memory_size, 2) # read 6, posn 2
forward_read_weights = module.directional_read_weights(
state.link, prev_read_weights, is_forward=True
)
backward_read_weights = module.directional_read_weights(
state.link, prev_read_weights, is_forward=False
)
# Check directional weights calculated correctly.
torch.testing.assert_close(
forward_read_weights[0, 5, 0, :], # read=5, write=0
one_hot(memory_size, 1, dtype=torch.float32),
)
torch.testing.assert_close(
backward_read_weights[0, 6, 1, :], # read=6, write=1
one_hot(memory_size, 3, dtype=torch.float32),
)
def test_temporal_linkage_precedence_weights():
batch_size, memory_size, num_writes = 7, 3, 5
module = TemporalLinkage(memory_size=memory_size, num_writes=num_writes)
prev_precedence_weights = torch.rand(batch_size, num_writes, memory_size)
write_weights = torch.rand(batch_size, num_writes, memory_size)
# These should sum to at most 1 for each write head in each batch.
write_weights /= write_weights.sum(2, keepdim=True) + 1
prev_precedence_weights /= prev_precedence_weights.sum(2, keepdim=True) + 1
write_weights[0, 1, :] = 0 # batch 0 head 1: no writing
write_weights[1, 2, :] /= write_weights[1, 2, :].sum() # b1 h2: all writing
precedence_weights = module._precedence_weights(
prev_precedence_weights=prev_precedence_weights, write_weights=write_weights
)
# precedence weights should be bounded in range [0, 1]
assert torch.all(0 <= precedence_weights)
assert torch.all(precedence_weights <= 1)
# no writing in batch 0, head 1
torch.testing.assert_close(
precedence_weights[0, 1, :], prev_precedence_weights[0, 1, :]
)
# all writing in batch 1, head 2
torch.testing.assert_close(precedence_weights[1, 2, :], write_weights[1, 2, :])
def test_freeness():
batch_size, memory_size, num_reads, num_writes = 5, 11, 3, 7
module = Freeness(memory_size)
free_gate = torch.rand((batch_size, num_reads))
# Produce read weights that sum to 1 for each batch and head.
prev_read_weights = torch.rand((batch_size, num_reads, memory_size))
prev_read_weights[1, :, 3] = 0
prev_read_weights /= prev_read_weights.sum(2, keepdim=True)
prev_write_weights = torch.rand((batch_size, num_writes, memory_size))
prev_write_weights /= prev_write_weights.sum(2, keepdim=True)
prev_usage = torch.rand((batch_size, memory_size))
# Add some special values that allows us to test the behaviour:
prev_write_weights[1, 2, 3] = 1
prev_read_weights[2, 0, 4] = 1
free_gate[2, 0] = 1
usage = module(prev_write_weights, free_gate, prev_read_weights, prev_usage)
# Check all usages are between 0 and 1.
assert torch.all(0 <= usage)
assert torch.all(usage <= 1)
# Check that the full write at batch 1, position 3 makes it fully used.
assert usage[1][3] == 1
# Check that the full free at batch 2, position 4 makes it fully free.
assert usage[2][4] == 0
def test_freeness_write_allocation_weights():
batch_size, memory_size, num_writes = 7, 23, 5
module = Freeness(memory_size)
usage = torch.rand((batch_size, memory_size))
write_gates = torch.rand((batch_size, num_writes))
# Turn off gates for heads 1 and 3 in batch 0. This doesn't scaling down the
# weighting, but it means that the usage doesn't change, so we should get
# the same allocation weightings for: (1, 2) and (3, 4) (but all others
# being different).
write_gates[0, 1] = 0
write_gates[0, 3] = 0
# and turn heads 0 and 2 on for full effect.
write_gates[0, 0] = 1
write_gates[0, 2] = 1
# In batch 1, make one of the usages 0 and another almost 0, so that these
# entries get most of the allocation weights for the first and second heads.
usage[1] = usage[1] * 0.9 + 0.1 # make sure all entries are in [0.1, 1]
usage[1][4] = 0 # write head 0 should get allocated to position 4
usage[1][3] = 1e-4 # write head 1 should get allocated to position 3
write_gates[1, 0] = 1 # write head 0 fully on
write_gates[1, 1] = 1 # write head 1 fully on
weights = module.write_allocation_weights(
usage=usage, write_gates=write_gates, num_writes=num_writes
)
assert torch.all(0 <= weights)
assert torch.all(weights <= 1)
# Check that weights sum to close to 1
torch.testing.assert_close(
weights.sum(dim=2), torch.ones((batch_size, num_writes)), atol=1e-3, rtol=0
)
# Check the same / different allocation weight pairs as described above.
assert (weights[0, 0, :] - weights[0, 1, :]).abs().max() > 0.1
torch.testing.assert_close(weights[0, 1, :], weights[0, 2, :])
assert (weights[0, 2, :] - weights[0, 3, :]).abs().max() > 0.1
torch.testing.assert_close(weights[0, 3, :], weights[0, 4, :])
torch.testing.assert_close(
weights[1][0], one_hot(memory_size, 4, dtype=torch.float32), atol=1e-3, rtol=0
)
torch.testing.assert_close(
weights[1][1], one_hot(memory_size, 3, dtype=torch.float32), atol=1e-3, rtol=0
)
def test_freeness_write_allocation_weights_gradient():
batch_size, memory_size, num_writes = 7, 5, 3
module = Freeness(memory_size).to(torch.float64)
usage = torch.rand(
(batch_size, memory_size), dtype=torch.float64, requires_grad=True
)
write_gates = torch.rand(
(batch_size, num_writes), dtype=torch.float64, requires_grad=True
)
def func(usage, write_gates):
return module.write_allocation_weights(usage, write_gates, num_writes)
torch.autograd.gradcheck(func, (usage, write_gates))
def test_freeness_allocation():
batch_size, memory_size = 7, 13
usage = torch.rand((batch_size, memory_size))
module = Freeness(memory_size)
allocation = module._allocation(usage)
# 1. Test that max allocation goes to min usage, and vice versa.
assert torch.all(usage.argmin(dim=1) == allocation.argmax(dim=1))
assert torch.all(usage.argmax(dim=1) == allocation.argmin(dim=1))
# 2. Test that allocations sum to almost 1.
torch.testing.assert_close(
allocation.sum(dim=1), torch.ones(batch_size), atol=0.01, rtol=0
)
def test_freeness_allocation_gradient():
batch_size, memory_size = 1, 5
usage = torch.rand(
(batch_size, memory_size), dtype=torch.float64, requires_grad=True
)
module = Freeness(memory_size).to(torch.float64)
torch.autograd.gradcheck(module._allocation, (usage,))

42
tests/dnc_test.py Normal file
View File

@ -0,0 +1,42 @@
import torch
from dnc.dnc import DNC
def test_dnc():
"""smoke test"""
memory_size = 16
word_size = 16
num_reads = 4
num_writes = 1
clip_value = 20
input_size = 4
hidden_size = 64
output_size = input_size
batch_size = 16
time_steps = 64
access_config = {
"memory_size": memory_size,
"word_size": word_size,
"num_reads": num_reads,
"num_writes": num_writes,
}
controller_config = {
"input_size": input_size + num_reads * word_size,
"hidden_size": hidden_size,
}
dnc = DNC(
access_config=access_config,
controller_config=controller_config,
output_size=output_size,
clip_value=clip_value,
)
inputs = torch.randn((time_steps, batch_size, input_size))
state = None
for input in inputs:
output, state = dnc(input, state)

9
tests/util.py Normal file
View File

@ -0,0 +1,9 @@
import torch
import torch.nn.functional as F
def one_hot(length, index, dtype=None):
val = F.one_hot(torch.tensor(index), num_classes=length)
if dtype is not None:
val = val.to(dtype=dtype)
return val

205
train.py Normal file
View File

@ -0,0 +1,205 @@
"""Example script to train the DNC on a repeated copy task."""
import os
import argparse
import logging
import torch
from dnc.repeat_copy import RepeatCopy
from dnc.dnc import DNC
_LG = logging.getLogger(__name__)
def _main():
args = _parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s: %(message)s")
dataset = RepeatCopy(
args.num_bits,
args.batch_size,
args.min_length,
args.max_length,
args.min_repeats,
args.max_repeats,
)
dnc = DNC(
access_config={
"memory_size": args.memory_size,
"word_size": args.word_size,
"num_reads": args.num_read_heads,
"num_writes": args.num_write_heads,
},
controller_config={
"input_size": args.num_bits + 2 + args.num_read_heads * args.word_size,
"hidden_size": args.hidden_size,
},
output_size=dataset.target_size,
clip_value=args.clip_value,
)
optimizer = torch.optim.RMSprop(dnc.parameters(), lr=args.lr, eps=args.eps)
_run_train_loop(
dnc,
dataset,
optimizer,
args.num_training_iterations,
args.report_interval,
args.checkpoint_interval,
args.checkpoint_dir,
)
def _run_train_loop(
dnc,
dataset,
optimizer,
num_training,
report_interval,
checkpoint_interval,
checkpoint_dir,
):
total_loss = 0
for i in range(num_training):
batch = dataset()
state = None
outputs = []
for inputs in batch.observations:
output, state = dnc(inputs, state)
outputs.append(output)
outputs = torch.stack(outputs, 0)
loss = dataset.cost(outputs, batch.target, batch.mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if (i + 1) % report_interval == 0:
outputs = torch.round(batch.mask.unsqueeze(-1) * torch.sigmoid(outputs))
dataset_string = dataset.to_human_readable(batch, outputs)
_LG.info(f"{i}: Avg training loss {total_loss / report_interval}")
_LG.info(dataset_string)
total_loss = 0
if checkpoint_interval is not None and (i + 1) % checkpoint_interval == 0:
path = os.path.join(checkpoint_dir, "model.pt")
torch.save(dnc.state_dict(), path)
def _parse_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description=__doc__,
)
model_opts = parser.add_argument_group("Model Parameters")
model_opts.add_argument(
"--hidden-size", type=int, default=64, help="Size of LSTM hidden layer."
)
model_opts.add_argument(
"--memory-size", type=int, default=16, help="The number of memory slots."
)
model_opts.add_argument(
"--word-size", type=int, default=16, help="The width of each memory slot."
)
model_opts.add_argument(
"--num-write-heads", type=int, default=1, help="Number of memory write heads."
)
model_opts.add_argument(
"--num-read-heads", type=int, default=4, help="Number of memory read heads."
)
model_opts.add_argument(
"--clip-value",
type=float,
default=20,
help="Maximum absolute value of controller and dnc outputs.",
)
optim_opts = parser.add_argument_group("Optimizer Parameters")
optim_opts.add_argument(
"--max-grad-norm", type=float, default=50, help="Gradient clipping norm limit."
)
optim_opts.add_argument(
"--learning-rate",
"--lr",
type=float,
default=1e-4,
dest="lr",
help="Optimizer learning rate.",
)
optim_opts.add_argument(
"--optimizer-epsilon",
type=float,
default=1e-10,
dest="eps",
help="Epsilon used for RMSProp optimizer.",
)
task_opts = parser.add_argument_group("Task Parameters")
task_opts.add_argument(
"--batch-size", type=int, default=16, help="Batch size for training"
)
task_opts.add_argument(
"--num-bits", type=int, default=4, help="Dimensionality of each vector to copy"
)
task_opts.add_argument(
"--min-length",
type=int,
default=1,
help="Lower limit on number of vectors in the observation pattern to copy",
)
task_opts.add_argument(
"--max-length",
type=int,
default=2,
help="Upper limit on number of vectors in the observation pattern to copy",
)
task_opts.add_argument(
"--min-repeats",
type=int,
default=1,
help="Lower limit on number of copy repeats.",
)
task_opts.add_argument(
"--max-repeats",
type=int,
default=2,
help="Upper limit on number of copy repeats.",
)
train_opts = parser.add_argument_group("Training Options")
train_opts.add_argument(
"--num-training-iterations",
type=int,
default=100_000,
help="Number of iterations to train for.",
)
train_opts.add_argument(
"--report-interval",
type=int,
default=100,
help="Iterations between reports (samples, valid loss).",
)
train_opts.add_argument(
"--checkpoint-dir", default=None, help="Checkpointing directory."
)
train_opts.add_argument(
"--checkpoint-interval",
type=int,
default=None,
help="Checkpointing step interval.",
)
args = parser.parse_args()
if args.checkpoint_dir is None and args.checkpoint_interval is not None:
raise RuntimeError(
"`--checkpoint-dir` is provided but `--checkpoint-interval` is not provided."
)
if args.checkpoint_dir is not None and args.checkpoint_interval is None:
raise RuntimeError(
"`--checkpoint-interval` is provided but `--checkpoint-dir` is not provided."
)
return args
if __name__ == "__main__":
_main()