diff --git a/dnc/dnc.py b/dnc/dnc.py index 1c49749..8c22c2f 100644 --- a/dnc/dnc.py +++ b/dnc/dnc.py @@ -178,14 +178,15 @@ class DNC(nn.Module): input, chx = self.rnns[layer](input.unsqueeze(1), chx) input = input.squeeze(1) - # the interface vector - ξ = input # clip the controller output if self.clip != 0: output = T.clamp(input, -self.clip, self.clip) else: output = input + # the interface vector + ξ = output + # pass through memory if pass_through_memory: if self.share_memory: