Rewrite sdnc, more read heads
This commit is contained in:
parent
142811a552
commit
a6667bf98c
@ -25,7 +25,7 @@ class Index(object):
|
||||
self.res.initializeForDevice(self.gpu_id)
|
||||
|
||||
nr_samples = self.nr_cells * 100 * self.cell_size
|
||||
train = train if train is not None else T.arange(-nr_samples, nr_samples, 2).view(self.nr_cells * 100, self.cell_size) / nr_samples
|
||||
train = train if train is not None else T.arange(-nr_samples, nr_samples, 2).view(self.nr_cells * 100, self.cell_size) / (nr_samples/10)
|
||||
# train = T.randn(self.nr_cells * 100, self.cell_size)
|
||||
|
||||
self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
|
||||
|
12
dnc/sdnc.py
12
dnc/sdnc.py
@ -28,7 +28,8 @@ class SDNC(nn.Module):
|
||||
dropout=0,
|
||||
bidirectional=False,
|
||||
nr_cells=5,
|
||||
sparse_reads=2,
|
||||
sparse_reads=10,
|
||||
read_heads=4,
|
||||
cell_size=10,
|
||||
nonlinearity='tanh',
|
||||
gpu_id=-1,
|
||||
@ -51,6 +52,7 @@ class SDNC(nn.Module):
|
||||
self.bidirectional = bidirectional
|
||||
self.nr_cells = nr_cells
|
||||
self.sparse_reads = sparse_reads
|
||||
self.read_heads = read_heads
|
||||
self.cell_size = cell_size
|
||||
self.nonlinearity = nonlinearity
|
||||
self.gpu_id = gpu_id
|
||||
@ -60,7 +62,7 @@ class SDNC(nn.Module):
|
||||
self.clip = clip
|
||||
|
||||
self.w = self.cell_size
|
||||
self.r = self.sparse_reads
|
||||
self.r = self.read_heads
|
||||
|
||||
self.read_vectors_size = self.r * self.w
|
||||
self.output_size = self.hidden_size
|
||||
@ -90,7 +92,8 @@ class SDNC(nn.Module):
|
||||
input_size=self.output_size,
|
||||
mem_size=self.nr_cells,
|
||||
cell_size=self.w,
|
||||
sparse_reads=self.r,
|
||||
sparse_reads=self.sparse_reads,
|
||||
read_heads=self.read_heads,
|
||||
gpu_id=self.gpu_id,
|
||||
mem_gpu_id=self.gpu_id,
|
||||
independent_linears=self.independent_linears
|
||||
@ -105,7 +108,8 @@ class SDNC(nn.Module):
|
||||
input_size=self.output_size,
|
||||
mem_size=self.nr_cells,
|
||||
cell_size=self.w,
|
||||
sparse_reads=self.r,
|
||||
sparse_reads=self.sparse_reads,
|
||||
read_heads=self.read_heads,
|
||||
gpu_id=self.gpu_id,
|
||||
mem_gpu_id=self.gpu_id,
|
||||
independent_linears=self.independent_linears
|
||||
|
@ -21,10 +21,10 @@ class SparseMemory(nn.Module):
|
||||
mem_size=512,
|
||||
cell_size=32,
|
||||
independent_linears=True,
|
||||
sparse_reads=4,
|
||||
read_heads=4,
|
||||
sparse_reads=10,
|
||||
num_lists=None,
|
||||
index_checks=32,
|
||||
rebuild_indexes_after=10,
|
||||
gpu_id=-1,
|
||||
mem_gpu_id=-1
|
||||
):
|
||||
@ -37,23 +37,22 @@ class SparseMemory(nn.Module):
|
||||
self.input_size = input_size
|
||||
self.independent_linears = independent_linears
|
||||
self.K = sparse_reads if self.mem_size > sparse_reads else self.mem_size
|
||||
self.read_heads = read_heads
|
||||
self.num_lists = num_lists if num_lists is not None else int(self.mem_size / 100)
|
||||
self.index_checks = index_checks
|
||||
# self.rebuild_indexes_after = rebuild_indexes_after
|
||||
|
||||
# self.index_reset_ctr = 0
|
||||
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
r = self.K + 1
|
||||
r = self.read_heads
|
||||
c = r * self.K + 1
|
||||
|
||||
if self.independent_linears:
|
||||
self.read_query_transform = nn.Linear(self.input_size, w)
|
||||
self.read_query_transform = nn.Linear(self.input_size, w*r)
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, w)
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, c)
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||
else:
|
||||
self.interface_size = (2 * w) + r + 1
|
||||
self.interface_size = (r * w) + w + c + 1
|
||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
||||
|
||||
self.I = cuda(1 - T.eye(m).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
@ -88,22 +87,25 @@ class SparseMemory(nn.Module):
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
b = batch_size
|
||||
r = self.K + 1
|
||||
r = self.read_heads
|
||||
c = r * self.K + 1
|
||||
|
||||
if hidden is None:
|
||||
hidden = {
|
||||
# warning can be a huge chunk of contiguous memory
|
||||
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'read_weights': cuda(T.zeros(b, 1, r).fill_(δ), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, 1, r).fill_(δ), gpu_id=self.gpu_id),
|
||||
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'read_weights': cuda(T.zeros(b, r, c).fill_(δ), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, 1, c).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||
'last_used_mem': cuda(T.zeros(b, 1).fill_(δ), gpu_id=self.gpu_id).long(),
|
||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_positions': cuda(T.arange(0, r).expand(b, 1, r), gpu_id=self.gpu_id).long()
|
||||
'read_positions': cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
}
|
||||
hidden = self.rebuild_indexes(hidden, erase=True)
|
||||
else:
|
||||
hidden['memory'] = hidden['memory'].clone()
|
||||
hidden['visible_memory'] = hidden['visible_memory'].clone()
|
||||
hidden['read_weights'] = hidden['read_weights'].clone()
|
||||
hidden['write_weights'] = hidden['write_weights'].clone()
|
||||
hidden['read_vectors'] = hidden['read_vectors'].clone()
|
||||
@ -114,21 +116,22 @@ class SparseMemory(nn.Module):
|
||||
|
||||
if erase:
|
||||
hidden['memory'].data.fill_(δ)
|
||||
hidden['visible_memory'].data.fill_(δ)
|
||||
hidden['read_weights'].data.fill_(δ)
|
||||
hidden['write_weights'].data.fill_(δ)
|
||||
hidden['read_vectors'].data.fill_(δ)
|
||||
hidden['last_used_mem'].data.fill_(0)
|
||||
hidden['usage'].data.fill_(δ)
|
||||
hidden['read_positions'] = cuda(T.arange(0, r).expand(b, 1, r), gpu_id=self.gpu_id).long()
|
||||
hidden['read_positions'] = cuda(T.arange(0, c).expand(b, c), gpu_id=self.gpu_id).long()
|
||||
return hidden
|
||||
|
||||
def write_into_sparse_memory(self, hidden):
|
||||
read_vectors = hidden['read_vectors']
|
||||
visible_memory = hidden['visible_memory']
|
||||
positions = hidden['read_positions'].squeeze()
|
||||
|
||||
(b, m, w) = hidden['memory'].size()
|
||||
# update memory
|
||||
hidden['memory'].scatter_(1, positions.unsqueeze(2).expand(b, self.K+1, w), read_vectors)
|
||||
hidden['memory'].scatter_(1, positions, visible_memory)
|
||||
|
||||
# non-differentiable operations
|
||||
pos = positions.data.cpu().numpy()
|
||||
@ -138,7 +141,6 @@ class SparseMemory(nn.Module):
|
||||
hidden['indexes'][b].add(hidden['memory'][b], last=pos[b][-1])
|
||||
hidden['last_used_mem'][b] = (int(pos[b][-1]) + 1) if (pos[b][-1] + 1) < self.mem_size else 0
|
||||
|
||||
# print('total ', hidden['indexes'][0].index.ntotal, self.timestep)
|
||||
return hidden
|
||||
|
||||
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
||||
@ -150,22 +152,23 @@ class SparseMemory(nn.Module):
|
||||
hidden['usage']
|
||||
)
|
||||
|
||||
# either we write to previous read locations
|
||||
x = interpolation_gate * hidden['read_weights']
|
||||
# or to a new location
|
||||
y = (1 - interpolation_gate) * I
|
||||
hidden['write_weights'] = write_gate.unsqueeze(1) * (x + y)
|
||||
hidden['write_weights'] = T.prod(write_gate.unsqueeze(1) * (x + y), 1)
|
||||
|
||||
# no erasing and hence no erase matrix R_{t}
|
||||
# print('write_weights', hidden['write_weights'].size(), 'write_vector', write_vector.size(), write_vector.squeeze())
|
||||
# print('bmm', T.bmm(hidden['write_weights'].transpose(1, 2), write_vector).size())
|
||||
hidden['read_vectors'] = hidden['read_vectors'] + T.bmm(hidden['write_weights'].transpose(1, 2), write_vector)
|
||||
hidden = self.write_into_sparse_memory(hidden)
|
||||
hidden['visible_memory'] = hidden['visible_memory'] + T.bmm(hidden['write_weights'].unsqueeze(2), write_vector)
|
||||
# hidden = self.write_into_sparse_memory(hidden)
|
||||
|
||||
return hidden
|
||||
|
||||
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
||||
read_positions = read_positions.squeeze()
|
||||
(b, _) = read_positions.size()
|
||||
# usage is timesteps since a non-negligible memory access
|
||||
u = (read_weights + write_weights > self.δ).float()
|
||||
# todo store write weights of all mem and gather from that
|
||||
u = (read_weights.sum(1) + write_weights.squeeze() > self.δ).float().view(b, -1)
|
||||
|
||||
# usage before write
|
||||
relevant_usages = usage.gather(1, read_positions)
|
||||
@ -176,7 +179,7 @@ class SparseMemory(nn.Module):
|
||||
I = (relevant_usages == minusage).float().unsqueeze(1)
|
||||
|
||||
# usage after write
|
||||
relevant_usages = (self.timestep - relevant_usages) * u.squeeze() + relevant_usages * (1 - u.squeeze())
|
||||
relevant_usages = (self.timestep - relevant_usages) * u + relevant_usages * (1 - u)
|
||||
|
||||
usage.scatter_(1, read_positions, relevant_usages)
|
||||
|
||||
@ -185,45 +188,30 @@ class SparseMemory(nn.Module):
|
||||
def read_from_sparse_memory(self, memory, indexes, keys, last_used_mem, usage):
|
||||
b = keys.size(0)
|
||||
read_positions = []
|
||||
read_weights = []
|
||||
|
||||
# print(keys.squeeze())
|
||||
# non-differentiable operations
|
||||
# we search for k cells per read head
|
||||
for batch in range(b):
|
||||
distances, positions = indexes[batch].search(keys[batch])
|
||||
read_weights.append(distances)
|
||||
read_positions.append(T.clamp(positions, 0, self.mem_size - 1))
|
||||
|
||||
# add least used mem to read positions
|
||||
read_positions = T.stack(read_positions, 0)
|
||||
|
||||
# TODO: explore possibility of reading co-locations and such
|
||||
# if read_collocations:
|
||||
# read the previous and the next memory locations
|
||||
# read_positions = T.cat([read_positions, read_positions-1, read_positions+1], -1)
|
||||
|
||||
# add least used mem to read positions
|
||||
# TODO: explore possibility of reading co-locations or ranges and such
|
||||
(b, r, k) = read_positions.size()
|
||||
read_positions = var(read_positions)
|
||||
read_positions = T.cat([read_positions, last_used_mem.unsqueeze(1)], 2)
|
||||
# print(read_positions.squeeze())
|
||||
|
||||
# add weight of 0 for least used mem block
|
||||
read_weights = T.stack(read_weights, 0)
|
||||
new_block = read_weights.new(b, 1, 1)
|
||||
new_block.fill_(δ)
|
||||
read_weights = T.cat([read_weights, new_block], 2)
|
||||
read_weights = var(read_weights)
|
||||
# condition read weights by their usages
|
||||
relevant_usages = usage.gather(1, read_positions.squeeze())
|
||||
read_weights = (read_weights.squeeze(1) * relevant_usages).unsqueeze(1)
|
||||
read_positions = T.cat([read_positions.view(b, -1), last_used_mem], 1)
|
||||
|
||||
(b, m, w) = memory.size()
|
||||
read_vectors = memory.gather(1, read_positions.squeeze().unsqueeze(2).expand(b, self.K+1, w))
|
||||
visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, r*k+1, w))
|
||||
|
||||
return read_vectors, read_positions, read_weights
|
||||
read_weights = F.softmax(θ(visible_memory, keys), dim=2)
|
||||
read_vectors = T.bmm(read_weights, visible_memory)
|
||||
|
||||
return read_vectors, read_positions, read_weights, visible_memory
|
||||
|
||||
def read(self, read_query, hidden):
|
||||
# sparse read
|
||||
read_vectors, positions, read_weights = \
|
||||
read_vectors, positions, read_weights, visible_memory = \
|
||||
self.read_from_sparse_memory(
|
||||
hidden['memory'],
|
||||
hidden['indexes'],
|
||||
@ -234,8 +222,9 @@ class SparseMemory(nn.Module):
|
||||
hidden['read_positions'] = positions
|
||||
hidden['read_weights'] = read_weights
|
||||
hidden['read_vectors'] = read_vectors
|
||||
hidden['visible_memory'] = visible_memory
|
||||
|
||||
return hidden['read_vectors'][:, :-1, :].contiguous(), hidden
|
||||
return hidden['read_vectors'], hidden
|
||||
|
||||
def forward(self, ξ, hidden):
|
||||
t = time.time()
|
||||
@ -243,26 +232,27 @@ class SparseMemory(nn.Module):
|
||||
# ξ = ξ.detach()
|
||||
m = self.mem_size
|
||||
w = self.cell_size
|
||||
r = self.K + 1
|
||||
r = self.read_heads
|
||||
c = r * self.K + 1
|
||||
b = ξ.size()[0]
|
||||
|
||||
if self.independent_linears:
|
||||
# r read keys (b * r * w)
|
||||
read_query = self.read_query_transform(ξ).view(b, 1, w)
|
||||
read_query = self.read_query_transform(ξ).view(b, r, w)
|
||||
# write key (b * 1 * w)
|
||||
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
|
||||
# write vector (b * 1 * w)
|
||||
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, 1, r)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, 1, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
||||
else:
|
||||
ξ = self.interface_weights(ξ)
|
||||
# r read keys (b * w * r)
|
||||
read_query = ξ[:, :w].contiguous().view(b, 1, w)
|
||||
# write key (b * w * 1)
|
||||
write_vector = ξ[:, w: 2 * w].contiguous().view(b, 1, w)
|
||||
# write vector (b * w)
|
||||
interpolation_gate = F.sigmoid(ξ[:, 2 * w: 2 * w + r]).contiguous().view(b, 1, r)
|
||||
# r read keys (b * r * w)
|
||||
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
|
||||
# write key (b * 1 * w)
|
||||
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, 1, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
||||
|
||||
|
@ -43,6 +43,7 @@ parser.add_argument('-batch_size', type=int, default=100, metavar='N', help='bat
|
||||
parser.add_argument('-mem_size', type=int, default=20, help='memory dimension')
|
||||
parser.add_argument('-mem_slot', type=int, default=16, help='number of memory slots')
|
||||
parser.add_argument('-read_heads', type=int, default=4, help='number of read heads')
|
||||
parser.add_argument('-sparse_reads', type=int, default=10, help='number of sparse reads per read head')
|
||||
|
||||
parser.add_argument('-sequence_max_length', type=int, default=4, metavar='N', help='sequence_max_length')
|
||||
parser.add_argument('-cuda', type=int, default=-1, help='Cuda GPU ID, -1 for CPU')
|
||||
@ -139,7 +140,8 @@ if __name__ == '__main__':
|
||||
dropout=args.dropout,
|
||||
nr_cells=mem_slot,
|
||||
cell_size=mem_size,
|
||||
sparse_reads=read_heads,
|
||||
sparse_reads=args.sparse_reads,
|
||||
read_heads=args.read_heads,
|
||||
gpu_id=args.cuda,
|
||||
debug=False,
|
||||
batch_first=True,
|
||||
|
Loading…
Reference in New Issue
Block a user