writing part of temporal links
This commit is contained in:
parent
42451c346d
commit
7f4b582c52
@ -60,7 +60,7 @@ class SparseMemory(nn.Module):
|
||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
||||
T.nn.init.orthogonal(self.interface_weights.weight)
|
||||
|
||||
self.I = cuda(1 - T.eye(m).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
self.I = cuda(1 - T.eye(c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||
self.δ = 0.005 # minimum usage
|
||||
self.timestep = 0
|
||||
|
||||
@ -100,6 +100,9 @@ class SparseMemory(nn.Module):
|
||||
# warning can be a huge chunk of contiguous memory
|
||||
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||
'link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
|
||||
'rev_link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
|
||||
'precedence': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
|
||||
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||
@ -111,6 +114,9 @@ class SparseMemory(nn.Module):
|
||||
else:
|
||||
hidden['memory'] = hidden['memory'].clone()
|
||||
hidden['visible_memory'] = hidden['visible_memory'].clone()
|
||||
hidden['link_matrix'] = hidden['link_matrix'].clone()
|
||||
hidden['rev_link_matrix'] = hidden['link_matrix'].clone()
|
||||
hidden['precedence'] = hidden['precedence'].clone()
|
||||
hidden['read_weights'] = hidden['read_weights'].clone()
|
||||
hidden['write_weights'] = hidden['write_weights'].clone()
|
||||
hidden['read_vectors'] = hidden['read_vectors'].clone()
|
||||
@ -122,6 +128,9 @@ class SparseMemory(nn.Module):
|
||||
if erase:
|
||||
hidden['memory'].data.fill_(δ)
|
||||
hidden['visible_memory'].data.fill_(δ)
|
||||
hidden['link_matrix'].data.zero_()
|
||||
hidden['rev_link_matrix'].data.zero_()
|
||||
hidden['precedence'].data.zero_()
|
||||
hidden['read_weights'].data.fill_(δ)
|
||||
hidden['write_weights'].data.fill_(δ)
|
||||
hidden['read_vectors'].data.fill_(δ)
|
||||
@ -150,6 +159,16 @@ class SparseMemory(nn.Module):
|
||||
|
||||
return hidden
|
||||
|
||||
def update_link_matrices(self, link_matrix, rev_link_matrix, write_weights, precedence):
|
||||
link_matrix = (1 - write_weights).unsqueeze(2) * link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
|
||||
|
||||
rev_link_matrix = (1 - write_weights).unsqueeze(1) * rev_link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
|
||||
|
||||
return link_matrix, rev_link_matrix
|
||||
|
||||
def update_precedence(self, precedence, write_weights):
|
||||
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
|
||||
|
||||
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
||||
|
||||
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
||||
@ -171,11 +190,23 @@ class SparseMemory(nn.Module):
|
||||
# store the write weights
|
||||
hidden['write_weights'].scatter_(1, hidden['read_positions'], write_weights)
|
||||
|
||||
# erase matrix
|
||||
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
|
||||
|
||||
# write into memory
|
||||
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||
hidden = self.write_into_sparse_memory(hidden)
|
||||
|
||||
# update link_matrix and precedence
|
||||
(b, c) = write_weights.size()
|
||||
precedence = hidden['precedence'].gather(1, hidden['read_positions'])
|
||||
|
||||
hidden['link_matrix'], hidden['rev_link_matrix'] = \
|
||||
self.update_link_matrices(hidden['link_matrix'], hidden['rev_link_matrix'], write_weights, precedence)
|
||||
precedence = self.update_precedence(hidden['precedence'], hidden['write_weights'])
|
||||
|
||||
hidden['precedence'].scatter_(1, hidden['read_positions'], precedence)
|
||||
|
||||
return hidden
|
||||
|
||||
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
||||
@ -225,6 +256,8 @@ class SparseMemory(nn.Module):
|
||||
|
||||
return read_vectors, read_positions, read_weights, visible_memory
|
||||
|
||||
# def
|
||||
|
||||
def read(self, read_query, hidden):
|
||||
# sparse read
|
||||
read_vectors, positions, read_weights, visible_memory = \
|
||||
|
Loading…
Reference in New Issue
Block a user