add docstrings

This commit is contained in:
joergfranke 2018-07-11 14:35:33 +02:00
parent caf9e137fa
commit 88258f7656
27 changed files with 367 additions and 149 deletions

View File

@ -21,21 +21,19 @@ from adnc.analysis.plot_functionality import PlotFunctionality
from adnc.analysis.prepare_variables import Bucket from adnc.analysis.prepare_variables import Bucket
from adnc.model.utils import softmax from adnc.model.utils import softmax
"""
"""
class Analyser(): class Analyser():
def __init__(self, data_set, record_dir, save_variables=False, save_fig=False): """
The analyzer helps to analyze the functionality of the DNC during training. It is used to calculate the
memory influence and to plot function plots of the memory usage.
"""
def __init__(self, record_dir, save_variables=False, save_fig=False):
""" """
Args: Args:
data_set: record_dir: dir to store the function plots
record_dir: save_variables: bool, to save weights, gradients and losses in a numpy list
save_variables: save_fig: bool, save plots
save_fig:
""" """
self.data_set = data_set
self.record_dir = record_dir self.record_dir = record_dir
self.save_variables = save_variables self.save_variables = save_variables
self.save_fig = save_fig self.save_fig = save_fig

View File

@ -23,6 +23,9 @@ import matplotlib.pyplot as plt
from adnc.analysis.plot_functions import PlotFunctions from adnc.analysis.plot_functions import PlotFunctions
"""
Function for plotting different memory units behaviours.
"""
class PlotFunctionality(PlotFunctions): class PlotFunctionality(PlotFunctions):
def __init__(self, bucket, legend=False, title=False, text_size=16, data_type='png'): def __init__(self, bucket, legend=False, title=False, text_size=16, data_type='png'):

View File

@ -21,6 +21,9 @@ import matplotlib.pyplot as plt
from matplotlib import colors from matplotlib import colors
from matplotlib import ticker from matplotlib import ticker
"""
Principle plot functions for gates, modes, input/output sequences or reading/writings of the DNC
"""
class PlotFunctions(): class PlotFunctions():
def __init__(self, legend=False, title=False, text_size=16): def __init__(self, legend=False, title=False, text_size=16):

View File

@ -17,6 +17,9 @@ import numpy as np
from adnc.model.utils import softmax from adnc.model.utils import softmax
from adnc.model.utils import weighted_softmax from adnc.model.utils import weighted_softmax
"""
The bucket prepaires and provide a full samples sequence of the DNC internal states for plotting the functions.
"""
class Bucket: class Bucket:
def __init__(self, variables, babi_short=True): def __init__(self, variables, babi_short=True):

View File

@ -24,7 +24,16 @@ from adnc.data.tasks.babi import bAbI
class DataLoader(): class DataLoader():
"""
The data loader loads and process the datasets and provides iterators for training or inference.
"""
def __init__(self, config, word_dict=None, re_word_dict=None): def __init__(self, config, word_dict=None, re_word_dict=None):
"""
Args:
config: dict with the config to pre-process the dataset
word_dict: dict with word-feature pairs, optional
re_word_dict: dict with feature-word pairs, optional
"""
self.config = config self.config = config
if config['data_set'] == 'copy_task': if config['data_set'] == 'copy_task':
@ -47,6 +56,14 @@ class DataLoader():
return self.dataset.y_size return self.dataset.y_size
def batch_amount(self, set_name): def batch_amount(self, set_name):
"""
Calculates the batch amount given a batch size
Args:
set_name: str, name of dataset (train, test, valid)
Returns: int, number of batches
"""
if 'max_len' in self.config.keys(): if 'max_len' in self.config.keys():
return np.floor( return np.floor(
self.dataset.sample_amount(set_name, self.config['max_len']) / self.config['batch_size']).astype(int) self.dataset.sample_amount(set_name, self.config['max_len']) / self.config['batch_size']).astype(int)
@ -63,6 +80,18 @@ class DataLoader():
return self.dataset.decode_output(sample, prediction) return self.dataset.decode_output(sample, prediction)
def get_data_loader(self, set_name, shuffle=True, max_len=False, batch_size=None, get_shuffle_option=False): def get_data_loader(self, set_name, shuffle=True, max_len=False, batch_size=None, get_shuffle_option=False):
"""
Provides a data iterator of the given dataset.
Args:
set_name: str, name of dataset
shuffle: bool, shuffle set or not
max_len: int, max length in time of sample
batch_size: int, batch size
get_shuffle_option: bool, returns shuffle function
Returns: iter, iterator over dataset
"""
if batch_size == None: if batch_size == None:
batch_size = self.config['batch_size'] batch_size = self.config['batch_size']
@ -77,6 +106,16 @@ class DataLoader():
@staticmethod @staticmethod
def _generate_in_background(batch_gen, num_cached=10, threads=1): def _generate_in_background(batch_gen, num_cached=10, threads=1):
"""
Starts threads with parallel batch generator for faster iteration
Args:
batch_gen: func, the batch generator
num_cached: int, numb of caches batches
threads: int, numb of parallel threads
Returns: iter, iterator over dataset
"""
queue = Queue(maxsize=num_cached) queue = Queue(maxsize=num_cached)
sentinel = object() sentinel = object()

