writing part of temporal links

This commit is contained in:
ixaxaar 2017-12-11 03:42:45 +05:30
parent 42451c346d
commit 7f4b582c52

View File

@ -60,7 +60,7 @@ class SparseMemory(nn.Module):
self.interface_weights = nn.Linear(self.input_size, self.interface_size) self.interface_weights = nn.Linear(self.input_size, self.interface_size)
T.nn.init.orthogonal(self.interface_weights.weight) T.nn.init.orthogonal(self.interface_weights.weight)
self.I = cuda(1 - T.eye(m).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n) self.I = cuda(1 - T.eye(c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
self.δ = 0.005 # minimum usage self.δ = 0.005 # minimum usage
self.timestep = 0 self.timestep = 0
@ -100,6 +100,9 @@ class SparseMemory(nn.Module):
# warning can be a huge chunk of contiguous memory # warning can be a huge chunk of contiguous memory
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id), 'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id), 'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
'link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
'rev_link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
'precedence': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id), 'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id), 'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id), 'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
@ -111,6 +114,9 @@ class SparseMemory(nn.Module):
else: else:
hidden['memory'] = hidden['memory'].clone() hidden['memory'] = hidden['memory'].clone()
hidden['visible_memory'] = hidden['visible_memory'].clone() hidden['visible_memory'] = hidden['visible_memory'].clone()
hidden['link_matrix'] = hidden['link_matrix'].clone()
hidden['rev_link_matrix'] = hidden['link_matrix'].clone()
hidden['precedence'] = hidden['precedence'].clone()
hidden['read_weights'] = hidden['read_weights'].clone() hidden['read_weights'] = hidden['read_weights'].clone()
hidden['write_weights'] = hidden['write_weights'].clone() hidden['write_weights'] = hidden['write_weights'].clone()
hidden['read_vectors'] = hidden['read_vectors'].clone() hidden['read_vectors'] = hidden['read_vectors'].clone()
@ -122,6 +128,9 @@ class SparseMemory(nn.Module):
if erase: if erase:
hidden['memory'].data.fill_(δ) hidden['memory'].data.fill_(δ)
hidden['visible_memory'].data.fill_(δ) hidden['visible_memory'].data.fill_(δ)
hidden['link_matrix'].data.zero_()
hidden['rev_link_matrix'].data.zero_()
hidden['precedence'].data.zero_()
hidden['read_weights'].data.fill_(δ) hidden['read_weights'].data.fill_(δ)
hidden['write_weights'].data.fill_(δ) hidden['write_weights'].data.fill_(δ)
hidden['read_vectors'].data.fill_(δ) hidden['read_vectors'].data.fill_(δ)
@ -150,6 +159,16 @@ class SparseMemory(nn.Module):
return hidden return hidden
def update_link_matrices(self, link_matrix, rev_link_matrix, write_weights, precedence):
link_matrix = (1 - write_weights).unsqueeze(2) * link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
rev_link_matrix = (1 - write_weights).unsqueeze(1) * rev_link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
return link_matrix, rev_link_matrix
def update_precedence(self, precedence, write_weights):
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
def write(self, interpolation_gate, write_vector, write_gate, hidden): def write(self, interpolation_gate, write_vector, write_gate, hidden):
read_weights = hidden['read_weights'].gather(1, hidden['read_positions']) read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
@ -171,11 +190,23 @@ class SparseMemory(nn.Module):
# store the write weights # store the write weights
hidden['write_weights'].scatter_(1, hidden['read_positions'], write_weights) hidden['write_weights'].scatter_(1, hidden['read_positions'], write_weights)
# erase matrix
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size()) erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
# write into memory
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector) hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
hidden = self.write_into_sparse_memory(hidden) hidden = self.write_into_sparse_memory(hidden)
# update link_matrix and precedence
(b, c) = write_weights.size()
precedence = hidden['precedence'].gather(1, hidden['read_positions'])
hidden['link_matrix'], hidden['rev_link_matrix'] = \
self.update_link_matrices(hidden['link_matrix'], hidden['rev_link_matrix'], write_weights, precedence)
precedence = self.update_precedence(hidden['precedence'], hidden['write_weights'])
hidden['precedence'].scatter_(1, hidden['read_positions'], precedence)
return hidden return hidden
def update_usage(self, read_positions, read_weights, write_weights, usage): def update_usage(self, read_positions, read_weights, write_weights, usage):
@ -225,6 +256,8 @@ class SparseMemory(nn.Module):
return read_vectors, read_positions, read_weights, visible_memory return read_vectors, read_positions, read_weights, visible_memory
# def
def read(self, read_query, hidden): def read(self, read_query, hidden):
# sparse read # sparse read
read_vectors, positions, read_weights, visible_memory = \ read_vectors, positions, read_weights, visible_memory = \