Few bug fixes for unsqueezing batch size 1
This commit is contained in:
parent
9164e5721d
commit
1d11fae06b
@ -145,7 +145,7 @@ class SparseTemporalMemory(nn.Module):
|
||||
|
||||
def write_into_sparse_memory(self, hidden):
|
||||
visible_memory = hidden['visible_memory']
|
||||
positions = hidden['read_positions'].squeeze()
|
||||
positions = hidden['read_positions']
|
||||
|
||||
(b, m, w) = hidden['memory'].size()
|
||||
# update memory
|
||||
@ -181,7 +181,7 @@ class SparseTemporalMemory(nn.Module):
|
||||
|
||||
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (temporal_write_weights_j * precedence_dense_i)
|
||||
|
||||
return link_matrix.squeeze() * I, rev_link_matrix.squeeze() * I
|
||||
return link_matrix * I, rev_link_matrix * I
|
||||
|
||||
def update_precedence(self, precedence, write_weights):
|
||||
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
|
||||
@ -255,8 +255,8 @@ class SparseTemporalMemory(nn.Module):
|
||||
return usage, I
|
||||
|
||||
def directional_weightings(self, link_matrix, rev_link_matrix, temporal_read_weights):
|
||||
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
||||
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
||||
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2)
|
||||
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2)
|
||||
return f, b
|
||||
|
||||
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage, forward, backward, prev_read_positions):
|
||||
|
@ -214,7 +214,6 @@ if __name__ == '__main__':
|
||||
else:
|
||||
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||
|
||||
# print(.size(), target_output.size())
|
||||
output = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True)
|
||||
loss = cross_entropy(output, target_output)
|
||||
|
||||
|
@ -220,7 +220,7 @@ if __name__ == '__main__':
|
||||
else:
|
||||
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||
|
||||
loss = T.mean((output[:, -1, :].sum() - target_output.sum()) ** 2, dim=-1)
|
||||
loss = T.mean(((loss_weights * output).sum(-1, keepdim=True) - target_output) ** 2)
|
||||
|
||||
loss.backward()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user