From 48081f315bfa11c42b8a1f53a422546c7b09938a Mon Sep 17 00:00:00 2001 From: Joerg Franke Date: Mon, 25 Jun 2018 01:38:45 +0200 Subject: [PATCH] add mann and test --- adnc/model/__init__.py | 14 ++ adnc/model/mann.py | 288 +++++++++++++++++++++++++++++++++++ test/adnc/model/__init__.py | 0 test/adnc/model/test_mann.py | 140 +++++++++++++++++ 4 files changed, 442 insertions(+) create mode 100644 adnc/model/__init__.py create mode 100755 adnc/model/mann.py create mode 100644 test/adnc/model/__init__.py create mode 100755 test/adnc/model/test_mann.py diff --git a/adnc/model/__init__.py b/adnc/model/__init__.py new file mode 100644 index 0000000..f5514c0 --- /dev/null +++ b/adnc/model/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 Jörg Franke +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== \ No newline at end of file diff --git a/adnc/model/mann.py b/adnc/model/mann.py new file mode 100755 index 0000000..1db6c63 --- /dev/null +++ b/adnc/model/mann.py @@ -0,0 +1,288 @@ +# Copyright 2018 Jörg Franke +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import tensorflow as tf +from tensorflow.contrib.rnn import MultiRNNCell +from tensorflow.python.ops import variable_scope as vs + +from adnc.model.controller_units.controller import get_rnn_cell_list +from adnc.model.memory_units.memory_unit import get_memory_unit + +from adnc.model.utils import HolisticMultiRNNCell +from adnc.model.utils import WordEmbedding + + +class MANN(): + + def __init__(self, config, analyse=False, reuse=False, name='mann', dtype=tf.float32, new_output_structure=False): + + self.seed = config["seed"] + self.rng = np.random.RandomState(seed=self.seed) + self.dtype = dtype + self.analyse = analyse + + self.input_size = config["input_size"] + self.output_size = config["output_size"] + self.batch_size = config["batch_size"] + + self.input_embedding = config["input_embedding"] + self.architecture = config['architecture'] + self.controller_config = config["controller_config"] + self.memory_unit_config = config["memory_unit_config"] + self.output_function = config["output_function"] + self.output_mask = config["output_mask"] + self.loss_function = config['loss_function'] + + + self.reuse = reuse + self.name = name + + self.mask = tf.placeholder(self.dtype, [None, self.batch_size], name='mask') + self.target = tf.placeholder(self.dtype, [None, self.batch_size, self.output_size], name='y') + + if self.input_embedding: + word_idx_dict = self.input_embedding['word_idx_dict'] + embedding_size = self.input_embedding['embedding_size'] + tmp_dir = self.input_embedding['tmp_dir'] + glove = WordEmbedding(embedding_size, word_idx_dict=word_idx_dict, initialization='glove', tmp_dir=tmp_dir) + + self._data = tf.placeholder(tf.int64, [None, self.batch_size], name='x') + self.data = glove.embed(self._data) + else: + self.data = tf.placeholder(tf.float32, [None, self.batch_size, self.input_size], name='x') + + + if self.architecture in ['uni', 'unidirectional']: + unweighted_outputs, states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config, reuse=self.reuse) + elif self.architecture in ['bi', 'bidirectional']: + unweighted_outputs, states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config, reuse=self.reuse) + else: + raise UserWarning("Unknown architecture, use unidirectional or bidirectional") + + + if self.analyse: + with tf.device('/cpu:0'): + if self.architecture in ['uni', 'unidirectional']: + analyse_outputs, analyse_states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config, analyse=True, reuse=True) + analyse_outputs, analyse_signals = analyse_outputs + self.analyse =(analyse_outputs, analyse_signals, analyse_states) + elif self.architecture in ['bi', 'bidirectional']: + analyse_outputs, analyse_states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config, analyse=True, reuse=True) + analyse_outputs, analyse_signals = analyse_outputs + self.analyse =(analyse_outputs, analyse_signals, analyse_states) + + self.unweighted_outputs = unweighted_outputs + self.prediction, self.outputs = self._output_layer(unweighted_outputs) + self.loss = self.get_loss(self.prediction) + + + def _output_layer(self, outputs): + + with tf.variable_scope("output_layer"): + output_size = outputs.get_shape()[-1].value + + weights_concat = tf.get_variable("weights_concat",(output_size, self.output_size), + initializer=tf.contrib.layers.xavier_initializer(seed=self.seed), collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype) + bias_merge = tf.get_variable("bias_merge",(self.output_size,), initializer=tf.constant_initializer(0.), collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype) + + output_flat = tf.reshape(outputs, [-1, output_size]) + output_flat = tf.matmul(output_flat, weights_concat) + bias_merge + + if self.output_function == 'softmax': + predictions_flat = tf.nn.softmax(output_flat) + elif self.output_function == 'tanh': + predictions_flat = tf.tanh(output_flat) + elif self.output_function == 'linear': + predictions_flat = output_flat + else: + raise UserWarning("Unknown output function, use softmax, tanh or linear") + + predictions = tf.reshape(predictions_flat, [-1, self.batch_size, self.output_size]) + weighted_outputs = tf.reshape(output_flat, [-1, self.batch_size, self.output_size]) + + return predictions, weighted_outputs + + @property + def feed(self): + return self.data, self.target, self.mask + + @property + def controller_trainable_variables(self): + return tf.get_collection('recurrent_unit') + + @property + def memory_unit_trainable_variables(self): + return tf.get_collection('memory_unit') + + @property + def mann_trainable_variables(self): + return tf.get_collection('mann') + + @property + def trainable_variables(self): + return tf.trainable_variables() + + @property + def controller_parameter_amount(self): + return self.count_parameter_amount(self.controller_trainable_variables) + + @property + def memory_unit_parameter_amount(self): + return self.count_parameter_amount(self.memory_unit_trainable_variables) + + @property + def mann_parameter_amount(self): + return self.count_parameter_amount(self.mann_trainable_variables) + + @staticmethod + def count_parameter_amount(var_list): + parameters = 0 + for variable in var_list: + shape = variable.get_shape() + variable_parametes = 1 + for dim in shape: + variable_parametes *= dim.value + parameters += variable_parametes + return parameters + + @property + def parameter_amount(self): + var_list = tf.trainable_variables() + parameters = 0 + for variable in var_list: + shape = variable.get_shape() + variable_parametes = 1 + for dim in shape: + variable_parametes *= dim.value + parameters += variable_parametes + return parameters + + + def get_loss(self, prediction): + if self.loss_function == 'cross_entropy': + if self.output_mask: + cost = tf.reduce_sum(-1 * self.target * tf.log(tf.clip_by_value(prediction,1e-12,10.0)) - (1 - self.target) * tf.log(tf.clip_by_value(1 - prediction,1e-12,10.0)), axis=2) + cost *= self.mask + loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask) + else: + loss = tf.reduce_mean(-1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log(tf.clip_by_value(1 - prediction, 1e-12, 10.0))) + + elif self.loss_function == 'mse': + clipped_prediction = tf.clip_by_value(prediction, 1e-12, 10.0) + mse = tf.square(self.target - clipped_prediction) + mse = tf.reduce_mean(mse, axis=2) + + if self.output_mask: + cost = mse * self.mask + loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask) + else: + loss = tf.reduce_mean(mse) + else: + raise UserWarning("Unknown loss function, use cross_entropy or mse") + return loss + + + def unidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False): + + with tf.variable_scope("controller"): + controller_list = get_rnn_cell_list(controller_config, name='controller', reuse=reuse, seed=self.seed, dtype=self.dtype) + + if controller_config['connect'] == 'sparse': + memory_input_size = controller_list[-1].output_size + mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse) + cell = MultiRNNCell(controller_list +[mu_cell]) + else: + controller_cell = HolisticMultiRNNCell(controller_list) + memory_input_size = controller_cell.output_size + mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse) + cell = MultiRNNCell([controller_cell, mu_cell]) + + batch_size = inputs.get_shape()[1].value + cell_init_states = cell.zero_state(batch_size, dtype=self.dtype) + output_init = tf.zeros([batch_size, cell.output_size], dtype=self.dtype) + + if analyse: + output_init = (output_init, mu_cell.analyse_state(batch_size, dtype=self.dtype)) + + init_states = (output_init, cell_init_states) + + def step(pre_states, inputs): + pre_rnn_output, pre_rnn_states = pre_states + + if analyse: + pre_rnn_output = pre_rnn_output[0] + + controller_inputs = tf.concat([inputs, pre_rnn_output], axis=-1) + rnn_output, rnn_states = cell(controller_inputs, pre_rnn_states) + return (rnn_output, rnn_states) + + outputs, states = tf.scan(step, inputs, initializer=init_states, parallel_iterations=32) + + return outputs, states + + + def bidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False): + + with tf.variable_scope("controller"): + list_fw = get_rnn_cell_list(controller_config, name='con_fw', reuse=reuse, seed=self.seed, dtype=self.dtype) + list_bw = get_rnn_cell_list( controller_config, name='con_bw', reuse=reuse, seed=self.seed, dtype=self.dtype) + if controller_config['connect'] == 'sparse': + cell_fw = MultiRNNCell(list_fw) + cell_bw = MultiRNNCell(list_bw) + else: + cell_fw = HolisticMultiRNNCell(list_fw) + cell_bw = HolisticMultiRNNCell(list_bw) + + memory_input_size = cell_fw.output_size + cell_bw.output_size + cell_mu = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse) + + with vs.variable_scope("bw") as bw_scope: + inputs_reverse = tf.reverse(inputs, axis=[0]) + output_bw, output_state_bw = tf.nn.dynamic_rnn(cell=cell_bw, inputs=inputs_reverse, dtype=self.dtype, + parallel_iterations=32, time_major=True, scope=bw_scope) + output_bw = tf.reverse(output_bw, axis=[0]) + + batch_size = inputs.get_shape()[1].value + cell_fw_init_states = cell_fw.zero_state(batch_size, dtype=self.dtype) + cell_mu_init_states = cell_mu.zero_state(batch_size, dtype=self.dtype) + output_init = tf.zeros([batch_size, cell_mu.output_size], dtype=self.dtype) + + if analyse: + output_init = (output_init, cell_mu.analyse_state(batch_size, dtype=self.dtype)) + + init_states = (output_init, cell_fw_init_states, cell_mu_init_states) + coupled_inputs = (inputs, output_bw) + + with vs.variable_scope("fw") as fw_scope: + + def step(pre_states, coupled_inputs): + inputs, output_bw = coupled_inputs + pre_outputs, pre_states_fw, pre_states_mu = pre_states + + if analyse: + pre_outputs = pre_outputs[0] + + controller_inputs = tf.concat([inputs, pre_outputs], axis=-1) + output_fw, states_fw = cell_fw(controller_inputs, pre_states_fw) + + mu_inputs = tf.concat([output_fw, output_bw], axis=-1) + output_mu, states_mu = cell_mu(mu_inputs, pre_states_mu) + + return (output_mu, states_fw, states_mu) + + outputs, states_fw, states_mu = tf.scan(step, coupled_inputs, initializer=init_states, parallel_iterations=32) + + states = states_fw, states_mu + return outputs, states \ No newline at end of file diff --git a/test/adnc/model/__init__.py b/test/adnc/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/adnc/model/test_mann.py b/test/adnc/model/test_mann.py new file mode 100755 index 0000000..6912710 --- /dev/null +++ b/test/adnc/model/test_mann.py @@ -0,0 +1,140 @@ +# Copyright 2018 Jörg Franke +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import pytest +import tensorflow as tf + +from adnc.model.mann import MANN + +INPUT_SIZE = 22 +OUTPUT_SIZE = 22 +BATCH_SIZE = 31 + +CONFIG = { + "seed": 123, + "input_size": INPUT_SIZE, + "output_size": OUTPUT_SIZE, + "batch_size": BATCH_SIZE, + "input_embedding": False, + "architecture": 'uni', # bidirectional + "controller_config": {"num_units": [67, 63], "layer_norm": False, "activation": 'tanh', 'cell_type': 'clstm', + 'connect': 'dense', 'attention': False}, + "memory_unit_config": {"memory_length": 96, "memory_width": 31, "read_heads": 4, "write_heads": None, + "dnc_norm": False, "bypass_dropout": False, 'cell_type': 'dnc'}, + "output_function": "softmax", + "loss_function": "cross_entropy", + "output_mask": True, +} + + +@pytest.fixture() +def mann(): + tf.reset_default_graph() + return MANN(config=CONFIG) + + +@pytest.fixture() +def session(): + with tf.Session() as sess: + yield sess + tf.reset_default_graph() + + +@pytest.fixture() +def np_rng(): + seed = np.random.randint(1, 999) + return np.random.RandomState(seed) + + +class TestMANN(): + def test_init(self, mann): + assert isinstance(mann, object) + assert isinstance(mann.rng, np.random.RandomState) + + assert mann.seed == CONFIG["seed"] + assert mann.input_size == CONFIG["input_size"] + assert mann.output_size == CONFIG["output_size"] + assert mann.batch_size == CONFIG["batch_size"] + + assert mann.input_embedding == CONFIG["input_embedding"] + assert mann.architecture == CONFIG["architecture"] + assert mann.controller_config == CONFIG["controller_config"] + assert mann.memory_unit_config == CONFIG["memory_unit_config"] + assert mann.output_function == CONFIG["output_function"] + assert mann.output_mask == CONFIG["output_mask"] + + def test_property_feed(self, mann): + data, target, mask = mann.feed + assert type(data) == tf.Tensor + assert type(target) == tf.Tensor + assert type(mask) == tf.Tensor + + def test_property_controller_trainable_variables(self, mann): + assert mann.controller_trainable_variables.__len__() == CONFIG["controller_config"]['num_units'].__len__() * 2 + + def test_property_controller_parameter_amount(self, mann): + total_signal_size = (1 + INPUT_SIZE + CONFIG["memory_unit_config"]["memory_width"] * + CONFIG["memory_unit_config"]["read_heads"] + + CONFIG["controller_config"]['num_units'][0] + CONFIG["controller_config"]['num_units'][ + 1]) * 4 * CONFIG["controller_config"]['num_units'][0] + \ + (1 + CONFIG["controller_config"]['num_units'][0]) * 4 * \ + CONFIG["controller_config"]['num_units'][1] + parameter_amount = mann.controller_parameter_amount + assert parameter_amount == total_signal_size + + def test_property_memory_unit_trainable_variables(self, mann): + assert mann.memory_unit_trainable_variables.__len__() == 2 + + def test_property_memory_unit_parameter_amount(self, mann): + total_signal_size = ( + CONFIG["memory_unit_config"]['memory_width'] * (3 + CONFIG["memory_unit_config"]["read_heads"]) + 5 * + CONFIG["memory_unit_config"]['read_heads'] + 3) + parameter_amount = mann.memory_unit_parameter_amount + assert parameter_amount == (sum(CONFIG["controller_config"]['num_units']) + 1) * total_signal_size + + def test_property_mann_trainable_variables(self, mann): + assert mann.mann_trainable_variables.__len__() == 2 # weights and bias for softmax + + def test_property_mann_parameter_amount(self, mann): + total_mann = ((CONFIG["memory_unit_config"]['memory_width'] * CONFIG["memory_unit_config"]["read_heads"]) + \ + sum(CONFIG["controller_config"]['num_units']) + 1) * OUTPUT_SIZE + parameter_amount = mann.mann_parameter_amount + assert parameter_amount == total_mann + + def test_property_trainable_variables(self, mann): + assert mann.trainable_variables.__len__() == CONFIG["controller_config"]['num_units'].__len__() * 2 + 2 + 2 + + def test_property_parameter_amount(self, mann): + total_mann = ((CONFIG["memory_unit_config"]['memory_width'] * CONFIG["memory_unit_config"]["read_heads"]) + \ + sum(CONFIG["controller_config"]['num_units']) + 1) * OUTPUT_SIZE + parameter_amount = mann.parameter_amount + assert parameter_amount == mann.controller_parameter_amount + mann.memory_unit_parameter_amount + total_mann + + def test_property_predictions_loss(self, mann, session): + np_inputs = np.ones([12, BATCH_SIZE, INPUT_SIZE]) + np_target = np.ones([12, BATCH_SIZE, OUTPUT_SIZE]) + np_mask = np.ones([12, BATCH_SIZE]) + + data, target, mask = mann.feed + session.run(tf.global_variables_initializer()) + + prediction, loss = session.run([mann.prediction, mann.loss], + feed_dict={data: np_inputs, target: np_target, mask: np_mask}) + + assert prediction.shape == (12, BATCH_SIZE, OUTPUT_SIZE) + assert 0 <= prediction.min() and prediction.max() <= 1 and prediction.sum(axis=2).all() == 1 + + assert loss >= 0 + assert loss.shape == ()