add memory units and test

This commit is contained in:
Joerg Franke 2018-06-25 00:32:42 +02:00
parent 11637635f0
commit 0fe3939a17
12 changed files with 2610 additions and 0 deletions

View File

@ -0,0 +1,14 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

View File

@ -0,0 +1,90 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from abc import abstractmethod, ABCMeta
import numpy as np
import tensorflow as tf
class BaseMemoryUnitCell():
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,
seed=100, reuse=False, analyse=False, dtype=tf.float32, name='base'):
self.rng = np.random.RandomState(seed=seed)
self.seed = seed
self.dtype = dtype
self.analyse = analyse
# dnc parameters
self.input_size = input_size
self.h_N = memory_length
self.h_W = memory_width
self.h_RH = read_heads
self.dnc_norm = dnc_norm
self.bypass_dropout = bypass_dropout
self.reuse = reuse
self.name = name
@property
@abstractmethod
def state_size(self):
pass
@abstractmethod
def zero_state(self):
pass
@property
def output_size(self):
return self.h_RH * self.h_W + self.input_size
@property
def trainable_variables(self):
return tf.get_collection('memory_unit')
@property
def parameter_amount(self):
var_list = self.trainable_variables
parameters = 0
for variable in var_list:
shape = variable.get_shape()
variable_parametes = 1
for dim in shape:
variable_parametes *= dim.value
parameters += variable_parametes
return parameters
@staticmethod
def _calculate_content_weightings(memory, keys, strengths):
similarity_numerator = tf.matmul(keys, memory, adjoint_b=True)
norm_memory = tf.sqrt(tf.reduce_sum(tf.square(memory), axis=2, keepdims=True))
norm_keys = tf.sqrt(tf.reduce_sum(tf.square(keys), axis=2, keepdims=True))
similarity_denominator = tf.matmul(norm_keys, norm_memory, adjoint_b=True)
similarity = similarity_numerator / similarity_denominator
similarity = tf.squeeze(similarity)
adjusted_similarity = similarity * strengths
softmax_similarity = tf.nn.softmax(adjusted_similarity, dim=-1)
return softmax_similarity
@staticmethod
def _read_memory(memory, read_weightings):
read_vectors = tf.matmul(read_weightings, memory)
return read_vectors

View File

@ -0,0 +1,158 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
from adnc.model.utils import oneplus
from adnc.model.utils import unit_simplex_initialization
class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):
@property
def state_size(self):
init_memory = tf.TensorShape([self.h_N, self.h_W])
init_usage_vector = tf.TensorShape([self.h_N])
init_write_weighting = tf.TensorShape([self.h_N])
init_read_weighting = tf.TensorShape([self.h_RH, self.h_N])
return (init_memory, init_usage_vector, init_write_weighting, init_read_weighting)
def zero_state(self, batch_size, dtype=tf.float32):
init_memory = tf.fill([batch_size, self.h_N, self.h_W], tf.cast(1 / (self.h_N * self.h_W), dtype=dtype))
init_usage_vector = tf.zeros([batch_size, self.h_N], dtype=dtype)
init_write_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_N], dtype=dtype)
init_read_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_RH, self.h_N], dtype=dtype)
zero_states = (init_memory, init_usage_vector, init_write_weighting, init_read_weighting,)
return zero_states
def analyse_state(self, batch_size, dtype=tf.float32):
alloc_gate = tf.zeros([batch_size, 1], dtype=dtype)
free_gates = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
write_gate = tf.zeros([batch_size, 1], dtype=dtype)
write_keys = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
write_strengths = tf.zeros([batch_size, 1], dtype=dtype)
write_vector = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
erase_vector = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
read_keys = tf.zeros([batch_size, self.h_RH, self.h_W], dtype=dtype)
read_strengths = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
analyse_states = alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths
return analyse_states
def _weight_input(self, inputs):
input_size = inputs.get_shape()[1].value
total_signal_size = (3 + self.h_RH) * self.h_W + 2 * self.h_RH + 3
with tf.variable_scope('{}'.format(self.name), reuse=self.reuse):
w_x = tf.get_variable("mu_w_x", (input_size, total_signal_size),
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed),
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
b_x = tf.get_variable("mu_b_x", (total_signal_size,), initializer=tf.constant_initializer(0.),
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
weighted_input = tf.matmul(inputs, w_x) + b_x
if self.dnc_norm:
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype)
return weighted_input
def __call__(self, inputs, pre_states, scope=None):
self.h_B = inputs.get_shape()[0].value
memory_ones, batch_memory_range = self._create_constant_value_tensors(self.h_B, self.dtype)
self.const_memory_ones = memory_ones
self.const_batch_memory_range = batch_memory_range
pre_memory, pre_usage_vector, pre_write_weightings, pre_read_weightings = pre_states
weighted_input = self._weight_input(inputs)
control_signals = self._create_control_signals(weighted_input)
alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths = control_signals
alloc_weightings, usage_vector = self._update_alloc_and_usage_vectors(pre_write_weightings, pre_read_weightings,
pre_usage_vector, free_gates)
write_content_weighting = self._calculate_content_weightings(pre_memory, write_keys, write_strengths)
write_weighting = self._update_write_weighting(alloc_weightings, write_content_weighting, write_gate,
alloc_gate)
memory = self._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
read_content_weightings = self._calculate_content_weightings(memory, read_keys, read_strengths)
read_vectors = self._read_memory(memory, read_content_weightings)
read_vectors = tf.reshape(read_vectors, [self.h_B, self.h_W * self.h_RH])
if self.bypass_dropout:
input_bypass = tf.nn.dropout(inputs, self.bypass_dropout)
else:
input_bypass = inputs
output = tf.concat([read_vectors, input_bypass], axis=-1)
if self.analyse:
output = (output, control_signals)
return output, (memory, usage_vector, write_weighting, read_content_weightings)
def _create_constant_value_tensors(self, batch_size, dtype):
memory_ones = tf.ones([batch_size, self.h_N, self.h_W], dtype=dtype, name="memory_ones")
batch_range = tf.range(0, batch_size, delta=1, dtype=tf.int32, name="batch_range")
repeat_memory_length = tf.fill([self.h_N], tf.constant(self.h_N, dtype=tf.int32), name="repeat_memory_length")
batch_memory_range = tf.matmul(tf.expand_dims(batch_range, -1), tf.expand_dims(repeat_memory_length, 0),
name="batch_memory_range")
return memory_ones, batch_memory_range
def _create_control_signals(self, weighted_input):
write_keys = weighted_input[:, : self.h_W] # W
write_strengths = weighted_input[:, self.h_W: self.h_W + 1] # 1
erase_vector = weighted_input[:, self.h_W + 1: 2 * self.h_W + 1] # W
write_vector = weighted_input[:, 2 * self.h_W + 1: 3 * self.h_W + 1] # W
alloc_gates = weighted_input[:, 3 * self.h_W + 1: 3 * self.h_W + 2] # 1
write_gates = weighted_input[:, 3 * self.h_W + 2: 3 * self.h_W + 3] # 1
read_keys = weighted_input[:, 3 * self.h_W + 3: (self.h_RH + 3) * self.h_W + 3] # R * W
read_strengths = weighted_input[:,
(self.h_RH + 3) * self.h_W + 3: (self.h_RH + 3) * self.h_W + 3 + 1 * self.h_RH] # R
free_gates = weighted_input[:, (self.h_RH + 3) * self.h_W + 3 + 1 * self.h_RH: (
self.h_RH + 3) * self.h_W + 3 + 2 * self.h_RH]
alloc_gates = tf.sigmoid(alloc_gates, 'alloc_gates')
free_gates = tf.sigmoid(free_gates, 'free_gates')
free_gates = tf.expand_dims(free_gates, 2)
write_gates = tf.sigmoid(write_gates, 'write_gates')
write_keys = tf.expand_dims(write_keys, axis=1)
write_strengths = oneplus(write_strengths)
write_vector = tf.reshape(write_vector, [self.h_B, 1, self.h_W])
erase_vector = tf.sigmoid(erase_vector, 'erase_vector')
erase_vector = tf.reshape(erase_vector, [self.h_B, 1, self.h_W])
read_keys = tf.reshape(read_keys, [self.h_B, self.h_RH, self.h_W])
read_strengths = oneplus(read_strengths)
read_strengths = tf.expand_dims(read_strengths, axis=2)
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths

View File

