mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
add docstrings
This commit is contained in:
parent
caf9e137fa
commit
88258f7656
@ -21,21 +21,19 @@ from adnc.analysis.plot_functionality import PlotFunctionality
|
|||||||
from adnc.analysis.prepare_variables import Bucket
|
from adnc.analysis.prepare_variables import Bucket
|
||||||
from adnc.model.utils import softmax
|
from adnc.model.utils import softmax
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Analyser():
|
class Analyser():
|
||||||
def __init__(self, data_set, record_dir, save_variables=False, save_fig=False):
|
|
||||||
"""
|
"""
|
||||||
|
The analyzer helps to analyze the functionality of the DNC during training. It is used to calculate the
|
||||||
|
memory influence and to plot function plots of the memory usage.
|
||||||
|
"""
|
||||||
|
def __init__(self, record_dir, save_variables=False, save_fig=False):
|
||||||
|
"""
|
||||||
Args:
|
Args:
|
||||||
data_set:
|
record_dir: dir to store the function plots
|
||||||
record_dir:
|
save_variables: bool, to save weights, gradients and losses in a numpy list
|
||||||
save_variables:
|
save_fig: bool, save plots
|
||||||
save_fig:
|
|
||||||
"""
|
"""
|
||||||
self.data_set = data_set
|
|
||||||
self.record_dir = record_dir
|
self.record_dir = record_dir
|
||||||
self.save_variables = save_variables
|
self.save_variables = save_variables
|
||||||
self.save_fig = save_fig
|
self.save_fig = save_fig
|
||||||
|
@ -23,6 +23,9 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
from adnc.analysis.plot_functions import PlotFunctions
|
from adnc.analysis.plot_functions import PlotFunctions
|
||||||
|
|
||||||
|
"""
|
||||||
|
Function for plotting different memory units behaviours.
|
||||||
|
"""
|
||||||
|
|
||||||
class PlotFunctionality(PlotFunctions):
|
class PlotFunctionality(PlotFunctions):
|
||||||
def __init__(self, bucket, legend=False, title=False, text_size=16, data_type='png'):
|
def __init__(self, bucket, legend=False, title=False, text_size=16, data_type='png'):
|
||||||
|
@ -21,6 +21,9 @@ import matplotlib.pyplot as plt
|
|||||||
from matplotlib import colors
|
from matplotlib import colors
|
||||||
from matplotlib import ticker
|
from matplotlib import ticker
|
||||||
|
|
||||||
|
"""
|
||||||
|
Principle plot functions for gates, modes, input/output sequences or reading/writings of the DNC
|
||||||
|
"""
|
||||||
|
|
||||||
class PlotFunctions():
|
class PlotFunctions():
|
||||||
def __init__(self, legend=False, title=False, text_size=16):
|
def __init__(self, legend=False, title=False, text_size=16):
|
||||||
|
@ -17,6 +17,9 @@ import numpy as np
|
|||||||
from adnc.model.utils import softmax
|
from adnc.model.utils import softmax
|
||||||
from adnc.model.utils import weighted_softmax
|
from adnc.model.utils import weighted_softmax
|
||||||
|
|
||||||
|
"""
|
||||||
|
The bucket prepaires and provide a full samples sequence of the DNC internal states for plotting the functions.
|
||||||
|
"""
|
||||||
|
|
||||||
class Bucket:
|
class Bucket:
|
||||||
def __init__(self, variables, babi_short=True):
|
def __init__(self, variables, babi_short=True):
|
||||||
|
@ -24,7 +24,16 @@ from adnc.data.tasks.babi import bAbI
|
|||||||
|
|
||||||
|
|
||||||
class DataLoader():
|
class DataLoader():
|
||||||
|
"""
|
||||||
|
The data loader loads and process the datasets and provides iterators for training or inference.
|
||||||
|
"""
|
||||||
def __init__(self, config, word_dict=None, re_word_dict=None):
|
def __init__(self, config, word_dict=None, re_word_dict=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
config: dict with the config to pre-process the dataset
|
||||||
|
word_dict: dict with word-feature pairs, optional
|
||||||
|
re_word_dict: dict with feature-word pairs, optional
|
||||||
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
if config['data_set'] == 'copy_task':
|
if config['data_set'] == 'copy_task':
|
||||||
@ -47,6 +56,14 @@ class DataLoader():
|
|||||||
return self.dataset.y_size
|
return self.dataset.y_size
|
||||||
|
|
||||||
def batch_amount(self, set_name):
|
def batch_amount(self, set_name):
|
||||||
|
"""
|
||||||
|
Calculates the batch amount given a batch size
|
||||||
|
Args:
|
||||||
|
set_name: str, name of dataset (train, test, valid)
|
||||||
|
|
||||||
|
Returns: int, number of batches
|
||||||
|
|
||||||
|
"""
|
||||||
if 'max_len' in self.config.keys():
|
if 'max_len' in self.config.keys():
|
||||||
return np.floor(
|
return np.floor(
|
||||||
self.dataset.sample_amount(set_name, self.config['max_len']) / self.config['batch_size']).astype(int)
|
self.dataset.sample_amount(set_name, self.config['max_len']) / self.config['batch_size']).astype(int)
|
||||||
@ -63,6 +80,18 @@ class DataLoader():
|
|||||||
return self.dataset.decode_output(sample, prediction)
|
return self.dataset.decode_output(sample, prediction)
|
||||||
|
|
||||||
def get_data_loader(self, set_name, shuffle=True, max_len=False, batch_size=None, get_shuffle_option=False):
|
def get_data_loader(self, set_name, shuffle=True, max_len=False, batch_size=None, get_shuffle_option=False):
|
||||||
|
"""
|
||||||
|
Provides a data iterator of the given dataset.
|
||||||
|
Args:
|
||||||
|
set_name: str, name of dataset
|
||||||
|
shuffle: bool, shuffle set or not
|
||||||
|
max_len: int, max length in time of sample
|
||||||
|
batch_size: int, batch size
|
||||||
|
get_shuffle_option: bool, returns shuffle function
|
||||||
|
|
||||||
|
Returns: iter, iterator over dataset
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
if batch_size == None:
|
if batch_size == None:
|
||||||
batch_size = self.config['batch_size']
|
batch_size = self.config['batch_size']
|
||||||
@ -77,6 +106,16 @@ class DataLoader():
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_in_background(batch_gen, num_cached=10, threads=1):
|
def _generate_in_background(batch_gen, num_cached=10, threads=1):
|
||||||
|
"""
|
||||||
|
Starts threads with parallel batch generator for faster iteration
|
||||||
|
Args:
|
||||||
|
batch_gen: func, the batch generator
|
||||||
|
num_cached: int, numb of caches batches
|
||||||
|
threads: int, numb of parallel threads
|
||||||
|
|
||||||
|
Returns: iter, iterator over dataset
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
queue = Queue(maxsize=num_cached)
|
queue = Queue(maxsize=num_cached)
|
||||||
sentinel = object()
|
sentinel = object()
|
||||||
|
@ -20,11 +20,14 @@ from urllib.request import Request, urlopen
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
"""
|
||||||
|
Downloads and pre-preocess the 20 bAbI task. It also augmets task 16 as described in paper.
|
||||||
|
"""
|
||||||
|
|
||||||
DEFAULT_DATA_FOLDER = "data_babi"
|
DEFAULT_DATA_FOLDER = "data_babi"
|
||||||
LONGEST_SAMPLE_LENGTH = 1920
|
LONGEST_SAMPLE_LENGTH = 1920
|
||||||
bAbI_URL = 'http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz'
|
bAbI_URL = 'http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz'
|
||||||
|
|
||||||
|
|
||||||
class bAbI():
|
class bAbI():
|
||||||
def __init__(self, config, word_dict=None, re_word_dict=None):
|
def __init__(self, config, word_dict=None, re_word_dict=None):
|
||||||
|
|
||||||
|
@ -24,11 +24,14 @@ import numpy as np
|
|||||||
|
|
||||||
from adnc.data.utils.data_memorizer import DataMemorizer
|
from adnc.data.utils.data_memorizer import DataMemorizer
|
||||||
|
|
||||||
|
"""
|
||||||
|
Downloads and pre-preocess the CNN RC task.
|
||||||
|
"""
|
||||||
|
|
||||||
DEFAULT_DATA_FOLDER = 'data_cnn'
|
DEFAULT_DATA_FOLDER = 'data_cnn'
|
||||||
DEFAULT_TMP_FOLDER = 'data_tmp'
|
DEFAULT_TMP_FOLDER = 'data_tmp'
|
||||||
CNN_DATA_URL = 'http://cs.stanford.edu/~danqi/data/cnn.tar.gz'
|
CNN_DATA_URL = 'http://cs.stanford.edu/~danqi/data/cnn.tar.gz'
|
||||||
|
|
||||||
|
|
||||||
class ReadingComprehension():
|
class ReadingComprehension():
|
||||||
def __init__(self, config, save=True, debug_max_load=None):
|
def __init__(self, config, save=True, debug_max_load=None):
|
||||||
|
|
||||||
|
@ -16,6 +16,9 @@ import numpy as np
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from scipy.sparse import csr_matrix
|
from scipy.sparse import csr_matrix
|
||||||
|
|
||||||
|
"""
|
||||||
|
Generates "repeat a input sequences"-samples as described in NTM paper.
|
||||||
|
"""
|
||||||
|
|
||||||
class CopyTask():
|
class CopyTask():
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
@ -18,6 +18,15 @@ import threading
|
|||||||
|
|
||||||
class BatchGenerator():
|
class BatchGenerator():
|
||||||
def __init__(self, data_set, set, batch_size, shuffle=True, max_len=False):
|
def __init__(self, data_set, set, batch_size, shuffle=True, max_len=False):
|
||||||
|
"""
|
||||||
|
Creates batches out of samples from the dataset object
|
||||||
|
Args:
|
||||||
|
data_set: dataset object, contains the dataset
|
||||||
|
set: str, name of dataset (train, valid, test)
|
||||||
|
batch_size: int, batch size
|
||||||
|
shuffle: bool, shuffle or not
|
||||||
|
max_len: int, max length of sample in time
|
||||||
|
"""
|
||||||
|
|
||||||
self.set = set
|
self.set = set
|
||||||
self.data_set = data_set
|
self.data_set = data_set
|
||||||
@ -36,9 +45,15 @@ class BatchGenerator():
|
|||||||
self.sample_count = 0
|
self.sample_count = 0
|
||||||
|
|
||||||
def shuffle_order(self):
|
def shuffle_order(self):
|
||||||
|
"""
|
||||||
|
Shuffles the order of sample in dataset
|
||||||
|
"""
|
||||||
self.order = self.data_set.rng.permutation(self.order)
|
self.order = self.data_set.rng.permutation(self.order)
|
||||||
|
|
||||||
def increase_sample_count(self):
|
def increase_sample_count(self):
|
||||||
|
"""
|
||||||
|
Increase the global sample count, it is required because it can run in parallel
|
||||||
|
"""
|
||||||
with self.lock:
|
with self.lock:
|
||||||
self.sample_count += 1
|
self.sample_count += 1
|
||||||
if self.sample_count >= self.sample_amount:
|
if self.sample_count >= self.sample_amount:
|
||||||
@ -50,7 +65,12 @@ class BatchGenerator():
|
|||||||
return next(self)
|
return next(self)
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
|
"""
|
||||||
|
Loads the next data samples and creates a batch. It uses the get_sample and patch_batch methods which are
|
||||||
|
provided by the dataset
|
||||||
|
Returns: batch of samples
|
||||||
|
|
||||||
|
"""
|
||||||
batch_list = []
|
batch_list = []
|
||||||
for b in range(self.batch_size):
|
for b in range(self.batch_size):
|
||||||
|
|
||||||
|
@ -20,8 +20,15 @@ from collections import OrderedDict
|
|||||||
|
|
||||||
|
|
||||||
class DataMemorizer():
|
class DataMemorizer():
|
||||||
|
"""
|
||||||
|
Given a config, it saves the pre-processed data in a pickle dump.
|
||||||
|
"""
|
||||||
def __init__(self, config, tmp_dir):
|
def __init__(self, config, tmp_dir):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
config: dict, config of dataset
|
||||||
|
tmp_dir: str, dir to save dataset dump
|
||||||
|
"""
|
||||||
self.hash_name = self.make_config_hash(config)
|
self.hash_name = self.make_config_hash(config)
|
||||||
if isinstance(tmp_dir, pathlib.Path):
|
if isinstance(tmp_dir, pathlib.Path):
|
||||||
self.tmp_dir = tmp_dir
|
self.tmp_dir = tmp_dir
|
||||||
@ -35,25 +42,46 @@ class DataMemorizer():
|
|||||||
return self.check_existent()
|
return self.check_existent()
|
||||||
|
|
||||||
def check_existent(self):
|
def check_existent(self):
|
||||||
|
"""
|
||||||
|
Returns: bool, if the dataset dump exists
|
||||||
|
"""
|
||||||
file_name = self.tmp_dir / self.hash_name
|
file_name = self.tmp_dir / self.hash_name
|
||||||
return file_name.exists()
|
return file_name.exists()
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
|
"""
|
||||||
|
Returns: dataset, pickle load of dataset
|
||||||
|
"""
|
||||||
with open(str(self.tmp_dir / self.hash_name), 'rb') as outfile:
|
with open(str(self.tmp_dir / self.hash_name), 'rb') as outfile:
|
||||||
data = pickle.load(outfile)
|
data = pickle.load(outfile)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def dump_data(self, data_to_save):
|
def dump_data(self, data_to_save):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
data_to_save: object, what to save
|
||||||
|
"""
|
||||||
with open(str(self.tmp_dir / self.hash_name), 'wb') as outfile:
|
with open(str(self.tmp_dir / self.hash_name), 'wb') as outfile:
|
||||||
pickle.dump(data_to_save, outfile)
|
pickle.dump(data_to_save, outfile)
|
||||||
|
|
||||||
def purge_data(self):
|
def purge_data(self):
|
||||||
|
"""
|
||||||
|
removes data dump
|
||||||
|
"""
|
||||||
file_name = str(self.tmp_dir / self.hash_name)
|
file_name = str(self.tmp_dir / self.hash_name)
|
||||||
if os.path.isfile(file_name):
|
if os.path.isfile(file_name):
|
||||||
os.remove(file_name)
|
os.remove(file_name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_config_hash(dict):
|
def make_config_hash(dict):
|
||||||
|
"""
|
||||||
|
computes a hash string to name the dataset dump uniquely
|
||||||
|
Args:
|
||||||
|
dict: dict, config which describes the dataset
|
||||||
|
|
||||||
|
Returns: str, hash tag of dataset
|
||||||
|
|
||||||
|
"""
|
||||||
pre = sorted(((k, v) for k, v in dict.items() if k not in ['batch_size', 'num_chached', 'threads']))
|
pre = sorted(((k, v) for k, v in dict.items() if k not in ['batch_size', 'num_chached', 'threads']))
|
||||||
sort_dict = OrderedDict()
|
sort_dict = OrderedDict()
|
||||||
for element in pre:
|
for element in pre:
|
||||||
|
@ -18,6 +18,9 @@ from tensorflow.contrib.rnn import LayerNormBasicLSTMCell, LSTMCell, LSTMBlockCe
|
|||||||
from adnc.model.controller_units.custom_lstm_cell import CustomLSTMCell
|
from adnc.model.controller_units.custom_lstm_cell import CustomLSTMCell
|
||||||
from adnc.model.utils import get_activation
|
from adnc.model.utils import get_activation
|
||||||
|
|
||||||
|
"""
|
||||||
|
A wrapper for the controller units.
|
||||||
|
"""
|
||||||
|
|
||||||
def get_rnn_cell_list(config, name, reuse=False, seed=123, dtype=tf.float32):
|
def get_rnn_cell_list(config, name, reuse=False, seed=123, dtype=tf.float32):
|
||||||
cell_list = []
|
cell_list = []
|
||||||
|
@ -17,6 +17,9 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from adnc.model.utils import layer_norm, get_activation
|
from adnc.model.utils import layer_norm, get_activation
|
||||||
|
|
||||||
|
"""
|
||||||
|
A implementation of the LSTM unit, it performs a bit faster as the TF implementation and implements layer norm.
|
||||||
|
"""
|
||||||
|
|
||||||
class CustomLSTMCell():
|
class CustomLSTMCell():
|
||||||
def __init__(self, num_units, layer_norm=False, activation='tanh', seed=100, reuse=False, trainable=True,
|
def __init__(self, num_units, layer_norm=False, activation='tanh', seed=100, reuse=False, trainable=True,
|
||||||
|
@ -23,10 +23,19 @@ from adnc.model.memory_units.memory_unit import get_memory_unit
|
|||||||
from adnc.model.utils import HolisticMultiRNNCell
|
from adnc.model.utils import HolisticMultiRNNCell
|
||||||
from adnc.model.utils import WordEmbedding
|
from adnc.model.utils import WordEmbedding
|
||||||
|
|
||||||
|
"""
|
||||||
|
The memory augmented neural network (MANN) model object contains the controller and the memory unit as well as the
|
||||||
|
loss function and connects everything.
|
||||||
|
"""
|
||||||
|
|
||||||
class MANN():
|
class MANN():
|
||||||
|
def __init__(self, config, analyse=False, reuse=False, name='mann', dtype=tf.float32):
|
||||||
def __init__(self, config, analyse=False, reuse=False, name='mann', dtype=tf.float32, new_output_structure=False):
|
"""
|
||||||
|
Args:
|
||||||
|
config: dict, configuration of the whole model
|
||||||
|
analyse: bool, is analyzer is used or not
|
||||||
|
reuse: bool, reuse model or not
|
||||||
|
"""
|
||||||
|
|
||||||
self.seed = config["seed"]
|
self.seed = config["seed"]
|
||||||
self.rng = np.random.RandomState(seed=self.seed)
|
self.rng = np.random.RandomState(seed=self.seed)
|
||||||
@ -45,7 +54,6 @@ class MANN():
|
|||||||
self.output_mask = config["output_mask"]
|
self.output_mask = config["output_mask"]
|
||||||
self.loss_function = config['loss_function']
|
self.loss_function = config['loss_function']
|
||||||
|
|
||||||
|
|
||||||
self.reuse = reuse
|
self.reuse = reuse
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
@ -66,23 +74,27 @@ class MANN():
|
|||||||
else:
|
else:
|
||||||
self.data = tf.placeholder(tf.float32, [None, self.batch_size, self.input_size], name='x')
|
self.data = tf.placeholder(tf.float32, [None, self.batch_size, self.input_size], name='x')
|
||||||
|
|
||||||
|
|
||||||
if self.architecture in ['uni', 'unidirectional']:
|
if self.architecture in ['uni', 'unidirectional']:
|
||||||
unweighted_outputs, states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config, reuse=self.reuse)
|
unweighted_outputs, states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config,
|
||||||
|
reuse=self.reuse)
|
||||||
elif self.architecture in ['bi', 'bidirectional']:
|
elif self.architecture in ['bi', 'bidirectional']:
|
||||||
unweighted_outputs, states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config, reuse=self.reuse)
|
unweighted_outputs, states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config,
|
||||||
|
reuse=self.reuse)
|
||||||
else:
|
else:
|
||||||
raise UserWarning("Unknown architecture, use unidirectional or bidirectional")
|
raise UserWarning("Unknown architecture, use unidirectional or bidirectional")
|
||||||
|
|
||||||
|
|
||||||
if self.analyse:
|
if self.analyse:
|
||||||
with tf.device('/cpu:0'):
|
with tf.device('/cpu:0'):
|
||||||
if self.architecture in ['uni', 'unidirectional']:
|
if self.architecture in ['uni', 'unidirectional']:
|
||||||
analyse_outputs, analyse_states = self.unidirectional(self.data, self.controller_config, self.memory_unit_config, analyse=True, reuse=True)
|
analyse_outputs, analyse_states = self.unidirectional(self.data, self.controller_config,
|
||||||
|
self.memory_unit_config, analyse=True,
|
||||||
|
reuse=True)
|
||||||
analyse_outputs, analyse_signals = analyse_outputs
|
analyse_outputs, analyse_signals = analyse_outputs
|
||||||
self.analyse = (analyse_outputs, analyse_signals, analyse_states)
|
self.analyse = (analyse_outputs, analyse_signals, analyse_states)
|
||||||
elif self.architecture in ['bi', 'bidirectional']:
|
elif self.architecture in ['bi', 'bidirectional']:
|
||||||
analyse_outputs, analyse_states = self.bidirectional(self.data, self.controller_config, self.memory_unit_config, analyse=True, reuse=True)
|
analyse_outputs, analyse_states = self.bidirectional(self.data, self.controller_config,
|
||||||
|
self.memory_unit_config, analyse=True,
|
||||||
|
reuse=True)
|
||||||
analyse_outputs, analyse_signals = analyse_outputs
|
analyse_outputs, analyse_signals = analyse_outputs
|
||||||
self.analyse = (analyse_outputs, analyse_signals, analyse_states)
|
self.analyse = (analyse_outputs, analyse_signals, analyse_states)
|
||||||
|
|
||||||
@ -90,15 +102,23 @@ class MANN():
|
|||||||
self.prediction, self.outputs = self._output_layer(unweighted_outputs)
|
self.prediction, self.outputs = self._output_layer(unweighted_outputs)
|
||||||
self.loss = self.get_loss(self.prediction)
|
self.loss = self.get_loss(self.prediction)
|
||||||
|
|
||||||
|
|
||||||
def _output_layer(self, outputs):
|
def _output_layer(self, outputs):
|
||||||
|
"""
|
||||||
|
Calculates the weighted and activated output of the MANN model
|
||||||
|
Args:
|
||||||
|
outputs: TF tensor, concatenation of memory units output and controller output
|
||||||
|
|
||||||
|
Returns: TF tensor, predictions; TF tensor, unactivated predictions
|
||||||
|
|
||||||
|
"""
|
||||||
with tf.variable_scope("output_layer"):
|
with tf.variable_scope("output_layer"):
|
||||||
output_size = outputs.get_shape()[-1].value
|
output_size = outputs.get_shape()[-1].value
|
||||||
|
|
||||||
weights_concat = tf.get_variable("weights_concat", (output_size, self.output_size),
|
weights_concat = tf.get_variable("weights_concat", (output_size, self.output_size),
|
||||||
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed), collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
initializer=tf.contrib.layers.xavier_initializer(seed=self.seed),
|
||||||
bias_merge = tf.get_variable("bias_merge",(self.output_size,), initializer=tf.constant_initializer(0.), collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||||
|
bias_merge = tf.get_variable("bias_merge", (self.output_size,), initializer=tf.constant_initializer(0.),
|
||||||
|
collections=['mann', tf.GraphKeys.GLOBAL_VARIABLES], dtype=self.dtype)
|
||||||
|
|
||||||
output_flat = tf.reshape(outputs, [-1, output_size])
|
output_flat = tf.reshape(outputs, [-1, output_size])
|
||||||
output_flat = tf.matmul(output_flat, weights_concat) + bias_merge
|
output_flat = tf.matmul(output_flat, weights_concat) + bias_merge
|
||||||
@ -117,63 +137,25 @@ class MANN():
|
|||||||
|
|
||||||
return predictions, weighted_outputs
|
return predictions, weighted_outputs
|
||||||
|
|
||||||
@property
|
|
||||||
def feed(self):
|
|
||||||
return self.data, self.target, self.mask
|
|
||||||
|
|
||||||
@property
|
|
||||||
def controller_trainable_variables(self):
|
|
||||||
return tf.get_collection('recurrent_unit')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def memory_unit_trainable_variables(self):
|
|
||||||
return tf.get_collection('memory_unit')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def mann_trainable_variables(self):
|
|
||||||
return tf.get_collection('mann')
|
|
||||||
|
|
||||||
@property
|
|
||||||
def trainable_variables(self):
|
|
||||||
return tf.trainable_variables()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def controller_parameter_amount(self):
|
|
||||||
return self.count_parameter_amount(self.controller_trainable_variables)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def memory_unit_parameter_amount(self):
|
|
||||||
return self.count_parameter_amount(self.memory_unit_trainable_variables)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def mann_parameter_amount(self):
|
|
||||||
return self.count_parameter_amount(self.mann_trainable_variables)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def count_parameter_amount(var_list):
|
|
||||||
parameters = 0
|
|
||||||
for variable in var_list:
|
|
||||||
shape = variable.get_shape()
|
|
||||||
variable_parametes = 1
|
|
||||||
for dim in shape:
|
|
||||||
variable_parametes *= dim.value
|
|
||||||
parameters += variable_parametes
|
|
||||||
return parameters
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameter_amount(self):
|
|
||||||
var_list = tf.trainable_variables()
|
|
||||||
return self.count_parameter_amount(var_list)
|
|
||||||
|
|
||||||
|
|
||||||
def get_loss(self, prediction):
|
def get_loss(self, prediction):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
prediction: TF tensor, activated prediction of the model
|
||||||
|
|
||||||
|
Returns: TF scalar, loss of the current forward set
|
||||||
|
|
||||||
|
"""
|
||||||
if self.loss_function == 'cross_entropy':
|
if self.loss_function == 'cross_entropy':
|
||||||
if self.output_mask:
|
if self.output_mask:
|
||||||
cost = tf.reduce_sum(-1 * self.target * tf.log(tf.clip_by_value(prediction,1e-12,10.0)) - (1 - self.target) * tf.log(tf.clip_by_value(1 - prediction,1e-12,10.0)), axis=2)
|
cost = tf.reduce_sum(
|
||||||
|
-1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log(
|
||||||
|
tf.clip_by_value(1 - prediction, 1e-12, 10.0)), axis=2)
|
||||||
cost *= self.mask
|
cost *= self.mask
|
||||||
loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask)
|
loss = tf.reduce_sum(cost) / tf.reduce_sum(self.mask)
|
||||||
else:
|
else:
|
||||||
loss = tf.reduce_mean(-1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log(tf.clip_by_value(1 - prediction, 1e-12, 10.0)))
|
loss = tf.reduce_mean(
|
||||||
|
-1 * self.target * tf.log(tf.clip_by_value(prediction, 1e-12, 10.0)) - (1 - self.target) * tf.log(
|
||||||
|
tf.clip_by_value(1 - prediction, 1e-12, 10.0)))
|
||||||
|
|
||||||
elif self.loss_function == 'mse':
|
elif self.loss_function == 'mse':
|
||||||
clipped_prediction = tf.clip_by_value(prediction, 1e-12, 10.0)
|
clipped_prediction = tf.clip_by_value(prediction, 1e-12, 10.0)
|
||||||
@ -189,20 +171,34 @@ class MANN():
|
|||||||
raise UserWarning("Unknown loss function, use cross_entropy or mse")
|
raise UserWarning("Unknown loss function, use cross_entropy or mse")
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def unidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False):
|
def unidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False):
|
||||||
|
"""
|
||||||
|
Connects unidirectional controller and memory unit and performs scan over sequence
|
||||||
|
Args:
|
||||||
|
inputs: TF tensor, input sequence
|
||||||
|
controller_config: dict, configuration of the controller
|
||||||
|
memory_unit_config: dict, configuration of the memory unit
|
||||||
|
analyse: bool, do analysis
|
||||||
|
reuse: bool, reuse
|
||||||
|
|
||||||
|
Returns: TF tensor, output sequence; TF tensor, hidden states
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
with tf.variable_scope("controller"):
|
with tf.variable_scope("controller"):
|
||||||
controller_list = get_rnn_cell_list(controller_config, name='controller', reuse=reuse, seed=self.seed, dtype=self.dtype)
|
controller_list = get_rnn_cell_list(controller_config, name='controller', reuse=reuse, seed=self.seed,
|
||||||
|
dtype=self.dtype)
|
||||||
|
|
||||||
if controller_config['connect'] == 'sparse':
|
if controller_config['connect'] == 'sparse':
|
||||||
memory_input_size = controller_list[-1].output_size
|
memory_input_size = controller_list[-1].output_size
|
||||||
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse)
|
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse,
|
||||||
|
reuse=reuse)
|
||||||
cell = MultiRNNCell(controller_list + [mu_cell])
|
cell = MultiRNNCell(controller_list + [mu_cell])
|
||||||
else:
|
else:
|
||||||
controller_cell = HolisticMultiRNNCell(controller_list)
|
controller_cell = HolisticMultiRNNCell(controller_list)
|
||||||
memory_input_size = controller_cell.output_size
|
memory_input_size = controller_cell.output_size
|
||||||
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse, reuse=reuse)
|
mu_cell = get_memory_unit(memory_input_size, memory_unit_config, 'memory_unit', analyse=analyse,
|
||||||
|
reuse=reuse)
|
||||||
cell = MultiRNNCell([controller_cell, mu_cell])
|
cell = MultiRNNCell([controller_cell, mu_cell])
|
||||||
|
|
||||||
batch_size = inputs.get_shape()[1].value
|
batch_size = inputs.get_shape()[1].value
|
||||||
@ -228,8 +224,19 @@ class MANN():
|
|||||||
|
|
||||||
return outputs, states
|
return outputs, states
|
||||||
|
|
||||||
|
|
||||||
def bidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False):
|
def bidirectional(self, inputs, controller_config, memory_unit_config, analyse=False, reuse=False):
|
||||||
|
"""
|
||||||
|
Connects bidirectional controller and memory unit and performs scan over sequence
|
||||||
|
Args:
|
||||||
|
inputs: TF tensor, input sequence
|
||||||
|
controller_config: dict, configuration of the controller
|
||||||
|
memory_unit_config: dict, configuration of the memory unit
|
||||||
|
analyse: bool, do analysis
|
||||||
|
reuse: bool, reuse
|
||||||
|
|
||||||
|
Returns: TF tensor, output sequence; TF tensor, hidden states
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
with tf.variable_scope("controller"):
|
with tf.variable_scope("controller"):
|
||||||
list_fw = get_rnn_cell_list(controller_config, name='con_fw', reuse=reuse, seed=self.seed, dtype=self.dtype)
|
list_fw = get_rnn_cell_list(controller_config, name='con_fw', reuse=reuse, seed=self.seed, dtype=self.dtype)
|
||||||
@ -278,7 +285,59 @@ class MANN():
|
|||||||
|
|
||||||
return (output_mu, states_fw, states_mu)
|
return (output_mu, states_fw, states_mu)
|
||||||
|
|
||||||
outputs, states_fw, states_mu = tf.scan(step, coupled_inputs, initializer=init_states, parallel_iterations=32)
|
outputs, states_fw, states_mu = tf.scan(step, coupled_inputs, initializer=init_states,
|
||||||
|
parallel_iterations=32)
|
||||||
|
|
||||||
states = states_fw, states_mu
|
states = states_fw, states_mu
|
||||||
return outputs, states
|
return outputs, states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feed(self):
|
||||||
|
"""
|
||||||
|
Returns: TF placeholder for data, target and mask inout to the model
|
||||||
|
"""
|
||||||
|
return self.data, self.target, self.mask
|
||||||
|
|
||||||
|
@property
|
||||||
|
def controller_trainable_variables(self):
|
||||||
|
return tf.get_collection('recurrent_unit')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def memory_unit_trainable_variables(self):
|
||||||
|
return tf.get_collection('memory_unit')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mann_trainable_variables(self):
|
||||||
|
return tf.get_collection('mann')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def trainable_variables(self):
|
||||||
|
return tf.trainable_variables()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def controller_parameter_amount(self):
|
||||||
|
return self.count_parameter_amount(self.controller_trainable_variables)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def memory_unit_parameter_amount(self):
|
||||||
|
return self.count_parameter_amount(self.memory_unit_trainable_variables)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mann_parameter_amount(self):
|
||||||
|
return self.count_parameter_amount(self.mann_trainable_variables)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def count_parameter_amount(var_list):
|
||||||
|
parameters = 0
|
||||||
|
for variable in var_list:
|
||||||
|
shape = variable.get_shape()
|
||||||
|
variable_parametes = 1
|
||||||
|
for dim in shape:
|
||||||
|
variable_parametes *= dim.value
|
||||||
|
parameters += variable_parametes
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameter_amount(self):
|
||||||
|
var_list = tf.trainable_variables()
|
||||||
|
return self.count_parameter_amount(var_list)
|
||||||
|
@ -16,6 +16,9 @@ from abc import abstractmethod, ABCMeta
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
"""
|
||||||
|
The basis DNC memory unit class, all other inherit from this.
|
||||||
|
"""
|
||||||
|
|
||||||
class BaseMemoryUnitCell():
|
class BaseMemoryUnitCell():
|
||||||
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,
|
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,
|
||||||
|
@ -17,6 +17,9 @@ import tensorflow as tf
|
|||||||
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
|
from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
|
||||||
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||||
|
|
||||||
|
"""
|
||||||
|
The content-based memory unit.
|
||||||
|
"""
|
||||||
|
|
||||||
class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):
|
class ContentBasedMemoryUnitCell(DNCMemoryUnitCell):
|
||||||
|
|
||||||
|
@ -18,6 +18,9 @@ import tensorflow as tf
|
|||||||
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
|
from adnc.model.memory_units.base_cell import BaseMemoryUnitCell
|
||||||
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||||
|
|
||||||
|
"""
|
||||||
|
The vanilla DNC memory unit.
|
||||||
|
"""
|
||||||
|
|
||||||
class DNCMemoryUnitCell(BaseMemoryUnitCell):
|
class DNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||||
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,
|
def __init__(self, input_size, memory_length, memory_width, read_heads, bypass_dropout=False, dnc_norm=False,
|
||||||
|
@ -19,6 +19,9 @@ from adnc.model.memory_units.dnc_cell import DNCMemoryUnitCell
|
|||||||
from adnc.model.memory_units.multi_write_content_based_cell import MWContentMemoryUnitCell
|
from adnc.model.memory_units.multi_write_content_based_cell import MWContentMemoryUnitCell
|
||||||
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
|
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
|
||||||
|
|
||||||
|
"""
|
||||||
|
A warpper for the memory units
|
||||||
|
"""
|
||||||
|
|
||||||
def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, seed=123, dtype=tf.float32):
|
def get_memory_unit(input_size, config, name='mu', analyse=False, reuse=False, seed=123, dtype=tf.float32):
|
||||||
memory_length = config['memory_length']
|
memory_length = config['memory_length']
|
||||||
|
@ -17,6 +17,9 @@ import tensorflow as tf
|
|||||||
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
|
from adnc.model.memory_units.multi_write_dnc_cell import MWDNCMemoryUnitCell
|
||||||
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
from adnc.model.utils import oneplus, layer_norm, unit_simplex_initialization
|
||||||
|
|
||||||
|
"""
|
||||||
|
The content-based memory unit with multi write heads.
|
||||||
|
"""
|
||||||
|
|
||||||
class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):
|
class MWContentMemoryUnitCell(MWDNCMemoryUnitCell):
|
||||||
|
|
||||||
|
@ -20,6 +20,9 @@ from adnc.model.utils import layer_norm
|
|||||||
from adnc.model.utils import oneplus
|
from adnc.model.utils import oneplus
|
||||||
from adnc.model.utils import unit_simplex_initialization
|
from adnc.model.utils import unit_simplex_initialization
|
||||||
|
|
||||||
|
"""
|
||||||
|
The vanilla DNC memory unit with multi write heads.
|
||||||
|
"""
|
||||||
|
|
||||||
class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
|
class MWDNCMemoryUnitCell(BaseMemoryUnitCell):
|
||||||
def __init__(self, input_size, memory_length, memory_width, read_heads, write_heads, bypass_dropout=False,
|
def __init__(self, input_size, memory_length, memory_width, read_heads, write_heads, bypass_dropout=False,
|
||||||
|
@ -14,6 +14,9 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
"""
|
||||||
|
The Optimizer clas is a wrapper for the TF optimizer and performs gradient clipping and weight decay
|
||||||
|
"""
|
||||||
|
|
||||||
class Optimizer:
|
class Optimizer:
|
||||||
def __init__(self, config, loss, variables, use_locking=False):
|
def __init__(self, config, loss, variables, use_locking=False):
|
||||||
|
@ -22,6 +22,10 @@ from shutil import copyfile
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
"""
|
||||||
|
The supporter class creates for each training run a folder, logs the prints in a file and saves the weights,
|
||||||
|
gradients and losses.
|
||||||
|
"""
|
||||||
|
|
||||||
class color_code:
|
class color_code:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -15,6 +15,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
"""
|
||||||
|
EarlyStop has a true call return if the loss was the last "list_len" higher as the loss before.
|
||||||
|
"""
|
||||||
|
|
||||||
class EarlyStop():
|
class EarlyStop():
|
||||||
def __init__(self, list_len=5):
|
def __init__(self, list_len=5):
|
||||||
|
@ -21,6 +21,9 @@ import hashlib
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
"""
|
||||||
|
Downloads and process glove word embeddings, applies them to a given vocabulary of a dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
class WordEmbedding():
|
class WordEmbedding():
|
||||||
def __init__(self, embedding_size, vocabulary_size=None, word_idx_dict=None, initialization='uniform', tmp_dir='.',
|
def __init__(self, embedding_size, vocabulary_size=None, word_idx_dict=None, initialization='uniform', tmp_dir='.',
|
||||||
|
@ -25,6 +25,11 @@ import tensorflow as tf
|
|||||||
from adnc.data.loader import DataLoader
|
from adnc.data.loader import DataLoader
|
||||||
from adnc.model.mann import MANN
|
from adnc.model.mann import MANN
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script performs a inference with the given models of this repository on the bAbI task 1 or on 1-20. Please add the
|
||||||
|
model name when calling the script. (dnc, adnc, biadnc, biadnc-all, biadnc-aug16-all)
|
||||||
|
"""
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Load model')
|
parser = argparse.ArgumentParser(description='Load model')
|
||||||
parser.add_argument('model', type=str, default=False, help='model name')
|
parser.add_argument('model', type=str, default=False, help='model name')
|
||||||
model_name = parser.parse_args().model
|
model_name = parser.parse_args().model
|
||||||
|
@ -22,6 +22,10 @@ from tqdm import tqdm
|
|||||||
from adnc.data.loader import DataLoader
|
from adnc.data.loader import DataLoader
|
||||||
from adnc.model.mann import MANN
|
from adnc.model.mann import MANN
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script performs a inference with the given models of this repository on the CNN RC task.
|
||||||
|
"""
|
||||||
|
|
||||||
expt_dir = "experiments/pre_trained/cnn_rc_task/adnc"
|
expt_dir = "experiments/pre_trained/cnn_rc_task/adnc"
|
||||||
|
|
||||||
# load config file
|
# load config file
|
||||||
|
@ -25,6 +25,11 @@ from adnc.analysis import Bucket, PlotFunctionality
|
|||||||
|
|
||||||
tf.reset_default_graph()
|
tf.reset_default_graph()
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script plot the memory unit functionality with the given models of this repository on the bAbI task 1 or on 1-20.
|
||||||
|
Please add the model name when calling the script. (dnc, adnc, biadnc, biadnc-all, biadnc-aug16-all)
|
||||||
|
"""
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Load model')
|
parser = argparse.ArgumentParser(description='Load model')
|
||||||
parser.add_argument('model', type=str, default=False, help='model name')
|
parser.add_argument('model', type=str, default=False, help='model name')
|
||||||
model_name = parser.parse_args().model
|
model_name = parser.parse_args().model
|
||||||
|
@ -27,6 +27,11 @@ from adnc.data import DataLoader
|
|||||||
from adnc.model import MANN, Optimizer, Supporter
|
from adnc.model import MANN, Optimizer, Supporter
|
||||||
from adnc.model.utils import EarlyStop
|
from adnc.model.utils import EarlyStop
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script performs starts a training run on the bAbI task. The training can be fully configured in the config.yml
|
||||||
|
file. To restore a session use the --sess and --check flag.
|
||||||
|
"""
|
||||||
|
|
||||||
tf.reset_default_graph()
|
tf.reset_default_graph()
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Process some integers.')
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||||
@ -63,7 +68,7 @@ valid_loader = dl.get_data_loader('valid') # gets a valid data iterator
|
|||||||
train_loader = dl.get_data_loader('train') # gets a train data iterator
|
train_loader = dl.get_data_loader('train') # gets a train data iterator
|
||||||
|
|
||||||
if analyse:
|
if analyse:
|
||||||
ana = Analyser(dataset_name, sp.session_dir, save_fig=plot_process,
|
ana = Analyser(sp.session_dir, save_fig=plot_process,
|
||||||
save_variables=True) # initilizes a analyzer class
|
save_variables=True) # initilizes a analyzer class
|
||||||
|
|
||||||
sp.config(model_type)['input_size'] = dl.x_size # after the data loader is initilized, the input size
|
sp.config(model_type)['input_size'] = dl.x_size # after the data loader is initilized, the input size
|
||||||
|
Loading…
Reference in New Issue
Block a user