add analyzer and plot functions

This commit is contained in:
Joerg Franke 2018-07-05 00:46:41 +02:00
parent 98f56912f3
commit 438e9bf0a0
4 changed files with 770 additions and 1 deletions

View File

@ -11,4 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# ==============================================================================
from .analyzer import Analyser
from .prepare_variables import Bucket
from .plot_functionality import PlotFunctionality

136
adnc/analysis/analyzer.py Executable file
View File

@ -0,0 +1,136 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import pickle
import numpy as np
from adnc.analysis.plot_functionality import PlotFunctionality
from adnc.analysis.prepare_variables import Bucket
from adnc.model.utils import softmax
class Analyser():
def __init__(self, data_set, record_dir, save_variables=False, save_fig=False):
self.data_set = data_set
self.record_dir = record_dir
self.save_variables = save_variables
self.save_fig = save_fig
self.max_batch_plot = 1
if save_fig:
self.plot_dir = os.path.join(self.record_dir, "plots")
while not os.path.isdir(self.plot_dir):
try:
os.mkdir(self.plot_dir)
except ValueError:
pass
def feed_variables(self, variables, epoch, name='variables'):
variable_name = name + "_{}".format(epoch)
buck = Bucket(variables)
plotter = PlotFunctionality(bucket=buck)
if self.save_variables:
self.save_variables_to_file(variables, variable_name)
if self.save_fig:
plotter.plot_basic_functionality(batch=0, plot_dir=self.plot_dir, name=variable_name, show=False)
return self.estimate_memory_usage(variables)
def feed_variables_two(self, variables, epoch, name='variables', save_plot=1):
variable_name = name + "_{}".format(epoch)
buck = Bucket(variables)
plotter = PlotFunctionality(bucket=buck)
if save_plot > 0:
if self.save_variables:
self.save_variables_to_file(variables, variable_name)
if self.save_fig:
plotter.plot_basic_functionality(batch=0, plot_dir=self.plot_dir, name=variable_name, show=False)
return self.estimate_memory_usage(variables)
def plot_analysis(self, variables, plot_dir, name='variables'):
buck = Bucket(variables)
plotter = PlotFunctionality(bucket=buck)
plotter.plot_basic_functionality(batch=0, plot_dir=plot_dir, name=name, show=True)
plotter.plot_advanced_functionality(batch=0, plot_dir=plot_dir, name=name, show=True)
def estimate_memory_usage(self, variables):
analyse_values, prediction, decoded_predictions, data_sample, weights_dict = variables
data, target, mask, x_word = data_sample
analyse_outputs, analyse_signals, analyse_states = analyse_values
controller_states, memory_states = analyse_states
if memory_states.__len__() == 6:
memory, usage_vector, write_weightings, precedence_weighting, link_matrix, read_weightings = memory_states
else:
memory, usage_vector, write_weightings, read_weightings = memory_states
read_head = read_weightings.shape[2]
memory_width = memory.shape[-1]
time_len = memory.shape[0]
memory_unit_mask = np.concatenate([np.ones([time_len, read_head * memory_width]), np.zeros(
[time_len, analyse_outputs.shape[-1] - (read_head * memory_width)])], axis=-1)
controller_mask = np.concatenate([np.zeros([time_len, read_head * memory_width]),
np.ones([time_len, analyse_outputs.shape[-1] - (read_head * memory_width)])],
axis=-1)
controller_influence = []
memory_unit_influence = []
for b in range(mask.shape[1]):
matmul = np.matmul(analyse_outputs[:, b, :], weights_dict['output_layer/weights_concat:0']['var']) + \
weights_dict['output_layer/bias_merge:0']['var']
pred_both = softmax(matmul)
matmul = np.matmul(analyse_outputs[:, b, :] * controller_mask,
weights_dict['output_layer/weights_concat:0']['var']) + \
weights_dict['output_layer/bias_merge:0']['var']
pred_c = softmax(matmul)
matmul = np.matmul(analyse_outputs[:, b, :] * memory_unit_mask,
weights_dict['output_layer/weights_concat:0']['var']) + \
weights_dict['output_layer/bias_merge:0']['var']
pred_mu = softmax(matmul)
co_inf = (np.abs(pred_both - pred_mu) * np.expand_dims(mask[:, b], 1)).sum() / mask[:, b].sum()
me_inf = (np.abs(pred_both - pred_c) * np.expand_dims(mask[:, b], 1)).sum() / mask[:, b].sum()
co_inf = (1 / (co_inf + me_inf)) * co_inf
me_inf = (1 / (co_inf + me_inf)) * me_inf
controller_influence.append(co_inf)
memory_unit_influence.append(me_inf)
controller_influence = np.mean(controller_influence)
memory_unit_influence = np.mean(memory_unit_influence)
return controller_influence, memory_unit_influence
def save_variables_to_file(self, variables, name):
save_file = os.path.join(self.record_dir, "{}.plk".format(name))
with open(save_file, "wb") as f:
pickle.dump(variables, f)

