mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 22:08:04 +08:00
add memory units and test
This commit is contained in:
parent
11637635f0
commit
0fe3939a17
14
adnc/model/memory_units/__init__.py
Normal file
14
adnc/model/memory_units/__init__.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
90
adnc/model/memory_units/base_cell.py
Normal file
90
adnc/model/memory_units/base_cell.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
from abc import abstractmethod, ABCMeta
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMemoryUnitCell():
|
||||||
|
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,
|
||||||
|
seed=100, reuse=False, analyse=False, dtype=tf.float32, name='base'):
|
||||||
|
|
||||||
|
self.rng = np.random.RandomState(seed=seed)
|
||||||
|
self.seed = seed
|
||||||
|
self.dtype = dtype
|
||||||
|
self.analyse = analyse
|
||||||
|
|
||||||
|
# dnc parameters
|
||||||
|
self.input_size = input_size
|
||||||
|
self.h_N = memory_length
|
||||||
|
self.h_W = memory_width
|
||||||
|
self.h_RH = read_heads
|
||||||
|
|
||||||
|
self.dnc_norm = dnc_norm
|
||||||
|
self.bypass_dropout = bypass_dropout
|
||||||
|
|
||||||
|
self.reuse = reuse
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def state_size(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def zero_state(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_size(self):
|
||||||
|
return self.h_RH * self.h_W + self.input_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trainable_variables(self):
|
||||||
|
return tf.get_collection('memory_unit')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameter_amount(self):
|
||||||
|
var_list = self.trainable_variables
|
||||||
|
parameters = 0
|
||||||
|
for variable in var_list:
|
||||||
|
shape = variable.get_shape()
|
||||||
|
variable_parametes = 1
|
||||||
|
for dim in shape:
|
||||||
|
variable_parametes *= dim.value
|
||||||
|
parameters += variable_parametes
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _calculate_content_weightings(memory, keys, strengths):
|
||||||
|
|
||||||
|
similarity_numerator = tf.matmul(keys, memory, adjoint_b=True)
|
||||||
|
|
||||||
|
norm_memory = tf.sqrt(tf.reduce_sum(tf.square(memory), axis=2, keepdims=True))
|
||||||
|
norm_keys = tf.sqrt(tf.reduce_sum(tf.square(keys), axis=2, keepdims=True))
|
||||||
|
similarity_denominator = tf.matmul(norm_keys, norm_memory, adjoint_b=True)
|
||||||
|
|
||||||
|
similarity = similarity_numerator / similarity_denominator
|
||||||
|
similarity = tf.squeeze(similarity)
|
||||||
|
adjusted_similarity = similarity * strengths
|
||||||
|
|
||||||
|
softmax_similarity = tf.nn.softmax(adjusted_similarity, dim=-1)
|
||||||
|
|
||||||
|
return softmax_similarity
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _read_memory(memory, read_weightings):
|
||||||
|
read_vectors = tf.matmul(read_weightings, memory)
|
||||||
|
return read_vectors
|
158
adnc/model/memory_units/content_based_cell.py
Executable file
158
adnc/model/memory_units/content_based_cell.py
Executable file
@ -0,0 +1,158 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
|
||||||
|
from adnc.model.utils import oneplus
|
||||||
|
from adnc.model.utils import unit_simplex_initialization
|
||||||
|
|
||||||
|
|
||||||
|
class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
init_memory = tf.TensorShape([self.h_N, self.h_W])
|
||||||
|
init_usage_vector = tf.TensorShape([self.h_N])
|
||||||
|
init_write_weighting = tf.TensorShape([self.h_N])
|
||||||
|
init_read_weighting = tf.TensorShape([self.h_RH, self.h_N])
|
||||||
|
return (init_memory, init_usage_vector, init_write_weighting, init_read_weighting)
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype=tf.float32):
|
||||||
|
|
||||||
|
init_memory = tf.fill([batch_size, self.h_N, self.h_W], tf.cast(1 / (self.h_N * self.h_W), dtype=dtype))
|
||||||
|
init_usage_vector = tf.zeros([batch_size, self.h_N], dtype=dtype)
|
||||||
|
init_write_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_N], dtype=dtype)
|
||||||
|
init_read_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_RH, self.h_N], dtype=dtype)
|
||||||
|
zero_states = (init_memory, init_usage_vector, init_write_weighting, init_read_weighting,)
|
||||||
|
|
||||||
|
return zero_states
|
||||||
|
|
||||||
|
def analyse_state(self, batch_size, dtype=tf.float32):
|
||||||
|
|
||||||
|
alloc_gate = tf.zeros([batch_size, 1], dtype=dtype)
|
||||||
|
free_gates = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
|
||||||
|
write_gate = tf.zeros([batch_size, 1], dtype=dtype)
|
||||||
|
write_keys = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
|
||||||
|
write_strengths = tf.zeros([batch_size, 1], dtype=dtype)
|
||||||
|
write_vector = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
|
||||||
|
erase_vector = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
|
||||||
|
read_keys = tf.zeros([batch_size, self.h_RH, self.h_W], dtype=dtype)
|
||||||
|
read_strengths = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
|
||||||
|
|
||||||
|
analyse_states = alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths
|
||||||
|
|
||||||
|
return analyse_states
|
||||||
|
|
||||||
|
def _weight_input(self, inputs):
|
||||||
|
|
||||||
|
input_size = inputs.get_shape()[1].value
|
||||||
|
total_signal_size = (3 + self.h_RH) * self.h_W + 2 * self.h_RH + 3
|
||||||
|
|
||||||
|
with tf.variable_scope('{}'.format(self.name), reuse=self.reuse):
|
||||||
|
w_x = tf.get_variable("mu_w_x", (input_size, total_signal_size),
|
||||||
|
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed),
|
||||||
|
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||||
|
|
||||||
|
b_x = tf.get_variable("mu_b_x", (total_signal_size,), initializer=tf.constant_initializer(0.),
|
||||||
|
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||||
|
|
||||||
|
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||||
|
|
||||||
|
if self.dnc_norm:
|
||||||
|
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype)
|
||||||
|
|
||||||
|
return weighted_input
|
||||||
|
|
||||||
|
def __call__(self, inputs, pre_states, scope=None):
|
||||||
|
|
||||||
|
self.h_B = inputs.get_shape()[0].value
|
||||||
|
|
||||||
|
memory_ones, batch_memory_range = self._create_constant_value_tensors(self.h_B, self.dtype)
|
||||||
|
self.const_memory_ones = memory_ones
|
||||||
|
self.const_batch_memory_range = batch_memory_range
|
||||||
|
|
||||||
|
pre_memory, pre_usage_vector, pre_write_weightings, pre_read_weightings = pre_states
|
||||||
|
|
||||||
|
weighted_input = self._weight_input(inputs)
|
||||||
|
|
||||||
|
control_signals = self._create_control_signals(weighted_input)
|
||||||
|
alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths = control_signals
|
||||||
|
|
||||||
|
alloc_weightings, usage_vector = self._update_alloc_and_usage_vectors(pre_write_weightings, pre_read_weightings,
|
||||||
|
pre_usage_vector, free_gates)
|
||||||
|
write_content_weighting = self._calculate_content_weightings(pre_memory, write_keys, write_strengths)
|
||||||
|
write_weighting = self._update_write_weighting(alloc_weightings, write_content_weighting, write_gate,
|
||||||
|
alloc_gate)
|
||||||
|
memory = self._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
|
||||||
|
read_content_weightings = self._calculate_content_weightings(memory, read_keys, read_strengths)
|
||||||
|
read_vectors = self._read_memory(memory, read_content_weightings)
|
||||||
|
|
||||||
|
read_vectors = tf.reshape(read_vectors, [self.h_B, self.h_W * self.h_RH])
|
||||||
|
|
||||||
|
if self.bypass_dropout:
|
||||||
|
input_bypass = tf.nn.dropout(inputs, self.bypass_dropout)
|
||||||
|
else:
|
||||||
|
input_bypass = inputs
|
||||||
|
|
||||||
|
output = tf.concat([read_vectors, input_bypass], axis=-1)
|
||||||
|
|
||||||
|
if self.analyse:
|
||||||
|
output = (output, control_signals)
|
||||||
|
|
||||||
|
return output, (memory, usage_vector, write_weighting, read_content_weightings)
|
||||||
|
|
||||||
|
def _create_constant_value_tensors(self, batch_size, dtype):
|
||||||
|
|
||||||
|
memory_ones = tf.ones([batch_size, self.h_N, self.h_W], dtype=dtype, name="memory_ones")
|
||||||
|
|
||||||
|
batch_range = tf.range(0, batch_size, delta=1, dtype=tf.int32, name="batch_range")
|
||||||
|
repeat_memory_length = tf.fill([self.h_N], tf.constant(self.h_N, dtype=tf.int32), name="repeat_memory_length")
|
||||||
|
batch_memory_range = tf.matmul(tf.expand_dims(batch_range, -1), tf.expand_dims(repeat_memory_length, 0),
|
||||||
|
name="batch_memory_range")
|
||||||
|
return memory_ones, batch_memory_range
|
||||||
|
|
||||||
|
def _create_control_signals(self, weighted_input):
|
||||||
|
|
||||||
|
write_keys = weighted_input[:, : self.h_W] # W
|
||||||
|
write_strengths = weighted_input[:, self.h_W: self.h_W + 1] # 1
|
||||||
|
erase_vector = weighted_input[:, self.h_W + 1: 2 * self.h_W + 1] # W
|
||||||
|
write_vector = weighted_input[:, 2 * self.h_W + 1: 3 * self.h_W + 1] # W
|
||||||
|
alloc_gates = weighted_input[:, 3 * self.h_W + 1: 3 * self.h_W + 2] # 1
|
||||||
|
write_gates = weighted_input[:, 3 * self.h_W + 2: 3 * self.h_W + 3] # 1
|
||||||
|
read_keys = weighted_input[:, 3 * self.h_W + 3: (self.h_RH + 3) * self.h_W + 3] # R * W
|
||||||
|
read_strengths = weighted_input[:,
|
||||||
|
(self.h_RH + 3) * self.h_W + 3: (self.h_RH + 3) * self.h_W + 3 + 1 * self.h_RH] # R
|
||||||
|
free_gates = weighted_input[:, (self.h_RH + 3) * self.h_W + 3 + 1 * self.h_RH: (
|
||||||
|
self.h_RH + 3) * self.h_W + 3 + 2 * self.h_RH]
|
||||||
|
|
||||||
|
alloc_gates = tf.sigmoid(alloc_gates, 'alloc_gates')
|
||||||
|
free_gates = tf.sigmoid(free_gates, 'free_gates')
|
||||||
|
free_gates = tf.expand_dims(free_gates, 2)
|
||||||
|
write_gates = tf.sigmoid(write_gates, 'write_gates')
|
||||||
|
|
||||||
|
write_keys = tf.expand_dims(write_keys, axis=1)
|
||||||
|
write_strengths = oneplus(write_strengths)
|
||||||
|
write_vector = tf.reshape(write_vector, [self.h_B, 1, self.h_W])
|
||||||
|
erase_vector = tf.sigmoid(erase_vector, 'erase_vector')
|
||||||
|
erase_vector = tf.reshape(erase_vector, [self.h_B, 1, self.h_W])
|
||||||
|
|
||||||
|
read_keys = tf.reshape(read_keys, [self.h_B, self.h_RH, self.h_W])
|
||||||
|
read_strengths = oneplus(read_strengths)
|
||||||
|
read_strengths = tf.expand_dims(read_strengths, axis=2)
|
||||||
|
|
||||||
|
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths
|
255
adnc/model/memory_units/dnc_cell.py
Executable file
255
adnc/model/memory_units/dnc_cell.py
Executable file
@ -0,0 +1,255 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
|
||||||
|
from adnc.model.utils import layer_norm
|
||||||
|
from adnc.model.utils import oneplus
|
||||||
|
from adnc.model.utils import unit_simplex_initialization
|
||||||
|
|
||||||
|
|
||||||
|
class DNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||||
|
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,
|
||||||
|
seed=100, reuse=False, analyse=False, dtype=tf.float32, name='dnc_mu'):
|
||||||
|
|
||||||
|
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
|
||||||
|
analyse, dtype, name)
|
||||||
|
self.h_B = 0 # will set in call
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
init_memory = tf.TensorShape([self.h_N, self.h_W])
|
||||||
|
init_usage_vector = tf.TensorShape([self.h_N])
|
||||||
|
init_write_weighting = tf.TensorShape([self.h_N])
|
||||||
|
init_precedence_weightings = tf.TensorShape([self.h_N])
|
||||||
|
init_link_mat = tf.TensorShape([self.h_N, self.h_N])
|
||||||
|
init_read_weighting = tf.TensorShape([self.h_RH, self.h_N])
|
||||||
|
return (init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings,
|
||||||
|
init_link_mat, init_read_weighting)
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype=tf.float32):
|
||||||
|
|
||||||
|
init_memory = tf.fill([batch_size, self.h_N, self.h_W], tf.cast(1 / (self.h_N * self.h_W), dtype=dtype))
|
||||||
|
init_usage_vector = tf.zeros([batch_size, self.h_N], dtype=dtype)
|
||||||
|
init_write_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_N], dtype=dtype)
|
||||||
|
init_precedence_weightings = tf.zeros([batch_size, self.h_N], dtype=dtype)
|
||||||
|
|
||||||
|
init_link_mat = tf.zeros([batch_size, self.h_N, self.h_N], dtype=dtype)
|
||||||
|
init_read_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_RH, self.h_N], dtype=dtype)
|
||||||
|
zero_states = (init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings,
|
||||||
|
init_link_mat, init_read_weighting,)
|
||||||
|
|
||||||
|
return zero_states
|
||||||
|
|
||||||
|
def analyse_state(self, batch_size, dtype=tf.float32):
|
||||||
|
|
||||||
|
alloc_gate = tf.zeros([batch_size, 1], dtype=dtype)
|
||||||
|
free_gates = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
|
||||||
|
write_gate = tf.zeros([batch_size, 1], dtype=dtype)
|
||||||
|
write_keys = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
|
||||||
|
write_strengths = tf.zeros([batch_size, 1], dtype=dtype)
|
||||||
|
write_vector = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
|
||||||
|
erase_vector = tf.zeros([batch_size, 1, self.h_W], dtype=dtype)
|
||||||
|
read_keys = tf.zeros([batch_size, self.h_RH, self.h_W], dtype=dtype)
|
||||||
|
read_strengths = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
|
||||||
|
read_modes = tf.zeros([batch_size, self.h_RH, 3], dtype=dtype)
|
||||||
|
|
||||||
|
analyse_states = alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths, read_modes
|
||||||
|
|
||||||
|
return analyse_states
|
||||||
|
|
||||||
|
def __call__(self, inputs, pre_states, scope=None):
|
||||||
|
|
||||||
|
self.h_B = inputs.get_shape()[0].value
|
||||||
|
|
||||||
|
link_matrix_inv_eye, memory_ones, batch_memory_range = self._create_constant_value_tensors(self.h_B, self.dtype)
|
||||||
|
self.const_link_matrix_inv_eye = link_matrix_inv_eye
|
||||||
|
self.const_memory_ones = memory_ones
|
||||||
|
self.const_batch_memory_range = batch_memory_range
|
||||||
|
|
||||||
|
pre_memory, pre_usage_vector, pre_write_weightings, pre_precedence_weighting, pre_link_matrix, pre_read_weightings = pre_states
|
||||||
|
|
||||||
|
weighted_input = self._weight_input(inputs)
|
||||||
|
|
||||||
|
control_signals = self._create_control_signals(weighted_input)
|
||||||
|
alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths, read_modes = control_signals
|
||||||
|
|
||||||
|
alloc_weightings, usage_vector = self._update_alloc_and_usage_vectors(pre_write_weightings, pre_read_weightings,
|
||||||
|
pre_usage_vector, free_gates)
|
||||||
|
write_content_weighting = self._calculate_content_weightings(pre_memory, write_keys, write_strengths)
|
||||||
|
write_weighting = self._update_write_weighting(alloc_weightings, write_content_weighting, write_gate,
|
||||||
|
alloc_gate)
|
||||||
|
memory = self._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
|
||||||
|
link_matrix, precedence_weighting = self._update_link_matrix(pre_link_matrix, write_weighting,
|
||||||
|
pre_precedence_weighting)
|
||||||
|
forward_weightings, backward_weightings = self._make_read_forward_backward_weightings(link_matrix,
|
||||||
|
pre_read_weightings)
|
||||||
|
read_content_weightings = self._calculate_content_weightings(memory, read_keys, read_strengths)
|
||||||
|
read_weightings = self._make_read_weightings(forward_weightings, backward_weightings, read_content_weightings,
|
||||||
|
read_modes)
|
||||||
|
read_vectors = self._read_memory(memory, read_weightings)
|
||||||
|
read_vectors = tf.reshape(read_vectors, [self.h_B, self.h_W * self.h_RH])
|
||||||
|
|
||||||
|
if self.bypass_dropout:
|
||||||
|
input_bypass = tf.nn.dropout(inputs, self.bypass_dropout)
|
||||||
|
else:
|
||||||
|
input_bypass = inputs
|
||||||
|
|
||||||
|
output = tf.concat([read_vectors, input_bypass], axis=-1)
|
||||||
|
|
||||||
|
if self.analyse:
|
||||||
|
output = (output, control_signals)
|
||||||
|
|
||||||
|
return output, (memory, usage_vector, write_weighting, precedence_weighting, link_matrix, read_weightings)
|
||||||
|
|
||||||
|
def _create_constant_value_tensors(self, batch_size, dtype):
|
||||||
|
|
||||||
|
link_matrix_inv_eye = 1 - tf.constant(np.identity(self.h_N), dtype=dtype, name="link_matrix_inv_eye")
|
||||||
|
memory_ones = tf.ones([batch_size, self.h_N, self.h_W], dtype=dtype, name="memory_ones")
|
||||||
|
|
||||||
|
batch_range = tf.range(0, batch_size, delta=1, dtype=tf.int32, name="batch_range")
|
||||||
|
repeat_memory_length = tf.fill([self.h_N], tf.constant(self.h_N, dtype=tf.int32), name="repeat_memory_length")
|
||||||
|
batch_memory_range = tf.matmul(tf.expand_dims(batch_range, -1), tf.expand_dims(repeat_memory_length, 0),
|
||||||
|
name="batch_memory_range")
|
||||||
|
return link_matrix_inv_eye, memory_ones, batch_memory_range
|
||||||
|
|
||||||
|
def _weight_input(self, inputs):
|
||||||
|
|
||||||
|
input_size = inputs.get_shape()[1].value
|
||||||
|
total_signal_size = (3 + self.h_RH) * self.h_W + 5 * self.h_RH + 3
|
||||||
|
|
||||||
|
with tf.variable_scope('{}'.format(self.name), reuse=self.reuse):
|
||||||
|
w_x = tf.get_variable("mu_w_x", (input_size, total_signal_size),
|
||||||
|
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed),
|
||||||
|
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||||
|
|
||||||
|
b_x = tf.get_variable("mu_b_x", (total_signal_size,), initializer=tf.constant_initializer(0.),
|
||||||
|
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||||
|
|
||||||
|
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||||
|
|
||||||
|
if self.dnc_norm:
|
||||||
|
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
|
||||||
|
collection='memory_unit')
|
||||||
|
return weighted_input
|
||||||
|
|
||||||
|
def _create_control_signals(self, weighted_input):
|
||||||
|
|
||||||
|
write_keys = weighted_input[:, : self.h_W] # W
|
||||||
|
write_strengths = weighted_input[:, self.h_W: self.h_W + 1] # 1
|
||||||
|
erase_vector = weighted_input[:, self.h_W + 1: 2 * self.h_W + 1] # W
|
||||||
|
write_vector = weighted_input[:, 2 * self.h_W + 1: 3 * self.h_W + 1] # W
|
||||||
|
alloc_gates = weighted_input[:, 3 * self.h_W + 1: 3 * self.h_W + 2] # 1
|
||||||
|
write_gates = weighted_input[:, 3 * self.h_W + 2: 3 * self.h_W + 3] # 1
|
||||||
|
read_keys = weighted_input[:, 3 * self.h_W + 3: (self.h_RH + 3) * self.h_W + 3] # R * W
|
||||||
|
read_strengths = weighted_input[:,
|
||||||
|
(self.h_RH + 3) * self.h_W + 3: (self.h_RH + 3) * self.h_W + 3 + 1 * self.h_RH] # R
|
||||||
|
read_modes = weighted_input[:, (self.h_RH + 3) * self.h_W + 3 + 1 * self.h_RH: (
|
||||||
|
self.h_RH + 3) * self.h_W + 3 + 4 * self.h_RH] # 3R
|
||||||
|
free_gates = weighted_input[:, (self.h_RH + 3) * self.h_W + 3 + 4 * self.h_RH: (
|
||||||
|
self.h_RH + 3) * self.h_W + 3 + 5 * self.h_RH] # R
|
||||||
|
|
||||||
|
alloc_gates = tf.sigmoid(alloc_gates, 'alloc_gates')
|
||||||
|
free_gates = tf.sigmoid(free_gates, 'free_gates')
|
||||||
|
free_gates = tf.expand_dims(free_gates, 2)
|
||||||
|
write_gates = tf.sigmoid(write_gates, 'write_gates')
|
||||||
|
|
||||||
|
write_keys = tf.expand_dims(write_keys, axis=1)
|
||||||
|
write_strengths = oneplus(write_strengths)
|
||||||
|
# write_strengths = tf.expand_dims(write_strengths, axis=2)
|
||||||
|
write_vector = tf.reshape(write_vector, [self.h_B, 1, self.h_W])
|
||||||
|
erase_vector = tf.sigmoid(erase_vector, 'erase_vector')
|
||||||
|
erase_vector = tf.reshape(erase_vector, [self.h_B, 1, self.h_W])
|
||||||
|
|
||||||
|
read_keys = tf.reshape(read_keys, [self.h_B, self.h_RH, self.h_W])
|
||||||
|
read_strengths = oneplus(read_strengths)
|
||||||
|
read_strengths = tf.expand_dims(read_strengths, axis=2)
|
||||||
|
read_modes = tf.reshape(read_modes, [self.h_B, self.h_RH, 3]) # 3 read modes
|
||||||
|
read_modes = tf.nn.softmax(read_modes, dim=2)
|
||||||
|
|
||||||
|
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths, read_modes
|
||||||
|
|
||||||
|
def _update_alloc_and_usage_vectors(self, pre_write_weightings, pre_read_weightings, pre_usage_vector, free_gates):
|
||||||
|
|
||||||
|
retention_vector = tf.reduce_prod(1 - free_gates * pre_read_weightings, axis=1, keepdims=False,
|
||||||
|
name='retention_prod')
|
||||||
|
usage_vector = (
|
||||||
|
pre_usage_vector + pre_write_weightings - pre_usage_vector * pre_write_weightings) * retention_vector
|
||||||
|
|
||||||
|
sorted_usage, free_list = tf.nn.top_k(-1 * usage_vector, self.h_N)
|
||||||
|
sorted_usage = -1 * sorted_usage
|
||||||
|
|
||||||
|
cumprod_sorted_usage = tf.cumprod(sorted_usage, axis=1, exclusive=True)
|
||||||
|
corrected_free_list = free_list + self.const_batch_memory_range
|
||||||
|
|
||||||
|
cumprod_sorted_usage_re = [tf.reshape(cumprod_sorted_usage, [-1, ]), ]
|
||||||
|
corrected_free_list_re = [tf.reshape(corrected_free_list, [-1]), ]
|
||||||
|
|
||||||
|
stitched_usage = tf.dynamic_stitch(corrected_free_list_re, cumprod_sorted_usage_re, name=None)
|
||||||
|
|
||||||
|
stitched_usage = tf.reshape(stitched_usage, [self.h_B, self.h_N])
|
||||||
|
|
||||||
|
alloc_weighting = (1 - usage_vector) * stitched_usage
|
||||||
|
|
||||||
|
return alloc_weighting, usage_vector
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_write_weighting(alloc_weighting, write_content_weighting, write_gate, alloc_gate):
|
||||||
|
|
||||||
|
write_weighting = write_gate * (alloc_gate * alloc_weighting + (1 - alloc_gate) * write_content_weighting)
|
||||||
|
|
||||||
|
return write_weighting
|
||||||
|
|
||||||
|
def _update_memory(self, pre_memory, write_weighting, write_vector, erase_vector):
|
||||||
|
|
||||||
|
write_w = tf.expand_dims(write_weighting, 2)
|
||||||
|
erase_matrix = tf.multiply(pre_memory, (self.const_memory_ones - tf.matmul(write_w, erase_vector)))
|
||||||
|
write_matrix = tf.matmul(write_w, write_vector)
|
||||||
|
return erase_matrix + write_matrix
|
||||||
|
|
||||||
|
def _update_link_matrix(self, pre_link_matrix, write_weighting, pre_precedence_weighting):
|
||||||
|
|
||||||
|
precedence_weighting = (1 - tf.reduce_sum(write_weighting, 1,
|
||||||
|
keepdims=True)) * pre_precedence_weighting + write_weighting
|
||||||
|
|
||||||
|
add_mat = tf.matmul(tf.expand_dims(write_weighting, axis=2),
|
||||||
|
tf.expand_dims(pre_precedence_weighting, axis=1))
|
||||||
|
erase_mat = 1 - tf.expand_dims(write_weighting, 1) - tf.expand_dims(write_weighting, 2)
|
||||||
|
|
||||||
|
updated_link_mat = erase_mat * pre_link_matrix + add_mat
|
||||||
|
link_matrix = self.const_link_matrix_inv_eye * updated_link_mat
|
||||||
|
|
||||||
|
return link_matrix, precedence_weighting
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_read_forward_backward_weightings(link_matrix, pre_read_weightings):
|
||||||
|
|
||||||
|
forward_weightings = tf.matmul(pre_read_weightings, link_matrix)
|
||||||
|
backward_weightings = tf.matmul(pre_read_weightings, link_matrix, adjoint_b=True)
|
||||||
|
|
||||||
|
return forward_weightings, backward_weightings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_read_weightings(forward_weightings, backward_weightings, read_content_weightings, read_modes):
|
||||||
|
|
||||||
|
read_weighting = tf.expand_dims(read_modes[:, :, 0], 2) * backward_weightings + \
|
||||||
|
tf.expand_dims(read_modes[:, :, 1], 2) * read_content_weightings + \
|
||||||
|
tf.expand_dims(read_modes[:, :, 2], 2) * forward_weightings
|
||||||
|
|
||||||
|
return read_weighting
|
160
adnc/model/memory_units/multi_write_content_based_cell.py
Normal file
160
adnc/model/memory_units/multi_write_content_based_cell.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
|
||||||
|
from adnc.model.utils import layer_norm
|
||||||
|
from adnc.model.utils import oneplus
|
||||||
|
from adnc.model.utils import unit_simplex_initialization
|
||||||
|
|
||||||
|
|
||||||
|
class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
init_memory = tf.TensorShape([self.h_N, self.h_W])
|
||||||
|
init_usage_vector = tf.TensorShape([self.h_N])
|
||||||
|
init_write_weighting = tf.TensorShape([self.h_WH, self.h_N])
|
||||||
|
init_read_weighting = tf.TensorShape([self.h_RH, self.h_N])
|
||||||
|
return (init_memory, init_usage_vector, init_write_weighting, init_read_weighting)
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype=tf.float32):
|
||||||
|
|
||||||
|
init_memory = tf.fill([batch_size, self.h_N, self.h_W], tf.cast(1 / (self.h_N * self.h_W), dtype=dtype))
|
||||||
|
init_usage_vector = tf.zeros([batch_size, self.h_N], dtype=dtype)
|
||||||
|
init_write_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_WH, self.h_N], dtype=dtype)
|
||||||
|
init_read_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_RH, self.h_N], dtype=dtype)
|
||||||
|
zero_states = (init_memory, init_usage_vector, init_write_weighting, init_read_weighting,)
|
||||||
|
return zero_states
|
||||||
|
|
||||||
|
def analyse_state(self, batch_size, dtype=tf.float32):
|
||||||
|
|
||||||
|
alloc_gate = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype) # WH
|
||||||
|
free_gates = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
|
||||||
|
write_gate = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype)
|
||||||
|
write_keys = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
|
||||||
|
write_strengths = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype)
|
||||||
|
write_vector = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
|
||||||
|
erase_vector = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
|
||||||
|
read_keys = tf.zeros([batch_size, self.h_RH, self.h_W], dtype=dtype)
|
||||||
|
read_strengths = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
|
||||||
|
|
||||||
|
analyse_states = alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths
|
||||||
|
|
||||||
|
return analyse_states
|
||||||
|
|
||||||
|
def __call__(self, inputs, pre_states, scope=None):
|
||||||
|
|
||||||
|
self.h_B = inputs.get_shape()[0].value
|
||||||
|
|
||||||
|
memory_ones, batch_memory_range = self._create_constant_value_tensors(self.h_B, self.dtype)
|
||||||
|
self.const_memory_ones = memory_ones
|
||||||
|
self.const_batch_memory_range = batch_memory_range
|
||||||
|
|
||||||
|
pre_memory, pre_usage_vector, pre_write_weightings, pre_read_weightings = pre_states
|
||||||
|
|
||||||
|
weighted_input = self._weight_input(inputs)
|
||||||
|
|
||||||
|
control_signals = self._create_control_signals(weighted_input)
|
||||||
|
alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths = control_signals
|
||||||
|
|
||||||
|
alloc_weightings, usage_vector = self._update_alloc_and_usage_vectors(pre_write_weightings, pre_read_weightings,
|
||||||
|
pre_usage_vector, free_gates, write_gate)
|
||||||
|
write_content_weighting = self._calculate_content_weightings(pre_memory, write_keys, write_strengths)
|
||||||
|
write_weighting = self._update_write_weightings(alloc_weightings, write_content_weighting, write_gate,
|
||||||
|
alloc_gate)
|
||||||
|
memory = self._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
|
||||||
|
read_content_weightings = self._calculate_content_weightings(memory, read_keys, read_strengths)
|
||||||
|
read_vectors = self._read_memory(memory, read_content_weightings)
|
||||||
|
read_vectors = tf.reshape(read_vectors, [self.h_B, self.h_W * self.h_RH])
|
||||||
|
|
||||||
|
if self.bypass_dropout:
|
||||||
|
input_bypass = tf.nn.dropout(inputs, self.bypass_dropout)
|
||||||
|
else:
|
||||||
|
input_bypass = inputs
|
||||||
|
|
||||||
|
output = tf.concat([read_vectors, input_bypass], axis=-1)
|
||||||
|
|
||||||
|
if self.analyse:
|
||||||
|
output = (output, control_signals)
|
||||||
|
|
||||||
|
return output, (memory, usage_vector, write_weighting, read_content_weightings)
|
||||||
|
|
||||||
|
def _create_constant_value_tensors(self, batch_size, dtype):
|
||||||
|
|
||||||
|
memory_ones = tf.ones([batch_size, self.h_N, self.h_W], dtype=dtype, name="memory_ones")
|
||||||
|
|
||||||
|
batch_range = tf.range(0, batch_size, delta=1, dtype=tf.int32, name="batch_range")
|
||||||
|
repeat_memory_length = tf.fill([self.h_N], tf.constant(self.h_N, dtype=tf.int32), name="repeat_memory_length")
|
||||||
|
batch_memory_range = tf.matmul(tf.expand_dims(batch_range, -1), tf.expand_dims(repeat_memory_length, 0),
|
||||||
|
name="batch_memory_range")
|
||||||
|
return memory_ones, batch_memory_range
|
||||||
|
|
||||||
|
def _weight_input(self, inputs):
|
||||||
|
|
||||||
|
input_size = inputs.get_shape()[1].value
|
||||||
|
total_signal_size = self.h_RH * (2 + self.h_W) + self.h_WH * (3 + 3 * self.h_W)
|
||||||
|
|
||||||
|
with tf.variable_scope('{}'.format(self.name), reuse=self.reuse):
|
||||||
|
w_x = tf.get_variable("mu_w_x", (input_size, total_signal_size),
|
||||||
|
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed),
|
||||||
|
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||||
|
b_x = tf.get_variable("mu_b_x", (total_signal_size,), initializer=tf.constant_initializer(0.),
|
||||||
|
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||||
|
|
||||||
|
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||||
|
if self.dnc_norm:
|
||||||
|
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
|
||||||
|
collection='memory_unit')
|
||||||
|
return weighted_input
|
||||||
|
|
||||||
|
def _create_control_signals(self, weighted_input):
|
||||||
|
|
||||||
|
alloc_gates = weighted_input[:, : self.h_WH]
|
||||||
|
free_gates = weighted_input[:, self.h_WH: self.h_WH + self.h_RH]
|
||||||
|
write_gates = weighted_input[:, self.h_WH + self.h_RH: 2 * self.h_WH + self.h_RH]
|
||||||
|
write_keys = weighted_input[:, 2 * self.h_WH + self.h_RH: (self.h_W + 2) * self.h_WH + self.h_RH]
|
||||||
|
write_strengths = weighted_input[:,
|
||||||
|
(self.h_W + 2) * self.h_WH + self.h_RH: (self.h_W + 3) * self.h_WH + self.h_RH]
|
||||||
|
write_vectors = weighted_input[:,
|
||||||
|
(self.h_W + 3) * self.h_WH + self.h_RH: (2 * self.h_W + 3) * self.h_WH + self.h_RH]
|
||||||
|
erase_vectors = weighted_input[:,
|
||||||
|
(2 * self.h_W + 3) * self.h_WH + self.h_RH: (3 * self.h_W + 3) * self.h_WH + self.h_RH]
|
||||||
|
read_keys = weighted_input[:, (3 * self.h_W + 3) * self.h_WH + self.h_RH: (3 * self.h_W + 3) * self.h_WH
|
||||||
|
+ (self.h_W + 1) * self.h_RH]
|
||||||
|
read_strengths = weighted_input[:, (3 * self.h_W + 3) * self.h_WH + (self.h_W + 1) * self.h_RH:]
|
||||||
|
|
||||||
|
alloc_gates = tf.sigmoid(alloc_gates, 'alloc_gates')
|
||||||
|
alloc_gates = tf.expand_dims(alloc_gates, 2)
|
||||||
|
free_gates = tf.sigmoid(free_gates, 'free_gates')
|
||||||
|
free_gates = tf.expand_dims(free_gates, 2)
|
||||||
|
write_gates = tf.sigmoid(write_gates, 'write_gates')
|
||||||
|
write_gates = tf.expand_dims(write_gates, 2)
|
||||||
|
|
||||||
|
write_keys = tf.reshape(write_keys, [self.h_B, self.h_WH, self.h_W])
|
||||||
|
write_strengths = oneplus(write_strengths)
|
||||||
|
write_strengths = tf.expand_dims(write_strengths, axis=2)
|
||||||
|
write_vectors = tf.reshape(write_vectors, [self.h_B, self.h_WH, self.h_W])
|
||||||
|
erase_vectors = tf.reshape(erase_vectors, [self.h_B, self.h_WH, self.h_W])
|
||||||
|
erase_vectors = tf.sigmoid(erase_vectors, 'erase_vector')
|
||||||
|
|
||||||
|
read_keys = tf.reshape(read_keys, [self.h_B, self.h_RH, self.h_W])
|
||||||
|
read_strengths = oneplus(read_strengths)
|
||||||
|
read_strengths = tf.expand_dims(read_strengths, axis=2)
|
||||||
|
|
||||||
|
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
|
||||||
|
erase_vectors, read_keys, read_strengths
|
275
adnc/model/memory_units/multi_write_dnc_cell.py
Normal file
275
adnc/model/memory_units/multi_write_dnc_cell.py
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
|
||||||
|
from adnc.model.utils import layer_norm
|
||||||
|
from adnc.model.utils import oneplus
|
||||||
|
from adnc.model.utils import unit_simplex_initialization
|
||||||
|
|
||||||
|
|
||||||
|
class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||||
|
def __init__(self, input_size, memory_length, memory_width, read_heads, write_heads, bypass_dropout=False,
|
||||||
|
dnc_norm=False, seed=100, reuse=False, analyse=False, dtype=tf.float32, name='mwdnc_mu'):
|
||||||
|
|
||||||
|
self.h_WH = write_heads
|
||||||
|
super().__init__(input_size, memory_length, memory_width, read_heads, bypass_dropout, dnc_norm, seed, reuse,
|
||||||
|
analyse, dtype, name)
|
||||||
|
self.h_B = 0 # will set in call
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state_size(self):
|
||||||
|
init_memory = tf.TensorShape([self.h_N, self.h_W])
|
||||||
|
init_usage_vector = tf.TensorShape([self.h_N])
|
||||||
|
init_write_weighting = tf.TensorShape([self.h_WH, self.h_N])
|
||||||
|
init_precedence_weightings = tf.TensorShape([self.h_WH, self.h_N])
|
||||||
|
init_link_mat = tf.TensorShape([self.h_WH, self.h_N, self.h_N])
|
||||||
|
init_read_weighting = tf.TensorShape([self.h_RH, self.h_N])
|
||||||
|
return (init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings,
|
||||||
|
init_link_mat, init_read_weighting)
|
||||||
|
|
||||||
|
def zero_state(self, batch_size, dtype=tf.float32):
|
||||||
|
|
||||||
|
init_memory = tf.fill([batch_size, self.h_N, self.h_W], tf.cast(1 / (self.h_N * self.h_W), dtype=dtype))
|
||||||
|
init_usage_vector = tf.zeros([batch_size, self.h_N], dtype=dtype)
|
||||||
|
init_write_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_WH, self.h_N], dtype=dtype)
|
||||||
|
init_precedence_weightings = tf.zeros([batch_size, self.h_WH, self.h_N], dtype=dtype)
|
||||||
|
init_link_mat = tf.zeros([batch_size, self.h_WH, self.h_N, self.h_N], dtype=dtype)
|
||||||
|
init_read_weighting = unit_simplex_initialization(self.rng, batch_size, [self.h_RH, self.h_N], dtype=dtype)
|
||||||
|
zero_states = (init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings,
|
||||||
|
init_link_mat, init_read_weighting,)
|
||||||
|
return zero_states
|
||||||
|
|
||||||
|
def analyse_state(self, batch_size, dtype=tf.float32):
|
||||||
|
|
||||||
|
alloc_gate = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype) # WH
|
||||||
|
free_gates = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
|
||||||
|
write_gate = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype)
|
||||||
|
write_keys = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
|
||||||
|
write_strengths = tf.zeros([batch_size, self.h_WH, 1], dtype=dtype)
|
||||||
|
write_vector = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
|
||||||
|
erase_vector = tf.zeros([batch_size, self.h_WH, self.h_W], dtype=dtype)
|
||||||
|
read_keys = tf.zeros([batch_size, self.h_RH, self.h_W], dtype=dtype)
|
||||||
|
read_strengths = tf.zeros([batch_size, self.h_RH, 1], dtype=dtype)
|
||||||
|
read_modes = tf.zeros([batch_size, self.h_RH, 1 + 2 * self.h_WH], dtype=dtype)
|
||||||
|
|
||||||
|
analyse_states = alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths, read_modes
|
||||||
|
|
||||||
|
return analyse_states
|
||||||
|
|
||||||
|
def __call__(self, inputs, pre_states, scope=None):
|
||||||
|
|
||||||
|
self.h_B = inputs.get_shape()[0].value
|
||||||
|
|
||||||
|
link_matrix_inv_eye, memory_ones, batch_memory_range = self._create_constant_value_tensors(self.h_B, self.dtype)
|
||||||
|
self.const_link_matrix_inv_eye = link_matrix_inv_eye
|
||||||
|
self.const_memory_ones = memory_ones
|
||||||
|
self.const_batch_memory_range = batch_memory_range
|
||||||
|
|
||||||
|
pre_memory, pre_usage_vector, pre_write_weightings, pre_precedence_weighting, pre_link_matrix, pre_read_weightings = pre_states
|
||||||
|
|
||||||
|
weighted_input = self._weight_input(inputs)
|
||||||
|
|
||||||
|
control_signals = self._create_control_signals(weighted_input)
|
||||||
|
alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths, read_modes = control_signals
|
||||||
|
|
||||||
|
alloc_weightings, usage_vector = self._update_alloc_and_usage_vectors(pre_write_weightings, pre_read_weightings,
|
||||||
|
pre_usage_vector, free_gates, write_gate)
|
||||||
|
write_content_weighting = self._calculate_content_weightings(pre_memory, write_keys, write_strengths)
|
||||||
|
write_weighting = self._update_write_weightings(alloc_weightings, write_content_weighting, write_gate,
|
||||||
|
alloc_gate)
|
||||||
|
memory = self._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
|
||||||
|
link_matrix, precedence_weighting = self._update_link_matrix(pre_link_matrix, write_weighting,
|
||||||
|
pre_precedence_weighting)
|
||||||
|
forward_weightings, backward_weightings = self._make_read_forward_backward_weightings(link_matrix,
|
||||||
|
pre_read_weightings)
|
||||||
|
read_content_weightings = self._calculate_content_weightings(memory, read_keys, read_strengths)
|
||||||
|
read_weightings = self._make_read_weightings(forward_weightings, backward_weightings, read_content_weightings,
|
||||||
|
read_modes)
|
||||||
|
read_vectors = self._read_memory(memory, read_weightings)
|
||||||
|
read_vectors = tf.reshape(read_vectors, [self.h_B, self.h_W * self.h_RH])
|
||||||
|
|
||||||
|
if self.bypass_dropout:
|
||||||
|
input_bypass = tf.nn.dropout(inputs, self.bypass_dropout)
|
||||||
|
else:
|
||||||
|
input_bypass = inputs
|
||||||
|
|
||||||
|
output = tf.concat([read_vectors, input_bypass], axis=-1)
|
||||||
|
|
||||||
|
if self.analyse:
|
||||||
|
output = (output, control_signals)
|
||||||
|
|
||||||
|
return output, (memory, usage_vector, write_weighting, precedence_weighting, link_matrix, read_weightings)
|
||||||
|
|
||||||
|
def _create_constant_value_tensors(self, batch_size, dtype):
|
||||||
|
|
||||||
|
link_matrix_inv_eye = 1 - tf.constant(np.identity(self.h_N), dtype=dtype, name="link_matrix_inv_eye")
|
||||||
|
link_matrix_inv_eye = tf.stack([link_matrix_inv_eye, ] * self.h_WH, axis=0)
|
||||||
|
link_matrix_inv_eye = tf.stack([link_matrix_inv_eye, ] * batch_size, axis=0)
|
||||||
|
|
||||||
|
memory_ones = tf.ones([batch_size, self.h_N, self.h_W], dtype=dtype, name="memory_ones")
|
||||||
|
|
||||||
|
batch_range = tf.range(0, batch_size, delta=1, dtype=tf.int32, name="batch_range")
|
||||||
|
repeat_memory_length = tf.fill([self.h_N], tf.constant(self.h_N, dtype=tf.int32), name="repeat_memory_length")
|
||||||
|
batch_memory_range = tf.matmul(tf.expand_dims(batch_range, -1), tf.expand_dims(repeat_memory_length, 0),
|
||||||
|
name="batch_memory_range")
|
||||||
|
return link_matrix_inv_eye, memory_ones, batch_memory_range
|
||||||
|
|
||||||
|
def _weight_input(self, inputs):
|
||||||
|
|
||||||
|
input_size = inputs.get_shape()[1].value
|
||||||
|
total_signal_size = self.h_RH * (3 + 2 * self.h_WH + self.h_W) + self.h_WH * (3 + 3 * self.h_W)
|
||||||
|
|
||||||
|
with tf.variable_scope('{}'.format(self.name), reuse=self.reuse):
|
||||||
|
w_x = tf.get_variable("mu_w_x", (input_size, total_signal_size),
|
||||||
|
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed),
|
||||||
|
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||||
|
b_x = tf.get_variable("mu_b_x", (total_signal_size,), initializer=tf.constant_initializer(0.),
|
||||||
|
collections=['memory_unit', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||||
|
|
||||||
|
weighted_input = tf.matmul(inputs, w_x) + b_x
|
||||||
|
if self.dnc_norm:
|
||||||
|
weighted_input = layer_norm(weighted_input, name='dnc_norm', dtype=self.dtype,
|
||||||
|
collection='memory_unit')
|
||||||
|
return weighted_input
|
||||||
|
|
||||||
|
def _create_control_signals(self, weighted_input):
|
||||||
|
|
||||||
|
alloc_gates = weighted_input[:, : self.h_WH]
|
||||||
|
free_gates = weighted_input[:, self.h_WH: self.h_WH + self.h_RH]
|
||||||
|
write_gates = weighted_input[:, self.h_WH + self.h_RH: 2 * self.h_WH + self.h_RH]
|
||||||
|
write_keys = weighted_input[:, 2 * self.h_WH + self.h_RH: (self.h_W + 2) * self.h_WH + self.h_RH]
|
||||||
|
write_strengths = weighted_input[:,
|
||||||
|
(self.h_W + 2) * self.h_WH + self.h_RH: (self.h_W + 3) * self.h_WH + self.h_RH]
|
||||||
|
write_vectors = weighted_input[:,
|
||||||
|
(self.h_W + 3) * self.h_WH + self.h_RH: (2 * self.h_W + 3) * self.h_WH + self.h_RH]
|
||||||
|
erase_vectors = weighted_input[:,
|
||||||
|
(2 * self.h_W + 3) * self.h_WH + self.h_RH: (3 * self.h_W + 3) * self.h_WH + self.h_RH]
|
||||||
|
read_keys = weighted_input[:, (3 * self.h_W + 3) * self.h_WH + self.h_RH: (3 * self.h_W + 3) * self.h_WH +
|
||||||
|
(self.h_W + 1) * self.h_RH]
|
||||||
|
read_strengths = weighted_input[:,
|
||||||
|
(3 * self.h_W + 3) * self.h_WH + (self.h_W + 1) * self.h_RH: (3 * self.h_W + 3) * self.h_WH +
|
||||||
|
(self.h_W + 2) * self.h_RH]
|
||||||
|
read_modes = weighted_input[:, (3 * self.h_W + 3) * self.h_WH + (self.h_W + 2) * self.h_RH:]
|
||||||
|
|
||||||
|
alloc_gates = tf.sigmoid(alloc_gates, 'alloc_gates')
|
||||||
|
alloc_gates = tf.expand_dims(alloc_gates, 2)
|
||||||
|
free_gates = tf.sigmoid(free_gates, 'free_gates')
|
||||||
|
free_gates = tf.expand_dims(free_gates, 2)
|
||||||
|
write_gates = tf.sigmoid(write_gates, 'write_gates')
|
||||||
|
write_gates = tf.expand_dims(write_gates, 2)
|
||||||
|
|
||||||
|
write_keys = tf.reshape(write_keys, [self.h_B, self.h_WH, self.h_W])
|
||||||
|
write_strengths = oneplus(write_strengths)
|
||||||
|
write_strengths = tf.expand_dims(write_strengths, axis=2)
|
||||||
|
write_vectors = tf.reshape(write_vectors, [self.h_B, self.h_WH, self.h_W])
|
||||||
|
erase_vectors = tf.reshape(erase_vectors, [self.h_B, self.h_WH, self.h_W])
|
||||||
|
erase_vectors = tf.sigmoid(erase_vectors, 'erase_vector')
|
||||||
|
|
||||||
|
read_keys = tf.reshape(read_keys, [self.h_B, self.h_RH, self.h_W])
|
||||||
|
read_strengths = oneplus(read_strengths)
|
||||||
|
read_strengths = tf.expand_dims(read_strengths, axis=2)
|
||||||
|
read_modes = tf.reshape(read_modes, [self.h_B, self.h_RH, 1 + 2 * self.h_WH])
|
||||||
|
read_modes = tf.nn.softmax(read_modes, dim=2)
|
||||||
|
|
||||||
|
return alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
|
||||||
|
erase_vectors, read_keys, read_strengths, read_modes
|
||||||
|
|
||||||
|
def _update_alloc_and_usage_vectors(self, pre_write_weightings, pre_read_weightings, pre_usage_vector, free_gates,
|
||||||
|
write_gates):
|
||||||
|
|
||||||
|
# usage update after write from last time step
|
||||||
|
pre_write_weighting = 1 - tf.reduce_prod(1 - pre_write_weightings, [1], keepdims=False)
|
||||||
|
usage_vector = pre_usage_vector + pre_write_weighting - pre_usage_vector * pre_write_weighting
|
||||||
|
|
||||||
|
# usage update after read
|
||||||
|
retention_vector = tf.reduce_prod(1 - free_gates * pre_read_weightings, axis=1, keepdims=False,
|
||||||
|
name='retention_prod')
|
||||||
|
usage_vector = usage_vector * retention_vector
|
||||||
|
|
||||||
|
usage_vector_cp = tf.identity(usage_vector)
|
||||||
|
|
||||||
|
alloc_list = []
|
||||||
|
for w in range(self.h_WH):
|
||||||
|
sorted_usage, free_list = tf.nn.top_k(-1 * usage_vector_cp, self.h_N)
|
||||||
|
sorted_usage = -1 * sorted_usage
|
||||||
|
|
||||||
|
cumprod_sorted_usage = tf.cumprod(sorted_usage, axis=1, exclusive=True)
|
||||||
|
corrected_free_list = free_list + self.const_batch_memory_range
|
||||||
|
|
||||||
|
corrected_free_list_un = [tf.reshape(corrected_free_list, [-1, ]), ]
|
||||||
|
cumprod_sorted_usage_un = [tf.reshape(cumprod_sorted_usage, [-1, ]), ]
|
||||||
|
|
||||||
|
stitched_usage = tf.dynamic_stitch(corrected_free_list_un, cumprod_sorted_usage_un, name=None)
|
||||||
|
stitched_usage = tf.reshape(stitched_usage, [self.h_B, self.h_N])
|
||||||
|
|
||||||
|
alloc_weighting = (1 - usage_vector_cp) * stitched_usage
|
||||||
|
|
||||||
|
alloc_list.append(alloc_weighting)
|
||||||
|
usage_vector_cp = usage_vector_cp + ((1 - usage_vector_cp) * write_gates[:, w, :] * alloc_weighting)
|
||||||
|
|
||||||
|
alloc_weighting = tf.stack(alloc_list, 1)
|
||||||
|
|
||||||
|
return alloc_weighting, usage_vector
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _update_write_weightings(alloc_weighting, write_content_weighting, write_gate, alloc_gate):
|
||||||
|
|
||||||
|
write_weighting = write_gate * (alloc_gate * alloc_weighting + (1 - alloc_gate) * write_content_weighting)
|
||||||
|
|
||||||
|
return write_weighting
|
||||||
|
|
||||||
|
def _update_memory(self, pre_memory, write_weighting, write_vector, erase_vector):
|
||||||
|
|
||||||
|
write_w = tf.expand_dims(write_weighting, 3)
|
||||||
|
erase_vector = tf.expand_dims(erase_vector, 2)
|
||||||
|
erase_matrix = tf.reduce_prod(1 - write_w * erase_vector, axis=1, keepdims=False)
|
||||||
|
write_matrix = tf.matmul(write_weighting, write_vector, adjoint_a=True)
|
||||||
|
|
||||||
|
return pre_memory * erase_matrix + write_matrix
|
||||||
|
|
||||||
|
def _update_link_matrix(self, pre_link_matrices, write_weightings, pre_precedence_weightings):
|
||||||
|
|
||||||
|
precedence_weightings = (1 - tf.reduce_sum(write_weightings, 2,
|
||||||
|
keepdims=True)) * pre_precedence_weightings + write_weightings
|
||||||
|
|
||||||
|
add_mat = tf.expand_dims(write_weightings, axis=3) * tf.expand_dims(pre_precedence_weightings, axis=2)
|
||||||
|
erase_mat = 1 - tf.expand_dims(write_weightings, 2) - tf.expand_dims(write_weightings, 3)
|
||||||
|
|
||||||
|
updated_link_mat = erase_mat * pre_link_matrices + add_mat
|
||||||
|
|
||||||
|
link_matrices = self.const_link_matrix_inv_eye * updated_link_mat
|
||||||
|
|
||||||
|
return link_matrices, precedence_weightings
|
||||||
|
|
||||||
|
def _make_read_forward_backward_weightings(self, link_matrix, pre_read_weightings):
|
||||||
|
|
||||||
|
read_weightings_stacked = tf.stack([pre_read_weightings, ] * self.h_WH, axis=1)
|
||||||
|
|
||||||
|
forward_weightings = tf.matmul(read_weightings_stacked, link_matrix)
|
||||||
|
backward_weightings = tf.matmul(read_weightings_stacked, link_matrix, adjoint_b=True)
|
||||||
|
|
||||||
|
return tf.transpose(forward_weightings, (0, 2, 1, 3)), tf.transpose(backward_weightings, (0, 2, 1, 3))
|
||||||
|
|
||||||
|
def _make_read_weightings(self, forward_weightings, backward_weightings, read_content_weightings, read_modes):
|
||||||
|
|
||||||
|
read_weighting = tf.reduce_sum(tf.expand_dims(read_modes[:, :, :self.h_WH], 3) * backward_weightings, axis=2) + \
|
||||||
|
tf.expand_dims(read_modes[:, :, self.h_WH], 2) * read_content_weightings + \
|
||||||
|
tf.reduce_sum(tf.expand_dims(read_modes[:, :, self.h_WH + 1:], 3) * forward_weightings, axis=2)
|
||||||
|
|
||||||
|
return read_weighting
|
0
test/adnc/model/memory_units/__init__.py
Executable file
0
test/adnc/model/memory_units/__init__.py
Executable file
152
test/adnc/model/memory_units/test_base_cell.py
Normal file
152
test/adnc/model/memory_units/test_base_cell.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
params=[{"seed": 123, "input_size": 13, "batch_size": 3, "memory_length": 4, "memory_width": 4, "read_heads": 3,
|
||||||
|
"dnc_norm": True, "bypass_dropout": False},
|
||||||
|
{"seed": 124, "input_size": 11, "batch_size": 3, "memory_length": 256, "memory_width": 23, "read_heads": 2,
|
||||||
|
"dnc_norm": False, "bypass_dropout": False},
|
||||||
|
{"seed": 125, "input_size": 5, "batch_size": 3, "memory_length": 4, "memory_width": 11, "read_heads": 8,
|
||||||
|
"dnc_norm": True, "bypass_dropout": True},
|
||||||
|
{"seed": 126, "input_size": 2, "batch_size": 3, "memory_length": 56, "memory_width": 9, "read_heads": 11,
|
||||||
|
"dnc_norm": False, "bypass_dropout": True}
|
||||||
|
])
|
||||||
|
def memory_config(request):
|
||||||
|
config = request.param
|
||||||
|
return BaseMemoryUnitCell(input_size=config['input_size'], memory_length=config["memory_length"],
|
||||||
|
memory_width=config["memory_width"],
|
||||||
|
read_heads=config["read_heads"], seed=config["seed"],
|
||||||
|
reuse=False, name='test_mu'), config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def session():
|
||||||
|
with tf.Session() as sess:
|
||||||
|
yield sess
|
||||||
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def np_rng():
|
||||||
|
seed = np.random.randint(1, 999)
|
||||||
|
return np.random.RandomState(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDNCMemoryUnit():
|
||||||
|
def test_init(self, memory_config):
|
||||||
|
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
assert isinstance(memory_unit, object)
|
||||||
|
assert isinstance(memory_unit.rng, np.random.RandomState)
|
||||||
|
|
||||||
|
assert memory_unit.h_N == config["memory_length"]
|
||||||
|
assert memory_unit.h_W == config["memory_width"]
|
||||||
|
assert memory_unit.h_RH == config["read_heads"]
|
||||||
|
|
||||||
|
def test_property_output_size(self, memory_config, session):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
output_size = memory_unit.output_size
|
||||||
|
assert output_size == config['memory_width'] * config["read_heads"] + config['input_size']
|
||||||
|
|
||||||
|
def test_calculate_content_weightings(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_memory = np_rng.normal(0, 1, (config['batch_size'], config['memory_length'], config['memory_width']))
|
||||||
|
np_keys = np_rng.normal(0, 2, (config['batch_size'], 1, config['memory_width']))
|
||||||
|
np_strengths = np_rng.uniform(1, 10, (config['batch_size'], 1))
|
||||||
|
|
||||||
|
memory = tf.constant(np_memory, dtype=tf.float32)
|
||||||
|
keys = tf.constant(np_keys, dtype=tf.float32)
|
||||||
|
strengths = tf.constant(np_strengths, dtype=tf.float32)
|
||||||
|
|
||||||
|
content_weightings = memory_unit._calculate_content_weightings(memory, keys, strengths)
|
||||||
|
weightings = content_weightings.eval()
|
||||||
|
|
||||||
|
np_similarity = np.empty([config['batch_size'], config['memory_length']])
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for l in range(config['memory_length']):
|
||||||
|
np_similarity[b, l] = np.dot(np_memory[b, l, :], np_keys[b, 0, :]) / (
|
||||||
|
np.sqrt(np.dot(np_memory[b, l, :], np_memory[b, l, :])) * np.sqrt(
|
||||||
|
np.dot(np_keys[b, 0, :], np_keys[b, 0, :])))
|
||||||
|
|
||||||
|
def _weighted_softmax(x, s):
|
||||||
|
e_x = np.exp(x * s)
|
||||||
|
return e_x / e_x.sum(axis=1, keepdims=True)
|
||||||
|
|
||||||
|
np_weightings = _weighted_softmax(np_similarity, np_strengths)
|
||||||
|
|
||||||
|
assert weightings.shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert 0 <= weightings.min() and weightings.max() <= 1 and weightings.sum(axis=1).all() <= 1
|
||||||
|
assert np.allclose(weightings, np_weightings)
|
||||||
|
|
||||||
|
np_memory = np_rng.uniform(0, 1, (config['batch_size'], config['memory_length'], config['memory_width']))
|
||||||
|
np_keys = np_rng.normal(0, 2, (config['batch_size'], config['read_heads'], config['memory_width']))
|
||||||
|
np_strengths = np_rng.uniform(1, 10, (config['batch_size'], config['read_heads'], 1))
|
||||||
|
|
||||||
|
memory = tf.constant(np_memory, dtype=tf.float32)
|
||||||
|
keys = tf.constant(np_keys, dtype=tf.float32)
|
||||||
|
strengths = tf.constant(np_strengths, dtype=tf.float32)
|
||||||
|
|
||||||
|
content_weightings = memory_unit._calculate_content_weightings(memory, keys, strengths)
|
||||||
|
weightings = content_weightings.eval()
|
||||||
|
|
||||||
|
np_similarity = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for r in range(config['read_heads']):
|
||||||
|
for l in range(config['memory_length']):
|
||||||
|
np_similarity[b, r, l] = np.dot(np_memory[b, l, :], np_keys[b, r, :]) / (
|
||||||
|
np.sqrt(np.dot(np_memory[b, l, :], np_memory[b, l, :])) * np.sqrt(
|
||||||
|
np.dot(np_keys[b, r, :], np_keys[b, r, :])))
|
||||||
|
np_weightings = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
def _weighted_softmax(x, s):
|
||||||
|
e_x = np.exp(x * s)
|
||||||
|
return e_x / e_x.sum(axis=1, keepdims=True)
|
||||||
|
|
||||||
|
for r in range(config['read_heads']):
|
||||||
|
np_weightings[:, r, :] = _weighted_softmax(np_similarity[:, r, :], np_strengths[:, r])
|
||||||
|
|
||||||
|
assert weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
|
||||||
|
assert 0 <= weightings.min() and weightings.max() <= 1 and weightings.sum(axis=2).all() <= 1
|
||||||
|
assert np.allclose(weightings, np_weightings)
|
||||||
|
|
||||||
|
def test_read_memory(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
memory = tf.constant(np_memory, dtype=tf.float32)
|
||||||
|
read_weightings = tf.constant(np_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
read_vectors = memory_unit._read_memory(memory, read_weightings)
|
||||||
|
read_vectors = read_vectors.eval()
|
||||||
|
|
||||||
|
np_read_vectors = np.empty([config['batch_size'], config['read_heads'], config['memory_width']])
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for r in range(config['read_heads']):
|
||||||
|
np_read_vectors[b, r, :] = np.matmul(np.expand_dims(np_read_weightings[b, r, :], 0), np_memory[b, :, :])
|
||||||
|
|
||||||
|
assert read_vectors.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
|
||||||
|
assert np.allclose(read_vectors, np_read_vectors, atol=1e-06)
|
325
test/adnc/model/memory_units/test_content_based_cell.py
Executable file
325
test/adnc/model/memory_units/test_content_based_cell.py
Executable file
@ -0,0 +1,325 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
|
||||||
|
from adnc.model.memory_units.content_based_cell import ContentBasedMemoryUnitCell
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
params=[{"seed": 123, "input_size": 13, "batch_size": 3, "memory_length": 4, "memory_width": 4, "read_heads": 3,
|
||||||
|
"dnc_norm": True, "bypass_dropout": False},
|
||||||
|
{"seed": 124, "input_size": 11, "batch_size": 3, "memory_length": 256, "memory_width": 23, "read_heads": 2,
|
||||||
|
"dnc_norm": False, "bypass_dropout": False},
|
||||||
|
{"seed": 125, "input_size": 5, "batch_size": 3, "memory_length": 4, "memory_width": 11, "read_heads": 8,
|
||||||
|
"dnc_norm": True, "bypass_dropout": True},
|
||||||
|
{"seed": 126, "input_size": 2, "batch_size": 3, "memory_length": 56, "memory_width": 9, "read_heads": 11,
|
||||||
|
"dnc_norm": False, "bypass_dropout": True}
|
||||||
|
])
|
||||||
|
def memory_config(request):
|
||||||
|
config = request.param
|
||||||
|
return ContentBasedMemoryUnitCell(input_size=config['input_size'], memory_length=config["memory_length"],
|
||||||
|
memory_width=config["memory_width"],
|
||||||
|
read_heads=config["read_heads"], seed=config["seed"],
|
||||||
|
reuse=False, name='test_mu'), config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def session():
|
||||||
|
with tf.Session() as sess:
|
||||||
|
yield sess
|
||||||
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def np_rng():
|
||||||
|
seed = np.random.randint(1, 999)
|
||||||
|
return np.random.RandomState(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContentBasedMemoryUnitCell():
|
||||||
|
def test_zero_state(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
init_tuple = memory_unit.zero_state(batch_size=config['batch_size'], dtype=tf.float32)
|
||||||
|
|
||||||
|
# test init_tuple
|
||||||
|
init_memory, init_usage_vector, init_write_weighting, init_read_weighting = init_tuple
|
||||||
|
assert init_memory.eval().shape == (config['batch_size'], config['memory_length'], config['memory_width'])
|
||||||
|
assert init_usage_vector.eval().shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert init_write_weighting.eval().shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert init_read_weighting.eval().shape == (config['batch_size'], config["read_heads"], config['memory_length'])
|
||||||
|
|
||||||
|
def test_parameter_amount(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 2 * config['read_heads'] + 3)
|
||||||
|
|
||||||
|
inputs = np.ones([config['batch_size'], config['input_size']])
|
||||||
|
tf_input = tf.constant(inputs, tf.float32)
|
||||||
|
|
||||||
|
memory_unit._weight_input(tf_input)
|
||||||
|
parameter_amount = memory_unit.parameter_amount
|
||||||
|
|
||||||
|
assert parameter_amount == (config['input_size'] + 1) * total_signal_size
|
||||||
|
|
||||||
|
def test_create_constant_value_tensors(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
memory_ones, batch_memory_range = memory_unit._create_constant_value_tensors(batch_size=config['batch_size'],
|
||||||
|
dtype=tf.float32)
|
||||||
|
|
||||||
|
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
assert np.array_equal(memory_ones.eval(), np_memory_ones)
|
||||||
|
|
||||||
|
np_batch_range = np.arange(0, config['batch_size'])
|
||||||
|
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
|
||||||
|
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
|
||||||
|
np.expand_dims(np_repeat_memory_length, 0))
|
||||||
|
assert np.array_equal(batch_memory_range.eval(), np_batch_memory_range)
|
||||||
|
|
||||||
|
def test_weight_input(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
inputs = np.ones([config['batch_size'], config['input_size']])
|
||||||
|
tf_input = tf.placeholder(tf.float32, [config['batch_size'], config['input_size']], name='x')
|
||||||
|
|
||||||
|
weight_inputs = memory_unit._weight_input(tf_input)
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
np_weight_inputs = weight_inputs.eval(session=session, feed_dict={tf_input: inputs})
|
||||||
|
|
||||||
|
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 2 * config['read_heads'] + 3)
|
||||||
|
assert np_weight_inputs.shape == (config['batch_size'], total_signal_size)
|
||||||
|
|
||||||
|
def test_create_control_signals(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 5 * config['read_heads'] + 3)
|
||||||
|
np_weighted_input = np.array([np.arange(1, 1 + total_signal_size)] * config['batch_size'])
|
||||||
|
|
||||||
|
weighted_input = tf.constant(np_weighted_input, dtype=tf.float32)
|
||||||
|
|
||||||
|
memory_unit.h_B = config['batch_size']
|
||||||
|
control_signals = memory_unit._create_control_signals(weighted_input)
|
||||||
|
control_signals = session.run(control_signals)
|
||||||
|
|
||||||
|
alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths = control_signals
|
||||||
|
|
||||||
|
assert alloc_gates.shape == (config['batch_size'], 1)
|
||||||
|
assert 0 <= alloc_gates.min() and alloc_gates.max() <= 1
|
||||||
|
assert free_gates.shape == (config['batch_size'], config['read_heads'], 1)
|
||||||
|
assert 0 <= free_gates.min() and free_gates.max() <= 1
|
||||||
|
assert write_gates.shape == (config['batch_size'], 1)
|
||||||
|
assert 0 <= write_gates.min() and write_gates.max() <= 1
|
||||||
|
|
||||||
|
assert write_keys.shape == (config['batch_size'], 1, config['memory_width'])
|
||||||
|
assert write_strengths.shape == (config['batch_size'], 1)
|
||||||
|
assert 1 <= write_strengths.min()
|
||||||
|
assert write_vector.shape == (config['batch_size'], 1, config['memory_width'])
|
||||||
|
assert erase_vector.shape == (config['batch_size'], 1, config['memory_width'])
|
||||||
|
assert 0 <= erase_vector.min() and erase_vector.max() <= 1
|
||||||
|
# comment
|
||||||
|
assert read_keys.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
|
||||||
|
assert read_strengths.shape == (config['batch_size'], config['read_heads'], 1)
|
||||||
|
assert 1 <= read_strengths.min()
|
||||||
|
|
||||||
|
def test_update_alloc_weightings_and_usage_vectors(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_pre_write_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_usage_vectors = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_free_gates = np.ones([config['batch_size'], config['read_heads'], 1]) * 0.5
|
||||||
|
|
||||||
|
pre_write_weightings = tf.constant(np_pre_write_weightings, dtype=tf.float32)
|
||||||
|
pre_usage_vectors = tf.constant(np_pre_usage_vectors, dtype=tf.float32)
|
||||||
|
free_gates = tf.constant(np_free_gates, dtype=tf.float32)
|
||||||
|
|
||||||
|
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
inputs = tf.constant(np_inputs, dtype=tf.float32)
|
||||||
|
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
|
||||||
|
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
|
||||||
|
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_read_weightings)
|
||||||
|
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
memory_unit(inputs, pre_states) # just for initialization
|
||||||
|
|
||||||
|
alloc_weightings, usage_vectors = memory_unit._update_alloc_and_usage_vectors(pre_write_weightings,
|
||||||
|
pre_read_weightings,
|
||||||
|
pre_usage_vectors, free_gates)
|
||||||
|
alloc_weightings, usage_vectors = session.run([alloc_weightings, usage_vectors])
|
||||||
|
|
||||||
|
np_retention_vector = np.prod(1 - np_free_gates * np_pre_read_weightings, axis=1, keepdims=False)
|
||||||
|
np_usage_vectors = (
|
||||||
|
np_pre_usage_vectors + np_pre_write_weightings - np_pre_usage_vectors * np_pre_write_weightings) * np_retention_vector
|
||||||
|
|
||||||
|
assert usage_vectors.shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert usage_vectors.min() >= 0 and usage_vectors.max() <= 1
|
||||||
|
assert np.allclose(usage_vectors, np_usage_vectors)
|
||||||
|
|
||||||
|
free_list = np.argsort(np_usage_vectors).astype(int)
|
||||||
|
|
||||||
|
np_alloc_weightings = np.zeros([config['batch_size'], config['memory_length']])
|
||||||
|
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for j in range(config['memory_length']):
|
||||||
|
fj = free_list[b, j]
|
||||||
|
np_alloc_weightings[b, fj] = (1 - np_usage_vectors[b, fj]) * np.prod(
|
||||||
|
[np_usage_vectors[b, free_list[b, i]] for i in range(j)])
|
||||||
|
|
||||||
|
assert alloc_weightings.shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert np.allclose(alloc_weightings, np_alloc_weightings)
|
||||||
|
|
||||||
|
def test_update_write_weighting(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_alloc_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_write_content_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_write_gate = np.ones([config['batch_size'], 1]) * 0.5
|
||||||
|
np_alloc_gate = np.ones([config['batch_size'], 1]) * 0.5
|
||||||
|
|
||||||
|
alloc_weighting = tf.constant(np_alloc_weighting, dtype=tf.float32)
|
||||||
|
write_content_weighting = tf.constant(np_write_content_weighting, dtype=tf.float32)
|
||||||
|
write_gate = tf.constant(np_write_gate, dtype=tf.float32)
|
||||||
|
alloc_gate = tf.constant(np_alloc_gate, dtype=tf.float32)
|
||||||
|
|
||||||
|
write_weighting = memory_unit._update_write_weighting(alloc_weighting, write_content_weighting, write_gate,
|
||||||
|
alloc_gate)
|
||||||
|
write_weighting = write_weighting.eval()
|
||||||
|
|
||||||
|
np_write_weighting = np_write_gate * (
|
||||||
|
np_alloc_gate * np_alloc_weighting + (1 - np_alloc_gate) * np_write_content_weighting)
|
||||||
|
|
||||||
|
assert write_weighting.shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert 0 <= write_weighting.min() and write_weighting.max() <= 1 and write_weighting.sum(axis=1).all() <= 1
|
||||||
|
assert np.allclose(write_weighting, np_write_weighting)
|
||||||
|
|
||||||
|
def test_update_memory(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_write_vector = np_rng.normal(0, 2, [config['batch_size'], 1, config['memory_width']])
|
||||||
|
np_erase_vector = np_rng.uniform(0, 1, [config['batch_size'], 1, config['memory_width']])
|
||||||
|
|
||||||
|
write_weighting = tf.constant(np_write_weighting, dtype=tf.float32)
|
||||||
|
write_vector = tf.constant(np_write_vector, dtype=tf.float32)
|
||||||
|
erase_vector = tf.constant(np_erase_vector, dtype=tf.float32)
|
||||||
|
|
||||||
|
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
inputs = tf.constant(np_inputs, dtype=tf.float32)
|
||||||
|
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
|
||||||
|
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
|
||||||
|
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_read_weightings)
|
||||||
|
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
memory_unit(inputs, pre_states) # just for initialization
|
||||||
|
|
||||||
|
memory = memory_unit._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
|
||||||
|
memory = memory.eval()
|
||||||
|
|
||||||
|
write_w = np.expand_dims(np_write_weighting, 2)
|
||||||
|
np_erase_memory = (1 - (write_w * np_erase_vector))
|
||||||
|
np_add_memory = np.matmul(write_w, np_write_vector)
|
||||||
|
np_memory = np_pre_memory * np_erase_memory + np_add_memory
|
||||||
|
|
||||||
|
assert memory.shape == (config['batch_size'], config['memory_length'], config['memory_width'])
|
||||||
|
assert np.allclose(memory, np_memory, atol=1e-06)
|
||||||
|
|
||||||
|
def test_read_memory(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
memory = tf.constant(np_memory, dtype=tf.float32)
|
||||||
|
read_weightings = tf.constant(np_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
read_vectors = memory_unit._read_memory(memory, read_weightings)
|
||||||
|
read_vectors = read_vectors.eval()
|
||||||
|
|
||||||
|
np_read_vectors = np.empty([config['batch_size'], config['read_heads'], config['memory_width']])
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for r in range(config['read_heads']):
|
||||||
|
np_read_vectors[b, r, :] = np.matmul(np.expand_dims(np_read_weightings[b, r, :], 0), np_memory[b, :, :])
|
||||||
|
|
||||||
|
assert read_vectors.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
|
||||||
|
assert np.allclose(read_vectors, np_read_vectors, atol=1e-06)
|
||||||
|
|
||||||
|
def test_call(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
inputs = tf.constant(np_inputs, dtype=tf.float32)
|
||||||
|
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
|
||||||
|
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
|
||||||
|
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_read_weightings)
|
||||||
|
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
read_vectors, states = memory_unit(inputs, pre_states)
|
||||||
|
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
read_vectors, states = session.run([read_vectors, states])
|
||||||
|
|
||||||
|
# test const initialization
|
||||||
|
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
assert np.array_equal(memory_unit.const_memory_ones.eval(), np_memory_ones)
|
||||||
|
np_batch_range = np.arange(0, config['batch_size'])
|
||||||
|
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
|
||||||
|
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
|
||||||
|
np.expand_dims(np_repeat_memory_length, 0))
|
||||||
|
assert np.array_equal(memory_unit.const_batch_memory_range.eval(), np_batch_memory_range)
|
||||||
|
|
||||||
|
assert read_vectors.shape == (
|
||||||
|
config['batch_size'], config['memory_width'] * config['read_heads'] + config['input_size'])
|
466
test/adnc/model/memory_units/test_dnc_cell.py
Executable file
466
test/adnc/model/memory_units/test_dnc_cell.py
Executable file
@ -0,0 +1,466 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
params=[{"seed": 123, "input_size": 13, "batch_size": 3, "memory_length": 4, "memory_width": 4, "read_heads": 3,
|
||||||
|
"dnc_norm": True, "bypass_dropout": False},
|
||||||
|
{"seed": 124, "input_size": 11, "batch_size": 3, "memory_length": 256, "memory_width": 23, "read_heads": 2,
|
||||||
|
"dnc_norm": False, "bypass_dropout": False},
|
||||||
|
{"seed": 125, "input_size": 5, "batch_size": 3, "memory_length": 4, "memory_width": 11, "read_heads": 8,
|
||||||
|
"dnc_norm": True, "bypass_dropout": True},
|
||||||
|
{"seed": 126, "input_size": 2, "batch_size": 3, "memory_length": 56, "memory_width": 9, "read_heads": 11,
|
||||||
|
"dnc_norm": False, "bypass_dropout": True}
|
||||||
|
])
|
||||||
|
def memory_config(request):
|
||||||
|
config = request.param
|
||||||
|
return DNCMemoryUnitCell(input_size=config['input_size'], memory_length=config["memory_length"],
|
||||||
|
memory_width=config["memory_width"],
|
||||||
|
read_heads=config["read_heads"], seed=config["seed"],
|
||||||
|
reuse=False, name='test_mu'), config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def session():
|
||||||
|
with tf.Session() as sess:
|
||||||
|
yield sess
|
||||||
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def np_rng():
|
||||||
|
seed = np.random.randint(1, 999)
|
||||||
|
return np.random.RandomState(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDNCMemoryUnit():
|
||||||
|
def test_zero_state(self, memory_config, session):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
init_tuple = memory_unit.zero_state(batch_size=config['batch_size'], dtype=tf.float32)
|
||||||
|
|
||||||
|
# test init_tuple
|
||||||
|
init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings, init_link_mat, init_read_weighting = init_tuple
|
||||||
|
assert init_memory.eval().shape == (config['batch_size'], config['memory_length'], config['memory_width'])
|
||||||
|
assert init_usage_vector.eval().shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert init_write_weighting.eval().shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert init_precedence_weightings.eval().shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert init_link_mat.eval().shape == (config['batch_size'], config['memory_length'], config['memory_length'])
|
||||||
|
assert init_read_weighting.eval().shape == (config['batch_size'], config["read_heads"], config['memory_length'])
|
||||||
|
|
||||||
|
def test_parameter_amount(self, memory_config, session):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 5 * config['read_heads'] + 3)
|
||||||
|
|
||||||
|
inputs = np.ones([config['batch_size'], config['input_size']])
|
||||||
|
tf_input = tf.constant(inputs, tf.float32)
|
||||||
|
|
||||||
|
memory_unit._weight_input(tf_input)
|
||||||
|
parameter_amount = memory_unit.parameter_amount
|
||||||
|
|
||||||
|
assert parameter_amount == (config['input_size'] + 1) * total_signal_size
|
||||||
|
|
||||||
|
def test_create_constant_value_tensors(self, memory_config, session):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
link_matrix_inv_eye, memory_ones, batch_memory_range = memory_unit._create_constant_value_tensors(
|
||||||
|
batch_size=config['batch_size'], dtype=tf.float32)
|
||||||
|
|
||||||
|
np_link_matrix_inv_eye = np.ones([config['memory_length'], config['memory_length']]) - np.eye(
|
||||||
|
config['memory_length'])
|
||||||
|
assert np.array_equal(link_matrix_inv_eye.eval(), np_link_matrix_inv_eye)
|
||||||
|
|
||||||
|
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
assert np.array_equal(memory_ones.eval(), np_memory_ones)
|
||||||
|
|
||||||
|
np_batch_range = np.arange(0, config['batch_size'])
|
||||||
|
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
|
||||||
|
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
|
||||||
|
np.expand_dims(np_repeat_memory_length, 0))
|
||||||
|
assert np.array_equal(batch_memory_range.eval(), np_batch_memory_range)
|
||||||
|
|
||||||
|
def test_weight_input(self, memory_config, session):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
mu_weight_test = DNCMemoryUnitCell(memory_length=config["memory_length"], memory_width=config["memory_width"],
|
||||||
|
read_heads=config["read_heads"], input_size=config['input_size'],
|
||||||
|
seed=config["seed"],
|
||||||
|
reuse=False, name='dnc_mu_weight_test')
|
||||||
|
|
||||||
|
inputs = np.ones([config['batch_size'], config['input_size']])
|
||||||
|
tf_input = tf.placeholder(tf.float32, [config['batch_size'], config['input_size']], name='x')
|
||||||
|
|
||||||
|
weight_inputs = mu_weight_test._weight_input(tf_input)
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
np_weight_inputs = weight_inputs.eval(session=session, feed_dict={tf_input: inputs})
|
||||||
|
|
||||||
|
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 5 * config['read_heads'] + 3)
|
||||||
|
assert np_weight_inputs.shape == (config['batch_size'], total_signal_size)
|
||||||
|
|
||||||
|
def test_create_control_signals(self, memory_config, session):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
total_signal_size = (config['memory_width'] * (3 + config["read_heads"]) + 5 * config['read_heads'] + 3)
|
||||||
|
np_weighted_input = np.array([np.arange(1, 1 + total_signal_size)] * config['batch_size'])
|
||||||
|
|
||||||
|
weighted_input = tf.constant(np_weighted_input, dtype=tf.float32)
|
||||||
|
|
||||||
|
memory_unit.h_B = config['batch_size']
|
||||||
|
control_signals = memory_unit._create_control_signals(weighted_input)
|
||||||
|
control_signals = session.run(control_signals)
|
||||||
|
|
||||||
|
alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vector, \
|
||||||
|
erase_vector, read_keys, read_strengths, read_modes = control_signals
|
||||||
|
|
||||||
|
assert alloc_gates.shape == (config['batch_size'], 1)
|
||||||
|
assert 0 <= alloc_gates.min() and alloc_gates.max() <= 1
|
||||||
|
assert free_gates.shape == (config['batch_size'], config['read_heads'], 1)
|
||||||
|
assert 0 <= free_gates.min() and free_gates.max() <= 1
|
||||||
|
assert write_gates.shape == (config['batch_size'], 1)
|
||||||
|
assert 0 <= write_gates.min() and write_gates.max() <= 1
|
||||||
|
|
||||||
|
assert write_keys.shape == (config['batch_size'], 1, config['memory_width'])
|
||||||
|
assert write_strengths.shape == (config['batch_size'], 1)
|
||||||
|
assert 1 <= write_strengths.min()
|
||||||
|
assert write_vector.shape == (config['batch_size'], 1, config['memory_width'])
|
||||||
|
assert erase_vector.shape == (config['batch_size'], 1, config['memory_width'])
|
||||||
|
assert 0 <= erase_vector.min() and erase_vector.max() <= 1
|
||||||
|
# comment
|
||||||
|
assert read_keys.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
|
||||||
|
assert read_strengths.shape == (config['batch_size'], config['read_heads'], 1)
|
||||||
|
assert 1 <= read_strengths.min()
|
||||||
|
assert read_modes.shape == (config['batch_size'], config['read_heads'], 3) # 3 read modes
|
||||||
|
assert 0 <= read_modes.min() and read_modes.max() <= 1 and read_modes.sum(axis=2).all() == 1
|
||||||
|
|
||||||
|
def test_update_alloc_weightings_and_usage_vectors(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_pre_write_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
prw_rand = np.arange(0, config['memory_length']) / config['memory_length']
|
||||||
|
np_pre_read_weightings = np.stack([prw_rand, ] * config['read_heads'], 0)
|
||||||
|
np_pre_read_weightings = np.stack([np_pre_read_weightings, ] * config['batch_size'], 0)
|
||||||
|
np_pre_usage_vectors = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_free_gates = np.ones([config['batch_size'], config['read_heads'], 1]) * 0.5
|
||||||
|
|
||||||
|
pre_write_weightings = tf.constant(np_pre_write_weightings, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
pre_usage_vectors = tf.constant(np_pre_usage_vectors, dtype=tf.float32)
|
||||||
|
free_gates = tf.constant(np_free_gates, dtype=tf.float32)
|
||||||
|
|
||||||
|
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_link_matrix = np.zeros([config['batch_size'], config['memory_length'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
inputs = tf.constant(np_inputs, dtype=tf.float32)
|
||||||
|
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
|
||||||
|
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
|
||||||
|
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
|
||||||
|
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
|
||||||
|
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
|
||||||
|
pre_read_weightings)
|
||||||
|
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
memory_unit(inputs, pre_states) # just for initialization
|
||||||
|
|
||||||
|
alloc_weightings, usage_vectors = memory_unit._update_alloc_and_usage_vectors(pre_write_weightings,
|
||||||
|
pre_read_weightings,
|
||||||
|
pre_usage_vectors, free_gates)
|
||||||
|
alloc_weightings, usage_vectors = session.run([alloc_weightings, usage_vectors])
|
||||||
|
|
||||||
|
np_retention_vector = np.prod(1 - np_free_gates * np_pre_read_weightings, axis=1, keepdims=False)
|
||||||
|
np_usage_vectors = (
|
||||||
|
np_pre_usage_vectors + np_pre_write_weightings - np_pre_usage_vectors * np_pre_write_weightings) * np_retention_vector
|
||||||
|
|
||||||
|
assert usage_vectors.shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert usage_vectors.min() >= 0 and usage_vectors.max() <= 1
|
||||||
|
assert np.allclose(usage_vectors, np_usage_vectors)
|
||||||
|
|
||||||
|
free_list = np.argsort(np_usage_vectors).astype(int)
|
||||||
|
|
||||||
|
np_alloc_weightings = np.zeros([config['batch_size'], config['memory_length']])
|
||||||
|
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for j in range(config['memory_length']):
|
||||||
|
fj = free_list[b, j]
|
||||||
|
np_alloc_weightings[b, fj] = (1 - np_usage_vectors[b, fj]) * np.prod(
|
||||||
|
[np_usage_vectors[b, free_list[b, i]] for i in range(j)])
|
||||||
|
|
||||||
|
assert alloc_weightings.shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert np.allclose(alloc_weightings, np_alloc_weightings)
|
||||||
|
|
||||||
|
def test_update_write_weighting(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_alloc_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_write_content_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_write_gate = np.ones([config['batch_size'], 1]) * 0.5
|
||||||
|
np_alloc_gate = np.ones([config['batch_size'], 1]) * 0.5
|
||||||
|
|
||||||
|
alloc_weighting = tf.constant(np_alloc_weighting, dtype=tf.float32)
|
||||||
|
write_content_weighting = tf.constant(np_write_content_weighting, dtype=tf.float32)
|
||||||
|
write_gate = tf.constant(np_write_gate, dtype=tf.float32)
|
||||||
|
alloc_gate = tf.constant(np_alloc_gate, dtype=tf.float32)
|
||||||
|
|
||||||
|
write_weighting = memory_unit._update_write_weighting(alloc_weighting, write_content_weighting, write_gate,
|
||||||
|
alloc_gate)
|
||||||
|
write_weighting = write_weighting.eval()
|
||||||
|
|
||||||
|
np_write_weighting = np_write_gate * (
|
||||||
|
np_alloc_gate * np_alloc_weighting + (1 - np_alloc_gate) * np_write_content_weighting)
|
||||||
|
|
||||||
|
assert write_weighting.shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert 0 <= write_weighting.min() and write_weighting.max() <= 1 and write_weighting.sum(axis=1).all() <= 1
|
||||||
|
assert np.allclose(write_weighting, np_write_weighting)
|
||||||
|
|
||||||
|
def test_update_memory(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_write_vector = np_rng.normal(0, 2, [config['batch_size'], 1, config['memory_width']])
|
||||||
|
np_erase_vector = np_rng.uniform(0, 1, [config['batch_size'], 1, config['memory_width']])
|
||||||
|
|
||||||
|
write_weighting = tf.constant(np_write_weighting, dtype=tf.float32)
|
||||||
|
write_vector = tf.constant(np_write_vector, dtype=tf.float32)
|
||||||
|
erase_vector = tf.constant(np_erase_vector, dtype=tf.float32)
|
||||||
|
|
||||||
|
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_link_matrix = np.zeros([config['batch_size'], config['memory_length'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
inputs = tf.constant(np_inputs, dtype=tf.float32)
|
||||||
|
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
|
||||||
|
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
|
||||||
|
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
|
||||||
|
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
|
||||||
|
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
|
||||||
|
pre_read_weightings)
|
||||||
|
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
memory_unit(inputs, pre_states) # just for initialization
|
||||||
|
|
||||||
|
memory = memory_unit._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
|
||||||
|
memory = memory.eval()
|
||||||
|
|
||||||
|
write_w = np.expand_dims(np_write_weighting, 2)
|
||||||
|
np_erase_memory = (1 - (write_w * np_erase_vector))
|
||||||
|
np_add_memory = np.matmul(write_w, np_write_vector)
|
||||||
|
np_memory = np_pre_memory * np_erase_memory + np_add_memory
|
||||||
|
|
||||||
|
assert memory.shape == (config['batch_size'], config['memory_length'], config['memory_width'])
|
||||||
|
assert np.allclose(memory, np_memory, atol=1e-06)
|
||||||
|
|
||||||
|
def test_update_link_matrix(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
|
||||||
|
write_weighting = tf.constant(np_write_weighting, dtype=tf.float32)
|
||||||
|
|
||||||
|
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_link_matrix = np.zeros([config['batch_size'], config['memory_length'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
inputs = tf.constant(np_inputs, dtype=tf.float32)
|
||||||
|
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
|
||||||
|
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
|
||||||
|
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
|
||||||
|
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
|
||||||
|
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
|
||||||
|
pre_read_weightings)
|
||||||
|
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
memory_unit(inputs, pre_states) # just for initialization
|
||||||
|
|
||||||
|
link_matrix, precedence_weighting = memory_unit._update_link_matrix(pre_link_matrix, write_weighting,
|
||||||
|
pre_precedence_weighting)
|
||||||
|
link_matrix, precedence_weighting = session.run([link_matrix, precedence_weighting])
|
||||||
|
|
||||||
|
np_precedence_weighting = (1 - np.sum(np_write_weighting, axis=1,
|
||||||
|
keepdims=True)) * np_pre_precedence_weighting + np_write_weighting
|
||||||
|
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for i in range(config['memory_length']):
|
||||||
|
for j in range(config['memory_length']):
|
||||||
|
if i == j:
|
||||||
|
np_pre_link_matrix[b, i, j] = 0
|
||||||
|
else:
|
||||||
|
np_pre_link_matrix[b, i, j] = (1 - np_write_weighting[b, i] - np_write_weighting[b, j]) * \
|
||||||
|
np_pre_link_matrix[b, i, j] + np_write_weighting[b, i] * \
|
||||||
|
np_pre_precedence_weighting[b, j]
|
||||||
|
np_link_matrix = np_pre_link_matrix
|
||||||
|
|
||||||
|
assert precedence_weighting.shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert 0 <= precedence_weighting.min() and precedence_weighting.max() <= 1 and precedence_weighting.sum(
|
||||||
|
axis=1).all() <= 1
|
||||||
|
assert np.allclose(precedence_weighting, np_precedence_weighting)
|
||||||
|
|
||||||
|
assert link_matrix.shape == (config['batch_size'], config['memory_length'], config['memory_length'])
|
||||||
|
assert np.allclose(link_matrix, np_link_matrix)
|
||||||
|
|
||||||
|
def test_make_read_forward_backward_weightings(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_link_matrix = np.zeros([config['batch_size'], config['memory_length'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
link_matrix = tf.constant(np_link_matrix, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
forward_weightings, backward_weightings = memory_unit._make_read_forward_backward_weightings(link_matrix,
|
||||||
|
pre_read_weightings)
|
||||||
|
forward_weightings, backward_weightings = session.run([forward_weightings, backward_weightings])
|
||||||
|
|
||||||
|
np_forward_weightings = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
np_backward_weightings = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for r in range(config['read_heads']):
|
||||||
|
np_forward_weightings[b, r, :] = np.matmul(np_link_matrix[b, :, :], np_pre_read_weightings[b, r, :])
|
||||||
|
np_backward_weightings[b, r, :] = np.matmul(np.transpose(np_link_matrix[b, :, :]),
|
||||||
|
np_pre_read_weightings[b, r, :])
|
||||||
|
|
||||||
|
assert forward_weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
|
||||||
|
assert 0 <= forward_weightings.min() and forward_weightings.max() <= 1 and forward_weightings.sum(
|
||||||
|
axis=1).all() <= 1
|
||||||
|
assert np.allclose(forward_weightings, np_forward_weightings)
|
||||||
|
|
||||||
|
assert backward_weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
|
||||||
|
assert 0 <= backward_weightings.min() and backward_weightings.max() <= 1 and backward_weightings.sum(
|
||||||
|
axis=1).all() <= 1
|
||||||
|
assert np.allclose(backward_weightings, np_backward_weightings)
|
||||||
|
|
||||||
|
def test_make_read_weightings(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_forward_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
np_backward_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
np_read_content_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'],
|
||||||
|
config['memory_length']])
|
||||||
|
np_read_modes = np.reshape(np.repeat([0.2, 0.3, 0.5], config['batch_size'] * config['read_heads']),
|
||||||
|
[config['batch_size'], config['read_heads'], 3])
|
||||||
|
|
||||||
|
forward_weightings = tf.constant(np_forward_weightings, dtype=tf.float32)
|
||||||
|
backward_weightings = tf.constant(np_backward_weightings, dtype=tf.float32)
|
||||||
|
read_content_weightings = tf.constant(np_read_content_weightings, dtype=tf.float32)
|
||||||
|
read_modes = tf.constant(np_read_modes, dtype=tf.float32)
|
||||||
|
|
||||||
|
read_weightings = memory_unit._make_read_weightings(forward_weightings, backward_weightings,
|
||||||
|
read_content_weightings, read_modes)
|
||||||
|
read_weightings = read_weightings.eval()
|
||||||
|
|
||||||
|
np_read_weightings = np_backward_weightings * np.expand_dims(np_read_modes[:, :, 0], 2) + \
|
||||||
|
np_read_content_weightings * np.expand_dims(np_read_modes[:, :, 1], 2) + \
|
||||||
|
np_forward_weightings * np.expand_dims(np_read_modes[:, :, 2], 2)
|
||||||
|
|
||||||
|
assert read_weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
|
||||||
|
assert 0 <= read_weightings.min() and read_weightings.max() <= 1 and read_weightings.sum(axis=1).all() <= 1
|
||||||
|
assert np.allclose(read_weightings, np_read_weightings)
|
||||||
|
|
||||||
|
def test_call(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_link_matrix = np.zeros([config['batch_size'], config['memory_length'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
inputs = tf.constant(np_inputs, dtype=tf.float32)
|
||||||
|
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
|
||||||
|
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
|
||||||
|
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
|
||||||
|
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
|
||||||
|
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
|
||||||
|
pre_read_weightings)
|
||||||
|
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
read_vectors, states = memory_unit(inputs, pre_states)
|
||||||
|
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
read_vectors, states = session.run([read_vectors, states])
|
||||||
|
|
||||||
|
# test const initialization
|
||||||
|
np_link_matrix_inv_eye = np.ones([config['memory_length'], config['memory_length']]) - np.eye(
|
||||||
|
config['memory_length'])
|
||||||
|
assert np.array_equal(memory_unit.const_link_matrix_inv_eye.eval(), np_link_matrix_inv_eye)
|
||||||
|
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
assert np.array_equal(memory_unit.const_memory_ones.eval(), np_memory_ones)
|
||||||
|
np_batch_range = np.arange(0, config['batch_size'])
|
||||||
|
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
|
||||||
|
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
|
||||||
|
np.expand_dims(np_repeat_memory_length, 0))
|
||||||
|
assert np.array_equal(memory_unit.const_batch_memory_range.eval(), np_batch_memory_range)
|
||||||
|
|
||||||
|
assert read_vectors.shape == (
|
||||||
|
config['batch_size'], config['memory_width'] * config['read_heads'] + config['input_size'])
|
194
test/adnc/model/memory_units/test_multi_write_content_based_cell.py
Executable file
194
test/adnc/model/memory_units/test_multi_write_content_based_cell.py
Executable file
@ -0,0 +1,194 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from adnc.model.memory_units.multi_write_content_based_cell import MWContentMemoryUnitCell
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
params=[{"seed": 123, "input_size": 13, "batch_size": 3, "memory_length": 4, "memory_width": 4, "read_heads": 3,
|
||||||
|
"write_heads": 3, "dnc_norm": True, "bypass_dropout": False},
|
||||||
|
{"seed": 124, "input_size": 11, "batch_size": 3, "memory_length": 256, "memory_width": 23, "read_heads": 2,
|
||||||
|
"write_heads": 2, "dnc_norm": False, "bypass_dropout": False},
|
||||||
|
{"seed": 125, "input_size": 5, "batch_size": 3, "memory_length": 4, "memory_width": 11, "read_heads": 8,
|
||||||
|
"write_heads": 5, "dnc_norm": True, "bypass_dropout": True},
|
||||||
|
{"seed": 126, "input_size": 2, "batch_size": 3, "memory_length": 56, "memory_width": 9, "read_heads": 11,
|
||||||
|
"write_heads": 9, "dnc_norm": False, "bypass_dropout": True}
|
||||||
|
])
|
||||||
|
def memory_config(request):
|
||||||
|
config = request.param
|
||||||
|
return MWContentMemoryUnitCell(input_size=config['input_size'], memory_length=config["memory_length"],
|
||||||
|
memory_width=config["memory_width"], write_heads=config["write_heads"],
|
||||||
|
read_heads=config["read_heads"], seed=config["seed"],
|
||||||
|
reuse=False, name='test_mu'), config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def session():
|
||||||
|
with tf.Session() as sess:
|
||||||
|
yield sess
|
||||||
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def np_rng():
|
||||||
|
seed = np.random.randint(1, 999)
|
||||||
|
return np.random.RandomState(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMWContentMemoryUnitCell():
|
||||||
|
def test_parameter_amount(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
total_signal_size = (
|
||||||
|
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 2 * config['read_heads'] + 3 *
|
||||||
|
config["write_heads"])
|
||||||
|
|
||||||
|
inputs = np.ones([config['batch_size'], config['input_size']])
|
||||||
|
tf_input = tf.constant(inputs, tf.float32)
|
||||||
|
|
||||||
|
memory_unit._weight_input(tf_input)
|
||||||
|
parameter_amount = memory_unit.parameter_amount
|
||||||
|
|
||||||
|
assert parameter_amount == (config['input_size'] + 1) * total_signal_size
|
||||||
|
|
||||||
|
def test_create_constant_value_tensors(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
memory_ones, batch_memory_range = memory_unit._create_constant_value_tensors(
|
||||||
|
batch_size=config['batch_size'], dtype=tf.float32)
|
||||||
|
|
||||||
|
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
assert np.array_equal(memory_ones.eval(), np_memory_ones)
|
||||||
|
|
||||||
|
np_batch_range = np.arange(0, config['batch_size'])
|
||||||
|
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
|
||||||
|
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
|
||||||
|
np.expand_dims(np_repeat_memory_length, 0))
|
||||||
|
assert np.array_equal(batch_memory_range.eval(), np_batch_memory_range)
|
||||||
|
|
||||||
|
def test_zero_state(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
init_tuple = memory_unit.zero_state(batch_size=config['batch_size'], dtype=tf.float32)
|
||||||
|
|
||||||
|
# test init_tuple
|
||||||
|
init_memory, init_usage_vector, init_write_weighting, init_read_weighting = init_tuple
|
||||||
|
assert init_memory.eval().shape == (config['batch_size'], config['memory_length'], config['memory_width'])
|
||||||
|
assert init_usage_vector.eval().shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert init_write_weighting.eval().shape == (
|
||||||
|
config['batch_size'], config["write_heads"], config['memory_length'])
|
||||||
|
assert init_read_weighting.eval().shape == (config['batch_size'], config["read_heads"], config['memory_length'])
|
||||||
|
|
||||||
|
def test_weight_input(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
inputs = np.ones([config['batch_size'], config['input_size']])
|
||||||
|
tf_input = tf.placeholder(tf.float32, [config['batch_size'], config['input_size']], name='x')
|
||||||
|
|
||||||
|
weight_inputs = memory_unit._weight_input(tf_input)
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
np_weight_inputs = weight_inputs.eval(session=session, feed_dict={tf_input: inputs})
|
||||||
|
|
||||||
|
total_signal_size = (
|
||||||
|
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 2 * config['read_heads'] + 3 *
|
||||||
|
config["write_heads"])
|
||||||
|
assert np_weight_inputs.shape == (config['batch_size'], total_signal_size)
|
||||||
|
|
||||||
|
def test_create_control_signals(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
total_signal_size = (
|
||||||
|
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 2 * config['read_heads'] + 3 *
|
||||||
|
config["write_heads"])
|
||||||
|
np_weighted_input = np.array([np.arange(1, 1 + total_signal_size)] * config['batch_size'])
|
||||||
|
weighted_input = tf.constant(np_weighted_input, dtype=tf.float32)
|
||||||
|
|
||||||
|
memory_unit.h_B = config['batch_size']
|
||||||
|
control_signals = memory_unit._create_control_signals(weighted_input)
|
||||||
|
control_signals = session.run(control_signals)
|
||||||
|
|
||||||
|
alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
|
||||||
|
erase_vector, read_keys, read_strengths = control_signals
|
||||||
|
|
||||||
|
assert alloc_gates.shape == (config['batch_size'], config['write_heads'], 1)
|
||||||
|
assert 0 <= alloc_gates.min() and alloc_gates.max() <= 1
|
||||||
|
assert free_gates.shape == (config['batch_size'], config['read_heads'], 1)
|
||||||
|
assert 0 <= free_gates.min() and free_gates.max() <= 1
|
||||||
|
assert write_gates.shape == (config['batch_size'], config['write_heads'], 1)
|
||||||
|
assert 0 <= write_gates.min() and write_gates.max() <= 1
|
||||||
|
|
||||||
|
assert write_keys.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
|
||||||
|
assert write_strengths.shape == (config['batch_size'], config['write_heads'], 1)
|
||||||
|
assert 1 <= write_strengths.min()
|
||||||
|
assert write_vectors.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
|
||||||
|
assert erase_vector.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
|
||||||
|
assert 0 <= erase_vector.min() and erase_vector.max() <= 1
|
||||||
|
|
||||||
|
assert read_keys.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
|
||||||
|
assert read_strengths.shape == (config['batch_size'], config['read_heads'], 1)
|
||||||
|
assert 1 <= read_strengths.min()
|
||||||
|
|
||||||
|
def test_read_memory(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
memory = tf.constant(np_memory, dtype=tf.float32)
|
||||||
|
read_weightings = tf.constant(np_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
read_vectors = memory_unit._read_memory(memory, read_weightings)
|
||||||
|
read_vectors = read_vectors.eval()
|
||||||
|
|
||||||
|
np_read_vectors = np.empty([config['batch_size'], config['read_heads'], config['memory_width']])
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for r in range(config['read_heads']):
|
||||||
|
np_read_vectors[b, r, :] = np.matmul(np.expand_dims(np_read_weightings[b, r, :], 0), np_memory[b, :, :])
|
||||||
|
|
||||||
|
assert read_vectors.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
|
||||||
|
assert np.allclose(read_vectors, np_read_vectors, atol=1e-06)
|
||||||
|
|
||||||
|
def test_call(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['write_heads'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
inputs = tf.constant(np_inputs, dtype=tf.float32)
|
||||||
|
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
|
||||||
|
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
|
||||||
|
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_read_weightings)
|
||||||
|
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
read_vectors, states = memory_unit(inputs, pre_states)
|
||||||
|
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
read_vectors, states = session.run([read_vectors, states])
|
||||||
|
|
||||||
|
assert read_vectors.shape == (
|
||||||
|
config['batch_size'], config['memory_width'] * config['read_heads'] + config['input_size'])
|
521
test/adnc/model/memory_units/test_multi_write_dnc_cell.py
Executable file
521
test/adnc/model/memory_units/test_multi_write_dnc_cell.py
Executable file
@ -0,0 +1,521 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
params=[{"seed": 123, "input_size": 13, "batch_size": 3, "memory_length": 4, "memory_width": 4, "read_heads": 3,
|
||||||
|
"write_heads": 3, "dnc_norm": True, "bypass_dropout": False},
|
||||||
|
{"seed": 124, "input_size": 11, "batch_size": 3, "memory_length": 256, "memory_width": 23, "read_heads": 2,
|
||||||
|
"write_heads": 2, "dnc_norm": False, "bypass_dropout": False},
|
||||||
|
{"seed": 125, "input_size": 5, "batch_size": 3, "memory_length": 4, "memory_width": 11, "read_heads": 8,
|
||||||
|
"write_heads": 5, "dnc_norm": True, "bypass_dropout": True},
|
||||||
|
{"seed": 126, "input_size": 2, "batch_size": 3, "memory_length": 56, "memory_width": 9, "read_heads": 11,
|
||||||
|
"write_heads": 9, "dnc_norm": False, "bypass_dropout": True}
|
||||||
|
])
|
||||||
|
def memory_config(request):
|
||||||
|
config = request.param
|
||||||
|
return MWDNCMemoryUnitCell(input_size=config['input_size'], memory_length=config["memory_length"],
|
||||||
|
memory_width=config["memory_width"], write_heads=config["write_heads"],
|
||||||
|
read_heads=config["read_heads"], seed=config["seed"],
|
||||||
|
reuse=False, name='test_mu'), config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def session():
|
||||||
|
with tf.Session() as sess:
|
||||||
|
yield sess
|
||||||
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def np_rng():
|
||||||
|
seed = np.random.randint(1, 999)
|
||||||
|
return np.random.RandomState(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMWDNCMemoryUnit():
|
||||||
|
def test_init(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
assert isinstance(memory_unit, object)
|
||||||
|
assert isinstance(memory_unit.rng, np.random.RandomState)
|
||||||
|
|
||||||
|
assert memory_unit.h_N == config["memory_length"]
|
||||||
|
assert memory_unit.h_W == config["memory_width"]
|
||||||
|
assert memory_unit.h_RH == config["read_heads"]
|
||||||
|
assert memory_unit.h_WH == config["write_heads"]
|
||||||
|
|
||||||
|
def test_parameter_amount(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
total_signal_size = (
|
||||||
|
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 3 * config['read_heads'] + 3 *
|
||||||
|
config["write_heads"] + 2 * config['read_heads'] * config["write_heads"])
|
||||||
|
|
||||||
|
inputs = np.ones([config['batch_size'], config['input_size']])
|
||||||
|
tf_input = tf.constant(inputs, tf.float32)
|
||||||
|
|
||||||
|
memory_unit._weight_input(tf_input)
|
||||||
|
parameter_amount = memory_unit.parameter_amount
|
||||||
|
|
||||||
|
assert parameter_amount == (config['input_size'] + 1) * total_signal_size
|
||||||
|
|
||||||
|
def test_create_constant_value_tensors(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
link_matrix_inv_eye, memory_ones, batch_memory_range = memory_unit._create_constant_value_tensors(
|
||||||
|
batch_size=config['batch_size'], dtype=tf.float32)
|
||||||
|
|
||||||
|
np_link_matrix_inv_eye = np.ones([config['memory_length'], config['memory_length']]) - np.eye(
|
||||||
|
config['memory_length'])
|
||||||
|
np_link_matrix_inv_eye = np.stack([np_link_matrix_inv_eye, ] * config["write_heads"], axis=0)
|
||||||
|
np_link_matrix_inv_eye = np.stack([np_link_matrix_inv_eye, ] * config['batch_size'], axis=0)
|
||||||
|
|
||||||
|
assert np.array_equal(link_matrix_inv_eye.eval(), np_link_matrix_inv_eye)
|
||||||
|
|
||||||
|
np_memory_ones = np.ones([config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
assert np.array_equal(memory_ones.eval(), np_memory_ones)
|
||||||
|
|
||||||
|
np_batch_range = np.arange(0, config['batch_size'])
|
||||||
|
np_repeat_memory_length = np.repeat(config['memory_length'], config['memory_length'])
|
||||||
|
np_batch_memory_range = np.matmul(np.expand_dims(np_batch_range, axis=-1),
|
||||||
|
np.expand_dims(np_repeat_memory_length, 0))
|
||||||
|
assert np.array_equal(batch_memory_range.eval(), np_batch_memory_range)
|
||||||
|
|
||||||
|
def test_zero_state(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
init_tuple = memory_unit.zero_state(batch_size=config['batch_size'], dtype=tf.float32)
|
||||||
|
|
||||||
|
# test init_tuple
|
||||||
|
init_memory, init_usage_vector, init_write_weighting, init_precedence_weightings, init_link_mat, init_read_weighting = init_tuple
|
||||||
|
assert init_memory.eval().shape == (config['batch_size'], config['memory_length'], config['memory_width'])
|
||||||
|
assert init_usage_vector.eval().shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert init_write_weighting.eval().shape == (
|
||||||
|
config['batch_size'], config["write_heads"], config['memory_length'])
|
||||||
|
assert init_precedence_weightings.eval().shape == (
|
||||||
|
config['batch_size'], config['write_heads'], config['memory_length'])
|
||||||
|
assert init_link_mat.eval().shape == (
|
||||||
|
config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length'])
|
||||||
|
assert init_read_weighting.eval().shape == (config['batch_size'], config["read_heads"], config['memory_length'])
|
||||||
|
|
||||||
|
def test_weight_input(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
inputs = np.ones([config['batch_size'], config['input_size']])
|
||||||
|
tf_input = tf.placeholder(tf.float32, [config['batch_size'], config['input_size']], name='x')
|
||||||
|
|
||||||
|
weight_inputs = memory_unit._weight_input(tf_input)
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
np_weight_inputs = weight_inputs.eval(session=session, feed_dict={tf_input: inputs})
|
||||||
|
|
||||||
|
total_signal_size = (
|
||||||
|
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 3 * config['read_heads'] + 3 *
|
||||||
|
config["write_heads"] + 2 * config['read_heads'] * config["write_heads"])
|
||||||
|
assert np_weight_inputs.shape == (config['batch_size'], total_signal_size)
|
||||||
|
|
||||||
|
def test_create_control_signals(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
total_signal_size = (
|
||||||
|
config['memory_width'] * (3 * config["write_heads"] + config["read_heads"]) + 3 * config['read_heads'] + 3 *
|
||||||
|
config["write_heads"] + 2 * config['read_heads'] * config["write_heads"])
|
||||||
|
np_weighted_input = np.array([np.arange(1, 1 + total_signal_size)] * config['batch_size'])
|
||||||
|
weighted_input = tf.constant(np_weighted_input, dtype=tf.float32)
|
||||||
|
|
||||||
|
memory_unit.h_B = config['batch_size']
|
||||||
|
control_signals = memory_unit._create_control_signals(weighted_input)
|
||||||
|
control_signals = session.run(control_signals)
|
||||||
|
|
||||||
|
alloc_gates, free_gates, write_gates, write_keys, write_strengths, write_vectors, \
|
||||||
|
erase_vector, read_keys, read_strengths, read_modes = control_signals
|
||||||
|
|
||||||
|
assert alloc_gates.shape == (config['batch_size'], config['write_heads'], 1)
|
||||||
|
assert 0 <= alloc_gates.min() and alloc_gates.max() <= 1
|
||||||
|
assert free_gates.shape == (config['batch_size'], config['read_heads'], 1)
|
||||||
|
assert 0 <= free_gates.min() and free_gates.max() <= 1
|
||||||
|
assert write_gates.shape == (config['batch_size'], config['write_heads'], 1)
|
||||||
|
assert 0 <= write_gates.min() and write_gates.max() <= 1
|
||||||
|
|
||||||
|
assert write_keys.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
|
||||||
|
assert write_strengths.shape == (config['batch_size'], config['write_heads'], 1)
|
||||||
|
assert 1 <= write_strengths.min()
|
||||||
|
assert write_vectors.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
|
||||||
|
assert erase_vector.shape == (config['batch_size'], config['write_heads'], config['memory_width'])
|
||||||
|
assert 0 <= erase_vector.min() and erase_vector.max() <= 1
|
||||||
|
|
||||||
|
assert read_keys.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
|
||||||
|
assert read_strengths.shape == (config['batch_size'], config['read_heads'], 1)
|
||||||
|
assert 1 <= read_strengths.min()
|
||||||
|
assert read_modes.shape == (
|
||||||
|
config['batch_size'], config['read_heads'], 1 + 2 * config['write_heads']) # 3 read modes
|
||||||
|
assert 0 <= read_modes.min() and read_modes.max() <= 1 and read_modes.sum(axis=2).all() == 1
|
||||||
|
|
||||||
|
def test_update_alloc_weightings_and_usage_vectors(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_pre_link_matrix = np.zeros(
|
||||||
|
[config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length']])
|
||||||
|
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['write_heads'],
|
||||||
|
config['memory_length']])
|
||||||
|
np_pre_write_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['write_heads'], config['memory_length']])
|
||||||
|
prw_rand = np.arange(0, config['memory_length']) / config['memory_length']
|
||||||
|
np_pre_read_weightings = np.stack([prw_rand, ] * config['read_heads'], 0)
|
||||||
|
np_pre_read_weightings = np.stack([np_pre_read_weightings, ] * config['batch_size'], 0)
|
||||||
|
np_pre_usage_vectors = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_free_gates = np.ones([config['batch_size'], config['read_heads'], 1]) * 0.5
|
||||||
|
np_write_gates = np.ones([config['batch_size'], config['write_heads'], 1]) * 0.5
|
||||||
|
|
||||||
|
inputs = tf.constant(np_inputs, dtype=tf.float32)
|
||||||
|
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
|
||||||
|
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_write_weightings = tf.constant(np_pre_write_weightings, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
pre_usage_vectors = tf.constant(np_pre_usage_vectors, dtype=tf.float32)
|
||||||
|
free_gates = tf.constant(np_free_gates, dtype=tf.float32)
|
||||||
|
write_gates = tf.constant(np_write_gates, dtype=tf.float32)
|
||||||
|
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_states = (pre_memory, pre_usage_vectors, pre_write_weightings, pre_precedence_weighting, pre_link_matrix,
|
||||||
|
pre_read_weightings)
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
memory_unit(inputs, pre_states)
|
||||||
|
alloc_weightings, usage_vectors = memory_unit._update_alloc_and_usage_vectors(pre_write_weightings,
|
||||||
|
pre_read_weightings,
|
||||||
|
pre_usage_vectors, free_gates,
|
||||||
|
write_gates)
|
||||||
|
alloc_weightings, usage_vectors = session.run([alloc_weightings, usage_vectors])
|
||||||
|
|
||||||
|
np_pre_write_weighting = 1 - np.prod(1 - np_pre_write_weightings, axis=1, keepdims=False)
|
||||||
|
np_usage_vector = np_pre_usage_vectors + np_pre_write_weighting - np_pre_usage_vectors * np_pre_write_weighting
|
||||||
|
|
||||||
|
np_retention_vector = np.prod(1 - np_free_gates * np_pre_read_weightings, axis=1, keepdims=False)
|
||||||
|
np_usage_vector = np_usage_vector * np_retention_vector
|
||||||
|
|
||||||
|
assert usage_vectors.shape == (config['batch_size'], config['memory_length'])
|
||||||
|
assert usage_vectors.min() >= 0 and usage_vectors.max() <= 1
|
||||||
|
assert np.allclose(usage_vectors, np_usage_vector, atol=1e-06)
|
||||||
|
|
||||||
|
np_alloc_weightings = np.zeros([config['batch_size'], config['write_heads'], config['memory_length']])
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for w in range(config['write_heads']):
|
||||||
|
|
||||||
|
free_list = np.argsort(np_usage_vector, axis=1)
|
||||||
|
|
||||||
|
for j in range(config['memory_length']):
|
||||||
|
np_alloc_weightings[b, w, free_list[b, j]] = (1 - np_usage_vector[b, free_list[b, j]]) * np.prod(
|
||||||
|
[np_usage_vector[b, free_list[b, i]] for i in range(j)])
|
||||||
|
|
||||||
|
np_usage_vector[b, :] += (
|
||||||
|
(1 - np_usage_vector[b, :]) * np_write_gates[b, w, :] * np_alloc_weightings[b, w, :])
|
||||||
|
|
||||||
|
assert alloc_weightings.shape == (config['batch_size'], config['write_heads'], config['memory_length'])
|
||||||
|
assert np.allclose(alloc_weightings, np_alloc_weightings, atol=1e-06)
|
||||||
|
|
||||||
|
def test_calculate_content_weightings(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_memory = np_rng.uniform(0, 1, (config['batch_size'], config['memory_length'], config['memory_width']))
|
||||||
|
np_keys = np_rng.normal(0, 2, (config['batch_size'], config['read_heads'], config['memory_width']))
|
||||||
|
np_strengths = np_rng.uniform(1, 10, (config['batch_size'], config['read_heads'], 1))
|
||||||
|
|
||||||
|
memory = tf.constant(np_memory, dtype=tf.float32)
|
||||||
|
keys = tf.constant(np_keys, dtype=tf.float32)
|
||||||
|
strengths = tf.constant(np_strengths, dtype=tf.float32)
|
||||||
|
|
||||||
|
content_weightings = memory_unit._calculate_content_weightings(memory, keys, strengths)
|
||||||
|
weightings = content_weightings.eval()
|
||||||
|
|
||||||
|
np_similarity = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for r in range(config['read_heads']):
|
||||||
|
for l in range(config['memory_length']):
|
||||||
|
np_similarity[b, r, l] = np.dot(np_memory[b, l, :], np_keys[b, r, :]) / (
|
||||||
|
np.sqrt(np.dot(np_memory[b, l, :], np_memory[b, l, :])) * np.sqrt(
|
||||||
|
np.dot(np_keys[b, r, :], np_keys[b, r, :])))
|
||||||
|
np_weightings = np.empty([config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
def _weighted_softmax(x, s):
|
||||||
|
e_x = np.exp(x * s)
|
||||||
|
return e_x / e_x.sum(axis=1, keepdims=True)
|
||||||
|
|
||||||
|
for r in range(config['read_heads']):
|
||||||
|
np_weightings[:, r, :] = _weighted_softmax(np_similarity[:, r, :], np_strengths[:, r])
|
||||||
|
|
||||||
|
assert weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
|
||||||
|
assert 0 <= weightings.min() and weightings.max() <= 1 and weightings.sum(axis=2).all() <= 1
|
||||||
|
assert np.allclose(weightings, np_weightings)
|
||||||
|
|
||||||
|
def test_update_write_weightings(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_alloc_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['write_heads'], config['memory_length']])
|
||||||
|
np_write_content_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['write_heads'],
|
||||||
|
config['memory_length']])
|
||||||
|
np_write_gate = np.ones([config['batch_size'], config['write_heads'], 1]) * 0.5
|
||||||
|
np_alloc_gate = np.ones([config['batch_size'], config['write_heads'], 1]) * 0.5
|
||||||
|
|
||||||
|
alloc_weightings = tf.constant(np_alloc_weightings, dtype=tf.float32)
|
||||||
|
write_content_weightings = tf.constant(np_write_content_weighting, dtype=tf.float32)
|
||||||
|
write_gates = tf.constant(np_write_gate, dtype=tf.float32)
|
||||||
|
alloc_gates = tf.constant(np_alloc_gate, dtype=tf.float32)
|
||||||
|
|
||||||
|
write_weightings = memory_unit._update_write_weightings(alloc_weightings, write_content_weightings, write_gates,
|
||||||
|
alloc_gates)
|
||||||
|
write_weightings = write_weightings.eval()
|
||||||
|
|
||||||
|
np_write_weightings = np_write_gate * (
|
||||||
|
np_alloc_gate * np_alloc_weightings + (1 - np_alloc_gate) * np_write_content_weighting)
|
||||||
|
|
||||||
|
assert write_weightings.shape == (config['batch_size'], config['write_heads'], config['memory_length'])
|
||||||
|
assert 0 <= write_weightings.min() and write_weightings.max() <= 1 and write_weightings.sum(axis=2).all() <= 1
|
||||||
|
assert np.allclose(write_weightings, np_write_weightings)
|
||||||
|
|
||||||
|
def test_update_memory(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['write_heads'], config['memory_length']])
|
||||||
|
np_write_vector = np_rng.normal(0, 2, [config['batch_size'], config['write_heads'], config['memory_width']])
|
||||||
|
np_erase_vector = np_rng.uniform(0, 1, [config['batch_size'], config['write_heads'], config['memory_width']])
|
||||||
|
|
||||||
|
pre_memory = tf.constant(np_memory, dtype=tf.float32)
|
||||||
|
write_weighting = tf.constant(np_write_weighting, dtype=tf.float32)
|
||||||
|
write_vector = tf.constant(np_write_vector, dtype=tf.float32)
|
||||||
|
erase_vector = tf.constant(np_erase_vector, dtype=tf.float32)
|
||||||
|
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
memory = memory_unit._update_memory(pre_memory, write_weighting, write_vector, erase_vector)
|
||||||
|
memory = memory.eval()
|
||||||
|
|
||||||
|
np_erase_memory = (1 - np.expand_dims(np_write_weighting, 3) * np.expand_dims(np_erase_vector, 2))
|
||||||
|
np_erase_memory = np.prod(np_erase_memory, axis=1, keepdims=False)
|
||||||
|
|
||||||
|
np_add_memory = np.matmul(np.transpose(np_write_weighting, (0, 2, 1)), np_write_vector)
|
||||||
|
|
||||||
|
np_memory = np_memory * np_erase_memory + np_add_memory
|
||||||
|
|
||||||
|
assert memory.shape == (config['batch_size'], config['memory_length'], config['memory_width'])
|
||||||
|
assert np.allclose(memory, np_memory, atol=1e-06)
|
||||||
|
|
||||||
|
def test_update_link_matrix(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['write_heads'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
np_pre_link_matrix = np.zeros(
|
||||||
|
[config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length']])
|
||||||
|
np_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['write_heads'], config['memory_length']])
|
||||||
|
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['write_heads'],
|
||||||
|
config['memory_length']])
|
||||||
|
|
||||||
|
inputs = tf.constant(np_inputs, dtype=tf.float32)
|
||||||
|
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
|
||||||
|
write_weighting = tf.constant(np_write_weighting, dtype=tf.float32)
|
||||||
|
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
|
||||||
|
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
|
||||||
|
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
|
||||||
|
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
|
||||||
|
pre_read_weightings)
|
||||||
|
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
memory_unit(inputs, pre_states)
|
||||||
|
link_matrix, precedence_weighting = memory_unit._update_link_matrix(pre_link_matrix, write_weighting,
|
||||||
|
pre_precedence_weighting)
|
||||||
|
link_matrix, precedence_weighting = session.run([link_matrix, precedence_weighting])
|
||||||
|
|
||||||
|
np_precedence_weighting = (1 - np.sum(np_write_weighting, axis=2,
|
||||||
|
keepdims=True)) * np_pre_precedence_weighting + np_write_weighting
|
||||||
|
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for w in range(config['write_heads']):
|
||||||
|
for i in range(config['memory_length']):
|
||||||
|
for j in range(config['memory_length']):
|
||||||
|
if i == j:
|
||||||
|
np_pre_link_matrix[b, w, i, j] = 0
|
||||||
|
else:
|
||||||
|
np_pre_link_matrix[b, w, i, j] = (1 - np_write_weighting[b, w, i] - np_write_weighting[
|
||||||
|
b, w, j]) * np_pre_link_matrix[b, w, i, j] + \
|
||||||
|
np_write_weighting[b, w, i] * np_pre_precedence_weighting[
|
||||||
|
b, w, j]
|
||||||
|
np_link_matrix = np_pre_link_matrix
|
||||||
|
|
||||||
|
assert precedence_weighting.shape == (config['batch_size'], config['write_heads'], config['memory_length'])
|
||||||
|
assert 0 <= precedence_weighting.min() and precedence_weighting.max() <= 1 and precedence_weighting.sum(
|
||||||
|
axis=1).all() <= 1
|
||||||
|
assert np.allclose(precedence_weighting, np_precedence_weighting)
|
||||||
|
|
||||||
|
assert link_matrix.shape == (
|
||||||
|
config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length'])
|
||||||
|
assert np.allclose(link_matrix, np_link_matrix)
|
||||||
|
|
||||||
|
def test_make_read_forward_backward_weightings(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_link_matrix = np.zeros(
|
||||||
|
[config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
link_matrix = tf.constant(np_link_matrix, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
forward_weightings, backward_weightings = memory_unit._make_read_forward_backward_weightings(link_matrix,
|
||||||
|
pre_read_weightings)
|
||||||
|
forward_weightings, backward_weightings = session.run([forward_weightings, backward_weightings])
|
||||||
|
|
||||||
|
np_forward_weightings = np.empty(
|
||||||
|
[config['batch_size'], config['read_heads'], config['write_heads'], config['memory_length']])
|
||||||
|
np_backward_weightings = np.empty(
|
||||||
|
[config['batch_size'], config['read_heads'], config['write_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for r in range(config['read_heads']):
|
||||||
|
for w in range(config['write_heads']):
|
||||||
|
np_forward_weightings[b, r, w, :] = np.matmul(np_pre_read_weightings[b, r, :],
|
||||||
|
np_link_matrix[b, w, :, :])
|
||||||
|
np_backward_weightings[b, r, w, :] = np.matmul(np_pre_read_weightings[b, r, :],
|
||||||
|
np.transpose(np_link_matrix[b, w, :, :]))
|
||||||
|
|
||||||
|
assert forward_weightings.shape == (
|
||||||
|
config['batch_size'], config['read_heads'], config['write_heads'], config['memory_length'])
|
||||||
|
assert 0 <= forward_weightings.min() and forward_weightings.max() <= 1 and forward_weightings.sum(
|
||||||
|
axis=3).all() <= 1
|
||||||
|
assert np.allclose(forward_weightings, np_forward_weightings)
|
||||||
|
|
||||||
|
assert backward_weightings.shape == (
|
||||||
|
config['batch_size'], config['read_heads'], config['write_heads'], config['memory_length'])
|
||||||
|
assert 0 <= backward_weightings.min() and backward_weightings.max() <= 1 and backward_weightings.sum(
|
||||||
|
axis=3).all() <= 1
|
||||||
|
assert np.allclose(backward_weightings, np_backward_weightings)
|
||||||
|
|
||||||
|
def test_make_read_weightings(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_forward_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['write_heads'],
|
||||||
|
config['memory_length']])
|
||||||
|
np_backward_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['write_heads'],
|
||||||
|
config['memory_length']])
|
||||||
|
np_read_content_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'],
|
||||||
|
config['memory_length']])
|
||||||
|
np_read_modes = np.reshape(
|
||||||
|
np.repeat([0.1, ], config['batch_size'] * config['read_heads'] * (2 * config['write_heads'] + 1)),
|
||||||
|
[config['batch_size'], config['read_heads'], 1 + 2 * config['write_heads']])
|
||||||
|
|
||||||
|
forward_weightings = tf.constant(np_forward_weightings, dtype=tf.float32)
|
||||||
|
backward_weightings = tf.constant(np_backward_weightings, dtype=tf.float32)
|
||||||
|
read_content_weightings = tf.constant(np_read_content_weightings, dtype=tf.float32)
|
||||||
|
read_modes = tf.constant(np_read_modes, dtype=tf.float32)
|
||||||
|
|
||||||
|
read_weightings = memory_unit._make_read_weightings(forward_weightings, backward_weightings,
|
||||||
|
read_content_weightings, read_modes)
|
||||||
|
read_weightings = read_weightings.eval()
|
||||||
|
|
||||||
|
np_read_weightings = np.sum(
|
||||||
|
np_backward_weightings * np.expand_dims(np_read_modes[:, :, : config['write_heads']], 3), axis=2) + \
|
||||||
|
np_read_content_weightings * np.expand_dims(np_read_modes[:, :, config['write_heads']],
|
||||||
|
2) + \
|
||||||
|
np.sum(
|
||||||
|
np_forward_weightings * np.expand_dims(np_read_modes[:, :, config['write_heads'] + 1:],
|
||||||
|
3), axis=2)
|
||||||
|
|
||||||
|
assert read_weightings.shape == (config['batch_size'], config['read_heads'], config['memory_length'])
|
||||||
|
assert 0 <= read_weightings.min() and read_weightings.max() <= 1 and read_weightings.sum(axis=1).all() <= 1
|
||||||
|
assert np.allclose(read_weightings, np_read_weightings)
|
||||||
|
|
||||||
|
def test_read_memory(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
memory = tf.constant(np_memory, dtype=tf.float32)
|
||||||
|
read_weightings = tf.constant(np_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
read_vectors = memory_unit._read_memory(memory, read_weightings)
|
||||||
|
read_vectors = read_vectors.eval()
|
||||||
|
|
||||||
|
np_read_vectors = np.empty([config['batch_size'], config['read_heads'], config['memory_width']])
|
||||||
|
for b in range(config['batch_size']):
|
||||||
|
for r in range(config['read_heads']):
|
||||||
|
np_read_vectors[b, r, :] = np.matmul(np.expand_dims(np_read_weightings[b, r, :], 0), np_memory[b, :, :])
|
||||||
|
|
||||||
|
assert read_vectors.shape == (config['batch_size'], config['read_heads'], config['memory_width'])
|
||||||
|
assert np.allclose(read_vectors, np_read_vectors, atol=1e-06)
|
||||||
|
|
||||||
|
def test_call(self, memory_config, session, np_rng):
|
||||||
|
memory_unit, config = memory_config
|
||||||
|
|
||||||
|
np_inputs = np_rng.normal(0, 1, [config['batch_size'], config['input_size']])
|
||||||
|
np_pre_memory = np_rng.normal(0, 1, [config['batch_size'], config['memory_length'], config['memory_width']])
|
||||||
|
np_pre_usage_vector = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['memory_length']])
|
||||||
|
np_pre_write_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['write_heads'], config['memory_length']])
|
||||||
|
np_pre_precedence_weighting = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['write_heads'],
|
||||||
|
config['memory_length']])
|
||||||
|
np_pre_link_matrix = np.zeros(
|
||||||
|
[config['batch_size'], config['write_heads'], config['memory_length'], config['memory_length']])
|
||||||
|
np_pre_read_weightings = np_rng.uniform(0, 1 / config['memory_length'],
|
||||||
|
[config['batch_size'], config['read_heads'], config['memory_length']])
|
||||||
|
|
||||||
|
inputs = tf.constant(np_inputs, dtype=tf.float32)
|
||||||
|
pre_memory = tf.constant(np_pre_memory, dtype=tf.float32)
|
||||||
|
pre_usage_vector = tf.constant(np_pre_usage_vector, dtype=tf.float32)
|
||||||
|
pre_write_weighting = tf.constant(np_pre_write_weighting, dtype=tf.float32)
|
||||||
|
pre_precedence_weighting = tf.constant(np_pre_precedence_weighting, dtype=tf.float32)
|
||||||
|
pre_link_matrix = tf.constant(np_pre_link_matrix, dtype=tf.float32)
|
||||||
|
pre_read_weightings = tf.constant(np_pre_read_weightings, dtype=tf.float32)
|
||||||
|
|
||||||
|
pre_states = (pre_memory, pre_usage_vector, pre_write_weighting, pre_precedence_weighting, pre_link_matrix,
|
||||||
|
pre_read_weightings)
|
||||||
|
|
||||||
|
memory_unit.zero_state(config['batch_size'])
|
||||||
|
read_vectors, states = memory_unit(inputs, pre_states)
|
||||||
|
|
||||||
|
session.run(tf.global_variables_initializer())
|
||||||
|
read_vectors, states = session.run([read_vectors, states])
|
||||||
|
|
||||||
|
assert read_vectors.shape == (
|
||||||
|
config['batch_size'], config['memory_width'] * config['read_heads'] + config['input_size'])
|
Loading…
Reference in New Issue
Block a user