diff --git a/README.md b/README.md index 50183e7..e071952 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,11 @@ -# dnc_pytorch -PyTorch Port of DeepMind's Differentiable Neural Computer +# Differential Neural Computer in PyTorch + +This is PyTorch port of DeepMind's [Differentiable Neural Computer (DNC)](https://github.com/deepmind/dnc). + +The original code was written with TensorFlow v1, which is not straightforward to set up in modern tech stack, such as Apple Scilicon and Python 3.8+. + +This code requires only PyTorch >= 1.10. + +The code structure and interfaces are kept almost same. + +You can run repeat-copy task with `python train.py`. \ No newline at end of file diff --git a/dnc/__init__.py b/dnc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dnc/access.py b/dnc/access.py new file mode 100644 index 0000000..b2f1f05 --- /dev/null +++ b/dnc/access.py @@ -0,0 +1,353 @@ +from dataclasses import dataclass +from typing import Dict, Tuple, Optional + +import torch +from torch import Tensor + +from .addressing import ( + CosineWeights, + Freeness, + TemporalLinkage, + TemporalLinkageState, +) + + +@dataclass +class AccessState: + memory: Tensor # [batch_size, memory_size, word_size] + read_weights: Tensor # [batch_size, num_reads, memory_size] + write_weights: Tensor # [batch_size, num_writes, memory_size] + linkage: TemporalLinkageState + # link: [batch_size, num_writes, memory_size, memory_size] + # precedence_weights: [batch_size, num_writes, memory_size] + usage: Tensor # [batch_size, memory_size] + + +def _erase_and_write(memory, address, reset_weights, values): + """Module to erase and write in the external memory. + + Erase operation: + M_t'(i) = M_{t-1}(i) * (1 - w_t(i) * e_t) + + Add operation: + M_t(i) = M_t'(i) + w_t(i) * a_t + + where e are the reset_weights, w the write weights and a the values. + + Args: + memory: 3-D tensor of shape `[batch_size, memory_size, word_size]`. + address: 3-D tensor `[batch_size, num_writes, memory_size]`. + reset_weights: 3-D tensor `[batch_size, num_writes, word_size]`. + values: 3-D tensor `[batch_size, num_writes, word_size]`. + + Returns: + 3-D tensor of shape `[batch_size, num_writes, word_size]`. + """ + weighted_resets = address.unsqueeze(3) * reset_weights.unsqueeze(2) + reset_gate = torch.prod(1 - weighted_resets, dim=1) + return memory * reset_gate + torch.matmul(address.transpose(-1, -2), values) + + +class Reshape(torch.nn.Module): + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, input): + return input.reshape(self.dims) + + +class MemoryAccess(torch.nn.Module): + """Access module of the Differentiable Neural Computer. + + This memory module supports multiple read and write heads. It makes use of: + + * `addressing.TemporalLinkage` to track the temporal ordering of writes in + memory for each write head. + * `addressing.FreenessAllocator` for keeping track of memory usage, where + usage increase when a memory location is written to, and decreases when + memory is read from that the controller says can be freed. + + Write-address selection is done by an interpolation between content-based + lookup and using unused memory. + + Read-address selection is done by an interpolation of content-based lookup + and following the link graph in the forward or backwards read direction. + """ + + def __init__(self, memory_size=128, word_size=20, num_reads=1, num_writes=1): + """Creates a MemoryAccess module. + + Args: + memory_size: The number of memory slots (N in the DNC paper). + word_size: The width of each memory slot (W in the DNC paper) + num_reads: The number of read heads (R in the DNC paper). + num_writes: The number of write heads (fixed at 1 in the paper). + name: The name of the module. + """ + super().__init__() + self._memory_size = memory_size + self._word_size = word_size + self._num_reads = num_reads + self._num_writes = num_writes + + self._write_content_weights_mod = CosineWeights(num_writes, word_size) + self._read_content_weights_mod = CosineWeights(num_reads, word_size) + + self._linkage = TemporalLinkage(memory_size, num_writes) + self._freeness = Freeness(memory_size) + + def _linear(first_dim, second_dim, pre_act=None, post_act=None): + """Returns a linear transformation of `inputs`, followed by a reshape.""" + mods = [] + mods.append(torch.nn.LazyLinear(first_dim * second_dim)) + if pre_act is not None: + mods.append(pre_act) + mods.append(Reshape([-1, first_dim, second_dim])) + if post_act is not None: + mods.append(post_act) + return torch.nn.Sequential(*mods) + + self._write_vectors = _linear(num_writes, word_size) + self._erase_vectors = _linear(num_writes, word_size, pre_act=torch.nn.Sigmoid()) + self._free_gate = torch.nn.Sequential( + torch.nn.LazyLinear(num_reads), + torch.nn.Sigmoid(), + ) + self._alloc_gate = torch.nn.Sequential( + torch.nn.LazyLinear(num_writes), + torch.nn.Sigmoid(), + ) + self._write_gate = torch.nn.Sequential( + torch.nn.LazyLinear(num_writes), + torch.nn.Sigmoid(), + ) + num_read_modes = 1 + 2 * num_writes + self._read_mode = _linear( + num_reads, num_read_modes, post_act=torch.nn.Softmax(dim=-1) + ) + self._write_keys = _linear(num_writes, word_size) + self._write_strengths = torch.nn.LazyLinear(num_writes) + + self._read_keys = _linear(num_reads, word_size) + self._read_strengths = torch.nn.LazyLinear(num_reads) + + def _read_inputs(self, inputs: Tensor) -> Dict[str, Tensor]: + """Applies transformations to `inputs` to get control for this module.""" + # v_t^i - The vectors to write to memory, for each write head `i`. + write_vectors = self._write_vectors(inputs) + + # e_t^i - Amount to erase the memory by before writing, for each write head. + erase_vectors = self._erase_vectors(inputs) + + # f_t^j - Amount that the memory at the locations read from at the previous + # time step can be declared unused, for each read head `j`. + free_gate = self._free_gate(inputs) + + # g_t^{a, i} - Interpolation between writing to unallocated memory and + # content-based lookup, for each write head `i`. Note: `a` is simply used to + # identify this gate with allocation vs writing (as defined below). + allocation_gate = self._alloc_gate(inputs) + + # g_t^{w, i} - Overall gating of write amount for each write head. + write_gate = self._write_gate(inputs) + + # \pi_t^j - Mixing between "backwards" and "forwards" positions (for + # each write head), and content-based lookup, for each read head. + read_mode = self._read_mode(inputs) + + # Parameters for the (read / write) "weights by content matching" modules. + write_keys = self._write_keys(inputs) + write_strengths = self._write_strengths(inputs) + + read_keys = self._read_keys(inputs) + read_strengths = self._read_strengths(inputs) + + result = { + "read_content_keys": read_keys, + "read_content_strengths": read_strengths, + "write_content_keys": write_keys, + "write_content_strengths": write_strengths, + "write_vectors": write_vectors, + "erase_vectors": erase_vectors, + "free_gate": free_gate, + "allocation_gate": allocation_gate, + "write_gate": write_gate, + "read_mode": read_mode, + } + return result + + def _write_weights( + self, + inputs: Tensor, + memory: Tensor, + usage: Tensor, + ) -> Tensor: + """Calculates the memory locations to write to. + + This uses a combination of content-based lookup and finding an unused + location in memory, for each write head. + + Args: + inputs: Collection of inputs to the access module, including controls for + how to chose memory writing, such as the content to look-up and the + weighting between content-based and allocation-based addressing. + memory: A tensor of shape `[batch_size, memory_size, word_size]` + containing the current memory contents. + usage: Current memory usage, which is a tensor of shape `[batch_size, + memory_size]`, used for allocation-based addressing. + + Returns: + tensor of shape `[batch_size, num_writes, memory_size]` indicating where + to write to (if anywhere) for each write head. + """ + # c_t^{w, i} - The content-based weights for each write head. + write_content_weights = self._write_content_weights_mod( + memory, inputs["write_content_keys"], inputs["write_content_strengths"] + ) + + # a_t^i - The allocation weights for each write head. + write_allocation_weights = self._freeness.write_allocation_weights( + usage=usage, + write_gates=(inputs["allocation_gate"] * inputs["write_gate"]), + num_writes=self._num_writes, + ) + + # Expands gates over memory locations. + allocation_gate = inputs["allocation_gate"].unsqueeze(-1) + write_gate = inputs["write_gate"].unsqueeze(-1) + + # w_t^{w, i} - The write weightings for each write head. + return write_gate * ( + allocation_gate * write_allocation_weights + + (1 - allocation_gate) * write_content_weights + ) + + def _read_weights( + self, + inputs: Tensor, + memory: Tensor, + prev_read_weights: Tensor, + link: Tensor, + ) -> Tensor: + """Calculates read weights for each read head. + + The read weights are a combination of following the link graphs in the + forward or backward directions from the previous read position, and doing + content-based lookup. The interpolation between these different modes is + done by `inputs['read_mode']`. + + Args: + inputs: Controls for this access module. This contains the content-based + keys to lookup, and the weightings for the different read modes. + memory: A tensor of shape `[batch_size, memory_size, word_size]` + containing the current memory contents to do content-based lookup. + prev_read_weights: A tensor of shape `[batch_size, num_reads, + memory_size]` containing the previous read locations. + link: A tensor of shape `[batch_size, num_writes, memory_size, + memory_size]` containing the temporal write transition graphs. + + Returns: + A tensor of shape `[batch_size, num_reads, memory_size]` containing the + read weights for each read head. + """ + # c_t^{r, i} - The content weightings for each read head. + content_weights = self._read_content_weights_mod( + memory, inputs["read_content_keys"], inputs["read_content_strengths"] + ) + + # Calculates f_t^i and b_t^i. + forward_weights = self._linkage.directional_read_weights( + link, prev_read_weights, is_forward=True + ) + backward_weights = self._linkage.directional_read_weights( + link, prev_read_weights, is_forward=False + ) + + m = self._num_writes + backward_mode = inputs["read_mode"][:, :, :m, None] + forward_mode = inputs["read_mode"][:, :, m : 2 * m, None] + content_mode = inputs["read_mode"][:, :, None, 2 * m] + + read_weights = ( + content_mode * content_weights + + (forward_mode * forward_weights).sum(dim=2) + + (backward_mode * backward_weights).sum(dim=2) + ) + return read_weights + + def forward( + self, + inputs: Tensor, + prev_state: Optional[AccessState] = None, + ) -> Tuple[Tensor, AccessState]: + """Connects the MemoryAccess module into the graph. + + Args: + inputs: tensor of shape `[batch_size, input_size]`. This is used to + control this access module. + prev_state: Instance of `AccessState` containing the previous state. + + Returns: + A tuple `(output, next_state)`, where `output` is a tensor of shape + `[batch_size, num_reads, word_size]`, and `next_state` is the new + `AccessState` named tuple at the current time t. + """ + if inputs.ndim != 2: + raise ValueError("Expected `inputs` to be 2D. Found: {inputs.ndim}.") + if prev_state is None: + B, device = inputs.shape[0], inputs.device + prev_state = AccessState( + memory=torch.zeros( + (B, self._memory_size, self._word_size), device=device + ), + read_weights=torch.zeros( + (B, self._num_reads, self._memory_size), device=device + ), + write_weights=torch.zeros( + (B, self._num_writes, self._memory_size), device=device + ), + linkage=None, + usage=torch.zeros((B, self._memory_size), device=device), + ) + + inputs = self._read_inputs(inputs) + + # Update usage using inputs['free_gate'] and previous read & write weights. + usage = self._freeness( + write_weights=prev_state.write_weights, + free_gate=inputs["free_gate"], + read_weights=prev_state.read_weights, + prev_usage=prev_state.usage, + ) + + # Write to memory. + write_weights = self._write_weights(inputs, prev_state.memory, usage) + memory = _erase_and_write( + prev_state.memory, + address=write_weights, + reset_weights=inputs["erase_vectors"], + values=inputs["write_vectors"], + ) + + linkage_state = self._linkage(write_weights, prev_state.linkage) + + # Read from memory. + read_weights = self._read_weights( + inputs, + memory=memory, + prev_read_weights=prev_state.read_weights, + link=linkage_state.link, + ) + read_words = torch.matmul(read_weights, memory) + + return ( + read_words, + AccessState( + memory=memory, + read_weights=read_weights, + write_weights=write_weights, + linkage=linkage_state, + usage=usage, + ), + ) diff --git a/dnc/addressing.py b/dnc/addressing.py new file mode 100644 index 0000000..de7a262 --- /dev/null +++ b/dnc/addressing.py @@ -0,0 +1,368 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import Tensor +import torch.nn.functional as F + + +@dataclass +class TemporalLinkageState: + link: Tensor # [batch_size, num_writes, memory_size, memory_size] + precedence_weights: Tensor # [batch_size, num_writes, memory_size] + + +_EPSILON = 1e-6 + + +def _vector_norms(m: Tensor) -> Tensor: + norm = torch.sum(m * m, axis=2, keepdim=True) + return torch.sqrt(norm + _EPSILON) + + +def weighted_softmax(activations: Tensor, strengths: Tensor, strengths_op=F.softplus): + """Returns softmax over activations multiplied by positive strengths. + + Args: + activations: A tensor of shape `[batch_size, num_heads, memory_size]`, of + activations to be transformed. Softmax is taken over the last dimension. + strengths: A tensor of shape `[batch_size, num_heads]` containing strengths to + multiply by the activations prior to the softmax. + strengths_op: An operation to transform strengths before softmax. + + Returns: + A tensor of same shape as `activations` with weighted softmax applied. + """ + transformed_strengths = strengths_op(strengths).unsqueeze(-1) + sharp_activations = activations * transformed_strengths + softmax = F.softmax(sharp_activations, dim=-1) + return softmax + + +class CosineWeights(torch.nn.Module): + def __init__(self, num_heads, word_size, strength_op=F.softplus): + """ + Args: + num_heads: number of memory heads. + word_size: memory word size. + strength_op: operation to apply to strengths (default softplus). + """ + super().__init__() + self._num_heads = num_heads + self._word_size = word_size + self._strength_op = strength_op + + def forward(self, memory: Tensor, keys: Tensor, strengths: Tensor) -> Tensor: + """ + + Args: + memory: A 3-D tensor of shape `[batch_size, memory_size, word_size]`. + keys: A 3-D tensor of shape `[batch_size, num_heads, word_size]`. + strengths: A 2-D tensor of shape `[batch_size, num_heads]`. + + Returns: + Weights tensor of shape `[batch_size, num_heads, memory_size]`. + """ + dot = torch.matmul(keys, memory.transpose(-1, -2)) # + memory_norms = _vector_norms(memory) # + key_norms = _vector_norms(keys) # + norm = torch.matmul(key_norms, memory_norms.transpose(-1, -2)) # + + similarity = dot / (norm + _EPSILON) + + return weighted_softmax(similarity, strengths, self._strength_op) + + +class TemporalLinkage(torch.nn.Module): + def __init__(self, memory_size, num_writes): + super().__init__() + self._memory_size = memory_size + self._num_writes = num_writes + + def _link( + self, + prev_link: Tensor, + prev_precedence_weights: Tensor, + write_weights: Tensor, + ) -> Tensor: + """Calculates the new link graphs. + + For each write head, the link is a directed graph (represented by a matrix + with entries in range [0, 1]) whose vertices are the memory locations, and + an edge indicates temporal ordering of writes. + + Args: + prev_link: A tensor of shape `[batch_size, num_writes, memory_size, + memory_size]` representing the previous link graphs for each write + head. + prev_precedence_weights: A tensor of shape `[batch_size, num_writes, + memory_size]` which is the previous "aggregated" write weights for + each write head. + write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` + containing the new locations in memory written to. + + Returns: + A tensor of shape `[batch_size, num_writes, memory_size, memory_size]` + containing the new link graphs for each write head. + """ + batch_size = prev_link.size(0) + write_weights_i = write_weights.unsqueeze(3) # + write_weights_j = write_weights.unsqueeze(2) # + prev_precedence_weights_j = prev_precedence_weights.unsqueeze(2) # + + prev_link_scale = 1 - write_weights_i - write_weights_j # + new_link = write_weights_i * prev_precedence_weights_j # + link = prev_link_scale * prev_link + new_link # + + # Return the link with the diagonal set to zero, to remove self-looping + # edges. + mask = ( + torch.eye(self._memory_size) + .repeat(batch_size, self._num_writes, 1, 1) + .bool() + ) + link[mask] = 0 + return link + + def _precedence_weights( + self, + prev_precedence_weights: Tensor, + write_weights: Tensor, + ) -> Tensor: + """Calculates the new precedence weights given the current write weights. + + The precedence weights are the "aggregated write weights" for each write + head, where write weights with sum close to zero will leave the precedence + weights unchanged, but with sum close to one will replace the precedence + weights. + + Args: + prev_precedence_weights: A tensor of shape `[batch_size, num_writes, + memory_size]` containing the previous precedence weights. + write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` + containing the new write weights. + + Returns: + A tensor of shape `[batch_size, num_writes, memory_size]` containing the + new precedence weights. + """ + write_sum = write_weights.sum(dim=2, keepdim=True) + return (1 - write_sum) * prev_precedence_weights + write_weights + + def forward( + self, + write_weights: Tensor, + prev_state: Optional[TemporalLinkageState] = None, + ) -> TemporalLinkageState: + """Calculate the updated linkage state given the write weights. + + Args: + write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` + containing the memory addresses of the different write heads. + prev_state: `TemporalLinkageState` tuple containg a tensor `link` of + shape `[batch_size, num_writes, memory_size, memory_size]`, and a + tensor `precedence_weights` of shape `[batch_size, num_writes, + memory_size]` containing the aggregated history of recent writes. + + Returns: + A `TemporalLinkageState` tuple `next_state`, which contains the updated + link and precedence weights. + """ + if write_weights.ndim != 3: + raise ValueError( + f"Expected `write_weights` to be 3D. Found: {write_weights.ndim}" + ) + if ( + write_weights.size(1) != self._num_writes + or write_weights.size(2) != self._memory_size + ): + raise ValueError( + "Expected the shape of `write_weights` to be " + f"[batch, {self._num_writes}, {self._memory_size}]. " + f"Found: {write_weights.shape}." + ) + + if prev_state is None: + B, W, M = write_weights.shape + prev_state = TemporalLinkageState( + link=torch.zeros((B, W, M, M), device=write_weights.device), + precedence_weights=torch.zeros((B, W, M), device=write_weights.device), + ) + + link = self._link(prev_state.link, prev_state.precedence_weights, write_weights) + precedence_weights = self._precedence_weights( + prev_state.precedence_weights, write_weights + ) + return TemporalLinkageState(link=link, precedence_weights=precedence_weights) + + def directional_read_weights( + self, + link: Tensor, + prev_read_weights: Tensor, + is_forward: bool, + ) -> Tensor: + """Calculates the forward or the backward read weights. + + For each read head (at a given address), there are `num_writes` link graphs + to follow. Thus this function computes a read address for each of the + `num_reads * num_writes` pairs of read and write heads. + + Args: + link: tensor of shape `[batch_size, num_writes, memory_size, + memory_size]` representing the link graphs L_t. + prev_read_weights: tensor of shape `[batch_size, num_reads, + memory_size]` containing the previous read weights w_{t-1}^r. + forward: Boolean indicating whether to follow the "future" direction in + the link graph (True) or the "past" direction (False). + + Returns: + tensor of shape `[batch_size, num_reads, num_writes, memory_size]` + """ + # + expanded_read_weights = torch.stack( + [prev_read_weights for _ in range(self._num_writes)], dim=1 + ) + if is_forward: + link = link.transpose(-1, -2) + result = torch.matmul(expanded_read_weights, link) # + return result.permute((0, 2, 1, 3)) # + + +class Freeness(torch.nn.Module): + def __init__(self, memory_size): + super().__init__() + self._memory_size = memory_size + + def _usage_after_write( + self, + prev_usage: Tensor, + write_weights: Tensor, + ) -> Tensor: + """Calcualtes the new usage after writing to memory. + + Args: + prev_usage: tensor of shape `[batch_size, memory_size]`. + write_weights: tensor of shape `[batch_size, num_writes, memory_size]`. + + Returns: + New usage, a tensor of shape `[batch_size, memory_size]`. + """ + write_weights = 1 - torch.prod(1 - write_weights, 1) + return prev_usage + (1 - prev_usage) * write_weights + + def _usage_after_read( + self, prev_usage: Tensor, free_gate: Tensor, read_weights: Tensor + ) -> Tensor: + """Calcualtes the new usage after reading and freeing from memory. + + Args: + prev_usage: tensor of shape `[batch_size, memory_size]`. + free_gate: tensor of shape `[batch_size, num_reads]` with entries in the + range [0, 1] indicating the amount that locations read from can be + freed. + read_weights: tensor of shape `[batch_size, num_reads, memory_size]`. + + Returns: + New usage, a tensor of shape `[batch_size, memory_size]`. + """ + free_gate = free_gate.unsqueeze(-1) + free_read_weights = free_gate * read_weights + phi = torch.prod(1 - free_read_weights, 1) + return prev_usage * phi + + def forward( + self, + write_weights: Tensor, + free_gate: Tensor, + read_weights: Tensor, + prev_usage: Tensor, + ) -> Tensor: + """Calculates the new memory usage u_t. + + Memory that was written to in the previous time step will have its usage + increased; memory that was read from and the controller says can be "freed" + will have its usage decreased. + + Args: + write_weights: tensor of shape `[batch_size, num_writes, + memory_size]` giving write weights at previous time step. + free_gate: tensor of shape `[batch_size, num_reads]` which indicates + which read heads read memory that can now be freed. + read_weights: tensor of shape `[batch_size, num_reads, + memory_size]` giving read weights at previous time step. + prev_usage: tensor of shape `[batch_size, memory_size]` giving + usage u_{t - 1} at the previous time step, with entries in range + [0, 1]. + + Returns: + tensor of shape `[batch_size, memory_size]` representing updated memory + usage. + """ + with torch.no_grad(): + usage = self._usage_after_write(prev_usage, write_weights) + usage = self._usage_after_read(usage, free_gate, read_weights) + return usage + + def _allocation(self, usage: Tensor) -> Tensor: + """Computes allocation by sorting `usage`. + + This corresponds to the value a = a_t[\phi_t[j]] in the paper. + + Args: + usage: tensor of shape `[batch_size, memory_size]` indicating current + memory usage. This is equal to u_t in the paper when we only have one + write head, but for multiple write heads, one should update the usage + while iterating through the write heads to take into account the + allocation returned by this function. + + Returns: + Tensor of shape `[batch_size, memory_size]` corresponding to allocation. + """ + usage = _EPSILON + (1 - _EPSILON) * usage + + nonusage = 1 - usage + sorted_nonusage, indices = torch.topk(nonusage, k=self._memory_size) + sorted_usage = 1 - sorted_nonusage + + # emulate tf.cumprod(exclusive=True) + sorted_usage = F.pad(sorted_usage, (1, 0), mode="constant", value=1) + prod_sorted_usage = torch.cumprod(sorted_usage, dim=1) + prod_sorted_usage = prod_sorted_usage[:, :-1] + + sorted_allocation = sorted_nonusage * prod_sorted_usage + inverse_indices = torch.argsort(indices) + return torch.gather(sorted_allocation, 1, inverse_indices) + + def write_allocation_weights( + self, + usage: Tensor, + write_gates: Tensor, + num_writes: Tensor, + ) -> Tensor: + """Calculates freeness-based locations for writing to. + + This finds unused memory by ranking the memory locations by usage, for each + write head. (For more than one write head, we use a "simulated new usage" + which takes into account the fact that the previous write head will increase + the usage in that area of the memory.) + + Args: + usage: A tensor of shape `[batch_size, memory_size]` representing + current memory usage. + write_gates: A tensor of shape `[batch_size, num_writes]` with values in + the range [0, 1] indicating how much each write head does writing + based on the address returned here (and hence how much usage + increases). + num_writes: The number of write heads to calculate write weights for. + + Returns: + tensor of shape `[batch_size, num_writes, memory_size]` containing the + freeness-based write locations. Note that this isn't scaled by + `write_gate`; this scaling must be applied externally. + """ + write_gates = write_gates.unsqueeze(-1) + allocation_weights = [] + for i in range(num_writes): + allocation_weights.append(self._allocation(usage)) + usage = usage + (1 - usage) * write_gates[:, i, :] * allocation_weights[i] + return torch.stack(allocation_weights, dim=1) diff --git a/dnc/dnc.py b/dnc/dnc.py new file mode 100644 index 0000000..fa0abb8 --- /dev/null +++ b/dnc/dnc.py @@ -0,0 +1,129 @@ +from dataclasses import dataclass +from typing import Tuple, Optional + +import torch +from torch import Tensor + +from .access import MemoryAccess, AccessState + + +@dataclass +class DNCState: + access_output: Tensor # [batch_size, num_reads, word_size] + access_state: AccessState + # memory: Tensor [batch_size, memory_size, word_size] + # read_weights: Tensor # [batch_size, num_reads, memory_size] + # write_weights: Tensor # [batch_size, num_writes, memory_size] + # linkage: TemporalLinkageState + # link: [batch_size, num_writes, memory_size, memory_size] + # precedence_weights: [batch_size, num_writes, memory_size] + # usage: Tensor # [batch_size, memory_size] + controller_state: Tuple[Tensor, Tensor] + # h_n: [num_layers, batch_size, projection_size] + # c_n: [num_layers, batch_size, hidden_size] + + +class DNC(torch.nn.Module): + """DNC core module. + + Contains controller and memory access module. + """ + + def __init__( + self, + access_config, + controller_config, + output_size, + clip_value=None, + ): + """Initializes the DNC core. + + Args: + access_config: dictionary of access module configurations. + controller_config: dictionary of controller (LSTM) module configurations. + output_size: output dimension size of core. + clip_value: clips controller and core output values to between + `[-clip_value, clip_value]` if specified. + Raises: + TypeError: if direct_input_size is not None for any access module other + than KeyValueMemory. + """ + super().__init__() + + self._controller = torch.nn.LSTMCell(**controller_config) + self._access = MemoryAccess(**access_config) + self._output = torch.nn.LazyLinear(output_size) + if clip_value is None: + self._clip = lambda x: x + else: + self._clip = lambda x: torch.clamp(x, min=-clip_value, max=clip_value) + + def forward(self, inputs: Tensor, prev_state: Optional[DNCState] = None): + """Connects the DNC core into the graph. + + Args: + inputs: Tensor input. + prev_state: A `DNCState` tuple containing the fields `access_output`, + `access_state` and `controller_state`. `access_state` is a 3-D Tensor + of shape `[batch_size, num_reads, word_size]` containing read words. + `access_state` is a tuple of the access module's state, and + `controller_state` is a tuple of controller module's state. + + Returns: + A tuple `(output, next_state)` where `output` is a tensor and `next_state` + is a `DNCState` tuple containing the fields `access_output`, + `access_state`, and `controller_state`. + """ + if inputs.ndim != 2: + raise ValueError(f"Expected `inputs` to be 2D: Found {inputs.ndim}.") + if prev_state is None: + B, device = inputs.shape[0], inputs.device + num_reads = self._access._num_reads + word_size = self._access._word_size + prev_state = DNCState( + access_output=torch.zeros((B, num_reads, word_size), device=device), + access_state=None, + controller_state=None, + ) + + def batch_flatten(x): + return torch.reshape(x, [x.size(0), -1]) + + controller_input = torch.concat( + [ + batch_flatten(inputs), # [batch_size, num_input_feats] + batch_flatten( + prev_state.access_output + ), # [batch_size, num_reads*word_size] + ], + dim=1, + ) # [batch_size, num_input_feats + num_reads * word_size] + + controller_state = self._controller( + controller_input, prev_state.controller_state + ) + controller_state = tuple(self._clip(t) for t in controller_state) + controller_output = controller_state[0] + + access_output, access_state = self._access( + controller_output, prev_state.access_state + ) + + output = torch.concat( + [ + controller_output, # [batch_size, num_ctrl_feats] + batch_flatten(access_output), # [batch_size, num_reads*word_size] + ], + dim=1, + ) + output = self._output(output) + output = self._clip(output) + + return ( + output, + DNCState( + access_output=access_output, + access_state=access_state, + controller_state=controller_state, + ), + ) diff --git a/dnc/repeat_copy.py b/dnc/repeat_copy.py new file mode 100644 index 0000000..a571eee --- /dev/null +++ b/dnc/repeat_copy.py @@ -0,0 +1,383 @@ +from dataclasses import dataclass + +import torch +from torch import Tensor +import torch.nn.functional as F + + +@dataclass +class DatasetTensors: + observations: Tensor + target: Tensor + mask: Tensor + + +def masked_sigmoid_cross_entropy( + logits, + target, + mask, + time_average=False, + log_prob_in_bits=False, +): + """Adds ops to graph which compute the (scalar) NLL of the target sequence. + + The logits parametrize independent bernoulli distributions per time-step and + per batch element, and irrelevant time/batch elements are masked out by the + mask tensor. + + Args: + logits: `Tensor` of activations for which sigmoid(`logits`) gives the + bernoulli parameter. + target: time-major `Tensor` of target. + mask: time-major `Tensor` to be multiplied elementwise with cost T x B cost + masking out irrelevant time-steps. + time_average: optionally average over the time dimension (sum by default). + log_prob_in_bits: iff True express log-probabilities in bits (default nats). + + Returns: + A `Tensor` representing the log-probability of the target. + """ + batch_size = logits.shape[1] + + xent = F.binary_cross_entropy_with_logits(logits, target, reduction="none") + loss_time_batch = xent.sum(dim=2) + loss_batch = (loss_time_batch * mask).sum(dim=0) + + if time_average: + mask_count = mask.sum(dim=0) + loss_batch /= mask_count + torch.finfo(loss_batch.dtype).eps + + loss = loss_batch.sum() / batch_size + if log_prob_in_bits: + loss /= torch.log(2.0) + return loss + + +def bitstring_readable(data, batch_size, model_output=None, whole_batch=False): + """Produce a human readable representation of the sequences in data. + + Args: + data: data to be visualised + batch_size: size of batch + model_output: optional model output tensor to visualize alongside data. + whole_batch: whether to visualise the whole batch. Only the first sample + will be visualized if False + + Returns: + A string used to visualise the data batch + """ + + def _readable(datum): + return "+" + " ".join([f"{int(x):d}" if x else "-" for x in datum]) + "+" + + obs_batch = data.observations + targ_batch = data.target + + batch_strings = [] + for batch_index in range(batch_size if whole_batch else 1): + obs = obs_batch[:, batch_index, :] + targ = targ_batch[:, batch_index, :] + + obs_channels = range(obs.shape[1]) + targ_channels = range(targ.shape[1]) + obs_channel_strings = [_readable(obs[:, i]) for i in obs_channels] + targ_channel_strings = [_readable(targ[:, i]) for i in targ_channels] + + readable_obs = "Observations:\n" + "\n".join(obs_channel_strings) + readable_targ = "Targets:\n" + "\n".join(targ_channel_strings) + strings = [readable_obs, readable_targ] + + if model_output is not None: + output = model_output[:, batch_index, :] + output_strings = [_readable(output[:, i]) for i in targ_channels] + strings.append("Model Output:\n" + "\n".join(output_strings)) + + batch_strings.append("\n\n".join(strings)) + + return "\n" + "\n\n\n\n".join(batch_strings) + + +class RepeatCopy: + """Sequence data generator for the task of repeating a random binary pattern. + + When called, an instance of this class will return a tuple of tensorflow ops + (obs, targ, mask), representing an input sequence, target sequence, and + binary mask. Each of these ops produces tensors whose first two dimensions + represent sequence position and batch index respectively. The value in + mask[t, b] is equal to 1 iff a prediction about targ[t, b, :] should be + penalized and 0 otherwise. + + For each realisation from this generator, the observation sequence is + comprised of I.I.D. uniform-random binary vectors (and some flags). + + The target sequence is comprised of this binary pattern repeated + some number of times (and some flags). Before explaining in more detail, + let's examine the setup pictorially for a single batch element: + + ```none + Note: blank space represents 0. + + time ------------------------------------------> + + +-------------------------------+ + mask: |0000000001111111111111111111111| + +-------------------------------+ + + +-------------------------------+ + target: | 1| 'end-marker' channel. + | 101100110110011011001 | + | 010101001010100101010 | + +-------------------------------+ + + +-------------------------------+ + observation: | 1011001 | + | 0101010 | + |1 | 'start-marker' channel + | 3 | 'num-repeats' channel. + +-------------------------------+ + ``` + + The length of the random pattern and the number of times it is repeated + in the target are both discrete random variables distributed according to + uniform distributions whose parameters are configured at construction time. + + The obs sequence has two extra channels (components in the trailing dimension) + which are used for flags. One channel is marked with a 1 at the first time + step and is otherwise equal to 0. The other extra channel is zero until the + binary pattern to be repeated ends. At this point, it contains an encoding of + the number of times the observation pattern should be repeated. Rather than + simply providing this integer number directly, it is normalised so that + a neural network may have an easier time representing the number of + repetitions internally. To allow a network to be readily evaluated on + instances of this task with greater numbers of repetitions, the range with + respect to which this encoding is normalised is also configurable by the user. + + As in the diagram, the target sequence is offset to begin directly after the + observation sequence; both sequences are padded with zeros to accomplish this, + resulting in their lengths being equal. Additional padding is done at the end + so that all sequences in a minibatch represent tensors with the same shape. + """ + + def __init__( + self, + num_bits=6, + batch_size=1, + min_length=1, + max_length=1, + min_repeats=1, + max_repeats=2, + norm_max=10, + log_prob_in_bits=False, + time_average_cost=False, + ): + """Creates an instance of RepeatCopy task. + + Args: + name: A name for the generator instance (for name scope purposes). + num_bits: The dimensionality of each random binary vector. + batch_size: Minibatch size per realization. + min_length: Lower limit on number of random binary vectors in the + observation pattern. + max_length: Upper limit on number of random binary vectors in the + observation pattern. + min_repeats: Lower limit on number of times the obervation pattern + is repeated in targ. + max_repeats: Upper limit on number of times the observation pattern + is repeated in targ. + norm_max: Upper limit on uniform distribution w.r.t which the encoding + of the number of repetitions presented in the observation sequence + is normalised. + log_prob_in_bits: By default, log probabilities are expressed in units of + nats. If true, express log probabilities in bits. + time_average_cost: If true, the cost at each time step will be + divided by the `true`, sequence length, the number of non-masked time + steps, in each sequence before any subsequent reduction over the time + and batch dimensions. + """ + super().__init__() + + self._batch_size = batch_size + self._num_bits = num_bits + self._min_length = min_length + self._max_length = max_length + self._min_repeats = min_repeats + self._max_repeats = max_repeats + self._norm_max = norm_max + self._log_prob_in_bits = log_prob_in_bits + self._time_average_cost = time_average_cost + + def _normalise(self, val): + return val / self._norm_max + + def _unnormalise(self, val): + return val * self._norm_max + + @property + def time_average_cost(self): + return self._time_average_cost + + @property + def log_prob_in_bits(self): + return self._log_prob_in_bits + + @property + def num_bits(self): + """The dimensionality of each random binary vector in a pattern.""" + return self._num_bits + + @property + def target_size(self): + """The dimensionality of the target tensor.""" + return self._num_bits + 1 + + @property + def batch_size(self): + return self._batch_size + + def __call__(self): + """Implements build method which adds ops to graph.""" + + # short-hand for private fields. + min_length, max_length = self._min_length, self._max_length + min_reps, max_reps = self._min_repeats, self._max_repeats + num_bits = self.num_bits + batch_size = self.batch_size + + # We reserve one dimension for the num-repeats and one for the start-marker. + full_obs_size = num_bits + 2 + # We reserve one target dimension for the end-marker. + full_targ_size = num_bits + 1 + start_end_flag_idx = full_obs_size - 2 + num_repeats_channel_idx = full_obs_size - 1 + + # Samples each batch index's sequence length and the number of repeats. + sub_seq_length_batch = torch.randint( + low=min_length, high=max_length + 1, size=[batch_size], dtype=torch.int32 + ) + num_repeats_batch = torch.randint( + low=min_reps, high=max_reps + 1, size=[batch_size], dtype=torch.int32 + ) + + # Pads all the batches to have the same total sequence length. + total_length_batch = sub_seq_length_batch * (num_repeats_batch + 1) + 3 + max_length_batch = total_length_batch.max() + residual_length_batch = max_length_batch - total_length_batch + + obs_batch_shape = [max_length_batch, batch_size, full_obs_size] + targ_batch_shape = [max_length_batch, batch_size, full_targ_size] + mask_batch_trans_shape = [batch_size, max_length_batch] + + obs_tensors = [] + targ_tensors = [] + mask_tensors = [] + + # Generates patterns for each batch element independently. + for batch_index in range(batch_size): + sub_seq_len = sub_seq_length_batch[batch_index] + num_reps = num_repeats_batch[batch_index] + + # The observation pattern is a sequence of random binary vectors. + obs_pattern_shape = [sub_seq_len, num_bits] + obs_pattern = torch.randint(low=0, high=2, size=obs_pattern_shape).to( + torch.float32 + ) + + # The target pattern is the observation pattern repeated n times. + # Some reshaping is required to accomplish the tiling. + targ_pattern_shape = [sub_seq_len * num_reps, num_bits] + flat_obs_pattern = obs_pattern.reshape([-1]) + flat_targ_pattern = torch.tile(flat_obs_pattern, [num_reps]) + targ_pattern = flat_targ_pattern.reshape(targ_pattern_shape) + + # Expand the obs_pattern to have two extra channels for flags. + # Concatenate start flag and num_reps flag to the sequence. + obs_flag_channel_pad = torch.zeros([sub_seq_len, 2]) + obs_start_flag = F.one_hot( + torch.tensor([start_end_flag_idx]), num_classes=full_obs_size + ).to(torch.float32) + num_reps_flag = F.one_hot( + torch.tensor([num_repeats_channel_idx]), num_classes=full_obs_size + ).to(torch.float32) + num_reps_flag *= self._normalise(num_reps) + + # note the concatenation dimensions. + obs = torch.concat([obs_pattern, obs_flag_channel_pad], 1) + obs = torch.concat([obs_start_flag, obs], 0) + obs = torch.concat([obs, num_reps_flag], 0) + + # Now do the same for the targ_pattern (it only has one extra channel). + targ_flag_channel_pad = torch.zeros([sub_seq_len * num_reps, 1]) + targ_end_flag = F.one_hot( + torch.tensor([start_end_flag_idx]), num_classes=full_targ_size + ).to(torch.float32) + targ = torch.concat([targ_pattern, targ_flag_channel_pad], 1) + targ = torch.concat([targ, targ_end_flag], 0) + + # Concatenate zeros at end of obs and begining of targ. + # This aligns them s.t. the target begins as soon as the obs ends. + obs_end_pad = torch.zeros([sub_seq_len * num_reps + 1, full_obs_size]) + targ_start_pad = torch.zeros([sub_seq_len + 2, full_targ_size]) + + # The mask is zero during the obs and one during the targ. + mask_off = torch.zeros([sub_seq_len + 2]) + mask_on = torch.ones([sub_seq_len * num_reps + 1]) + + obs = torch.concat([obs, obs_end_pad], 0) + targ = torch.concat([targ_start_pad, targ], 0) + mask = torch.concat([mask_off, mask_on], 0) + + obs_tensors.append(obs) + targ_tensors.append(targ) + mask_tensors.append(mask) + + # End the loop over batch index. + # Compute how much zero padding is needed to make tensors sequences + # the same length for all batch elements. + residual_obs_pad = [ + torch.zeros([residual_length_batch[i], full_obs_size]) + for i in range(batch_size) + ] + residual_targ_pad = [ + torch.zeros([residual_length_batch[i], full_targ_size]) + for i in range(batch_size) + ] + residual_mask_pad = [ + torch.zeros([residual_length_batch[i]]) for i in range(batch_size) + ] + + # Concatenate the pad to each batch element. + obs_tensors = [ + torch.concat([o, p], 0) for o, p in zip(obs_tensors, residual_obs_pad) + ] + targ_tensors = [ + torch.concat([t, p], 0) for t, p in zip(targ_tensors, residual_targ_pad) + ] + mask_tensors = [ + torch.concat([m, p], 0) for m, p in zip(mask_tensors, residual_mask_pad) + ] + + # Concatenate each batch element into a single tensor. + obs = torch.concat(obs_tensors, 1).reshape(obs_batch_shape) + targ = torch.concat(targ_tensors, 1).reshape(targ_batch_shape) + mask = torch.concat(mask_tensors, 0).reshape(mask_batch_trans_shape).T + return DatasetTensors(obs, targ, mask) + + def cost(self, logits, targ, mask): + return masked_sigmoid_cross_entropy( + logits, + targ, + mask, + time_average=self.time_average_cost, + log_prob_in_bits=self.log_prob_in_bits, + ) + + def to_human_readable(self, data, model_output=None, whole_batch=False): + obs = data.observations + unnormalised_num_reps_flag = self._unnormalise(obs[:, :, -1:]).round() + obs = torch.cat([obs[:, :, :-1], unnormalised_num_reps_flag], dim=2) + data = DatasetTensors( + observations=obs, + target=data.target, + mask=data.mask, + ) + return bitstring_readable(data, self.batch_size, model_output, whole_batch) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7a2e4b6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +torch > 1.9 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/access_test.py b/tests/access_test.py new file mode 100644 index 0000000..043b3c1 --- /dev/null +++ b/tests/access_test.py @@ -0,0 +1,155 @@ +import torch + +from dnc.access import MemoryAccess, AccessState +from dnc.addressing import TemporalLinkageState + +from .util import one_hot + +BATCH_SIZE = 2 +MEMORY_SIZE = 20 +WORD_SIZE = 6 +NUM_READS = 2 +NUM_WRITES = 3 +TIME_STEPS = 4 +INPUT_SIZE = 10 + + +def test_memory_access_build_and_train(): + module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES) + inputs = torch.randn((TIME_STEPS, BATCH_SIZE, INPUT_SIZE)) + + outputs = [] + state = None + for input in inputs: + output, state = module(input, state) + outputs.append(output) + outputs = torch.stack(outputs, dim=-1) + + targets = torch.rand((TIME_STEPS, BATCH_SIZE, NUM_READS, WORD_SIZE)) + loss = (output - targets).square().mean() + + optim = torch.optim.SGD(module.parameters(), lr=1) + optim.zero_grad() + loss.backward() + optim.step() + + +def test_memory_access_valid_read_mode(): + module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES) + inputs = module._read_inputs(torch.randn((BATCH_SIZE, INPUT_SIZE))) + + # Check that the read modes for each read head constitute a probability + # distribution. + torch.testing.assert_close( + inputs["read_mode"].sum(2), torch.ones([BATCH_SIZE, NUM_READS]) + ) + assert torch.all(inputs["read_mode"] >= 0) + + +def test_memory_access_write_weights(): + memory = 10 * (torch.rand((BATCH_SIZE, MEMORY_SIZE, WORD_SIZE)) - 0.5) + usage = torch.rand((BATCH_SIZE, MEMORY_SIZE)) + + allocation_gate = torch.rand((BATCH_SIZE, NUM_WRITES)) + write_gate = torch.rand((BATCH_SIZE, NUM_WRITES)) + + write_gate = torch.rand((BATCH_SIZE, NUM_WRITES)) + write_content_keys = torch.rand((BATCH_SIZE, NUM_WRITES, WORD_SIZE)) + write_content_strengths = torch.rand((BATCH_SIZE, NUM_WRITES)) + + # Check that turning on allocation gate fully brings the write gate to + # the allocation weighting (which we will control by controlling the usage). + usage[:, 3] = 0 + allocation_gate[:, 0] = 1 + write_gate[:, 0] = 1 + + inputs = { + "allocation_gate": allocation_gate, + "write_gate": write_gate, + "write_content_keys": write_content_keys, + "write_content_strengths": write_content_strengths, + } + + module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES) + weights = module._write_weights(inputs, memory, usage) + + # Check the weights sum to their target gating. + torch.testing.assert_close(weights.sum(dim=2), write_gate, atol=5e-2, rtol=0) + + # Check that we fully allocated to the third row. + weights_0_0_target = one_hot(MEMORY_SIZE, 3, dtype=torch.float32) + torch.testing.assert_close(weights[0, 0], weights_0_0_target, atol=1e-3, rtol=0) + + +def test_memory_access_read_weights(): + memory = 10 * (torch.rand((BATCH_SIZE, MEMORY_SIZE, WORD_SIZE)) - 0.5) + prev_read_weights = torch.rand((BATCH_SIZE, NUM_READS, MEMORY_SIZE)) + prev_read_weights /= prev_read_weights.sum(2, keepdim=True) + 1 + + link = torch.rand((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE, MEMORY_SIZE)) + # Row and column sums should be at most 1: + link /= torch.maximum(link.sum(2, keepdim=True), torch.tensor(1)) + link /= torch.maximum(link.sum(3, keepdim=True), torch.tensor(1)) + + # We query the memory on the third location in memory, and select a large + # strength on the query. Then we select a content-based read-mode. + read_content_keys = torch.rand((BATCH_SIZE, NUM_READS, WORD_SIZE)) + read_content_keys[0, 0] = memory[0, 3] + read_content_strengths = torch.full( + (BATCH_SIZE, NUM_READS), 100.0, dtype=torch.float64 + ) + read_mode = torch.rand((BATCH_SIZE, NUM_READS, 1 + 2 * NUM_WRITES)) + read_mode[0, 0, :] = one_hot(1 + 2 * NUM_WRITES, 2 * NUM_WRITES) + inputs = { + "read_content_keys": read_content_keys, + "read_content_strengths": read_content_strengths, + "read_mode": read_mode, + } + + module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES) + read_weights = module._read_weights(inputs, memory, prev_read_weights, link) + + # read_weights for batch 0, read head 0 should be memory location 3 + ref = one_hot(MEMORY_SIZE, 3, dtype=torch.float64) + torch.testing.assert_close(read_weights[0, 0, :], ref, atol=1e-3, rtol=0) + + +def test_memory_access_gradients(): + kwargs = {"dtype": torch.float64, "requires_grad": True} + + inputs = torch.randn((BATCH_SIZE, INPUT_SIZE), **kwargs) + memory = torch.zeros((BATCH_SIZE, MEMORY_SIZE, WORD_SIZE), **kwargs) + read_weights = torch.zeros((BATCH_SIZE, NUM_READS, MEMORY_SIZE), **kwargs) + precedence_weights = torch.zeros((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE), **kwargs) + link = torch.zeros((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE, MEMORY_SIZE), **kwargs) + + module = MemoryAccess(MEMORY_SIZE, WORD_SIZE, NUM_READS, NUM_WRITES) + + class Wrapper(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inputs, memory, read_weights, link, precedence_weights): + write_weights = torch.zeros((BATCH_SIZE, NUM_WRITES, MEMORY_SIZE)) + usage = torch.zeros((BATCH_SIZE, MEMORY_SIZE)) + + state = AccessState( + memory=memory, + read_weights=read_weights, + write_weights=write_weights, + linkage=TemporalLinkageState( + link=link, + precedence_weights=precedence_weights, + ), + usage=usage, + ) + output, _ = self.module(inputs, state) + return output.sum() + + module = Wrapper(module).to(torch.float64) + + torch.autograd.gradcheck( + module, + (inputs, memory, read_weights, link, precedence_weights), + ) diff --git a/tests/addressing_test.py b/tests/addressing_test.py new file mode 100644 index 0000000..f9c6d26 --- /dev/null +++ b/tests/addressing_test.py @@ -0,0 +1,338 @@ +import numpy as np +import torch +import torch.nn.functional as F +from dnc.addressing import weighted_softmax, CosineWeights, TemporalLinkage, Freeness + +from .util import one_hot + + +def _test_weighted_softmax(strength_op): + batch_size, num_heads, memory_size = 5, 3, 7 + activations = torch.randn(batch_size, num_heads, memory_size) + weights = torch.ones((batch_size, num_heads)) + + observed = weighted_softmax(activations, weights, strength_op) + expected = torch.stack( + [ + F.softmax(a * strength_op(w).unsqueeze(-1)) + for a, w in zip(activations, weights) + ], + dim=0, + ) + + torch.testing.assert_close(observed, expected) + + +def test_weighted_softmax_identity(): + _test_weighted_softmax(lambda x: x) + + +def test_weighted_softmax_softplus(): + _test_weighted_softmax(F.softplus) + + +def test_cosine_weights_shape(): + batch_size, num_heads, memory_size, word_size = 5, 3, 7, 2 + + module = CosineWeights(num_heads, word_size) + mem = torch.randn([batch_size, memory_size, word_size]) + keys = torch.randn([batch_size, num_heads, word_size]) + strengths = torch.randn([batch_size, num_heads]) + weights = module(mem, keys, strengths) + + assert weights.shape == torch.Size([batch_size, num_heads, memory_size]) + + +def test_cosine_weights_values(): + batch_size, num_heads, memory_size, word_size = 5, 4, 10, 2 + + mem = torch.randn((batch_size, memory_size, word_size)) + mem[0, 0, 0] = 1 + mem[0, 0, 1] = 2 + mem[0, 1, 0] = 3 + mem[0, 1, 1] = 4 + mem[0, 2, 0] = 5 + mem[0, 2, 1] = 6 + + keys = torch.randn((batch_size, num_heads, word_size)) + keys[0, 0, 0] = 5 + keys[0, 0, 1] = 6 + keys[0, 1, 0] = 1 + keys[0, 1, 1] = 2 + keys[0, 2, 0] = 5 + keys[0, 2, 1] = 6 + keys[0, 3, 0] = 3 + keys[0, 3, 1] = 4 + + strengths = torch.randn((batch_size, num_heads)) + + module = CosineWeights(num_heads, word_size) + result = module(mem, keys, strengths) + + # Manually checks results. + strengths_softplus = np.log(1 + np.exp(strengths.numpy())) + similarity = np.zeros((memory_size)) + + for b in range(batch_size): + for h in range(num_heads): + key = keys[b, h] + key_norm = np.linalg.norm(key) + + for m in range(memory_size): + row = mem[b, m] + similarity[m] = np.dot(key, row) / (key_norm * np.linalg.norm(row)) + + similarity = np.exp(similarity * strengths_softplus[b, h]) + similarity /= similarity.sum() + ref = torch.from_numpy(similarity).to(dtype=torch.float32) + torch.testing.assert_close(result[b, h], ref, atol=1e-4, rtol=1e-4) + + +def test_cosine_weights_divide_by_zero(): + batch_size, num_heads, memory_size, word_size = 5, 4, 10, 2 + + module = CosineWeights(num_heads, word_size) + keys = torch.randn([batch_size, num_heads, word_size], requires_grad=True) + strengths = torch.randn([batch_size, num_heads], requires_grad=True) + + # First row of memory is non-zero to concentrate attention on this location. + # Remaining rows are all zero. + mem = torch.zeros([batch_size, memory_size, word_size]) + mem[:, 0, :] = 1 + mem.requires_grad = True + + output = module(mem, keys, strengths) + output.sum().backward() + + assert torch.all(~output.isnan()) + assert torch.all(~mem.grad.isnan()) + assert torch.all(~keys.grad.isnan()) + assert torch.all(~strengths.grad.isnan()) + + +def test_temporal_linkage(): + batch_size, memory_size, num_reads, num_writes = 7, 4, 11, 5 + + module = TemporalLinkage(memory_size=memory_size, num_writes=num_writes) + + state = None + num_steps = 5 + for i in range(num_steps): + write_weights = torch.rand([batch_size, num_writes, memory_size]) + write_weights /= write_weights.sum(2, keepdim=True) + 1 + + # Simulate (in final steps) link 0-->1 in head 0 and 3-->2 in head 1 + if i == num_steps - 2: + write_weights[0, 0, :] = one_hot(memory_size, 0) + write_weights[0, 1, :] = one_hot(memory_size, 3) + elif i == num_steps - 1: + write_weights[0, 0, :] = one_hot(memory_size, 1) + write_weights[0, 1, :] = one_hot(memory_size, 2) + + state = module(write_weights, state) + + # link should be bounded in range [0, 1] + assert torch.all(0 <= state.link.min() <= 1) + + # link diagonal should be zero + torch.testing.assert_close( + state.link[:, :, range(memory_size), range(memory_size)], + torch.zeros([batch_size, num_writes, memory_size]), + ) + + # link rows and columns should sum to at most 1 + assert torch.all(state.link.sum(2) <= 1) + assert torch.all(state.link.sum(3) <= 1) + + # records our transitions in batch 0: head 0: 0->1, and head 1: 3->2 + torch.testing.assert_close( + state.link[0, 0, :, 0], one_hot(memory_size, 1, dtype=torch.float32) + ) + torch.testing.assert_close( + state.link[0, 1, :, 3], one_hot(memory_size, 2, dtype=torch.float32) + ) + + # Now test calculation of forward and backward read weights + prev_read_weights = torch.randn((batch_size, num_reads, memory_size)) + prev_read_weights[0, 5, :] = one_hot(memory_size, 0) # read 5, posn 0 + prev_read_weights[0, 6, :] = one_hot(memory_size, 2) # read 6, posn 2 + + forward_read_weights = module.directional_read_weights( + state.link, prev_read_weights, is_forward=True + ) + backward_read_weights = module.directional_read_weights( + state.link, prev_read_weights, is_forward=False + ) + + # Check directional weights calculated correctly. + torch.testing.assert_close( + forward_read_weights[0, 5, 0, :], # read=5, write=0 + one_hot(memory_size, 1, dtype=torch.float32), + ) + torch.testing.assert_close( + backward_read_weights[0, 6, 1, :], # read=6, write=1 + one_hot(memory_size, 3, dtype=torch.float32), + ) + + +def test_temporal_linkage_precedence_weights(): + batch_size, memory_size, num_writes = 7, 3, 5 + + module = TemporalLinkage(memory_size=memory_size, num_writes=num_writes) + + prev_precedence_weights = torch.rand(batch_size, num_writes, memory_size) + write_weights = torch.rand(batch_size, num_writes, memory_size) + + # These should sum to at most 1 for each write head in each batch. + write_weights /= write_weights.sum(2, keepdim=True) + 1 + prev_precedence_weights /= prev_precedence_weights.sum(2, keepdim=True) + 1 + + write_weights[0, 1, :] = 0 # batch 0 head 1: no writing + write_weights[1, 2, :] /= write_weights[1, 2, :].sum() # b1 h2: all writing + + precedence_weights = module._precedence_weights( + prev_precedence_weights=prev_precedence_weights, write_weights=write_weights + ) + + # precedence weights should be bounded in range [0, 1] + assert torch.all(0 <= precedence_weights) + assert torch.all(precedence_weights <= 1) + + # no writing in batch 0, head 1 + torch.testing.assert_close( + precedence_weights[0, 1, :], prev_precedence_weights[0, 1, :] + ) + + # all writing in batch 1, head 2 + torch.testing.assert_close(precedence_weights[1, 2, :], write_weights[1, 2, :]) + + +def test_freeness(): + batch_size, memory_size, num_reads, num_writes = 5, 11, 3, 7 + + module = Freeness(memory_size) + + free_gate = torch.rand((batch_size, num_reads)) + + # Produce read weights that sum to 1 for each batch and head. + prev_read_weights = torch.rand((batch_size, num_reads, memory_size)) + prev_read_weights[1, :, 3] = 0 + prev_read_weights /= prev_read_weights.sum(2, keepdim=True) + prev_write_weights = torch.rand((batch_size, num_writes, memory_size)) + prev_write_weights /= prev_write_weights.sum(2, keepdim=True) + prev_usage = torch.rand((batch_size, memory_size)) + + # Add some special values that allows us to test the behaviour: + prev_write_weights[1, 2, 3] = 1 + prev_read_weights[2, 0, 4] = 1 + free_gate[2, 0] = 1 + + usage = module(prev_write_weights, free_gate, prev_read_weights, prev_usage) + + # Check all usages are between 0 and 1. + assert torch.all(0 <= usage) + assert torch.all(usage <= 1) + + # Check that the full write at batch 1, position 3 makes it fully used. + assert usage[1][3] == 1 + + # Check that the full free at batch 2, position 4 makes it fully free. + assert usage[2][4] == 0 + + +def test_freeness_write_allocation_weights(): + batch_size, memory_size, num_writes = 7, 23, 5 + + module = Freeness(memory_size) + + usage = torch.rand((batch_size, memory_size)) + write_gates = torch.rand((batch_size, num_writes)) + + # Turn off gates for heads 1 and 3 in batch 0. This doesn't scaling down the + # weighting, but it means that the usage doesn't change, so we should get + # the same allocation weightings for: (1, 2) and (3, 4) (but all others + # being different). + write_gates[0, 1] = 0 + write_gates[0, 3] = 0 + # and turn heads 0 and 2 on for full effect. + write_gates[0, 0] = 1 + write_gates[0, 2] = 1 + + # In batch 1, make one of the usages 0 and another almost 0, so that these + # entries get most of the allocation weights for the first and second heads. + usage[1] = usage[1] * 0.9 + 0.1 # make sure all entries are in [0.1, 1] + usage[1][4] = 0 # write head 0 should get allocated to position 4 + usage[1][3] = 1e-4 # write head 1 should get allocated to position 3 + write_gates[1, 0] = 1 # write head 0 fully on + write_gates[1, 1] = 1 # write head 1 fully on + + weights = module.write_allocation_weights( + usage=usage, write_gates=write_gates, num_writes=num_writes + ) + + assert torch.all(0 <= weights) + assert torch.all(weights <= 1) + + # Check that weights sum to close to 1 + torch.testing.assert_close( + weights.sum(dim=2), torch.ones((batch_size, num_writes)), atol=1e-3, rtol=0 + ) + + # Check the same / different allocation weight pairs as described above. + assert (weights[0, 0, :] - weights[0, 1, :]).abs().max() > 0.1 + torch.testing.assert_close(weights[0, 1, :], weights[0, 2, :]) + assert (weights[0, 2, :] - weights[0, 3, :]).abs().max() > 0.1 + torch.testing.assert_close(weights[0, 3, :], weights[0, 4, :]) + + torch.testing.assert_close( + weights[1][0], one_hot(memory_size, 4, dtype=torch.float32), atol=1e-3, rtol=0 + ) + torch.testing.assert_close( + weights[1][1], one_hot(memory_size, 3, dtype=torch.float32), atol=1e-3, rtol=0 + ) + + +def test_freeness_write_allocation_weights_gradient(): + batch_size, memory_size, num_writes = 7, 5, 3 + + module = Freeness(memory_size).to(torch.float64) + + usage = torch.rand( + (batch_size, memory_size), dtype=torch.float64, requires_grad=True + ) + write_gates = torch.rand( + (batch_size, num_writes), dtype=torch.float64, requires_grad=True + ) + + def func(usage, write_gates): + return module.write_allocation_weights(usage, write_gates, num_writes) + + torch.autograd.gradcheck(func, (usage, write_gates)) + + +def test_freeness_allocation(): + batch_size, memory_size = 7, 13 + + usage = torch.rand((batch_size, memory_size)) + module = Freeness(memory_size) + allocation = module._allocation(usage) + + # 1. Test that max allocation goes to min usage, and vice versa. + assert torch.all(usage.argmin(dim=1) == allocation.argmax(dim=1)) + assert torch.all(usage.argmax(dim=1) == allocation.argmin(dim=1)) + + # 2. Test that allocations sum to almost 1. + torch.testing.assert_close( + allocation.sum(dim=1), torch.ones(batch_size), atol=0.01, rtol=0 + ) + + +def test_freeness_allocation_gradient(): + batch_size, memory_size = 1, 5 + + usage = torch.rand( + (batch_size, memory_size), dtype=torch.float64, requires_grad=True + ) + module = Freeness(memory_size).to(torch.float64) + + torch.autograd.gradcheck(module._allocation, (usage,)) diff --git a/tests/dnc_test.py b/tests/dnc_test.py new file mode 100644 index 0000000..2897347 --- /dev/null +++ b/tests/dnc_test.py @@ -0,0 +1,42 @@ +import torch + +from dnc.dnc import DNC + + +def test_dnc(): + """smoke test""" + memory_size = 16 + word_size = 16 + num_reads = 4 + num_writes = 1 + + clip_value = 20 + + input_size = 4 + hidden_size = 64 + output_size = input_size + + batch_size = 16 + time_steps = 64 + + access_config = { + "memory_size": memory_size, + "word_size": word_size, + "num_reads": num_reads, + "num_writes": num_writes, + } + controller_config = { + "input_size": input_size + num_reads * word_size, + "hidden_size": hidden_size, + } + + dnc = DNC( + access_config=access_config, + controller_config=controller_config, + output_size=output_size, + clip_value=clip_value, + ) + inputs = torch.randn((time_steps, batch_size, input_size)) + state = None + for input in inputs: + output, state = dnc(input, state) diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 0000000..2951b43 --- /dev/null +++ b/tests/util.py @@ -0,0 +1,9 @@ +import torch +import torch.nn.functional as F + + +def one_hot(length, index, dtype=None): + val = F.one_hot(torch.tensor(index), num_classes=length) + if dtype is not None: + val = val.to(dtype=dtype) + return val diff --git a/train.py b/train.py new file mode 100644 index 0000000..1b601df --- /dev/null +++ b/train.py @@ -0,0 +1,205 @@ +"""Example script to train the DNC on a repeated copy task.""" +import os +import argparse +import logging + +import torch +from dnc.repeat_copy import RepeatCopy +from dnc.dnc import DNC + +_LG = logging.getLogger(__name__) + + +def _main(): + args = _parse_args() + logging.basicConfig(level=logging.INFO, format="%(asctime)s: %(message)s") + + dataset = RepeatCopy( + args.num_bits, + args.batch_size, + args.min_length, + args.max_length, + args.min_repeats, + args.max_repeats, + ) + + dnc = DNC( + access_config={ + "memory_size": args.memory_size, + "word_size": args.word_size, + "num_reads": args.num_read_heads, + "num_writes": args.num_write_heads, + }, + controller_config={ + "input_size": args.num_bits + 2 + args.num_read_heads * args.word_size, + "hidden_size": args.hidden_size, + }, + output_size=dataset.target_size, + clip_value=args.clip_value, + ) + + optimizer = torch.optim.RMSprop(dnc.parameters(), lr=args.lr, eps=args.eps) + + _run_train_loop( + dnc, + dataset, + optimizer, + args.num_training_iterations, + args.report_interval, + args.checkpoint_interval, + args.checkpoint_dir, + ) + + +def _run_train_loop( + dnc, + dataset, + optimizer, + num_training, + report_interval, + checkpoint_interval, + checkpoint_dir, +): + total_loss = 0 + for i in range(num_training): + batch = dataset() + state = None + outputs = [] + for inputs in batch.observations: + output, state = dnc(inputs, state) + outputs.append(output) + outputs = torch.stack(outputs, 0) + loss = dataset.cost(outputs, batch.target, batch.mask) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + if (i + 1) % report_interval == 0: + outputs = torch.round(batch.mask.unsqueeze(-1) * torch.sigmoid(outputs)) + dataset_string = dataset.to_human_readable(batch, outputs) + _LG.info(f"{i}: Avg training loss {total_loss / report_interval}") + _LG.info(dataset_string) + total_loss = 0 + if checkpoint_interval is not None and (i + 1) % checkpoint_interval == 0: + path = os.path.join(checkpoint_dir, "model.pt") + torch.save(dnc.state_dict(), path) + + +def _parse_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description=__doc__, + ) + model_opts = parser.add_argument_group("Model Parameters") + model_opts.add_argument( + "--hidden-size", type=int, default=64, help="Size of LSTM hidden layer." + ) + model_opts.add_argument( + "--memory-size", type=int, default=16, help="The number of memory slots." + ) + model_opts.add_argument( + "--word-size", type=int, default=16, help="The width of each memory slot." + ) + model_opts.add_argument( + "--num-write-heads", type=int, default=1, help="Number of memory write heads." + ) + model_opts.add_argument( + "--num-read-heads", type=int, default=4, help="Number of memory read heads." + ) + model_opts.add_argument( + "--clip-value", + type=float, + default=20, + help="Maximum absolute value of controller and dnc outputs.", + ) + + optim_opts = parser.add_argument_group("Optimizer Parameters") + optim_opts.add_argument( + "--max-grad-norm", type=float, default=50, help="Gradient clipping norm limit." + ) + optim_opts.add_argument( + "--learning-rate", + "--lr", + type=float, + default=1e-4, + dest="lr", + help="Optimizer learning rate.", + ) + optim_opts.add_argument( + "--optimizer-epsilon", + type=float, + default=1e-10, + dest="eps", + help="Epsilon used for RMSProp optimizer.", + ) + + task_opts = parser.add_argument_group("Task Parameters") + task_opts.add_argument( + "--batch-size", type=int, default=16, help="Batch size for training" + ) + task_opts.add_argument( + "--num-bits", type=int, default=4, help="Dimensionality of each vector to copy" + ) + task_opts.add_argument( + "--min-length", + type=int, + default=1, + help="Lower limit on number of vectors in the observation pattern to copy", + ) + task_opts.add_argument( + "--max-length", + type=int, + default=2, + help="Upper limit on number of vectors in the observation pattern to copy", + ) + task_opts.add_argument( + "--min-repeats", + type=int, + default=1, + help="Lower limit on number of copy repeats.", + ) + task_opts.add_argument( + "--max-repeats", + type=int, + default=2, + help="Upper limit on number of copy repeats.", + ) + + train_opts = parser.add_argument_group("Training Options") + train_opts.add_argument( + "--num-training-iterations", + type=int, + default=100_000, + help="Number of iterations to train for.", + ) + train_opts.add_argument( + "--report-interval", + type=int, + default=100, + help="Iterations between reports (samples, valid loss).", + ) + train_opts.add_argument( + "--checkpoint-dir", default=None, help="Checkpointing directory." + ) + train_opts.add_argument( + "--checkpoint-interval", + type=int, + default=None, + help="Checkpointing step interval.", + ) + args = parser.parse_args() + + if args.checkpoint_dir is None and args.checkpoint_interval is not None: + raise RuntimeError( + "`--checkpoint-dir` is provided but `--checkpoint-interval` is not provided." + ) + if args.checkpoint_dir is not None and args.checkpoint_interval is None: + raise RuntimeError( + "`--checkpoint-interval` is provided but `--checkpoint-dir` is not provided." + ) + return args + + +if __name__ == "__main__": + _main()