@ -0,0 +1,255 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
from adnc.model.utils import layer_norm
from adnc.model.utils import oneplus
from adnc.model.utils import unit_simplex_initialization
class DNCMemoryUnitCell(BaseMemoryUnitCell):
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,
seed=100, reuse=False, analyse=False, dtype=tf.float32, name='dnc_mu'):
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
analyse, dtype, name)
self.h_B = 0 # will set in call
@property
def state_size(self):
init_memory = tf.TensorShape([self.h_N, self.h_W])
init_usage_vector = tf.TensorShape([self.h_N])
init_write_weighting = tf.TensorShape([self.h_N])
init_precedence_weightings = tf.TensorShape([self.h_N])
init_link_mat = tf.TensorShape([self.h_N, self.h_N])
init_read_weighting = tf.TensorShape([self.h_RH, self.h_N])
return (init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings,
init_link_mat, init_read_weighting)
def zero_state(self, batch_size, dtype=tf.float32):
init_memory = tf.fill([batch_size, self.h_N, self.h_W], tf.cast(1 / (self.h_N * self.h_W), dtype=dtype))
init_usage_vector = tf.zeros([batch_size, self.h_N], dtype=dtype)
init_write_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_N], dtype=dtype)
init_precedence_weightings = tf.zeros([batch_size, self.h_N], dtype=dtype)
init_link_mat = tf.zeros([batch_size, self.h_N, self.h_N], dtype=dtype)
init_read_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_RH, self.h_N], dtype=dtype)
zero_states = (init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings,
init_link_mat, init_read_weighting,)
return zero_states
def analyse_state(self, batch_size, dtype=tf.float32):
alloc_gate = tf.zeros([batch_size, 1], dtype=dtype)
free_gates = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
write_gate = tf.zeros([batch_size, 1], dtype=dtype)
write_keys = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
write_strengths = tf.zeros([batch_size, 1], dtype=dtype)
write_vector = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
erase_vector = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
read_keys = tf.zeros([batch_size, self.h_RH, self.h_W], dtype=dtype)
read_strengths = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
read_modes = tf.zeros([batch_size, self.h_RH, 3], dtype=dtype)
analyse_states = alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths, read_modes
return analyse_states
def __call__(self, inputs, pre_states, scope=None):
self.h_B = inputs.get_shape()[0].value
link_matrix_inv_eye, memory_ones, batch_memory_range = self._create_constant_value_tensors(self.h_B, self.dtype)
self.const_link_matrix_inv_eye = link_matrix_inv_eye
self.const_memory_ones = memory_ones
self.const_batch_memory_range = batch_memory_range
pre_memory, pre_usage_vector, pre_write_weightings, pre_precedence_weighting, pre_link_matrix, pre_read_weightings = pre_states
weighted_input = self._weight_input(inputs)
control_signals = self._create_control_signals(weighted_input)
alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths, read_modes = control_signals
alloc_weightings, usage_vector = self._update_alloc_and_usage_vectors(pre_write_weightings, pre_read_weightings,
pre_usage_vector, free_gates)
write_content_weighting = self._calculate_content_weightings(pre_memory, write_keys, write_strengths)
write_weighting = self._update_write_weighting(alloc_weightings, write_content_weighting, write_gate,
alloc_gate)
memory = self._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
link_matrix, precedence_weighting = self._update_link_matrix(pre_link_matrix, write_weighting,
pre_precedence_weighting)
forward_weightings, backward_weightings = self._make_read_forward_backward_weightings(link_matrix,
pre_read_weightings)
read_content_weightings = self._calculate_content_weightings(memory, read_keys, read_strengths)
read_weightings = self._make_read_weightings(forward_weightings, backward_weightings, read_content_weightings,
read_modes)
read_vectors = self._read_memory(memory, read_weightings)
read_vectors = tf.reshape(read_vectors, [self.h_B, self.h_W * self.h_RH])
if self.bypass_dropout:
input_bypass = tf.nn.dropout(inputs, self.bypass_dropout)
else:
input_bypass = inputs
output = tf.concat([read_vectors, input_bypass], axis=-1)
if self.analyse:
output = (output, control_signals)
return output, (memory, usage_vector, write_weighting, precedence_weighting, link_matrix, read_weightings)
def _create_constant_value_tensors(self, batch_size, dtype):
link_matrix_inv_eye = 1 - tf.constant(np.identity(self.h_N), dtype=dtype, name="link_matrix_inv_eye")
memory_ones = tf.ones([batch_size, self.h_N, self.h_W], dtype=dtype, name="memory_ones")
batch_range = tf.range(0, batch_size, delta=1, dtype=tf.int32, name="batch_range")
repeat_memory_length = tf.fill([self.h_N], tf.constant(self.h_N, dtype=tf.int32), name="repeat_memory_length")
batch_memory_range = tf.matmul(tf.expand_dims(batch_range, -1), tf.expand_dims(repeat_memory_length, 0),
name="batch_memory_range")
return link_matrix_inv_eye, memory_ones, batch_memory_range
def _weight_input(self, inputs):
input_size = inputs.get_shape()[1].value
total_signal_size = (3 + self.h_RH) * self.h_W + 5 * self.h_RH + 3
with tf.variable_scope('{}'.format(self.name), reuse=self.reuse):
w_x = tf.get_variable("mu_w_x", (input_size, total_signal_size),
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed),
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
b_x = tf.get_variable("mu_b_x", (total_signal_size,), initializer=tf.constant_initializer(0.),
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
weighted_input = tf.matmul(inputs, w_x) + b_x
if self.dnc_norm:
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
collection='memory_unit')
return weighted_input
def _create_control_signals(self, weighted_input):
write_keys = weighted_input[:, : self.h_W] # W
write_strengths = weighted_input[:, self.h_W: self.h_W + 1] # 1
erase_vector = weighted_input[:, self.h_W + 1: 2 * self.h_W + 1] # W
write_vector = weighted_input[:, 2 * self.h_W + 1: 3 * self.h_W + 1] # W
alloc_gates = weighted_input[:, 3 * self.h_W + 1: 3 * self.h_W + 2] # 1
write_gates = weighted_input[:, 3 * self.h_W + 2: 3 * self.h_W + 3] # 1
read_keys = weighted_input[:, 3 * self.h_W + 3: (self.h_RH + 3) * self.h_W + 3] # R * W
read_strengths = weighted_input[:,
(self.h_RH + 3) * self.h_W + 3: (self.h_RH + 3) * self.h_W + 3 + 1 * self.h_RH] # R
read_modes = weighted_input[:, (self.h_RH + 3) * self.h_W + 3 + 1 * self.h_RH: (
self.h_RH + 3) * self.h_W + 3 + 4 * self.h_RH] # 3R
free_gates = weighted_input[:, (self.h_RH + 3) * self.h_W + 3 + 4 * self.h_RH: (
self.h_RH + 3) * self.h_W + 3 + 5 * self.h_RH] # R
alloc_gates = tf.sigmoid(alloc_gates, 'alloc_gates')
free_gates = tf.sigmoid(free_gates, 'free_gates')
free_gates = tf.expand_dims(free_gates, 2)
write_gates = tf.sigmoid(write_gates, 'write_gates')
write_keys = tf.expand_dims(write_keys, axis=1)
write_strengths = oneplus(write_strengths)
# write_strengths = tf.expand_dims(write_strengths, axis=2)
write_vector = tf.reshape(write_vector, [self.h_B, 1, self.h_W])
erase_vector = tf.sigmoid(erase_vector, 'erase_vector')
erase_vector = tf.reshape(erase_vector, [self.h_B, 1, self.h_W])
read_keys = tf.reshape(read_keys, [self.h_B, self.h_RH, self.h_W])
read_strengths = oneplus(read_strengths)
read_strengths = tf.expand_dims(read_strengths, axis=2)
read_modes = tf.reshape(read_modes, [self.h_B, self.h_RH, 3]) # 3 read modes
read_modes = tf.nn.softmax(read_modes, dim=2)
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths, read_modes
def _update_alloc_and_usage_vectors(self, pre_write_weightings, pre_read_weightings, pre_usage_vector, free_gates):
retention_vector = tf.reduce_prod(1 - free_gates * pre_read_weightings, axis=1, keepdims=False,
name='retention_prod')
usage_vector = (
pre_usage_vector + pre_write_weightings - pre_usage_vector * pre_write_weightings) * retention_vector
sorted_usage, free_list = tf.nn.top_k(-1 * usage_vector, self.h_N)
sorted_usage = -1 * sorted_usage
cumprod_sorted_usage = tf.cumprod(sorted_usage, axis=1, exclusive=True)
corrected_free_list = free_list + self.const_batch_memory_range
cumprod_sorted_usage_re = [tf.reshape(cumprod_sorted_usage, [-1, ]), ]
corrected_free_list_re = [tf.reshape(corrected_free_list, [-1]), ]
stitched_usage = tf.dynamic_stitch(corrected_free_list_re, cumprod_sorted_usage_re, name=None)
stitched_usage = tf.reshape(stitched_usage, [self.h_B, self.h_N])
alloc_weighting = (1 - usage_vector) * stitched_usage
return alloc_weighting, usage_vector
@staticmethod
def _update_write_weighting(alloc_weighting, write_content_weighting, write_gate, alloc_gate):
write_weighting = write_gate * (alloc_gate * alloc_weighting + (1 - alloc_gate) * write_content_weighting)
return write_weighting
def _update_memory(self, pre_memory, write_weighting, write_vector, erase_vector):
write_w = tf.expand_dims(write_weighting, 2)
erase_matrix = tf.multiply(pre_memory, (self.const_memory_ones - tf.matmul(write_w, erase_vector)))
write_matrix = tf.matmul(write_w, write_vector)
return erase_matrix + write_matrix
def _update_link_matrix(self, pre_link_matrix, write_weighting, pre_precedence_weighting):
precedence_weighting = (1 - tf.reduce_sum(write_weighting, 1,
keepdims=True)) * pre_precedence_weighting + write_weighting
add_mat = tf.matmul(tf.expand_dims(write_weighting, axis=2),
tf.expand_dims(pre_precedence_weighting, axis=1))
erase_mat = 1 - tf.expand_dims(write_weighting, 1) - tf.expand_dims(write_weighting, 2)
updated_link_mat = erase_mat * pre_link_matrix + add_mat
link_matrix = self.const_link_matrix_inv_eye * updated_link_mat
return link_matrix, precedence_weighting
@staticmethod
def _make_read_forward_backward_weightings(link_matrix, pre_read_weightings):
forward_weightings = tf.matmul(pre_read_weightings, link_matrix)
backward_weightings = tf.matmul(pre_read_weightings, link_matrix, adjoint_b=True)
return forward_weightings, backward_weightings
@staticmethod
def _make_read_weightings(forward_weightings, backward_weightings, read_content_weightings, read_modes):
read_weighting = tf.expand_dims(read_modes[:, :, 0], 2) * backward_weightings + \
tf.expand_dims(read_modes[:, :, 1], 2) * read_content_weightings + \
tf.expand_dims(read_modes[:, :, 2], 2) * forward_weightings
return read_weighting

View File

@ -0,0 +1,160 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import tensorflow as tf
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
from adnc.model.utils import layer_norm
from adnc.model.utils import oneplus
from adnc.model.utils import unit_simplex_initialization
class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):
@property
def state_size(self):
init_memory = tf.TensorShape([self.h_N, self.h_W])
init_usage_vector = tf.TensorShape([self.h_N])
init_write_weighting = tf.TensorShape([self.h_WH, self.h_N])
init_read_weighting = tf.TensorShape([self.h_RH, self.h_N])
return (init_memory, init_usage_vector, init_write_weighting, init_read_weighting)
def zero_state(self, batch_size, dtype=tf.float32):
init_memory = tf.fill([batch_size, self.h_N, self.h_W], tf.cast(1 / (self.h_N * self.h_W), dtype=dtype))
init_usage_vector = tf.zeros([batch_size, self.h_N], dtype=dtype)
init_write_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_WH, self.h_N], dtype=dtype)
init_read_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_RH, self.h_N], dtype=dtype)
zero_states = (init_memory, init_usage_vector, init_write_weighting, init_read_weighting,)
return zero_states
def analyse_state(self, batch_size, dtype=tf.float32):
alloc_gate = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype) # WH
free_gates = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
write_gate = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype)
write_keys = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
write_strengths = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype)
write_vector = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
erase_vector = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
read_keys = tf.zeros([batch_size, self.h_RH, self.h_W], dtype=dtype)
read_strengths = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
analyse_states = alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths
return analyse_states
def __call__(self, inputs, pre_states, scope=None):
self.h_B = inputs.get_shape()[0].value
memory_ones, batch_memory_range = self._create_constant_value_tensors(self.h_B, self.dtype)
self.const_memory_ones = memory_ones
self.const_batch_memory_range = batch_memory_range
pre_memory, pre_usage_vector, pre_write_weightings, pre_read_weightings = pre_states
weighted_input = self._weight_input(inputs)
control_signals = self._create_control_signals(weighted_input)
alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths = control_signals
alloc_weightings, usage_vector = self._update_alloc_and_usage_vectors(pre_write_weightings, pre_read_weightings,
pre_usage_vector, free_gates, write_gate)
write_content_weighting = self._calculate_content_weightings(pre_memory, write_keys, write_strengths)
write_weighting = self._update_write_weightings(alloc_weightings, write_content_weighting, write_gate,
alloc_gate)
memory = self._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
read_content_weightings = self._calculate_content_weightings(memory, read_keys, read_strengths)
read_vectors = self._read_memory(memory, read_content_weightings)
read_vectors = tf.reshape(read_vectors, [self.h_B, self.h_W * self.h_RH])
if self.bypass_dropout:
input_bypass = tf.nn.dropout(inputs, self.bypass_dropout)
else:
input_bypass = inputs
output = tf.concat([read_vectors, input_bypass], axis=-1)
if self.analyse:
output = (output, control_signals)
return output, (memory, usage_vector, write_weighting, read_content_weightings)
def _create_constant_value_tensors(self, batch_size, dtype):
memory_ones = tf.ones([batch_size, self.h_N, self.h_W], dtype=dtype, name="memory_ones")
batch_range = tf.range(0, batch_size, delta=1, dtype=tf.int32, name="batch_range")
repeat_memory_length = tf.fill([self.h_N], tf.constant(self.h_N, dtype=tf.int32), name="repeat_memory_length")
batch_memory_range = tf.matmul(tf.expand_dims(batch_range, -1), tf.expand_dims(repeat_memory_length, 0),
name="batch_memory_range")
return memory_ones, batch_memory_range
def _weight_input(self, inputs):
input_size = inputs.get_shape()[1].value
total_signal_size = self.h_RH * (2 + self.h_W) + self.h_WH * (3 + 3 * self.h_W)
with tf.variable_scope('{}'.format(self.name), reuse=self.reuse):
w_x = tf.get_variable("mu_w_x", (input_size, total_signal_size),
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed),
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
b_x = tf.get_variable("mu_b_x", (total_signal_size,), initializer=tf.constant_initializer(0.),
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
weighted_input = tf.matmul(inputs, w_x) + b_x
if self.dnc_norm:
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
collection='memory_unit')
return weighted_input
def _create_control_signals(self, weighted_input):
alloc_gates = weighted_input[:, : self.h_WH]
free_gates = weighted_input[:, self.h_WH: self.h_WH + self.h_RH]
write_gates = weighted_input[:, self.h_WH + self.h_RH: 2 * self.h_WH + self.h_RH]
write_keys = weighted_input[:, 2 * self.h_WH + self.h_RH: (self.h_W + 2) * self.h_WH + self.h_RH]
write_strengths = weighted_input[:,
(self.h_W + 2) * self.h_WH + self.h_RH: (self.h_W + 3) * self.h_WH + self.h_RH]
write_vectors = weighted_input[:,
(self.h_W + 3) * self.h_WH + self.h_RH: (2 * self.h_W + 3) * self.h_WH + self.h_RH]
erase_vectors = weighted_input[:,
(2 * self.h_W + 3) * self.h_WH + self.h_RH: (3 * self.h_W + 3) * self.h_WH + self.h_RH]
read_keys = weighted_input[:, (3 * self.h_W + 3) * self.h_WH + self.h_RH: (3 * self.h_W + 3) * self.h_WH
+ (self.h_W + 1) * self.h_RH]
read_strengths = weighted_input[:, (3 * self.h_W + 3) * self.h_WH + (self.h_W + 1) * self.h_RH:]
alloc_gates = tf.sigmoid(alloc_gates, 'alloc_gates')
alloc_gates = tf.expand_dims(alloc_gates, 2)
free_gates = tf.sigmoid(free_gates, 'free_gates')
free_gates = tf.expand_dims(free_gates, 2)
write_gates = tf.sigmoid(write_gates, 'write_gates')
write_gates = tf.expand_dims(write_gates, 2)
write_keys = tf.reshape(write_keys, [self.h_B, self.h_WH, self.h_W])
write_strengths = oneplus(write_strengths)
write_strengths = tf.expand_dims(write_strengths, axis=2)
write_vectors = tf.reshape(write_vectors, [self.h_B, self.h_WH, self.h_W])
erase_vectors = tf.reshape(erase_vectors, [self.h_B, self.h_WH, self.h_W])
erase_vectors = tf.sigmoid(erase_vectors, 'erase_vector')
read_keys = tf.reshape(read_keys, [self.h_B, self.h_RH, self.h_W])
read_strengths = oneplus(read_strengths)
read_strengths = tf.expand_dims(read_strengths, axis=2)
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
erase_vectors, read_keys, read_strengths

