mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
add analyzer and plot functions
This commit is contained in:
parent
98f56912f3
commit
438e9bf0a0
@ -11,4 +11,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
from .analyzer import Analyser
|
||||||
|
from .prepare_variables import Bucket
|
||||||
|
from .plot_functionality import PlotFunctionality
|
||||||
|
136
adnc/analysis/analyzer.py
Executable file
136
adnc/analysis/analyzer.py
Executable file
@ -0,0 +1,136 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from adnc.analysis.plot_functionality import PlotFunctionality
|
||||||
|
from adnc.analysis.prepare_variables import Bucket
|
||||||
|
from adnc.model.utils import softmax
|
||||||
|
|
||||||
|
|
||||||
|
class Analyser():
|
||||||
|
def __init__(self, data_set, record_dir, save_variables=False, save_fig=False):
|
||||||
|
|
||||||
|
self.data_set = data_set
|
||||||
|
self.record_dir = record_dir
|
||||||
|
self.save_variables = save_variables
|
||||||
|
self.save_fig = save_fig
|
||||||
|
|
||||||
|
self.max_batch_plot = 1
|
||||||
|
|
||||||
|
if save_fig:
|
||||||
|
self.plot_dir = os.path.join(self.record_dir, "plots")
|
||||||
|
while not os.path.isdir(self.plot_dir):
|
||||||
|
try:
|
||||||
|
os.mkdir(self.plot_dir)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def feed_variables(self, variables, epoch, name='variables'):
|
||||||
|
|
||||||
|
variable_name = name + "_{}".format(epoch)
|
||||||
|
|
||||||
|
buck = Bucket(variables)
|
||||||
|
|
||||||
|
plotter = PlotFunctionality(bucket=buck)
|
||||||
|
|
||||||
|
if self.save_variables:
|
||||||
|
self.save_variables_to_file(variables, variable_name)
|
||||||
|
|
||||||
|
if self.save_fig:
|
||||||
|
plotter.plot_basic_functionality(batch=0, plot_dir=self.plot_dir, name=variable_name, show=False)
|
||||||
|
|
||||||
|
return self.estimate_memory_usage(variables)
|
||||||
|
|
||||||
|
def feed_variables_two(self, variables, epoch, name='variables', save_plot=1):
|
||||||
|
|
||||||
|
variable_name = name + "_{}".format(epoch)
|
||||||
|
|
||||||
|
buck = Bucket(variables)
|
||||||
|
plotter = PlotFunctionality(bucket=buck)
|
||||||
|
|
||||||
|
if save_plot > 0:
|
||||||
|
if self.save_variables:
|
||||||
|
self.save_variables_to_file(variables, variable_name)
|
||||||
|
|
||||||
|
if self.save_fig:
|
||||||
|
plotter.plot_basic_functionality(batch=0, plot_dir=self.plot_dir, name=variable_name, show=False)
|
||||||
|
|
||||||
|
return self.estimate_memory_usage(variables)
|
||||||
|
|
||||||
|
def plot_analysis(self, variables, plot_dir, name='variables'):
|
||||||
|
|
||||||
|
buck = Bucket(variables)
|
||||||
|
plotter = PlotFunctionality(bucket=buck)
|
||||||
|
|
||||||
|
plotter.plot_basic_functionality(batch=0, plot_dir=plot_dir, name=name, show=True)
|
||||||
|
plotter.plot_advanced_functionality(batch=0, plot_dir=plot_dir, name=name, show=True)
|
||||||
|
|
||||||
|
def estimate_memory_usage(self, variables):
|
||||||
|
|
||||||
|
analyse_values, prediction, decoded_predictions, data_sample, weights_dict = variables
|
||||||
|
data, target, mask, x_word = data_sample
|
||||||
|
analyse_outputs, analyse_signals, analyse_states = analyse_values
|
||||||
|
controller_states, memory_states = analyse_states
|
||||||
|
|
||||||
|
if memory_states.__len__() == 6:
|
||||||
|
memory, usage_vector, write_weightings, precedence_weighting, link_matrix, read_weightings = memory_states
|
||||||
|
else:
|
||||||
|
memory, usage_vector, write_weightings, read_weightings = memory_states
|
||||||
|
read_head = read_weightings.shape[2]
|
||||||
|
memory_width = memory.shape[-1]
|
||||||
|
time_len = memory.shape[0]
|
||||||
|
memory_unit_mask = np.concatenate([np.ones([time_len, read_head * memory_width]), np.zeros(
|
||||||
|
[time_len, analyse_outputs.shape[-1] - (read_head * memory_width)])], axis=-1)
|
||||||
|
controller_mask = np.concatenate([np.zeros([time_len, read_head * memory_width]),
|
||||||
|
np.ones([time_len, analyse_outputs.shape[-1] - (read_head * memory_width)])],
|
||||||
|
axis=-1)
|
||||||
|
|
||||||
|
controller_influence = []
|
||||||
|
memory_unit_influence = []
|
||||||
|
for b in range(mask.shape[1]):
|
||||||
|
matmul = np.matmul(analyse_outputs[:, b, :], weights_dict['output_layer/weights_concat:0']['var']) + \
|
||||||
|
weights_dict['output_layer/bias_merge:0']['var']
|
||||||
|
pred_both = softmax(matmul)
|
||||||
|
|
||||||
|
matmul = np.matmul(analyse_outputs[:, b, :] * controller_mask,
|
||||||
|
weights_dict['output_layer/weights_concat:0']['var']) + \
|
||||||
|
weights_dict['output_layer/bias_merge:0']['var']
|
||||||
|
pred_c = softmax(matmul)
|
||||||
|
|
||||||
|
matmul = np.matmul(analyse_outputs[:, b, :] * memory_unit_mask,
|
||||||
|
weights_dict['output_layer/weights_concat:0']['var']) + \
|
||||||
|
weights_dict['output_layer/bias_merge:0']['var']
|
||||||
|
pred_mu = softmax(matmul)
|
||||||
|
|
||||||
|
co_inf = (np.abs(pred_both - pred_mu) * np.expand_dims(mask[:, b], 1)).sum() / mask[:, b].sum()
|
||||||
|
me_inf = (np.abs(pred_both - pred_c) * np.expand_dims(mask[:, b], 1)).sum() / mask[:, b].sum()
|
||||||
|
|
||||||
|
co_inf = (1 / (co_inf + me_inf)) * co_inf
|
||||||
|
me_inf = (1 / (co_inf + me_inf)) * me_inf
|
||||||
|
|
||||||
|
controller_influence.append(co_inf)
|
||||||
|
memory_unit_influence.append(me_inf)
|
||||||
|
|
||||||
|
controller_influence = np.mean(controller_influence)
|
||||||
|
memory_unit_influence = np.mean(memory_unit_influence)
|
||||||
|
return controller_influence, memory_unit_influence
|
||||||
|
|
||||||
|
def save_variables_to_file(self, variables, name):
|
||||||
|
save_file = os.path.join(self.record_dir, "{}.plk".format(name))
|
||||||
|
with open(save_file, "wb") as f:
|
||||||
|
pickle.dump(variables, f)
|
362
adnc/analysis/plot_functionality.py
Executable file
362
adnc/analysis/plot_functionality.py
Executable file
@ -0,0 +1,362 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import os
|
||||||
|
|
||||||
|
import matplotlib
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
matplotlib.use('agg')
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from adnc.analysis.plot_functions import PlotFunctions
|
||||||
|
|
||||||
|
|
||||||
|
class PlotFunctionality(PlotFunctions):
|
||||||
|
def __init__(self, bucket, legend=False, title=False, text_size=16, data_type='png'):
|
||||||
|
self.bucket = bucket
|
||||||
|
self.data_type = data_type
|
||||||
|
super().__init__(legend, title, text_size)
|
||||||
|
|
||||||
|
def plot_basic_functionality(self, batch, plot_dir, name, show=False):
|
||||||
|
|
||||||
|
if self.bucket.cell_type == 'dnc':
|
||||||
|
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
|
||||||
|
write_weighting, read_mode, read_weighting, read_head_influence, old_memory, new_memory, read_strength, max_loc = self.bucket.get_basic_functionality(
|
||||||
|
batch=batch)
|
||||||
|
read_heads = self.bucket.max_read_head
|
||||||
|
write_heads = self.bucket.max_write_head
|
||||||
|
f, ax = plt.subplots((8 + 2 * read_heads + write_heads), sharex=True)
|
||||||
|
else:
|
||||||
|
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
|
||||||
|
write_weighting, read_weighting, read_head_influence, old_memory, new_memory, read_strength, max_loc = self.bucket.get_basic_functionality(
|
||||||
|
batch=batch)
|
||||||
|
read_heads = self.bucket.max_read_head
|
||||||
|
write_heads = self.bucket.max_write_head
|
||||||
|
f, ax = plt.subplots((8 + read_heads + write_heads), sharex=True)
|
||||||
|
|
||||||
|
controller_influence, memory_unit_influence = self.bucket.get_memory_influence(batch)
|
||||||
|
|
||||||
|
plt.xlim([-1, max_loc])
|
||||||
|
|
||||||
|
line_loc, width = self.plot_data_and_prediction(correct_prediction, false_prediction, text, decoded_predictions,
|
||||||
|
mask, ax[:2])
|
||||||
|
self.plot_modes(alloc_gate[:, 0, :], ax[2], ['y', 'b'], ['usage', 'content'], name='Alloc Gate')
|
||||||
|
self.plot_multi_modes(free_gate, ax[3], width, ['g', 'r'], ['free', 'not free'], name='Free Gates')
|
||||||
|
|
||||||
|
for i in range(write_heads):
|
||||||
|
self.plot_modes(write_gate[:, i, :], ax[4 + i], ['g', 'r'], ['write', 'write not'], name='Write Gate')
|
||||||
|
|
||||||
|
self.plot_matrix(old_memory, ax[4 + write_heads], name='Old Memeory', color='bwr')
|
||||||
|
self.plot_matrix(new_memory, ax[5 + write_heads], name='New Memeory', color='bwr')
|
||||||
|
|
||||||
|
self.plot_modes(read_head_influence, ax[6 + write_heads], [None for _ in range(read_heads)],
|
||||||
|
['head {}'.format(i + 1) for i in range(read_heads)], name='Head Influence')
|
||||||
|
|
||||||
|
for i in range(read_heads):
|
||||||
|
if self.bucket.cell_type == 'dnc':
|
||||||
|
self.plot_modes(read_strength[:, i, :], ax[7 + write_heads + i * 2], ['k'], ['strength'],
|
||||||
|
name='Read Cont. Stg.')
|
||||||
|
self.plot_modes(read_mode[:, i, :], ax[8 + write_heads + i * 2], ['m', 'b', 'c'],
|
||||||
|
['backward', 'content', 'forward'], name='Read Modes')
|
||||||
|
else:
|
||||||
|
self.plot_modes(read_strength[:, i, :], ax[7 + write_heads + i], ['k'], ['strength'],
|
||||||
|
name='Read Cont. Stg.')
|
||||||
|
|
||||||
|
influence = np.stack([memory_unit_influence, controller_influence], axis=-1)
|
||||||
|
self.plot_modes(influence, ax[-1], ['y', 'b'], ['memory usage', 'controller usage'], name='Memory Usage')
|
||||||
|
|
||||||
|
if max_loc < 80:
|
||||||
|
for ax_i in ax:
|
||||||
|
for l in line_loc:
|
||||||
|
ax_i.axvline(x=l, c='k')
|
||||||
|
|
||||||
|
for _ax in ax:
|
||||||
|
for tick in _ax.yaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(16)
|
||||||
|
|
||||||
|
for tick in ax[-1].xaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(16)
|
||||||
|
plt.xlabel("Time step", size='16')
|
||||||
|
|
||||||
|
f.set_size_inches(2 + max_loc * 0.2, 1 + ax.__len__() * 1.4)
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
plt.savefig(os.path.join(plot_dir, 'basic_{}_{}'.format(name, batch)), bbox_inches='tight', dpi=160)
|
||||||
|
plt.close(f)
|
||||||
|
return f
|
||||||
|
|
||||||
|
def plot_write_process(self, batch, plot_dir, name, show=False):
|
||||||
|
|
||||||
|
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
|
||||||
|
write_weighting, content_weighting, write_strength, alloc_weighting, write_vector, write_key, max_loc = self.bucket.get_write_process(
|
||||||
|
batch)
|
||||||
|
write_heads = self.bucket.max_write_head
|
||||||
|
usage_weightings = self.bucket.usage_vector
|
||||||
|
f, ax = plt.subplots((3 + 6 * write_heads), sharex=True, figsize=(12, 18))
|
||||||
|
plt.xlim([-1, max_loc])
|
||||||
|
|
||||||
|
line_loc, width = self.plot_data_and_prediction(correct_prediction, false_prediction, text, decoded_predictions,
|
||||||
|
mask, ax[:2])
|
||||||
|
self.plot_multi_modes(free_gate, ax[2], width, ['g', 'r'], ['Free', 'Free not'], name='Free Gates')
|
||||||
|
for i in range(write_heads):
|
||||||
|
self.plot_weightings(alloc_weighting[:, i, :], ax[3 + i * 6], name='Allocation\nWeighting')
|
||||||
|
self.plot_weightings(usage_weightings[:, i, :], ax[4 + i * 6], name='Usage\nWeighting')
|
||||||
|
# plot_modes(write_strength[:,i,:], ax[3+i*8], ['k'], ['strength'], name='Content Stg.')
|
||||||
|
# plot_weightings(write_key[:,i,:], ax[4+i*8], name='Content Key', mode='norm', color='jet')
|
||||||
|
self.plot_weightings(content_weighting[:, i, :], ax[5 + i * 6], name='Content\nWeighting')
|
||||||
|
self.plot_modes(alloc_gate[:, i, :], ax[6 + i * 6], ['y', 'b'], ['Allocation', 'Content'],
|
||||||
|
name='Allocation\nGate')
|
||||||
|
self.plot_modes(write_gate[:, i, :], ax[8 + i * 6], ['g', 'r'], ['Write', 'Write not'], name='Write Gate')
|
||||||
|
self.plot_weightings(write_weighting[:, i, :], ax[7 + i * 6], name='Write\nWeighting')
|
||||||
|
# plot_weightings(write_vector[:,i,:], ax[9+i*8], name='Write Vector', mode='norm', color='jet')
|
||||||
|
|
||||||
|
for ax_i in ax:
|
||||||
|
for l in line_loc:
|
||||||
|
ax_i.axvline(x=l, c='k')
|
||||||
|
|
||||||
|
for _ax in ax:
|
||||||
|
for tick in _ax.yaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(16)
|
||||||
|
|
||||||
|
for tick in ax[-1].xaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(16)
|
||||||
|
|
||||||
|
f.set_size_inches(2 + max_loc * 0.2, 1 + ax.__len__() * 1.4)
|
||||||
|
plt.savefig(os.path.join(plot_dir, '{}_{}'.format(name, batch)), bbox_inches='tight', dpi=160)
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
plt.close(f)
|
||||||
|
|
||||||
|
def plot_read_process(self, batch, plot_dir, name, show=False):
|
||||||
|
|
||||||
|
correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, backward_weighting, \
|
||||||
|
read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc = self.bucket.get_read_process(
|
||||||
|
batch)
|
||||||
|
read_heads = self.bucket.max_read_head
|
||||||
|
write_heads = self.bucket.max_write_head
|
||||||
|
|
||||||
|
f, ax = plt.subplots((3 + 5 * read_heads), sharex=True, figsize=(12, 24))
|
||||||
|
plt.xlim([-1, max_loc])
|
||||||
|
|
||||||
|
line_loc, width = self.plot_data_and_prediction(correct_prediction, false_prediction, text, decoded_predictions,
|
||||||
|
mask, ax[:2])
|
||||||
|
for i in range(read_heads):
|
||||||
|
for wh in range(write_heads):
|
||||||
|
self.plot_weightings(forward_weighting[:, i, wh, :], ax[2 + i * 5],
|
||||||
|
name='Forward\nWeighting\nHead {}'.format(i + 1))
|
||||||
|
self.plot_weightings(backward_weighting[:, i, wh, :], ax[3 + i * 5],
|
||||||
|
name='Backward\nWeighting\nHead {}'.format(i + 1))
|
||||||
|
# plot_modes(read_strength[:,i,:], ax[3+i*8], ['k'], ['strength'], name='Content Stg.')
|
||||||
|
# plot_weightings(read_key[:,i,:], ax[4+i*8], name='Content Key', mode='norm', color='jet')
|
||||||
|
self.plot_weightings(read_content_weighting[:, i, :], ax[4 + i * 5],
|
||||||
|
name='Content\nWeighting\nHead {}'.format(i + 1))
|
||||||
|
self.plot_modes(read_mode[:, i, :], ax[5 + i * 5], ['m', 'b', 'c'], ['Backward', 'Content', 'Forward'],
|
||||||
|
name='Read Modes\nHead {}'.format(i + 1))
|
||||||
|
self.plot_weightings(read_weighting[:, i, :], ax[6 + i * 5], name='Read Wgh. {}'.format(i + 1))
|
||||||
|
# plot_weightings(read_vector[:,i,:], ax[8+i*8], name='Read Vector', mode='norm', color='jet')
|
||||||
|
self.plot_modes(read_head_influence, ax[-1], [None for _ in range(read_heads)],
|
||||||
|
['Head {}'.format(i + 1) for i in range(read_heads)], name='Head\nInfluence')
|
||||||
|
|
||||||
|
for ax_i in ax:
|
||||||
|
for l in line_loc:
|
||||||
|
ax_i.axvline(x=l, c='k')
|
||||||
|
|
||||||
|
for _ax in ax:
|
||||||
|
for tick in _ax.yaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(16)
|
||||||
|
|
||||||
|
for tick in ax[-1].xaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(16)
|
||||||
|
|
||||||
|
f.set_size_inches(2 + max_loc * 0.2, 1 + ax.__len__() * 1.4)
|
||||||
|
plt.savefig(os.path.join(plot_dir, '{}_{}'.format(name, batch)), bbox_inches='tight', dpi=160)
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
plt.close(f)
|
||||||
|
|
||||||
|
def plot_short_process(self, batch, plot_dir, name, show=False):
|
||||||
|
|
||||||
|
correct_prediction, false_prediction, text, decoded_predictions, mask, alloc_gate, free_gate, write_gate, \
|
||||||
|
write_weighting, content_weighting, write_strength, alloc_weighting, write_vector, write_key, max_loc = self.bucket.get_write_process(
|
||||||
|
batch)
|
||||||
|
|
||||||
|
controller_influence, memory_unit_influence = self.bucket.get_memory_influence(batch)
|
||||||
|
influence = np.stack([memory_unit_influence, controller_influence], axis=-1)
|
||||||
|
influence = influence / influence.sum(axis=1, keepdims=True)
|
||||||
|
|
||||||
|
read_heads = self.bucket.max_read_head
|
||||||
|
write_heads = self.bucket.max_write_head
|
||||||
|
|
||||||
|
f, ax = plt.subplots((4 + 1 * read_heads + 3 * write_heads - 2), sharex=True, figsize=(12, 18))
|
||||||
|
|
||||||
|
ax[0].set_title(name, size=33, weight='bold')
|
||||||
|
|
||||||
|
plt.xlim([-1, max_loc])
|
||||||
|
|
||||||
|
line_loc, width = self.plot_data_and_prediction(correct_prediction, false_prediction, text, decoded_predictions,
|
||||||
|
mask, ax[:2])
|
||||||
|
|
||||||
|
self.plot_multi_modes(free_gate, ax[2], width, ['g', 'r'], ['Free', 'Free not'], name='Free Gates')
|
||||||
|
for i in range(write_heads):
|
||||||
|
self.plot_modes(alloc_gate[:, i, :], ax[3 + i * 3], ['y', 'b'], ['Content', 'Usage'], name='Alloc Gate')
|
||||||
|
self.plot_modes(write_gate[:, i, :], ax[4 + i * 3], ['g', 'r'], ['Write', 'Write not'], name='Write Gate')
|
||||||
|
print(write_gate[:, i, :].shape)
|
||||||
|
self.plot_modes(np.zeros([34, 3]), ax[5 + 0 * 1], ['m', 'b', 'c'], ['Backward', 'Content', 'Forward'],
|
||||||
|
name='Read Mode')
|
||||||
|
|
||||||
|
ax[5 + 0 * 1].set_yticks([])
|
||||||
|
|
||||||
|
self.plot_modes(influence, ax[-1], ['darkorange', 'blueviolet'], ['Memory', 'Controller'],
|
||||||
|
name='Output\nInfluencer')
|
||||||
|
|
||||||
|
for ax_i in ax:
|
||||||
|
for l in line_loc:
|
||||||
|
ax_i.axvline(x=l, c='k')
|
||||||
|
|
||||||
|
for _ax in ax:
|
||||||
|
for tick in _ax.yaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(self.text_size)
|
||||||
|
|
||||||
|
for tick in ax[-1].xaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(16)
|
||||||
|
plt.xlabel("Time steps", size=self.text_size)
|
||||||
|
|
||||||
|
plt.savefig(os.path.join(plot_dir, 'function_{}_{}.{}'.format(name, batch, self.data_type)),
|
||||||
|
bbox_inches='tight', format=self.data_type, dpi=80)
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
plt.close(f)
|
||||||
|
|
||||||
|
def plot_memory_process(self, batch, plot_dir, name, show=False, dpi=160):
|
||||||
|
|
||||||
|
correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, backward_weighting, \
|
||||||
|
read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc = self.bucket.get_read_process(
|
||||||
|
batch)
|
||||||
|
|
||||||
|
correct_prediction, false_prediction, text, decoded_predictions, mask, old_memory, write_weighting, \
|
||||||
|
write_vector, erase_vector, add_memory, erase_memory, new_memory, max_loc = self.bucket.get_memory_process(
|
||||||
|
batch)
|
||||||
|
write_heads = self.bucket.max_write_head
|
||||||
|
read_heads = self.bucket.max_read_head
|
||||||
|
|
||||||
|
f, ax = plt.subplots((4 + 5 * write_heads + 2 * read_heads), sharex=False,
|
||||||
|
gridspec_kw={'height_ratios': [6, 6, 6, 1, 6, 1, 6, 6, 6, 1, 6, 1, 1]}, figsize=(12, 24))
|
||||||
|
plt.xlim([-1, max_loc * new_memory.shape[-1]])
|
||||||
|
line_loc, width = self.plot_data_plus_prediction(correct_prediction, false_prediction, text,
|
||||||
|
decoded_predictions, mask, ax[0])
|
||||||
|
ax[0].axis('tight')
|
||||||
|
ax[0].set_xlim(0 - width / 2, max_loc - width / 2)
|
||||||
|
|
||||||
|
self.plot_matrix(old_memory, ax[1], name='Old\nMemeory', color='bwr')
|
||||||
|
for i in range(write_heads):
|
||||||
|
self.plot_vector_as_matrix(write_weighting[:, i, :], vertical=True, repeats=old_memory.shape[2],
|
||||||
|
ax=ax[2 + i * 5], name='Write Wgh.', zero_width=5)
|
||||||
|
self.plot_vector_as_matrix(write_vector[:, i, :], vertical=False, repeats=old_memory.shape[1],
|
||||||
|
ax=ax[3 + i * 5], name='Write\nVector\n', zero_width=5, mode='norm', color='bwr',
|
||||||
|
legend=False)
|
||||||
|
self.plot_matrix(add_memory[:, :], ax[4 + i * 5], name='Add\nMatrix', color='bwr')
|
||||||
|
self.plot_vector_as_matrix(erase_vector[:, i, :], vertical=False, repeats=old_memory.shape[1],
|
||||||
|
ax=ax[5 + i * 5], name='Erase\nVector\n', zero_width=5, mode='norm1',
|
||||||
|
color='YlGnBu', legend=False)
|
||||||
|
self.plot_matrix(erase_memory[:, :], ax[6 + i * 5], name='Erase\nMatrix', mode='norm1', color='YlGnBu')
|
||||||
|
self.plot_matrix(new_memory, ax[7], name='New\nMemeory', color='bwr')
|
||||||
|
|
||||||
|
for i in range(read_heads):
|
||||||
|
self.plot_vector_as_matrix(read_weighting[:, i, :], ax=ax[8 + i * 2], vertical=True,
|
||||||
|
repeats=old_memory.shape[2], name='Read Wgh.\nHead {}'.format(i + 1),
|
||||||
|
zero_width=5)
|
||||||
|
self.plot_vector_as_matrix(read_vector[:, i, :], vertical=False, repeats=old_memory.shape[1],
|
||||||
|
ax=ax[9 + i * 2], name='Read\nVector\nHead {}\n'.format(i + 1), zero_width=5,
|
||||||
|
mode='norm', color='bwr', legend=False)
|
||||||
|
|
||||||
|
for ax_ in ax:
|
||||||
|
ax_.set_xticks([])
|
||||||
|
ax[-1].axis('off')
|
||||||
|
|
||||||
|
for _ax in ax:
|
||||||
|
for tick in _ax.yaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(16)
|
||||||
|
|
||||||
|
for tick in ax[-1].xaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(16)
|
||||||
|
|
||||||
|
f.set_size_inches(2 + max_loc * 0.4, 1 + ax.__len__() * 1.4)
|
||||||
|
plt.savefig(os.path.join(plot_dir, '{}_{}_{}'.format(name, batch, dpi)), bbox_inches='tight', dpi=dpi)
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
plt.close(f)
|
||||||
|
|
||||||
|
def plot_link_matrix_process(self, batch, plot_dir, name, show=False):
|
||||||
|
|
||||||
|
correct_prediction, false_prediction, text, decoded_predictions, mask, old_link_matrix, \
|
||||||
|
old_precedence_weighting, new_precedence_weighting, write_weighting, new_link_matrix, max_loc = self.bucket.get_link_matrix_process(
|
||||||
|
batch)
|
||||||
|
|
||||||
|
correct_prediction, false_prediction, text, decoded_predictions, mask, forward_weighting, backward_weighting, \
|
||||||
|
read_content_weighting, read_strength, read_key, read_mode, read_weighting, read_vector, read_head_influence, max_loc = self.bucket.get_read_process(
|
||||||
|
batch)
|
||||||
|
|
||||||
|
write_heads = self.bucket.max_write_head
|
||||||
|
read_heads = self.bucket.max_read_head
|
||||||
|
|
||||||
|
f, ax = plt.subplots((1 + 3 * write_heads + 2 * read_heads), sharex=False)
|
||||||
|
plt.xlim([-1, max_loc * old_link_matrix.shape[-1]])
|
||||||
|
line_loc, width = self.plot_data_plus_prediction(correct_prediction, false_prediction, text,
|
||||||
|
decoded_predictions, mask, ax[0])
|
||||||
|
ax[0].axis('tight')
|
||||||
|
ax[0].set_xlim(0 - width / 2, max_loc - width / 2)
|
||||||
|
|
||||||
|
for i in range(write_heads):
|
||||||
|
# plot_matrix(old_link_matrix[:,i,:,:], ax[1+i*5], name='Old Link Mat', color='Purples', mode='norm1', zero_add='ones')
|
||||||
|
# plot_vector_as_matrix(old_precedence_weighting[:,i,:], vertical=True, repeats=old_link_matrix.shape[2], ax=ax[2+i*5], name='Old Precedence',zero_width=5)
|
||||||
|
self.plot_vector_as_matrix(write_weighting[:, i, :], vertical=True, repeats=old_link_matrix.shape[2],
|
||||||
|
ax=ax[1 + i * 5], name='Write\nWeighting.', zero_width=5)
|
||||||
|
self.plot_vector_as_matrix(new_precedence_weighting[:, i, :], vertical=True,
|
||||||
|
repeats=old_link_matrix.shape[2], ax=ax[2 + i * 5], name='Precedence',
|
||||||
|
zero_width=5)
|
||||||
|
self.plot_matrix(new_link_matrix[:, i, :, :], ax[3 + i * 5], name='Linkage\nMatrix', color='Purples',
|
||||||
|
mode='norm1', zero_add='ones')
|
||||||
|
|
||||||
|
for i in range(read_heads):
|
||||||
|
self.plot_vector_as_matrix(forward_weighting[:, i, 0, :], vertical=True, repeats=old_link_matrix.shape[2],
|
||||||
|
ax=ax[4 + i * 2], name='Forward Wgh.\nHead {}'.format(i + 1), zero_width=5,
|
||||||
|
color='GnBu')
|
||||||
|
self.plot_vector_as_matrix(backward_weighting[:, i, 0, :], vertical=True, repeats=old_link_matrix.shape[2],
|
||||||
|
ax=ax[5 + i * 2], name='Backward Wgh.\nHead {}'.format(i + 1), zero_width=5,
|
||||||
|
color='BuGn')
|
||||||
|
|
||||||
|
for _ax in ax:
|
||||||
|
for tick in _ax.yaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(16)
|
||||||
|
|
||||||
|
for tick in ax[-1].xaxis.get_major_ticks():
|
||||||
|
tick.label.set_fontsize(16)
|
||||||
|
plt.xlabel("Time step", size='16')
|
||||||
|
|
||||||
|
f.set_size_inches(2 + max_loc * 1.2, 1 + ax.__len__() * 1.4)
|
||||||
|
|
||||||
|
plt.savefig(os.path.join(plot_dir, '{}_{}.{}'.format(name, batch, self.date_type)), bbox_inches='tight',
|
||||||
|
dpi=160)
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
plt.close(f)
|
||||||
|
|
||||||
|
def plot_advanced_functionality(self, batch, plot_dir, name, show=False):
|
||||||
|
|
||||||
|
self.plot_write_process(batch, plot_dir, name='{}_write'.format(name), show=show)
|
||||||
|
self.plot_read_process(batch, plot_dir, name='{}_read'.format(name), show=show)
|
||||||
|
self.plot_memory_process(batch, plot_dir, name='{}_memory'.format(name), show=show)
|
||||||
|
self.plot_link_matrix_process(batch, plot_dir, name='{}_link_mat'.format(name), show=show)
|
268
adnc/analysis/plot_functions.py
Executable file
268
adnc/analysis/plot_functions.py
Executable file
@ -0,0 +1,268 @@
|
|||||||
|
# Copyright 2018 Jörg Franke
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
matplotlib.use('agg')
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib import colors
|
||||||
|
from matplotlib import ticker
|
||||||
|
|
||||||
|
|
||||||
|
class PlotFunctions():
|
||||||
|
def __init__(self, legend=False, title=False, text_size=16):
|
||||||
|
|
||||||
|
self.legend = legend
|
||||||
|
self.title = title
|
||||||
|
self.text_size = text_size
|
||||||
|
|
||||||
|
def plot_data_plus_prediction(self, correct_prediction, false_prediction, text, decoded_predictions, mask, ax):
|
||||||
|
|
||||||
|
ind = np.arange(mask.shape[0])
|
||||||
|
ax.set_ylim([0, 1])
|
||||||
|
ax.bar(ind, np.ones(correct_prediction.shape), color='lightgray')
|
||||||
|
ax_corr = ax.bar(ind, correct_prediction, color='lawngreen')
|
||||||
|
ax_false = ax.bar(ind, false_prediction, color='tomato')
|
||||||
|
count = 0
|
||||||
|
line_loc = []
|
||||||
|
for rect in ax_corr:
|
||||||
|
if count >= text.__len__():
|
||||||
|
word = '---'
|
||||||
|
else:
|
||||||
|
word = text[count]
|
||||||
|
if mask[count] == 1:
|
||||||
|
word = decoded_predictions[count]
|
||||||
|
line_loc.append(rect.get_x())
|
||||||
|
line_loc.append(rect.get_x() + rect.get_width())
|
||||||
|
|
||||||
|
yloc = 0.5
|
||||||
|
xloc = rect.get_x() + 0.4
|
||||||
|
ax.text(xloc, yloc, word, horizontalalignment='center',
|
||||||
|
verticalalignment='center', rotation='vertical', color='black', clip_on=True, size=16)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if self.legend:
|
||||||
|
ax.legend((ax_corr, ax_false), ('correct', 'wrong'), loc='center left', bbox_to_anchor=(1, 0.5),
|
||||||
|
prop={'size': self.text_size})
|
||||||
|
|
||||||
|
if self.title:
|
||||||
|
ax.set_ylabel('Task', size=self.text_size)
|
||||||
|
|
||||||
|
ax.get_yaxis().set_ticks([])
|
||||||
|
|
||||||
|
return line_loc, rect.get_width()
|
||||||
|
|
||||||
|
def plot_data_and_prediction(self, correct_prediction, false_prediction, text, decoded_predictions, mask, ax):
|
||||||
|
|
||||||
|
ind = np.arange(mask.shape[0])
|
||||||
|
ax[0].set_ylim([0, 1])
|
||||||
|
|
||||||
|
ax_aws = ax[0].bar(ind, mask, color='lightgray')
|
||||||
|
count = 0
|
||||||
|
line_loc = []
|
||||||
|
for rect in ax_aws:
|
||||||
|
if count >= text.__len__():
|
||||||
|
word = '---'
|
||||||
|
else:
|
||||||
|
word = text[count]
|
||||||
|
if mask[count] == 1:
|
||||||
|
line_loc.append(rect.get_x())
|
||||||
|
line_loc.append(rect.get_x() + rect.get_width())
|
||||||
|
|
||||||
|
yloc = 0.5
|
||||||
|
xloc = rect.get_x() + 0.4
|
||||||
|
ax[0].text(xloc, yloc, word, horizontalalignment='center',
|
||||||
|
verticalalignment='center', rotation='vertical', color='black',
|
||||||
|
clip_on=True, size=16)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if self.legend:
|
||||||
|
ax[0].legend((ax_aws), ('answer',), loc='center left', bbox_to_anchor=(1, 0.5),
|
||||||
|
prop={'size': self.text_size})
|
||||||
|
if self.title:
|
||||||
|
ax[0].annotate('Questions', xy=(0, 0.8), xytext=(-ax[0].yaxis.labelpad - 150, 0),
|
||||||
|
xycoords=ax[0].yaxis.label, textcoords='offset points', size=self.text_size, ha='left',
|
||||||
|
va='center', ma='left')
|
||||||
|
|
||||||
|
ax[1].set_ylim([0, 1])
|
||||||
|
ax_corr = ax[1].bar(ind, correct_prediction, color='lawngreen')
|
||||||
|
ax_false = ax[1].bar(ind, false_prediction, color='tomato')
|
||||||
|
count = 0
|
||||||
|
for rect in ax_corr:
|
||||||
|
yloc = 0.5
|
||||||
|
xloc = rect.get_x() + 0.4
|
||||||
|
ax[1].text(xloc, yloc, decoded_predictions[count], horizontalalignment='center',
|
||||||
|
verticalalignment='center', rotation='vertical', color='black',
|
||||||
|
clip_on=True, size=16)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if self.legend:
|
||||||
|
ax[1].legend((ax_corr, ax_false), ('Correct', 'Wrong'), loc='center left', bbox_to_anchor=(1, 0.5),
|
||||||
|
prop={'size': self.text_size})
|
||||||
|
if self.title:
|
||||||
|
ax[1].annotate('Predictions', xy=(0, 0.8), xytext=(-ax[1].yaxis.labelpad - 150, 0),
|
||||||
|
xycoords=ax[1].yaxis.label, textcoords='offset points', size=self.text_size, ha='left',
|
||||||
|
va='center', ma='left')
|
||||||
|
|
||||||
|
ax[0].get_yaxis().set_ticks([])
|
||||||
|
ax[1].get_yaxis().set_ticks([])
|
||||||
|
|
||||||
|
return line_loc, rect.get_width()
|
||||||
|
|
||||||
|
def plot_weightings(self, weightings, ax, name='Weightings', mode='log', color='YlOrRd'):
|
||||||
|
assert weightings.shape.__len__() == 2, "plot weightings: need 2D matrix as data"
|
||||||
|
if mode == 'log':
|
||||||
|
norm = colors.LogNorm(vmin=1e-3, vmax=1)
|
||||||
|
else:
|
||||||
|
norm = colors.Normalize(vmin=0, vmax=1)
|
||||||
|
img = ax.imshow(np.transpose(weightings), interpolation='nearest', norm=norm, cmap=color,
|
||||||
|
aspect='auto') # gist_stern
|
||||||
|
ax.set_adjustable('box-forced')
|
||||||
|
if self.title:
|
||||||
|
ax.set_ylabel(name, size=self.text_size)
|
||||||
|
if self.legend:
|
||||||
|
box = ax.get_position()
|
||||||
|
ax.set_position([box.x0 - 0.001, box.y0, box.width, box.height])
|
||||||
|
axColor = plt.axes([box.x0 + box.width + 0.005, box.y0, 0.005, box.height])
|
||||||
|
cb = plt.colorbar(img, cax=axColor, orientation="vertical")
|
||||||
|
for l in cb.ax.yaxis.get_ticklabels():
|
||||||
|
l.set_size(self.text_size)
|
||||||
|
|
||||||
|
def plot_modes(self, modes, ax, mode_colors, mode_names, name='Modes'):
|
||||||
|
assert modes.shape.__len__() == 2, "plot modes: need 2D matrix as data"
|
||||||
|
assert modes.shape[1] == mode_colors.__len__() and modes.shape[
|
||||||
|
1] == mode_names.__len__(), "plot modes: not same length"
|
||||||
|
|
||||||
|
ind = np.arange(modes.shape[0])
|
||||||
|
ax_list = [ax.bar(ind, modes[:, 0], color=mode_colors[0]), ]
|
||||||
|
|
||||||
|
if modes.shape[1] > 1:
|
||||||
|
for m in range(1, modes.shape[1]):
|
||||||
|
ax_list.append(ax.bar(ind, modes[:, m], bottom=modes[:, :m].sum(axis=1), color=mode_colors[m]))
|
||||||
|
|
||||||
|
ax.set_yticks([0, 1])
|
||||||
|
ax.set_ylim(0, 1)
|
||||||
|
|
||||||
|
if self.title:
|
||||||
|
if name == 'Read Mode':
|
||||||
|
ax.annotate(name, xy=(0, 0.8), xytext=(-ax.yaxis.labelpad - 150, 0), xycoords=ax.yaxis.label,
|
||||||
|
textcoords='offset points', size=self.text_size, ha='left', va='center', ma='left')
|
||||||
|
else:
|
||||||
|
ax.annotate(name, xy=(0, 0.8), xytext=(-ax.yaxis.labelpad - 135, 0), xycoords=ax.yaxis.label,
|
||||||
|
textcoords='offset points', size=self.text_size, ha='left', va='center', ma='left')
|
||||||
|
if self.legend:
|
||||||
|
ax.legend(ax_list, mode_names, loc='center left', bbox_to_anchor=(1, 0.5), prop={'size': self.text_size})
|
||||||
|
|
||||||
|
def plot_multi_modes(self, multi_modes, ax, width, mode_colors, mode_names, name='Multi Modes'):
|
||||||
|
modes = multi_modes.shape[1]
|
||||||
|
ind = np.arange(multi_modes.shape[0])
|
||||||
|
|
||||||
|
width = width / modes
|
||||||
|
for j in range(-1, modes - 1):
|
||||||
|
ax_list = [ax.bar(ind + j * width + (width * 0.5), multi_modes[:, j, 0], color=mode_colors[0], width=width,
|
||||||
|
align='center'), ]
|
||||||
|
if multi_modes.shape[2] > 1:
|
||||||
|
for m in range(1, multi_modes.shape[2]):
|
||||||
|
ax_list.append(ax.bar(ind + j * width + (width * 0.5), multi_modes[:, j, m],
|
||||||
|
bottom=multi_modes[:, j, :m].sum(axis=1), color=mode_colors[m], width=width,
|
||||||
|
align='center'))
|
||||||
|
|
||||||
|
if self.title:
|
||||||
|
ax.annotate(name, xy=(0, 0.8), xytext=(-ax.yaxis.labelpad - 135, 0), xycoords=ax.yaxis.label,
|
||||||
|
textcoords='offset points', size=self.text_size, ha='left', va='center', ma='left')
|
||||||
|
if self.legend:
|
||||||
|
ax.legend(ax_list, mode_names, loc='center left', bbox_to_anchor=(1, 0.5), prop={'size': self.text_size})
|
||||||
|
ax.set_yticks([0, 1])
|
||||||
|
ax.set_yticklabels(['0', '', '', '', '', '1'])
|
||||||
|
ax.set_ylim(0, 1)
|
||||||
|
|
||||||
|
def plot_matrix(self, matrix, ax, name='Weightings', mode='norm', color='RdYlBu', zero_width=5, zero_add='zeros'):
|
||||||
|
assert matrix.shape.__len__() == 3, "plot weightings: need 3D matrix as data"
|
||||||
|
|
||||||
|
if mode == 'log':
|
||||||
|
norm = colors.LogNorm(vmin=1e-8, vmax=0.1)
|
||||||
|
elif mode == 'norm1':
|
||||||
|
norm = colors.Normalize(vmin=0, vmax=1)
|
||||||
|
else:
|
||||||
|
norm = colors.Normalize(vmin=-1, vmax=1)
|
||||||
|
|
||||||
|
if zero_add == 'zeros':
|
||||||
|
matrix = np.concatenate([matrix, np.zeros([matrix.shape[0], matrix.shape[1], zero_width])], axis=2)
|
||||||
|
matrix = np.transpose(matrix, axes=(0, 2, 1))
|
||||||
|
flat_matrix = np.reshape(matrix, [-1, matrix.shape[2]])
|
||||||
|
flat_matrix = np.concatenate([np.zeros([zero_width, flat_matrix.shape[1]]), flat_matrix], axis=0)
|
||||||
|
else:
|
||||||
|
matrix = np.concatenate([matrix, np.ones([matrix.shape[0], matrix.shape[1], zero_width])], axis=2)
|
||||||
|
matrix = np.transpose(matrix, axes=(0, 2, 1))
|
||||||
|
flat_matrix = np.reshape(matrix, [-1, matrix.shape[2]])
|
||||||
|
flat_matrix = np.concatenate([np.ones([zero_width, flat_matrix.shape[1]]), flat_matrix], axis=0)
|
||||||
|
|
||||||
|
img = ax.imshow(np.transpose(flat_matrix), aspect='auto', interpolation='nearest', norm=norm, cmap=color)
|
||||||
|
|
||||||
|
ax.set_adjustable('box-forced')
|
||||||
|
if self.title:
|
||||||
|
ax.set_ylabel(name, size=self.text_size)
|
||||||
|
if self.legend:
|
||||||
|
box = ax.get_position()
|
||||||
|
ax.set_position([box.x0 - 0.001, box.y0, box.width, box.height])
|
||||||
|
axColor = plt.axes([box.x0 + box.width + 0.005, box.y0, 0.005, box.height])
|
||||||
|
cb = plt.colorbar(img, cax=axColor, orientation="vertical")
|
||||||
|
for l in cb.ax.yaxis.get_ticklabels():
|
||||||
|
l.set_size(self.text_size)
|
||||||
|
tick_locator = ticker.MaxNLocator(nbins=3)
|
||||||
|
cb.locator = tick_locator
|
||||||
|
cb.update_ticks()
|
||||||
|
|
||||||
|
def plot_vector_as_matrix(self, vector, vertical, repeats, ax, name='Weightings', mode='log', color='YlOrRd',
|
||||||
|
zero_width=5):
|
||||||
|
|
||||||
|
assert vector.shape.__len__() == 2, "plot weightings: need 2D matrix as data"
|
||||||
|
|
||||||
|
if mode == 'log':
|
||||||
|
norm = colors.LogNorm(vmin=1e-3, vmax=1)
|
||||||
|
elif mode == 'norm1':
|
||||||
|
norm = colors.Normalize(vmin=0, vmax=1)
|
||||||
|
else:
|
||||||
|
norm = colors.Normalize(vmin=-1, vmax=1)
|
||||||
|
|
||||||
|
if vertical:
|
||||||
|
matrix = np.repeat(vector, repeats, axis=1)
|
||||||
|
matrix = np.reshape(matrix, [vector.shape[0], vector.shape[1], repeats])
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
matrix = np.repeat(vector, repeats, axis=0)
|
||||||
|
matrix = np.reshape(matrix, [vector.shape[0], repeats, vector.shape[1]])
|
||||||
|
|
||||||
|
matrix = np.concatenate([matrix, np.zeros([matrix.shape[0], matrix.shape[1], zero_width])], axis=2)
|
||||||
|
matrix = np.transpose(matrix, axes=(0, 2, 1))
|
||||||
|
flat_matrix = np.reshape(matrix, [-1, matrix.shape[2]])
|
||||||
|
flat_matrix = np.concatenate([np.zeros([zero_width, flat_matrix.shape[1]]), flat_matrix], axis=0)
|
||||||
|
|
||||||
|
img = ax.imshow(np.transpose(flat_matrix), aspect='auto', interpolation='nearest', norm=norm, cmap=color)
|
||||||
|
|
||||||
|
ax.set_adjustable('box-forced')
|
||||||
|
box = ax.get_position()
|
||||||
|
ax.set_position([box.x0 - 0.001, box.y0, box.width, box.height])
|
||||||
|
|
||||||
|
if self.legend:
|
||||||
|
axColor = plt.axes([box.x0 + box.width + 0.005, box.y0, 0.005, box.height])
|
||||||
|
cb = plt.colorbar(img, cax=axColor, orientation="vertical")
|
||||||
|
for l in cb.ax.yaxis.get_ticklabels():
|
||||||
|
l.set_size(self.text_size)
|
||||||
|
if self.title:
|
||||||
|
ax.set_ylabel(name, labelpad=30, size=self.text_size)
|
||||||
|
ax.set_yticks([])
|
Loading…
Reference in New Issue
Block a user