dnc_pytorch/dnc/repeat_copy.py
2022-08-23 13:31:09 +00:00

388 lines
15 KiB
Python

from dataclasses import dataclass
import torch
from torch import Tensor
import torch.nn.functional as F
@dataclass
class DatasetTensors:
observations: Tensor
target: Tensor
mask: Tensor
def masked_sigmoid_cross_entropy(
logits,
target,
mask,
time_average=False,
log_prob_in_bits=False,
):
"""Adds ops to graph which compute the (scalar) NLL of the target sequence.
The logits parametrize independent bernoulli distributions per time-step and
per batch element, and irrelevant time/batch elements are masked out by the
mask tensor.
Args:
logits: `Tensor` of activations for which sigmoid(`logits`) gives the
bernoulli parameter.
target: time-major `Tensor` of target.
mask: time-major `Tensor` to be multiplied elementwise with cost T x B cost
masking out irrelevant time-steps.
time_average: optionally average over the time dimension (sum by default).
log_prob_in_bits: iff True express log-probabilities in bits (default nats).
Returns:
A `Tensor` representing the log-probability of the target.
"""
batch_size = logits.shape[1]
xent = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
loss_time_batch = xent.sum(dim=2)
loss_batch = (loss_time_batch * mask).sum(dim=0)
if time_average:
mask_count = mask.sum(dim=0)
loss_batch /= mask_count + torch.finfo(loss_batch.dtype).eps
loss = loss_batch.sum() / batch_size
if log_prob_in_bits:
loss /= torch.log(2.0)
return loss
def bitstring_readable(data, batch_size, model_output=None, whole_batch=False):
"""Produce a human readable representation of the sequences in data.
Args:
data: data to be visualised
batch_size: size of batch
model_output: optional model output tensor to visualize alongside data.
whole_batch: whether to visualise the whole batch. Only the first sample
will be visualized if False
Returns:
A string used to visualise the data batch
"""
def _readable(datum):
return "+" + " ".join([f"{int(x):d}" if x else "-" for x in datum]) + "+"
obs_batch = data.observations
targ_batch = data.target
batch_strings = []
for batch_index in range(batch_size if whole_batch else 1):
obs = obs_batch[:, batch_index, :]
targ = targ_batch[:, batch_index, :]
obs_channels = range(obs.shape[1])
targ_channels = range(targ.shape[1])
obs_channel_strings = [_readable(obs[:, i]) for i in obs_channels]
targ_channel_strings = [_readable(targ[:, i]) for i in targ_channels]
readable_obs = "Observations:\n" + "\n".join(obs_channel_strings)
readable_targ = "Targets:\n" + "\n".join(targ_channel_strings)
strings = [readable_obs, readable_targ]
if model_output is not None:
output = model_output[:, batch_index, :]
output_strings = [_readable(output[:, i]) for i in targ_channels]
strings.append("Model Output:\n" + "\n".join(output_strings))
batch_strings.append("\n\n".join(strings))
return "\n" + "\n\n\n\n".join(batch_strings)
class RepeatCopy:
"""Sequence data generator for the task of repeating a random binary pattern.
When called, an instance of this class will return a tuple of tensorflow ops
(obs, targ, mask), representing an input sequence, target sequence, and
binary mask. Each of these ops produces tensors whose first two dimensions
represent sequence position and batch index respectively. The value in
mask[t, b] is equal to 1 iff a prediction about targ[t, b, :] should be
penalized and 0 otherwise.
For each realisation from this generator, the observation sequence is
comprised of I.I.D. uniform-random binary vectors (and some flags).
The target sequence is comprised of this binary pattern repeated
some number of times (and some flags). Before explaining in more detail,
let's examine the setup pictorially for a single batch element:
```none
Note: blank space represents 0.
time ------------------------------------------>
+-------------------------------+
mask: |0000000001111111111111111111111|
+-------------------------------+
+-------------------------------+
target: | 1| 'end-marker' channel.
| 101100110110011011001 |
| 010101001010100101010 |
+-------------------------------+
+-------------------------------+
observation: | 1011001 |
| 0101010 |
|1 | 'start-marker' channel
| 3 | 'num-repeats' channel.
+-------------------------------+
```
The length of the random pattern and the number of times it is repeated
in the target are both discrete random variables distributed according to
uniform distributions whose parameters are configured at construction time.
The obs sequence has two extra channels (components in the trailing dimension)
which are used for flags. One channel is marked with a 1 at the first time
step and is otherwise equal to 0. The other extra channel is zero until the
binary pattern to be repeated ends. At this point, it contains an encoding of
the number of times the observation pattern should be repeated. Rather than
simply providing this integer number directly, it is normalised so that
a neural network may have an easier time representing the number of
repetitions internally. To allow a network to be readily evaluated on
instances of this task with greater numbers of repetitions, the range with
respect to which this encoding is normalised is also configurable by the user.
As in the diagram, the target sequence is offset to begin directly after the
observation sequence; both sequences are padded with zeros to accomplish this,
resulting in their lengths being equal. Additional padding is done at the end
so that all sequences in a minibatch represent tensors with the same shape.
"""
def __init__(
self,
num_bits=6,
batch_size=1,
min_length=1,
max_length=1,
min_repeats=1,
max_repeats=2,
norm_max=10,
log_prob_in_bits=False,
time_average_cost=False,
):
"""Creates an instance of RepeatCopy task.
Args:
name: A name for the generator instance (for name scope purposes).
num_bits: The dimensionality of each random binary vector.
batch_size: Minibatch size per realization.
min_length: Lower limit on number of random binary vectors in the
observation pattern.
max_length: Upper limit on number of random binary vectors in the
observation pattern.
min_repeats: Lower limit on number of times the obervation pattern
is repeated in targ.
max_repeats: Upper limit on number of times the observation pattern
is repeated in targ.
norm_max: Upper limit on uniform distribution w.r.t which the encoding
of the number of repetitions presented in the observation sequence
is normalised.
log_prob_in_bits: By default, log probabilities are expressed in units of
nats. If true, express log probabilities in bits.
time_average_cost: If true, the cost at each time step will be
divided by the `true`, sequence length, the number of non-masked time
steps, in each sequence before any subsequent reduction over the time
and batch dimensions.
"""
super().__init__()
self._batch_size = batch_size
self._num_bits = num_bits
self._min_length = min_length
self._max_length = max_length
self._min_repeats = min_repeats
self._max_repeats = max_repeats
self._norm_max = norm_max
self._log_prob_in_bits = log_prob_in_bits
self._time_average_cost = time_average_cost
def _normalise(self, val):
return val / self._norm_max
def _unnormalise(self, val):
return val * self._norm_max
@property
def time_average_cost(self):
return self._time_average_cost
@property
def log_prob_in_bits(self):
return self._log_prob_in_bits
@property
def num_bits(self):
"""The dimensionality of each random binary vector in a pattern."""
return self._num_bits
@property
def target_size(self):
"""The dimensionality of the target tensor."""
return self._num_bits + 1
@property
def batch_size(self):
return self._batch_size
def __call__(self, device=None):
"""Implements build method which adds ops to graph."""
# short-hand for private fields.
min_length, max_length = self._min_length, self._max_length
min_reps, max_reps = self._min_repeats, self._max_repeats
num_bits = self.num_bits
batch_size = self.batch_size
# We reserve one dimension for the num-repeats and one for the start-marker.
full_obs_size = num_bits + 2
# We reserve one target dimension for the end-marker.
full_targ_size = num_bits + 1
start_end_flag_idx = full_obs_size - 2
num_repeats_channel_idx = full_obs_size - 1
# Samples each batch index's sequence length and the number of repeats.
sub_seq_length_batch = torch.randint(
low=min_length, high=max_length + 1, size=[batch_size], dtype=torch.int32
)
num_repeats_batch = torch.randint(
low=min_reps, high=max_reps + 1, size=[batch_size], dtype=torch.int32
)
# Pads all the batches to have the same total sequence length.
total_length_batch = sub_seq_length_batch * (num_repeats_batch + 1) + 3
max_length_batch = total_length_batch.max()
residual_length_batch = max_length_batch - total_length_batch
obs_batch_shape = [max_length_batch, batch_size, full_obs_size]
targ_batch_shape = [max_length_batch, batch_size, full_targ_size]
mask_batch_trans_shape = [batch_size, max_length_batch]
obs_tensors = []
targ_tensors = []
mask_tensors = []
# Generates patterns for each batch element independently.
for batch_index in range(batch_size):
sub_seq_len = sub_seq_length_batch[batch_index]
num_reps = num_repeats_batch[batch_index]
# The observation pattern is a sequence of random binary vectors.
obs_pattern_shape = [sub_seq_len, num_bits]
obs_pattern = torch.randint(low=0, high=2, size=obs_pattern_shape).to(
torch.float32
)
# The target pattern is the observation pattern repeated n times.
# Some reshaping is required to accomplish the tiling.
targ_pattern_shape = [sub_seq_len * num_reps, num_bits]
flat_obs_pattern = obs_pattern.reshape([-1])
flat_targ_pattern = torch.tile(flat_obs_pattern, [num_reps])
targ_pattern = flat_targ_pattern.reshape(targ_pattern_shape)
# Expand the obs_pattern to have two extra channels for flags.
# Concatenate start flag and num_reps flag to the sequence.
obs_flag_channel_pad = torch.zeros([sub_seq_len, 2])
obs_start_flag = F.one_hot(
torch.tensor([start_end_flag_idx]), num_classes=full_obs_size
).to(torch.float32)
num_reps_flag = F.one_hot(
torch.tensor([num_repeats_channel_idx]), num_classes=full_obs_size
).to(torch.float32)
num_reps_flag *= self._normalise(num_reps)
# note the concatenation dimensions.
obs = torch.concat([obs_pattern, obs_flag_channel_pad], 1)
obs = torch.concat([obs_start_flag, obs], 0)
obs = torch.concat([obs, num_reps_flag], 0)
# Now do the same for the targ_pattern (it only has one extra channel).
targ_flag_channel_pad = torch.zeros([sub_seq_len * num_reps, 1])
targ_end_flag = F.one_hot(
torch.tensor([start_end_flag_idx]), num_classes=full_targ_size
).to(torch.float32)
targ = torch.concat([targ_pattern, targ_flag_channel_pad], 1)
targ = torch.concat([targ, targ_end_flag], 0)
# Concatenate zeros at end of obs and begining of targ.
# This aligns them s.t. the target begins as soon as the obs ends.
obs_end_pad = torch.zeros([sub_seq_len * num_reps + 1, full_obs_size])
targ_start_pad = torch.zeros([sub_seq_len + 2, full_targ_size])
# The mask is zero during the obs and one during the targ.
mask_off = torch.zeros([sub_seq_len + 2])
mask_on = torch.ones([sub_seq_len * num_reps + 1])
obs = torch.concat([obs, obs_end_pad], 0)
targ = torch.concat([targ_start_pad, targ], 0)
mask = torch.concat([mask_off, mask_on], 0)
obs_tensors.append(obs)
targ_tensors.append(targ)
mask_tensors.append(mask)
# End the loop over batch index.
# Compute how much zero padding is needed to make tensors sequences
# the same length for all batch elements.
residual_obs_pad = [
torch.zeros([residual_length_batch[i], full_obs_size])
for i in range(batch_size)
]
residual_targ_pad = [
torch.zeros([residual_length_batch[i], full_targ_size])
for i in range(batch_size)
]
residual_mask_pad = [
torch.zeros([residual_length_batch[i]]) for i in range(batch_size)
]
# Concatenate the pad to each batch element.
obs_tensors = [
torch.concat([o, p], 0) for o, p in zip(obs_tensors, residual_obs_pad)
]
targ_tensors = [
torch.concat([t, p], 0) for t, p in zip(targ_tensors, residual_targ_pad)
]
mask_tensors = [
torch.concat([m, p], 0) for m, p in zip(mask_tensors, residual_mask_pad)
]
# Concatenate each batch element into a single tensor.
obs = torch.concat(obs_tensors, 1).reshape(obs_batch_shape)
targ = torch.concat(targ_tensors, 1).reshape(targ_batch_shape)
mask = torch.concat(mask_tensors, 0).reshape(mask_batch_trans_shape).T
if device is not None:
obs = obs.to(device)
targ = targ.to(device)
mask = mask.to(device)
return DatasetTensors(obs, targ, mask)
def cost(self, logits, targ, mask):
return masked_sigmoid_cross_entropy(
logits,
targ,
mask,
time_average=self.time_average_cost,
log_prob_in_bits=self.log_prob_in_bits,
)
def to_human_readable(self, data, model_output=None, whole_batch=False):
obs = data.observations
unnormalised_num_reps_flag = self._unnormalise(obs[:, :, -1:]).round()
obs = torch.cat([obs[:, :, :-1], unnormalised_num_reps_flag], dim=2)
data = DatasetTensors(
observations=obs,
target=data.target,
mask=data.mask,
)
return bitstring_readable(data, self.batch_size, model_output, whole_batch)