add function plot script and bugfix methods

This commit is contained in:
Joerg Franke 2018-07-06 08:27:56 +02:00
parent 858355dcf4
commit 1c42aae82a
3 changed files with 139 additions and 19 deletions

View File

@ -189,18 +189,24 @@ class PlotFunctionality(PlotFunctions):
def plot_short_process(self, batch, plot_dir, name, show=False):
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
write_weighting, content_weighting, write_strength, alloc_weighting, write_vector, write_key, max_loc = self.bucket.get_write_process(
batch)
if self.bucket.cell_type == 'dnc':
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
write_weighting, read_mode, read_weighting, read_head_influence, old_memory, new_memory, read_strength, max_loc = self.bucket.get_basic_functionality(
batch=batch)
else:
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
write_weighting, read_weighting, read_head_influence, old_memory, new_memory, read_strength, max_loc = self.bucket.get_basic_functionality(
batch=batch)
read_heads = self.bucket.max_read_head
write_heads = self.bucket.max_write_head
f, ax = plt.subplots((4 + 1 * read_heads + 3 * write_heads - 2), sharex=True, figsize=(12, 18))
controller_influence, memory_unit_influence = self.bucket.get_memory_influence(batch)
influence = np.stack([memory_unit_influence, controller_influence], axis=-1)
influence = influence / influence.sum(axis=1, keepdims=True)
read_heads = self.bucket.max_read_head
write_heads = self.bucket.max_write_head
f, ax = plt.subplots((4 + 1 * read_heads + 3 * write_heads - 2), sharex=True, figsize=(12, 18))
ax[0].set_title(name, size=33, weight='bold')
@ -213,9 +219,12 @@ class PlotFunctionality(PlotFunctions):
for i in range(write_heads):
self.plot_modes(alloc_gate[:, i, :], ax[3 + i * 3], ['y', 'b'], ['Content', 'Usage'], name='Alloc Gate')
self.plot_modes(write_gate[:, i, :], ax[4 + i * 3], ['g', 'r'], ['Write', 'Write not'], name='Write Gate')
print(write_gate[:, i, :].shape)
self.plot_modes(np.zeros([34, 3]), ax[5 + 0 * 1], ['m', 'b', 'c'], ['Backward', 'Content', 'Forward'],
name='Read Mode')
if self.bucket.cell_type == 'dnc':
self.plot_modes(read_mode[:,i,:], ax[5+0*1], ['m', 'b','c'], ['Backward', 'Content', 'Forward'], name='Read Mode')
else:
self.plot_modes(np.zeros([34, 3]), ax[5 + 0 * 1], ['m', 'b', 'c'], ['Backward', 'Content', 'Forward'],
name='Read Mode')
ax[5 + 0 * 1].set_yticks([])
@ -234,7 +243,7 @@ class PlotFunctionality(PlotFunctions):
tick.label.set_fontsize(16)
plt.xlabel("Time steps", size=self.text_size)
plt.savefig(os.path.join(plot_dir, 'function_{}_{}.{}'.format(name, batch, self.data_type)),
plt.savefig(os.path.join(plot_dir, '{}_{}.{}'.format(name, batch, self.data_type)),
bbox_inches='tight', format=self.data_type, dpi=80)
if show:
plt.show()

View File

@ -42,11 +42,11 @@ class Bucket:
self.link_matrix = link_matrix
self.precedence_weighting = precedence_weighting
else:
self.cell_type = 'cmu'
self.cell_type = 'cbmu'
memory, usage_vector, write_weightings, read_weightings = memory_states
alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, erase_vector, read_keys, read_strengths = analyse_signals
self.read_mode = np.expand_dims(write_gate, -1)
self.link_matrix = np.expand_dims(memory, -1)
# self.read_mode = np.expand_dims(write_gate, -1)
# self.link_matrix = np.expand_dims(memory, -1)
self.weights_dict = weights_dict
self.batch_size = memory.shape[1]
@ -265,21 +265,23 @@ class Bucket:
text = self.x_word[batch]
decoded_predictions = self.decoded_predictions[:, batch]
read_mode = self.read_mode[:max_loc, batch, :, :]
read_weighting = self.read_weighting[:max_loc, batch, :, :]
read_strength = self.read_strength[:max_loc, batch, :, :]
read_key = self.read_keys[:max_loc, batch, :, :, ]
read_vector = self.read_vector[:max_loc, batch, :, :]
memory = self.memory[:max_loc, batch, :, :]
link_matrix = self.link_matrix[:max_loc, batch, :, :, :]
controller_output = self.controller_output[:max_loc, batch, :]
read_content_weighting = self.calculate_content_weightings(memory, read_key, read_strength[:, :, 0])
forward_weighting, backward_weighting = self.calculate_forward_backward_weightings(link_matrix, read_weighting)
read_head_influence = self.calculate_read_head_influence(read_vector, self.weights_dict, controller_output)
return correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, \
backward_weighting, read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc
if self.cell_type == 'dnc':
read_mode = self.read_mode[:max_loc, batch, :, :]
link_matrix = self.link_matrix[:max_loc, batch, :, :, :]
forward_weighting, backward_weighting = self.calculate_forward_backward_weightings(link_matrix, read_weighting)
return correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, backward_weighting, read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc
else:
return correct_prediction, false_prediction, text, decoded_predictions, mask, read_content_weighting, read_strength, read_key, read_weighting, read_vector, read_head_influence, max_loc
def get_memory_process(self, batch):

View File

@ -0,0 +1,109 @@
#!/usr/bin/env python
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import yaml
import argparse
import numpy as np
import tensorflow as tf
from adnc.data import DataLoader
from adnc.model import MANN, Optimizer
from adnc.analysis import Bucket, PlotFunctionality
tf.reset_default_graph()
parser = argparse.ArgumentParser(description='Load model')
parser.add_argument('model', type=str, default=False, help='model name')
model_name = parser.parse_args().model
# Choose a pre trained model by uncomment
if model_name == 'dnc':
model_dir = "experiments/pre_trained/babi_task_1/dnc" # DNC trained on bAbI tasks 1
elif model_name == 'adnc':
model_dir = "experiments/pre_trained/babi_task_1/adnc" # ADNC trained on bAbI tasks 1
elif model_name == 'biadnc':
model_dir = "experiments/pre_trained/babi_task_1/biadnc" # BiADNC trained on bAbI tasks 1
elif model_name == 'biadnc-all':
model_dir = "experiments/pre_trained/babi_task_all/biadnc" # BiADNC trained on all bAbI tasks
else:
model_dir = "experiments/pre_trained/babi_task_all/biadnc_aug16" # BiADNC trained on all bAbI tasks with task 16 augmentation
plot_dir = "experiments/"
analyse = True
BATCH_SIZE = 1
# load config from file
with open(os.path.join(model_dir, 'config.yml'), 'r') as f:
configs = yaml.load(f)
dataset_config = configs['babi_task']
trainer_config = configs['training']
model_config = configs['mann']
dataset_config['batch_size'] = BATCH_SIZE
model_config['batch_size'] = BATCH_SIZE
dl = DataLoader(dataset_config)
model_config['input_size'] = dl.x_size
model_config['output_size'] = dl.y_size
word_dict = dl.dataset.word_dict
re_word_dict = dl.dataset.re_word_dict
dataset_config['task_selection'] = [1]
dl2 = DataLoader(dataset_config, word_dict, re_word_dict)
valid_loader = dl2.get_data_loader('valid')
model = MANN(model_config, analyse=True)
data, target, mask = model.feed
trainer = Optimizer(trainer_config, model.loss, model.trainable_variables)
saver = tf.train.Saver()
conf = tf.ConfigProto()
conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True
with tf.Session(config=conf) as sess:
saver.restore(sess, os.path.join(model_dir, "model_dump.ckpt"))
vsample = next(valid_loader)
analyse_values, prediction, gradients = sess.run([model.analyse, model.prediction, trainer.gradients],
feed_dict={data: vsample['x'], target: vsample['y'], mask: vsample['m']})
weights = {v.name: {'var':g[1], 'grad':g[0], 'shape':g[0].shape } for v, g in zip(model.trainable_variables, gradients)}
if 'x_word' not in vsample.keys():
vsample['x_word'] = np.transpose(np.argmax(vsample['x'], axis=-1),(1,0))
data_sample = [vsample['x'], vsample['y'], vsample['m'], vsample['x_word'],]
decoded_targets, decoded_predictions = dl.decode_output(vsample, prediction)
save_list = [analyse_values, prediction, decoded_predictions, data_sample, weights ]
babi_bucket = Bucket(save_list, babi_short=True)
plotter = PlotFunctionality(babi_bucket, title=True, legend=True, text_size=22)
plotter.plot_short_process(batch=0, plot_dir=plot_dir, name='function plot {}'.format(model_name))
# plot_advanced_functionality(batch=0, plot_dir=plot_dir, name='extended function plot {}'.format(model_name))