View File

@ -0,0 +1,362 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import matplotlib
import numpy as np
matplotlib.use('agg')
import matplotlib.pyplot as plt
from adnc.analysis.plot_functions import PlotFunctions
class PlotFunctionality(PlotFunctions):
def __init__(self, bucket, legend=False, title=False, text_size=16, data_type='png'):
self.bucket = bucket
self.data_type = data_type
super().__init__(legend, title, text_size)
def plot_basic_functionality(self, batch, plot_dir, name, show=False):
if self.bucket.cell_type == 'dnc':
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
write_weighting, read_mode, read_weighting, read_head_influence, old_memory, new_memory, read_strength, max_loc = self.bucket.get_basic_functionality(
batch=batch)
read_heads = self.bucket.max_read_head
write_heads = self.bucket.max_write_head
f, ax = plt.subplots((8 + 2 * read_heads + write_heads), sharex=True)
else:
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
write_weighting, read_weighting, read_head_influence, old_memory, new_memory, read_strength, max_loc = self.bucket.get_basic_functionality(
batch=batch)
read_heads = self.bucket.max_read_head
write_heads = self.bucket.max_write_head
f, ax = plt.subplots((8 + read_heads + write_heads), sharex=True)
controller_influence, memory_unit_influence = self.bucket.get_memory_influence(batch)
plt.xlim([-1, max_loc])
line_loc, width = self.plot_data_and_prediction(correct_prediction, false_prediction, text, decoded_predictions,
mask, ax[:2])
self.plot_modes(alloc_gate[:, 0, :], ax[2], ['y', 'b'], ['usage', 'content'], name='Alloc Gate')
self.plot_multi_modes(free_gate, ax[3], width, ['g', 'r'], ['free', 'not free'], name='Free Gates')
for i in range(write_heads):
self.plot_modes(write_gate[:, i, :], ax[4 + i], ['g', 'r'], ['write', 'write not'], name='Write Gate')
self.plot_matrix(old_memory, ax[4 + write_heads], name='Old Memeory', color='bwr')
self.plot_matrix(new_memory, ax[5 + write_heads], name='New Memeory', color='bwr')
self.plot_modes(read_head_influence, ax[6 + write_heads], [None for _ in range(read_heads)],
['head {}'.format(i + 1) for i in range(read_heads)], name='Head Influence')
for i in range(read_heads):
if self.bucket.cell_type == 'dnc':
self.plot_modes(read_strength[:, i, :], ax[7 + write_heads + i * 2], ['k'], ['strength'],
name='Read Cont. Stg.')
self.plot_modes(read_mode[:, i, :], ax[8 + write_heads + i * 2], ['m', 'b', 'c'],
['backward', 'content', 'forward'], name='Read Modes')
else:
self.plot_modes(read_strength[:, i, :], ax[7 + write_heads + i], ['k'], ['strength'],
name='Read Cont. Stg.')
influence = np.stack([memory_unit_influence, controller_influence], axis=-1)
self.plot_modes(influence, ax[-1], ['y', 'b'], ['memory usage', 'controller usage'], name='Memory Usage')
if max_loc < 80:
for ax_i in ax:
for l in line_loc:
ax_i.axvline(x=l, c='k')
for _ax in ax:
for tick in _ax.yaxis.get_major_ticks():
tick.label.set_fontsize(16)
for tick in ax[-1].xaxis.get_major_ticks():
tick.label.set_fontsize(16)
plt.xlabel("Time step", size='16')
f.set_size_inches(2 + max_loc * 0.2, 1 + ax.__len__() * 1.4)
if show:
plt.show()
plt.savefig(os.path.join(plot_dir, 'basic_{}_{}'.format(name, batch)), bbox_inches='tight', dpi=160)
plt.close(f)
return f
def plot_write_process(self, batch, plot_dir, name, show=False):
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
write_weighting, content_weighting, write_strength, alloc_weighting, write_vector, write_key, max_loc = self.bucket.get_write_process(
batch)
write_heads = self.bucket.max_write_head
usage_weightings = self.bucket.usage_vector
f, ax = plt.subplots((3 + 6 * write_heads), sharex=True, figsize=(12, 18))
plt.xlim([-1, max_loc])
line_loc, width = self.plot_data_and_prediction(correct_prediction, false_prediction, text, decoded_predictions,
mask, ax[:2])
self.plot_multi_modes(free_gate, ax[2], width, ['g', 'r'], ['Free', 'Free not'], name='Free Gates')
for i in range(write_heads):
self.plot_weightings(alloc_weighting[:, i, :], ax[3 + i * 6], name='Allocation\nWeighting')
self.plot_weightings(usage_weightings[:, i, :], ax[4 + i * 6], name='Usage\nWeighting')
# plot_modes(write_strength[:,i,:], ax[3+i*8], ['k'], ['strength'], name='Content Stg.')
# plot_weightings(write_key[:,i,:], ax[4+i*8], name='Content Key', mode='norm', color='jet')
self.plot_weightings(content_weighting[:, i, :], ax[5 + i * 6], name='Content\nWeighting')
self.plot_modes(alloc_gate[:, i, :], ax[6 + i * 6], ['y', 'b'], ['Allocation', 'Content'],
name='Allocation\nGate')
self.plot_modes(write_gate[:, i, :], ax[8 + i * 6], ['g', 'r'], ['Write', 'Write not'], name='Write Gate')
self.plot_weightings(write_weighting[:, i, :], ax[7 + i * 6], name='Write\nWeighting')
# plot_weightings(write_vector[:,i,:], ax[9+i*8], name='Write Vector', mode='norm', color='jet')
for ax_i in ax:
for l in line_loc:
ax_i.axvline(x=l, c='k')
for _ax in ax:
for tick in _ax.yaxis.get_major_ticks():
tick.label.set_fontsize(16)
for tick in ax[-1].xaxis.get_major_ticks():
tick.label.set_fontsize(16)
f.set_size_inches(2 + max_loc * 0.2, 1 + ax.__len__() * 1.4)
plt.savefig(os.path.join(plot_dir, '{}_{}'.format(name, batch)), bbox_inches='tight', dpi=160)
if show:
plt.show()
plt.close(f)
def plot_read_process(self, batch, plot_dir, name, show=False):
correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, backward_weighting, \
read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc = self.bucket.get_read_process(
batch)
read_heads = self.bucket.max_read_head
write_heads = self.bucket.max_write_head
f, ax = plt.subplots((3 + 5 * read_heads), sharex=True, figsize=(12, 24))
plt.xlim([-1, max_loc])
line_loc, width = self.plot_data_and_prediction(correct_prediction, false_prediction, text, decoded_predictions,
mask, ax[:2])
for i in range(read_heads):
for wh in range(write_heads):
self.plot_weightings(forward_weighting[:, i, wh, :], ax[2 + i * 5],
name='Forward\nWeighting\nHead {}'.format(i + 1))
self.plot_weightings(backward_weighting[:, i, wh, :], ax[3 + i * 5],
name='Backward\nWeighting\nHead {}'.format(i + 1))
# plot_modes(read_strength[:,i,:], ax[3+i*8], ['k'], ['strength'], name='Content Stg.')
# plot_weightings(read_key[:,i,:], ax[4+i*8], name='Content Key', mode='norm', color='jet')
self.plot_weightings(read_content_weighting[:, i, :], ax[4 + i * 5],
name='Content\nWeighting\nHead {}'.format(i + 1))
self.plot_modes(read_mode[:, i, :], ax[5 + i * 5], ['m', 'b', 'c'], ['Backward', 'Content', 'Forward'],
name='Read Modes\nHead {}'.format(i + 1))
self.plot_weightings(read_weighting[:, i, :], ax[6 + i * 5], name='Read Wgh. {}'.format(i + 1))
# plot_weightings(read_vector[:,i,:], ax[8+i*8], name='Read Vector', mode='norm', color='jet')
self.plot_modes(read_head_influence, ax[-1], [None for _ in range(read_heads)],
['Head {}'.format(i + 1) for i in range(read_heads)], name='Head\nInfluence')
for ax_i in ax:
for l in line_loc:
ax_i.axvline(x=l, c='k')
for _ax in ax:
for tick in _ax.yaxis.get_major_ticks():
tick.label.set_fontsize(16)
for tick in ax[-1].xaxis.get_major_ticks():
tick.label.set_fontsize(16)
f.set_size_inches(2 + max_loc * 0.2, 1 + ax.__len__() * 1.4)
plt.savefig(os.path.join(plot_dir, '{}_{}'.format(name, batch)), bbox_inches='tight', dpi=160)
if show:
plt.show()
plt.close(f)
def plot_short_process(self, batch, plot_dir, name, show=False):
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
write_weighting, content_weighting, write_strength, alloc_weighting, write_vector, write_key, max_loc = self.bucket.get_write_process(
batch)
controller_influence, memory_unit_influence = self.bucket.get_memory_influence(batch)
influence = np.stack([memory_unit_influence, controller_influence], axis=-1)
influence = influence / influence.sum(axis=1, keepdims=True)
read_heads = self.bucket.max_read_head
write_heads = self.bucket.max_write_head
f, ax = plt.subplots((4 + 1 * read_heads + 3 * write_heads - 2), sharex=True, figsize=(12, 18))
ax[0].set_title(name, size=33, weight='bold')
plt.xlim([-1, max_loc])
line_loc, width = self.plot_data_and_prediction(correct_prediction, false_prediction, text, decoded_predictions,
mask, ax[:2])
self.plot_multi_modes(free_gate, ax[2], width, ['g', 'r'], ['Free', 'Free not'], name='Free Gates')
for i in range(write_heads):
self.plot_modes(alloc_gate[:, i, :], ax[3 + i * 3], ['y', 'b'], ['Content', 'Usage'], name='Alloc Gate')
self.plot_modes(write_gate[:, i, :], ax[4 + i * 3], ['g', 'r'], ['Write', 'Write not'], name='Write Gate')
print(write_gate[:, i, :].shape)
self.plot_modes(np.zeros([34, 3]), ax[5 + 0 * 1], ['m', 'b', 'c'], ['Backward', 'Content', 'Forward'],
name='Read Mode')
ax[5 + 0 * 1].set_yticks([])
self.plot_modes(influence, ax[-1], ['darkorange', 'blueviolet'], ['Memory', 'Controller'],
name='Output\nInfluencer')
for ax_i in ax:
for l in line_loc:
ax_i.axvline(x=l, c='k')
for _ax in ax:
for tick in _ax.yaxis.get_major_ticks():
tick.label.set_fontsize(self.text_size)
for tick in ax[-1].xaxis.get_major_ticks():
tick.label.set_fontsize(16)
plt.xlabel("Time steps", size=self.text_size)
plt.savefig(os.path.join(plot_dir, 'function_{}_{}.{}'.format(name, batch, self.data_type)),
bbox_inches='tight', format=self.data_type, dpi=80)
if show:
plt.show()
plt.close(f)
def plot_memory_process(self, batch, plot_dir, name, show=False, dpi=160):
correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, backward_weighting, \
read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc = self.bucket.get_read_process(
batch)
correct_prediction, false_prediction, text, decoded_predictions, mask, old_memory, write_weighting, \
write_vector, erase_vector, add_memory, erase_memory, new_memory, max_loc = self.bucket.get_memory_process(
batch)
write_heads = self.bucket.max_write_head
read_heads = self.bucket.max_read_head
f, ax = plt.subplots((4 + 5 * write_heads + 2 * read_heads), sharex=False,
gridspec_kw={'height_ratios': [6, 6, 6, 1, 6, 1, 6, 6, 6, 1, 6, 1, 1]}, figsize=(12, 24))
plt.xlim([-1, max_loc * new_memory.shape[-1]])
line_loc, width = self.plot_data_plus_prediction(correct_prediction, false_prediction, text,
decoded_predictions, mask, ax[0])
ax[0].axis('tight')
ax[0].set_xlim(0 - width / 2, max_loc - width / 2)
self.plot_matrix(old_memory, ax[1], name='Old\nMemeory', color='bwr')
for i in range(write_heads):
self.plot_vector_as_matrix(write_weighting[:, i, :], vertical=True, repeats=old_memory.shape[2],
ax=ax[2 + i * 5], name='Write Wgh.', zero_width=5)
self.plot_vector_as_matrix(write_vector[:, i, :], vertical=False, repeats=old_memory.shape[1],
ax=ax[3 + i * 5], name='Write\nVector\n', zero_width=5, mode='norm', color='bwr',
legend=False)
self.plot_matrix(add_memory[:, :], ax[4 + i * 5], name='Add\nMatrix', color='bwr')
self.plot_vector_as_matrix(erase_vector[:, i, :], vertical=False, repeats=old_memory.shape[1],
ax=ax[5 + i * 5], name='Erase\nVector\n', zero_width=5, mode='norm1',
color='YlGnBu', legend=False)
self.plot_matrix(erase_memory[:, :], ax[6 + i * 5], name='Erase\nMatrix', mode='norm1', color='YlGnBu')
self.plot_matrix(new_memory, ax[7], name='New\nMemeory', color='bwr')
for i in range(read_heads):
self.plot_vector_as_matrix(read_weighting[:, i, :], ax=ax[8 + i * 2], vertical=True,
repeats=old_memory.shape[2], name='Read Wgh.\nHead {}'.format(i + 1),
zero_width=5)
self.plot_vector_as_matrix(read_vector[:, i, :], vertical=False, repeats=old_memory.shape[1],
ax=ax[9 + i * 2], name='Read\nVector\nHead {}\n'.format(i + 1), zero_width=5,
mode='norm', color='bwr', legend=False)
for ax_ in ax:
ax_.set_xticks([])
ax[-1].axis('off')
for _ax in ax:
for tick in _ax.yaxis.get_major_ticks():
tick.label.set_fontsize(16)
for tick in ax[-1].xaxis.get_major_ticks():
tick.label.set_fontsize(16)
f.set_size_inches(2 + max_loc * 0.4, 1 + ax.__len__() * 1.4)
plt.savefig(os.path.join(plot_dir, '{}_{}_{}'.format(name, batch, dpi)), bbox_inches='tight', dpi=dpi)
if show:
plt.show()
plt.close(f)
def plot_link_matrix_process(self, batch, plot_dir, name, show=False):
correct_prediction, false_prediction, text, decoded_predictions, mask, old_link_matrix, \
old_precedence_weighting, new_precedence_weighting, write_weighting, new_link_matrix, max_loc = self.bucket.get_link_matrix_process(
batch)
correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, backward_weighting, \
read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc = self.bucket.get_read_process(
batch)
write_heads = self.bucket.max_write_head
read_heads = self.bucket.max_read_head
f, ax = plt.subplots((1 + 3 * write_heads + 2 * read_heads), sharex=False)
plt.xlim([-1, max_loc * old_link_matrix.shape[-1]])
line_loc, width = self.plot_data_plus_prediction(correct_prediction, false_prediction, text,
decoded_predictions, mask, ax[0])
ax[0].axis('tight')
ax[0].set_xlim(0 - width / 2, max_loc - width / 2)
for i in range(write_heads):
# plot_matrix(old_link_matrix[:,i,:,:], ax[1+i*5], name='Old Link Mat', color='Purples', mode='norm1', zero_add='ones')
# plot_vector_as_matrix(old_precedence_weighting[:,i,:], vertical=True, repeats=old_link_matrix.shape[2], ax=ax[2+i*5], name='Old Precedence',zero_width=5)
self.plot_vector_as_matrix(write_weighting[:, i, :], vertical=True, repeats=old_link_matrix.shape[2],
ax=ax[1 + i * 5], name='Write\nWeighting.', zero_width=5)
self.plot_vector_as_matrix(new_precedence_weighting[:, i, :], vertical=True,
repeats=old_link_matrix.shape[2], ax=ax[2 + i * 5], name='Precedence',
zero_width=5)
self.plot_matrix(new_link_matrix[:, i, :, :], ax[3 + i * 5], name='Linkage\nMatrix', color='Purples',
mode='norm1', zero_add='ones')
for i in range(read_heads):
self.plot_vector_as_matrix(forward_weighting[:, i, 0, :], vertical=True, repeats=old_link_matrix.shape[2],
ax=ax[4 + i * 2], name='Forward Wgh.\nHead {}'.format(i + 1), zero_width=5,
color='GnBu')
self.plot_vector_as_matrix(backward_weighting[:, i, 0, :], vertical=True, repeats=old_link_matrix.shape[2],
ax=ax[5 + i * 2], name='Backward Wgh.\nHead {}'.format(i + 1), zero_width=5,
color='BuGn')
for _ax in ax:
for tick in _ax.yaxis.get_major_ticks():
tick.label.set_fontsize(16)
for tick in ax[-1].xaxis.get_major_ticks():
tick.label.set_fontsize(16)
plt.xlabel("Time step", size='16')
f.set_size_inches(2 + max_loc * 1.2, 1 + ax.__len__() * 1.4)
plt.savefig(os.path.join(plot_dir, '{}_{}.{}'.format(name, batch, self.date_type)), bbox_inches='tight',
dpi=160)
if show:
plt.show()
plt.close(f)
def plot_advanced_functionality(self, batch, plot_dir, name, show=False):
self.plot_write_process(batch, plot_dir, name='{}_write'.format(name), show=show)
self.plot_read_process(batch, plot_dir, name='{}_read'.format(name), show=show)
self.plot_memory_process(batch, plot_dir, name='{}_memory'.format(name), show=show)
self.plot_link_matrix_process(batch, plot_dir, name='{}_link_mat'.format(name), show=show)

