From 1d11fae06b2715df41390cde98a32ede99c4b25a Mon Sep 17 00:00:00 2001 From: ixaxaar Date: Tue, 19 Dec 2017 11:17:59 +0530 Subject: [PATCH] Few bug fixes for unsqueezing batch size 1 --- dnc/sparse_temporal_memory.py | 8 ++++---- tasks/adding_task.py | 1 - tasks/adding_task_v2.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/dnc/sparse_temporal_memory.py b/dnc/sparse_temporal_memory.py index bb066d7..1154a40 100644 --- a/dnc/sparse_temporal_memory.py +++ b/dnc/sparse_temporal_memory.py @@ -145,7 +145,7 @@ class SparseTemporalMemory(nn.Module): def write_into_sparse_memory(self, hidden): visible_memory = hidden['visible_memory'] - positions = hidden['read_positions'].squeeze() + positions = hidden['read_positions'] (b, m, w) = hidden['memory'].size() # update memory @@ -181,7 +181,7 @@ class SparseTemporalMemory(nn.Module): rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (temporal_write_weights_j * precedence_dense_i) - return link_matrix.squeeze() * I, rev_link_matrix.squeeze() * I + return link_matrix * I, rev_link_matrix * I def update_precedence(self, precedence, write_weights): return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights @@ -255,8 +255,8 @@ class SparseTemporalMemory(nn.Module): return usage, I def directional_weightings(self, link_matrix, rev_link_matrix, temporal_read_weights): - f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze() - b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze() + f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2) + b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2) return f, b def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage, forward, backward, prev_read_positions): diff --git a/tasks/adding_task.py b/tasks/adding_task.py index 360baba..f5a88e8 100644 --- a/tasks/adding_task.py +++ b/tasks/adding_task.py @@ -214,7 +214,6 @@ if __name__ == '__main__': else: output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True) - # print(.size(), target_output.size()) output = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True) loss = cross_entropy(output, target_output) diff --git a/tasks/adding_task_v2.py b/tasks/adding_task_v2.py index 4352771..82ef224 100644 --- a/tasks/adding_task_v2.py +++ b/tasks/adding_task_v2.py @@ -220,7 +220,7 @@ if __name__ == '__main__': else: output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True) - loss = T.mean((output[:, -1, :].sum() - target_output.sum()) ** 2, dim=-1) + loss = T.mean(((loss_weights * output).sum(-1, keepdim=True) - target_output) ** 2) loss.backward()