ADNC/scripts/start_training.py
2018-07-11 14:35:33 +02:00

234 lines
11 KiB
Python
Executable File

#!/usr/bin/env python
# Copyright 2018 Jörg Franke
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import os
import sys
import time
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from adnc.analysis import Analyser
from adnc.data import DataLoader
from adnc.model import MANN, Optimizer, Supporter
from adnc.model.utils import EarlyStop
"""
This script performs starts a training run on the bAbI task. The training can be fully configured in the config.yml
file. To restore a session use the --sess and --check flag.
"""
tf.reset_default_graph()
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--sess', type=int, default=False, help='session number')
parser.add_argument('--check', type=int, default=False, help='restore checkpoint')
args = parser.parse_args()
session_no = args.sess # allows to restore a specific session
if not session_no:
session_no = False
restore_checkpoint = args.check # allows to restore a specific checkpoint
if not restore_checkpoint:
restore_checkpoint = False
dataset_name = 'babi_task' # defines the dataset choosen from config
model_type = 'mann' # type of model, currently only 'mann'
experiment_name = 'github_example' # name of the experiment
project_dir = 'experiments/' # folder to save experiments
config_file = 'config.yml' # name of config file
early_stop = EarlyStop(10) # initialize early stopping after 10 higher losses in a row
analyse = True # allows a closer analysis of the training progress, like memory influence
plot_process = True # plots a function plot after each epoch
sp = Supporter(project_dir, config_file, experiment_name, dataset_name, model_type,
session_no) # initializes supporter class for experiment handling
dl = DataLoader(sp.config(dataset_name)) # initializes data loader class
valid_loader = dl.get_data_loader('valid') # gets a valid data iterator
train_loader = dl.get_data_loader('train') # gets a train data iterator
if analyse:
ana = Analyser(sp.session_dir, save_fig=plot_process,
save_variables=True) # initilizes a analyzer class
sp.config(model_type)['input_size'] = dl.x_size # after the data loader is initilized, the input size
sp.config(model_type)['output_size'] = dl.y_size # and output size is known and used for the model
model = MANN(sp.config('mann'), analyse) # initilizes the model class
data, target, mask = model.feed # TF data, target and mask placeholders for training
trainer = Optimizer(sp.config('training'), model.loss,
model.trainable_variables) # initilizes a trainer class with the optimizer
optimizer = trainer.optimizer # the optimizer for training, similar to TF
init_op = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=30)
summary_train_loss = tf.summary.scalar("train_loss", model.loss)
summary_valid_loss = tf.summary.scalar("valid_loss", model.loss)
lstm_scale = tf.summary.scalar("lstm_scale", tf.reduce_mean(model.trainable_variables[2]))
lstm_beta = tf.summary.scalar("lstm_beta", tf.reduce_mean(model.trainable_variables[3]))
sp.pub("vocabulary size: {}".format(dl.vocabulary_size)) # prints values and logs it to a log file
sp.pub("train set length: {}".format(dl.sample_amount('train')))
sp.pub("train batch amount: {}".format(dl.batch_amount('train')))
sp.pub("valid set length: {}".format(dl.sample_amount('valid')))
sp.pub("valid batch amount: {}".format(dl.batch_amount('valid')))
sp.pub("model parameter amount: {}".format(model.parameter_amount))
conf = tf.ConfigProto() # TF session config for optimal GPU usage
conf.gpu_options.per_process_gpu_memory_fraction = 0.8
conf.gpu_options.allocator_type = 'BFC'
conf.gpu_options.allow_growth = True
conf.allow_soft_placement = True
with tf.Session(config=conf) as sess:
if sp.restore and restore_checkpoint: # restores model dumps after a crash or to continiue training
saver.restore(sess, os.path.join(sp.session_dir, "model_dump_{}.ckpt".format(restore_checkpoint)))
epoch_start = restore_checkpoint + 1
sp.pub("restart training with checkpoint {}".format(epoch_start - 1))
elif sp.restore and not restore_checkpoint:
if tf.train.latest_checkpoint(sp.session_dir) == None:
sess.run(init_op)
epoch_start = 0
sp.pub("start new training")
else:
saver.restore(sess, tf.train.latest_checkpoint(sp.session_dir))
epoch_start = int(tf.train.latest_checkpoint(sp.session_dir).split('_')[-1].split('.')[0]) + 1
sp.pub("restart training with checkpoint {}".format(epoch_start - 1))
else:
sess.run(init_op)
epoch_start = 0
sp.pub("start new training")
writer = tf.summary.FileWriter(os.path.join(sp.session_dir, "summary"), sess.graph)
for e in range(epoch_start, sp.config('training')['epochs']): # loop over all training epochs
train_cost = 0
train_count = 0
all_corrects = 0
all_overall = 0
time_e = time.time()
time_0 = time.time()
for step in tqdm(range(int(dl.batch_amount('train')))): # loop over all training samples
sample = next(train_loader) # new training sample from train iterator
_, c, summary, lb, ls = sess.run([optimizer, model.loss, summary_train_loss, lstm_beta, lstm_scale],
feed_dict={data: sample['x'], target: sample['y'], mask: sample['m']})
train_cost += c
train_count += 1
writer.add_summary(summary, e * dl.batch_amount('train') + step)
writer.add_summary(lb, e * dl.batch_amount('train') + step)
writer.add_summary(ls, e * dl.batch_amount('train') + step)
valid_cost = 0
valid_count = 0
for v in range(int(dl.batch_amount('valid'))): # loop over all validation samples
vsample = next(valid_loader)
vcost, vpred, summary = sess.run([model.loss, model.prediction, summary_valid_loss],
feed_dict={data: vsample['x'], target: vsample['y'], mask: vsample['m']})
valid_cost += vcost
valid_count += 1
writer.add_summary(summary, e * dl.batch_amount('valid') + v)
tm = np.argmax(vsample['y'], axis=-1) # calculates the word error rate
pm = np.argmax(vpred, axis=-1)
corrects = np.equal(tm, pm)
all_corrects += np.sum(corrects * vsample['m'])
all_overall += np.sum(vsample['m'])
valid_cost = valid_cost / valid_count
train_cost = train_cost / train_count
word_error_rate = 1 - (all_corrects / all_overall)
if not np.isnan(valid_cost): # checks NAN
save_path = saver.save(sess,
os.path.join(sp.session_dir, "model_dump_{}.ckpt".format(e))) # dumps model weights
if analyse: # if analysis, it logs memory influence and plots functionality
controller_inf = []
memory_inf = []
all_corrects = 0
all_overall = 0
for vstep in range(
10): # makes ten valid inferneces to get the gradiens for memory influence calculation
vsample = next(valid_loader)
analyse_values, prediction, gradients = sess.run(
[model.analyse, model.prediction, trainer.gradients],
feed_dict={data: vsample['x'], target: vsample['y'], mask: vsample['m']})
weights = {v.name: {'var': g[1], 'grad': g[0], 'shape': g[0].shape} for v, g in
zip(model.trainable_variables, gradients)}
if 'x_word' not in vsample.keys():
vsample['x_word'] = np.transpose(np.argmax(vsample['x'], axis=-1), (1, 0))
data_sample = [vsample['x'], vsample['y'], vsample['m'], vsample['x_word'], ]
decoded_targets, decoded_predictions = dl.decode_output(vsample, prediction)
save_list = [analyse_values, prediction, decoded_predictions, data_sample, weights]
co_inf, mu_inf = ana.feed_variables_two(save_list, e, name="states_epoch",
save_plot=vstep) # calculates the memory influence
controller_inf.append(co_inf)
memory_inf.append(mu_inf)
controller_inf = np.mean(controller_inf)
memory_inf = np.mean(memory_inf)
writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='wer', simple_value=word_error_rate)]),
e * dl.batch_amount('train') + step)
writer.add_summary(
tf.Summary(value=[tf.Summary.Value(tag='controller_inf', simple_value=controller_inf)]),
e * dl.batch_amount('train') + step)
writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='memory_inf', simple_value=memory_inf)]),
e * dl.batch_amount('train') + step)
sp.pub(
"epoch {:3}, step {:5}, train cost {:4.3f}, valid cost {:4.3f}, wer {:4.3f}, controller influence {:4.3f}, "
"memory influence {:4.3f}, duration {:5.1f}sec, time: {}, Model saved in {}".format(
e, step, train_cost, valid_cost, word_error_rate, controller_inf, memory_inf,
time.time() - time_0, sp.time_stamp(), save_path))
sp.monitor(["epoch", "step", "train cost", "valid cost", "duration", "controller influence",
"memory influence", "wer"],
[e, step, train_cost, valid_cost, time.time() - time_0, controller_inf, memory_inf,
word_error_rate]) # saves the values in an numpy array for later analysis
else:
sp.pub(
"epoch {:3}, step {:5}, train cost {:4.3f}, valid cost {:4.3f}, duration {:5.1f}sec, time: {}, Model saved in {}".format(
e, step, train_cost, valid_cost, time.time() - time_0, sp.time_stamp(), save_path))
sp.monitor(["epoch", "step", "train cost", "valid cost", "duration"],
[e, step, train_cost, valid_cost, time.time() - time_0])
else:
sp.pub("ERROR: nan in training")
sys.exit("NAN") # end training in case of NAN
if early_stop(valid_cost):
sp.pub("EARLYSTOP: valid error increase")
sys.exit("EARLYSTOP") # end training when valid loss increases, early stopping