Reload only used part of the indedx
This commit is contained in:
parent
b22274bfc0
commit
142811a552
@ -46,7 +46,7 @@ class Index(object):
|
||||
self.index.reset()
|
||||
T.cuda.synchronize()
|
||||
|
||||
def add(self, other, positions=None):
|
||||
def add(self, other, positions=None, last=-1):
|
||||
other = ensure_gpu(other, self.gpu_id)
|
||||
|
||||
T.cuda.synchronize()
|
||||
@ -55,6 +55,7 @@ class Index(object):
|
||||
assert positions.size(0) == other.size(0), "Mismatch in number of positions and vectors"
|
||||
self.index.add_with_ids_c(other.size(0), cast_float(ptr(other)), cast_long(ptr(positions + 1)))
|
||||
else:
|
||||
other = other[:last, :]
|
||||
self.index.add_c(other.size(0), cast_float(ptr(other)))
|
||||
T.cuda.synchronize()
|
||||
|
||||
|
@ -74,9 +74,11 @@ class SparseMemory(nn.Module):
|
||||
probes=self.index_checks, gpu_id=self.mem_gpu_id) for x in range(b)]
|
||||
|
||||
# add existing memory into indexes
|
||||
pos = hidden['read_positions'].squeeze().data.cpu().numpy()
|
||||
if not erase:
|
||||
for n, i in enumerate(hidden['indexes']):
|
||||
i.add(hidden['memory'][n, :self.timestep, :])
|
||||
i.reset()
|
||||
i.add(hidden['memory'][n], last=pos[n][-1])
|
||||
else:
|
||||
self.timestep = 0
|
||||
|
||||
@ -133,7 +135,7 @@ class SparseMemory(nn.Module):
|
||||
for b in range(positions.size(0)):
|
||||
# update indexes
|
||||
hidden['indexes'][b].reset()
|
||||
hidden['indexes'][b].add(hidden['memory'][b])
|
||||
hidden['indexes'][b].add(hidden['memory'][b], last=pos[b][-1])
|
||||
hidden['last_used_mem'][b] = (int(pos[b][-1]) + 1) if (pos[b][-1] + 1) < self.mem_size else 0
|
||||
|
||||
# print('total ', hidden['indexes'][0].index.ntotal, self.timestep)
|
||||
@ -153,6 +155,8 @@ class SparseMemory(nn.Module):
|
||||
hidden['write_weights'] = write_gate.unsqueeze(1) * (x + y)
|
||||
|
||||
# no erasing and hence no erase matrix R_{t}
|
||||
# print('write_weights', hidden['write_weights'].size(), 'write_vector', write_vector.size(), write_vector.squeeze())
|
||||
# print('bmm', T.bmm(hidden['write_weights'].transpose(1, 2), write_vector).size())
|
||||
hidden['read_vectors'] = hidden['read_vectors'] + T.bmm(hidden['write_weights'].transpose(1, 2), write_vector)
|
||||
hidden = self.write_into_sparse_memory(hidden)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user