mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
add function plot script and bugfix methods
This commit is contained in:
parent
858355dcf4
commit
1c42aae82a
@ -189,18 +189,24 @@ class PlotFunctionality(PlotFunctions):
|
||||
|
||||
def plot_short_process(self, batch, plot_dir, name, show=False):
|
||||
|
||||
|
||||
if self.bucket.cell_type == 'dnc':
|
||||
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
|
||||
write_weighting, content_weighting, write_strength, alloc_weighting, write_vector, write_key, max_loc = self.bucket.get_write_process(
|
||||
batch)
|
||||
write_weighting, read_mode, read_weighting, read_head_influence, old_memory, new_memory, read_strength, max_loc = self.bucket.get_basic_functionality(
|
||||
batch=batch)
|
||||
else:
|
||||
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
|
||||
write_weighting, read_weighting, read_head_influence, old_memory, new_memory, read_strength, max_loc = self.bucket.get_basic_functionality(
|
||||
batch=batch)
|
||||
read_heads = self.bucket.max_read_head
|
||||
write_heads = self.bucket.max_write_head
|
||||
f, ax = plt.subplots((4 + 1 * read_heads + 3 * write_heads - 2), sharex=True, figsize=(12, 18))
|
||||
|
||||
|
||||
controller_influence, memory_unit_influence = self.bucket.get_memory_influence(batch)
|
||||
influence = np.stack([memory_unit_influence, controller_influence], axis=-1)
|
||||
influence = influence / influence.sum(axis=1, keepdims=True)
|
||||
|
||||
read_heads = self.bucket.max_read_head
|
||||
write_heads = self.bucket.max_write_head
|
||||
|
||||
f, ax = plt.subplots((4 + 1 * read_heads + 3 * write_heads - 2), sharex=True, figsize=(12, 18))
|
||||
|
||||
ax[0].set_title(name, size=33, weight='bold')
|
||||
|
||||
@ -213,7 +219,10 @@ class PlotFunctionality(PlotFunctions):
|
||||
for i in range(write_heads):
|
||||
self.plot_modes(alloc_gate[:, i, :], ax[3 + i * 3], ['y', 'b'], ['Content', 'Usage'], name='Alloc Gate')
|
||||
self.plot_modes(write_gate[:, i, :], ax[4 + i * 3], ['g', 'r'], ['Write', 'Write not'], name='Write Gate')
|
||||
print(write_gate[:, i, :].shape)
|
||||
|
||||
if self.bucket.cell_type == 'dnc':
|
||||
self.plot_modes(read_mode[:,i,:], ax[5+0*1], ['m', 'b','c'], ['Backward', 'Content', 'Forward'], name='Read Mode')
|
||||
else:
|
||||
self.plot_modes(np.zeros([34, 3]), ax[5 + 0 * 1], ['m', 'b', 'c'], ['Backward', 'Content', 'Forward'],
|
||||
name='Read Mode')
|
||||
|
||||
@ -234,7 +243,7 @@ class PlotFunctionality(PlotFunctions):
|
||||
tick.label.set_fontsize(16)
|
||||
plt.xlabel("Time steps", size=self.text_size)
|
||||
|
||||
plt.savefig(os.path.join(plot_dir, 'function_{}_{}.{}'.format(name, batch, self.data_type)),
|
||||
plt.savefig(os.path.join(plot_dir, '{}_{}.{}'.format(name, batch, self.data_type)),
|
||||
bbox_inches='tight', format=self.data_type, dpi=80)
|
||||
if show:
|
||||
plt.show()
|
||||
|
@ -42,11 +42,11 @@ class Bucket:
|
||||
self.link_matrix = link_matrix
|
||||
self.precedence_weighting = precedence_weighting
|
||||
else:
|
||||
self.cell_type = 'cmu'
|
||||
self.cell_type = 'cbmu'
|
||||
memory, usage_vector, write_weightings, read_weightings = memory_states
|
||||
alloc_gate, free_gates, write_gate, write_keys, write_strengths, write_vector, erase_vector, read_keys, read_strengths = analyse_signals
|
||||
self.read_mode = np.expand_dims(write_gate, -1)
|
||||
self.link_matrix = np.expand_dims(memory, -1)
|
||||
# self.read_mode = np.expand_dims(write_gate, -1)
|
||||
# self.link_matrix = np.expand_dims(memory, -1)
|
||||
|
||||
self.weights_dict = weights_dict
|
||||
self.batch_size = memory.shape[1]
|
||||
@ -265,21 +265,23 @@ class Bucket:
|
||||
text = self.x_word[batch]
|
||||
decoded_predictions = self.decoded_predictions[:, batch]
|
||||
|
||||
read_mode = self.read_mode[:max_loc, batch, :, :]
|
||||
read_weighting = self.read_weighting[:max_loc, batch, :, :]
|
||||
read_strength = self.read_strength[:max_loc, batch, :, :]
|
||||
read_key = self.read_keys[:max_loc, batch, :, :, ]
|
||||
read_vector = self.read_vector[:max_loc, batch, :, :]
|
||||
memory = self.memory[:max_loc, batch, :, :]
|
||||
link_matrix = self.link_matrix[:max_loc, batch, :, :, :]
|
||||
controller_output = self.controller_output[:max_loc, batch, :]
|
||||
|
||||
read_content_weighting = self.calculate_content_weightings(memory, read_key, read_strength[:, :, 0])
|
||||
forward_weighting, backward_weighting = self.calculate_forward_backward_weightings(link_matrix, read_weighting)
|
||||
read_head_influence = self.calculate_read_head_influence(read_vector, self.weights_dict, controller_output)
|
||||
|
||||
return correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, \
|
||||
backward_weighting, read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc
|
||||
if self.cell_type == 'dnc':
|
||||
read_mode = self.read_mode[:max_loc, batch, :, :]
|
||||
link_matrix = self.link_matrix[:max_loc, batch, :, :, :]
|
||||
forward_weighting, backward_weighting = self.calculate_forward_backward_weightings(link_matrix, read_weighting)
|
||||
return correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, backward_weighting, read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc
|
||||
else:
|
||||
return correct_prediction, false_prediction, text, decoded_predictions, mask, read_content_weighting, read_strength, read_key, read_weighting, read_vector, read_head_influence, max_loc
|
||||
|
||||
def get_memory_process(self, batch):
|
||||
|
||||
|
109
scripts/plot_function_babi_task.py
Normal file
109
scripts/plot_function_babi_task.py
Normal file
@ -0,0 +1,109 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2018 Jörg Franke
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import os
|
||||
import yaml
|
||||
import argparse
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from adnc.data import DataLoader
|
||||
from adnc.model import MANN, Optimizer
|
||||
from adnc.analysis import Bucket, PlotFunctionality
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
parser = argparse.ArgumentParser(description='Load model')
|
||||
parser.add_argument('model', type=str, default=False, help='model name')
|
||||
model_name = parser.parse_args().model
|
||||
|
||||
# Choose a pre trained model by uncomment
|
||||
if model_name == 'dnc':
|
||||
model_dir = "experiments/pre_trained/babi_task_1/dnc" # DNC trained on bAbI tasks 1
|
||||
elif model_name == 'adnc':
|
||||
model_dir = "experiments/pre_trained/babi_task_1/adnc" # ADNC trained on bAbI tasks 1
|
||||
elif model_name == 'biadnc':
|
||||
model_dir = "experiments/pre_trained/babi_task_1/biadnc" # BiADNC trained on bAbI tasks 1
|
||||
elif model_name == 'biadnc-all':
|
||||
model_dir = "experiments/pre_trained/babi_task_all/biadnc" # BiADNC trained on all bAbI tasks
|
||||
else:
|
||||
model_dir = "experiments/pre_trained/babi_task_all/biadnc_aug16" # BiADNC trained on all bAbI tasks with task 16 augmentation
|
||||
|
||||
plot_dir = "experiments/"
|
||||
|
||||
|
||||
analyse = True
|
||||
BATCH_SIZE = 1
|
||||
|
||||
|
||||
# load config from file
|
||||
with open(os.path.join(model_dir, 'config.yml'), 'r') as f:
|
||||
configs = yaml.load(f)
|
||||
dataset_config = configs['babi_task']
|
||||
trainer_config = configs['training']
|
||||
model_config = configs['mann']
|
||||
|
||||
|
||||
dataset_config['batch_size'] = BATCH_SIZE
|
||||
model_config['batch_size'] = BATCH_SIZE
|
||||
|
||||
dl = DataLoader(dataset_config)
|
||||
|
||||
model_config['input_size'] = dl.x_size
|
||||
model_config['output_size'] = dl.y_size
|
||||
|
||||
word_dict = dl.dataset.word_dict
|
||||
re_word_dict = dl.dataset.re_word_dict
|
||||
dataset_config['task_selection'] = [1]
|
||||
dl2 = DataLoader(dataset_config, word_dict, re_word_dict)
|
||||
valid_loader = dl2.get_data_loader('valid')
|
||||
|
||||
|
||||
model = MANN(model_config, analyse=True)
|
||||
|
||||
data, target, mask = model.feed
|
||||
|
||||
trainer = Optimizer(trainer_config, model.loss, model.trainable_variables)
|
||||
saver = tf.train.Saver()
|
||||
|
||||
|
||||
conf = tf.ConfigProto()
|
||||
conf.gpu_options.allocator_type = 'BFC'
|
||||
conf.gpu_options.allow_growth = True
|
||||
with tf.Session(config=conf) as sess:
|
||||
|
||||
saver.restore(sess, os.path.join(model_dir, "model_dump.ckpt"))
|
||||
|
||||
vsample = next(valid_loader)
|
||||
|
||||
analyse_values, prediction, gradients = sess.run([model.analyse, model.prediction, trainer.gradients],
|
||||
feed_dict={data: vsample['x'], target: vsample['y'], mask: vsample['m']})
|
||||
weights = {v.name: {'var':g[1], 'grad':g[0], 'shape':g[0].shape } for v, g in zip(model.trainable_variables, gradients)}
|
||||
if 'x_word' not in vsample.keys():
|
||||
vsample['x_word'] = np.transpose(np.argmax(vsample['x'], axis=-1),(1,0))
|
||||
data_sample = [vsample['x'], vsample['y'], vsample['m'], vsample['x_word'],]
|
||||
|
||||
decoded_targets, decoded_predictions = dl.decode_output(vsample, prediction)
|
||||
|
||||
save_list = [analyse_values, prediction, decoded_predictions, data_sample, weights ]
|
||||
babi_bucket = Bucket(save_list, babi_short=True)
|
||||
|
||||
plotter = PlotFunctionality(babi_bucket, title=True, legend=True, text_size=22)
|
||||
plotter.plot_short_process(batch=0, plot_dir=plot_dir, name='function plot {}'.format(model_name))
|
||||
# plot_advanced_functionality(batch=0, plot_dir=plot_dir, name='extended function plot {}'.format(model_name))
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user