writing part of temporal links
This commit is contained in:
parent
42451c346d
commit
7f4b582c52
@ -60,7 +60,7 @@ class SparseMemory(nn.Module):
|
|||||||
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
self.interface_weights = nn.Linear(self.input_size, self.interface_size)
|
||||||
T.nn.init.orthogonal(self.interface_weights.weight)
|
T.nn.init.orthogonal(self.interface_weights.weight)
|
||||||
|
|
||||||
self.I = cuda(1 - T.eye(m).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
self.I = cuda(1 - T.eye(c).unsqueeze(0), gpu_id=self.gpu_id) # (1 * n * n)
|
||||||
self.δ = 0.005 # minimum usage
|
self.δ = 0.005 # minimum usage
|
||||||
self.timestep = 0
|
self.timestep = 0
|
||||||
|
|
||||||
@ -100,6 +100,9 @@ class SparseMemory(nn.Module):
|
|||||||
# warning can be a huge chunk of contiguous memory
|
# warning can be a huge chunk of contiguous memory
|
||||||
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
'memory': cuda(T.zeros(b, m, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||||
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
'visible_memory': cuda(T.zeros(b, c, w).fill_(δ), gpu_id=self.mem_gpu_id),
|
||||||
|
'link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
|
||||||
|
'rev_link_matrix': cuda(T.zeros(b, c, c), gpu_id=self.gpu_id),
|
||||||
|
'precedence': cuda(T.zeros(b, m), gpu_id=self.gpu_id),
|
||||||
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'read_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
'write_weights': cuda(T.zeros(b, m).fill_(δ), gpu_id=self.gpu_id),
|
||||||
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
'read_vectors': cuda(T.zeros(b, r, w).fill_(δ), gpu_id=self.gpu_id),
|
||||||
@ -111,6 +114,9 @@ class SparseMemory(nn.Module):
|
|||||||
else:
|
else:
|
||||||
hidden['memory'] = hidden['memory'].clone()
|
hidden['memory'] = hidden['memory'].clone()
|
||||||
hidden['visible_memory'] = hidden['visible_memory'].clone()
|
hidden['visible_memory'] = hidden['visible_memory'].clone()
|
||||||
|
hidden['link_matrix'] = hidden['link_matrix'].clone()
|
||||||
|
hidden['rev_link_matrix'] = hidden['link_matrix'].clone()
|
||||||
|
hidden['precedence'] = hidden['precedence'].clone()
|
||||||
hidden['read_weights'] = hidden['read_weights'].clone()
|
hidden['read_weights'] = hidden['read_weights'].clone()
|
||||||
hidden['write_weights'] = hidden['write_weights'].clone()
|
hidden['write_weights'] = hidden['write_weights'].clone()
|
||||||
hidden['read_vectors'] = hidden['read_vectors'].clone()
|
hidden['read_vectors'] = hidden['read_vectors'].clone()
|
||||||
@ -122,6 +128,9 @@ class SparseMemory(nn.Module):
|
|||||||
if erase:
|
if erase:
|
||||||
hidden['memory'].data.fill_(δ)
|
hidden['memory'].data.fill_(δ)
|
||||||
hidden['visible_memory'].data.fill_(δ)
|
hidden['visible_memory'].data.fill_(δ)
|
||||||
|
hidden['link_matrix'].data.zero_()
|
||||||
|
hidden['rev_link_matrix'].data.zero_()
|
||||||
|
hidden['precedence'].data.zero_()
|
||||||
hidden['read_weights'].data.fill_(δ)
|
hidden['read_weights'].data.fill_(δ)
|
||||||
hidden['write_weights'].data.fill_(δ)
|
hidden['write_weights'].data.fill_(δ)
|
||||||
hidden['read_vectors'].data.fill_(δ)
|
hidden['read_vectors'].data.fill_(δ)
|
||||||
@ -150,6 +159,16 @@ class SparseMemory(nn.Module):
|
|||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
|
def update_link_matrices(self, link_matrix, rev_link_matrix, write_weights, precedence):
|
||||||
|
link_matrix = (1 - write_weights).unsqueeze(2) * link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
|
||||||
|
|
||||||
|
rev_link_matrix = (1 - write_weights).unsqueeze(1) * rev_link_matrix + write_weights.unsqueeze(2) * precedence.unsqueeze(1)
|
||||||
|
|
||||||
|
return link_matrix, rev_link_matrix
|
||||||
|
|
||||||
|
def update_precedence(self, precedence, write_weights):
|
||||||
|
return (1 - T.sum(write_weights, dim=-1, keepdim=True)) * precedence + write_weights
|
||||||
|
|
||||||
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
def write(self, interpolation_gate, write_vector, write_gate, hidden):
|
||||||
|
|
||||||
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
read_weights = hidden['read_weights'].gather(1, hidden['read_positions'])
|
||||||
@ -171,11 +190,23 @@ class SparseMemory(nn.Module):
|
|||||||
# store the write weights
|
# store the write weights
|
||||||
hidden['write_weights'].scatter_(1, hidden['read_positions'], write_weights)
|
hidden['write_weights'].scatter_(1, hidden['read_positions'], write_weights)
|
||||||
|
|
||||||
|
# erase matrix
|
||||||
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
|
erase_matrix = I.unsqueeze(2).expand(hidden['visible_memory'].size())
|
||||||
|
|
||||||
|
# write into memory
|
||||||
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
hidden['visible_memory'] = hidden['visible_memory'] * (1 - erase_matrix) + T.bmm(write_weights.unsqueeze(2), write_vector)
|
||||||
hidden = self.write_into_sparse_memory(hidden)
|
hidden = self.write_into_sparse_memory(hidden)
|
||||||
|
|
||||||
|
# update link_matrix and precedence
|
||||||
|
(b, c) = write_weights.size()
|
||||||
|
precedence = hidden['precedence'].gather(1, hidden['read_positions'])
|
||||||
|
|
||||||
|
hidden['link_matrix'], hidden['rev_link_matrix'] = \
|
||||||
|
self.update_link_matrices(hidden['link_matrix'], hidden['rev_link_matrix'], write_weights, precedence)
|
||||||
|
precedence = self.update_precedence(hidden['precedence'], hidden['write_weights'])
|
||||||
|
|
||||||
|
hidden['precedence'].scatter_(1, hidden['read_positions'], precedence)
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
|
||||||
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
def update_usage(self, read_positions, read_weights, write_weights, usage):
|
||||||
@ -225,6 +256,8 @@ class SparseMemory(nn.Module):
|
|||||||
|
|
||||||
return read_vectors, read_positions, read_weights, visible_memory
|
return read_vectors, read_positions, read_weights, visible_memory
|
||||||
|
|
||||||
|
# def
|
||||||
|
|
||||||
def read(self, read_query, hidden):
|
def read(self, read_query, hidden):
|
||||||
# sparse read
|
# sparse read
|
||||||
read_vectors, positions, read_weights, visible_memory = \
|
read_vectors, positions, read_weights, visible_memory = \
|
||||||
|
Loading…
Reference in New Issue
Block a user