View File

@ -20,11 +20,14 @@ from urllib.request import Request, urlopen
import numpy as np import numpy as np
"""
Downloads and pre-preocess the 20 bAbI task. It also augmets task 16 as described in paper.
"""
DEFAULT_DATA_FOLDER = "data_babi" DEFAULT_DATA_FOLDER = "data_babi"
LONGEST_SAMPLE_LENGTH = 1920 LONGEST_SAMPLE_LENGTH = 1920
bAbI_URL = 'http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz' bAbI_URL = 'http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz'
class bAbI(): class bAbI():
def __init__(self, config, word_dict=None, re_word_dict=None): def __init__(self, config, word_dict=None, re_word_dict=None):

View File

@ -24,11 +24,14 @@ import numpy as np
from adnc.data.utils.data_memorizer import DataMemorizer from adnc.data.utils.data_memorizer import DataMemorizer
"""
Downloads and pre-preocess the CNN RC task.
"""
DEFAULT_DATA_FOLDER = 'data_cnn' DEFAULT_DATA_FOLDER = 'data_cnn'
DEFAULT_TMP_FOLDER = 'data_tmp' DEFAULT_TMP_FOLDER = 'data_tmp'
CNN_DATA_URL = 'http://cs.stanford.edu/~danqi/data/cnn.tar.gz' CNN_DATA_URL = 'http://cs.stanford.edu/~danqi/data/cnn.tar.gz'
class ReadingComprehension(): class ReadingComprehension():
def __init__(self, config, save=True, debug_max_load=None): def __init__(self, config, save=True, debug_max_load=None):

View File

@ -16,6 +16,9 @@ import numpy as np
from collections import OrderedDict from collections import OrderedDict
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
"""
Generates "repeat a input sequences"-samples as described in NTM paper.
"""
class CopyTask(): class CopyTask():
def __init__(self, config): def __init__(self, config):

View File

@ -18,6 +18,15 @@ import threading
class BatchGenerator(): class BatchGenerator():
def __init__(self, data_set, set, batch_size, shuffle=True, max_len=False): def __init__(self, data_set, set, batch_size, shuffle=True, max_len=False):
"""
Creates batches out of samples from the dataset object
Args:
data_set: dataset object, contains the dataset
set: str, name of dataset (train, valid, test)
batch_size: int, batch size
shuffle: bool, shuffle or not
max_len: int, max length of sample in time
"""
self.set = set self.set = set
self.data_set = data_set self.data_set = data_set
@ -36,9 +45,15 @@ class BatchGenerator():
self.sample_count = 0 self.sample_count = 0
def shuffle_order(self): def shuffle_order(self):
"""
Shuffles the order of sample in dataset
"""
self.order = self.data_set.rng.permutation(self.order) self.order = self.data_set.rng.permutation(self.order)
def increase_sample_count(self): def increase_sample_count(self):
"""
Increase the global sample count, it is required because it can run in parallel
"""
with self.lock: with self.lock:
self.sample_count += 1 self.sample_count += 1
if self.sample_count >= self.sample_amount: if self.sample_count >= self.sample_amount:
@ -50,7 +65,12 @@ class BatchGenerator():
return next(self) return next(self)
def __next__(self): def __next__(self):
"""
Loads the next data samples and creates a batch. It uses the get_sample and patch_batch methods which are
provided by the dataset
Returns: batch of samples
"""
batch_list = [] batch_list = []
for b in range(self.batch_size): for b in range(self.batch_size):

View File

