Initial commit
This commit is contained in:
parent
7544f49f69
commit
1a581b8aff
13
README.md
13
README.md
@ -1,2 +1,11 @@
|
||||
# dnc_pytorch
|
||||
PyTorch Port of DeepMind's Differentiable Neural Computer
|
||||
# Differential Neural Computer in PyTorch
|
||||
|
||||
This is PyTorch port of DeepMind's [Differentiable Neural Computer (DNC)](https://github.com/deepmind/dnc).
|
||||
|
||||
The original code was written with TensorFlow v1, which is not straightforward to set up in modern tech stack, such as Apple Scilicon and Python 3.8+.
|
||||
|
||||
This code requires only PyTorch >= 1.10.
|
||||
|
||||
The code structure and interfaces are kept almost same.
|
||||
|
||||
You can run repeat-copy task with `python train.py`.
|
0
dnc/__init__.py
Normal file
0
dnc/__init__.py
Normal file
353
dnc/access.py
Normal file
353
dnc/access.py
Normal file
@ -0,0 +1,353 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .addressing import (
|
||||
CosineWeights,
|
||||
Freeness,
|
||||
TemporalLinkage,
|
||||
TemporalLinkageState,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AccessState:
|
||||
memory: Tensor # [batch_size, memory_size, word_size]
|
||||
read_weights: Tensor # [batch_size, num_reads, memory_size]
|
||||
write_weights: Tensor # [batch_size, num_writes, memory_size]
|
||||
linkage: TemporalLinkageState
|
||||
# link: [batch_size, num_writes, memory_size, memory_size]
|
||||
# precedence_weights: [batch_size, num_writes, memory_size]
|
||||
usage: Tensor # [batch_size, memory_size]
|
||||
|
||||
|
||||
def _erase_and_write(memory, address, reset_weights, values):
|
||||
"""Module to erase and write in the external memory.
|
||||
|
||||
Erase operation:
|
||||
M_t'(i) = M_{t-1}(i) * (1 - w_t(i) * e_t)
|
||||
|
||||
Add operation:
|
||||
M_t(i) = M_t'(i) + w_t(i) * a_t
|
||||
|
||||
where e are the reset_weights, w the write weights and a the values.
|
||||
|
||||
Args:
|
||||
memory: 3-D tensor of shape `[batch_size, memory_size, word_size]`.
|
||||
address: 3-D tensor `[batch_size, num_writes, memory_size]`.
|
||||
reset_weights: 3-D tensor `[batch_size, num_writes, word_size]`.
|
||||
values: 3-D tensor `[batch_size, num_writes, word_size]`.
|
||||
|
||||
Returns:
|
||||
3-D tensor of shape `[batch_size, num_writes, word_size]`.
|
||||
"""
|
||||
weighted_resets = address.unsqueeze(3) * reset_weights.unsqueeze(2)
|
||||
reset_gate = torch.prod(1 - weighted_resets, dim=1)
|
||||
return memory * reset_gate + torch.matmul(address.transpose(-1, -2), values)
|
||||
|
||||
|
||||
class Reshape(torch.nn.Module):
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
|
||||
def forward(self, input):
|
||||
return input.reshape(self.dims)
|
||||
|
||||
|
||||
class MemoryAccess(torch.nn.Module):
|
||||
"""Access module of the Differentiable Neural Computer.
|
||||
|
||||
This memory module supports multiple read and write heads. It makes use of:
|
||||
|
||||
* `addressing.TemporalLinkage` to track the temporal ordering of writes in
|
||||
memory for each write head.
|
||||
* `addressing.FreenessAllocator` for keeping track of memory usage, where
|
||||
usage increase when a memory location is written to, and decreases when
|
||||
memory is read from that the controller says can be freed.
|
||||
|
||||
Write-address selection is done by an interpolation between content-based
|
||||
lookup and using unused memory.
|
||||
|
||||
Read-address selection is done by an interpolation of content-based lookup
|
||||
and following the link graph in the forward or backwards read direction.
|
||||
"""
|
||||
|
||||
def __init__(self, memory_size=128, word_size=20, num_reads=1, num_writes=1):
|
||||
"""Creates a MemoryAccess module.
|
||||
|
||||
Args:
|
||||
memory_size: The number of memory slots (N in the DNC paper).
|
||||
word_size: The width of each memory slot (W in the DNC paper)
|
||||
num_reads: The number of read heads (R in the DNC paper).
|
||||
num_writes: The number of write heads (fixed at 1 in the paper).
|
||||
name: The name of the module.
|
||||
"""
|
||||
super().__init__()
|
||||
self._memory_size = memory_size
|
||||
self._word_size = word_size
|
||||
self._num_reads = num_reads
|
||||
self._num_writes = num_writes
|
||||
|
||||
self._write_content_weights_mod = CosineWeights(num_writes, word_size)
|
||||
self._read_content_weights_mod = CosineWeights(num_reads, word_size)
|
||||
|
||||
self._linkage = TemporalLinkage(memory_size, num_writes)
|
||||
self._freeness = Freeness(memory_size)
|
||||
|
||||
def _linear(first_dim, second_dim, pre_act=None, post_act=None):
|
||||
"""Returns a linear transformation of `inputs`, followed by a reshape."""
|
||||
mods = []
|
||||
mods.append(torch.nn.LazyLinear(first_dim * second_dim))
|
||||
if pre_act is not None:
|
||||
mods.append(pre_act)
|
||||
mods.append(Reshape([-1, first_dim, second_dim]))
|
||||
if post_act is not None:
|
||||
mods.append(post_act)
|
||||
return torch.nn.Sequential(*mods)
|
||||
|
||||
self._write_vectors = _linear(num_writes, word_size)
|
||||
self._erase_vectors = _linear(num_writes, word_size, pre_act=torch.nn.Sigmoid())
|
||||
self._free_gate = torch.nn.Sequential(
|
||||
torch.nn.LazyLinear(num_reads),
|
||||
torch.nn.Sigmoid(),
|
||||
)
|
||||
self._alloc_gate = torch.nn.Sequential(
|
||||
torch.nn.LazyLinear(num_writes),
|
||||
torch.nn.Sigmoid(),
|
||||
)
|
||||
self._write_gate = torch.nn.Sequential(
|
||||
torch.nn.LazyLinear(num_writes),
|
||||
torch.nn.Sigmoid(),
|
||||
)
|
||||
num_read_modes = 1 + 2 * num_writes
|
||||
self._read_mode = _linear(
|
||||
num_reads, num_read_modes, post_act=torch.nn.Softmax(dim=-1)
|
||||
)
|
||||
self._write_keys = _linear(num_writes, word_size)
|
||||
self._write_strengths = torch.nn.LazyLinear(num_writes)
|
||||
|
||||
self._read_keys = _linear(num_reads, word_size)
|
||||
self._read_strengths = torch.nn.LazyLinear(num_reads)
|
||||
|
||||
def _read_inputs(self, inputs: Tensor) -> Dict[str, Tensor]:
|
||||
"""Applies transformations to `inputs` to get control for this module."""
|
||||
# v_t^i - The vectors to write to memory, for each write head `i`.
|
||||
write_vectors = self._write_vectors(inputs)
|
||||
|
||||
# e_t^i - Amount to erase the memory by before writing, for each write head.
|
||||
erase_vectors = self._erase_vectors(inputs)
|
||||
|
||||
# f_t^j - Amount that the memory at the locations read from at the previous
|
||||
# time step can be declared unused, for each read head `j`.
|
||||
free_gate = self._free_gate(inputs)
|
||||
|
||||
# g_t^{a, i} - Interpolation between writing to unallocated memory and
|
||||
# content-based lookup, for each write head `i`. Note: `a` is simply used to
|
||||
# identify this gate with allocation vs writing (as defined below).
|
||||
allocation_gate = self._alloc_gate(inputs)
|
||||
|
||||
# g_t^{w, i} - Overall gating of write amount for each write head.
|
||||
write_gate = self._write_gate(inputs)
|
||||
|
||||
# \pi_t^j - Mixing between "backwards" and "forwards" positions (for
|
||||
# each write head), and content-based lookup, for each read head.
|
||||
read_mode = self._read_mode(inputs)
|
||||
|
||||
# Parameters for the (read / write) "weights by content matching" modules.
|
||||
write_keys = self._write_keys(inputs)
|
||||
write_strengths = self._write_strengths(inputs)
|
||||
|
||||
read_keys = self._read_keys(inputs)
|
||||
read_strengths = self._read_strengths(inputs)
|
||||
|
||||
result = {
|
||||
"read_content_keys": read_keys,
|
||||
"read_content_strengths": read_strengths,
|
||||
"write_content_keys": write_keys,
|
||||
"write_content_strengths": write_strengths,
|
||||
"write_vectors": write_vectors,
|
||||
"erase_vectors": erase_vectors,
|
||||
"free_gate": free_gate,
|
||||
"allocation_gate": allocation_gate,
|
||||
"write_gate": write_gate,
|
||||
"read_mode": read_mode,
|
||||
}
|
||||
return result
|
||||
|
||||
def _write_weights(
|
||||
self,
|
||||
inputs: Tensor,
|
||||
memory: Tensor,
|
||||
usage: Tensor,
|
||||
) -> Tensor:
|
||||
"""Calculates the memory locations to write to.
|
||||
|
||||
This uses a combination of content-based lookup and finding an unused
|
||||
location in memory, for each write head.
|
||||
|
||||
Args:
|
||||
inputs: Collection of inputs to the access module, including controls for
|
||||
how to chose memory writing, such as the content to look-up and the
|
||||
weighting between content-based and allocation-based addressing.
|
||||
memory: A tensor of shape `[batch_size, memory_size, word_size]`
|
||||
containing the current memory contents.
|
||||
usage: Current memory usage, which is a tensor of shape `[batch_size,
|
||||
memory_size]`, used for allocation-based addressing.
|
||||
|
||||
Returns:
|
||||
tensor of shape `[batch_size, num_writes, memory_size]` indicating where
|
||||
to write to (if anywhere) for each write head.
|
||||
"""
|
||||
# c_t^{w, i} - The content-based weights for each write head.
|
||||
write_content_weights = self._write_content_weights_mod(
|
||||
memory, inputs["write_content_keys"], inputs["write_content_strengths"]
|
||||
)
|
||||
|
||||
# a_t^i - The allocation weights for each write head.
|
||||
write_allocation_weights = self._freeness.write_allocation_weights(
|
||||
usage=usage,
|
||||
write_gates=(inputs["allocation_gate"] * inputs["write_gate"]),
|
||||
num_writes=self._num_writes,
|
||||
)
|
||||
|
||||
# Expands gates over memory locations.
|
||||
allocation_gate = inputs["allocation_gate"].unsqueeze(-1)
|
||||
write_gate = inputs["write_gate"].unsqueeze(-1)
|
||||
|
||||
# w_t^{w, i} - The write weightings for each write head.
|
||||
return write_gate * (
|
||||
allocation_gate * write_allocation_weights
|
||||
+ (1 - allocation_gate) * write_content_weights
|
||||
)
|
||||
|
||||
def _read_weights(
|
||||
self,
|
||||
inputs: Tensor,
|
||||
memory: Tensor,
|
||||
prev_read_weights: Tensor,
|
||||
link: Tensor,
|
||||
) -> Tensor:
|
||||
"""Calculates read weights for each read head.
|
||||
|
||||
The read weights are a combination of following the link graphs in the
|
||||
forward or backward directions from the previous read position, and doing
|
||||
content-based lookup. The interpolation between these different modes is
|
||||
done by `inputs['read_mode']`.
|
||||
|
||||
Args:
|
||||
inputs: Controls for this access module. This contains the content-based
|
||||
keys to lookup, and the weightings for the different read modes.
|
||||
memory: A tensor of shape `[batch_size, memory_size, word_size]`
|
||||
containing the current memory contents to do content-based lookup.
|
||||
prev_read_weights: A tensor of shape `[batch_size, num_reads,
|
||||
memory_size]` containing the previous read locations.
|
||||
link: A tensor of shape `[batch_size, num_writes, memory_size,
|
||||
memory_size]` containing the temporal write transition graphs.
|
||||
|
||||
Returns:
|
||||
A tensor of shape `[batch_size, num_reads, memory_size]` containing the
|
||||
read weights for each read head.
|
||||
"""
|
||||
# c_t^{r, i} - The content weightings for each read head.
|
||||
content_weights = self._read_content_weights_mod(
|
||||
memory, inputs["read_content_keys"], inputs["read_content_strengths"]
|
||||
)
|
||||
|
||||
# Calculates f_t^i and b_t^i.
|
||||
forward_weights = self._linkage.directional_read_weights(
|
||||
link, prev_read_weights, is_forward=True
|
||||
)
|
||||
backward_weights = self._linkage.directional_read_weights(
|
||||
link, prev_read_weights, is_forward=False
|
||||
)
|
||||
|
||||
m = self._num_writes
|
||||
backward_mode = inputs["read_mode"][:, :, :m, None]
|
||||
forward_mode = inputs["read_mode"][:, :, m : 2 * m, None]
|
||||
content_mode = inputs["read_mode"][:, :, None, 2 * m]
|
||||
|
||||
read_weights = (
|
||||
content_mode * content_weights
|
||||
+ (forward_mode * forward_weights).sum(dim=2)
|
||||
+ (backward_mode * backward_weights).sum(dim=2)
|
||||
)
|
||||
return read_weights
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: Tensor,
|
||||
prev_state: Optional[AccessState] = None,
|
||||
) -> Tuple[Tensor, AccessState]:
|
||||
"""Connects the MemoryAccess module into the graph.
|
||||
|
||||
Args:
|
||||
inputs: tensor of shape `[batch_size, input_size]`. This is used to
|
||||
control this access module.
|
||||
prev_state: Instance of `AccessState` containing the previous state.
|
||||
|
||||
Returns:
|
||||
A tuple `(output, next_state)`, where `output` is a tensor of shape
|
||||
`[batch_size, num_reads, word_size]`, and `next_state` is the new
|
||||
`AccessState` named tuple at the current time t.
|
||||
"""
|
||||
if inputs.ndim != 2:
|
||||
raise ValueError("Expected `inputs` to be 2D. Found: {inputs.ndim}.")
|
||||
if prev_state is None:
|
||||
B, device = inputs.shape[0], inputs.device
|
||||
prev_state = AccessState(
|
||||
memory=torch.zeros(
|
||||
(B, self._memory_size, self._word_size), device=device
|
||||
),
|
||||
read_weights=torch.zeros(
|
||||
(B, self._num_reads, self._memory_size), device=device
|
||||
),
|
||||
write_weights=torch.zeros(
|
||||
(B, self._num_writes, self._memory_size), device=device
|
||||
),
|
||||
linkage=None,
|
||||
usage=torch.zeros((B, self._memory_size), device=device),
|
||||
)
|
||||
|
||||
inputs = self._read_inputs(inputs)
|
||||
|
||||
# Update usage using inputs['free_gate'] and previous read & write weights.
|
||||
usage = self._freeness(
|
||||
write_weights=prev_state.write_weights,
|
||||
free_gate=inputs["free_gate"],
|
||||
read_weights=prev_state.read_weights,
|
||||
prev_usage=prev_state.usage,
|
||||
)
|
||||
|
||||
# Write to memory.
|
||||
write_weights = self._write_weights(inputs, prev_state.memory, usage)
|
||||
memory = _erase_and_write(
|
||||
prev_state.memory,
|
||||
address=write_weights,
|
||||
reset_weights=inputs["erase_vectors"],
|
||||
values=inputs["write_vectors"],
|
||||
)
|
||||
|
||||
linkage_state = self._linkage(write_weights, prev_state.linkage)
|
||||
|
||||
# Read from memory.
|
||||
read_weights = self._read_weights(
|
||||
inputs,
|
||||
memory=memory,
|
||||
prev_read_weights=prev_state.read_weights,
|
||||
link=linkage_state.link,
|
||||
)
|
||||
read_words = torch.matmul(read_weights, memory)
|
||||
|
||||
return (
|
||||
read_words,
|
||||
AccessState(
|
||||
memory=memory,
|
||||
read_weights=read_weights,
|
||||
write_weights=write_weights,
|
||||
linkage=linkage_state,
|
||||
usage=usage,
|
||||
),
|
||||
)
|
368
dnc/addressing.py
Normal file
368
dnc/addressing.py
Normal file
@ -0,0 +1,368 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemporalLinkageState:
|
||||
link: Tensor # [batch_size, num_writes, memory_size, memory_size]
|
||||
precedence_weights: Tensor # [batch_size, num_writes, memory_size]
|
||||
|
||||
|
||||
_EPSILON = 1e-6
|
||||
|
||||
|
||||
def _vector_norms(m: Tensor) -> Tensor:
|
||||
norm = torch.sum(m * m, axis=2, keepdim=True)
|
||||
return torch.sqrt(norm + _EPSILON)
|
||||
|
||||
|
||||
def weighted_softmax(activations: Tensor, strengths: Tensor, strengths_op=F.softplus):
|
||||
"""Returns softmax over activations multiplied by positive strengths.
|
||||
|
||||
Args:
|
||||
activations: A tensor of shape `[batch_size, num_heads, memory_size]`, of
|
||||
activations to be transformed. Softmax is taken over the last dimension.
|
||||
strengths: A tensor of shape `[batch_size, num_heads]` containing strengths to
|
||||
multiply by the activations prior to the softmax.
|
||||
strengths_op: An operation to transform strengths before softmax.
|
||||
|
||||
Returns:
|
||||
A tensor of same shape as `activations` with weighted softmax applied.
|
||||
"""
|
||||
transformed_strengths = strengths_op(strengths).unsqueeze(-1)
|
||||
sharp_activations = activations * transformed_strengths
|
||||
softmax = F.softmax(sharp_activations, dim=-1)
|
||||
return softmax
|
||||
|
||||
|
||||
class CosineWeights(torch.nn.Module):
|
||||
def __init__(self, num_heads, word_size, strength_op=F.softplus):
|
||||
"""
|
||||
Args:
|
||||
num_heads: number of memory heads.
|
||||
word_size: memory word size.
|
||||
strength_op: operation to apply to strengths (default softplus).
|
||||
"""
|
||||
super().__init__()
|
||||
self._num_heads = num_heads
|
||||
self._word_size = word_size
|
||||
self._strength_op = strength_op
|
||||
|
||||
def forward(self, memory: Tensor, keys: Tensor, strengths: Tensor) -> Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
memory: A 3-D tensor of shape `[batch_size, memory_size, word_size]`.
|
||||
keys: A 3-D tensor of shape `[batch_size, num_heads, word_size]`.
|
||||
strengths: A 2-D tensor of shape `[batch_size, num_heads]`.
|
||||
|
||||
Returns:
|
||||
Weights tensor of shape `[batch_size, num_heads, memory_size]`.
|
||||
"""
|
||||
dot = torch.matmul(keys, memory.transpose(-1, -2)) # <B, H, M>
|
||||
memory_norms = _vector_norms(memory) # <B, M, 1>
|
||||
key_norms = _vector_norms(keys) # <B, H, 1>
|
||||
norm = torch.matmul(key_norms, memory_norms.transpose(-1, -2)) # <B, H, M>
|
||||
|
||||
similarity = dot / (norm + _EPSILON)
|
||||
|
||||
return weighted_softmax(similarity, strengths, self._strength_op)
|
||||
|
||||
|
||||
class TemporalLinkage(torch.nn.Module):
|
||||
def __init__(self, memory_size, num_writes):
|
||||
super().__init__()
|
||||
self._memory_size = memory_size
|
||||
self._num_writes = num_writes
|
||||
|
||||
def _link(
|
||||
self,
|
||||
prev_link: Tensor,
|
||||
prev_precedence_weights: Tensor,
|
||||
write_weights: Tensor,
|
||||
) -> Tensor:
|
||||
"""Calculates the new link graphs.
|
||||
|
||||
For each write head, the link is a directed graph (represented by a matrix
|
||||
with entries in range [0, 1]) whose vertices are the memory locations, and
|
||||
an edge indicates temporal ordering of writes.
|
||||
|
||||
Args:
|
||||
prev_link: A tensor of shape `[batch_size, num_writes, memory_size,
|
||||
memory_size]` representing the previous link graphs for each write
|
||||
head.
|
||||
prev_precedence_weights: A tensor of shape `[batch_size, num_writes,
|
||||
memory_size]` which is the previous "aggregated" write weights for
|
||||
each write head.
|
||||
write_weights: A tensor of shape `[batch_size, num_writes, memory_size]`
|
||||
containing the new locations in memory written to.
|
||||
|
||||
Returns:
|
||||
A tensor of shape `[batch_size, num_writes, memory_size, memory_size]`
|
||||
containing the new link graphs for each write head.
|
||||
"""
|
||||
batch_size = prev_link.size(0)
|
||||
write_weights_i = write_weights.unsqueeze(3) # <B, W, M, 1>
|
||||
write_weights_j = write_weights.unsqueeze(2) # <B, W, 1, M>
|
||||
prev_precedence_weights_j = prev_precedence_weights.unsqueeze(2) # <B, W, 1, M>
|
||||
|
||||
prev_link_scale = 1 - write_weights_i - write_weights_j # <B, W, M, M>
|
||||
new_link = write_weights_i * prev_precedence_weights_j # <B, W, M, M>
|
||||
link = prev_link_scale * prev_link + new_link # <B, W, M, M>
|
||||
|
||||
# Return the link with the diagonal set to zero, to remove self-looping
|
||||
# edges.
|
||||
mask = (
|
||||
torch.eye(self._memory_size)
|
||||
.repeat(batch_size, self._num_writes, 1, 1)
|
||||
.bool()
|
||||
)
|
||||
link[mask] = 0
|
||||
return link
|
||||
|
||||
def _precedence_weights(
|
||||
self,
|
||||
prev_precedence_weights: Tensor,
|
||||
write_weights: Tensor,
|
||||
) -> Tensor:
|
||||
"""Calculates the new precedence weights given the current write weights.
|
||||
|
||||
The precedence weights are the "aggregated write weights" for each write
|
||||
head, where write weights with sum close to zero will leave the precedence
|
||||
weights unchanged, but with sum close to one will replace the precedence
|
||||
weights.
|
||||
|
||||
Args:
|
||||
prev_precedence_weights: A tensor of shape `[batch_size, num_writes,
|
||||
memory_size]` containing the previous precedence weights.
|
||||
write_weights: A tensor of shape `[batch_size, num_writes, memory_size]`
|
||||
containing the new write weights.
|
||||
|
||||
Returns:
|
||||
A tensor of shape `[batch_size, num_writes, memory_size]` containing the
|
||||
new precedence weights.
|
||||
"""
|
||||
write_sum = write_weights.sum(dim=2, keepdim=True)
|
||||
return (1 - write_sum) * prev_precedence_weights + write_weights
|
||||
|
||||
def forward(
|
||||
self,
|
||||
write_weights: Tensor,
|
||||
prev_state: Optional[TemporalLinkageState] = None,
|
||||
) -> TemporalLinkageState:
|
||||
"""Calculate the updated linkage state given the write weights.
|
||||
|
||||
Args:
|
||||
write_weights: A tensor of shape `[batch_size, num_writes, memory_size]`
|
||||
containing the memory addresses of the different write heads.
|
||||
prev_state: `TemporalLinkageState` tuple containg a tensor `link` of
|
||||
shape `[batch_size, num_writes, memory_size, memory_size]`, and a
|
||||
tensor `precedence_weights` of shape `[batch_size, num_writes,
|
||||
memory_size]` containing the aggregated history of recent writes.
|
||||
|
||||
Returns:
|
||||
A `TemporalLinkageState` tuple `next_state`, which contains the updated
|
||||
link and precedence weights.
|
||||
"""
|
||||
if write_weights.ndim != 3:
|
||||
raise ValueError(
|
||||
f"Expected `write_weights` to be 3D. Found: {write_weights.ndim}"
|
||||
)
|
||||
if (
|
||||
write_weights.size(1) != self._num_writes
|
||||
or write_weights.size(2) != self._memory_size
|
||||
):
|
||||
raise ValueError(
|
||||
"Expected the shape of `write_weights` to be "
|
||||
f"[batch, {self._num_writes}, {self._memory_size}]. "
|
||||
f"Found: {write_weights.shape}."
|
||||
)
|
||||
|
||||
if prev_state is None:
|
||||
B, W, M = write_weights.shape
|
||||
prev_state = TemporalLinkageState(
|
||||
link=torch.zeros((B, W, M, M), device=write_weights.device),
|
||||
precedence_weights=torch.zeros((B, W, M), device=write_weights.device),
|
||||
)
|
||||
|
||||
link = self._link(prev_state.link, prev_state.precedence_weights, write_weights)
|
||||
precedence_weights = self._precedence_weights(
|
||||
prev_state.precedence_weights, write_weights
|
||||
)
|
||||
return TemporalLinkageState(link=link, precedence_weights=precedence_weights)
|
||||
|
||||
def directional_read_weights(
|
||||
self,
|
||||
link: Tensor,
|
||||
prev_read_weights: Tensor,
|
||||
is_forward: bool,
|
||||
) -> Tensor:
|
||||
"""Calculates the forward or the backward read weights.
|
||||
|
||||
For each read head (at a given address), there are `num_writes` link graphs
|
||||
to follow. Thus this function computes a read address for each of the
|
||||
`num_reads * num_writes` pairs of read and write heads.
|
||||
|
||||
Args:
|
||||
link: tensor of shape `[batch_size, num_writes, memory_size,
|
||||
memory_size]` representing the link graphs L_t.
|
||||
prev_read_weights: tensor of shape `[batch_size, num_reads,
|
||||
memory_size]` containing the previous read weights w_{t-1}^r.
|
||||
forward: Boolean indicating whether to follow the "future" direction in
|
||||
the link graph (True) or the "past" direction (False).
|
||||
|
||||
Returns:
|
||||
tensor of shape `[batch_size, num_reads, num_writes, memory_size]`
|
||||
"""
|
||||
# <B, W, R, M>
|
||||
expanded_read_weights = torch.stack(
|
||||
[prev_read_weights for _ in range(self._num_writes)], dim=1
|
||||
)
|
||||
if is_forward:
|
||||
link = link.transpose(-1, -2)
|
||||
result = torch.matmul(expanded_read_weights, link) # <B, W, R, M>
|
||||
return result.permute((0, 2, 1, 3)) # <B, R, W, M>
|
||||
|
||||
|
||||
class Freeness(torch.nn.Module):
|
||||
def __init__(self, memory_size):
|
||||
super().__init__()
|
||||
self._memory_size = memory_size
|
||||
|
||||
def _usage_after_write(
|
||||
self,
|
||||
prev_usage: Tensor,
|
||||
write_weights: Tensor,
|
||||
) -> Tensor:
|
||||
"""Calcualtes the new usage after writing to memory.
|
||||
|
||||
Args:
|
||||
prev_usage: tensor of shape `[batch_size, memory_size]`.
|
||||
write_weights: tensor of shape `[batch_size, num_writes, memory_size]`.
|
||||
|
||||
Returns:
|
||||
New usage, a tensor of shape `[batch_size, memory_size]`.
|
||||
"""
|
||||
write_weights = 1 - torch.prod(1 - write_weights, 1)
|
||||
return prev_usage + (1 - prev_usage) * write_weights
|
||||
|
||||
def _usage_after_read(
|
||||
self, prev_usage: Tensor, free_gate: Tensor, read_weights: Tensor
|
||||
) -> Tensor:
|
||||
"""Calcualtes the new usage after reading and freeing from memory.
|
||||
|
||||
Args:
|
||||
prev_usage: tensor of shape `[batch_size, memory_size]`.
|
||||
free_gate: tensor of shape `[batch_size, num_reads]` with entries in the
|
||||
range [0, 1] indicating the amount that locations read from can be
|
||||
freed.
|
||||
read_weights: tensor of shape `[batch_size, num_reads, memory_size]`.
|
||||
|
||||
Returns:
|
||||
New usage, a tensor of shape `[batch_size, memory_size]`.
|
||||
"""
|
||||
free_gate = free_gate.unsqueeze(-1)
|
||||
free_read_weights = free_gate * read_weights
|
||||
phi = torch.prod(1 - free_read_weights, 1)
|
||||
return prev_usage * phi
|
||||
|
||||
def forward(
|
||||
self,
|
||||
write_weights: Tensor,
|
||||
free_gate: Tensor,
|
||||
read_weights: Tensor,
|
||||
prev_usage: Tensor,
|
||||
) -> Tensor:
|
||||
"""Calculates the new memory usage u_t.
|
||||
|
||||
Memory that was written to in the previous time step will have its usage
|
||||
increased; memory that was read from and the controller says can be "freed"
|
||||
will have its usage decreased.
|
||||
|
||||
Args:
|
||||
write_weights: tensor of shape `[batch_size, num_writes,
|
||||
memory_size]` giving write weights at previous time step.
|
||||
free_gate: tensor of shape `[batch_size, num_reads]` which indicates
|
||||
which read heads read memory that can now be freed.
|
||||
read_weights: tensor of shape `[batch_size, num_reads,
|
||||
memory_size]` giving read weights at previous time step.
|
||||
prev_usage: tensor of shape `[batch_size, memory_size]` giving
|
||||
usage u_{t - 1} at the previous time step, with entries in range
|
||||
[0, 1].
|
||||
|
||||
Returns:
|
||||
tensor of shape `[batch_size, memory_size]` representing updated memory
|
||||
usage.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
usage = self._usage_after_write(prev_usage, write_weights)
|
||||
usage = self._usage_after_read(usage, free_gate, read_weights)
|
||||
return usage
|
||||
|
||||
def _allocation(self, usage: Tensor) -> Tensor:
|
||||
"""Computes allocation by sorting `usage`.
|
||||
|
||||
This corresponds to the value a = a_t[\phi_t[j]] in the paper.
|
||||
|
||||
Args:
|
||||
usage: tensor of shape `[batch_size, memory_size]` indicating current
|
||||
memory usage. This is equal to u_t in the paper when we only have one
|
||||
write head, but for multiple write heads, one should update the usage
|
||||
while iterating through the write heads to take into account the
|
||||
allocation returned by this function.
|
||||
|
||||
Returns:
|
||||
Tensor of shape `[batch_size, memory_size]` corresponding to allocation.
|
||||
"""
|
||||
usage = _EPSILON + (1 - _EPSILON) * usage
|
||||
|
||||
nonusage = 1 - usage
|
||||
sorted_nonusage, indices = torch.topk(nonusage, k=self._memory_size)
|
||||
sorted_usage = 1 - sorted_nonusage
|
||||
|
||||
# emulate tf.cumprod(exclusive=True)
|
||||
sorted_usage = F.pad(sorted_usage, (1, 0), mode="constant", value=1)
|
||||
prod_sorted_usage = torch.cumprod(sorted_usage, dim=1)
|
||||
prod_sorted_usage = prod_sorted_usage[:, :-1]
|
||||
|
||||
sorted_allocation = sorted_nonusage * prod_sorted_usage
|
||||
inverse_indices = torch.argsort(indices)
|
||||
return torch.gather(sorted_allocation, 1, inverse_indices)
|
||||
|
||||
def write_allocation_weights(
|
||||
self,
|
||||
usage: Tensor,
|
||||
write_gates: Tensor,
|
||||
num_writes: Tensor,
|
||||
) -> Tensor:
|
||||
"""Calculates freeness-based locations for writing to.
|
||||
|
||||
This finds unused memory by ranking the memory locations by usage, for each
|
||||
write head. (For more than one write head, we use a "simulated new usage"
|
||||
which takes into account the fact that the previous write head will increase
|
||||
the usage in that area of the memory.)
|
||||
|
||||
Args:
|
||||
usage: A tensor of shape `[batch_size, memory_size]` representing
|
||||
current memory usage.
|
||||
write_gates: A tensor of shape `[batch_size, num_writes]` with values in
|
||||
the range [0, 1] indicating how much each write head does writing
|
||||
based on the address returned here (and hence how much usage
|
||||
increases).
|
||||
num_writes: The number of write heads to calculate write weights for.
|
||||
|
||||
Returns:
|
||||
tensor of shape `[batch_size, num_writes, memory_size]` containing the
|
||||
freeness-based write locations. Note that this isn't scaled by
|
||||
`write_gate`; this scaling must be applied externally.
|
||||
"""
|
||||
write_gates = write_gates.unsqueeze(-1)
|
||||
allocation_weights = []
|
||||
for i in range(num_writes):
|
||||
allocation_weights.append(self._allocation(usage))
|
||||
usage = usage + (1 - usage) * write_gates[:, i, :] * allocation_weights[i]
|
||||
return torch.stack(allocation_weights, dim=1)
|
129
dnc/dnc.py
Normal file
129
dnc/dnc.py
Normal file
@ -0,0 +1,129 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .access import MemoryAccess, AccessState
|
||||
|
||||
|
||||
@dataclass
|
||||
class DNCState:
|
||||
access_output: Tensor # [batch_size, num_reads, word_size]
|
||||
access_state: AccessState
|
||||
# memory: Tensor [batch_size, memory_size, word_size]
|
||||
# read_weights: Tensor # [batch_size, num_reads, memory_size]
|
||||
# write_weights: Tensor # [batch_size, num_writes, memory_size]
|
||||
# linkage: TemporalLinkageState
|
||||
# link: [batch_size, num_writes, memory_size, memory_size]
|
||||
# precedence_weights: [batch_size, num_writes, memory_size]
|
||||
# usage: Tensor # [batch_size, memory_size]
|
||||
controller_state: Tuple[Tensor, Tensor]
|
||||
# h_n: [num_layers, batch_size, projection_size]
|
||||
# c_n: [num_layers, batch_size, hidden_size]
|
||||
|
||||
|
||||
class DNC(torch.nn.Module):
|
||||
"""DNC core module.
|
||||
|
||||
Contains controller and memory access module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_config,
|
||||
controller_config,
|
||||
output_size,
|
||||
clip_value=None,
|
||||
):
|
||||
"""Initializes the DNC core.
|
||||
|
||||
Args:
|
||||
access_config: dictionary of access module configurations.
|
||||
controller_config: dictionary of controller (LSTM) module configurations.
|
||||
output_size: output dimension size of core.
|
||||
clip_value: clips controller and core output values to between
|
||||
`[-clip_value, clip_value]` if specified.
|
||||
Raises:
|
||||
TypeError: if direct_input_size is not None for any access module other
|
||||
than KeyValueMemory.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._controller = torch.nn.LSTMCell(**controller_config)
|
||||
self._access = MemoryAccess(**access_config)
|
||||
self._output = torch.nn.LazyLinear(output_size)
|
||||
if clip_value is None:
|
||||
self._clip = lambda x: x
|
||||
else:
|
||||
self._clip = lambda x: torch.clamp(x, min=-clip_value, max=clip_value)
|
||||
|
||||
def forward(self, inputs: Tensor, prev_state: Optional[DNCState] = None):
|
||||
"""Connects the DNC core into the graph.
|
||||
|
||||
Args:
|
||||
inputs: Tensor input.
|
||||
prev_state: A `DNCState` tuple containing the fields `access_output`,
|
||||
`access_state` and `controller_state`. `access_state` is a 3-D Tensor
|
||||
of shape `[batch_size, num_reads, word_size]` containing read words.
|
||||
`access_state` is a tuple of the access module's state, and
|
||||
`controller_state` is a tuple of controller module's state.
|
||||
|
||||
Returns:
|
||||
A tuple `(output, next_state)` where `output` is a tensor and `next_state`
|
||||
is a `DNCState` tuple containing the fields `access_output`,
|
||||
`access_state`, and `controller_state`.
|
||||
"""
|
||||
if inputs.ndim != 2:
|
||||
raise ValueError(f"Expected `inputs` to be 2D: Found {inputs.ndim}.")
|
||||
if prev_state is None:
|
||||
B, device = inputs.shape[0], inputs.device
|
||||
num_reads = self._access._num_reads
|
||||
word_size = self._access._word_size
|
||||
prev_state = DNCState(
|
||||
access_output=torch.zeros((B, num_reads, word_size), device=device),
|
||||
access_state=None,
|
||||
controller_state=None,
|
||||
)
|
||||
|
||||
def batch_flatten(x):
|
||||
return torch.reshape(x, [x.size(0), -1])
|
||||
|
||||
controller_input = torch.concat(
|
||||
[
|
||||
batch_flatten(inputs), # [batch_size, num_input_feats]
|
||||
batch_flatten(
|
||||
prev_state.access_output
|
||||
), # [batch_size, num_reads*word_size]
|
||||
],
|
||||
dim=1,
|
||||
) # [batch_size, num_input_feats + num_reads * word_size]
|
||||
|
||||
controller_state = self._controller(
|
||||
controller_input, prev_state.controller_state
|
||||
)
|
||||
controller_state = tuple(self._clip(t) for t in controller_state)
|
||||
controller_output = controller_state[0]
|
||||
|
||||
access_output, access_state = self._access(
|
||||
controller_output, prev_state.access_state
|
||||
)
|
||||
|
||||
output = torch.concat(
|
||||
[
|
||||
controller_output, # [batch_size, num_ctrl_feats]
|
||||
batch_flatten(access_output), # [batch_size, num_reads*word_size]
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
output = self._output(output)
|
||||
output = self._clip(output)
|
||||
|
||||
return (
|
||||
output,
|
||||
DNCState(
|
||||
access_output=access_output,
|
||||
access_state=access_state,
|
||||
controller_state=controller_state,
|
||||
),
|
||||
)
|
383
dnc/repeat_copy.py
Normal file
383
dnc/repeat_copy.py
Normal file
@ -0,0 +1,383 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetTensors:
|
||||
observations: Tensor
|
||||
target: Tensor
|
||||
mask: Tensor
|
||||
|
||||
|
||||
def masked_sigmoid_cross_entropy(
|
||||
logits,
|
||||
target,
|
||||
mask,
|
||||
time_average=False,
|
||||
log_prob_in_bits=False,
|
||||
):
|
||||
"""Adds ops to graph which compute the (scalar) NLL of the target sequence.
|
||||
|
||||
The logits parametrize independent bernoulli distributions per time-step and
|
||||
per batch element, and irrelevant time/batch elements are masked out by the
|
||||
mask tensor.
|
||||
|
||||
Args:
|
||||
logits: `Tensor` of activations for which sigmoid(`logits`) gives the
|
||||
bernoulli parameter.
|
||||
target: time-major `Tensor` of target.
|
||||
mask: time-major `Tensor` to be multiplied elementwise with cost T x B cost
|
||||
masking out irrelevant time-steps.
|
||||
time_average: optionally average over the time dimension (sum by default).
|
||||
log_prob_in_bits: iff True express log-probabilities in bits (default nats).
|
||||
|
||||
Returns:
|
||||
A `Tensor` representing the log-probability of the target.
|
||||
"""
|
||||
batch_size = logits.shape[1]
|
||||
|
||||
xent = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
|
||||
loss_time_batch = xent.sum(dim=2)
|
||||
loss_batch = (loss_time_batch * mask).sum(dim=0)
|
||||
|
||||
if time_average:
|
||||
mask_count = mask.sum(dim=0)
|
||||
loss_batch /= mask_count + torch.finfo(loss_batch.dtype).eps
|
||||
|
||||
loss = loss_batch.sum() / batch_size
|
||||
if log_prob_in_bits:
|
||||
loss /= torch.log(2.0)
|
||||
return loss
|
||||
|
||||
|
||||
def bitstring_readable(data, batch_size, model_output=None, whole_batch=False):
|
||||
"""Produce a human readable representation of the sequences in data.
|
||||
|
||||
Args:
|
||||
data: data to be visualised
|
||||
batch_size: size of batch
|
||||
model_output: optional model output tensor to visualize alongside data.
|
||||
whole_batch: whether to visualise the whole batch. Only the first sample
|
||||
will be visualized if False
|
||||
|
||||
Returns:
|
||||
A string used to visualise the data batch
|
||||
"""
|
||||
|
||||
def _readable(datum):
|
||||
return "+" + " ".join([f"{int(x):d}" if x else "-" for x in datum]) + "+"
|
||||
|
||||
obs_batch = data.observations
|
||||
targ_batch = data.target
|
||||
|
||||
batch_strings = []
|
||||
for batch_index in range(batch_size if whole_batch else 1):
|
||||
obs = obs_batch[:, batch_index, :]
|
||||
targ = targ_batch[:, batch_index, :]
|
||||
|
||||
obs_channels = range(obs.shape[1])
|
||||
targ_channels = range(targ.shape[1])
|
||||
obs_channel_strings = [_readable(obs[:, i]) for i in obs_channels]
|
||||
targ_channel_strings = [_readable(targ[:, i]) for i in targ_channels]
|
||||
|
||||
readable_obs = "Observations:\n" + "\n".join(obs_channel_strings)
|
||||
readable_targ = "Targets:\n" + "\n".join(targ_channel_strings)
|
||||
strings = [readable_obs, readable_targ]
|
||||
|
||||
if model_output is not None:
|
||||
output = model_output[:, batch_index, :]
|
||||
output_strings = [_readable(output[:, i]) for i in targ_channels]
|
||||
strings.append("Model Output:\n" + "\n".join(output_strings))
|
||||
|
||||
batch_strings.append("\n\n".join(strings))
|
||||
|
||||
return "\n" + "\n\n\n\n".join(batch_strings)
|
||||
|
||||
|
||||
class RepeatCopy:
|
||||
"""Sequence data generator for the task of repeating a random binary pattern.
|
||||
|
||||
When called, an instance of this class will return a tuple of tensorflow ops
|
||||
(obs, targ, mask), representing an input sequence, target sequence, and
|
||||
binary mask. Each of these ops produces tensors whose first two dimensions
|
||||
represent sequence position and batch index respectively. The value in
|
||||
mask[t, b] is equal to 1 iff a prediction about targ[t, b, :] should be
|
||||
penalized and 0 otherwise.
|
||||
|
||||
For each realisation from this generator, the observation sequence is
|
||||
comprised of I.I.D. uniform-random binary vectors (and some flags).
|
||||
|
||||
The target sequence is comprised of this binary pattern repeated
|
||||
some number of times (and some flags). Before explaining in more detail,
|
||||
let's examine the setup pictorially for a single batch element:
|
||||
|
||||
```none
|
||||
Note: blank space represents 0.
|
||||
|
||||
time ------------------------------------------>
|
||||
|
||||
+-------------------------------+
|
||||
mask: |0000000001111111111111111111111|
|
||||
+-------------------------------+
|
||||
|
||||
+-------------------------------+
|
||||
target: | 1| 'end-marker' channel.
|
||||
| 101100110110011011001 |
|
||||
| 010101001010100101010 |
|
||||
+-------------------------------+
|
||||
|
||||
+-------------------------------+
|
||||
observation: | 1011001 |
|
||||
| 0101010 |
|
||||
|1 | 'start-marker' channel
|
||||
| 3 | 'num-repeats' channel.
|
||||
+-------------------------------+
|
||||
```
|
||||
|
||||
The length of the random pattern and the number of times it is repeated
|
||||
in the target are both discrete random variables distributed according to
|
||||
uniform distributions whose parameters are configured at construction time.
|
||||
|
||||
The obs sequence has two extra channels (components in the trailing dimension)
|
||||
which are used for flags. One channel is marked with a 1 at the first time
|
||||
step and is otherwise equal to 0. The other extra channel is zero until the
|
||||
binary pattern to be repeated ends. At this point, it contains an encoding of
|
||||
the number of times the observation pattern should be repeated. Rather than
|
||||
simply providing this integer number directly, it is normalised so that
|
||||
a neural network may have an easier time representing the number of
|
||||
repetitions internally. To allow a network to be readily evaluated on
|
||||
instances of this task with greater numbers of repetitions, the range with
|
||||
respect to which this encoding is normalised is also configurable by the user.
|
||||
|
||||
As in the diagram, the target sequence is offset to begin directly after the
|
||||
observation sequence; both sequences are padded with zeros to accomplish this,
|
||||
resulting in their lengths being equal. Additional padding is done at the end
|
||||
so that all sequences in a minibatch represent tensors with the same shape.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_bits=6,
|
||||
batch_size=1,
|
||||
min_length=1,
|
||||
max_length=1,
|
||||
min_repeats=1,
|
||||
max_repeats=2,
|
||||
norm_max=10,
|
||||
log_prob_in_bits=False,
|
||||
time_average_cost=False,
|
||||
):
|
||||
"""Creates an instance of RepeatCopy task.
|
||||
|
||||
Args:
|
||||
name: A name for the generator instance (for name scope purposes).
|
||||
num_bits: The dimensionality of each random binary vector.
|
||||
batch_size: Minibatch size per realization.
|
||||
min_length: Lower limit on number of random binary vectors in the
|
||||
observation pattern.
|
||||
max_length: Upper limit on number of random binary vectors in the
|
||||
observation pattern.
|
||||
min_repeats: Lower limit on number of times the obervation pattern
|
||||
is repeated in targ.
|
||||
max_repeats: Upper limit on number of times the observation pattern
|
||||
is repeated in targ.
|
||||
norm_max: Upper limit on uniform distribution w.r.t which the encoding
|
||||
of the number of repetitions presented in the observation sequence
|
||||
is normalised.
|
||||
log_prob_in_bits: By default, log probabilities are expressed in units of
|
||||
nats. If true, express log probabilities in bits.
|
||||
time_average_cost: If true, the cost at each time step will be
|
||||
divided by the `true`, sequence length, the number of non-masked time
|
||||
steps, in each sequence before any subsequent reduction over the time
|
||||
and batch dimensions.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._batch_size = batch_size
|
||||
self._num_bits = num_bits
|
||||
self._min_length = min_length
|
||||
self._max_length = max_length
|
||||
self._min_repeats = min_repeats
|
||||
self._max_repeats = max_repeats
|
||||
self._norm_max = norm_max
|
||||
self._log_prob_in_bits = log_prob_in_bits
|
||||
self._time_average_cost = time_average_cost
|
||||
|
||||
def _normalise(self, val):
|
||||
return val / self._norm_max
|
||||
|
||||
def _unnormalise(self, val):
|
||||
return val * self._norm_max
|
||||
|
||||
@property
|
||||
def time_average_cost(self):
|
||||
return self._time_average_cost
|
||||
|
||||
@property
|
||||
def log_prob_in_bits(self):
|
||||
return self._log_prob_in_bits
|
||||
|
||||
@property
|
||||
def num_bits(self):
|
||||
"""The dimensionality of each random binary vector in a pattern."""
|
||||
return self._num_bits
|
||||
|
||||
@property
|
||||
def target_size(self):
|
||||
"""The dimensionality of the target tensor."""
|
||||
return self._num_bits + 1
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def __call__(self):
|
||||
"""Implements build method which adds ops to graph."""
|
||||
|
||||
# short-hand for private fields.
|
||||
min_length, max_length = self._min_length, self._max_length
|
||||
min_reps, max_reps = self._min_repeats, self._max_repeats
|
||||
num_bits = self.num_bits
|
||||
batch_size = self.batch_size
|
||||
|
||||
# We reserve one dimension for the num-repeats and one for the start-marker.
|
||||
full_obs_size = num_bits + 2
|
||||
# We reserve one target dimension for the end-marker.
|
||||
full_targ_size = num_bits + 1
|
||||
start_end_flag_idx = full_obs_size - 2
|
||||
num_repeats_channel_idx = full_obs_size - 1
|
||||
|
||||
# Samples each batch index's sequence length and the number of repeats.
|
||||
sub_seq_length_batch = torch.randint(
|
||||
low=min_length, high=max_length + 1, size=[batch_size], dtype=torch.int32
|
||||
)
|
||||
num_repeats_batch = torch.randint(
|
||||
low=min_reps, high=max_reps + 1, size=[batch_size], dtype=torch.int32
|
||||
)
|
||||
|
||||
# Pads all the batches to have the same total sequence length.
|
||||
total_length_batch = sub_seq_length_batch * (num_repeats_batch + 1) + 3
|
||||
max_length_batch = total_length_batch.max()
|
||||
residual_length_batch = max_length_batch - total_length_batch
|
||||
|
||||
obs_batch_shape = [max_length_batch, batch_size, full_obs_size]
|
||||
targ_batch_shape = [max_length_batch, batch_size, full_targ_size]
|
||||
mask_batch_trans_shape = [batch_size, max_length_batch]
|
||||
|
||||
obs_tensors = []
|
||||
targ_tensors = []
|
||||
mask_tensors = []
|
||||
|
||||
# Generates patterns for each batch element independently.
|
||||
for batch_index in range(batch_size):
|
||||
sub_seq_len = sub_seq_length_batch[batch_index]
|
||||
num_reps = num_repeats_batch[batch_index]
|
||||
|
||||
# The observation pattern is a sequence of random binary vectors.
|
||||
obs_pattern_shape = [sub_seq_len, num_bits]
|
||||
obs_pattern = torch.randint(low=0, high=2, size=obs_pattern_shape).to(
|
||||
torch.float32
|
||||
)
|
||||
|
||||
# The target pattern is the observation pattern repeated n times.
|
||||
# Some reshaping is required to accomplish the tiling.
|
||||
targ_pattern_shape = [sub_seq_len * num_reps, num_bits]
|
||||
flat_obs_pattern = obs_pattern.reshape([-1])
|
||||
flat_targ_pattern = torch.tile(flat_obs_pattern, [num_reps])
|
||||
targ_pattern = flat_targ_pattern.reshape(targ_pattern_shape)
|
||||
|
||||
# Expand the obs_pattern to have two extra channels for flags.
|
||||
# Concatenate start flag and num_reps flag to the sequence.
|
||||
obs_flag_channel_pad = torch.zeros([sub_seq_len, 2])
|
||||
obs_start_flag = F.one_hot(
|
||||
torch.tensor([start_end_flag_idx]), num_classes=full_obs_size
|
||||
).to(torch.float32)
|
||||
num_reps_flag = F.one_hot(
|
||||
torch.tensor([num_repeats_channel_idx]), num_classes=full_obs_size
|
||||
).to(torch.float32)
|
||||
num_reps_flag *= self._normalise(num_reps)
|
||||
|
||||
# note the concatenation dimensions.
|
||||
obs = torch.concat([obs_pattern, obs_flag_channel_pad], 1)
|
||||
obs = torch.concat([obs_start_flag, obs], 0)
|
||||
obs = torch.concat([obs, num_reps_flag], 0)
|
||||
|
||||
# Now do the same for the targ_pattern (it only has one extra channel).
|
||||
targ_flag_channel_pad = torch.zeros([sub_seq_len * num_reps, 1])
|
||||
targ_end_flag = F.one_hot(
|
||||
torch.tensor([start_end_flag_idx]), num_classes=full_targ_size
|
||||
).to(torch.float32)
|
||||
targ = torch.concat([targ_pattern, targ_flag_channel_pad], 1)
|
||||
targ = torch.concat([targ, targ_end_flag], 0)
|
||||
|
||||
# Concatenate zeros at end of obs and begining of targ.
|
||||
# This aligns them s.t. the target begins as soon as the obs ends.
|
||||
obs_end_pad = torch.zeros([sub_seq_len * num_reps + 1, full_obs_size])
|
||||
targ_start_pad = torch.zeros([sub_seq_len + 2, full_targ_size])
|
||||
|
||||
# The mask is zero during the obs and one during the targ.
|
||||
mask_off = torch.zeros([sub_seq_len + 2])
|
||||
mask_on = torch.ones([sub_seq_len * num_reps + 1])
|
||||
|
||||
obs = torch.concat([obs, obs_end_pad], 0)
|
||||
targ = torch.concat([targ_start_pad, targ], 0)
|
||||
mask = torch.concat([mask_off, mask_on], 0)
|
||||
|
||||
obs_tensors.append(obs)
|
||||
targ_tensors.append(targ)
|
||||
mask_tensors.append(mask)
|
||||
|
||||
# End the loop over batch index.
|
||||
# Compute how much zero padding is needed to make tensors sequences
|
||||
# the same length for all batch elements.
|
||||
residual_obs_pad = [
|
||||
torch.zeros([residual_length_batch[i], full_obs_size])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
residual_targ_pad = [
|
||||
torch.zeros([residual_length_batch[i], full_targ_size])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
residual_mask_pad = [
|
||||
torch.zeros([residual_length_batch[i]]) for i in range(batch_size)
|
||||
]
|
||||
|
||||
# Concatenate the pad to each batch element.
|
||||
obs_tensors = [
|
||||
torch.concat([o, p], 0) for o, p in zip(obs_tensors, residual_obs_pad)
|
||||
]
|
||||
targ_tensors = [
|
||||
torch.concat([t, p], 0) for t, p in zip(targ_tensors, residual_targ_pad)
|
||||
]
|
||||
mask_tensors = [
|
||||
torch.concat([m, p], 0) for m, p in zip(mask_tensors, residual_mask_pad)
|
||||
]
|
||||
|
||||
# Concatenate each batch element into a single tensor.
|
||||
obs = torch.concat(obs_tensors, 1).reshape(obs_batch_shape)
|
||||
targ = torch.concat(targ_tensors, 1).reshape(targ_batch_shape)
|
||||
mask = torch.concat(mask_tensors, 0).reshape(mask_batch_trans_shape).T
|
||||
return DatasetTensors(obs, targ, mask)
|
||||
|
||||
def cost(self, logits, targ, mask):
|
||||
return masked_sigmoid_cross_entropy(
|
||||
logits,
|
||||
targ,
|
||||
mask,
|
||||
time_average=self.time_average_cost,
|
||||
log_prob_in_bits=self.log_prob_in_bits,
|
||||
)
|
||||
|
||||
def to_human_readable(self, data, model_output=None, whole_batch=False):
|
||||
obs = data.observations
|
||||
unnormalised_num_reps_flag = self._unnormalise(obs[:, :, -1:]).round()
|
||||
obs = torch.cat([obs[:, :, :-1], unnormalised_num_reps_flag], dim=2)
|
||||
data = DatasetTensors(
|
||||
observations=obs,
|
||||
target=data.target,
|
||||
mask=data.mask,
|
||||
)
|
||||
return bitstring_readable(data, self.batch_size, model_output, whole_batch)
|
1
requirements.txt
Normal file
1
requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
torch > 1.9
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
155
tests/access_test.py
Normal file
155
tests/access_test.py
Normal file
@ -0,0 +1,155 @@
|
||||
import torch
|
||||
|
||||
from dnc.access import MemoryAccess, AccessState
|
||||
from dnc.addressing import TemporalLinkageState
|
||||
|
||||
from .util import one_hot
|
||||
|
||||
BATCH_SIZE = 2
|
||||
MEMORY_SIZE = 20
|
||||
WORD_SIZE = 6
|
||||
NUM_READS = 2
|
||||
NUM_WRITES = 3
|
||||
TIME_STEPS = 4
|
||||
INPUT_SIZE = 10
|
||||
|
||||
|
||||
def test_memory_access_build_and_train():
|
||||
module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES)
|
||||
inputs = torch.randn((TIME_STEPS, BATCH_SIZE, INPUT_SIZE))
|
||||
|
||||
outputs = []
|
||||
state = None
|
||||
for input in inputs:
|
||||
output, state = module(input, state)
|
||||
outputs.append(output)
|
||||
outputs = torch.stack(outputs, dim=-1)
|
||||
|
||||
targets = torch.rand((TIME_STEPS, BATCH_SIZE, NUM_READS, WORD_SIZE))
|
||||
loss = (output - targets).square().mean()
|
||||
|
||||
optim = torch.optim.SGD(module.parameters(), lr=1)
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
|
||||
|
||||
def test_memory_access_valid_read_mode():
|
||||
module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES)
|
||||
inputs = module._read_inputs(torch.randn((BATCH_SIZE, INPUT_SIZE)))
|
||||
|
||||
# Check that the read modes for each read head constitute a probability
|
||||
# distribution.
|
||||
torch.testing.assert_close(
|
||||
inputs["read_mode"].sum(2), torch.ones([BATCH_SIZE, NUM_READS])
|
||||
)
|
||||
assert torch.all(inputs["read_mode"] >= 0)
|
||||
|
||||
|
||||
def test_memory_access_write_weights():
|
||||
memory = 10 * (torch.rand((BATCH_SIZE, MEMORY_SIZE, WORD_SIZE)) - 0.5)
|
||||
usage = torch.rand((BATCH_SIZE, MEMORY_SIZE))
|
||||
|
||||
allocation_gate = torch.rand((BATCH_SIZE, NUM_WRITES))
|
||||
write_gate = torch.rand((BATCH_SIZE, NUM_WRITES))
|
||||
|
||||
write_gate = torch.rand((BATCH_SIZE, NUM_WRITES))
|
||||
write_content_keys = torch.rand((BATCH_SIZE, NUM_WRITES, WORD_SIZE))
|
||||
write_content_strengths = torch.rand((BATCH_SIZE, NUM_WRITES))
|
||||
|
||||
# Check that turning on allocation gate fully brings the write gate to
|
||||
# the allocation weighting (which we will control by controlling the usage).
|
||||
usage[:, 3] = 0
|
||||
allocation_gate[:, 0] = 1
|
||||
write_gate[:, 0] = 1
|
||||
|
||||
inputs = {
|
||||
"allocation_gate": allocation_gate,
|
||||
"write_gate": write_gate,
|
||||
"write_content_keys": write_content_keys,
|
||||
"write_content_strengths": write_content_strengths,
|
||||
}
|
||||
|
||||
module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES)
|
||||
weights = module._write_weights(inputs, memory, usage)
|
||||
|
||||
# Check the weights sum to their target gating.
|
||||
torch.testing.assert_close(weights.sum(dim=2), write_gate, atol=5e-2, rtol=0)
|
||||
|
||||
# Check that we fully allocated to the third row.
|
||||
weights_0_0_target = one_hot(MEMORY_SIZE, 3, dtype=torch.float32)
|
||||
torch.testing.assert_close(weights[0, 0], weights_0_0_target, atol=1e-3, rtol=0)
|
||||
|
||||
|
||||
def test_memory_access_read_weights():
|
||||
memory = 10 * (torch.rand((BATCH_SIZE, MEMORY_SIZE, WORD_SIZE)) - 0.5)
|
||||
prev_read_weights = torch.rand((BATCH_SIZE, NUM_READS, MEMORY_SIZE))
|
||||
prev_read_weights /= prev_read_weights.sum(2, keepdim=True) + 1
|
||||
|
||||
link = torch.rand((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE, MEMORY_SIZE))
|
||||
# Row and column sums should be at most 1:
|
||||
link /= torch.maximum(link.sum(2, keepdim=True), torch.tensor(1))
|
||||
link /= torch.maximum(link.sum(3, keepdim=True), torch.tensor(1))
|
||||
|
||||
# We query the memory on the third location in memory, and select a large
|
||||
# strength on the query. Then we select a content-based read-mode.
|
||||
read_content_keys = torch.rand((BATCH_SIZE, NUM_READS, WORD_SIZE))
|
||||
read_content_keys[0, 0] = memory[0, 3]
|
||||
read_content_strengths = torch.full(
|
||||
(BATCH_SIZE, NUM_READS), 100.0, dtype=torch.float64
|
||||
)
|
||||
read_mode = torch.rand((BATCH_SIZE, NUM_READS, 1 + 2 * NUM_WRITES))
|
||||
read_mode[0, 0, :] = one_hot(1 + 2 * NUM_WRITES, 2 * NUM_WRITES)
|
||||
inputs = {
|
||||
"read_content_keys": read_content_keys,
|
||||
"read_content_strengths": read_content_strengths,
|
||||
"read_mode": read_mode,
|
||||
}
|
||||
|
||||
module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES)
|
||||
read_weights = module._read_weights(inputs, memory, prev_read_weights, link)
|
||||
|
||||
# read_weights for batch 0, read head 0 should be memory location 3
|
||||
ref = one_hot(MEMORY_SIZE, 3, dtype=torch.float64)
|
||||
torch.testing.assert_close(read_weights[0, 0, :], ref, atol=1e-3, rtol=0)
|
||||
|
||||
|
||||
def test_memory_access_gradients():
|
||||
kwargs = {"dtype": torch.float64, "requires_grad": True}
|
||||
|
||||
inputs = torch.randn((BATCH_SIZE, INPUT_SIZE), **kwargs)
|
||||
memory = torch.zeros((BATCH_SIZE, MEMORY_SIZE, WORD_SIZE), **kwargs)
|
||||
read_weights = torch.zeros((BATCH_SIZE, NUM_READS, MEMORY_SIZE), **kwargs)
|
||||
precedence_weights = torch.zeros((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE), **kwargs)
|
||||
link = torch.zeros((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE, MEMORY_SIZE), **kwargs)
|
||||
|
||||
module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES)
|
||||
|
||||
class Wrapper(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, inputs, memory, read_weights, link, precedence_weights):
|
||||
write_weights = torch.zeros((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE))
|
||||
usage = torch.zeros((BATCH_SIZE, MEMORY_SIZE))
|
||||
|
||||
state = AccessState(
|
||||
memory=memory,
|
||||
read_weights=read_weights,
|
||||
write_weights=write_weights,
|
||||
linkage=TemporalLinkageState(
|
||||
link=link,
|
||||
precedence_weights=precedence_weights,
|
||||
),
|
||||
usage=usage,
|
||||
)
|
||||
output, _ = self.module(inputs, state)
|
||||
return output.sum()
|
||||
|
||||
module = Wrapper(module).to(torch.float64)
|
||||
|
||||
torch.autograd.gradcheck(
|
||||
module,
|
||||
(inputs, memory, read_weights, link, precedence_weights),
|
||||
)
|
338
tests/addressing_test.py
Normal file
338
tests/addressing_test.py
Normal file
@ -0,0 +1,338 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from dnc.addressing import weighted_softmax, CosineWeights, TemporalLinkage, Freeness
|
||||
|
||||
from .util import one_hot
|
||||
|
||||
|
||||
def _test_weighted_softmax(strength_op):
|
||||
batch_size, num_heads, memory_size = 5, 3, 7
|
||||
activations = torch.randn(batch_size, num_heads, memory_size)
|
||||
weights = torch.ones((batch_size, num_heads))
|
||||
|
||||
observed = weighted_softmax(activations, weights, strength_op)
|
||||
expected = torch.stack(
|
||||
[
|
||||
F.softmax(a * strength_op(w).unsqueeze(-1))
|
||||
for a, w in zip(activations, weights)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(observed, expected)
|
||||
|
||||
|
||||
def test_weighted_softmax_identity():
|
||||
_test_weighted_softmax(lambda x: x)
|
||||
|
||||
|
||||
def test_weighted_softmax_softplus():
|
||||
_test_weighted_softmax(F.softplus)
|
||||
|
||||
|
||||
def test_cosine_weights_shape():
|
||||
batch_size, num_heads, memory_size, word_size = 5, 3, 7, 2
|
||||
|
||||
module = CosineWeights(num_heads, word_size)
|
||||
mem = torch.randn([batch_size, memory_size, word_size])
|
||||
keys = torch.randn([batch_size, num_heads, word_size])
|
||||
strengths = torch.randn([batch_size, num_heads])
|
||||
weights = module(mem, keys, strengths)
|
||||
|
||||
assert weights.shape == torch.Size([batch_size, num_heads, memory_size])
|
||||
|
||||
|
||||
def test_cosine_weights_values():
|
||||
batch_size, num_heads, memory_size, word_size = 5, 4, 10, 2
|
||||
|
||||
mem = torch.randn((batch_size, memory_size, word_size))
|
||||
mem[0, 0, 0] = 1
|
||||
mem[0, 0, 1] = 2
|
||||
mem[0, 1, 0] = 3
|
||||
mem[0, 1, 1] = 4
|
||||
mem[0, 2, 0] = 5
|
||||
mem[0, 2, 1] = 6
|
||||
|
||||
keys = torch.randn((batch_size, num_heads, word_size))
|
||||
keys[0, 0, 0] = 5
|
||||
keys[0, 0, 1] = 6
|
||||
keys[0, 1, 0] = 1
|
||||
keys[0, 1, 1] = 2
|
||||
keys[0, 2, 0] = 5
|
||||
keys[0, 2, 1] = 6
|
||||
keys[0, 3, 0] = 3
|
||||
keys[0, 3, 1] = 4
|
||||
|
||||
strengths = torch.randn((batch_size, num_heads))
|
||||
|
||||
module = CosineWeights(num_heads, word_size)
|
||||
result = module(mem, keys, strengths)
|
||||
|
||||
# Manually checks results.
|
||||
strengths_softplus = np.log(1 + np.exp(strengths.numpy()))
|
||||
similarity = np.zeros((memory_size))
|
||||
|
||||
for b in range(batch_size):
|
||||
for h in range(num_heads):
|
||||
key = keys[b, h]
|
||||
key_norm = np.linalg.norm(key)
|
||||
|
||||
for m in range(memory_size):
|
||||
row = mem[b, m]
|
||||
similarity[m] = np.dot(key, row) / (key_norm * np.linalg.norm(row))
|
||||
|
||||
similarity = np.exp(similarity * strengths_softplus[b, h])
|
||||
similarity /= similarity.sum()
|
||||
ref = torch.from_numpy(similarity).to(dtype=torch.float32)
|
||||
torch.testing.assert_close(result[b, h], ref, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
def test_cosine_weights_divide_by_zero():
|
||||
batch_size, num_heads, memory_size, word_size = 5, 4, 10, 2
|
||||
|
||||
module = CosineWeights(num_heads, word_size)
|
||||
keys = torch.randn([batch_size, num_heads, word_size], requires_grad=True)
|
||||
strengths = torch.randn([batch_size, num_heads], requires_grad=True)
|
||||
|
||||
# First row of memory is non-zero to concentrate attention on this location.
|
||||
# Remaining rows are all zero.
|
||||
mem = torch.zeros([batch_size, memory_size, word_size])
|
||||
mem[:, 0, :] = 1
|
||||
mem.requires_grad = True
|
||||
|
||||
output = module(mem, keys, strengths)
|
||||
output.sum().backward()
|
||||
|
||||
assert torch.all(~output.isnan())
|
||||
assert torch.all(~mem.grad.isnan())
|
||||
assert torch.all(~keys.grad.isnan())
|
||||
assert torch.all(~strengths.grad.isnan())
|
||||
|
||||
|
||||
def test_temporal_linkage():
|
||||
batch_size, memory_size, num_reads, num_writes = 7, 4, 11, 5
|
||||
|
||||
module = TemporalLinkage(memory_size=memory_size, num_writes=num_writes)
|
||||
|
||||
state = None
|
||||
num_steps = 5
|
||||
for i in range(num_steps):
|
||||
write_weights = torch.rand([batch_size, num_writes, memory_size])
|
||||
write_weights /= write_weights.sum(2, keepdim=True) + 1
|
||||
|
||||
# Simulate (in final steps) link 0-->1 in head 0 and 3-->2 in head 1
|
||||
if i == num_steps - 2:
|
||||
write_weights[0, 0, :] = one_hot(memory_size, 0)
|
||||
write_weights[0, 1, :] = one_hot(memory_size, 3)
|
||||
elif i == num_steps - 1:
|
||||
write_weights[0, 0, :] = one_hot(memory_size, 1)
|
||||
write_weights[0, 1, :] = one_hot(memory_size, 2)
|
||||
|
||||
state = module(write_weights, state)
|
||||
|
||||
# link should be bounded in range [0, 1]
|
||||
assert torch.all(0 <= state.link.min() <= 1)
|
||||
|
||||
# link diagonal should be zero
|
||||
torch.testing.assert_close(
|
||||
state.link[:, :, range(memory_size), range(memory_size)],
|
||||
torch.zeros([batch_size, num_writes, memory_size]),
|
||||
)
|
||||
|
||||
# link rows and columns should sum to at most 1
|
||||
assert torch.all(state.link.sum(2) <= 1)
|
||||
assert torch.all(state.link.sum(3) <= 1)
|
||||
|
||||
# records our transitions in batch 0: head 0: 0->1, and head 1: 3->2
|
||||
torch.testing.assert_close(
|
||||
state.link[0, 0, :, 0], one_hot(memory_size, 1, dtype=torch.float32)
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
state.link[0, 1, :, 3], one_hot(memory_size, 2, dtype=torch.float32)
|
||||
)
|
||||
|
||||
# Now test calculation of forward and backward read weights
|
||||
prev_read_weights = torch.randn((batch_size, num_reads, memory_size))
|
||||
prev_read_weights[0, 5, :] = one_hot(memory_size, 0) # read 5, posn 0
|
||||
prev_read_weights[0, 6, :] = one_hot(memory_size, 2) # read 6, posn 2
|
||||
|
||||
forward_read_weights = module.directional_read_weights(
|
||||
state.link, prev_read_weights, is_forward=True
|
||||
)
|
||||
backward_read_weights = module.directional_read_weights(
|
||||
state.link, prev_read_weights, is_forward=False
|
||||
)
|
||||
|
||||
# Check directional weights calculated correctly.
|
||||
torch.testing.assert_close(
|
||||
forward_read_weights[0, 5, 0, :], # read=5, write=0
|
||||
one_hot(memory_size, 1, dtype=torch.float32),
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
backward_read_weights[0, 6, 1, :], # read=6, write=1
|
||||
one_hot(memory_size, 3, dtype=torch.float32),
|
||||
)
|
||||
|
||||
|
||||
def test_temporal_linkage_precedence_weights():
|
||||
batch_size, memory_size, num_writes = 7, 3, 5
|
||||
|
||||
module = TemporalLinkage(memory_size=memory_size, num_writes=num_writes)
|
||||
|
||||
prev_precedence_weights = torch.rand(batch_size, num_writes, memory_size)
|
||||
write_weights = torch.rand(batch_size, num_writes, memory_size)
|
||||
|
||||
# These should sum to at most 1 for each write head in each batch.
|
||||
write_weights /= write_weights.sum(2, keepdim=True) + 1
|
||||
prev_precedence_weights /= prev_precedence_weights.sum(2, keepdim=True) + 1
|
||||
|
||||
write_weights[0, 1, :] = 0 # batch 0 head 1: no writing
|
||||
write_weights[1, 2, :] /= write_weights[1, 2, :].sum() # b1 h2: all writing
|
||||
|
||||
precedence_weights = module._precedence_weights(
|
||||
prev_precedence_weights=prev_precedence_weights, write_weights=write_weights
|
||||
)
|
||||
|
||||
# precedence weights should be bounded in range [0, 1]
|
||||
assert torch.all(0 <= precedence_weights)
|
||||
assert torch.all(precedence_weights <= 1)
|
||||
|
||||
# no writing in batch 0, head 1
|
||||
torch.testing.assert_close(
|
||||
precedence_weights[0, 1, :], prev_precedence_weights[0, 1, :]
|
||||
)
|
||||
|
||||
# all writing in batch 1, head 2
|
||||
torch.testing.assert_close(precedence_weights[1, 2, :], write_weights[1, 2, :])
|
||||
|
||||
|
||||
def test_freeness():
|
||||
batch_size, memory_size, num_reads, num_writes = 5, 11, 3, 7
|
||||
|
||||
module = Freeness(memory_size)
|
||||
|
||||
free_gate = torch.rand((batch_size, num_reads))
|
||||
|
||||
# Produce read weights that sum to 1 for each batch and head.
|
||||
prev_read_weights = torch.rand((batch_size, num_reads, memory_size))
|
||||
prev_read_weights[1, :, 3] = 0
|
||||
prev_read_weights /= prev_read_weights.sum(2, keepdim=True)
|
||||
prev_write_weights = torch.rand((batch_size, num_writes, memory_size))
|
||||
prev_write_weights /= prev_write_weights.sum(2, keepdim=True)
|
||||
prev_usage = torch.rand((batch_size, memory_size))
|
||||
|
||||
# Add some special values that allows us to test the behaviour:
|
||||
prev_write_weights[1, 2, 3] = 1
|
||||
prev_read_weights[2, 0, 4] = 1
|
||||
free_gate[2, 0] = 1
|
||||
|
||||
usage = module(prev_write_weights, free_gate, prev_read_weights, prev_usage)
|
||||
|
||||
# Check all usages are between 0 and 1.
|
||||
assert torch.all(0 <= usage)
|
||||
assert torch.all(usage <= 1)
|
||||
|
||||
# Check that the full write at batch 1, position 3 makes it fully used.
|
||||
assert usage[1][3] == 1
|
||||
|
||||
# Check that the full free at batch 2, position 4 makes it fully free.
|
||||
assert usage[2][4] == 0
|
||||
|
||||
|
||||
def test_freeness_write_allocation_weights():
|
||||
batch_size, memory_size, num_writes = 7, 23, 5
|
||||
|
||||
module = Freeness(memory_size)
|
||||
|
||||
usage = torch.rand((batch_size, memory_size))
|
||||
write_gates = torch.rand((batch_size, num_writes))
|
||||
|
||||
# Turn off gates for heads 1 and 3 in batch 0. This doesn't scaling down the
|
||||
# weighting, but it means that the usage doesn't change, so we should get
|
||||
# the same allocation weightings for: (1, 2) and (3, 4) (but all others
|
||||
# being different).
|
||||
write_gates[0, 1] = 0
|
||||
write_gates[0, 3] = 0
|
||||
# and turn heads 0 and 2 on for full effect.
|
||||
write_gates[0, 0] = 1
|
||||
write_gates[0, 2] = 1
|
||||
|
||||
# In batch 1, make one of the usages 0 and another almost 0, so that these
|
||||
# entries get most of the allocation weights for the first and second heads.
|
||||
usage[1] = usage[1] * 0.9 + 0.1 # make sure all entries are in [0.1, 1]
|
||||
usage[1][4] = 0 # write head 0 should get allocated to position 4
|
||||
usage[1][3] = 1e-4 # write head 1 should get allocated to position 3
|
||||
write_gates[1, 0] = 1 # write head 0 fully on
|
||||
write_gates[1, 1] = 1 # write head 1 fully on
|
||||
|
||||
weights = module.write_allocation_weights(
|
||||
usage=usage, write_gates=write_gates, num_writes=num_writes
|
||||
)
|
||||
|
||||
assert torch.all(0 <= weights)
|
||||
assert torch.all(weights <= 1)
|
||||
|
||||
# Check that weights sum to close to 1
|
||||
torch.testing.assert_close(
|
||||
weights.sum(dim=2), torch.ones((batch_size, num_writes)), atol=1e-3, rtol=0
|
||||
)
|
||||
|
||||
# Check the same / different allocation weight pairs as described above.
|
||||
assert (weights[0, 0, :] - weights[0, 1, :]).abs().max() > 0.1
|
||||
torch.testing.assert_close(weights[0, 1, :], weights[0, 2, :])
|
||||
assert (weights[0, 2, :] - weights[0, 3, :]).abs().max() > 0.1
|
||||
torch.testing.assert_close(weights[0, 3, :], weights[0, 4, :])
|
||||
|
||||
torch.testing.assert_close(
|
||||
weights[1][0], one_hot(memory_size, 4, dtype=torch.float32), atol=1e-3, rtol=0
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
weights[1][1], one_hot(memory_size, 3, dtype=torch.float32), atol=1e-3, rtol=0
|
||||
)
|
||||
|
||||
|
||||
def test_freeness_write_allocation_weights_gradient():
|
||||
batch_size, memory_size, num_writes = 7, 5, 3
|
||||
|
||||
module = Freeness(memory_size).to(torch.float64)
|
||||
|
||||
usage = torch.rand(
|
||||
(batch_size, memory_size), dtype=torch.float64, requires_grad=True
|
||||
)
|
||||
write_gates = torch.rand(
|
||||
(batch_size, num_writes), dtype=torch.float64, requires_grad=True
|
||||
)
|
||||
|
||||
def func(usage, write_gates):
|
||||
return module.write_allocation_weights(usage, write_gates, num_writes)
|
||||
|
||||
torch.autograd.gradcheck(func, (usage, write_gates))
|
||||
|
||||
|
||||
def test_freeness_allocation():
|
||||
batch_size, memory_size = 7, 13
|
||||
|
||||
usage = torch.rand((batch_size, memory_size))
|
||||
module = Freeness(memory_size)
|
||||
allocation = module._allocation(usage)
|
||||
|
||||
# 1. Test that max allocation goes to min usage, and vice versa.
|
||||
assert torch.all(usage.argmin(dim=1) == allocation.argmax(dim=1))
|
||||
assert torch.all(usage.argmax(dim=1) == allocation.argmin(dim=1))
|
||||
|
||||
# 2. Test that allocations sum to almost 1.
|
||||
torch.testing.assert_close(
|
||||
allocation.sum(dim=1), torch.ones(batch_size), atol=0.01, rtol=0
|
||||
)
|
||||
|
||||
|
||||
def test_freeness_allocation_gradient():
|
||||
batch_size, memory_size = 1, 5
|
||||
|
||||
usage = torch.rand(
|
||||
(batch_size, memory_size), dtype=torch.float64, requires_grad=True
|
||||
)
|
||||
module = Freeness(memory_size).to(torch.float64)
|
||||
|
||||
torch.autograd.gradcheck(module._allocation, (usage,))
|
42
tests/dnc_test.py
Normal file
42
tests/dnc_test.py
Normal file
@ -0,0 +1,42 @@
|
||||
import torch
|
||||
|
||||
from dnc.dnc import DNC
|
||||
|
||||
|
||||
def test_dnc():
|
||||
"""smoke test"""
|
||||
memory_size = 16
|
||||
word_size = 16
|
||||
num_reads = 4
|
||||
num_writes = 1
|
||||
|
||||
clip_value = 20
|
||||
|
||||
input_size = 4
|
||||
hidden_size = 64
|
||||
output_size = input_size
|
||||
|
||||
batch_size = 16
|
||||
time_steps = 64
|
||||
|
||||
access_config = {
|
||||
"memory_size": memory_size,
|
||||
"word_size": word_size,
|
||||
"num_reads": num_reads,
|
||||
"num_writes": num_writes,
|
||||
}
|
||||
controller_config = {
|
||||
"input_size": input_size + num_reads * word_size,
|
||||
"hidden_size": hidden_size,
|
||||
}
|
||||
|
||||
dnc = DNC(
|
||||
access_config=access_config,
|
||||
controller_config=controller_config,
|
||||
output_size=output_size,
|
||||
clip_value=clip_value,
|
||||
)
|
||||
inputs = torch.randn((time_steps, batch_size, input_size))
|
||||
state = None
|
||||
for input in inputs:
|
||||
output, state = dnc(input, state)
|
9
tests/util.py
Normal file
9
tests/util.py
Normal file
@ -0,0 +1,9 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def one_hot(length, index, dtype=None):
|
||||
val = F.one_hot(torch.tensor(index), num_classes=length)
|
||||
if dtype is not None:
|
||||
val = val.to(dtype=dtype)
|
||||
return val
|
205
train.py
Normal file
205
train.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""Example script to train the DNC on a repeated copy task."""
|
||||
import os
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from dnc.repeat_copy import RepeatCopy
|
||||
from dnc.dnc import DNC
|
||||
|
||||
_LG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _main():
|
||||
args = _parse_args()
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s: %(message)s")
|
||||
|
||||
dataset = RepeatCopy(
|
||||
args.num_bits,
|
||||
args.batch_size,
|
||||
args.min_length,
|
||||
args.max_length,
|
||||
args.min_repeats,
|
||||
args.max_repeats,
|
||||
)
|
||||
|
||||
dnc = DNC(
|
||||
access_config={
|
||||
"memory_size": args.memory_size,
|
||||
"word_size": args.word_size,
|
||||
"num_reads": args.num_read_heads,
|
||||
"num_writes": args.num_write_heads,
|
||||
},
|
||||
controller_config={
|
||||
"input_size": args.num_bits + 2 + args.num_read_heads * args.word_size,
|
||||
"hidden_size": args.hidden_size,
|
||||
},
|
||||
output_size=dataset.target_size,
|
||||
clip_value=args.clip_value,
|
||||
)
|
||||
|
||||
optimizer = torch.optim.RMSprop(dnc.parameters(), lr=args.lr, eps=args.eps)
|
||||
|
||||
_run_train_loop(
|
||||
dnc,
|
||||
dataset,
|
||||
optimizer,
|
||||
args.num_training_iterations,
|
||||
args.report_interval,
|
||||
args.checkpoint_interval,
|
||||
args.checkpoint_dir,
|
||||
)
|
||||
|
||||
|
||||
def _run_train_loop(
|
||||
dnc,
|
||||
dataset,
|
||||
optimizer,
|
||||
num_training,
|
||||
report_interval,
|
||||
checkpoint_interval,
|
||||
checkpoint_dir,
|
||||
):
|
||||
total_loss = 0
|
||||
for i in range(num_training):
|
||||
batch = dataset()
|
||||
state = None
|
||||
outputs = []
|
||||
for inputs in batch.observations:
|
||||
output, state = dnc(inputs, state)
|
||||
outputs.append(output)
|
||||
outputs = torch.stack(outputs, 0)
|
||||
loss = dataset.cost(outputs, batch.target, batch.mask)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
if (i + 1) % report_interval == 0:
|
||||
outputs = torch.round(batch.mask.unsqueeze(-1) * torch.sigmoid(outputs))
|
||||
dataset_string = dataset.to_human_readable(batch, outputs)
|
||||
_LG.info(f"{i}: Avg training loss {total_loss / report_interval}")
|
||||
_LG.info(dataset_string)
|
||||
total_loss = 0
|
||||
if checkpoint_interval is not None and (i + 1) % checkpoint_interval == 0:
|
||||
path = os.path.join(checkpoint_dir, "model.pt")
|
||||
torch.save(dnc.state_dict(), path)
|
||||
|
||||
|
||||
def _parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
description=__doc__,
|
||||
)
|
||||
model_opts = parser.add_argument_group("Model Parameters")
|
||||
model_opts.add_argument(
|
||||
"--hidden-size", type=int, default=64, help="Size of LSTM hidden layer."
|
||||
)
|
||||
model_opts.add_argument(
|
||||
"--memory-size", type=int, default=16, help="The number of memory slots."
|
||||
)
|
||||
model_opts.add_argument(
|
||||
"--word-size", type=int, default=16, help="The width of each memory slot."
|
||||
)
|
||||
model_opts.add_argument(
|
||||
"--num-write-heads", type=int, default=1, help="Number of memory write heads."
|
||||
)
|
||||
model_opts.add_argument(
|
||||
"--num-read-heads", type=int, default=4, help="Number of memory read heads."
|
||||
)
|
||||
model_opts.add_argument(
|
||||
"--clip-value",
|
||||
type=float,
|
||||
default=20,
|
||||
help="Maximum absolute value of controller and dnc outputs.",
|
||||
)
|
||||
|
||||
optim_opts = parser.add_argument_group("Optimizer Parameters")
|
||||
optim_opts.add_argument(
|
||||
"--max-grad-norm", type=float, default=50, help="Gradient clipping norm limit."
|
||||
)
|
||||
optim_opts.add_argument(
|
||||
"--learning-rate",
|
||||
"--lr",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
dest="lr",
|
||||
help="Optimizer learning rate.",
|
||||
)
|
||||
optim_opts.add_argument(
|
||||
"--optimizer-epsilon",
|
||||
type=float,
|
||||
default=1e-10,
|
||||
dest="eps",
|
||||
help="Epsilon used for RMSProp optimizer.",
|
||||
)
|
||||
|
||||
task_opts = parser.add_argument_group("Task Parameters")
|
||||
task_opts.add_argument(
|
||||
"--batch-size", type=int, default=16, help="Batch size for training"
|
||||
)
|
||||
task_opts.add_argument(
|
||||
"--num-bits", type=int, default=4, help="Dimensionality of each vector to copy"
|
||||
)
|
||||
task_opts.add_argument(
|
||||
"--min-length",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Lower limit on number of vectors in the observation pattern to copy",
|
||||
)
|
||||
task_opts.add_argument(
|
||||
"--max-length",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Upper limit on number of vectors in the observation pattern to copy",
|
||||
)
|
||||
task_opts.add_argument(
|
||||
"--min-repeats",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Lower limit on number of copy repeats.",
|
||||
)
|
||||
task_opts.add_argument(
|
||||
"--max-repeats",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Upper limit on number of copy repeats.",
|
||||
)
|
||||
|
||||
train_opts = parser.add_argument_group("Training Options")
|
||||
train_opts.add_argument(
|
||||
"--num-training-iterations",
|
||||
type=int,
|
||||
default=100_000,
|
||||
help="Number of iterations to train for.",
|
||||
)
|
||||
train_opts.add_argument(
|
||||
"--report-interval",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Iterations between reports (samples, valid loss).",
|
||||
)
|
||||
train_opts.add_argument(
|
||||
"--checkpoint-dir", default=None, help="Checkpointing directory."
|
||||
)
|
||||
train_opts.add_argument(
|
||||
"--checkpoint-interval",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Checkpointing step interval.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.checkpoint_dir is None and args.checkpoint_interval is not None:
|
||||
raise RuntimeError(
|
||||
"`--checkpoint-dir` is provided but `--checkpoint-interval` is not provided."
|
||||
)
|
||||
if args.checkpoint_dir is not None and args.checkpoint_interval is None:
|
||||
raise RuntimeError(
|
||||
"`--checkpoint-interval` is provided but `--checkpoint-dir` is not provided."
|
||||
)
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_main()
|
Loading…
Reference in New Issue
Block a user