from dataclasses import dataclass import torch from torch import Tensor import torch.nn.functional as F @dataclass class DatasetTensors: observations: Tensor target: Tensor mask: Tensor def masked_sigmoid_cross_entropy( logits, target, mask, time_average=False, log_prob_in_bits=False, ): """Adds ops to graph which compute the (scalar) NLL of the target sequence. The logits parametrize independent bernoulli distributions per time-step and per batch element, and irrelevant time/batch elements are masked out by the mask tensor. Args: logits: `Tensor` of activations for which sigmoid(`logits`) gives the bernoulli parameter. target: time-major `Tensor` of target. mask: time-major `Tensor` to be multiplied elementwise with cost T x B cost masking out irrelevant time-steps. time_average: optionally average over the time dimension (sum by default). log_prob_in_bits: iff True express log-probabilities in bits (default nats). Returns: A `Tensor` representing the log-probability of the target. """ batch_size = logits.shape[1] xent = F.binary_cross_entropy_with_logits(logits, target, reduction="none") loss_time_batch = xent.sum(dim=2) loss_batch = (loss_time_batch * mask).sum(dim=0) if time_average: mask_count = mask.sum(dim=0) loss_batch /= mask_count + torch.finfo(loss_batch.dtype).eps loss = loss_batch.sum() / batch_size if log_prob_in_bits: loss /= torch.log(2.0) return loss def bitstring_readable(data, batch_size, model_output=None, whole_batch=False): """Produce a human readable representation of the sequences in data. Args: data: data to be visualised batch_size: size of batch model_output: optional model output tensor to visualize alongside data. whole_batch: whether to visualise the whole batch. Only the first sample will be visualized if False Returns: A string used to visualise the data batch """ def _readable(datum): return "+" + " ".join([f"{int(x):d}" if x else "-" for x in datum]) + "+" obs_batch = data.observations targ_batch = data.target batch_strings = [] for batch_index in range(batch_size if whole_batch else 1): obs = obs_batch[:, batch_index, :] targ = targ_batch[:, batch_index, :] obs_channels = range(obs.shape[1]) targ_channels = range(targ.shape[1]) obs_channel_strings = [_readable(obs[:, i]) for i in obs_channels] targ_channel_strings = [_readable(targ[:, i]) for i in targ_channels] readable_obs = "Observations:\n" + "\n".join(obs_channel_strings) readable_targ = "Targets:\n" + "\n".join(targ_channel_strings) strings = [readable_obs, readable_targ] if model_output is not None: output = model_output[:, batch_index, :] output_strings = [_readable(output[:, i]) for i in targ_channels] strings.append("Model Output:\n" + "\n".join(output_strings)) batch_strings.append("\n\n".join(strings)) return "\n" + "\n\n\n\n".join(batch_strings) class RepeatCopy: """Sequence data generator for the task of repeating a random binary pattern. When called, an instance of this class will return a tuple of tensorflow ops (obs, targ, mask), representing an input sequence, target sequence, and binary mask. Each of these ops produces tensors whose first two dimensions represent sequence position and batch index respectively. The value in mask[t, b] is equal to 1 iff a prediction about targ[t, b, :] should be penalized and 0 otherwise. For each realisation from this generator, the observation sequence is comprised of I.I.D. uniform-random binary vectors (and some flags). The target sequence is comprised of this binary pattern repeated some number of times (and some flags). Before explaining in more detail, let's examine the setup pictorially for a single batch element: ```none Note: blank space represents 0. time ------------------------------------------> +-------------------------------+ mask: |0000000001111111111111111111111| +-------------------------------+ +-------------------------------+ target: | 1| 'end-marker' channel. | 101100110110011011001 | | 010101001010100101010 | +-------------------------------+ +-------------------------------+ observation: | 1011001 | | 0101010 | |1 | 'start-marker' channel | 3 | 'num-repeats' channel. +-------------------------------+ ``` The length of the random pattern and the number of times it is repeated in the target are both discrete random variables distributed according to uniform distributions whose parameters are configured at construction time. The obs sequence has two extra channels (components in the trailing dimension) which are used for flags. One channel is marked with a 1 at the first time step and is otherwise equal to 0. The other extra channel is zero until the binary pattern to be repeated ends. At this point, it contains an encoding of the number of times the observation pattern should be repeated. Rather than simply providing this integer number directly, it is normalised so that a neural network may have an easier time representing the number of repetitions internally. To allow a network to be readily evaluated on instances of this task with greater numbers of repetitions, the range with respect to which this encoding is normalised is also configurable by the user. As in the diagram, the target sequence is offset to begin directly after the observation sequence; both sequences are padded with zeros to accomplish this, resulting in their lengths being equal. Additional padding is done at the end so that all sequences in a minibatch represent tensors with the same shape. """ def __init__( self, num_bits=6, batch_size=1, min_length=1, max_length=1, min_repeats=1, max_repeats=2, norm_max=10, log_prob_in_bits=False, time_average_cost=False, ): """Creates an instance of RepeatCopy task. Args: name: A name for the generator instance (for name scope purposes). num_bits: The dimensionality of each random binary vector. batch_size: Minibatch size per realization. min_length: Lower limit on number of random binary vectors in the observation pattern. max_length: Upper limit on number of random binary vectors in the observation pattern. min_repeats: Lower limit on number of times the obervation pattern is repeated in targ. max_repeats: Upper limit on number of times the observation pattern is repeated in targ. norm_max: Upper limit on uniform distribution w.r.t which the encoding of the number of repetitions presented in the observation sequence is normalised. log_prob_in_bits: By default, log probabilities are expressed in units of nats. If true, express log probabilities in bits. time_average_cost: If true, the cost at each time step will be divided by the `true`, sequence length, the number of non-masked time steps, in each sequence before any subsequent reduction over the time and batch dimensions. """ super().__init__() self._batch_size = batch_size self._num_bits = num_bits self._min_length = min_length self._max_length = max_length self._min_repeats = min_repeats self._max_repeats = max_repeats self._norm_max = norm_max self._log_prob_in_bits = log_prob_in_bits self._time_average_cost = time_average_cost def _normalise(self, val): return val / self._norm_max def _unnormalise(self, val): return val * self._norm_max @property def time_average_cost(self): return self._time_average_cost @property def log_prob_in_bits(self): return self._log_prob_in_bits @property def num_bits(self): """The dimensionality of each random binary vector in a pattern.""" return self._num_bits @property def target_size(self): """The dimensionality of the target tensor.""" return self._num_bits + 1 @property def batch_size(self): return self._batch_size def __call__(self, device=None): """Implements build method which adds ops to graph.""" # short-hand for private fields. min_length, max_length = self._min_length, self._max_length min_reps, max_reps = self._min_repeats, self._max_repeats num_bits = self.num_bits batch_size = self.batch_size # We reserve one dimension for the num-repeats and one for the start-marker. full_obs_size = num_bits + 2 # We reserve one target dimension for the end-marker. full_targ_size = num_bits + 1 start_end_flag_idx = full_obs_size - 2 num_repeats_channel_idx = full_obs_size - 1 # Samples each batch index's sequence length and the number of repeats. sub_seq_length_batch = torch.randint( low=min_length, high=max_length + 1, size=[batch_size], dtype=torch.int32 ) num_repeats_batch = torch.randint( low=min_reps, high=max_reps + 1, size=[batch_size], dtype=torch.int32 ) # Pads all the batches to have the same total sequence length. total_length_batch = sub_seq_length_batch * (num_repeats_batch + 1) + 3 max_length_batch = total_length_batch.max() residual_length_batch = max_length_batch - total_length_batch obs_batch_shape = [max_length_batch, batch_size, full_obs_size] targ_batch_shape = [max_length_batch, batch_size, full_targ_size] mask_batch_trans_shape = [batch_size, max_length_batch] obs_tensors = [] targ_tensors = [] mask_tensors = [] # Generates patterns for each batch element independently. for batch_index in range(batch_size): sub_seq_len = sub_seq_length_batch[batch_index] num_reps = num_repeats_batch[batch_index] # The observation pattern is a sequence of random binary vectors. obs_pattern_shape = [sub_seq_len, num_bits] obs_pattern = torch.randint(low=0, high=2, size=obs_pattern_shape).to( torch.float32 ) # The target pattern is the observation pattern repeated n times. # Some reshaping is required to accomplish the tiling. targ_pattern_shape = [sub_seq_len * num_reps, num_bits] flat_obs_pattern = obs_pattern.reshape([-1]) flat_targ_pattern = torch.tile(flat_obs_pattern, [num_reps]) targ_pattern = flat_targ_pattern.reshape(targ_pattern_shape) # Expand the obs_pattern to have two extra channels for flags. # Concatenate start flag and num_reps flag to the sequence. obs_flag_channel_pad = torch.zeros([sub_seq_len, 2]) obs_start_flag = F.one_hot( torch.tensor([start_end_flag_idx]), num_classes=full_obs_size ).to(torch.float32) num_reps_flag = F.one_hot( torch.tensor([num_repeats_channel_idx]), num_classes=full_obs_size ).to(torch.float32) num_reps_flag *= self._normalise(num_reps) # note the concatenation dimensions. obs = torch.concat([obs_pattern, obs_flag_channel_pad], 1) obs = torch.concat([obs_start_flag, obs], 0) obs = torch.concat([obs, num_reps_flag], 0) # Now do the same for the targ_pattern (it only has one extra channel). targ_flag_channel_pad = torch.zeros([sub_seq_len * num_reps, 1]) targ_end_flag = F.one_hot( torch.tensor([start_end_flag_idx]), num_classes=full_targ_size ).to(torch.float32) targ = torch.concat([targ_pattern, targ_flag_channel_pad], 1) targ = torch.concat([targ, targ_end_flag], 0) # Concatenate zeros at end of obs and begining of targ. # This aligns them s.t. the target begins as soon as the obs ends. obs_end_pad = torch.zeros([sub_seq_len * num_reps + 1, full_obs_size]) targ_start_pad = torch.zeros([sub_seq_len + 2, full_targ_size]) # The mask is zero during the obs and one during the targ. mask_off = torch.zeros([sub_seq_len + 2]) mask_on = torch.ones([sub_seq_len * num_reps + 1]) obs = torch.concat([obs, obs_end_pad], 0) targ = torch.concat([targ_start_pad, targ], 0) mask = torch.concat([mask_off, mask_on], 0) obs_tensors.append(obs) targ_tensors.append(targ) mask_tensors.append(mask) # End the loop over batch index. # Compute how much zero padding is needed to make tensors sequences # the same length for all batch elements. residual_obs_pad = [ torch.zeros([residual_length_batch[i], full_obs_size]) for i in range(batch_size) ] residual_targ_pad = [ torch.zeros([residual_length_batch[i], full_targ_size]) for i in range(batch_size) ] residual_mask_pad = [ torch.zeros([residual_length_batch[i]]) for i in range(batch_size) ] # Concatenate the pad to each batch element. obs_tensors = [ torch.concat([o, p], 0) for o, p in zip(obs_tensors, residual_obs_pad) ] targ_tensors = [ torch.concat([t, p], 0) for t, p in zip(targ_tensors, residual_targ_pad) ] mask_tensors = [ torch.concat([m, p], 0) for m, p in zip(mask_tensors, residual_mask_pad) ] # Concatenate each batch element into a single tensor. obs = torch.concat(obs_tensors, 1).reshape(obs_batch_shape) targ = torch.concat(targ_tensors, 1).reshape(targ_batch_shape) mask = torch.concat(mask_tensors, 0).reshape(mask_batch_trans_shape).T if device is not None: obs = obs.to(device) targ = targ.to(device) mask = mask.to(device) return DatasetTensors(obs, targ, mask) def cost(self, logits, targ, mask): return masked_sigmoid_cross_entropy( logits, targ, mask, time_average=self.time_average_cost, log_prob_in_bits=self.log_prob_in_bits, ) def to_human_readable(self, data, model_output=None, whole_batch=False): obs = data.observations unnormalised_num_reps_flag = self._unnormalise(obs[:, :, -1:]).round() obs = torch.cat([obs[:, :, :-1], unnormalised_num_reps_flag], dim=2) data = DatasetTensors( observations=obs, target=data.target, mask=data.mask, ) return bitstring_readable(data, self.batch_size, model_output, whole_batch)