@ -20,8 +20,15 @@ from collections import OrderedDict
class DataMemorizer(): class DataMemorizer():
"""
Given a config, it saves the pre-processed data in a pickle dump.
"""
def __init__(self, config, tmp_dir): def __init__(self, config, tmp_dir):
"""
Args:
config: dict, config of dataset
tmp_dir: str, dir to save dataset dump
"""
self.hash_name = self.make_config_hash(config) self.hash_name = self.make_config_hash(config)
if isinstance(tmp_dir, pathlib.Path): if isinstance(tmp_dir, pathlib.Path):
self.tmp_dir = tmp_dir self.tmp_dir = tmp_dir
@ -35,25 +42,46 @@ class DataMemorizer():
return self.check_existent() return self.check_existent()
def check_existent(self): def check_existent(self):
"""
Returns: bool, if the dataset dump exists
"""
file_name = self.tmp_dir / self.hash_name file_name = self.tmp_dir / self.hash_name
return file_name.exists() return file_name.exists()
def load_data(self): def load_data(self):
"""
Returns: dataset, pickle load of dataset
"""
with open(str(self.tmp_dir / self.hash_name), 'rb') as outfile: with open(str(self.tmp_dir / self.hash_name), 'rb') as outfile:
data = pickle.load(outfile) data = pickle.load(outfile)
return data return data
def dump_data(self, data_to_save): def dump_data(self, data_to_save):
"""
Args:
data_to_save: object, what to save
"""
with open(str(self.tmp_dir / self.hash_name), 'wb') as outfile: with open(str(self.tmp_dir / self.hash_name), 'wb') as outfile:
pickle.dump(data_to_save, outfile) pickle.dump(data_to_save, outfile)
def purge_data(self): def purge_data(self):
"""
removes data dump
"""
file_name = str(self.tmp_dir / self.hash_name) file_name = str(self.tmp_dir / self.hash_name)
if os.path.isfile(file_name): if os.path.isfile(file_name):
os.remove(file_name) os.remove(file_name)
@staticmethod @staticmethod
def make_config_hash(dict): def make_config_hash(dict):
"""
computes a hash string to name the dataset dump uniquely
Args:
dict: dict, config which describes the dataset
Returns: str, hash tag of dataset
"""
pre = sorted(((k, v) for k, v in dict.items() if k not in ['batch_size', 'num_chached', 'threads'])) pre = sorted(((k, v) for k, v in dict.items() if k not in ['batch_size', 'num_chached', 'threads']))
sort_dict = OrderedDict() sort_dict = OrderedDict()
for element in pre: for element in pre:

View File

@ -18,6 +18,9 @@ from tensorflow.contrib.rnn import LayerNormBasicLSTMCell, LSTMCell, LSTMBlockCe
from adnc.model.controller_units.custom_lstm_cell import CustomLSTMCell from adnc.model.controller_units.custom_lstm_cell import CustomLSTMCell
from adnc.model.utils import get_activation from adnc.model.utils import get_activation
"""
A wrapper for the controller units.
"""
def get_rnn_cell_list(config, name, reuse=False, seed=123, dtype=tf.float32): def get_rnn_cell_list(config, name, reuse=False, seed=123, dtype=tf.float32):
cell_list = [] cell_list = []

View File

