mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
add docstrings
This commit is contained in:
parent
caf9e137fa
commit
88258f7656
@ -21,21 +21,19 @@ from adnc.analysis.plot_functionality import PlotFunctionality
|
||||
from adnc.analysis.prepare_variables import Bucket
|
||||
from adnc.model.utils import softmax
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
class Analyser():
|
||||
def __init__(self, data_set, record_dir, save_variables=False, save_fig=False):
|
||||
"""
|
||||
|
||||
The analyzer helps to analyze the functionality of the DNC during training. It is used to calculate the
|
||||
memory influence and to plot function plots of the memory usage.
|
||||
"""
|
||||
def __init__(self, record_dir, save_variables=False, save_fig=False):
|
||||
"""
|
||||
Args:
|
||||
data_set:
|
||||
record_dir:
|
||||
save_variables:
|
||||
save_fig:
|
||||
record_dir: dir to store the function plots
|
||||
save_variables: bool, to save weights, gradients and losses in a numpy list
|
||||
save_fig: bool, save plots
|
||||
"""
|
||||
self.data_set = data_set
|
||||
self.record_dir = record_dir
|
||||
self.save_variables = save_variables
|
||||
self.save_fig = save_fig
|
||||
|
@ -23,6 +23,9 @@ import matplotlib.pyplot as plt
|
||||
|
||||
from adnc.analysis.plot_functions import PlotFunctions
|
||||
|
||||
"""
|
||||
Function for plotting different memory units behaviours.
|
||||
"""
|
||||
|
||||
class PlotFunctionality(PlotFunctions):
|
||||
def __init__(self, bucket, legend=False, title=False, text_size=16, data_type='png'):
|
||||
|
@ -21,6 +21,9 @@ import matplotlib.pyplot as plt
|
||||
from matplotlib import colors
|
||||
from matplotlib import ticker
|
||||
|
||||
"""
|
||||
Principle plot functions for gates, modes, input/output sequences or reading/writings of the DNC
|
||||
"""
|
||||
|
||||
class PlotFunctions():
|
||||
def __init__(self, legend=False, title=False, text_size=16):
|
||||
|
@ -17,6 +17,9 @@ import numpy as np
|
||||
from adnc.model.utils import softmax
|
||||
from adnc.model.utils import weighted_softmax
|
||||
|
||||
"""
|
||||
The bucket prepaires and provide a full samples sequence of the DNC internal states for plotting the functions.
|
||||
"""
|
||||
|
||||
class Bucket:
|
||||
def __init__(self, variables, babi_short=True):
|
||||
|
@ -24,7 +24,16 @@ from adnc.data.tasks.babi import bAbI
|
||||
|
||||
|
||||
class DataLoader():
|
||||
"""
|
||||
The data loader loads and process the datasets and provides iterators for training or inference.
|
||||
"""
|
||||
def __init__(self, config, word_dict=None, re_word_dict=None):
|
||||
"""
|
||||
Args:
|
||||
config: dict with the config to pre-process the dataset
|
||||
word_dict: dict with word-feature pairs, optional
|
||||
re_word_dict: dict with feature-word pairs, optional
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
if config['data_set'] == 'copy_task':
|
||||
@ -47,6 +56,14 @@ class DataLoader():
|
||||
return self.dataset.y_size
|
||||
|
||||
def batch_amount(self, set_name):
|
||||
"""
|
||||
Calculates the batch amount given a batch size
|
||||
Args:
|
||||
set_name: str, name of dataset (train, test, valid)
|
||||
|
||||
Returns: int, number of batches
|
||||
|
||||
"""
|
||||
if 'max_len' in self.config.keys():
|
||||
return np.floor(
|
||||
self.dataset.sample_amount(set_name, self.config['max_len']) / self.config['batch_size']).astype(int)
|
||||
@ -63,6 +80,18 @@ class DataLoader():
|
||||
return self.dataset.decode_output(sample, prediction)
|
||||
|
||||
def get_data_loader(self, set_name, shuffle=True, max_len=False, batch_size=None, get_shuffle_option=False):
|
||||
"""
|
||||
Provides a data iterator of the given dataset.
|
||||
Args:
|
||||
set_name: str, name of dataset
|
||||
shuffle: bool, shuffle set or not
|
||||
max_len: int, max length in time of sample
|
||||
batch_size: int, batch size
|
||||
get_shuffle_option: bool, returns shuffle function
|
||||
|
||||
Returns: iter, iterator over dataset
|
||||
|
||||
"""
|
||||
|
||||
if batch_size == None:
|
||||
batch_size = self.config['batch_size']
|
||||
@ -77,6 +106,16 @@ class DataLoader():
|
||||
|
||||
@staticmethod
|
||||
def _generate_in_background(batch_gen, num_cached=10, threads=1):
|
||||
"""
|
||||
Starts threads with parallel batch generator for faster iteration
|
||||
Args:
|
||||
batch_gen: func, the batch generator
|
||||
num_cached: int, numb of caches batches
|
||||
threads: int, numb of parallel threads
|
||||
|
||||
Returns: iter, iterator over dataset
|
||||
|
||||
"""
|
||||
|
||||
queue = Queue(maxsize=num_cached)
|
||||
sentinel = object()
|
||||
|
@ -20,11 +20,14 @@ from urllib.request import Request, urlopen
|
||||
|
||||
import numpy as np
|
||||
|
||||
"""
|
||||
Downloads and pre-preocess the 20 bAbI task. It also augmets task 16 as described in paper.
|
||||
"""
|
||||
|
||||
DEFAULT_DATA_FOLDER = "data_babi"
|
||||
LONGEST_SAMPLE_LENGTH = 1920
|
||||
bAbI_URL = 'http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz'
|
||||
|
||||
|
||||
class bAbI():
|
||||
def __init__(self, config, word_dict=None, re_word_dict=None):
|
||||
|
||||
|
@ -24,11 +24,14 @@ import numpy as np
|
||||
|
||||
from adnc.data.utils.data_memorizer import DataMemorizer
|
||||
|
||||
"""
|
||||
Downloads and pre-preocess the CNN RC task.
|
||||
"""
|
||||
|
||||
DEFAULT_DATA_FOLDER = 'data_cnn'
|
||||
DEFAULT_TMP_FOLDER = 'data_tmp'
|
||||
CNN_DATA_URL = 'http://cs.stanford.edu/~danqi/data/cnn.tar.gz'
|
||||
|
||||
|
||||
class ReadingComprehension():
|
||||
def __init__(self, config, save=True, debug_max_load=None):
|
||||
|
||||
|
@ -16,6 +16,9 @@ import numpy as np
|
||||
from collections import OrderedDict
|
||||
from scipy.sparse import csr_matrix
|
||||
|
||||
"""
|
||||
Generates "repeat a input sequences"-samples as described in NTM paper.
|
||||
"""
|
||||
|
||||
class CopyTask():
|
||||
def __init__(self, config):
|
||||
|
@ -18,6 +18,15 @@ import threading
|
||||
|
||||
class BatchGenerator():
|
||||
def __init__(self, data_set, set, batch_size, shuffle=True, max_len=False):
|
||||
"""
|
||||
Creates batches out of samples from the dataset object
|
||||
Args:
|
||||
data_set: dataset object, contains the dataset
|
||||
set: str, name of dataset (train, valid, test)
|
||||
batch_size: int, batch size
|
||||
shuffle: bool, shuffle or not
|
||||
max_len: int, max length of sample in time
|
||||
"""
|
||||
|
||||
self.set = set
|
||||
self.data_set = data_set
|
||||
@ -36,9 +45,15 @@ class BatchGenerator():
|
||||
self.sample_count = 0
|
||||
|
||||
def shuffle_order(self):
|
||||
"""
|
||||
Shuffles the order of sample in dataset
|
||||
"""
|
||||
self.order = self.data_set.rng.permutation(self.order)
|
||||
|
||||
def increase_sample_count(self):
|
||||
"""
|
||||
Increase the global sample count, it is required because it can run in parallel
|
||||
"""
|
||||
with self.lock:
|
||||
self.sample_count += 1
|
||||
if self.sample_count >= self.sample_amount:
|
||||
@ -50,7 +65,12 @@ class BatchGenerator():
|
||||
return next(self)
|
||||
|
||||
def __next__(self):
|
||||
"""
|
||||
Loads the next data samples and creates a batch. It uses the get_sample and patch_batch methods which are
|
||||
provided by the dataset
|
||||
Returns: batch of samples
|
||||
|
||||
"""
|
||||
batch_list = []
|
||||
for b in range(self.batch_size):
|
||||
|
||||
|
@ -20,8 +20,15 @@ from collections import OrderedDict
|
||||
|
||||
|
||||
class DataMemorizer():
|
||||
"""
|
||||
Given a config, it saves the pre-processed data in a pickle dump.
|
||||
"""
|
||||
def __init__(self, config, tmp_dir):
|
||||
|
||||
"""
|
||||
Args:
|
||||
config: dict, config of dataset
|
||||
tmp_dir: str, dir to save dataset dump
|
||||
"""
|
||||
self.hash_name = self.make_config_hash(config)
|
||||
if isinstance(tmp_dir, pathlib.Path):
|
||||
self.tmp_dir = tmp_dir
|
||||
@ -35,25 +42,46 @@ class DataMemorizer():
|
||||
return self.check_existent()
|
||||
|
||||
def check_existent(self):
|
||||
"""
|
||||
Returns: bool, if the dataset dump exists
|
||||
"""
|
||||
file_name = self.tmp_dir / self.hash_name
|
||||
return file_name.exists()
|
||||
|
||||
def load_data(self):
|
||||
"""
|
||||
Returns: dataset, pickle load of dataset
|
||||
"""
|
||||
with open(str(self.tmp_dir / self.hash_name), 'rb') as outfile:
|
||||
data = pickle.load(outfile)
|
||||
return data
|
||||
|
||||
def dump_data(self, data_to_save):
|
||||
"""
|
||||
Args:
|
||||
data_to_save: object, what to save
|
||||
"""
|
||||
with open(str(self.tmp_dir / self.hash_name), 'wb') as outfile:
|
||||
pickle.dump(data_to_save, outfile)
|
||||
|
||||
def purge_data(self):
|
||||
"""
|
||||
removes data dump
|
||||
"""
|
||||
file_name = str(self.tmp_dir / self.hash_name)
|
||||
if os.path.isfile(file_name):
|
||||
os.remove(file_name)
|
||||
|
||||
@staticmethod
|
||||
def make_config_hash(dict):
|
||||
"""
|
||||
computes a hash string to name the dataset dump uniquely
|
||||
Args:
|
||||
dict: dict, config which describes the dataset
|
||||
|
||||
Returns: str, hash tag of dataset
|
||||
|
||||
"""
|
||||
pre = sorted(((k, v) for k, v in dict.items() if k not in ['batch_size', 'num_chached', 'threads']))
|
||||
sort_dict = OrderedDict()
|
||||
for element in pre:
|
||||
|
@ -18,6 +18,9 @@ from tensorflow.contrib.rnn import LayerNormBasicLSTMCell, LSTMCell, LSTMBlockCe
|
||||
from adnc.model.controller_units.custom_lstm_cell import CustomLSTMCell
|
||||
from adnc.model.utils import get_activation
|
||||
|
||||
"""
|
||||
A wrapper for the controller units.
|
||||
"""
|
||||
|
||||
def get_rnn_cell_list(config, name, reuse=False, seed=123, dtype=tf.float32):
|
||||
cell_list = []
|
||||
|
@ -17,6 +17,9 @@ import tensorflow as tf
|
||||
|
||||
from adnc.model.utils import layer_norm, get_activation
|
||||
|
||||
"""
|
||||
A implementation of the LSTM unit, it performs a bit faster as the TF implementation and implements layer norm.
|
||||
"""
|
||||
|
||||
class CustomLSTMCell():
|
||||
def __init__(self, num_units, layer_norm=False, activation='tanh', seed=100, reuse=False, trainable=True,
|
||||
|
@ -23,10 +23,19 @@ from adnc.model.memory_units.memory_unit import get_memory_unit
|
||||
from adnc.model.utils import HolisticMultiRNNCell
|
||||
from adnc.model.utils import WordEmbedding
|
||||
|
||||
"""
|
||||
The memory augmented neural network (MANN) model object contains the controller and the memory unit as well as the
|
||||
loss function and connects everything.
|
||||
"""
|
||||
|
||||
class MANN():
|
||||
|
||||
def __init__(self, config, analyse=False, reuse=False, name='mann', dtype=tf.float32, new_output_structure=False):
|
||||
def __init__(self, config, analyse=False, reuse=False, name='mann', dtype=tf.float32):
|
||||
"""
|
||||
Args:
|
||||
config: dict, configuration of the whole model
|
||||
analyse: bool, is analyzer is used or not
|
||||
reuse: bool, reuse model or not
|
||||
"""
|
||||
|
||||
self.seed = config["seed"]
|
||||
self.rng = np.random.RandomState(seed=self.seed)
|
||||
@ -45,7 +54,6 @@ class MANN():
|
||||
self.output_mask = config["output_mask"]
|
||||
self.loss_function = config['loss_function']
|
||||
|
||||
|
||||
self.reuse = reuse
|
||||
self.name = name
|
||||
|
||||
@ -66,23 +74,27 @@ class MANN():
|
||||
else:
|
||||
self.data = tf.placeholder(tf.float32, [None, self.batch_size, self.input_size], name='x')
|
||||
|
||||
|
||||
if self.architecture in ['uni', 'unidirectional']:
|
||||
unweighted_outputs, states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config, reuse=self.reuse)
|
||||
unweighted_outputs, states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config,
|
||||
reuse=self.reuse)
|
||||
elif self.architecture in ['bi', 'bidirectional']:
|
||||
unweighted_outputs, states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config, reuse=self.reuse)
|
||||
unweighted_outputs, states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config,
|
||||
reuse=self.reuse)
|
||||
else:
|
||||
raise UserWarning("Unknown architecture, use unidirectional or bidirectional")
|
||||
|
||||
|
||||
if self.analyse:
|
||||
with tf.device('/cpu:0'):
|
||||
if self.architecture in ['uni', 'unidirectional']:
|
||||
analyse_outputs, analyse_states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config, analyse=True, reuse=True)
|
||||
analyse_outputs, analyse_states = self.unidirectional(self.data, self.controller_config,
|
||||
self.memory_unit_config, analyse=True,
|
||||
reuse=True)
|
||||
analyse_outputs, analyse_signals = analyse_outputs
|
||||
self.analyse = (analyse_outputs, analyse_signals, analyse_states)
|
||||
elif self.architecture in ['bi', 'bidirectional']:
|
||||
analyse_outputs, analyse_states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config, analyse=True, reuse=True)
|
||||
analyse_outputs, analyse_states = self.bidirectional(self.data, self.controller_config,
|
||||
self.memory_unit_config, analyse=True,
|
||||
reuse=True)
|
||||
analyse_outputs, analyse_signals = analyse_outputs
|
||||
self.analyse = (analyse_outputs, analyse_signals, analyse_states)
|
||||
|
||||
@ -90,15 +102,23 @@ class MANN():
|
||||
self.prediction, self.outputs = self._output_layer(unweighted_outputs)
|
||||
self.loss = self.get_loss(self.prediction)
|
||||
|
||||
|
||||
def _output_layer(self, outputs):
|
||||
"""
|
||||
Calculates the weighted and activated output of the MANN model
|
||||
Args:
|
||||
outputs: TF tensor, concatenation of memory units output and controller output
|
||||
|
||||
Returns: TF tensor, predictions; TF tensor, unactivated predictions
|
||||
|
||||
"""
|
||||
with tf.variable_scope("output_layer"):
|
||||
output_size = outputs.get_shape()[-1].value
|
||||
|
||||
weights_concat = tf.get_variable("weights_concat", (output_size, self.output_size),
|
||||
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed), collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||
bias_merge = tf.get_variable("bias_merge",(self.output_size,), initializer=tf.constant_initializer(0.), collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed),
|
||||
collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||
bias_merge = tf.get_variable("bias_merge", (self.output_size,), initializer=tf.constant_initializer(0.),
|
||||
collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||
|
||||
output_flat = tf.reshape(outputs, [-1, output_size])
|
||||
output_flat = tf.matmul(output_flat, weights_concat) + bias_merge
|
||||
@ -117,63 +137,25 @@ class MANN():
|
||||
|
||||
return predictions, weighted_outputs
|
||||
|
||||
@property
|
||||
def feed(self):
|
||||
return self.data, self.target, self.mask
|
||||
|
||||
@property
|
||||
def controller_trainable_variables(self):
|
||||
return tf.get_collection('recurrent_unit')
|
||||
|
||||
@property
|
||||
def memory_unit_trainable_variables(self):
|
||||
return tf.get_collection('memory_unit')
|
||||
|
||||
@property
|
||||
def mann_trainable_variables(self):
|
||||
return tf.get_collection('mann')
|
||||
|
||||
@property
|
||||
def trainable_variables(self):
|
||||
return tf.trainable_variables()
|
||||
|
||||
@property
|
||||
def controller_parameter_amount(self):
|
||||
return self.count_parameter_amount(self.controller_trainable_variables)
|
||||
|
||||
@property
|
||||
def memory_unit_parameter_amount(self):
|
||||
return self.count_parameter_amount(self.memory_unit_trainable_variables)
|
||||
|
||||
@property
|
||||
def mann_parameter_amount(self):
|
||||
return self.count_parameter_amount(self.mann_trainable_variables)
|
||||
|
||||
@staticmethod
|
||||
def count_parameter_amount(var_list):
|
||||
parameters = 0
|
||||
for variable in var_list:
|
||||
shape = variable.get_shape()
|
||||
variable_parametes = 1
|
||||
for dim in shape:
|
||||
variable_parametes *= dim.value
|
||||
parameters += variable_parametes
|
||||
return parameters
|
||||
|
||||
@property
|
||||
def parameter_amount(self):
|
||||
var_list = tf.trainable_variables()
|
||||
return self.count_parameter_amount(var_list)
|
||||
|
||||
|
||||
def get_loss(self, prediction):
|
||||
"""
|
||||
Args:
|
||||
prediction: TF tensor, activated prediction of the model
|
||||
|
||||
Returns: TF scalar, loss of the current forward set
|
||||
|
||||
"""
|
||||
if self.loss_function == 'cross_entropy':
|
||||
if self.output_mask:
|
||||
cost = tf.reduce_sum(-1 * self.target * tf.log(tf.clip_by_value(prediction,1e-12,10.0)) - (1 - self.target) * tf.log(tf.clip_by_value(1 - prediction,1e-12,10.0)), axis=2)
|
||||
cost = tf.reduce_sum(
|
||||
-1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log(
|
||||
tf.clip_by_value(1 - prediction, 1e-12, 10.0)), axis=2)
|
||||
cost *= self.mask
|
||||
loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask)
|
||||
else:
|
||||
loss = tf.reduce_mean(-1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log(tf.clip_by_value(1 - prediction, 1e-12, 10.0)))
|
||||
loss = tf.reduce_mean(
|
||||
-1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log(
|
||||
tf.clip_by_value(1 - prediction, 1e-12, 10.0)))
|
||||
|
||||
elif self.loss_function == 'mse':
|
||||
clipped_prediction = tf.clip_by_value(prediction, 1e-12, 10.0)
|
||||
@ -189,20 +171,34 @@ class MANN():
|
||||
raise UserWarning("Unknown loss function, use cross_entropy or mse")
|
||||
return loss
|
||||
|
||||
|
||||
def unidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False):
|
||||
"""
|
||||
Connects unidirectional controller and memory unit and performs scan over sequence
|
||||
Args:
|
||||
inputs: TF tensor, input sequence
|
||||
controller_config: dict, configuration of the controller
|
||||
memory_unit_config: dict, configuration of the memory unit
|
||||
analyse: bool, do analysis
|
||||
reuse: bool, reuse
|
||||
|
||||
Returns: TF tensor, output sequence; TF tensor, hidden states
|
||||
|
||||
"""
|
||||
|
||||
with tf.variable_scope("controller"):
|
||||
controller_list = get_rnn_cell_list(controller_config, name='controller', reuse=reuse, seed=self.seed, dtype=self.dtype)
|
||||
controller_list = get_rnn_cell_list(controller_config, name='controller', reuse=reuse, seed=self.seed,
|
||||
dtype=self.dtype)
|
||||
|
||||
if controller_config['connect'] == 'sparse':
|
||||
memory_input_size = controller_list[-1].output_size
|
||||
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse)
|
||||
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse,
|
||||
reuse=reuse)
|
||||
cell = MultiRNNCell(controller_list + [mu_cell])
|
||||
else:
|
||||
controller_cell = HolisticMultiRNNCell(controller_list)
|
||||
memory_input_size = controller_cell.output_size
|
||||
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse)
|
||||
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse,
|
||||
reuse=reuse)
|
||||
cell = MultiRNNCell([controller_cell, mu_cell])
|
||||
|
||||
batch_size = inputs.get_shape()[1].value
|
||||
@ -228,8 +224,19 @@ class MANN():
|
||||
|
||||
return outputs, states
|
||||
|
||||
|
||||
def bidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False):
|
||||
"""
|
||||
Connects bidirectional controller and memory unit and performs scan over sequence
|
||||
Args:
|
||||
inputs: TF tensor, input sequence
|
||||
controller_config: dict, configuration of the controller
|
||||
memory_unit_config: dict, configuration of the memory unit
|
||||
analyse: bool, do analysis
|
||||
reuse: bool, reuse
|
||||
|
||||
Returns: TF tensor, output sequence; TF tensor, hidden states
|
||||
|
||||
"""
|
||||
|
||||
with tf.variable_scope("controller"):
|
||||
list_fw = get_rnn_cell_list(controller_config, name='con_fw', reuse=reuse, seed=self.seed, dtype=self.dtype)
|
||||
@ -278,7 +285,59 @@ class MANN():
|
||||
|
||||
return (output_mu, states_fw, states_mu)
|
||||
|
||||
outputs, states_fw, states_mu = tf.scan(step, coupled_inputs, initializer=init_states, parallel_iterations=32)
|
||||
outputs, states_fw, states_mu = tf.scan(step, coupled_inputs, initializer=init_states,
|
||||
parallel_iterations=32)
|
||||
|
||||
states = states_fw, states_mu
|
||||
return outputs, states
|
||||
|
||||
@property
|
||||
def feed(self):
|
||||
"""
|
||||
Returns: TF placeholder for data, target and mask inout to the model
|
||||
"""
|
||||
return self.data, self.target, self.mask
|
||||
|
||||
@property
|
||||
def controller_trainable_variables(self):
|
||||
return tf.get_collection('recurrent_unit')
|
||||
|
||||
@property
|
||||
def memory_unit_trainable_variables(self):
|
||||
return tf.get_collection('memory_unit')
|
||||
|
||||
@property
|
||||
def mann_trainable_variables(self):
|
||||
return tf.get_collection('mann')
|
||||
|
||||
@property
|
||||
def trainable_variables(self):
|
||||
return tf.trainable_variables()
|
||||
|
||||
@property
|
||||
def controller_parameter_amount(self):
|
||||
return self.count_parameter_amount(self.controller_trainable_variables)
|
||||
|
||||
@property
|
||||
def memory_unit_parameter_amount(self):
|
||||
return self.count_parameter_amount(self.memory_unit_trainable_variables)
|
||||
|
||||
@property
|
||||
def mann_parameter_amount(self):
|
||||
return self.count_parameter_amount(self.mann_trainable_variables)
|
||||
|
||||
@staticmethod
|
||||
def count_parameter_amount(var_list):
|
||||
parameters = 0
|
||||
for variable in var_list:
|
||||
shape = variable.get_shape()
|
||||
variable_parametes = 1
|
||||
for dim in shape:
|
||||
variable_parametes *= dim.value
|
||||
parameters += variable_parametes
|
||||
return parameters
|
||||
|
||||
@property
|
||||
def parameter_amount(self):
|
||||
var_list = tf.trainable_variables()
|
||||
return self.count_parameter_amount(var_list)
|
||||
|
@ -16,6 +16,9 @@ from abc import abstractmethod, ABCMeta
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
"""
|
||||
The basis DNC memory unit class, all other inherit from this.
|
||||
"""
|
||||
|
||||
class BaseMemoryUnitCell():
|
||||
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,
|
||||
|
@ -17,6 +17,9 @@ import tensorflow as tf
|
||||
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
|
||||
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||
|
||||
"""
|
||||
The content-based memory unit.
|
||||
"""
|
||||
|
||||
class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):
|
||||
|
||||
|
@ -18,6 +18,9 @@ import tensorflow as tf
|
||||
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
|
||||
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||
|
||||
"""
|
||||
The vanilla DNC memory unit.
|
||||
"""
|
||||
|
||||
class DNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,
|
||||
|
@ -19,6 +19,9 @@ from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
|
||||
from adnc.model.memory_units.multi_write_content_based_cell import MWContentMemoryUnitCell
|
||||
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
|
||||
|
||||
"""
|
||||
A warpper for the memory units
|
||||
"""
|
||||
|
||||
def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, seed=123, dtype=tf.float32):
|
||||
memory_length = config['memory_length']
|
||||
|
@ -17,6 +17,9 @@ import tensorflow as tf
|
||||
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
|
||||
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||
|
||||
"""
|
||||
The content-based memory unit with multi write heads.
|
||||
"""
|
||||
|
||||
class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):
|
||||
|
||||
|
@ -20,6 +20,9 @@ from adnc.model.utils import layer_norm
|
||||
from adnc.model.utils import oneplus
|
||||
from adnc.model.utils import unit_simplex_initialization
|
||||
|
||||
"""
|
||||
The vanilla DNC memory unit with multi write heads.
|
||||
"""
|
||||
|
||||
class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||
def __init__(self, input_size, memory_length, memory_width, read_heads, write_heads, bypass_dropout=False,
|
||||
|
@ -14,6 +14,9 @@
|
||||
# ==============================================================================
|
||||
import tensorflow as tf
|
||||
|
||||
"""
|
||||
The Optimizer clas is a wrapper for the TF optimizer and performs gradient clipping and weight decay
|
||||
"""
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, config, loss, variables, use_locking=False):
|
||||
|
@ -22,6 +22,10 @@ from shutil import copyfile
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
"""
|
||||
The supporter class creates for each training run a folder, logs the prints in a file and saves the weights,
|
||||
gradients and losses.
|
||||
"""
|
||||
|
||||
class color_code:
|
||||
def __init__(self):
|
||||
|
@ -15,6 +15,9 @@
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
"""
|
||||
EarlyStop has a true call return if the loss was the last "list_len" higher as the loss before.
|
||||
"""
|
||||
|
||||
class EarlyStop():
|
||||
def __init__(self, list_len=5):
|
||||
|
@ -21,6 +21,9 @@ import hashlib
|
||||
from collections import OrderedDict
|
||||
import tensorflow as tf
|
||||
|
||||
"""
|
||||
Downloads and process glove word embeddings, applies them to a given vocabulary of a dataset.
|
||||
"""
|
||||
|
||||
class WordEmbedding():
|
||||
def __init__(self, embedding_size, vocabulary_size=None, word_idx_dict=None, initialization='uniform', tmp_dir='.',
|
||||
|
@ -25,6 +25,11 @@ import tensorflow as tf
|
||||
from adnc.data.loader import DataLoader
|
||||
from adnc.model.mann import MANN
|
||||
|
||||
"""
|
||||
This script performs a inference with the given models of this repository on the bAbI task 1 or on 1-20. Please add the
|
||||
model name when calling the script. (dnc, adnc, biadnc, biadnc-all, biadnc-aug16-all)
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description='Load model')
|
||||
parser.add_argument('model', type=str, default=False, help='model name')
|
||||
model_name = parser.parse_args().model
|
||||
|
@ -22,6 +22,10 @@ from tqdm import tqdm
|
||||
from adnc.data.loader import DataLoader
|
||||
from adnc.model.mann import MANN
|
||||
|
||||
"""
|
||||
This script performs a inference with the given models of this repository on the CNN RC task.
|
||||
"""
|
||||
|
||||
expt_dir = "experiments/pre_trained/cnn_rc_task/adnc"
|
||||
|
||||
# load config file
|
||||
|
@ -25,6 +25,11 @@ from adnc.analysis import Bucket, PlotFunctionality
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
"""
|
||||
This script plot the memory unit functionality with the given models of this repository on the bAbI task 1 or on 1-20.
|
||||
Please add the model name when calling the script. (dnc, adnc, biadnc, biadnc-all, biadnc-aug16-all)
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description='Load model')
|
||||
parser.add_argument('model', type=str, default=False, help='model name')
|
||||
model_name = parser.parse_args().model
|
||||
|
@ -27,6 +27,11 @@ from adnc.data import DataLoader
|
||||
from adnc.model import MANN, Optimizer, Supporter
|
||||
from adnc.model.utils import EarlyStop
|
||||
|
||||
"""
|
||||
This script performs starts a training run on the bAbI task. The training can be fully configured in the config.yml
|
||||
file. To restore a session use the --sess and --check flag.
|
||||
"""
|
||||
|
||||
tf.reset_default_graph()
|
||||
|
||||
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||
@ -63,7 +68,7 @@ valid_loader = dl.get_data_loader('valid') # gets a valid data iterator
|
||||
train_loader = dl.get_data_loader('train') # gets a train data iterator
|
||||
|
||||
if analyse:
|
||||
ana = Analyser(dataset_name, sp.session_dir, save_fig=plot_process,
|
||||
ana = Analyser(sp.session_dir, save_fig=plot_process,
|
||||
save_variables=True) # initilizes a analyzer class
|
||||
|
||||
sp.config(model_type)['input_size'] = dl.x_size # after the data loader is initilized, the input size
|
||||
|
Loading…
Reference in New Issue
Block a user