From 88258f7656c71083cb063b06e9d554da57b79cfc Mon Sep 17 00:00:00 2001 From: joergfranke Date: Wed, 11 Jul 2018 14:35:33 +0200 Subject: [PATCH] add docstrings --- adnc/analysis/analyzer.py | 18 +- adnc/analysis/plot_functionality.py | 3 + adnc/analysis/plot_functions.py | 3 + adnc/analysis/prepare_variables.py | 3 + adnc/data/loader.py | 39 +++ adnc/data/tasks/babi.py | 5 +- adnc/data/tasks/cnn_rc.py | 5 +- adnc/data/tasks/repeat_copy.py | 3 + adnc/data/utils/batch_generator.py | 20 ++ adnc/data/utils/data_memorizer.py | 30 +- adnc/model/controller_units/controller.py | 3 + .../controller_units/custom_lstm_cell.py | 3 + adnc/model/mann.py | 329 +++++++++++------- adnc/model/memory_units/base_cell.py | 3 + adnc/model/memory_units/content_based_cell.py | 3 + adnc/model/memory_units/dnc_cell.py | 3 + adnc/model/memory_units/memory_unit.py | 3 + .../multi_write_content_based_cell.py | 3 + .../memory_units/multi_write_dnc_cell.py | 3 + adnc/model/optimizer.py | 3 + adnc/model/supporter.py | 4 + adnc/model/utils/early_stop.py | 3 + adnc/model/utils/word_embedding.py | 3 + scripts/inference_babi_task.py | 5 + scripts/inference_cnn_task.py | 4 + scripts/plot_function_babi_task.py | 5 + scripts/start_training.py | 7 +- 27 files changed, 367 insertions(+), 149 deletions(-) diff --git a/adnc/analysis/analyzer.py b/adnc/analysis/analyzer.py index 276f528..b733a5b 100755 --- a/adnc/analysis/analyzer.py +++ b/adnc/analysis/analyzer.py @@ -21,21 +21,19 @@ from adnc.analysis.plot_functionality import PlotFunctionality from adnc.analysis.prepare_variables import Bucket from adnc.model.utils import softmax -""" - -""" class Analyser(): - def __init__(self, data_set, record_dir, save_variables=False, save_fig=False): + """ + The analyzer helps to analyze the functionality of the DNC during training. It is used to calculate the + memory influence and to plot function plots of the memory usage. + """ + def __init__(self, record_dir, save_variables=False, save_fig=False): """ - Args: - data_set: - record_dir: - save_variables: - save_fig: + record_dir: dir to store the function plots + save_variables: bool, to save weights, gradients and losses in a numpy list + save_fig: bool, save plots """ - self.data_set = data_set self.record_dir = record_dir self.save_variables = save_variables self.save_fig = save_fig diff --git a/adnc/analysis/plot_functionality.py b/adnc/analysis/plot_functionality.py index 8bf2962..058a721 100755 --- a/adnc/analysis/plot_functionality.py +++ b/adnc/analysis/plot_functionality.py @@ -23,6 +23,9 @@ import matplotlib.pyplot as plt from adnc.analysis.plot_functions import PlotFunctions +""" +Function for plotting different memory units behaviours. +""" class PlotFunctionality(PlotFunctions): def __init__(self, bucket, legend=False, title=False, text_size=16, data_type='png'): diff --git a/adnc/analysis/plot_functions.py b/adnc/analysis/plot_functions.py index 2d90228..45c436f 100755 --- a/adnc/analysis/plot_functions.py +++ b/adnc/analysis/plot_functions.py @@ -21,6 +21,9 @@ import matplotlib.pyplot as plt from matplotlib import colors from matplotlib import ticker +""" +Principle plot functions for gates, modes, input/output sequences or reading/writings of the DNC +""" class PlotFunctions(): def __init__(self, legend=False, title=False, text_size=16): diff --git a/adnc/analysis/prepare_variables.py b/adnc/analysis/prepare_variables.py index 3de2417..6633004 100755 --- a/adnc/analysis/prepare_variables.py +++ b/adnc/analysis/prepare_variables.py @@ -17,6 +17,9 @@ import numpy as np from adnc.model.utils import softmax from adnc.model.utils import weighted_softmax +""" +The bucket prepaires and provide a full samples sequence of the DNC internal states for plotting the functions. +""" class Bucket: def __init__(self, variables, babi_short=True): diff --git a/adnc/data/loader.py b/adnc/data/loader.py index d440792..ae5d515 100644 --- a/adnc/data/loader.py +++ b/adnc/data/loader.py @@ -24,7 +24,16 @@ from adnc.data.tasks.babi import bAbI class DataLoader(): + """ + The data loader loads and process the datasets and provides iterators for training or inference. + """ def __init__(self, config, word_dict=None, re_word_dict=None): + """ + Args: + config: dict with the config to pre-process the dataset + word_dict: dict with word-feature pairs, optional + re_word_dict: dict with feature-word pairs, optional + """ self.config = config if config['data_set'] == 'copy_task': @@ -47,6 +56,14 @@ class DataLoader(): return self.dataset.y_size def batch_amount(self, set_name): + """ + Calculates the batch amount given a batch size + Args: + set_name: str, name of dataset (train, test, valid) + + Returns: int, number of batches + + """ if 'max_len' in self.config.keys(): return np.floor( self.dataset.sample_amount(set_name, self.config['max_len']) / self.config['batch_size']).astype(int) @@ -63,6 +80,18 @@ class DataLoader(): return self.dataset.decode_output(sample, prediction) def get_data_loader(self, set_name, shuffle=True, max_len=False, batch_size=None, get_shuffle_option=False): + """ + Provides a data iterator of the given dataset. + Args: + set_name: str, name of dataset + shuffle: bool, shuffle set or not + max_len: int, max length in time of sample + batch_size: int, batch size + get_shuffle_option: bool, returns shuffle function + + Returns: iter, iterator over dataset + + """ if batch_size == None: batch_size = self.config['batch_size'] @@ -77,6 +106,16 @@ class DataLoader(): @staticmethod def _generate_in_background(batch_gen, num_cached=10, threads=1): + """ + Starts threads with parallel batch generator for faster iteration + Args: + batch_gen: func, the batch generator + num_cached: int, numb of caches batches + threads: int, numb of parallel threads + + Returns: iter, iterator over dataset + + """ queue = Queue(maxsize=num_cached) sentinel = object() diff --git a/adnc/data/tasks/babi.py b/adnc/data/tasks/babi.py index 384ed65..0eadbaa 100755 --- a/adnc/data/tasks/babi.py +++ b/adnc/data/tasks/babi.py @@ -20,11 +20,14 @@ from urllib.request import Request, urlopen import numpy as np +""" +Downloads and pre-preocess the 20 bAbI task. It also augmets task 16 as described in paper. +""" + DEFAULT_DATA_FOLDER = "data_babi" LONGEST_SAMPLE_LENGTH = 1920 bAbI_URL = 'http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz' - class bAbI(): def __init__(self, config, word_dict=None, re_word_dict=None): diff --git a/adnc/data/tasks/cnn_rc.py b/adnc/data/tasks/cnn_rc.py index 88c468f..e499e1e 100644 --- a/adnc/data/tasks/cnn_rc.py +++ b/adnc/data/tasks/cnn_rc.py @@ -24,11 +24,14 @@ import numpy as np from adnc.data.utils.data_memorizer import DataMemorizer +""" +Downloads and pre-preocess the CNN RC task. +""" + DEFAULT_DATA_FOLDER = 'data_cnn' DEFAULT_TMP_FOLDER = 'data_tmp' CNN_DATA_URL = 'http://cs.stanford.edu/~danqi/data/cnn.tar.gz' - class ReadingComprehension(): def __init__(self, config, save=True, debug_max_load=None): diff --git a/adnc/data/tasks/repeat_copy.py b/adnc/data/tasks/repeat_copy.py index 7537988..56f5ed3 100755 --- a/adnc/data/tasks/repeat_copy.py +++ b/adnc/data/tasks/repeat_copy.py @@ -16,6 +16,9 @@ import numpy as np from collections import OrderedDict from scipy.sparse import csr_matrix +""" +Generates "repeat a input sequences"-samples as described in NTM paper. +""" class CopyTask(): def __init__(self, config): diff --git a/adnc/data/utils/batch_generator.py b/adnc/data/utils/batch_generator.py index 90e01b4..6f052b4 100755 --- a/adnc/data/utils/batch_generator.py +++ b/adnc/data/utils/batch_generator.py @@ -18,6 +18,15 @@ import threading class BatchGenerator(): def __init__(self, data_set, set, batch_size, shuffle=True, max_len=False): + """ + Creates batches out of samples from the dataset object + Args: + data_set: dataset object, contains the dataset + set: str, name of dataset (train, valid, test) + batch_size: int, batch size + shuffle: bool, shuffle or not + max_len: int, max length of sample in time + """ self.set = set self.data_set = data_set @@ -36,9 +45,15 @@ class BatchGenerator(): self.sample_count = 0 def shuffle_order(self): + """ + Shuffles the order of sample in dataset + """ self.order = self.data_set.rng.permutation(self.order) def increase_sample_count(self): + """ + Increase the global sample count, it is required because it can run in parallel + """ with self.lock: self.sample_count += 1 if self.sample_count >= self.sample_amount: @@ -50,7 +65,12 @@ class BatchGenerator(): return next(self) def __next__(self): + """ + Loads the next data samples and creates a batch. It uses the get_sample and patch_batch methods which are + provided by the dataset + Returns: batch of samples + """ batch_list = [] for b in range(self.batch_size): diff --git a/adnc/data/utils/data_memorizer.py b/adnc/data/utils/data_memorizer.py index 4521852..fb58a7d 100755 --- a/adnc/data/utils/data_memorizer.py +++ b/adnc/data/utils/data_memorizer.py @@ -20,8 +20,15 @@ from collections import OrderedDict class DataMemorizer(): + """ + Given a config, it saves the pre-processed data in a pickle dump. + """ def __init__(self, config, tmp_dir): - + """ + Args: + config: dict, config of dataset + tmp_dir: str, dir to save dataset dump + """ self.hash_name = self.make_config_hash(config) if isinstance(tmp_dir, pathlib.Path): self.tmp_dir = tmp_dir @@ -35,25 +42,46 @@ class DataMemorizer(): return self.check_existent() def check_existent(self): + """ + Returns: bool, if the dataset dump exists + """ file_name = self.tmp_dir / self.hash_name return file_name.exists() def load_data(self): + """ + Returns: dataset, pickle load of dataset + """ with open(str(self.tmp_dir / self.hash_name), 'rb') as outfile: data = pickle.load(outfile) return data def dump_data(self, data_to_save): + """ + Args: + data_to_save: object, what to save + """ with open(str(self.tmp_dir / self.hash_name), 'wb') as outfile: pickle.dump(data_to_save, outfile) def purge_data(self): + """ + removes data dump + """ file_name = str(self.tmp_dir / self.hash_name) if os.path.isfile(file_name): os.remove(file_name) @staticmethod def make_config_hash(dict): + """ + computes a hash string to name the dataset dump uniquely + Args: + dict: dict, config which describes the dataset + + Returns: str, hash tag of dataset + + """ pre = sorted(((k, v) for k, v in dict.items() if k not in ['batch_size', 'num_chached', 'threads'])) sort_dict = OrderedDict() for element in pre: diff --git a/adnc/model/controller_units/controller.py b/adnc/model/controller_units/controller.py index e1454e6..04135d9 100644 --- a/adnc/model/controller_units/controller.py +++ b/adnc/model/controller_units/controller.py @@ -18,6 +18,9 @@ from tensorflow.contrib.rnn import LayerNormBasicLSTMCell, LSTMCell, LSTMBlockCe from adnc.model.controller_units.custom_lstm_cell import CustomLSTMCell from adnc.model.utils import get_activation +""" +A wrapper for the controller units. +""" def get_rnn_cell_list(config, name, reuse=False, seed=123, dtype=tf.float32): cell_list = [] diff --git a/adnc/model/controller_units/custom_lstm_cell.py b/adnc/model/controller_units/custom_lstm_cell.py index 7e4b52d..b38df49 100755 --- a/adnc/model/controller_units/custom_lstm_cell.py +++ b/adnc/model/controller_units/custom_lstm_cell.py @@ -17,6 +17,9 @@ import tensorflow as tf from adnc.model.utils import layer_norm, get_activation +""" +A implementation of the LSTM unit, it performs a bit faster as the TF implementation and implements layer norm. +""" class CustomLSTMCell(): def __init__(self, num_units, layer_norm=False, activation='tanh', seed=100, reuse=False, trainable=True, diff --git a/adnc/model/mann.py b/adnc/model/mann.py index e20e47f..f96af9a 100755 --- a/adnc/model/mann.py +++ b/adnc/model/mann.py @@ -23,10 +23,19 @@ from adnc.model.memory_units.memory_unit import get_memory_unit from adnc.model.utils import HolisticMultiRNNCell from adnc.model.utils import WordEmbedding +""" +The memory augmented neural network (MANN) model object contains the controller and the memory unit as well as the +loss function and connects everything. +""" class MANN(): - - def __init__(self, config, analyse=False, reuse=False, name='mann', dtype=tf.float32, new_output_structure=False): + def __init__(self, config, analyse=False, reuse=False, name='mann', dtype=tf.float32): + """ + Args: + config: dict, configuration of the whole model + analyse: bool, is analyzer is used or not + reuse: bool, reuse model or not + """ self.seed = config["seed"] self.rng = np.random.RandomState(seed=self.seed) @@ -40,12 +49,11 @@ class MANN(): self.input_embedding = config["input_embedding"] self.architecture = config['architecture'] self.controller_config = config["controller_config"] - self.memory_unit_config = config["memory_unit_config"] + self.memory_unit_config = config["memory_unit_config"] self.output_function = config["output_function"] self.output_mask = config["output_mask"] self.loss_function = config['loss_function'] - self.reuse = reuse self.name = name @@ -66,39 +74,51 @@ class MANN(): else: self.data = tf.placeholder(tf.float32, [None, self.batch_size, self.input_size], name='x') - if self.architecture in ['uni', 'unidirectional']: - unweighted_outputs, states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config, reuse=self.reuse) + unweighted_outputs, states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config, + reuse=self.reuse) elif self.architecture in ['bi', 'bidirectional']: - unweighted_outputs, states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config, reuse=self.reuse) + unweighted_outputs, states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config, + reuse=self.reuse) else: raise UserWarning("Unknown architecture, use unidirectional or bidirectional") - if self.analyse: with tf.device('/cpu:0'): if self.architecture in ['uni', 'unidirectional']: - analyse_outputs, analyse_states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config, analyse=True, reuse=True) + analyse_outputs, analyse_states = self.unidirectional(self.data, self.controller_config, + self.memory_unit_config, analyse=True, + reuse=True) analyse_outputs, analyse_signals = analyse_outputs - self.analyse =(analyse_outputs, analyse_signals, analyse_states) + self.analyse = (analyse_outputs, analyse_signals, analyse_states) elif self.architecture in ['bi', 'bidirectional']: - analyse_outputs, analyse_states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config, analyse=True, reuse=True) + analyse_outputs, analyse_states = self.bidirectional(self.data, self.controller_config, + self.memory_unit_config, analyse=True, + reuse=True) analyse_outputs, analyse_signals = analyse_outputs - self.analyse =(analyse_outputs, analyse_signals, analyse_states) + self.analyse = (analyse_outputs, analyse_signals, analyse_states) self.unweighted_outputs = unweighted_outputs self.prediction, self.outputs = self._output_layer(unweighted_outputs) self.loss = self.get_loss(self.prediction) - def _output_layer(self, outputs): + """ + Calculates the weighted and activated output of the MANN model + Args: + outputs: TF tensor, concatenation of memory units output and controller output + Returns: TF tensor, predictions; TF tensor, unactivated predictions + + """ with tf.variable_scope("output_layer"): output_size = outputs.get_shape()[-1].value - weights_concat = tf.get_variable("weights_concat",(output_size, self.output_size), - initializer=tf.contrib.layers.xavier_initializer(seed=self.seed), collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype) - bias_merge = tf.get_variable("bias_merge",(self.output_size,), initializer=tf.constant_initializer(0.), collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype) + weights_concat = tf.get_variable("weights_concat", (output_size, self.output_size), + initializer=tf.contrib.layers.xavier_initializer(seed=self.seed), + collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype) + bias_merge = tf.get_variable("bias_merge", (self.output_size,), initializer=tf.constant_initializer(0.), + collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype) output_flat = tf.reshape(outputs, [-1, output_size]) output_flat = tf.matmul(output_flat, weights_concat) + bias_merge @@ -117,13 +137,170 @@ class MANN(): return predictions, weighted_outputs + def get_loss(self, prediction): + """ + Args: + prediction: TF tensor, activated prediction of the model + + Returns: TF scalar, loss of the current forward set + + """ + if self.loss_function == 'cross_entropy': + if self.output_mask: + cost = tf.reduce_sum( + -1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log( + tf.clip_by_value(1 - prediction, 1e-12, 10.0)), axis=2) + cost *= self.mask + loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask) + else: + loss = tf.reduce_mean( + -1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log( + tf.clip_by_value(1 - prediction, 1e-12, 10.0))) + + elif self.loss_function == 'mse': + clipped_prediction = tf.clip_by_value(prediction, 1e-12, 10.0) + mse = tf.square(self.target - clipped_prediction) + mse = tf.reduce_mean(mse, axis=2) + + if self.output_mask: + cost = mse * self.mask + loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask) + else: + loss = tf.reduce_mean(mse) + else: + raise UserWarning("Unknown loss function, use cross_entropy or mse") + return loss + + def unidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False): + """ + Connects unidirectional controller and memory unit and performs scan over sequence + Args: + inputs: TF tensor, input sequence + controller_config: dict, configuration of the controller + memory_unit_config: dict, configuration of the memory unit + analyse: bool, do analysis + reuse: bool, reuse + + Returns: TF tensor, output sequence; TF tensor, hidden states + + """ + + with tf.variable_scope("controller"): + controller_list = get_rnn_cell_list(controller_config, name='controller', reuse=reuse, seed=self.seed, + dtype=self.dtype) + + if controller_config['connect'] == 'sparse': + memory_input_size = controller_list[-1].output_size + mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, + reuse=reuse) + cell = MultiRNNCell(controller_list + [mu_cell]) + else: + controller_cell = HolisticMultiRNNCell(controller_list) + memory_input_size = controller_cell.output_size + mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, + reuse=reuse) + cell = MultiRNNCell([controller_cell, mu_cell]) + + batch_size = inputs.get_shape()[1].value + cell_init_states = cell.zero_state(batch_size, dtype=self.dtype) + output_init = tf.zeros([batch_size, cell.output_size], dtype=self.dtype) + + if analyse: + output_init = (output_init, mu_cell.analyse_state(batch_size, dtype=self.dtype)) + + init_states = (output_init, cell_init_states) + + def step(pre_states, inputs): + pre_rnn_output, pre_rnn_states = pre_states + + if analyse: + pre_rnn_output = pre_rnn_output[0] + + controller_inputs = tf.concat([inputs, pre_rnn_output], axis=-1) + rnn_output, rnn_states = cell(controller_inputs, pre_rnn_states) + return (rnn_output, rnn_states) + + outputs, states = tf.scan(step, inputs, initializer=init_states, parallel_iterations=32) + + return outputs, states + + def bidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False): + """ + Connects bidirectional controller and memory unit and performs scan over sequence + Args: + inputs: TF tensor, input sequence + controller_config: dict, configuration of the controller + memory_unit_config: dict, configuration of the memory unit + analyse: bool, do analysis + reuse: bool, reuse + + Returns: TF tensor, output sequence; TF tensor, hidden states + + """ + + with tf.variable_scope("controller"): + list_fw = get_rnn_cell_list(controller_config, name='con_fw', reuse=reuse, seed=self.seed, dtype=self.dtype) + list_bw = get_rnn_cell_list(controller_config, name='con_bw', reuse=reuse, seed=self.seed, dtype=self.dtype) + if controller_config['connect'] == 'sparse': + cell_fw = MultiRNNCell(list_fw) + cell_bw = MultiRNNCell(list_bw) + else: + cell_fw = HolisticMultiRNNCell(list_fw) + cell_bw = HolisticMultiRNNCell(list_bw) + + memory_input_size = cell_fw.output_size + cell_bw.output_size + cell_mu = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse) + + with vs.variable_scope("bw") as bw_scope: + inputs_reverse = tf.reverse(inputs, axis=[0]) + output_bw, output_state_bw = tf.nn.dynamic_rnn(cell=cell_bw, inputs=inputs_reverse, dtype=self.dtype, + parallel_iterations=32, time_major=True, scope=bw_scope) + output_bw = tf.reverse(output_bw, axis=[0]) + + batch_size = inputs.get_shape()[1].value + cell_fw_init_states = cell_fw.zero_state(batch_size, dtype=self.dtype) + cell_mu_init_states = cell_mu.zero_state(batch_size, dtype=self.dtype) + output_init = tf.zeros([batch_size, cell_mu.output_size], dtype=self.dtype) + + if analyse: + output_init = (output_init, cell_mu.analyse_state(batch_size, dtype=self.dtype)) + + init_states = (output_init, cell_fw_init_states, cell_mu_init_states) + coupled_inputs = (inputs, output_bw) + + with vs.variable_scope("fw") as fw_scope: + + def step(pre_states, coupled_inputs): + inputs, output_bw = coupled_inputs + pre_outputs, pre_states_fw, pre_states_mu = pre_states + + if analyse: + pre_outputs = pre_outputs[0] + + controller_inputs = tf.concat([inputs, pre_outputs], axis=-1) + output_fw, states_fw = cell_fw(controller_inputs, pre_states_fw) + + mu_inputs = tf.concat([output_fw, output_bw], axis=-1) + output_mu, states_mu = cell_mu(mu_inputs, pre_states_mu) + + return (output_mu, states_fw, states_mu) + + outputs, states_fw, states_mu = tf.scan(step, coupled_inputs, initializer=init_states, + parallel_iterations=32) + + states = states_fw, states_mu + return outputs, states + @property def feed(self): + """ + Returns: TF placeholder for data, target and mask inout to the model + """ return self.data, self.target, self.mask @property def controller_trainable_variables(self): - return tf.get_collection('recurrent_unit') + return tf.get_collection('recurrent_unit') @property def memory_unit_trainable_variables(self): @@ -164,121 +341,3 @@ class MANN(): def parameter_amount(self): var_list = tf.trainable_variables() return self.count_parameter_amount(var_list) - - - def get_loss(self, prediction): - if self.loss_function == 'cross_entropy': - if self.output_mask: - cost = tf.reduce_sum(-1 * self.target * tf.log(tf.clip_by_value(prediction,1e-12,10.0)) - (1 - self.target) * tf.log(tf.clip_by_value(1 - prediction,1e-12,10.0)), axis=2) - cost *= self.mask - loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask) - else: - loss = tf.reduce_mean(-1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log(tf.clip_by_value(1 - prediction, 1e-12, 10.0))) - - elif self.loss_function == 'mse': - clipped_prediction = tf.clip_by_value(prediction, 1e-12, 10.0) - mse = tf.square(self.target - clipped_prediction) - mse = tf.reduce_mean(mse, axis=2) - - if self.output_mask: - cost = mse * self.mask - loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask) - else: - loss = tf.reduce_mean(mse) - else: - raise UserWarning("Unknown loss function, use cross_entropy or mse") - return loss - - - def unidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False): - - with tf.variable_scope("controller"): - controller_list = get_rnn_cell_list(controller_config, name='controller', reuse=reuse, seed=self.seed, dtype=self.dtype) - - if controller_config['connect'] == 'sparse': - memory_input_size = controller_list[-1].output_size - mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse) - cell = MultiRNNCell(controller_list +[mu_cell]) - else: - controller_cell = HolisticMultiRNNCell(controller_list) - memory_input_size = controller_cell.output_size - mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse) - cell = MultiRNNCell([controller_cell, mu_cell]) - - batch_size = inputs.get_shape()[1].value - cell_init_states = cell.zero_state(batch_size, dtype=self.dtype) - output_init = tf.zeros([batch_size, cell.output_size], dtype=self.dtype) - - if analyse: - output_init = (output_init, mu_cell.analyse_state(batch_size, dtype=self.dtype)) - - init_states = (output_init, cell_init_states) - - def step(pre_states, inputs): - pre_rnn_output, pre_rnn_states = pre_states - - if analyse: - pre_rnn_output = pre_rnn_output[0] - - controller_inputs = tf.concat([inputs, pre_rnn_output], axis=-1) - rnn_output, rnn_states = cell(controller_inputs, pre_rnn_states) - return (rnn_output, rnn_states) - - outputs, states = tf.scan(step, inputs, initializer=init_states, parallel_iterations=32) - - return outputs, states - - - def bidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False): - - with tf.variable_scope("controller"): - list_fw = get_rnn_cell_list(controller_config, name='con_fw', reuse=reuse, seed=self.seed, dtype=self.dtype) - list_bw = get_rnn_cell_list( controller_config, name='con_bw', reuse=reuse, seed=self.seed, dtype=self.dtype) - if controller_config['connect'] == 'sparse': - cell_fw = MultiRNNCell(list_fw) - cell_bw = MultiRNNCell(list_bw) - else: - cell_fw = HolisticMultiRNNCell(list_fw) - cell_bw = HolisticMultiRNNCell(list_bw) - - memory_input_size = cell_fw.output_size + cell_bw.output_size - cell_mu = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse) - - with vs.variable_scope("bw") as bw_scope: - inputs_reverse = tf.reverse(inputs, axis=[0]) - output_bw, output_state_bw = tf.nn.dynamic_rnn(cell=cell_bw, inputs=inputs_reverse, dtype=self.dtype, - parallel_iterations=32, time_major=True, scope=bw_scope) - output_bw = tf.reverse(output_bw, axis=[0]) - - batch_size = inputs.get_shape()[1].value - cell_fw_init_states = cell_fw.zero_state(batch_size, dtype=self.dtype) - cell_mu_init_states = cell_mu.zero_state(batch_size, dtype=self.dtype) - output_init = tf.zeros([batch_size, cell_mu.output_size], dtype=self.dtype) - - if analyse: - output_init = (output_init, cell_mu.analyse_state(batch_size, dtype=self.dtype)) - - init_states = (output_init, cell_fw_init_states, cell_mu_init_states) - coupled_inputs = (inputs, output_bw) - - with vs.variable_scope("fw") as fw_scope: - - def step(pre_states, coupled_inputs): - inputs, output_bw = coupled_inputs - pre_outputs, pre_states_fw, pre_states_mu = pre_states - - if analyse: - pre_outputs = pre_outputs[0] - - controller_inputs = tf.concat([inputs, pre_outputs], axis=-1) - output_fw, states_fw = cell_fw(controller_inputs, pre_states_fw) - - mu_inputs = tf.concat([output_fw, output_bw], axis=-1) - output_mu, states_mu = cell_mu(mu_inputs, pre_states_mu) - - return (output_mu, states_fw, states_mu) - - outputs, states_fw, states_mu = tf.scan(step, coupled_inputs, initializer=init_states, parallel_iterations=32) - - states = states_fw, states_mu - return outputs, states \ No newline at end of file diff --git a/adnc/model/memory_units/base_cell.py b/adnc/model/memory_units/base_cell.py index 6812285..9c42b56 100644 --- a/adnc/model/memory_units/base_cell.py +++ b/adnc/model/memory_units/base_cell.py @@ -16,6 +16,9 @@ from abc import abstractmethod, ABCMeta import numpy as np import tensorflow as tf +""" +The basis DNC memory unit class, all other inherit from this. +""" class BaseMemoryUnitCell(): def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False, diff --git a/adnc/model/memory_units/content_based_cell.py b/adnc/model/memory_units/content_based_cell.py index 16146b5..c08b398 100755 --- a/adnc/model/memory_units/content_based_cell.py +++ b/adnc/model/memory_units/content_based_cell.py @@ -17,6 +17,9 @@ import tensorflow as tf from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization +""" +The content-based memory unit. +""" class ContentBasedMemoryUnitCell(DNCMemoryUnitCell): diff --git a/adnc/model/memory_units/dnc_cell.py b/adnc/model/memory_units/dnc_cell.py index ec4d7c3..08a4ad1 100755 --- a/adnc/model/memory_units/dnc_cell.py +++ b/adnc/model/memory_units/dnc_cell.py @@ -18,6 +18,9 @@ import tensorflow as tf from adnc.model.memory_units.base_cell import BaseMemoryUnitCell from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization +""" +The vanilla DNC memory unit. +""" class DNCMemoryUnitCell(BaseMemoryUnitCell): def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False, diff --git a/adnc/model/memory_units/memory_unit.py b/adnc/model/memory_units/memory_unit.py index e0ef025..aee72bd 100644 --- a/adnc/model/memory_units/memory_unit.py +++ b/adnc/model/memory_units/memory_unit.py @@ -19,6 +19,9 @@ from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell from adnc.model.memory_units.multi_write_content_based_cell import MWContentMemoryUnitCell from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell +""" +A warpper for the memory units +""" def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, seed=123, dtype=tf.float32): memory_length = config['memory_length'] diff --git a/adnc/model/memory_units/multi_write_content_based_cell.py b/adnc/model/memory_units/multi_write_content_based_cell.py index b6afd96..f91cb1e 100644 --- a/adnc/model/memory_units/multi_write_content_based_cell.py +++ b/adnc/model/memory_units/multi_write_content_based_cell.py @@ -17,6 +17,9 @@ import tensorflow as tf from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization +""" +The content-based memory unit with multi write heads. +""" class MWContentMemoryUnitCell(MWDNCMemoryUnitCell): diff --git a/adnc/model/memory_units/multi_write_dnc_cell.py b/adnc/model/memory_units/multi_write_dnc_cell.py index b42ceb2..cd5e721 100644 --- a/adnc/model/memory_units/multi_write_dnc_cell.py +++ b/adnc/model/memory_units/multi_write_dnc_cell.py @@ -20,6 +20,9 @@ from adnc.model.utils import layer_norm from adnc.model.utils import oneplus from adnc.model.utils import unit_simplex_initialization +""" +The vanilla DNC memory unit with multi write heads. +""" class MWDNCMemoryUnitCell(BaseMemoryUnitCell): def __init__(self, input_size, memory_length, memory_width, read_heads, write_heads, bypass_dropout=False, diff --git a/adnc/model/optimizer.py b/adnc/model/optimizer.py index 962235c..bd48b2b 100755 --- a/adnc/model/optimizer.py +++ b/adnc/model/optimizer.py @@ -14,6 +14,9 @@ # ============================================================================== import tensorflow as tf +""" +The Optimizer clas is a wrapper for the TF optimizer and performs gradient clipping and weight decay +""" class Optimizer: def __init__(self, config, loss, variables, use_locking=False): diff --git a/adnc/model/supporter.py b/adnc/model/supporter.py index ab4c0e3..3ee41c4 100755 --- a/adnc/model/supporter.py +++ b/adnc/model/supporter.py @@ -22,6 +22,10 @@ from shutil import copyfile import numpy as np import yaml +""" +The supporter class creates for each training run a folder, logs the prints in a file and saves the weights, +gradients and losses. +""" class color_code: def __init__(self): diff --git a/adnc/model/utils/early_stop.py b/adnc/model/utils/early_stop.py index ba853b5..3a7ae07 100644 --- a/adnc/model/utils/early_stop.py +++ b/adnc/model/utils/early_stop.py @@ -15,6 +15,9 @@ import numpy as np from collections import deque +""" +EarlyStop has a true call return if the loss was the last "list_len" higher as the loss before. +""" class EarlyStop(): def __init__(self, list_len=5): diff --git a/adnc/model/utils/word_embedding.py b/adnc/model/utils/word_embedding.py index c769cea..848a1cf 100644 --- a/adnc/model/utils/word_embedding.py +++ b/adnc/model/utils/word_embedding.py @@ -21,6 +21,9 @@ import hashlib from collections import OrderedDict import tensorflow as tf +""" +Downloads and process glove word embeddings, applies them to a given vocabulary of a dataset. +""" class WordEmbedding(): def __init__(self, embedding_size, vocabulary_size=None, word_idx_dict=None, initialization='uniform', tmp_dir='.', diff --git a/scripts/inference_babi_task.py b/scripts/inference_babi_task.py index a124dc9..ef345d5 100755 --- a/scripts/inference_babi_task.py +++ b/scripts/inference_babi_task.py @@ -25,6 +25,11 @@ import tensorflow as tf from adnc.data.loader import DataLoader from adnc.model.mann import MANN +""" +This script performs a inference with the given models of this repository on the bAbI task 1 or on 1-20. Please add the +model name when calling the script. (dnc, adnc, biadnc, biadnc-all, biadnc-aug16-all) +""" + parser = argparse.ArgumentParser(description='Load model') parser.add_argument('model', type=str, default=False, help='model name') model_name = parser.parse_args().model diff --git a/scripts/inference_cnn_task.py b/scripts/inference_cnn_task.py index 53d3e01..766cf46 100644 --- a/scripts/inference_cnn_task.py +++ b/scripts/inference_cnn_task.py @@ -22,6 +22,10 @@ from tqdm import tqdm from adnc.data.loader import DataLoader from adnc.model.mann import MANN +""" +This script performs a inference with the given models of this repository on the CNN RC task. +""" + expt_dir = "experiments/pre_trained/cnn_rc_task/adnc" # load config file diff --git a/scripts/plot_function_babi_task.py b/scripts/plot_function_babi_task.py index 4e16968..add78bb 100644 --- a/scripts/plot_function_babi_task.py +++ b/scripts/plot_function_babi_task.py @@ -25,6 +25,11 @@ from adnc.analysis import Bucket, PlotFunctionality tf.reset_default_graph() +""" +This script plot the memory unit functionality with the given models of this repository on the bAbI task 1 or on 1-20. +Please add the model name when calling the script. (dnc, adnc, biadnc, biadnc-all, biadnc-aug16-all) +""" + parser = argparse.ArgumentParser(description='Load model') parser.add_argument('model', type=str, default=False, help='model name') model_name = parser.parse_args().model diff --git a/scripts/start_training.py b/scripts/start_training.py index 6e343c7..b5d6f4d 100755 --- a/scripts/start_training.py +++ b/scripts/start_training.py @@ -27,6 +27,11 @@ from adnc.data import DataLoader from adnc.model import MANN, Optimizer, Supporter from adnc.model.utils import EarlyStop +""" +This script performs starts a training run on the bAbI task. The training can be fully configured in the config.yml +file. To restore a session use the --sess and --check flag. +""" + tf.reset_default_graph() parser = argparse.ArgumentParser(description='Process some integers.') @@ -63,7 +68,7 @@ valid_loader = dl.get_data_loader('valid') # gets a valid data iterator train_loader = dl.get_data_loader('train') # gets a train data iterator if analyse: - ana = Analyser(dataset_name, sp.session_dir, save_fig=plot_process, + ana = Analyser(sp.session_dir, save_fig=plot_process, save_variables=True) # initilizes a analyzer class sp.config(model_type)['input_size'] = dl.x_size # after the data loader is initilized, the input size