@ -17,6 +17,9 @@ import tensorflow as tf
from adnc.model.utils import layer_norm, get_activation from adnc.model.utils import layer_norm, get_activation
"""
A implementation of the LSTM unit, it performs a bit faster as the TF implementation and implements layer norm.
"""
class CustomLSTMCell(): class CustomLSTMCell():
def __init__(self, num_units, layer_norm=False, activation='tanh', seed=100, reuse=False, trainable=True, def __init__(self, num_units, layer_norm=False, activation='tanh', seed=100, reuse=False, trainable=True,

View File

@ -23,10 +23,19 @@ from adnc.model.memory_units.memory_unit import get_memory_unit
from adnc.model.utils import HolisticMultiRNNCell from adnc.model.utils import HolisticMultiRNNCell
from adnc.model.utils import WordEmbedding from adnc.model.utils import WordEmbedding
"""
The memory augmented neural network (MANN) model object contains the controller and the memory unit as well as the
loss function and connects everything.
"""
class MANN(): class MANN():
def __init__(self, config, analyse=False, reuse=False, name='mann', dtype=tf.float32):
def __init__(self, config, analyse=False, reuse=False, name='mann', dtype=tf.float32, new_output_structure=False): """
Args:
config: dict, configuration of the whole model
analyse: bool, is analyzer is used or not
reuse: bool, reuse model or not
"""
self.seed = config["seed"] self.seed = config["seed"]
self.rng = np.random.RandomState(seed=self.seed) self.rng = np.random.RandomState(seed=self.seed)
@ -40,12 +49,11 @@ class MANN():
self.input_embedding = config["input_embedding"] self.input_embedding = config["input_embedding"]
self.architecture = config['architecture'] self.architecture = config['architecture']
self.controller_config = config["controller_config"] self.controller_config = config["controller_config"]
self.memory_unit_config = config["memory_unit_config"] self.memory_unit_config = config["memory_unit_config"]
self.output_function = config["output_function"] self.output_function = config["output_function"]
self.output_mask = config["output_mask"] self.output_mask = config["output_mask"]
self.loss_function = config['loss_function'] self.loss_function = config['loss_function']
self.reuse = reuse self.reuse = reuse
self.name = name self.name = name
@ -66,39 +74,51 @@ class MANN():
else: else:
self.data = tf.placeholder(tf.float32, [None, self.batch_size, self.input_size], name='x') self.data = tf.placeholder(tf.float32, [None, self.batch_size, self.input_size], name='x')
if self.architecture in ['uni', 'unidirectional']: if self.architecture in ['uni', 'unidirectional']:
unweighted_outputs, states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config, reuse=self.reuse) unweighted_outputs, states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config,
reuse=self.reuse)
elif self.architecture in ['bi', 'bidirectional']: elif self.architecture in ['bi', 'bidirectional']:
unweighted_outputs, states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config, reuse=self.reuse) unweighted_outputs, states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config,
reuse=self.reuse)
else: else:
raise UserWarning("Unknown architecture, use unidirectional or bidirectional") raise UserWarning("Unknown architecture, use unidirectional or bidirectional")
if self.analyse: if self.analyse:
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
if self.architecture in ['uni', 'unidirectional']: if self.architecture in ['uni', 'unidirectional']:
analyse_outputs, analyse_states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config, analyse=True, reuse=True) analyse_outputs, analyse_states = self.unidirectional(self.data, self.controller_config,
self.memory_unit_config, analyse=True,
reuse=True)
analyse_outputs, analyse_signals = analyse_outputs analyse_outputs, analyse_signals = analyse_outputs
self.analyse =(analyse_outputs, analyse_signals, analyse_states) self.analyse = (analyse_outputs, analyse_signals, analyse_states)
elif self.architecture in ['bi', 'bidirectional']: elif self.architecture in ['bi', 'bidirectional']:
analyse_outputs, analyse_states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config, analyse=True, reuse=True) analyse_outputs, analyse_states = self.bidirectional(self.data, self.controller_config,
self.memory_unit_config, analyse=True,
reuse=True)
analyse_outputs, analyse_signals = analyse_outputs analyse_outputs, analyse_signals = analyse_outputs
self.analyse =(analyse_outputs, analyse_signals, analyse_states) self.analyse = (analyse_outputs, analyse_signals, analyse_states)
self.unweighted_outputs = unweighted_outputs self.unweighted_outputs = unweighted_outputs
self.prediction, self.outputs = self._output_layer(unweighted_outputs) self.prediction, self.outputs = self._output_layer(unweighted_outputs)
self.loss = self.get_loss(self.prediction) self.loss = self.get_loss(self.prediction)
def _output_layer(self, outputs): def _output_layer(self, outputs):
"""
Calculates the weighted and activated output of the MANN model
Args:
outputs: TF tensor, concatenation of memory units output and controller output
Returns: TF tensor, predictions; TF tensor, unactivated predictions
"""
with tf.variable_scope("output_layer"): with tf.variable_scope("output_layer"):
output_size = outputs.get_shape()[-1].value output_size = outputs.get_shape()[-1].value
weights_concat = tf.get_variable("weights_concat",(output_size, self.output_size), weights_concat = tf.get_variable("weights_concat", (output_size, self.output_size),
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed), collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype) initializer=tf.contrib.layers.xavier_initializer(seed=self.seed),
bias_merge = tf.get_variable("bias_merge",(self.output_size,), initializer=tf.constant_initializer(0.), collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype) collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
bias_merge = tf.get_variable("bias_merge", (self.output_size,), initializer=tf.constant_initializer(0.),
collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
output_flat = tf.reshape(outputs, [-1, output_size]) output_flat = tf.reshape(outputs, [-1, output_size])
output_flat = tf.matmul(output_flat, weights_concat) + bias_merge output_flat = tf.matmul(output_flat, weights_concat) + bias_merge
@ -117,13 +137,170 @@ class MANN():
return predictions, weighted_outputs return predictions, weighted_outputs
def get_loss(self, prediction):
"""
Args:
prediction: TF tensor, activated prediction of the model
Returns: TF scalar, loss of the current forward set
"""
if self.loss_function == 'cross_entropy':
if self.output_mask:
cost = tf.reduce_sum(
-1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log(
tf.clip_by_value(1 - prediction, 1e-12, 10.0)), axis=2)
cost *= self.mask
loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask)
else:
loss = tf.reduce_mean(
-1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log(
tf.clip_by_value(1 - prediction, 1e-12, 10.0)))
elif self.loss_function == 'mse':
clipped_prediction = tf.clip_by_value(prediction, 1e-12, 10.0)
mse = tf.square(self.target - clipped_prediction)
mse = tf.reduce_mean(mse, axis=2)
if self.output_mask:
cost = mse * self.mask
loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask)
else:
loss = tf.reduce_mean(mse)
else:
raise UserWarning("Unknown loss function, use cross_entropy or mse")
return loss
def unidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False):
"""
Connects unidirectional controller and memory unit and performs scan over sequence
Args:
inputs: TF tensor, input sequence
controller_config: dict, configuration of the controller
memory_unit_config: dict, configuration of the memory unit
analyse: bool, do analysis
reuse: bool, reuse
Returns: TF tensor, output sequence; TF tensor, hidden states
"""
with tf.variable_scope("controller"):
controller_list = get_rnn_cell_list(controller_config, name='controller', reuse=reuse, seed=self.seed,
dtype=self.dtype)
if controller_config['connect'] == 'sparse':
memory_input_size = controller_list[-1].output_size
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse,
reuse=reuse)
cell = MultiRNNCell(controller_list + [mu_cell])
else:
controller_cell = HolisticMultiRNNCell(controller_list)
memory_input_size = controller_cell.output_size
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse,
reuse=reuse)
cell = MultiRNNCell([controller_cell, mu_cell])
batch_size = inputs.get_shape()[1].value
cell_init_states = cell.zero_state(batch_size, dtype=self.dtype)
output_init = tf.zeros([batch_size, cell.output_size], dtype=self.dtype)
if analyse:
output_init = (output_init, mu_cell.analyse_state(batch_size, dtype=self.dtype))
init_states = (output_init, cell_init_states)
def step(pre_states, inputs):
pre_rnn_output, pre_rnn_states = pre_states
if analyse:
pre_rnn_output = pre_rnn_output[0]
controller_inputs = tf.concat([inputs, pre_rnn_output], axis=-1)
rnn_output, rnn_states = cell(controller_inputs, pre_rnn_states)
return (rnn_output, rnn_states)
outputs, states = tf.scan(step, inputs, initializer=init_states, parallel_iterations=32)
return outputs, states
def bidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False):
"""
Connects bidirectional controller and memory unit and performs scan over sequence
Args:
inputs: TF tensor, input sequence
controller_config: dict, configuration of the controller
memory_unit_config: dict, configuration of the memory unit
analyse: bool, do analysis
reuse: bool, reuse
Returns: TF tensor, output sequence; TF tensor, hidden states
"""
with tf.variable_scope("controller"):
list_fw = get_rnn_cell_list(controller_config, name='con_fw', reuse=reuse, seed=self.seed, dtype=self.dtype)
list_bw = get_rnn_cell_list(controller_config, name='con_bw', reuse=reuse, seed=self.seed, dtype=self.dtype)
if controller_config['connect'] == 'sparse':
cell_fw = MultiRNNCell(list_fw)
cell_bw = MultiRNNCell(list_bw)
else:
cell_fw = HolisticMultiRNNCell(list_fw)
cell_bw = HolisticMultiRNNCell(list_bw)
memory_input_size = cell_fw.output_size + cell_bw.output_size
cell_mu = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse)
with vs.variable_scope("bw") as bw_scope:
inputs_reverse = tf.reverse(inputs, axis=[0])
output_bw, output_state_bw = tf.nn.dynamic_rnn(cell=cell_bw, inputs=inputs_reverse, dtype=self.dtype,
parallel_iterations=32, time_major=True, scope=bw_scope)
output_bw = tf.reverse(output_bw, axis=[0])
batch_size = inputs.get_shape()[1].value
cell_fw_init_states = cell_fw.zero_state(batch_size, dtype=self.dtype)
cell_mu_init_states = cell_mu.zero_state(batch_size, dtype=self.dtype)
output_init = tf.zeros([batch_size, cell_mu.output_size], dtype=self.dtype)
if analyse:
output_init = (output_init, cell_mu.analyse_state(batch_size, dtype=self.dtype))
init_states = (output_init, cell_fw_init_states, cell_mu_init_states)
coupled_inputs = (inputs, output_bw)
with vs.variable_scope("fw") as fw_scope:
def step(pre_states, coupled_inputs):
inputs, output_bw = coupled_inputs
pre_outputs, pre_states_fw, pre_states_mu = pre_states
if analyse:
pre_outputs = pre_outputs[0]
controller_inputs = tf.concat([inputs, pre_outputs], axis=-1)
output_fw, states_fw = cell_fw(controller_inputs, pre_states_fw)
mu_inputs = tf.concat([output_fw, output_bw], axis=-1)
output_mu, states_mu = cell_mu(mu_inputs, pre_states_mu)
return (output_mu, states_fw, states_mu)
outputs, states_fw, states_mu = tf.scan(step, coupled_inputs, initializer=init_states,
parallel_iterations=32)
states = states_fw, states_mu
return outputs, states
@property @property
def feed(self): def feed(self):
"""
Returns: TF placeholder for data, target and mask inout to the model
"""
return self.data, self.target, self.mask return self.data, self.target, self.mask
@property @property
def controller_trainable_variables(self): def controller_trainable_variables(self):
return tf.get_collection('recurrent_unit') return tf.get_collection('recurrent_unit')
@property @property
def memory_unit_trainable_variables(self): def memory_unit_trainable_variables(self):
@ -164,121 +341,3 @@ class MANN():
def parameter_amount(self): def parameter_amount(self):
var_list = tf.trainable_variables() var_list = tf.trainable_variables()
return self.count_parameter_amount(var_list) return self.count_parameter_amount(var_list)
def get_loss(self, prediction):
if self.loss_function == 'cross_entropy':
if self.output_mask:
cost = tf.reduce_sum(-1 * self.target * tf.log(tf.clip_by_value(prediction,1e-12,10.0)) - (1 - self.target) * tf.log(tf.clip_by_value(1 - prediction,1e-12,10.0)), axis=2)
cost *= self.mask
loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask)
else:
loss = tf.reduce_mean(-1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log(tf.clip_by_value(1 - prediction, 1e-12, 10.0)))
elif self.loss_function == 'mse':
clipped_prediction = tf.clip_by_value(prediction, 1e-12, 10.0)
mse = tf.square(self.target - clipped_prediction)
mse = tf.reduce_mean(mse, axis=2)
if self.output_mask:
cost = mse * self.mask
loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask)
else:
loss = tf.reduce_mean(mse)
else:
raise UserWarning("Unknown loss function, use cross_entropy or mse")
return loss
def unidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False):
with tf.variable_scope("controller"):
controller_list = get_rnn_cell_list(controller_config, name='controller', reuse=reuse, seed=self.seed, dtype=self.dtype)
if controller_config['connect'] == 'sparse':
memory_input_size = controller_list[-1].output_size
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse)
cell = MultiRNNCell(controller_list +[mu_cell])
else:
controller_cell = HolisticMultiRNNCell(controller_list)
memory_input_size = controller_cell.output_size
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse)
cell = MultiRNNCell([controller_cell, mu_cell])
batch_size = inputs.get_shape()[1].value
cell_init_states = cell.zero_state(batch_size, dtype=self.dtype)
output_init = tf.zeros([batch_size, cell.output_size], dtype=self.dtype)
if analyse:
output_init = (output_init, mu_cell.analyse_state(batch_size, dtype=self.dtype))
init_states = (output_init, cell_init_states)
def step(pre_states, inputs):
pre_rnn_output, pre_rnn_states = pre_states
if analyse:
pre_rnn_output = pre_rnn_output[0]
controller_inputs = tf.concat([inputs, pre_rnn_output], axis=-1)
rnn_output, rnn_states = cell(controller_inputs, pre_rnn_states)
return (rnn_output, rnn_states)
outputs, states = tf.scan(step, inputs, initializer=init_states, parallel_iterations=32)
return outputs, states
def bidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False):
with tf.variable_scope("controller"):
list_fw = get_rnn_cell_list(controller_config, name='con_fw', reuse=reuse, seed=self.seed, dtype=self.dtype)
list_bw = get_rnn_cell_list( controller_config, name='con_bw', reuse=reuse, seed=self.seed, dtype=self.dtype)
if controller_config['connect'] == 'sparse':
cell_fw = MultiRNNCell(list_fw)
cell_bw = MultiRNNCell(list_bw)
else:
cell_fw = HolisticMultiRNNCell(list_fw)
cell_bw = HolisticMultiRNNCell(list_bw)
memory_input_size = cell_fw.output_size + cell_bw.output_size
cell_mu = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse)
with vs.variable_scope("bw") as bw_scope:
inputs_reverse = tf.reverse(inputs, axis=[0])
output_bw, output_state_bw = tf.nn.dynamic_rnn(cell=cell_bw, inputs=inputs_reverse, dtype=self.dtype,
parallel_iterations=32, time_major=True, scope=bw_scope)
output_bw = tf.reverse(output_bw, axis=[0])
batch_size = inputs.get_shape()[1].value
cell_fw_init_states = cell_fw.zero_state(batch_size, dtype=self.dtype)
cell_mu_init_states = cell_mu.zero_state(batch_size, dtype=self.dtype)
output_init = tf.zeros([batch_size, cell_mu.output_size], dtype=self.dtype)
if analyse:
output_init = (output_init, cell_mu.analyse_state(batch_size, dtype=self.dtype))
init_states = (output_init, cell_fw_init_states, cell_mu_init_states)
coupled_inputs = (inputs, output_bw)
with vs.variable_scope("fw") as fw_scope:
def step(pre_states, coupled_inputs):
inputs, output_bw = coupled_inputs
pre_outputs, pre_states_fw, pre_states_mu = pre_states
if analyse:
pre_outputs = pre_outputs[0]
controller_inputs = tf.concat([inputs, pre_outputs], axis=-1)
output_fw, states_fw = cell_fw(controller_inputs, pre_states_fw)
mu_inputs = tf.concat([output_fw, output_bw], axis=-1)
output_mu, states_mu = cell_mu(mu_inputs, pre_states_mu)
return (output_mu, states_fw, states_mu)
outputs, states_fw, states_mu = tf.scan(step, coupled_inputs, initializer=init_states, parallel_iterations=32)
states = states_fw, states_mu
return outputs, states

