130 lines
4.6 KiB
Python
130 lines
4.6 KiB
Python
from dataclasses import dataclass
|
|
from typing import Tuple, Optional
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from .access import MemoryAccess, AccessState
|
|
|
|
|
|
@dataclass
|
|
class DNCState:
|
|
access_output: Tensor # [batch_size, num_reads, word_size]
|
|
access_state: AccessState
|
|
# memory: Tensor [batch_size, memory_size, word_size]
|
|
# read_weights: Tensor # [batch_size, num_reads, memory_size]
|
|
# write_weights: Tensor # [batch_size, num_writes, memory_size]
|
|
# linkage: TemporalLinkageState
|
|
# link: [batch_size, num_writes, memory_size, memory_size]
|
|
# precedence_weights: [batch_size, num_writes, memory_size]
|
|
# usage: Tensor # [batch_size, memory_size]
|
|
controller_state: Tuple[Tensor, Tensor]
|
|
# h_n: [num_layers, batch_size, projection_size]
|
|
# c_n: [num_layers, batch_size, hidden_size]
|
|
|
|
|
|
class DNC(torch.nn.Module):
|
|
"""DNC core module.
|
|
|
|
Contains controller and memory access module.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
access_config,
|
|
controller_config,
|
|
output_size,
|
|
clip_value=None,
|
|
):
|
|
"""Initializes the DNC core.
|
|
|
|
Args:
|
|
access_config: dictionary of access module configurations.
|
|
controller_config: dictionary of controller (LSTM) module configurations.
|
|
output_size: output dimension size of core.
|
|
clip_value: clips controller and core output values to between
|
|
`[-clip_value, clip_value]` if specified.
|
|
Raises:
|
|
TypeError: if direct_input_size is not None for any access module other
|
|
than KeyValueMemory.
|
|
"""
|
|
super().__init__()
|
|
|
|
self._controller = torch.nn.LSTMCell(**controller_config)
|
|
self._access = MemoryAccess(**access_config)
|
|
self._output = torch.nn.LazyLinear(output_size)
|
|
if clip_value is None:
|
|
self._clip = lambda x: x
|
|
else:
|
|
self._clip = lambda x: torch.clamp(x, min=-clip_value, max=clip_value)
|
|
|
|
def forward(self, inputs: Tensor, prev_state: Optional[DNCState] = None):
|
|
"""Connects the DNC core into the graph.
|
|
|
|
Args:
|
|
inputs: Tensor input.
|
|
prev_state: A `DNCState` tuple containing the fields `access_output`,
|
|
`access_state` and `controller_state`. `access_state` is a 3-D Tensor
|
|
of shape `[batch_size, num_reads, word_size]` containing read words.
|
|
`access_state` is a tuple of the access module's state, and
|
|
`controller_state` is a tuple of controller module's state.
|
|
|
|
Returns:
|
|
A tuple `(output, next_state)` where `output` is a tensor and `next_state`
|
|
is a `DNCState` tuple containing the fields `access_output`,
|
|
`access_state`, and `controller_state`.
|
|
"""
|
|
if inputs.ndim != 2:
|
|
raise ValueError(f"Expected `inputs` to be 2D: Found {inputs.ndim}.")
|
|
if prev_state is None:
|
|
B, device = inputs.shape[0], inputs.device
|
|
num_reads = self._access._num_reads
|
|
word_size = self._access._word_size
|
|
prev_state = DNCState(
|
|
access_output=torch.zeros((B, num_reads, word_size), device=device),
|
|
access_state=None,
|
|
controller_state=None,
|
|
)
|
|
|
|
def batch_flatten(x):
|
|
return torch.reshape(x, [x.size(0), -1])
|
|
|
|
controller_input = torch.concat(
|
|
[
|
|
batch_flatten(inputs), # [batch_size, num_input_feats]
|
|
batch_flatten(
|
|
prev_state.access_output
|
|
), # [batch_size, num_reads*word_size]
|
|
],
|
|
dim=1,
|
|
) # [batch_size, num_input_feats + num_reads * word_size]
|
|
|
|
controller_state = self._controller(
|
|
controller_input, prev_state.controller_state
|
|
)
|
|
controller_state = tuple(self._clip(t) for t in controller_state)
|
|
controller_output = controller_state[0]
|
|
|
|
access_output, access_state = self._access(
|
|
controller_output, prev_state.access_state
|
|
)
|
|
|
|
output = torch.concat(
|
|
[
|
|
controller_output, # [batch_size, num_ctrl_feats]
|
|
batch_flatten(access_output), # [batch_size, num_reads*word_size]
|
|
],
|
|
dim=1,
|
|
)
|
|
output = self._output(output)
|
|
output = self._clip(output)
|
|
|
|
return (
|
|
output,
|
|
DNCState(
|
|
access_output=access_output,
|
|
access_state=access_state,
|
|
controller_state=controller_state,
|
|
),
|
|
)
|