View File

@ -0,0 +1,275 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
from adnc.model.utils import layer_norm
from adnc.model.utils import oneplus
from adnc.model.utils import unit_simplex_initialization
class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
def __init__(self, input_size, memory_length, memory_width, read_heads, write_heads, bypass_dropout=False,
dnc_norm=False, seed=100, reuse=False, analyse=False, dtype=tf.float32, name='mwdnc_mu'):
self.h_WH = write_heads
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
analyse, dtype, name)
self.h_B = 0 # will set in call
@property
def state_size(self):
init_memory = tf.TensorShape([self.h_N, self.h_W])
init_usage_vector = tf.TensorShape([self.h_N])
init_write_weighting = tf.TensorShape([self.h_WH, self.h_N])
init_precedence_weightings = tf.TensorShape([self.h_WH, self.h_N])
init_link_mat = tf.TensorShape([self.h_WH, self.h_N, self.h_N])
init_read_weighting = tf.TensorShape([self.h_RH, self.h_N])
return (init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings,
init_link_mat, init_read_weighting)
def zero_state(self, batch_size, dtype=tf.float32):
init_memory = tf.fill([batch_size, self.h_N, self.h_W], tf.cast(1 / (self.h_N * self.h_W), dtype=dtype))
init_usage_vector = tf.zeros([batch_size, self.h_N], dtype=dtype)
init_write_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_WH, self.h_N], dtype=dtype)
init_precedence_weightings = tf.zeros([batch_size, self.h_WH, self.h_N], dtype=dtype)
init_link_mat = tf.zeros([batch_size, self.h_WH, self.h_N, self.h_N], dtype=dtype)
init_read_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_RH, self.h_N], dtype=dtype)
zero_states = (init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings,
init_link_mat, init_read_weighting,)
return zero_states
def analyse_state(self, batch_size, dtype=tf.float32):
alloc_gate = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype) # WH
free_gates = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
write_gate = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype)
write_keys = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
write_strengths = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype)
write_vector = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
erase_vector = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
read_keys = tf.zeros([batch_size, self.h_RH, self.h_W], dtype=dtype)
read_strengths = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
read_modes = tf.zeros([batch_size, self.h_RH, 1 + 2 * self.h_WH], dtype=dtype)
analyse_states = alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths, read_modes
return analyse_states
def __call__(self, inputs, pre_states, scope=None):
self.h_B = inputs.get_shape()[0].value
link_matrix_inv_eye, memory_ones, batch_memory_range = self._create_constant_value_tensors(self.h_B, self.dtype)
self.const_link_matrix_inv_eye = link_matrix_inv_eye
self.const_memory_ones = memory_ones
self.const_batch_memory_range = batch_memory_range
pre_memory, pre_usage_vector, pre_write_weightings, pre_precedence_weighting, pre_link_matrix, pre_read_weightings = pre_states
weighted_input = self._weight_input(inputs)
control_signals = self._create_control_signals(weighted_input)
alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths, read_modes = control_signals
alloc_weightings, usage_vector = self._update_alloc_and_usage_vectors(pre_write_weightings, pre_read_weightings,
pre_usage_vector, free_gates, write_gate)
write_content_weighting = self._calculate_content_weightings(pre_memory, write_keys, write_strengths)
write_weighting = self._update_write_weightings(alloc_weightings, write_content_weighting, write_gate,
alloc_gate)
memory = self._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
link_matrix, precedence_weighting = self._update_link_matrix(pre_link_matrix, write_weighting,
pre_precedence_weighting)
forward_weightings, backward_weightings = self._make_read_forward_backward_weightings(link_matrix,
pre_read_weightings)
read_content_weightings = self._calculate_content_weightings(memory, read_keys, read_strengths)
read_weightings = self._make_read_weightings(forward_weightings, backward_weightings, read_content_weightings,
read_modes)
read_vectors = self._read_memory(memory, read_weightings)
read_vectors = tf.reshape(read_vectors, [self.h_B, self.h_W * self.h_RH])
if self.bypass_dropout:
input_bypass = tf.nn.dropout(inputs, self.bypass_dropout)
else:
input_bypass = inputs
output = tf.concat([read_vectors, input_bypass], axis=-1)
if self.analyse:
output = (output, control_signals)
return output, (memory, usage_vector, write_weighting, precedence_weighting, link_matrix, read_weightings)
def _create_constant_value_tensors(self, batch_size, dtype):
link_matrix_inv_eye = 1 - tf.constant(np.identity(self.h_N), dtype=dtype, name="link_matrix_inv_eye")
link_matrix_inv_eye = tf.stack([link_matrix_inv_eye, ] * self.h_WH, axis=0)
link_matrix_inv_eye = tf.stack([link_matrix_inv_eye, ] * batch_size, axis=0)
memory_ones = tf.ones([batch_size, self.h_N, self.h_W], dtype=dtype, name="memory_ones")
batch_range = tf.range(0, batch_size, delta=1, dtype=tf.int32, name="batch_range")
repeat_memory_length = tf.fill([self.h_N], tf.constant(self.h_N, dtype=tf.int32), name="repeat_memory_length")
batch_memory_range = tf.matmul(tf.expand_dims(batch_range, -1), tf.expand_dims(repeat_memory_length, 0),
name="batch_memory_range")
return link_matrix_inv_eye, memory_ones, batch_memory_range
def _weight_input(self, inputs):
input_size = inputs.get_shape()[1].value
total_signal_size = self.h_RH * (3 + 2 * self.h_WH + self.h_W) + self.h_WH * (3 + 3 * self.h_W)
with tf.variable_scope('{}'.format(self.name), reuse=self.reuse):
w_x = tf.get_variable("mu_w_x", (input_size, total_signal_size),
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed),
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
b_x = tf.get_variable("mu_b_x", (total_signal_size,), initializer=tf.constant_initializer(0.),
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
weighted_input = tf.matmul(inputs, w_x) + b_x
if self.dnc_norm:
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
collection='memory_unit')
return weighted_input
def _create_control_signals(self, weighted_input):
alloc_gates = weighted_input[:, : self.h_WH]
free_gates = weighted_input[:, self.h_WH: self.h_WH + self.h_RH]
write_gates = weighted_input[:, self.h_WH + self.h_RH: 2 * self.h_WH + self.h_RH]
write_keys = weighted_input[:, 2 * self.h_WH + self.h_RH: (self.h_W + 2) * self.h_WH + self.h_RH]
write_strengths = weighted_input[:,
(self.h_W + 2) * self.h_WH + self.h_RH: (self.h_W + 3) * self.h_WH + self.h_RH]
write_vectors = weighted_input[:,
(self.h_W + 3) * self.h_WH + self.h_RH: (2 * self.h_W + 3) * self.h_WH + self.h_RH]
erase_vectors = weighted_input[:,
(2 * self.h_W + 3) * self.h_WH + self.h_RH: (3 * self.h_W + 3) * self.h_WH + self.h_RH]
read_keys = weighted_input[:, (3 * self.h_W + 3) * self.h_WH + self.h_RH: (3 * self.h_W + 3) * self.h_WH +
(self.h_W + 1) * self.h_RH]
read_strengths = weighted_input[:,
(3 * self.h_W + 3) * self.h_WH + (self.h_W + 1) * self.h_RH: (3 * self.h_W + 3) * self.h_WH +
(self.h_W + 2) * self.h_RH]
read_modes = weighted_input[:, (3 * self.h_W + 3) * self.h_WH + (self.h_W + 2) * self.h_RH:]
alloc_gates = tf.sigmoid(alloc_gates, 'alloc_gates')
alloc_gates = tf.expand_dims(alloc_gates, 2)
free_gates = tf.sigmoid(free_gates, 'free_gates')
free_gates = tf.expand_dims(free_gates, 2)
write_gates = tf.sigmoid(write_gates, 'write_gates')
write_gates = tf.expand_dims(write_gates, 2)
write_keys = tf.reshape(write_keys, [self.h_B, self.h_WH, self.h_W])
write_strengths = oneplus(write_strengths)
write_strengths = tf.expand_dims(write_strengths, axis=2)
write_vectors = tf.reshape(write_vectors, [self.h_B, self.h_WH, self.h_W])
erase_vectors = tf.reshape(erase_vectors, [self.h_B, self.h_WH, self.h_W])
erase_vectors = tf.sigmoid(erase_vectors, 'erase_vector')
read_keys = tf.reshape(read_keys, [self.h_B, self.h_RH, self.h_W])
read_strengths = oneplus(read_strengths)
read_strengths = tf.expand_dims(read_strengths, axis=2)
read_modes = tf.reshape(read_modes, [self.h_B, self.h_RH, 1 + 2 * self.h_WH])
read_modes = tf.nn.softmax(read_modes, dim=2)
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
erase_vectors, read_keys, read_strengths, read_modes
def _update_alloc_and_usage_vectors(self, pre_write_weightings, pre_read_weightings, pre_usage_vector, free_gates,
write_gates):
# usage update after write from last time step
pre_write_weighting = 1 - tf.reduce_prod(1 - pre_write_weightings, [1], keepdims=False)
usage_vector = pre_usage_vector + pre_write_weighting - pre_usage_vector * pre_write_weighting
# usage update after read
retention_vector = tf.reduce_prod(1 - free_gates * pre_read_weightings, axis=1, keepdims=False,
name='retention_prod')
usage_vector = usage_vector * retention_vector
usage_vector_cp = tf.identity(usage_vector)
alloc_list = []
for w in range(self.h_WH):
sorted_usage, free_list = tf.nn.top_k(-1 * usage_vector_cp, self.h_N)
sorted_usage = -1 * sorted_usage
cumprod_sorted_usage = tf.cumprod(sorted_usage, axis=1, exclusive=True)
corrected_free_list = free_list + self.const_batch_memory_range
corrected_free_list_un = [tf.reshape(corrected_free_list, [-1, ]), ]
cumprod_sorted_usage_un = [tf.reshape(cumprod_sorted_usage, [-1, ]), ]
stitched_usage = tf.dynamic_stitch(corrected_free_list_un, cumprod_sorted_usage_un, name=None)
stitched_usage = tf.reshape(stitched_usage, [self.h_B, self.h_N])
alloc_weighting = (1 - usage_vector_cp) * stitched_usage
alloc_list.append(alloc_weighting)
usage_vector_cp = usage_vector_cp + ((1 - usage_vector_cp) * write_gates[:, w, :] * alloc_weighting)
alloc_weighting = tf.stack(alloc_list, 1)
return alloc_weighting, usage_vector
@staticmethod
def _update_write_weightings(alloc_weighting, write_content_weighting, write_gate, alloc_gate):
write_weighting = write_gate * (alloc_gate * alloc_weighting + (1 - alloc_gate) * write_content_weighting)
return write_weighting
def _update_memory(self, pre_memory, write_weighting, write_vector, erase_vector):
write_w = tf.expand_dims(write_weighting, 3)
erase_vector = tf.expand_dims(erase_vector, 2)
erase_matrix = tf.reduce_prod(1 - write_w * erase_vector, axis=1, keepdims=False)
write_matrix = tf.matmul(write_weighting, write_vector, adjoint_a=True)
return pre_memory * erase_matrix + write_matrix
def _update_link_matrix(self, pre_link_matrices, write_weightings, pre_precedence_weightings):
precedence_weightings = (1 - tf.reduce_sum(write_weightings, 2,
keepdims=True)) * pre_precedence_weightings + write_weightings
add_mat = tf.expand_dims(write_weightings, axis=3) * tf.expand_dims(pre_precedence_weightings, axis=2)
erase_mat = 1 - tf.expand_dims(write_weightings, 2) - tf.expand_dims(write_weightings, 3)
updated_link_mat = erase_mat * pre_link_matrices + add_mat
link_matrices = self.const_link_matrix_inv_eye * updated_link_mat
return link_matrices, precedence_weightings
def _make_read_forward_backward_weightings(self, link_matrix, pre_read_weightings):
read_weightings_stacked = tf.stack([pre_read_weightings, ] * self.h_WH, axis=1)
forward_weightings = tf.matmul(read_weightings_stacked, link_matrix)
backward_weightings = tf.matmul(read_weightings_stacked, link_matrix, adjoint_b=True)
return tf.transpose(forward_weightings, (0, 2, 1, 3)), tf.transpose(backward_weightings, (0, 2, 1, 3))
def _make_read_weightings(self, forward_weightings, backward_weightings, read_content_weightings, read_modes):
read_weighting = tf.reduce_sum(tf.expand_dims(read_modes[:, :, :self.h_WH], 3) * backward_weightings, axis=2) + \
tf.expand_dims(read_modes[:, :, self.h_WH], 2) * read_content_weightings + \
tf.reduce_sum(tf.expand_dims(read_modes[:, :, self.h_WH + 1:], 3) * forward_weightings, axis=2)
return read_weighting