268
adnc/analysis/plot_functions.py Executable file
View File

@ -0,0 +1,268 @@
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib import ticker
class PlotFunctions():
def __init__(self, legend=False, title=False, text_size=16):
self.legend = legend
self.title = title
self.text_size = text_size
def plot_data_plus_prediction(self, correct_prediction, false_prediction, text, decoded_predictions, mask, ax):
ind = np.arange(mask.shape[0])
ax.set_ylim([0, 1])
ax.bar(ind, np.ones(correct_prediction.shape), color='lightgray')
ax_corr = ax.bar(ind, correct_prediction, color='lawngreen')
ax_false = ax.bar(ind, false_prediction, color='tomato')
count = 0
line_loc = []
for rect in ax_corr:
if count >= text.__len__():
word = '---'
else:
word = text[count]
if mask[count] == 1:
word = decoded_predictions[count]
line_loc.append(rect.get_x())
line_loc.append(rect.get_x() + rect.get_width())
yloc = 0.5
xloc = rect.get_x() + 0.4
ax.text(xloc, yloc, word, horizontalalignment='center',
verticalalignment='center', rotation='vertical', color='black', clip_on=True, size=16)
count += 1
if self.legend:
ax.legend((ax_corr, ax_false), ('correct', 'wrong'), loc='center left', bbox_to_anchor=(1, 0.5),
prop={'size': self.text_size})
if self.title:
ax.set_ylabel('Task', size=self.text_size)
ax.get_yaxis().set_ticks([])
return line_loc, rect.get_width()
def plot_data_and_prediction(self, correct_prediction, false_prediction, text, decoded_predictions, mask, ax):
ind = np.arange(mask.shape[0])
ax[0].set_ylim([0, 1])
ax_aws = ax[0].bar(ind, mask, color='lightgray')
count = 0
line_loc = []
for rect in ax_aws:
if count >= text.__len__():
word = '---'
else:
word = text[count]
if mask[count] == 1:
line_loc.append(rect.get_x())
line_loc.append(rect.get_x() + rect.get_width())
yloc = 0.5
xloc = rect.get_x() + 0.4
ax[0].text(xloc, yloc, word, horizontalalignment='center',
verticalalignment='center', rotation='vertical', color='black',
clip_on=True, size=16)
count += 1
if self.legend:
ax[0].legend((ax_aws), ('answer',), loc='center left', bbox_to_anchor=(1, 0.5),
prop={'size': self.text_size})
if self.title:
ax[0].annotate('Questions', xy=(0, 0.8), xytext=(-ax[0].yaxis.labelpad - 150, 0),
xycoords=ax[0].yaxis.label, textcoords='offset points', size=self.text_size, ha='left',
va='center', ma='left')
ax[1].set_ylim([0, 1])
ax_corr = ax[1].bar(ind, correct_prediction, color='lawngreen')
ax_false = ax[1].bar(ind, false_prediction, color='tomato')
count = 0
for rect in ax_corr:
yloc = 0.5
xloc = rect.get_x() + 0.4
ax[1].text(xloc, yloc, decoded_predictions[count], horizontalalignment='center',
verticalalignment='center', rotation='vertical', color='black',
clip_on=True, size=16)
count += 1
if self.legend:
ax[1].legend((ax_corr, ax_false), ('Correct', 'Wrong'), loc='center left', bbox_to_anchor=(1, 0.5),
prop={'size': self.text_size})
if self.title:
ax[1].annotate('Predictions', xy=(0, 0.8), xytext=(-ax[1].yaxis.labelpad - 150, 0),
xycoords=ax[1].yaxis.label, textcoords='offset points', size=self.text_size, ha='left',
va='center', ma='left')
ax[0].get_yaxis().set_ticks([])
ax[1].get_yaxis().set_ticks([])
return line_loc, rect.get_width()
def plot_weightings(self, weightings, ax, name='Weightings', mode='log', color='YlOrRd'):
assert weightings.shape.__len__() == 2, "plot weightings: need 2D matrix as data"
if mode == 'log':
norm = colors.LogNorm(vmin=1e-3, vmax=1)
else:
norm = colors.Normalize(vmin=0, vmax=1)
img = ax.imshow(np.transpose(weightings), interpolation='nearest', norm=norm, cmap=color,
aspect='auto') # gist_stern
ax.set_adjustable('box-forced')
if self.title:
ax.set_ylabel(name, size=self.text_size)
if self.legend:
box = ax.get_position()
ax.set_position([box.x0 - 0.001, box.y0, box.width, box.height])
axColor = plt.axes([box.x0 + box.width + 0.005, box.y0, 0.005, box.height])
cb = plt.colorbar(img, cax=axColor, orientation="vertical")
for l in cb.ax.yaxis.get_ticklabels():
l.set_size(self.text_size)
def plot_modes(self, modes, ax, mode_colors, mode_names, name='Modes'):
assert modes.shape.__len__() == 2, "plot modes: need 2D matrix as data"
assert modes.shape[1] == mode_colors.__len__() and modes.shape[
1] == mode_names.__len__(), "plot modes: not same length"
ind = np.arange(modes.shape[0])
ax_list = [ax.bar(ind, modes[:, 0], color=mode_colors[0]), ]
if modes.shape[1] > 1:
for m in range(1, modes.shape[1]):
ax_list.append(ax.bar(ind, modes[:, m], bottom=modes[:, :m].sum(axis=1), color=mode_colors[m]))
ax.set_yticks([0, 1])
ax.set_ylim(0, 1)
if self.title:
if name == 'Read Mode':
ax.annotate(name, xy=(0, 0.8), xytext=(-ax.yaxis.labelpad - 150, 0), xycoords=ax.yaxis.label,
textcoords='offset points', size=self.text_size, ha='left', va='center', ma='left')
else:
ax.annotate(name, xy=(0, 0.8), xytext=(-ax.yaxis.labelpad - 135, 0), xycoords=ax.yaxis.label,
textcoords='offset points', size=self.text_size, ha='left', va='center', ma='left')
if self.legend:
ax.legend(ax_list, mode_names, loc='center left', bbox_to_anchor=(1, 0.5), prop={'size': self.text_size})
def plot_multi_modes(self, multi_modes, ax, width, mode_colors, mode_names, name='Multi Modes'):
modes = multi_modes.shape[1]
ind = np.arange(multi_modes.shape[0])
width = width / modes
for j in range(-1, modes - 1):
ax_list = [ax.bar(ind + j * width + (width * 0.5), multi_modes[:, j, 0], color=mode_colors[0], width=width,
align='center'), ]
if multi_modes.shape[2] > 1:
for m in range(1, multi_modes.shape[2]):
ax_list.append(ax.bar(ind + j * width + (width * 0.5), multi_modes[:, j, m],
bottom=multi_modes[:, j, :m].sum(axis=1), color=mode_colors[m], width=width,
align='center'))
if self.title:
ax.annotate(name, xy=(0, 0.8), xytext=(-ax.yaxis.labelpad - 135, 0), xycoords=ax.yaxis.label,
textcoords='offset points', size=self.text_size, ha='left', va='center', ma='left')
if self.legend:
ax.legend(ax_list, mode_names, loc='center left', bbox_to_anchor=(1, 0.5), prop={'size': self.text_size})
ax.set_yticks([0, 1])
ax.set_yticklabels(['0', '', '', '', '', '1'])
ax.set_ylim(0, 1)
def plot_matrix(self, matrix, ax, name='Weightings', mode='norm', color='RdYlBu', zero_width=5, zero_add='zeros'):
assert matrix.shape.__len__() == 3, "plot weightings: need 3D matrix as data"
if mode == 'log':
norm = colors.LogNorm(vmin=1e-8, vmax=0.1)
elif mode == 'norm1':
norm = colors.Normalize(vmin=0, vmax=1)
else:
norm = colors.Normalize(vmin=-1, vmax=1)
if zero_add == 'zeros':
matrix = np.concatenate([matrix, np.zeros([matrix.shape[0], matrix.shape[1], zero_width])], axis=2)
matrix = np.transpose(matrix, axes=(0, 2, 1))
flat_matrix = np.reshape(matrix, [-1, matrix.shape[2]])
flat_matrix = np.concatenate([np.zeros([zero_width, flat_matrix.shape[1]]), flat_matrix], axis=0)
else:
matrix = np.concatenate([matrix, np.ones([matrix.shape[0], matrix.shape[1], zero_width])], axis=2)
matrix = np.transpose(matrix, axes=(0, 2, 1))
flat_matrix = np.reshape(matrix, [-1, matrix.shape[2]])
flat_matrix = np.concatenate([np.ones([zero_width, flat_matrix.shape[1]]), flat_matrix], axis=0)
img = ax.imshow(np.transpose(flat_matrix), aspect='auto', interpolation='nearest', norm=norm, cmap=color)
ax.set_adjustable('box-forced')
if self.title:
ax.set_ylabel(name, size=self.text_size)
if self.legend:
box = ax.get_position()
ax.set_position([box.x0 - 0.001, box.y0, box.width, box.height])
axColor = plt.axes([box.x0 + box.width + 0.005, box.y0, 0.005, box.height])
cb = plt.colorbar(img, cax=axColor, orientation="vertical")
for l in cb.ax.yaxis.get_ticklabels():
l.set_size(self.text_size)
tick_locator = ticker.MaxNLocator(nbins=3)
cb.locator = tick_locator
cb.update_ticks()
def plot_vector_as_matrix(self, vector, vertical, repeats, ax, name='Weightings', mode='log', color='YlOrRd',
zero_width=5):
assert vector.shape.__len__() == 2, "plot weightings: need 2D matrix as data"
if mode == 'log':
norm = colors.LogNorm(vmin=1e-3, vmax=1)
elif mode == 'norm1':
norm = colors.Normalize(vmin=0, vmax=1)
else:
norm = colors.Normalize(vmin=-1, vmax=1)
if vertical:
matrix = np.repeat(vector, repeats, axis=1)
matrix = np.reshape(matrix, [vector.shape[0], vector.shape[1], repeats])
else:
matrix = np.repeat(vector, repeats, axis=0)
matrix = np.reshape(matrix, [vector.shape[0], repeats, vector.shape[1]])
matrix = np.concatenate([matrix, np.zeros([matrix.shape[0], matrix.shape[1], zero_width])], axis=2)
matrix = np.transpose(matrix, axes=(0, 2, 1))
flat_matrix = np.reshape(matrix, [-1, matrix.shape[2]])
flat_matrix = np.concatenate([np.zeros([zero_width, flat_matrix.shape[1]]), flat_matrix], axis=0)
img = ax.imshow(np.transpose(flat_matrix), aspect='auto', interpolation='nearest', norm=norm, cmap=color)
ax.set_adjustable('box-forced')
box = ax.get_position()
ax.set_position([box.x0 - 0.001, box.y0, box.width, box.height])
if self.legend:
axColor = plt.axes([box.x0 + box.width + 0.005, box.y0, 0.005, box.height])
cb = plt.colorbar(img, cax=axColor, orientation="vertical")
for l in cb.ax.yaxis.get_ticklabels():
l.set_size(self.text_size)
if self.title:
ax.set_ylabel(name, labelpad=30, size=self.text_size)
ax.set_yticks([])