Reload only used part of the indedx

This commit is contained in:
ixaxaar 2017-12-06 21:14:16 +05:30
parent b22274bfc0
commit 142811a552
2 changed files with 8 additions and 3 deletions

View File

@ -46,7 +46,7 @@ class Index(object):
self.index.reset()
T.cuda.synchronize()
def add(self, other, positions=None):
def add(self, other, positions=None, last=-1):
other = ensure_gpu(other, self.gpu_id)
T.cuda.synchronize()
@ -55,6 +55,7 @@ class Index(object):
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
else:
other = other[:last, :]
self.index.add_c(other.size(0), cast_float(ptr(other)))
T.cuda.synchronize()

View File

@ -74,9 +74,11 @@ class SparseMemory(nn.Module):
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
# add existing memory into indexes
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
if not erase:
for n, i in enumerate(hidden['indexes']):
i.add(hidden['memory'][n, :self.timestep, :])
i.reset()
i.add(hidden['memory'][n], last=pos[n][-1])
else:
self.timestep = 0
@ -133,7 +135,7 @@ class SparseMemory(nn.Module):
for b in range(positions.size(0)):
# update indexes
hidden['indexes'][b].reset()
hidden['indexes'][b].add(hidden['memory'][b])
hidden['indexes'][b].add(hidden['memory'][b], last=pos[b][-1])
hidden['last_used_mem'][b] = (int(pos[b][-1]) + 1) if (pos[b][-1] + 1) < self.mem_size else 0
# print('total ', hidden['indexes'][0].index.ntotal, self.timestep)
@ -153,6 +155,8 @@ class SparseMemory(nn.Module):
hidden['write_weights'] = write_gate.unsqueeze(1) * (x + y)
# no erasing and hence no erase matrix R_{t}
# print('write_weights', hidden['write_weights'].size(), 'write_vector', write_vector.size(), write_vector.squeeze())
# print('bmm', T.bmm(hidden['write_weights'].transpose(1, 2), write_vector).size())
hidden['read_vectors'] = hidden['read_vectors'] + T.bmm(hidden['write_weights'].transpose(1, 2), write_vector)
hidden = self.write_into_sparse_memory(hidden)