From 1c42aae82aa6e37f987c17978629e79e4b0ee1ee Mon Sep 17 00:00:00 2001 From: Joerg Franke Date: Fri, 6 Jul 2018 08:27:56 +0200 Subject: [PATCH] add function plot script and bugfix methods --- adnc/analysis/plot_functionality.py | 31 +++++--- adnc/analysis/prepare_variables.py | 18 +++-- scripts/plot_function_babi_task.py | 109 ++++++++++++++++++++++++++++ 3 files changed, 139 insertions(+), 19 deletions(-) create mode 100644 scripts/plot_function_babi_task.py diff --git a/adnc/analysis/plot_functionality.py b/adnc/analysis/plot_functionality.py index b20f443..8bf2962 100755 --- a/adnc/analysis/plot_functionality.py +++ b/adnc/analysis/plot_functionality.py @@ -189,18 +189,24 @@ class PlotFunctionality(PlotFunctions): def plot_short_process(self, batch, plot_dir, name, show=False): - correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \ - write_weighting, content_weighting, write_strength, alloc_weighting, write_vector, write_key, max_loc = self.bucket.get_write_process( - batch) + + if self.bucket.cell_type == 'dnc': + correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \ + write_weighting, read_mode, read_weighting, read_head_influence, old_memory, new_memory, read_strength, max_loc = self.bucket.get_basic_functionality( + batch=batch) + else: + correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \ + write_weighting, read_weighting, read_head_influence, old_memory, new_memory, read_strength, max_loc = self.bucket.get_basic_functionality( + batch=batch) + read_heads = self.bucket.max_read_head + write_heads = self.bucket.max_write_head + f, ax = plt.subplots((4 + 1 * read_heads + 3 * write_heads - 2), sharex=True, figsize=(12, 18)) + controller_influence, memory_unit_influence = self.bucket.get_memory_influence(batch) influence = np.stack([memory_unit_influence, controller_influence], axis=-1) influence = influence / influence.sum(axis=1, keepdims=True) - read_heads = self.bucket.max_read_head - write_heads = self.bucket.max_write_head - - f, ax = plt.subplots((4 + 1 * read_heads + 3 * write_heads - 2), sharex=True, figsize=(12, 18)) ax[0].set_title(name, size=33, weight='bold') @@ -213,9 +219,12 @@ class PlotFunctionality(PlotFunctions): for i in range(write_heads): self.plot_modes(alloc_gate[:, i, :], ax[3 + i * 3], ['y', 'b'], ['Content', 'Usage'], name='Alloc Gate') self.plot_modes(write_gate[:, i, :], ax[4 + i * 3], ['g', 'r'], ['Write', 'Write not'], name='Write Gate') - print(write_gate[:, i, :].shape) - self.plot_modes(np.zeros([34, 3]), ax[5 + 0 * 1], ['m', 'b', 'c'], ['Backward', 'Content', 'Forward'], - name='Read Mode') + + if self.bucket.cell_type == 'dnc': + self.plot_modes(read_mode[:,i,:], ax[5+0*1], ['m', 'b','c'], ['Backward', 'Content', 'Forward'], name='Read Mode') + else: + self.plot_modes(np.zeros([34, 3]), ax[5 + 0 * 1], ['m', 'b', 'c'], ['Backward', 'Content', 'Forward'], + name='Read Mode') ax[5 + 0 * 1].set_yticks([]) @@ -234,7 +243,7 @@ class PlotFunctionality(PlotFunctions): tick.label.set_fontsize(16) plt.xlabel("Time steps", size=self.text_size) - plt.savefig(os.path.join(plot_dir, 'function_{}_{}.{}'.format(name, batch, self.data_type)), + plt.savefig(os.path.join(plot_dir, '{}_{}.{}'.format(name, batch, self.data_type)), bbox_inches='tight', format=self.data_type, dpi=80) if show: plt.show() diff --git a/adnc/analysis/prepare_variables.py b/adnc/analysis/prepare_variables.py index be5a800..3de2417 100755 --- a/adnc/analysis/prepare_variables.py +++ b/adnc/analysis/prepare_variables.py @@ -42,11 +42,11 @@ class Bucket: self.link_matrix = link_matrix self.precedence_weighting = precedence_weighting else: - self.cell_type = 'cmu' + self.cell_type = 'cbmu' memory, usage_vector, write_weightings, read_weightings = memory_states alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, erase_vector, read_keys, read_strengths = analyse_signals - self.read_mode = np.expand_dims(write_gate, -1) - self.link_matrix = np.expand_dims(memory, -1) + # self.read_mode = np.expand_dims(write_gate, -1) + # self.link_matrix = np.expand_dims(memory, -1) self.weights_dict = weights_dict self.batch_size = memory.shape[1] @@ -265,21 +265,23 @@ class Bucket: text = self.x_word[batch] decoded_predictions = self.decoded_predictions[:, batch] - read_mode = self.read_mode[:max_loc, batch, :, :] read_weighting = self.read_weighting[:max_loc, batch, :, :] read_strength = self.read_strength[:max_loc, batch, :, :] read_key = self.read_keys[:max_loc, batch, :, :, ] read_vector = self.read_vector[:max_loc, batch, :, :] memory = self.memory[:max_loc, batch, :, :] - link_matrix = self.link_matrix[:max_loc, batch, :, :, :] controller_output = self.controller_output[:max_loc, batch, :] read_content_weighting = self.calculate_content_weightings(memory, read_key, read_strength[:, :, 0]) - forward_weighting, backward_weighting = self.calculate_forward_backward_weightings(link_matrix, read_weighting) read_head_influence = self.calculate_read_head_influence(read_vector, self.weights_dict, controller_output) - return correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, \ - backward_weighting, read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc + if self.cell_type == 'dnc': + read_mode = self.read_mode[:max_loc, batch, :, :] + link_matrix = self.link_matrix[:max_loc, batch, :, :, :] + forward_weighting, backward_weighting = self.calculate_forward_backward_weightings(link_matrix, read_weighting) + return correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, backward_weighting, read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc + else: + return correct_prediction, false_prediction, text, decoded_predictions, mask, read_content_weighting, read_strength, read_key, read_weighting, read_vector, read_head_influence, max_loc def get_memory_process(self, batch): diff --git a/scripts/plot_function_babi_task.py b/scripts/plot_function_babi_task.py new file mode 100644 index 0000000..840d511 --- /dev/null +++ b/scripts/plot_function_babi_task.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python +# Copyright 2018 Jörg Franke +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import yaml +import argparse +import numpy as np +import tensorflow as tf + +from adnc.data import DataLoader +from adnc.model import MANN, Optimizer +from adnc.analysis import Bucket, PlotFunctionality + +tf.reset_default_graph() + +parser = argparse.ArgumentParser(description='Load model') +parser.add_argument('model', type=str, default=False, help='model name') +model_name = parser.parse_args().model + +# Choose a pre trained model by uncomment +if model_name == 'dnc': + model_dir = "experiments/pre_trained/babi_task_1/dnc" # DNC trained on bAbI tasks 1 +elif model_name == 'adnc': + model_dir = "experiments/pre_trained/babi_task_1/adnc" # ADNC trained on bAbI tasks 1 +elif model_name == 'biadnc': + model_dir = "experiments/pre_trained/babi_task_1/biadnc" # BiADNC trained on bAbI tasks 1 +elif model_name == 'biadnc-all': + model_dir = "experiments/pre_trained/babi_task_all/biadnc" # BiADNC trained on all bAbI tasks +else: + model_dir = "experiments/pre_trained/babi_task_all/biadnc_aug16" # BiADNC trained on all bAbI tasks with task 16 augmentation + +plot_dir = "experiments/" + + +analyse = True +BATCH_SIZE = 1 + + +# load config from file +with open(os.path.join(model_dir, 'config.yml'), 'r') as f: + configs = yaml.load(f) +dataset_config = configs['babi_task'] +trainer_config = configs['training'] +model_config = configs['mann'] + + +dataset_config['batch_size'] = BATCH_SIZE +model_config['batch_size'] = BATCH_SIZE + +dl = DataLoader(dataset_config) + +model_config['input_size'] = dl.x_size +model_config['output_size'] = dl.y_size + +word_dict = dl.dataset.word_dict +re_word_dict = dl.dataset.re_word_dict +dataset_config['task_selection'] = [1] +dl2 = DataLoader(dataset_config, word_dict, re_word_dict) +valid_loader = dl2.get_data_loader('valid') + + +model = MANN(model_config, analyse=True) + +data, target, mask = model.feed + +trainer = Optimizer(trainer_config, model.loss, model.trainable_variables) +saver = tf.train.Saver() + + +conf = tf.ConfigProto() +conf.gpu_options.allocator_type = 'BFC' +conf.gpu_options.allow_growth = True +with tf.Session(config=conf) as sess: + + saver.restore(sess, os.path.join(model_dir, "model_dump.ckpt")) + + vsample = next(valid_loader) + + analyse_values, prediction, gradients = sess.run([model.analyse, model.prediction, trainer.gradients], + feed_dict={data: vsample['x'], target: vsample['y'], mask: vsample['m']}) + weights = {v.name: {'var':g[1], 'grad':g[0], 'shape':g[0].shape } for v, g in zip(model.trainable_variables, gradients)} + if 'x_word' not in vsample.keys(): + vsample['x_word'] = np.transpose(np.argmax(vsample['x'], axis=-1),(1,0)) + data_sample = [vsample['x'], vsample['y'], vsample['m'], vsample['x_word'],] + + decoded_targets, decoded_predictions = dl.decode_output(vsample, prediction) + + save_list = [analyse_values, prediction, decoded_predictions, data_sample, weights ] + babi_bucket = Bucket(save_list, babi_short=True) + + plotter = PlotFunctionality(babi_bucket, title=True, legend=True, text_size=22) + plotter.plot_short_process(batch=0, plot_dir=plot_dir, name='function plot {}'.format(model_name)) + # plot_advanced_functionality(batch=0, plot_dir=plot_dir, name='extended function plot {}'.format(model_name)) + + + +