View File

@ -16,6 +16,9 @@ from abc import abstractmethod, ABCMeta
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
"""
The basis DNC memory unit class, all other inherit from this.
"""
class BaseMemoryUnitCell(): class BaseMemoryUnitCell():
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False, def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,

View File

@ -17,6 +17,9 @@ import tensorflow as tf
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
"""
The content-based memory unit.
"""
class ContentBasedMemoryUnitCell(DNCMemoryUnitCell): class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):

View File

@ -18,6 +18,9 @@ import tensorflow as tf
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
"""
The vanilla DNC memory unit.
"""
class DNCMemoryUnitCell(BaseMemoryUnitCell): class DNCMemoryUnitCell(BaseMemoryUnitCell):
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False, def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,

View File

@ -19,6 +19,9 @@ from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
from adnc.model.memory_units.multi_write_content_based_cell import MWContentMemoryUnitCell from adnc.model.memory_units.multi_write_content_based_cell import MWContentMemoryUnitCell
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
"""
A warpper for the memory units
"""
def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, seed=123, dtype=tf.float32): def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, seed=123, dtype=tf.float32):
memory_length = config['memory_length'] memory_length = config['memory_length']

View File

@ -17,6 +17,9 @@ import tensorflow as tf
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
"""
The content-based memory unit with multi write heads.
"""
class MWContentMemoryUnitCell(MWDNCMemoryUnitCell): class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):

