More differentiablity considering the network can read from a view of a larger memory
This commit is contained in:
parent
a6667bf98c
commit
c9eb3a5ca7
@ -11,7 +11,7 @@ from .util import *
|
||||
|
||||
class Index(object):
|
||||
|
||||
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_lists=30, probes=32, res=None, train=None, gpu_id=-1):
|
||||
def __init__(self, cell_size=20, nr_cells=1024, K=4, num_lists=32, probes=32, res=None, train=None, gpu_id=-1):
|
||||
super(Index, self).__init__()
|
||||
self.cell_size = cell_size
|
||||
self.nr_cells = nr_cells
|
||||
@ -29,7 +29,7 @@ class Index(object):
|
||||
# train = T.randn(self.nr_cells * 100, self.cell_size)
|
||||
|
||||
self.index = faiss.GpuIndexIVFFlat(self.res, self.cell_size, self.num_lists, faiss.METRIC_INNER_PRODUCT)
|
||||
self.index.setNumProbes(self.num_lists)
|
||||
self.index.setNumProbes(self.probes)
|
||||
self.train(train)
|
||||
|
||||
def cuda(self, gpu_id):
|
||||
|
@ -9,6 +9,7 @@ import numpy as np
|
||||
from torch.nn.utils.rnn import pad_packed_sequence as pad
|
||||
from torch.nn.utils.rnn import pack_padded_sequence as pack
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
from torch.nn.init import orthogonal
|
||||
|
||||
from .util import *
|
||||
from .sparse_memory import SparseMemory
|
||||
@ -119,6 +120,7 @@ class SDNC(nn.Module):
|
||||
|
||||
# final output layer
|
||||
self.output = nn.Linear(self.nn_output_size, self.input_size)
|
||||
orthogonal(self.output.weight)
|
||||
|
||||
if self.gpu_id != -1:
|
||||
[x.cuda(self.gpu_id) for x in self.rnns]
|
||||
@ -265,6 +267,8 @@ class SDNC(nn.Module):
|
||||
inputs = [self.output(i) for i in inputs]
|
||||
outputs = T.stack(inputs, 1 if self.batch_first else 0)
|
||||
|
||||
# outputs.register_hook(lambda x: print("========================================", x.squeeze()))
|
||||
|
||||
if is_packed:
|
||||
outputs = pack(output, lengths)
|
||||
|
||||
|
@ -51,9 +51,14 @@ class SparseMemory(nn.Module):
|
||||
self.write_vector_transform = nn.Linear(self.input_size, w)
|
||||
self.interpolation_gate_transform = nn.Linear(self.input_size, c)
|
||||
self.write_gate_transform = nn.Linear(self.input_size, 1)
|
||||
T.nn.init.orthogonal(self.read_query_transform.weight)
|
||||
T.nn.init.orthogonal(self.write_vector_transform.weight)
|
||||
T.nn.init.orthogonal(self.interpolation_gate_transform.weight)
|
||||
T.nn.init.orthogonal(self.write_gate_transform.weight)
|
||||
else:
|
||||
self.interface_size = (r * w) + w + c + 1
|
||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
||||
T.nn.init.orthogonal(self.interface_weights.weight)
|
||||
|
||||
self.I = cuda(1 - T.eye(m).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
self.δ = 0.005 # minimum usage
|
||||
@ -95,8 +100,8 @@ class SparseMemory(nn.Module):
|
||||
# warning can be a huge chunk of contiguous memory
|
||||
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'read_weights': cuda(T.zeros(b, r, c).fill_(δ), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, 1, c).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||
'last_used_mem': cuda(T.zeros(b, 1).fill_(δ), gpu_id=self.gpu_id).long(),
|
||||
'usage': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
@ -131,7 +136,8 @@ class SparseMemory(nn.Module):
|
||||
|
||||
(b, m, w) = hidden['memory'].size()
|
||||
# update memory
|
||||
hidden['memory'].scatter_(1, positions, visible_memory)
|
||||
hidden['memory'].scatter_(1, positions.unsqueeze(2).expand(b, self.read_heads*self.K+1, w), visible_memory)
|
||||
print(positions)
|
||||
|
||||
# non-differentiable operations
|
||||
pos = positions.data.cpu().numpy()
|
||||
@ -145,22 +151,27 @@ class SparseMemory(nn.Module):
|
||||
|
||||
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
||||
|
||||
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
||||
write_weights = hidden['write_weights'].gather(1, hidden['read_positions'])
|
||||
|
||||
hidden['usage'], I = self.update_usage(
|
||||
hidden['read_positions'],
|
||||
hidden['read_weights'],
|
||||
hidden['write_weights'],
|
||||
read_weights,
|
||||
write_weights,
|
||||
hidden['usage']
|
||||
)
|
||||
|
||||
# either we write to previous read locations
|
||||
x = interpolation_gate * hidden['read_weights']
|
||||
x = interpolation_gate * read_weights
|
||||
# or to a new location
|
||||
y = (1 - interpolation_gate) * I
|
||||
hidden['write_weights'] = T.prod(write_gate.unsqueeze(1) * (x + y), 1)
|
||||
write_weights = write_gate * (x + y)
|
||||
|
||||
# no erasing and hence no erase matrix R_{t}
|
||||
hidden['visible_memory'] = hidden['visible_memory'] + T.bmm(hidden['write_weights'].unsqueeze(2), write_vector)
|
||||
# hidden = self.write_into_sparse_memory(hidden)
|
||||
# store the write weights
|
||||
hidden['write_weights'].scatter_(1, hidden['read_positions'], write_weights)
|
||||
|
||||
hidden['visible_memory'] = hidden['visible_memory'] + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||
hidden = self.write_into_sparse_memory(hidden)
|
||||
|
||||
return hidden
|
||||
|
||||
@ -168,7 +179,7 @@ class SparseMemory(nn.Module):
|
||||
(b, _) = read_positions.size()
|
||||
# usage is timesteps since a non-negligible memory access
|
||||
# todo store write weights of all mem and gather from that
|
||||
u = (read_weights.sum(1) + write_weights.squeeze() > self.δ).float().view(b, -1)
|
||||
u = (read_weights + write_weights > self.δ).float()
|
||||
|
||||
# usage before write
|
||||
relevant_usages = usage.gather(1, read_positions)
|
||||
@ -176,7 +187,7 @@ class SparseMemory(nn.Module):
|
||||
# indicator of words with minimal memory usage
|
||||
minusage = T.min(relevant_usages, -1)[0].unsqueeze(1)
|
||||
minusage = minusage.expand(relevant_usages.size())
|
||||
I = (relevant_usages == minusage).float().unsqueeze(1)
|
||||
I = (relevant_usages == minusage).float()
|
||||
|
||||
# usage after write
|
||||
relevant_usages = (self.timestep - relevant_usages) * u + relevant_usages * (1 - u)
|
||||
@ -201,11 +212,14 @@ class SparseMemory(nn.Module):
|
||||
read_positions = var(read_positions)
|
||||
read_positions = T.cat([read_positions.view(b, -1), last_used_mem], 1)
|
||||
|
||||
# differentiable ops
|
||||
(b, m, w) = memory.size()
|
||||
visible_memory = memory.gather(1, read_positions.unsqueeze(2).expand(b, r*k+1, w))
|
||||
|
||||
read_weights = F.softmax(θ(visible_memory, keys), dim=2)
|
||||
# visible_memory.register_hook(lambda x: print("========================================", x.squeeze()))
|
||||
read_vectors = T.bmm(read_weights, visible_memory)
|
||||
read_weights = T.prod(read_weights, 1)
|
||||
|
||||
return read_vectors, read_positions, read_weights, visible_memory
|
||||
|
||||
@ -219,8 +233,9 @@ class SparseMemory(nn.Module):
|
||||
hidden['last_used_mem'],
|
||||
hidden['usage']
|
||||
)
|
||||
|
||||
hidden['read_positions'] = positions
|
||||
hidden['read_weights'] = read_weights
|
||||
hidden['read_weights'] = hidden['read_weights'].scatter_(1, positions, read_weights)
|
||||
hidden['read_vectors'] = read_vectors
|
||||
hidden['visible_memory'] = visible_memory
|
||||
|
||||
@ -242,17 +257,20 @@ class SparseMemory(nn.Module):
|
||||
# write key (b * 1 * w)
|
||||
write_vector = self.write_vector_transform(ξ).view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, 1, c)
|
||||
interpolation_gate = F.sigmoid(self.interpolation_gate_transform(ξ)).view(b, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(self.write_gate_transform(ξ).view(b, 1))
|
||||
else:
|
||||
ξ = self.interface_weights(ξ)
|
||||
# r read keys (b * r * w)
|
||||
read_query = ξ[:, :r*w].contiguous().view(b, r, w)
|
||||
print("999999999999999999999999")
|
||||
print(read_query.squeeze())
|
||||
# read_query.register_hook(lambda x: print("========================================", x.squeeze()))
|
||||
# write key (b * 1 * w)
|
||||
write_vector = ξ[:, r*w: r*w + w].contiguous().view(b, 1, w)
|
||||
# write vector (b * 1 * r)
|
||||
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, 1, c)
|
||||
interpolation_gate = F.sigmoid(ξ[:, r*w + w: r*w + w + c]).contiguous().view(b, c)
|
||||
# write gate (b * 1)
|
||||
write_gate = F.sigmoid(ξ[:, -1].contiguous()).unsqueeze(1).view(b, 1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user