Few bug fixes for unsqueezing batch size 1

This commit is contained in:
ixaxaar 2017-12-19 11:17:59 +05:30
parent 9164e5721d
commit 1d11fae06b
3 changed files with 5 additions and 6 deletions

View File

@ -145,7 +145,7 @@ class SparseTemporalMemory(nn.Module):
def write_into_sparse_memory(self, hidden):
visible_memory = hidden['visible_memory']
positions = hidden['read_positions'].squeeze()
positions = hidden['read_positions']
(b, m, w) = hidden['memory'].size()
# update memory
@ -181,7 +181,7 @@ class SparseTemporalMemory(nn.Module):
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (temporal_write_weights_j * precedence_dense_i)
return link_matrix.squeeze() * I, rev_link_matrix.squeeze() * I
return link_matrix * I, rev_link_matrix * I
def update_precedence(self, precedence, write_weights):
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
@ -255,8 +255,8 @@ class SparseTemporalMemory(nn.Module):
return usage, I
def directional_weightings(self, link_matrix, rev_link_matrix, temporal_read_weights):
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2)
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2)
return f, b
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage, forward, backward, prev_read_positions):

View File

@ -214,7 +214,6 @@ if __name__ == '__main__':
else:
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
# print(.size(), target_output.size())
output = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True)
loss = cross_entropy(output, target_output)

View File

@ -220,7 +220,7 @@ if __name__ == '__main__':
else:
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
loss = T.mean((output[:, -1, :].sum() - target_output.sum()) ** 2, dim=-1)
loss = T.mean(((loss_weights * output).sum(-1, keepdim=True) - target_output) ** 2)
loss.backward()