Few bug fixes for unsqueezing batch size 1
This commit is contained in:
parent
9164e5721d
commit
1d11fae06b
@ -145,7 +145,7 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
|
|
||||||
def write_into_sparse_memory(self, hidden):
|
def write_into_sparse_memory(self, hidden):
|
||||||
visible_memory = hidden['visible_memory']
|
visible_memory = hidden['visible_memory']
|
||||||
positions = hidden['read_positions'].squeeze()
|
positions = hidden['read_positions']
|
||||||
|
|
||||||
(b, m, w) = hidden['memory'].size()
|
(b, m, w) = hidden['memory'].size()
|
||||||
# update memory
|
# update memory
|
||||||
@ -181,7 +181,7 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
|
|
||||||
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (temporal_write_weights_j * precedence_dense_i)
|
rev_link_matrix = (1 - temporal_write_weights_j) * rev_link_matrix + (temporal_write_weights_j * precedence_dense_i)
|
||||||
|
|
||||||
return link_matrix.squeeze() * I, rev_link_matrix.squeeze() * I
|
return link_matrix * I, rev_link_matrix * I
|
||||||
|
|
||||||
def update_precedence(self, precedence, write_weights):
|
def update_precedence(self, precedence, write_weights):
|
||||||
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
|
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
|
||||||
@ -255,8 +255,8 @@ class SparseTemporalMemory(nn.Module):
|
|||||||
return usage, I
|
return usage, I
|
||||||
|
|
||||||
def directional_weightings(self, link_matrix, rev_link_matrix, temporal_read_weights):
|
def directional_weightings(self, link_matrix, rev_link_matrix, temporal_read_weights):
|
||||||
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
f = T.bmm(link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2)
|
||||||
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze()
|
b = T.bmm(rev_link_matrix, temporal_read_weights.unsqueeze(2)).squeeze(2)
|
||||||
return f, b
|
return f, b
|
||||||
|
|
||||||
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage, forward, backward, prev_read_positions):
|
def read_from_sparse_memory(self, memory, indexes, keys, least_used_mem, usage, forward, backward, prev_read_positions):
|
||||||
|
@ -214,7 +214,6 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
|
||||||
# print(.size(), target_output.size())
|
|
||||||
output = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True)
|
output = output.sum(dim=2, keepdim=True).sum(dim=1, keepdim=True)
|
||||||
loss = cross_entropy(output, target_output)
|
loss = cross_entropy(output, target_output)
|
||||||
|
|
||||||
|
@ -220,7 +220,7 @@ if __name__ == '__main__':
|
|||||||
else:
|
else:
|
||||||
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
output, (chx, mhx, rv) = rnn(input_data, (None, mhx, None), reset_experience=True, pass_through_memory=True)
|
||||||
|
|
||||||
loss = T.mean((output[:, -1, :].sum() - target_output.sum()) ** 2, dim=-1)
|
loss = T.mean(((loss_weights * output).sum(-1, keepdim=True) - target_output) ** 2)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user