250 lines
11 KiB
Python
Executable File
250 lines
11 KiB
Python
Executable File
# Differentiable Neural Computer
|
|
# inspired by (http://www.nature.com/nature/journal/v538/n7626/full/nature20101.html)
|
|
# some ideas taken from https://github.com/Mostafa-Samir/DNC-tensorflow
|
|
# Sam Greydanus. February 2017. MIT License.
|
|
|
|
import tensorflow as tf
|
|
import numpy as np
|
|
|
|
class Memory():
|
|
def __init__(self, R, W, N, batch_size):
|
|
'''
|
|
Defines how the interface vector zeta interacts with the memory state of the DNC
|
|
Parameters:
|
|
----------
|
|
R: the number of DNC read heads
|
|
W: the DNC "word length" (length of each DNC memory vector)
|
|
N: the number of DNC word vectors (corresponds to memory size)
|
|
batch_size: the number of batches
|
|
'''
|
|
|
|
self.R = R
|
|
self.W = W
|
|
self.N = N
|
|
self.batch_size = batch_size
|
|
|
|
# when we go from 2D indexes to a flat 1D vector, we need to reindex using these shifts
|
|
ix_flat_shifts = tf.constant(np.cumsum([0] + [N] * (batch_size - 1)), dtype=tf.int32)
|
|
self.ix_flat_shifts = tf.expand_dims(ix_flat_shifts, [1])
|
|
|
|
# N x N identity matrix
|
|
self.I = tf.eye(N)
|
|
self.eps = 1e-6
|
|
self.state_keys = ['M', 'u', 'p', 'L', 'w_w', 'w_r', 'r']
|
|
|
|
def zero_state(self):
|
|
'''
|
|
Supplies the initial state of the DNC's memory vector
|
|
|
|
Returns: Tuple(7)
|
|
dnc_state: contains initial values for (M, u, p, L, w_w, w_r, r) respectively. According to the DNC paper:
|
|
M: (batch_size, N, W) the memory vector
|
|
u: (batch_size, N) the usage vector
|
|
p: (batch_size, N) the precedence weighting (helps update L)
|
|
L: (batch_size, N, N) the temporal linkage matrix (helps DNC remember what order things were written)
|
|
w_w: (batch_size, N) the write weighting - says where the DNC wrote word last time step
|
|
w_r: (batch_size, N, R )the read vector - says which word vectors the DNC accessed last time step
|
|
'''
|
|
return [
|
|
tf.fill([self.batch_size, self.N, self.W], self.eps), # M
|
|
tf.zeros([self.batch_size, self.N, ]), # u
|
|
tf.zeros([self.batch_size, self.N, ]), # p
|
|
tf.zeros([self.batch_size, self.N, self.N]), # L
|
|
tf.fill([self.batch_size, self.N, ], self.eps), # w_w
|
|
tf.fill([self.batch_size, self.N, self.R], self.eps), # w_r
|
|
tf.fill([self.batch_size, self.W, self.R], self.eps), # r
|
|
]
|
|
|
|
def content_addressing(self, M, kappa, beta):
|
|
'''
|
|
Computes the probabilities that each word vector in memory was the target of a given key (see paper)
|
|
'''
|
|
norm_M = tf.nn.l2_normalize(M, 2)
|
|
norm_kappa = tf.nn.l2_normalize(kappa, 1)
|
|
similiarity = tf.matmul(norm_M, norm_kappa)
|
|
|
|
return tf.nn.softmax(similiarity * tf.expand_dims(beta, 1), 1)
|
|
|
|
def update_u(self, u, w_r, w_w, f):
|
|
'''
|
|
Computes the new usage vector. This tells the DNC which memory slots are being used and which are free (see paper)
|
|
'''
|
|
f = tf.expand_dims(f, 1) # need to match w_r dimensions
|
|
psi = tf.reduce_prod(1 - w_r * f, 2) # psi tells us what usage to reserve
|
|
next_u = (u + w_w - u * w_w) * psi # update u based on what we wrote last time
|
|
return next_u
|
|
|
|
def get_allocation(self, next_u):
|
|
'''
|
|
Computes the allocation vector. This tells the DNC where it COULD write its next memory (see paper)
|
|
'''
|
|
u_sorted, u_ix = tf.nn.top_k(-1 * next_u, self.N) # sort by descending usage
|
|
u_sorted = -1 * u_sorted
|
|
a_sorted = (1 - u_sorted) * tf.cumprod(u_sorted, axis=1, exclusive=True) # classic DNC cumprod
|
|
|
|
# indexing wizardry to account for multiple batches
|
|
ix_flat = u_ix + self.ix_flat_shifts
|
|
ix_flat = tf.reshape(ix_flat, (-1,))
|
|
flat_array = tf.TensorArray(tf.float32, self.batch_size * self.N)
|
|
|
|
a_scattered = flat_array.scatter(ix_flat, tf.reshape(a_sorted, (-1,))) # undo the sort
|
|
a = a_scattered.stack() # put back into a Tensor
|
|
return tf.reshape(a, (self.batch_size, self.N))
|
|
|
|
def update_w_w(self, c_w, a, g_w, g_a):
|
|
'''
|
|
Computes the new write weighting. This tells the DNC where (and if) it will write its next memory (see paper)
|
|
'''
|
|
c_w = tf.squeeze(c_w) # want c_w as a (batched) vector
|
|
next_w_w = g_w * (g_a * a + (1 - g_a) * c_w) # apply the allocation and write gates
|
|
return next_w_w
|
|
|
|
def update_M(self, M, w_w, v, e):
|
|
'''
|
|
Computes the new memry matrix. This is where the DNC actually stores memories (see paper)
|
|
'''
|
|
# expand data to force matmul to behave as an outer product
|
|
w_w = tf.expand_dims(w_w, 2)
|
|
v = tf.expand_dims(v, 1)
|
|
e = tf.expand_dims(e, 1)
|
|
|
|
# think of the memory update as a bunch of elementwise interpolations
|
|
M_erase = M * (1 - tf.matmul(w_w, e))
|
|
M_write = tf.matmul(w_w, v)
|
|
next_M = M_erase + M_write
|
|
return next_M
|
|
|
|
def update_p(self, p, w_w):
|
|
'''
|
|
Updates the precedence vector. This tells the DNC how much each location was the last one written to (see paper)
|
|
'''
|
|
interpolate = 1 - tf.reduce_sum(w_w, 1, keep_dims=True)
|
|
next_p = interpolate * p + w_w
|
|
return next_p
|
|
|
|
def update_L(self, p, L, w_w):
|
|
'''
|
|
Updates the temoral linkage matrix. This tells the DNC what order it has written things to memory (see paper)
|
|
'''
|
|
w_w = tf.expand_dims(w_w, 2)
|
|
p = tf.expand_dims(p, 1)
|
|
|
|
# compute "outer sum" of w_w
|
|
c_w_w = tf.reshape(w_w, (-1, self.N, 1))
|
|
U = tf.tile(c_w_w,[1, 1, self.N])
|
|
w_w_outer_sum = U + tf.transpose(U, [0, 2, 1])
|
|
|
|
next_L = (1 - w_w_outer_sum) * L + tf.matmul(w_w, p) # update L
|
|
return (1 - self.I) * next_L # get rid of links to self
|
|
|
|
def get_bf_w(self, w_r, L):
|
|
'''
|
|
Gets the write locations immediately before and after a given write location. This lets the DNC traverse memories in order (see paper)
|
|
'''
|
|
f_w = tf.matmul(L, w_r)
|
|
b_w = tf.matmul(L, w_r, adjoint_a=True) # transpose the first argument
|
|
return f_w, b_w
|
|
|
|
def update_w_r(self, c_w, f_w, b_w, pi):
|
|
'''
|
|
Updates the read weighting. This tells the DNC's read heads which memories to extract (see paper)
|
|
'''
|
|
backward = tf.expand_dims(pi[:, 0, :], 1) * b_w
|
|
content = tf.expand_dims(pi[:, 1, :], 1) * c_w
|
|
forward = tf.expand_dims(pi[:, 2, :], 1) * f_w
|
|
next_w_r = backward + content + forward
|
|
return next_w_r
|
|
|
|
def update_r(self, M, w_r):
|
|
'''
|
|
Gets the DNC's output. This vector contains the outputs of the DNC's read heads (see paper)
|
|
'''
|
|
return tf.matmul(M, w_r, adjoint_a=True) # transpose the first argument
|
|
|
|
def write(self, zeta, state):
|
|
'''
|
|
Performs a write action on the DNC's memory
|
|
Parameters:
|
|
----------
|
|
zeta: dict
|
|
variable names (string) mapping to tensors (Tensor) includes:
|
|
'kappa_r': (batch_size, W, R) read key (there are R of them)
|
|
'beta_r': (batch_size, R) read strength
|
|
'kappa_w': (batch_size, W, 1) write key
|
|
'beta_w': (batch_size, 1) write strength
|
|
'e': (batch_size, W) erase vector
|
|
'v': (batch_size, W) write vector
|
|
'f': (batch_size, R) free gates (R of them)
|
|
'g_a': (batch_size, 1) allocation gate
|
|
'g_w': (batch_size, 1) write gate
|
|
'pi': (batch_size, 3, R) read modes (backward, content, forward)
|
|
... see paper for more info
|
|
state: dict
|
|
contains initial values for (M, u, p, L, w_w, w_r, r) respectively. According to the DNC paper:
|
|
M: (batch_size, N, W) the memory vector
|
|
u: (batch_size, N) the usage vector
|
|
p: (batch_size, N) the precedence weighting (helps update L)
|
|
L: (batch_size, N, N) the temporal linkage matrix (helps DNC remember what order things were written)
|
|
w_w: (batch_size, N) the write weighting - says where the DNC wrote word last time step
|
|
w_r: (batch_size, N, R )the read vector - says which word vectors the DNC accessed last time step
|
|
Returns: Tuple(5)
|
|
next_u: Tensor
|
|
next_w_w: Tensor
|
|
next_M: Tensor
|
|
next_L: Tensor
|
|
next_pL Tensor
|
|
'''
|
|
c_w = self.content_addressing(state['M'], zeta['kappa_w'], zeta['beta_w'])
|
|
next_u = self.update_u(state['u'], state['w_r'], state['w_w'], zeta['f'])
|
|
|
|
a = self.get_allocation(next_u)
|
|
next_w_w = self.update_w_w(c_w, a, zeta['g_w'], zeta['g_a'])
|
|
next_M = self.update_M(state['M'], next_w_w, zeta['v'], zeta['e'])
|
|
next_L = self.update_L(state['p'], state['L'], next_w_w)
|
|
next_p = self.update_p(state['p'], next_w_w)
|
|
|
|
return next_u, next_w_w, next_M, next_L, next_p
|
|
|
|
def read(self, zeta, state):
|
|
'''
|
|
Performs a read action on the DNC's memory
|
|
Parameters:
|
|
----------
|
|
zeta: dict
|
|
variable names (string) mapping to tensors (Tensor) includes:
|
|
'kappa_r': (batch_size, W, R) read key (there are R of them)
|
|
'beta_r': (batch_size, R) read strength
|
|
'kappa_w': (batch_size, W, 1) write key
|
|
'beta_w': (batch_size, 1) write strength
|
|
'e': (batch_size, W) erase vector
|
|
'v': (batch_size, W) write vector
|
|
'f': (batch_size, R) free gates (R of them)
|
|
'g_a': (batch_size, 1) allocation gate
|
|
'g_w': (batch_size, 1) write gate
|
|
'pi': (batch_size, 3, R) read modes (backward, content, forward)
|
|
... see paper for more info
|
|
state: dict
|
|
contains initial values for (M, u, p, L, w_w, w_r, r) respectively. According to the DNC paper:
|
|
M: (batch_size, N, W) the memory vector
|
|
u: (batch_size, N) the usage vector
|
|
p: (batch_size, N) the precedence weighting (helps update L)
|
|
L: (batch_size, N, N) the temporal linkage matrix (helps DNC remember what order things were written)
|
|
w_w: (batch_size, N) the write weighting - says where the DNC wrote word last time step
|
|
w_r: (batch_size, N, R )the read vector - says which word vectors the DNC accessed last time step
|
|
Returns: Tuple(2)
|
|
next_w_r: Tensor
|
|
next_r: Tensor
|
|
'''
|
|
c_w = self.content_addressing(state['M'], zeta['kappa_r'], zeta['beta_r'])
|
|
f_w, b_w = self.get_bf_w(state['w_r'], state['L'])
|
|
next_w_r = self.update_w_r(c_w, f_w, b_w, zeta['pi'])
|
|
next_r = self.update_r(state['M'], next_w_r)
|
|
return next_w_r, next_r
|
|
|
|
def step(self, zeta, state):
|
|
'''
|
|
Combines the read and write operations into a single memory update step.
|
|
'''
|
|
state['u'], state['w_w'], state['M'], state['L'], state['p'] = self.write(zeta, state)
|
|
state['w_r'], state['r'] = self.read(zeta, state)
|
|
return state |