from dataclasses import dataclass from typing import Optional import torch from torch import Tensor import torch.nn.functional as F @dataclass class TemporalLinkageState: link: Tensor # [batch_size, num_writes, memory_size, memory_size] precedence_weights: Tensor # [batch_size, num_writes, memory_size] _EPSILON = 1e-6 def _vector_norms(m: Tensor) -> Tensor: norm = torch.sum(m * m, axis=2, keepdim=True) return torch.sqrt(norm + _EPSILON) def weighted_softmax(activations: Tensor, strengths: Tensor, strengths_op=F.softplus): """Returns softmax over activations multiplied by positive strengths. Args: activations: A tensor of shape `[batch_size, num_heads, memory_size]`, of activations to be transformed. Softmax is taken over the last dimension. strengths: A tensor of shape `[batch_size, num_heads]` containing strengths to multiply by the activations prior to the softmax. strengths_op: An operation to transform strengths before softmax. Returns: A tensor of same shape as `activations` with weighted softmax applied. """ transformed_strengths = strengths_op(strengths).unsqueeze(-1) sharp_activations = activations * transformed_strengths softmax = F.softmax(sharp_activations, dim=-1) return softmax class CosineWeights(torch.nn.Module): def __init__(self, num_heads, word_size, strength_op=F.softplus): """ Args: num_heads: number of memory heads. word_size: memory word size. strength_op: operation to apply to strengths (default softplus). """ super().__init__() self._num_heads = num_heads self._word_size = word_size self._strength_op = strength_op def forward(self, memory: Tensor, keys: Tensor, strengths: Tensor) -> Tensor: """ Args: memory: A 3-D tensor of shape `[batch_size, memory_size, word_size]`. keys: A 3-D tensor of shape `[batch_size, num_heads, word_size]`. strengths: A 2-D tensor of shape `[batch_size, num_heads]`. Returns: Weights tensor of shape `[batch_size, num_heads, memory_size]`. """ dot = torch.matmul(keys, memory.transpose(-1, -2)) # memory_norms = _vector_norms(memory) # key_norms = _vector_norms(keys) # norm = torch.matmul(key_norms, memory_norms.transpose(-1, -2)) # similarity = dot / (norm + _EPSILON) return weighted_softmax(similarity, strengths, self._strength_op) class TemporalLinkage(torch.nn.Module): def __init__(self, memory_size, num_writes): super().__init__() self._memory_size = memory_size self._num_writes = num_writes def _link( self, prev_link: Tensor, prev_precedence_weights: Tensor, write_weights: Tensor, ) -> Tensor: """Calculates the new link graphs. For each write head, the link is a directed graph (represented by a matrix with entries in range [0, 1]) whose vertices are the memory locations, and an edge indicates temporal ordering of writes. Args: prev_link: A tensor of shape `[batch_size, num_writes, memory_size, memory_size]` representing the previous link graphs for each write head. prev_precedence_weights: A tensor of shape `[batch_size, num_writes, memory_size]` which is the previous "aggregated" write weights for each write head. write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` containing the new locations in memory written to. Returns: A tensor of shape `[batch_size, num_writes, memory_size, memory_size]` containing the new link graphs for each write head. """ batch_size = prev_link.size(0) write_weights_i = write_weights.unsqueeze(3) # write_weights_j = write_weights.unsqueeze(2) # prev_precedence_weights_j = prev_precedence_weights.unsqueeze(2) # prev_link_scale = 1 - write_weights_i - write_weights_j # new_link = write_weights_i * prev_precedence_weights_j # link = prev_link_scale * prev_link + new_link # # Return the link with the diagonal set to zero, to remove self-looping # edges. mask = ( torch.eye(self._memory_size) .repeat(batch_size, self._num_writes, 1, 1) .bool() ) link[mask] = 0 return link def _precedence_weights( self, prev_precedence_weights: Tensor, write_weights: Tensor, ) -> Tensor: """Calculates the new precedence weights given the current write weights. The precedence weights are the "aggregated write weights" for each write head, where write weights with sum close to zero will leave the precedence weights unchanged, but with sum close to one will replace the precedence weights. Args: prev_precedence_weights: A tensor of shape `[batch_size, num_writes, memory_size]` containing the previous precedence weights. write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` containing the new write weights. Returns: A tensor of shape `[batch_size, num_writes, memory_size]` containing the new precedence weights. """ write_sum = write_weights.sum(dim=2, keepdim=True) return (1 - write_sum) * prev_precedence_weights + write_weights def forward( self, write_weights: Tensor, prev_state: Optional[TemporalLinkageState] = None, ) -> TemporalLinkageState: """Calculate the updated linkage state given the write weights. Args: write_weights: A tensor of shape `[batch_size, num_writes, memory_size]` containing the memory addresses of the different write heads. prev_state: `TemporalLinkageState` tuple containg a tensor `link` of shape `[batch_size, num_writes, memory_size, memory_size]`, and a tensor `precedence_weights` of shape `[batch_size, num_writes, memory_size]` containing the aggregated history of recent writes. Returns: A `TemporalLinkageState` tuple `next_state`, which contains the updated link and precedence weights. """ if write_weights.ndim != 3: raise ValueError( f"Expected `write_weights` to be 3D. Found: {write_weights.ndim}" ) if ( write_weights.size(1) != self._num_writes or write_weights.size(2) != self._memory_size ): raise ValueError( "Expected the shape of `write_weights` to be " f"[batch, {self._num_writes}, {self._memory_size}]. " f"Found: {write_weights.shape}." ) if prev_state is None: B, W, M = write_weights.shape prev_state = TemporalLinkageState( link=torch.zeros((B, W, M, M), device=write_weights.device), precedence_weights=torch.zeros((B, W, M), device=write_weights.device), ) link = self._link(prev_state.link, prev_state.precedence_weights, write_weights) precedence_weights = self._precedence_weights( prev_state.precedence_weights, write_weights ) return TemporalLinkageState(link=link, precedence_weights=precedence_weights) def directional_read_weights( self, link: Tensor, prev_read_weights: Tensor, is_forward: bool, ) -> Tensor: """Calculates the forward or the backward read weights. For each read head (at a given address), there are `num_writes` link graphs to follow. Thus this function computes a read address for each of the `num_reads * num_writes` pairs of read and write heads. Args: link: tensor of shape `[batch_size, num_writes, memory_size, memory_size]` representing the link graphs L_t. prev_read_weights: tensor of shape `[batch_size, num_reads, memory_size]` containing the previous read weights w_{t-1}^r. forward: Boolean indicating whether to follow the "future" direction in the link graph (True) or the "past" direction (False). Returns: tensor of shape `[batch_size, num_reads, num_writes, memory_size]` """ # expanded_read_weights = torch.stack( [prev_read_weights for _ in range(self._num_writes)], dim=1 ) if is_forward: link = link.transpose(-1, -2) result = torch.matmul(expanded_read_weights, link) # return result.permute((0, 2, 1, 3)) # class Freeness(torch.nn.Module): def __init__(self, memory_size): super().__init__() self._memory_size = memory_size def _usage_after_write( self, prev_usage: Tensor, write_weights: Tensor, ) -> Tensor: """Calcualtes the new usage after writing to memory. Args: prev_usage: tensor of shape `[batch_size, memory_size]`. write_weights: tensor of shape `[batch_size, num_writes, memory_size]`. Returns: New usage, a tensor of shape `[batch_size, memory_size]`. """ write_weights = 1 - torch.prod(1 - write_weights, 1) return prev_usage + (1 - prev_usage) * write_weights def _usage_after_read( self, prev_usage: Tensor, free_gate: Tensor, read_weights: Tensor ) -> Tensor: """Calcualtes the new usage after reading and freeing from memory. Args: prev_usage: tensor of shape `[batch_size, memory_size]`. free_gate: tensor of shape `[batch_size, num_reads]` with entries in the range [0, 1] indicating the amount that locations read from can be freed. read_weights: tensor of shape `[batch_size, num_reads, memory_size]`. Returns: New usage, a tensor of shape `[batch_size, memory_size]`. """ free_gate = free_gate.unsqueeze(-1) free_read_weights = free_gate * read_weights phi = torch.prod(1 - free_read_weights, 1) return prev_usage * phi def forward( self, write_weights: Tensor, free_gate: Tensor, read_weights: Tensor, prev_usage: Tensor, ) -> Tensor: """Calculates the new memory usage u_t. Memory that was written to in the previous time step will have its usage increased; memory that was read from and the controller says can be "freed" will have its usage decreased. Args: write_weights: tensor of shape `[batch_size, num_writes, memory_size]` giving write weights at previous time step. free_gate: tensor of shape `[batch_size, num_reads]` which indicates which read heads read memory that can now be freed. read_weights: tensor of shape `[batch_size, num_reads, memory_size]` giving read weights at previous time step. prev_usage: tensor of shape `[batch_size, memory_size]` giving usage u_{t - 1} at the previous time step, with entries in range [0, 1]. Returns: tensor of shape `[batch_size, memory_size]` representing updated memory usage. """ with torch.no_grad(): usage = self._usage_after_write(prev_usage, write_weights) usage = self._usage_after_read(usage, free_gate, read_weights) return usage def _allocation(self, usage: Tensor) -> Tensor: """Computes allocation by sorting `usage`. This corresponds to the value a = a_t[\phi_t[j]] in the paper. Args: usage: tensor of shape `[batch_size, memory_size]` indicating current memory usage. This is equal to u_t in the paper when we only have one write head, but for multiple write heads, one should update the usage while iterating through the write heads to take into account the allocation returned by this function. Returns: Tensor of shape `[batch_size, memory_size]` corresponding to allocation. """ usage = _EPSILON + (1 - _EPSILON) * usage nonusage = 1 - usage sorted_nonusage, indices = torch.topk(nonusage, k=self._memory_size) sorted_usage = 1 - sorted_nonusage # emulate tf.cumprod(exclusive=True) sorted_usage = F.pad(sorted_usage, (1, 0), mode="constant", value=1) prod_sorted_usage = torch.cumprod(sorted_usage, dim=1) prod_sorted_usage = prod_sorted_usage[:, :-1] sorted_allocation = sorted_nonusage * prod_sorted_usage inverse_indices = torch.argsort(indices) return torch.gather(sorted_allocation, 1, inverse_indices) def write_allocation_weights( self, usage: Tensor, write_gates: Tensor, num_writes: Tensor, ) -> Tensor: """Calculates freeness-based locations for writing to. This finds unused memory by ranking the memory locations by usage, for each write head. (For more than one write head, we use a "simulated new usage" which takes into account the fact that the previous write head will increase the usage in that area of the memory.) Args: usage: A tensor of shape `[batch_size, memory_size]` representing current memory usage. write_gates: A tensor of shape `[batch_size, num_writes]` with values in the range [0, 1] indicating how much each write head does writing based on the address returned here (and hence how much usage increases). num_writes: The number of write heads to calculate write weights for. Returns: tensor of shape `[batch_size, num_writes, memory_size]` containing the freeness-based write locations. Note that this isn't scaled by `write_gate`; this scaling must be applied externally. """ write_gates = write_gates.unsqueeze(-1) allocation_weights = [] for i in range(num_writes): allocation_weights.append(self._allocation(usage)) usage = usage + (1 - usage) * write_gates[:, i, :] * allocation_weights[i] return torch.stack(allocation_weights, dim=1)