View File

View File

@ -0,0 +1,152 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf
import pytest
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
@pytest.fixture(
params=[{"seed": 123, "input_size": 13, "batch_size": 3, "memory_length": 4, "memory_width": 4, "read_heads": 3,
"dnc_norm": True, "bypass_dropout": False},
{"seed": 124, "input_size": 11, "batch_size": 3, "memory_length": 256, "memory_width": 23, "read_heads": 2,
"dnc_norm": False, "bypass_dropout": False},
{"seed": 125, "input_size": 5, "batch_size": 3, "memory_length": 4, "memory_width": 11, "read_heads": 8,
"dnc_norm": True, "bypass_dropout": True},
{"seed": 126, "input_size": 2, "batch_size": 3, "memory_length": 56, "memory_width": 9, "read_heads": 11,
"dnc_norm": False, "bypass_dropout": True}
])
def memory_config(request):
config = request.param
return BaseMemoryUnitCell(input_size=config['input_size'], memory_length=config["memory_length"],
memory_width=config["memory_width"],
read_heads=config["read_heads"], seed=config["seed"],
reuse=False, name='test_mu'), config
@pytest.fixture()
def session():
with tf.Session() as sess:
yield sess
tf.reset_default_graph()
@pytest.fixture()
def np_rng():
seed = np.random.randint(1, 999)
return np.random.RandomState(seed)
class TestDNCMemoryUnit():
def test_init(self, memory_config):
memory_unit, config = memory_config
assert isinstance(memory_unit, object)
assert isinstance(memory_unit.rng, np.random.RandomState)
assert memory_unit.h_N == config["memory_length"]
assert memory_unit.h_W == config["memory_width"]
assert memory_unit.h_RH == config["read_heads"]
def test_property_output_size(self, memory_config, session):
memory_unit, config = memory_config
output_size = memory_unit.output_size
assert output_size == config['memory_width'] * config["read_heads"] + config['input_size']
def test_calculate_content_weightings(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_memory = np_rng.normal(0, 1, (config['batch_size'], config['memory_length'], config['memory_width']))
np_keys = np_rng.normal(0, 2, (config['batch_size'], 1, config['memory_width']))
np_strengths = np_rng.uniform(1, 10, (config['batch_size'], 1))
memory = tf.constant(np_memory, dtype=tf.float32)
keys = tf.constant(np_keys, dtype=tf.float32)
strengths = tf.constant(np_strengths, dtype=tf.float32)
content_weightings = memory_unit._calculate_content_weightings(memory, keys, strengths)
weightings = content_weightings.eval()
np_similarity = np.empty([config['batch_size'], config['memory_length']])
for b in range(config['batch_size']):
for l in range(config['memory_length']):
np_similarity[b, l] = np.dot(np_memory[b, l, :], np_keys[b, 0, :]) / (
np.sqrt(np.dot(np_memory[b, l, :], np_memory[b, l, :])) * np.sqrt(
np.dot(np_keys[b, 0, :], np_keys[b, 0, :])))
def _weighted_softmax(x, s):
e_x = np.exp(x * s)
return e_x / e_x.sum(axis=1, keepdims=True)
np_weightings = _weighted_softmax(np_similarity, np_strengths)
assert weightings.shape == (config['batch_size'], config['memory_length'])
assert 0 <= weightings.min() and weightings.max() <= 1 and weightings.sum(axis=1).all() <= 1
assert np.allclose(weightings, np_weightings)
np_memory = np_rng.uniform(0, 1, (config['batch_size'], config['memory_length'], config['memory_width']))
np_keys = np_rng.normal(0, 2, (config['batch_size'], config['read_heads'], config['memory_width']))
np_strengths = np_rng.uniform(1, 10, (config['batch_size'], config['read_heads'], 1))
memory = tf.constant(np_memory, dtype=tf.float32)
keys = tf.constant(np_keys, dtype=tf.float32)
strengths = tf.constant(np_strengths, dtype=tf.float32)
content_weightings = memory_unit._calculate_content_weightings(memory, keys, strengths)
weightings = content_weightings.eval()
np_similarity = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
for b in range(config['batch_size']):
for r in range(config['read_heads']):
for l in range(config['memory_length']):
np_similarity[b, r, l] = np.dot(np_memory[b, l, :], np_keys[b, r, :]) / (
np.sqrt(np.dot(np_memory[b, l, :], np_memory[b, l, :])) * np.sqrt(
np.dot(np_keys[b, r, :], np_keys[b, r, :])))
np_weightings = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
def _weighted_softmax(x, s):
e_x = np.exp(x * s)
return e_x / e_x.sum(axis=1, keepdims=True)
for r in range(config['read_heads']):
np_weightings[:, r, :] = _weighted_softmax(np_similarity[:, r, :], np_strengths[:, r])
assert weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
assert 0 <= weightings.min() and weightings.max() <= 1 and weightings.sum(axis=2).all() <= 1
assert np.allclose(weightings, np_weightings)
def test_read_memory(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
memory = tf.constant(np_memory, dtype=tf.float32)
read_weightings = tf.constant(np_read_weightings, dtype=tf.float32)
read_vectors = memory_unit._read_memory(memory, read_weightings)
read_vectors = read_vectors.eval()
np_read_vectors = np.empty([config['batch_size'], config['read_heads'], config['memory_width']])
for b in range(config['batch_size']):
for r in range(config['read_heads']):
np_read_vectors[b, r, :] = np.matmul(np.expand_dims(np_read_weightings[b, r, :], 0), np_memory[b, :, :])
assert read_vectors.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
assert np.allclose(read_vectors, np_read_vectors, atol=1e-06)

View File

@ -0,0 +1,325 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf
import pytest
import time
from adnc.model.memory_units.content_based_cell import ContentBasedMemoryUnitCell
@pytest.fixture(
params=[{"seed": 123, "input_size": 13, "batch_size": 3, "memory_length": 4, "memory_width": 4, "read_heads": 3,
"dnc_norm": True, "bypass_dropout": False},
{"seed": 124, "input_size": 11, "batch_size": 3, "memory_length": 256, "memory_width": 23, "read_heads": 2,
"dnc_norm": False, "bypass_dropout": False},
{"seed": 125, "input_size": 5, "batch_size": 3, "memory_length": 4, "memory_width": 11, "read_heads": 8,
"dnc_norm": True, "bypass_dropout": True},
{"seed": 126, "input_size": 2, "batch_size": 3, "memory_length": 56, "memory_width": 9, "read_heads": 11,
"dnc_norm": False, "bypass_dropout": True}
])
def memory_config(request):
config = request.param
return ContentBasedMemoryUnitCell(input_size=config['input_size'], memory_length=config["memory_length"],
memory_width=config["memory_width"],
read_heads=config["read_heads"], seed=config["seed"],
reuse=False, name='test_mu'), config
@pytest.fixture()
def session():
with tf.Session() as sess:
yield sess
tf.reset_default_graph()
@pytest.fixture()
def np_rng():
seed = np.random.randint(1, 999)
return np.random.RandomState(seed)
class TestContentBasedMemoryUnitCell():
def test_zero_state(self, memory_config, session, np_rng):
memory_unit, config = memory_config
init_tuple = memory_unit.zero_state(batch_size=config['batch_size'], dtype=tf.float32)
# test init_tuple
init_memory, init_usage_vector, init_write_weighting, init_read_weighting = init_tuple
assert init_memory.eval().shape == (config['batch_size'], config['memory_length'], config['memory_width'])
assert init_usage_vector.eval().shape == (config['batch_size'], config['memory_length'])
assert init_write_weighting.eval().shape == (config['batch_size'], config['memory_length'])
assert init_read_weighting.eval().shape == (config['batch_size'], config["read_heads"], config['memory_length'])
def test_parameter_amount(self, memory_config, session, np_rng):
memory_unit, config = memory_config
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 2 * config['read_heads'] + 3)
inputs = np.ones([config['batch_size'], config['input_size']])
tf_input = tf.constant(inputs, tf.float32)
memory_unit._weight_input(tf_input)
parameter_amount = memory_unit.parameter_amount
assert parameter_amount == (config['input_size'] + 1) * total_signal_size
def test_create_constant_value_tensors(self, memory_config, session, np_rng):
memory_unit, config = memory_config
memory_ones, batch_memory_range = memory_unit._create_constant_value_tensors(batch_size=config['batch_size'],
dtype=tf.float32)
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
assert np.array_equal(memory_ones.eval(), np_memory_ones)
np_batch_range = np.arange(0, config['batch_size'])
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
np.expand_dims(np_repeat_memory_length, 0))
assert np.array_equal(batch_memory_range.eval(), np_batch_memory_range)
def test_weight_input(self, memory_config, session, np_rng):
memory_unit, config = memory_config
inputs = np.ones([config['batch_size'], config['input_size']])
tf_input = tf.placeholder(tf.float32, [config['batch_size'], config['input_size']], name='x')
weight_inputs = memory_unit._weight_input(tf_input)
session.run(tf.global_variables_initializer())
np_weight_inputs = weight_inputs.eval(session=session, feed_dict={tf_input: inputs})
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 2 * config['read_heads'] + 3)
assert np_weight_inputs.shape == (config['batch_size'], total_signal_size)
def test_create_control_signals(self, memory_config, session, np_rng):
memory_unit, config = memory_config
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 5 * config['read_heads'] + 3)
np_weighted_input = np.array([np.arange(1, 1 + total_signal_size)] * config['batch_size'])
weighted_input = tf.constant(np_weighted_input, dtype=tf.float32)
memory_unit.h_B = config['batch_size']
control_signals = memory_unit._create_control_signals(weighted_input)
control_signals = session.run(control_signals)
alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths = control_signals
assert alloc_gates.shape == (config['batch_size'], 1)
assert 0 <= alloc_gates.min() and alloc_gates.max() <= 1
assert free_gates.shape == (config['batch_size'], config['read_heads'], 1)
assert 0 <= free_gates.min() and free_gates.max() <= 1
assert write_gates.shape == (config['batch_size'], 1)
assert 0 <= write_gates.min() and write_gates.max() <= 1
assert write_keys.shape == (config['batch_size'], 1, config['memory_width'])
assert write_strengths.shape == (config['batch_size'], 1)
assert 1 <= write_strengths.min()
assert write_vector.shape == (config['batch_size'], 1, config['memory_width'])
assert erase_vector.shape == (config['batch_size'], 1, config['memory_width'])
assert 0 <= erase_vector.min() and erase_vector.max() <= 1
# comment
assert read_keys.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
assert read_strengths.shape == (config['batch_size'], config['read_heads'], 1)
assert 1 <= read_strengths.min()
def test_update_alloc_weightings_and_usage_vectors(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_pre_write_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_usage_vectors = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_free_gates = np.ones([config['batch_size'], config['read_heads'], 1]) * 0.5
pre_write_weightings = tf.constant(np_pre_write_weightings, dtype=tf.float32)
pre_usage_vectors = tf.constant(np_pre_usage_vectors, dtype=tf.float32)
free_gates = tf.constant(np_free_gates, dtype=tf.float32)
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
inputs = tf.constant(np_inputs, dtype=tf.float32)
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_read_weightings)
memory_unit.zero_state(config['batch_size'])
memory_unit(inputs, pre_states) # just for initialization
alloc_weightings, usage_vectors = memory_unit._update_alloc_and_usage_vectors(pre_write_weightings,
pre_read_weightings,
pre_usage_vectors, free_gates)
alloc_weightings, usage_vectors = session.run([alloc_weightings, usage_vectors])
np_retention_vector = np.prod(1 - np_free_gates * np_pre_read_weightings, axis=1, keepdims=False)
np_usage_vectors = (
np_pre_usage_vectors + np_pre_write_weightings - np_pre_usage_vectors * np_pre_write_weightings) * np_retention_vector
assert usage_vectors.shape == (config['batch_size'], config['memory_length'])
assert usage_vectors.min() >= 0 and usage_vectors.max() <= 1
assert np.allclose(usage_vectors, np_usage_vectors)
free_list = np.argsort(np_usage_vectors).astype(int)
np_alloc_weightings = np.zeros([config['batch_size'], config['memory_length']])
for b in range(config['batch_size']):
for j in range(config['memory_length']):
fj = free_list[b, j]
np_alloc_weightings[b, fj] = (1 - np_usage_vectors[b, fj]) * np.prod(
[np_usage_vectors[b, free_list[b, i]] for i in range(j)])
assert alloc_weightings.shape == (config['batch_size'], config['memory_length'])
assert np.allclose(alloc_weightings, np_alloc_weightings)
def test_update_write_weighting(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_alloc_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_write_content_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_write_gate = np.ones([config['batch_size'], 1]) * 0.5
np_alloc_gate = np.ones([config['batch_size'], 1]) * 0.5
alloc_weighting = tf.constant(np_alloc_weighting, dtype=tf.float32)
write_content_weighting = tf.constant(np_write_content_weighting, dtype=tf.float32)
write_gate = tf.constant(np_write_gate, dtype=tf.float32)
alloc_gate = tf.constant(np_alloc_gate, dtype=tf.float32)
write_weighting = memory_unit._update_write_weighting(alloc_weighting, write_content_weighting, write_gate,
alloc_gate)
write_weighting = write_weighting.eval()
np_write_weighting = np_write_gate * (
np_alloc_gate * np_alloc_weighting + (1 - np_alloc_gate) * np_write_content_weighting)
assert write_weighting.shape == (config['batch_size'], config['memory_length'])
assert 0 <= write_weighting.min() and write_weighting.max() <= 1 and write_weighting.sum(axis=1).all() <= 1
assert np.allclose(write_weighting, np_write_weighting)
def test_update_memory(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_write_vector = np_rng.normal(0, 2, [config['batch_size'], 1, config['memory_width']])
np_erase_vector = np_rng.uniform(0, 1, [config['batch_size'], 1, config['memory_width']])
write_weighting = tf.constant(np_write_weighting, dtype=tf.float32)
write_vector = tf.constant(np_write_vector, dtype=tf.float32)
erase_vector = tf.constant(np_erase_vector, dtype=tf.float32)
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
inputs = tf.constant(np_inputs, dtype=tf.float32)
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_read_weightings)
memory_unit.zero_state(config['batch_size'])
memory_unit(inputs, pre_states) # just for initialization
memory = memory_unit._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
memory = memory.eval()
write_w = np.expand_dims(np_write_weighting, 2)
np_erase_memory = (1 - (write_w * np_erase_vector))
np_add_memory = np.matmul(write_w, np_write_vector)
np_memory = np_pre_memory * np_erase_memory + np_add_memory
assert memory.shape == (config['batch_size'], config['memory_length'], config['memory_width'])
assert np.allclose(memory, np_memory, atol=1e-06)
def test_read_memory(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
memory = tf.constant(np_memory, dtype=tf.float32)
read_weightings = tf.constant(np_read_weightings, dtype=tf.float32)
read_vectors = memory_unit._read_memory(memory, read_weightings)
read_vectors = read_vectors.eval()
np_read_vectors = np.empty([config['batch_size'], config['read_heads'], config['memory_width']])
for b in range(config['batch_size']):
for r in range(config['read_heads']):
np_read_vectors[b, r, :] = np.matmul(np.expand_dims(np_read_weightings[b, r, :], 0), np_memory[b, :, :])
assert read_vectors.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
assert np.allclose(read_vectors, np_read_vectors, atol=1e-06)
def test_call(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
inputs = tf.constant(np_inputs, dtype=tf.float32)
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_read_weightings)
memory_unit.zero_state(config['batch_size'])
read_vectors, states = memory_unit(inputs, pre_states)
session.run(tf.global_variables_initializer())
read_vectors, states = session.run([read_vectors, states])
# test const initialization
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
assert np.array_equal(memory_unit.const_memory_ones.eval(), np_memory_ones)
np_batch_range = np.arange(0, config['batch_size'])
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
np.expand_dims(np_repeat_memory_length, 0))
assert np.array_equal(memory_unit.const_batch_memory_range.eval(), np_batch_memory_range)
assert read_vectors.shape == (
config['batch_size'], config['memory_width'] * config['read_heads'] + config['input_size'])

View File

@ -0,0 +1,466 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf
import pytest
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
@pytest.fixture(
params=[{"seed": 123, "input_size": 13, "batch_size": 3, "memory_length": 4, "memory_width": 4, "read_heads": 3,
"dnc_norm": True, "bypass_dropout": False},
{"seed": 124, "input_size": 11, "batch_size": 3, "memory_length": 256, "memory_width": 23, "read_heads": 2,
"dnc_norm": False, "bypass_dropout": False},
{"seed": 125, "input_size": 5, "batch_size": 3, "memory_length": 4, "memory_width": 11, "read_heads": 8,
"dnc_norm": True, "bypass_dropout": True},
{"seed": 126, "input_size": 2, "batch_size": 3, "memory_length": 56, "memory_width": 9, "read_heads": 11,
"dnc_norm": False, "bypass_dropout": True}
])
def memory_config(request):
config = request.param
return DNCMemoryUnitCell(input_size=config['input_size'], memory_length=config["memory_length"],
memory_width=config["memory_width"],
read_heads=config["read_heads"], seed=config["seed"],
reuse=False, name='test_mu'), config
@pytest.fixture()
def session():
with tf.Session() as sess:
yield sess
tf.reset_default_graph()
@pytest.fixture()
def np_rng():
seed = np.random.randint(1, 999)
return np.random.RandomState(seed)
class TestDNCMemoryUnit():
def test_zero_state(self, memory_config, session):
memory_unit, config = memory_config
init_tuple = memory_unit.zero_state(batch_size=config['batch_size'], dtype=tf.float32)
# test init_tuple
init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings, init_link_mat, init_read_weighting = init_tuple
assert init_memory.eval().shape == (config['batch_size'], config['memory_length'], config['memory_width'])
assert init_usage_vector.eval().shape == (config['batch_size'], config['memory_length'])
assert init_write_weighting.eval().shape == (config['batch_size'], config['memory_length'])
assert init_precedence_weightings.eval().shape == (config['batch_size'], config['memory_length'])
assert init_link_mat.eval().shape == (config['batch_size'], config['memory_length'], config['memory_length'])
assert init_read_weighting.eval().shape == (config['batch_size'], config["read_heads"], config['memory_length'])
def test_parameter_amount(self, memory_config, session):
memory_unit, config = memory_config
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 5 * config['read_heads'] + 3)
inputs = np.ones([config['batch_size'], config['input_size']])
tf_input = tf.constant(inputs, tf.float32)
memory_unit._weight_input(tf_input)
parameter_amount = memory_unit.parameter_amount
assert parameter_amount == (config['input_size'] + 1) * total_signal_size
def test_create_constant_value_tensors(self, memory_config, session):
memory_unit, config = memory_config
link_matrix_inv_eye, memory_ones, batch_memory_range = memory_unit._create_constant_value_tensors(
batch_size=config['batch_size'], dtype=tf.float32)
np_link_matrix_inv_eye = np.ones([config['memory_length'], config['memory_length']]) - np.eye(
config['memory_length'])
assert np.array_equal(link_matrix_inv_eye.eval(), np_link_matrix_inv_eye)
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
assert np.array_equal(memory_ones.eval(), np_memory_ones)
np_batch_range = np.arange(0, config['batch_size'])
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
np.expand_dims(np_repeat_memory_length, 0))
assert np.array_equal(batch_memory_range.eval(), np_batch_memory_range)
def test_weight_input(self, memory_config, session):
memory_unit, config = memory_config
mu_weight_test = DNCMemoryUnitCell(memory_length=config["memory_length"], memory_width=config["memory_width"],
read_heads=config["read_heads"], input_size=config['input_size'],
seed=config["seed"],
reuse=False, name='dnc_mu_weight_test')
inputs = np.ones([config['batch_size'], config['input_size']])
tf_input = tf.placeholder(tf.float32, [config['batch_size'], config['input_size']], name='x')
weight_inputs = mu_weight_test._weight_input(tf_input)
session.run(tf.global_variables_initializer())
np_weight_inputs = weight_inputs.eval(session=session, feed_dict={tf_input: inputs})
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 5 * config['read_heads'] + 3)
assert np_weight_inputs.shape == (config['batch_size'], total_signal_size)
def test_create_control_signals(self, memory_config, session):
memory_unit, config = memory_config
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 5 * config['read_heads'] + 3)
np_weighted_input = np.array([np.arange(1, 1 + total_signal_size)] * config['batch_size'])
weighted_input = tf.constant(np_weighted_input, dtype=tf.float32)
memory_unit.h_B = config['batch_size']
control_signals = memory_unit._create_control_signals(weighted_input)
control_signals = session.run(control_signals)
alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vector, \
erase_vector, read_keys, read_strengths, read_modes = control_signals
assert alloc_gates.shape == (config['batch_size'], 1)
assert 0 <= alloc_gates.min() and alloc_gates.max() <= 1
assert free_gates.shape == (config['batch_size'], config['read_heads'], 1)
assert 0 <= free_gates.min() and free_gates.max() <= 1
assert write_gates.shape == (config['batch_size'], 1)
assert 0 <= write_gates.min() and write_gates.max() <= 1
assert write_keys.shape == (config['batch_size'], 1, config['memory_width'])
assert write_strengths.shape == (config['batch_size'], 1)
assert 1 <= write_strengths.min()
assert write_vector.shape == (config['batch_size'], 1, config['memory_width'])
assert erase_vector.shape == (config['batch_size'], 1, config['memory_width'])
assert 0 <= erase_vector.min() and erase_vector.max() <= 1
# comment
assert read_keys.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
assert read_strengths.shape == (config['batch_size'], config['read_heads'], 1)
assert 1 <= read_strengths.min()
assert read_modes.shape == (config['batch_size'], config['read_heads'], 3) # 3 read modes
assert 0 <= read_modes.min() and read_modes.max() <= 1 and read_modes.sum(axis=2).all() == 1
def test_update_alloc_weightings_and_usage_vectors(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_pre_write_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
prw_rand = np.arange(0, config['memory_length']) / config['memory_length']
np_pre_read_weightings = np.stack([prw_rand, ] * config['read_heads'], 0)
np_pre_read_weightings = np.stack([np_pre_read_weightings, ] * config['batch_size'], 0)
np_pre_usage_vectors = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_free_gates = np.ones([config['batch_size'], config['read_heads'], 1]) * 0.5
pre_write_weightings = tf.constant(np_pre_write_weightings, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_usage_vectors = tf.constant(np_pre_usage_vectors, dtype=tf.float32)
free_gates = tf.constant(np_free_gates, dtype=tf.float32)
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_link_matrix = np.zeros([config['batch_size'], config['memory_length'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
inputs = tf.constant(np_inputs, dtype=tf.float32)
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
pre_read_weightings)
memory_unit.zero_state(config['batch_size'])
memory_unit(inputs, pre_states) # just for initialization
alloc_weightings, usage_vectors = memory_unit._update_alloc_and_usage_vectors(pre_write_weightings,
pre_read_weightings,
pre_usage_vectors, free_gates)
alloc_weightings, usage_vectors = session.run([alloc_weightings, usage_vectors])
np_retention_vector = np.prod(1 - np_free_gates * np_pre_read_weightings, axis=1, keepdims=False)
np_usage_vectors = (
np_pre_usage_vectors + np_pre_write_weightings - np_pre_usage_vectors * np_pre_write_weightings) * np_retention_vector
assert usage_vectors.shape == (config['batch_size'], config['memory_length'])
assert usage_vectors.min() >= 0 and usage_vectors.max() <= 1
assert np.allclose(usage_vectors, np_usage_vectors)
free_list = np.argsort(np_usage_vectors).astype(int)
np_alloc_weightings = np.zeros([config['batch_size'], config['memory_length']])
for b in range(config['batch_size']):
for j in range(config['memory_length']):
fj = free_list[b, j]
np_alloc_weightings[b, fj] = (1 - np_usage_vectors[b, fj]) * np.prod(
[np_usage_vectors[b, free_list[b, i]] for i in range(j)])
assert alloc_weightings.shape == (config['batch_size'], config['memory_length'])
assert np.allclose(alloc_weightings, np_alloc_weightings)
def test_update_write_weighting(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_alloc_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_write_content_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_write_gate = np.ones([config['batch_size'], 1]) * 0.5
np_alloc_gate = np.ones([config['batch_size'], 1]) * 0.5
alloc_weighting = tf.constant(np_alloc_weighting, dtype=tf.float32)
write_content_weighting = tf.constant(np_write_content_weighting, dtype=tf.float32)
write_gate = tf.constant(np_write_gate, dtype=tf.float32)
alloc_gate = tf.constant(np_alloc_gate, dtype=tf.float32)
write_weighting = memory_unit._update_write_weighting(alloc_weighting, write_content_weighting, write_gate,
alloc_gate)
write_weighting = write_weighting.eval()
np_write_weighting = np_write_gate * (
np_alloc_gate * np_alloc_weighting + (1 - np_alloc_gate) * np_write_content_weighting)
assert write_weighting.shape == (config['batch_size'], config['memory_length'])
assert 0 <= write_weighting.min() and write_weighting.max() <= 1 and write_weighting.sum(axis=1).all() <= 1
assert np.allclose(write_weighting, np_write_weighting)
def test_update_memory(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_write_vector = np_rng.normal(0, 2, [config['batch_size'], 1, config['memory_width']])
np_erase_vector = np_rng.uniform(0, 1, [config['batch_size'], 1, config['memory_width']])
write_weighting = tf.constant(np_write_weighting, dtype=tf.float32)
write_vector = tf.constant(np_write_vector, dtype=tf.float32)
erase_vector = tf.constant(np_erase_vector, dtype=tf.float32)
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_link_matrix = np.zeros([config['batch_size'], config['memory_length'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
inputs = tf.constant(np_inputs, dtype=tf.float32)
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
pre_read_weightings)
memory_unit.zero_state(config['batch_size'])
memory_unit(inputs, pre_states) # just for initialization
memory = memory_unit._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
memory = memory.eval()
write_w = np.expand_dims(np_write_weighting, 2)
np_erase_memory = (1 - (write_w * np_erase_vector))
np_add_memory = np.matmul(write_w, np_write_vector)
np_memory = np_pre_memory * np_erase_memory + np_add_memory
assert memory.shape == (config['batch_size'], config['memory_length'], config['memory_width'])
assert np.allclose(memory, np_memory, atol=1e-06)
def test_update_link_matrix(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
write_weighting = tf.constant(np_write_weighting, dtype=tf.float32)
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_link_matrix = np.zeros([config['batch_size'], config['memory_length'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
inputs = tf.constant(np_inputs, dtype=tf.float32)
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
pre_read_weightings)
memory_unit.zero_state(config['batch_size'])
memory_unit(inputs, pre_states) # just for initialization
link_matrix, precedence_weighting = memory_unit._update_link_matrix(pre_link_matrix, write_weighting,
pre_precedence_weighting)
link_matrix, precedence_weighting = session.run([link_matrix, precedence_weighting])
np_precedence_weighting = (1 - np.sum(np_write_weighting, axis=1,
keepdims=True)) * np_pre_precedence_weighting + np_write_weighting
for b in range(config['batch_size']):
for i in range(config['memory_length']):
for j in range(config['memory_length']):
if i == j:
np_pre_link_matrix[b, i, j] = 0
else:
np_pre_link_matrix[b, i, j] = (1 - np_write_weighting[b, i] - np_write_weighting[b, j]) * \
np_pre_link_matrix[b, i, j] + np_write_weighting[b, i] * \
np_pre_precedence_weighting[b, j]
np_link_matrix = np_pre_link_matrix
assert precedence_weighting.shape == (config['batch_size'], config['memory_length'])
assert 0 <= precedence_weighting.min() and precedence_weighting.max() <= 1 and precedence_weighting.sum(
axis=1).all() <= 1
assert np.allclose(precedence_weighting, np_precedence_weighting)
assert link_matrix.shape == (config['batch_size'], config['memory_length'], config['memory_length'])
assert np.allclose(link_matrix, np_link_matrix)
def test_make_read_forward_backward_weightings(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_link_matrix = np.zeros([config['batch_size'], config['memory_length'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
link_matrix = tf.constant(np_link_matrix, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
forward_weightings, backward_weightings = memory_unit._make_read_forward_backward_weightings(link_matrix,
pre_read_weightings)
forward_weightings, backward_weightings = session.run([forward_weightings, backward_weightings])
np_forward_weightings = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
np_backward_weightings = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
for b in range(config['batch_size']):
for r in range(config['read_heads']):
np_forward_weightings[b, r, :] = np.matmul(np_link_matrix[b, :, :], np_pre_read_weightings[b, r, :])
np_backward_weightings[b, r, :] = np.matmul(np.transpose(np_link_matrix[b, :, :]),
np_pre_read_weightings[b, r, :])
assert forward_weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
assert 0 <= forward_weightings.min() and forward_weightings.max() <= 1 and forward_weightings.sum(
axis=1).all() <= 1
assert np.allclose(forward_weightings, np_forward_weightings)
assert backward_weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
assert 0 <= backward_weightings.min() and backward_weightings.max() <= 1 and backward_weightings.sum(
axis=1).all() <= 1
assert np.allclose(backward_weightings, np_backward_weightings)
def test_make_read_weightings(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_forward_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
np_backward_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
np_read_content_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'],
config['memory_length']])
np_read_modes = np.reshape(np.repeat([0.2, 0.3, 0.5], config['batch_size'] * config['read_heads']),
[config['batch_size'], config['read_heads'], 3])
forward_weightings = tf.constant(np_forward_weightings, dtype=tf.float32)
backward_weightings = tf.constant(np_backward_weightings, dtype=tf.float32)
read_content_weightings = tf.constant(np_read_content_weightings, dtype=tf.float32)
read_modes = tf.constant(np_read_modes, dtype=tf.float32)
read_weightings = memory_unit._make_read_weightings(forward_weightings, backward_weightings,
read_content_weightings, read_modes)
read_weightings = read_weightings.eval()
np_read_weightings = np_backward_weightings * np.expand_dims(np_read_modes[:, :, 0], 2) + \
np_read_content_weightings * np.expand_dims(np_read_modes[:, :, 1], 2) + \
np_forward_weightings * np.expand_dims(np_read_modes[:, :, 2], 2)
assert read_weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
assert 0 <= read_weightings.min() and read_weightings.max() <= 1 and read_weightings.sum(axis=1).all() <= 1
assert np.allclose(read_weightings, np_read_weightings)
def test_call(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_link_matrix = np.zeros([config['batch_size'], config['memory_length'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
inputs = tf.constant(np_inputs, dtype=tf.float32)
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
pre_read_weightings)
memory_unit.zero_state(config['batch_size'])
read_vectors, states = memory_unit(inputs, pre_states)
session.run(tf.global_variables_initializer())
read_vectors, states = session.run([read_vectors, states])
# test const initialization
np_link_matrix_inv_eye = np.ones([config['memory_length'], config['memory_length']]) - np.eye(
config['memory_length'])
assert np.array_equal(memory_unit.const_link_matrix_inv_eye.eval(), np_link_matrix_inv_eye)
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
assert np.array_equal(memory_unit.const_memory_ones.eval(), np_memory_ones)
np_batch_range = np.arange(0, config['batch_size'])
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
np.expand_dims(np_repeat_memory_length, 0))
assert np.array_equal(memory_unit.const_batch_memory_range.eval(), np_batch_memory_range)
assert read_vectors.shape == (
config['batch_size'], config['memory_width'] * config['read_heads'] + config['input_size'])

View File

@ -0,0 +1,194 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf
import pytest
from adnc.model.memory_units.multi_write_content_based_cell import MWContentMemoryUnitCell
@pytest.fixture(
params=[{"seed": 123, "input_size": 13, "batch_size": 3, "memory_length": 4, "memory_width": 4, "read_heads": 3,
"write_heads": 3, "dnc_norm": True, "bypass_dropout": False},
{"seed": 124, "input_size": 11, "batch_size": 3, "memory_length": 256, "memory_width": 23, "read_heads": 2,
"write_heads": 2, "dnc_norm": False, "bypass_dropout": False},
{"seed": 125, "input_size": 5, "batch_size": 3, "memory_length": 4, "memory_width": 11, "read_heads": 8,
"write_heads": 5, "dnc_norm": True, "bypass_dropout": True},
{"seed": 126, "input_size": 2, "batch_size": 3, "memory_length": 56, "memory_width": 9, "read_heads": 11,
"write_heads": 9, "dnc_norm": False, "bypass_dropout": True}
])
def memory_config(request):
config = request.param
return MWContentMemoryUnitCell(input_size=config['input_size'], memory_length=config["memory_length"],
memory_width=config["memory_width"], write_heads=config["write_heads"],
read_heads=config["read_heads"], seed=config["seed"],
reuse=False, name='test_mu'), config
@pytest.fixture()
def session():
with tf.Session() as sess:
yield sess
tf.reset_default_graph()
@pytest.fixture()
def np_rng():
seed = np.random.randint(1, 999)
return np.random.RandomState(seed)
class TestMWContentMemoryUnitCell():
def test_parameter_amount(self, memory_config, session, np_rng):
memory_unit, config = memory_config
total_signal_size = (
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 2 * config['read_heads'] + 3 *
config["write_heads"])
inputs = np.ones([config['batch_size'], config['input_size']])
tf_input = tf.constant(inputs, tf.float32)
memory_unit._weight_input(tf_input)
parameter_amount = memory_unit.parameter_amount
assert parameter_amount == (config['input_size'] + 1) * total_signal_size
def test_create_constant_value_tensors(self, memory_config, session, np_rng):
memory_unit, config = memory_config
memory_ones, batch_memory_range = memory_unit._create_constant_value_tensors(
batch_size=config['batch_size'], dtype=tf.float32)
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
assert np.array_equal(memory_ones.eval(), np_memory_ones)
np_batch_range = np.arange(0, config['batch_size'])
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
np.expand_dims(np_repeat_memory_length, 0))
assert np.array_equal(batch_memory_range.eval(), np_batch_memory_range)
def test_zero_state(self, memory_config, session, np_rng):
memory_unit, config = memory_config
init_tuple = memory_unit.zero_state(batch_size=config['batch_size'], dtype=tf.float32)
# test init_tuple
init_memory, init_usage_vector, init_write_weighting, init_read_weighting = init_tuple
assert init_memory.eval().shape == (config['batch_size'], config['memory_length'], config['memory_width'])
assert init_usage_vector.eval().shape == (config['batch_size'], config['memory_length'])
assert init_write_weighting.eval().shape == (
config['batch_size'], config["write_heads"], config['memory_length'])
assert init_read_weighting.eval().shape == (config['batch_size'], config["read_heads"], config['memory_length'])
def test_weight_input(self, memory_config, session, np_rng):
memory_unit, config = memory_config
inputs = np.ones([config['batch_size'], config['input_size']])
tf_input = tf.placeholder(tf.float32, [config['batch_size'], config['input_size']], name='x')
weight_inputs = memory_unit._weight_input(tf_input)
session.run(tf.global_variables_initializer())
np_weight_inputs = weight_inputs.eval(session=session, feed_dict={tf_input: inputs})
total_signal_size = (
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 2 * config['read_heads'] + 3 *
config["write_heads"])
assert np_weight_inputs.shape == (config['batch_size'], total_signal_size)
def test_create_control_signals(self, memory_config, session, np_rng):
memory_unit, config = memory_config
total_signal_size = (
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 2 * config['read_heads'] + 3 *
config["write_heads"])
np_weighted_input = np.array([np.arange(1, 1 + total_signal_size)] * config['batch_size'])
weighted_input = tf.constant(np_weighted_input, dtype=tf.float32)
memory_unit.h_B = config['batch_size']
control_signals = memory_unit._create_control_signals(weighted_input)
control_signals = session.run(control_signals)
alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
erase_vector, read_keys, read_strengths = control_signals
assert alloc_gates.shape == (config['batch_size'], config['write_heads'], 1)
assert 0 <= alloc_gates.min() and alloc_gates.max() <= 1
assert free_gates.shape == (config['batch_size'], config['read_heads'], 1)
assert 0 <= free_gates.min() and free_gates.max() <= 1
assert write_gates.shape == (config['batch_size'], config['write_heads'], 1)
assert 0 <= write_gates.min() and write_gates.max() <= 1
assert write_keys.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
assert write_strengths.shape == (config['batch_size'], config['write_heads'], 1)
assert 1 <= write_strengths.min()
assert write_vectors.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
assert erase_vector.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
assert 0 <= erase_vector.min() and erase_vector.max() <= 1
assert read_keys.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
assert read_strengths.shape == (config['batch_size'], config['read_heads'], 1)
assert 1 <= read_strengths.min()
def test_read_memory(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
memory = tf.constant(np_memory, dtype=tf.float32)
read_weightings = tf.constant(np_read_weightings, dtype=tf.float32)
read_vectors = memory_unit._read_memory(memory, read_weightings)
read_vectors = read_vectors.eval()
np_read_vectors = np.empty([config['batch_size'], config['read_heads'], config['memory_width']])
for b in range(config['batch_size']):
for r in range(config['read_heads']):
np_read_vectors[b, r, :] = np.matmul(np.expand_dims(np_read_weightings[b, r, :], 0), np_memory[b, :, :])
assert read_vectors.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
assert np.allclose(read_vectors, np_read_vectors, atol=1e-06)
def test_call(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['write_heads'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
inputs = tf.constant(np_inputs, dtype=tf.float32)
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_read_weightings)
memory_unit.zero_state(config['batch_size'])
read_vectors, states = memory_unit(inputs, pre_states)
session.run(tf.global_variables_initializer())
read_vectors, states = session.run([read_vectors, states])
assert read_vectors.shape == (
config['batch_size'], config['memory_width'] * config['read_heads'] + config['input_size'])

View File

@ -0,0 +1,521 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import tensorflow as tf
import pytest
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
@pytest.fixture(
params=[{"seed": 123, "input_size": 13, "batch_size": 3, "memory_length": 4, "memory_width": 4, "read_heads": 3,
"write_heads": 3, "dnc_norm": True, "bypass_dropout": False},
{"seed": 124, "input_size": 11, "batch_size": 3, "memory_length": 256, "memory_width": 23, "read_heads": 2,
"write_heads": 2, "dnc_norm": False, "bypass_dropout": False},
{"seed": 125, "input_size": 5, "batch_size": 3, "memory_length": 4, "memory_width": 11, "read_heads": 8,
"write_heads": 5, "dnc_norm": True, "bypass_dropout": True},
{"seed": 126, "input_size": 2, "batch_size": 3, "memory_length": 56, "memory_width": 9, "read_heads": 11,
"write_heads": 9, "dnc_norm": False, "bypass_dropout": True}
])
def memory_config(request):
config = request.param
return MWDNCMemoryUnitCell(input_size=config['input_size'], memory_length=config["memory_length"],
memory_width=config["memory_width"], write_heads=config["write_heads"],
read_heads=config["read_heads"], seed=config["seed"],
reuse=False, name='test_mu'), config
@pytest.fixture()
def session():
with tf.Session() as sess:
yield sess
tf.reset_default_graph()
@pytest.fixture()
def np_rng():
seed = np.random.randint(1, 999)
return np.random.RandomState(seed)
class TestMWDNCMemoryUnit():
def test_init(self, memory_config, session, np_rng):
memory_unit, config = memory_config
assert isinstance(memory_unit, object)
assert isinstance(memory_unit.rng, np.random.RandomState)
assert memory_unit.h_N == config["memory_length"]
assert memory_unit.h_W == config["memory_width"]
assert memory_unit.h_RH == config["read_heads"]
assert memory_unit.h_WH == config["write_heads"]
def test_parameter_amount(self, memory_config, session, np_rng):
memory_unit, config = memory_config
total_signal_size = (
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 3 * config['read_heads'] + 3 *
config["write_heads"] + 2 * config['read_heads'] * config["write_heads"])
inputs = np.ones([config['batch_size'], config['input_size']])
tf_input = tf.constant(inputs, tf.float32)
memory_unit._weight_input(tf_input)
parameter_amount = memory_unit.parameter_amount
assert parameter_amount == (config['input_size'] + 1) * total_signal_size
def test_create_constant_value_tensors(self, memory_config, session, np_rng):
memory_unit, config = memory_config
link_matrix_inv_eye, memory_ones, batch_memory_range = memory_unit._create_constant_value_tensors(
batch_size=config['batch_size'], dtype=tf.float32)
np_link_matrix_inv_eye = np.ones([config['memory_length'], config['memory_length']]) - np.eye(
config['memory_length'])
np_link_matrix_inv_eye = np.stack([np_link_matrix_inv_eye, ] * config["write_heads"], axis=0)
np_link_matrix_inv_eye = np.stack([np_link_matrix_inv_eye, ] * config['batch_size'], axis=0)
assert np.array_equal(link_matrix_inv_eye.eval(), np_link_matrix_inv_eye)
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
assert np.array_equal(memory_ones.eval(), np_memory_ones)
np_batch_range = np.arange(0, config['batch_size'])
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
np.expand_dims(np_repeat_memory_length, 0))
assert np.array_equal(batch_memory_range.eval(), np_batch_memory_range)
def test_zero_state(self, memory_config, session, np_rng):
memory_unit, config = memory_config
init_tuple = memory_unit.zero_state(batch_size=config['batch_size'], dtype=tf.float32)
# test init_tuple
init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings, init_link_mat, init_read_weighting = init_tuple
assert init_memory.eval().shape == (config['batch_size'], config['memory_length'], config['memory_width'])
assert init_usage_vector.eval().shape == (config['batch_size'], config['memory_length'])
assert init_write_weighting.eval().shape == (
config['batch_size'], config["write_heads"], config['memory_length'])
assert init_precedence_weightings.eval().shape == (
config['batch_size'], config['write_heads'], config['memory_length'])
assert init_link_mat.eval().shape == (
config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length'])
assert init_read_weighting.eval().shape == (config['batch_size'], config["read_heads"], config['memory_length'])
def test_weight_input(self, memory_config, session, np_rng):
memory_unit, config = memory_config
inputs = np.ones([config['batch_size'], config['input_size']])
tf_input = tf.placeholder(tf.float32, [config['batch_size'], config['input_size']], name='x')
weight_inputs = memory_unit._weight_input(tf_input)
session.run(tf.global_variables_initializer())
np_weight_inputs = weight_inputs.eval(session=session, feed_dict={tf_input: inputs})
total_signal_size = (
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 3 * config['read_heads'] + 3 *
config["write_heads"] + 2 * config['read_heads'] * config["write_heads"])
assert np_weight_inputs.shape == (config['batch_size'], total_signal_size)
def test_create_control_signals(self, memory_config, session, np_rng):
memory_unit, config = memory_config
total_signal_size = (
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 3 * config['read_heads'] + 3 *
config["write_heads"] + 2 * config['read_heads'] * config["write_heads"])
np_weighted_input = np.array([np.arange(1, 1 + total_signal_size)] * config['batch_size'])
weighted_input = tf.constant(np_weighted_input, dtype=tf.float32)
memory_unit.h_B = config['batch_size']
control_signals = memory_unit._create_control_signals(weighted_input)
control_signals = session.run(control_signals)
alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
erase_vector, read_keys, read_strengths, read_modes = control_signals
assert alloc_gates.shape == (config['batch_size'], config['write_heads'], 1)
assert 0 <= alloc_gates.min() and alloc_gates.max() <= 1
assert free_gates.shape == (config['batch_size'], config['read_heads'], 1)
assert 0 <= free_gates.min() and free_gates.max() <= 1
assert write_gates.shape == (config['batch_size'], config['write_heads'], 1)
assert 0 <= write_gates.min() and write_gates.max() <= 1
assert write_keys.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
assert write_strengths.shape == (config['batch_size'], config['write_heads'], 1)
assert 1 <= write_strengths.min()
assert write_vectors.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
assert erase_vector.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
assert 0 <= erase_vector.min() and erase_vector.max() <= 1
assert read_keys.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
assert read_strengths.shape == (config['batch_size'], config['read_heads'], 1)
assert 1 <= read_strengths.min()
assert read_modes.shape == (
config['batch_size'], config['read_heads'], 1 + 2 * config['write_heads']) # 3 read modes
assert 0 <= read_modes.min() and read_modes.max() <= 1 and read_modes.sum(axis=2).all() == 1
def test_update_alloc_weightings_and_usage_vectors(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_pre_link_matrix = np.zeros(
[config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length']])
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['write_heads'],
config['memory_length']])
np_pre_write_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['write_heads'], config['memory_length']])
prw_rand = np.arange(0, config['memory_length']) / config['memory_length']
np_pre_read_weightings = np.stack([prw_rand, ] * config['read_heads'], 0)
np_pre_read_weightings = np.stack([np_pre_read_weightings, ] * config['batch_size'], 0)
np_pre_usage_vectors = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_free_gates = np.ones([config['batch_size'], config['read_heads'], 1]) * 0.5
np_write_gates = np.ones([config['batch_size'], config['write_heads'], 1]) * 0.5
inputs = tf.constant(np_inputs, dtype=tf.float32)
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
pre_write_weightings = tf.constant(np_pre_write_weightings, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_usage_vectors = tf.constant(np_pre_usage_vectors, dtype=tf.float32)
free_gates = tf.constant(np_free_gates, dtype=tf.float32)
write_gates = tf.constant(np_write_gates, dtype=tf.float32)
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
pre_states = (pre_memory, pre_usage_vectors, pre_write_weightings, pre_precedence_weighting, pre_link_matrix,
pre_read_weightings)
memory_unit.zero_state(config['batch_size'])
memory_unit(inputs, pre_states)
alloc_weightings, usage_vectors = memory_unit._update_alloc_and_usage_vectors(pre_write_weightings,
pre_read_weightings,
pre_usage_vectors, free_gates,
write_gates)
alloc_weightings, usage_vectors = session.run([alloc_weightings, usage_vectors])
np_pre_write_weighting = 1 - np.prod(1 - np_pre_write_weightings, axis=1, keepdims=False)
np_usage_vector = np_pre_usage_vectors + np_pre_write_weighting - np_pre_usage_vectors * np_pre_write_weighting
np_retention_vector = np.prod(1 - np_free_gates * np_pre_read_weightings, axis=1, keepdims=False)
np_usage_vector = np_usage_vector * np_retention_vector
assert usage_vectors.shape == (config['batch_size'], config['memory_length'])
assert usage_vectors.min() >= 0 and usage_vectors.max() <= 1
assert np.allclose(usage_vectors, np_usage_vector, atol=1e-06)
np_alloc_weightings = np.zeros([config['batch_size'], config['write_heads'], config['memory_length']])
for b in range(config['batch_size']):
for w in range(config['write_heads']):
free_list = np.argsort(np_usage_vector, axis=1)
for j in range(config['memory_length']):
np_alloc_weightings[b, w, free_list[b, j]] = (1 - np_usage_vector[b, free_list[b, j]]) * np.prod(
[np_usage_vector[b, free_list[b, i]] for i in range(j)])
np_usage_vector[b, :] += (
(1 - np_usage_vector[b, :]) * np_write_gates[b, w, :] * np_alloc_weightings[b, w, :])
assert alloc_weightings.shape == (config['batch_size'], config['write_heads'], config['memory_length'])
assert np.allclose(alloc_weightings, np_alloc_weightings, atol=1e-06)
def test_calculate_content_weightings(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_memory = np_rng.uniform(0, 1, (config['batch_size'], config['memory_length'], config['memory_width']))
np_keys = np_rng.normal(0, 2, (config['batch_size'], config['read_heads'], config['memory_width']))
np_strengths = np_rng.uniform(1, 10, (config['batch_size'], config['read_heads'], 1))
memory = tf.constant(np_memory, dtype=tf.float32)
keys = tf.constant(np_keys, dtype=tf.float32)
strengths = tf.constant(np_strengths, dtype=tf.float32)
content_weightings = memory_unit._calculate_content_weightings(memory, keys, strengths)
weightings = content_weightings.eval()
np_similarity = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
for b in range(config['batch_size']):
for r in range(config['read_heads']):
for l in range(config['memory_length']):
np_similarity[b, r, l] = np.dot(np_memory[b, l, :], np_keys[b, r, :]) / (
np.sqrt(np.dot(np_memory[b, l, :], np_memory[b, l, :])) * np.sqrt(
np.dot(np_keys[b, r, :], np_keys[b, r, :])))
np_weightings = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
def _weighted_softmax(x, s):
e_x = np.exp(x * s)
return e_x / e_x.sum(axis=1, keepdims=True)
for r in range(config['read_heads']):
np_weightings[:, r, :] = _weighted_softmax(np_similarity[:, r, :], np_strengths[:, r])
assert weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
assert 0 <= weightings.min() and weightings.max() <= 1 and weightings.sum(axis=2).all() <= 1
assert np.allclose(weightings, np_weightings)
def test_update_write_weightings(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_alloc_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['write_heads'], config['memory_length']])
np_write_content_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['write_heads'],
config['memory_length']])
np_write_gate = np.ones([config['batch_size'], config['write_heads'], 1]) * 0.5
np_alloc_gate = np.ones([config['batch_size'], config['write_heads'], 1]) * 0.5
alloc_weightings = tf.constant(np_alloc_weightings, dtype=tf.float32)
write_content_weightings = tf.constant(np_write_content_weighting, dtype=tf.float32)
write_gates = tf.constant(np_write_gate, dtype=tf.float32)
alloc_gates = tf.constant(np_alloc_gate, dtype=tf.float32)
write_weightings = memory_unit._update_write_weightings(alloc_weightings, write_content_weightings, write_gates,
alloc_gates)
write_weightings = write_weightings.eval()
np_write_weightings = np_write_gate * (
np_alloc_gate * np_alloc_weightings + (1 - np_alloc_gate) * np_write_content_weighting)
assert write_weightings.shape == (config['batch_size'], config['write_heads'], config['memory_length'])
assert 0 <= write_weightings.min() and write_weightings.max() <= 1 and write_weightings.sum(axis=2).all() <= 1
assert np.allclose(write_weightings, np_write_weightings)
def test_update_memory(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['write_heads'], config['memory_length']])
np_write_vector = np_rng.normal(0, 2, [config['batch_size'], config['write_heads'], config['memory_width']])
np_erase_vector = np_rng.uniform(0, 1, [config['batch_size'], config['write_heads'], config['memory_width']])
pre_memory = tf.constant(np_memory, dtype=tf.float32)
write_weighting = tf.constant(np_write_weighting, dtype=tf.float32)
write_vector = tf.constant(np_write_vector, dtype=tf.float32)
erase_vector = tf.constant(np_erase_vector, dtype=tf.float32)
memory_unit.zero_state(config['batch_size'])
memory = memory_unit._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
memory = memory.eval()
np_erase_memory = (1 - np.expand_dims(np_write_weighting, 3) * np.expand_dims(np_erase_vector, 2))
np_erase_memory = np.prod(np_erase_memory, axis=1, keepdims=False)
np_add_memory = np.matmul(np.transpose(np_write_weighting, (0, 2, 1)), np_write_vector)
np_memory = np_memory * np_erase_memory + np_add_memory
assert memory.shape == (config['batch_size'], config['memory_length'], config['memory_width'])
assert np.allclose(memory, np_memory, atol=1e-06)
def test_update_link_matrix(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['write_heads'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
np_pre_link_matrix = np.zeros(
[config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length']])
np_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['write_heads'], config['memory_length']])
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['write_heads'],
config['memory_length']])
inputs = tf.constant(np_inputs, dtype=tf.float32)
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
write_weighting = tf.constant(np_write_weighting, dtype=tf.float32)
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
pre_read_weightings)
memory_unit.zero_state(config['batch_size'])
memory_unit(inputs, pre_states)
link_matrix, precedence_weighting = memory_unit._update_link_matrix(pre_link_matrix, write_weighting,
pre_precedence_weighting)
link_matrix, precedence_weighting = session.run([link_matrix, precedence_weighting])
np_precedence_weighting = (1 - np.sum(np_write_weighting, axis=2,
keepdims=True)) * np_pre_precedence_weighting + np_write_weighting
for b in range(config['batch_size']):
for w in range(config['write_heads']):
for i in range(config['memory_length']):
for j in range(config['memory_length']):
if i == j:
np_pre_link_matrix[b, w, i, j] = 0
else:
np_pre_link_matrix[b, w, i, j] = (1 - np_write_weighting[b, w, i] - np_write_weighting[
b, w, j]) * np_pre_link_matrix[b, w, i, j] + \
np_write_weighting[b, w, i] * np_pre_precedence_weighting[
b, w, j]
np_link_matrix = np_pre_link_matrix
assert precedence_weighting.shape == (config['batch_size'], config['write_heads'], config['memory_length'])
assert 0 <= precedence_weighting.min() and precedence_weighting.max() <= 1 and precedence_weighting.sum(
axis=1).all() <= 1
assert np.allclose(precedence_weighting, np_precedence_weighting)
assert link_matrix.shape == (
config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length'])
assert np.allclose(link_matrix, np_link_matrix)
def test_make_read_forward_backward_weightings(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_link_matrix = np.zeros(
[config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
link_matrix = tf.constant(np_link_matrix, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
forward_weightings, backward_weightings = memory_unit._make_read_forward_backward_weightings(link_matrix,
pre_read_weightings)
forward_weightings, backward_weightings = session.run([forward_weightings, backward_weightings])
np_forward_weightings = np.empty(
[config['batch_size'], config['read_heads'], config['write_heads'], config['memory_length']])
np_backward_weightings = np.empty(
[config['batch_size'], config['read_heads'], config['write_heads'], config['memory_length']])
for b in range(config['batch_size']):
for r in range(config['read_heads']):
for w in range(config['write_heads']):
np_forward_weightings[b, r, w, :] = np.matmul(np_pre_read_weightings[b, r, :],
np_link_matrix[b, w, :, :])
np_backward_weightings[b, r, w, :] = np.matmul(np_pre_read_weightings[b, r, :],
np.transpose(np_link_matrix[b, w, :, :]))
assert forward_weightings.shape == (
config['batch_size'], config['read_heads'], config['write_heads'], config['memory_length'])
assert 0 <= forward_weightings.min() and forward_weightings.max() <= 1 and forward_weightings.sum(
axis=3).all() <= 1
assert np.allclose(forward_weightings, np_forward_weightings)
assert backward_weightings.shape == (
config['batch_size'], config['read_heads'], config['write_heads'], config['memory_length'])
assert 0 <= backward_weightings.min() and backward_weightings.max() <= 1 and backward_weightings.sum(
axis=3).all() <= 1
assert np.allclose(backward_weightings, np_backward_weightings)
def test_make_read_weightings(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_forward_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['write_heads'],
config['memory_length']])
np_backward_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['write_heads'],
config['memory_length']])
np_read_content_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'],
config['memory_length']])
np_read_modes = np.reshape(
np.repeat([0.1, ], config['batch_size'] * config['read_heads'] * (2 * config['write_heads'] + 1)),
[config['batch_size'], config['read_heads'], 1 + 2 * config['write_heads']])
forward_weightings = tf.constant(np_forward_weightings, dtype=tf.float32)
backward_weightings = tf.constant(np_backward_weightings, dtype=tf.float32)
read_content_weightings = tf.constant(np_read_content_weightings, dtype=tf.float32)
read_modes = tf.constant(np_read_modes, dtype=tf.float32)
read_weightings = memory_unit._make_read_weightings(forward_weightings, backward_weightings,
read_content_weightings, read_modes)
read_weightings = read_weightings.eval()
np_read_weightings = np.sum(
np_backward_weightings * np.expand_dims(np_read_modes[:, :, : config['write_heads']], 3), axis=2) + \
np_read_content_weightings * np.expand_dims(np_read_modes[:, :, config['write_heads']],
2) + \
np.sum(
np_forward_weightings * np.expand_dims(np_read_modes[:, :, config['write_heads'] + 1:],
3), axis=2)
assert read_weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
assert 0 <= read_weightings.min() and read_weightings.max() <= 1 and read_weightings.sum(axis=1).all() <= 1
assert np.allclose(read_weightings, np_read_weightings)
def test_read_memory(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
memory = tf.constant(np_memory, dtype=tf.float32)
read_weightings = tf.constant(np_read_weightings, dtype=tf.float32)
read_vectors = memory_unit._read_memory(memory, read_weightings)
read_vectors = read_vectors.eval()
np_read_vectors = np.empty([config['batch_size'], config['read_heads'], config['memory_width']])
for b in range(config['batch_size']):
for r in range(config['read_heads']):
np_read_vectors[b, r, :] = np.matmul(np.expand_dims(np_read_weightings[b, r, :], 0), np_memory[b, :, :])
assert read_vectors.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
assert np.allclose(read_vectors, np_read_vectors, atol=1e-06)
def test_call(self, memory_config, session, np_rng):
memory_unit, config = memory_config
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['memory_length']])
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['write_heads'], config['memory_length']])
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['write_heads'],
config['memory_length']])
np_pre_link_matrix = np.zeros(
[config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length']])
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
[config['batch_size'], config['read_heads'], config['memory_length']])
inputs = tf.constant(np_inputs, dtype=tf.float32)
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
pre_read_weightings)
memory_unit.zero_state(config['batch_size'])
read_vectors, states = memory_unit(inputs, pre_states)
session.run(tf.global_variables_initializer())
read_vectors, states = session.run([read_vectors, states])
assert read_vectors.shape == (
config['batch_size'], config['memory_width'] * config['read_heads'] + config['input_size'])