View File

@ -20,6 +20,9 @@ from adnc.model.utils import layer_norm
from adnc.model.utils import oneplus from adnc.model.utils import oneplus
from adnc.model.utils import unit_simplex_initialization from adnc.model.utils import unit_simplex_initialization
"""
The vanilla DNC memory unit with multi write heads.
"""
class MWDNCMemoryUnitCell(BaseMemoryUnitCell): class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
def __init__(self, input_size, memory_length, memory_width, read_heads, write_heads, bypass_dropout=False, def __init__(self, input_size, memory_length, memory_width, read_heads, write_heads, bypass_dropout=False,

View File

@ -14,6 +14,9 @@
# ============================================================================== # ==============================================================================
import tensorflow as tf import tensorflow as tf
"""
The Optimizer clas is a wrapper for the TF optimizer and performs gradient clipping and weight decay
"""
class Optimizer: class Optimizer:
def __init__(self, config, loss, variables, use_locking=False): def __init__(self, config, loss, variables, use_locking=False):

View File

@ -22,6 +22,10 @@ from shutil import copyfile
import numpy as np import numpy as np
import yaml import yaml
"""
The supporter class creates for each training run a folder, logs the prints in a file and saves the weights,
gradients and losses.
"""
class color_code: class color_code:
def __init__(self): def __init__(self):

View File

@ -15,6 +15,9 @@
import numpy as np import numpy as np
from collections import deque from collections import deque
"""
EarlyStop has a true call return if the loss was the last "list_len" higher as the loss before.
"""
class EarlyStop(): class EarlyStop():
def __init__(self, list_len=5): def __init__(self, list_len=5):

View File

@ -21,6 +21,9 @@ import hashlib
from collections import OrderedDict from collections import OrderedDict
import tensorflow as tf import tensorflow as tf
"""
Downloads and process glove word embeddings, applies them to a given vocabulary of a dataset.
"""
class WordEmbedding(): class WordEmbedding():
def __init__(self, embedding_size, vocabulary_size=None, word_idx_dict=None, initialization='uniform', tmp_dir='.', def __init__(self, embedding_size, vocabulary_size=None, word_idx_dict=None, initialization='uniform', tmp_dir='.',

View File

@ -25,6 +25,11 @@ import tensorflow as tf
from adnc.data.loader import DataLoader from adnc.data.loader import DataLoader
from adnc.model.mann import MANN from adnc.model.mann import MANN
"""
This script performs a inference with the given models of this repository on the bAbI task 1 or on 1-20. Please add the
model name when calling the script. (dnc, adnc, biadnc, biadnc-all, biadnc-aug16-all)
"""
parser = argparse.ArgumentParser(description='Load model') parser = argparse.ArgumentParser(description='Load model')
parser.add_argument('model', type=str, default=False, help='model name') parser.add_argument('model', type=str, default=False, help='model name')
model_name = parser.parse_args().model model_name = parser.parse_args().model

View File

@ -22,6 +22,10 @@ from tqdm import tqdm
from adnc.data.loader import DataLoader from adnc.data.loader import DataLoader
from adnc.model.mann import MANN from adnc.model.mann import MANN
"""
This script performs a inference with the given models of this repository on the CNN RC task.
"""
expt_dir = "experiments/pre_trained/cnn_rc_task/adnc" expt_dir = "experiments/pre_trained/cnn_rc_task/adnc"
# load config file # load config file

View File

@ -25,6 +25,11 @@ from adnc.analysis import Bucket, PlotFunctionality
tf.reset_default_graph() tf.reset_default_graph()
"""
This script plot the memory unit functionality with the given models of this repository on the bAbI task 1 or on 1-20.
Please add the model name when calling the script. (dnc, adnc, biadnc, biadnc-all, biadnc-aug16-all)
"""
parser = argparse.ArgumentParser(description='Load model') parser = argparse.ArgumentParser(description='Load model')
parser.add_argument('model', type=str, default=False, help='model name') parser.add_argument('model', type=str, default=False, help='model name')
model_name = parser.parse_args().model model_name = parser.parse_args().model

View File

@ -27,6 +27,11 @@ from adnc.data import DataLoader
from adnc.model import MANN, Optimizer, Supporter from adnc.model import MANN, Optimizer, Supporter
from adnc.model.utils import EarlyStop from adnc.model.utils import EarlyStop
"""
This script performs starts a training run on the bAbI task. The training can be fully configured in the config.yml
file. To restore a session use the --sess and --check flag.
"""
tf.reset_default_graph() tf.reset_default_graph()
parser = argparse.ArgumentParser(description='Process some integers.') parser = argparse.ArgumentParser(description='Process some integers.')
@ -63,7 +68,7 @@ valid_loader = dl.get_data_loader('valid') # gets a valid data iterator
train_loader = dl.get_data_loader('train') # gets a train data iterator train_loader = dl.get_data_loader('train') # gets a train data iterator
if analyse: if analyse:
ana = Analyser(dataset_name, sp.session_dir, save_fig=plot_process, ana = Analyser(sp.session_dir, save_fig=plot_process,
save_variables=True) # initilizes a analyzer class save_variables=True) # initilizes a analyzer class
sp.config(model_type)['input_size'] = dl.x_size # after the data loader is initilized, the input size sp.config(model_type)['input_size'] = dl.x_size # after the data loader is initilized, the input size