From 9aeb5a21f17956ee8ac67297c5e7403c1fa240e4 Mon Sep 17 00:00:00 2001 From: Joerg Franke Date: Thu, 5 Jul 2018 01:05:18 +0200 Subject: [PATCH] add start training script --- scripts/config.yml | 106 ++++++++++++++++++ scripts/start_training.py | 228 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 334 insertions(+) create mode 100755 scripts/config.yml create mode 100755 scripts/start_training.py diff --git a/scripts/config.yml b/scripts/config.yml new file mode 100755 index 0000000..5312573 --- /dev/null +++ b/scripts/config.yml @@ -0,0 +1,106 @@ +# Copyright 2018 Jörg Franke +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +####################################### +### Global Configuration ### +####################################### + +global: + batch_size: &batch_size 32 + +####################################### +### Training Configuration ### +####################################### +training: + epochs: 50 # epochs to train + learn_rate: 0.00005 # learning reate for optimizer + optimizer: 'rmsprop' # optimizer [ rmsprop,, adam, momentum, adadelta, adagrad, sgd] + optimizer_config: {'momentum':0.9} # config for optimizer [momentum, nesterov] + gradient_clipping: 10 # gradient clipping value + weight_decay: False # weight decay, False or float + + + +####################################### +### MANN Configuration ### +####################################### +mann: + name: 'mann1' + seed: 245 + input_size: 0 + output_size: 0 + batch_size: *batch_size + input_embedding: False + architecture: 'uni' # bidirectional 172 384 + controller_config: {"num_units":[128], "layer_norm":True, "activation":'tanh', 'cell_type':'clstm', 'connect':'sparse'} + memory_unit_config: {"cell_type":'cbmu', "memory_length":64, "memory_width":32, "read_heads":4, "write_heads": 2, "dnc_norm":True, "bypass_dropout":False, "wgate1":False} + atop_rnn_config: False # {"num_units":[32], "layer_norm":True, "activation":'tanh', 'cell_type':'clstm', 'connect':'sparse', 'attention':False} + output_function: "softmax" # softmax tanh5 linear + output_mask: True + loss_function: 'cross_entropy' # cross_entropy, mse + bw_input_fw: False + + +################################################################### +####### bAbI QA Task ###### +################################################################### +babi_task: + data_set: 'babi' + +# data_dir: 'data_babi/tasks_1-20_v1-2' +# tmp_dir: 'data_dir' + + seed: 876 + valid_ratio: 0.1 # like nature paper + batch_size: *batch_size + max_len: 1000 + + set_type: ['en-10k'] # ['hn-10k', 'en-10k', 'shuffled-10k'] + task_selection: ['1', '2', '12'] # list of number (1-20) or 'all' + augment16: False # augmentation of task 16 + + num_chached: 5 # number of cached samples + threads: 1 # number of parallel threads + + + +################################################################## +###### Copy Task ###### +################################################################## +copy_task: + data_set: 'copy_task' + + seed: 125 + batch_size: *batch_size + + set_list: + train: + quantity: 6000 # quantity of the training set + min_length: 20 # min length of the training sample + max_length: 50 # max length of a training sample + valid: + quantity: 600 # quantity of the validation set + min_length: 50 + max_length: 100 +# test: +# quantity: 100 +# min_length: 10 +# max_length: 10 + + feature_width: 100 # width of the feature vector + + num_chached: 10 # number of cached samples + threads: 1 # number of parallel threads + diff --git a/scripts/start_training.py b/scripts/start_training.py new file mode 100755 index 0000000..bc556ee --- /dev/null +++ b/scripts/start_training.py @@ -0,0 +1,228 @@ +# Copyright 2018 Jörg Franke +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import argparse +import os +import sys +import time + +import numpy as np +import tensorflow as tf +from tqdm import tqdm + +from adnc.model import MANN, Optimizer, Supporter +from adnc.analysis import Analyser +from adnc.data import DataLoader +from adnc.model.utils import EarlyStop + +tf.reset_default_graph() + +parser = argparse.ArgumentParser(description='Process some integers.') +parser.add_argument('--sess', type=int, default=False, help='session number') +parser.add_argument('--check', type=int, default=False, help='restore checkpoint') +args = parser.parse_args() + +session_no = args.sess # allows to restore a specific session +if not session_no: + session_no = False + +restore_checkpoint = args.check # allows to restore a specific checkpoint +if not restore_checkpoint: + restore_checkpoint = False + + + +data_set_name = 'babi_task' +model_type = 'mann' + +experiment_name = 'github_example' + + +project_dir = 'experiments/' +config_file = 'config.yml' + +early_stop = EarlyStop(10) + + + +analyse = True +plot_process = True + + + + +sp = Supporter(project_dir, config_file, experiment_name, data_set_name, model_type, session_no) + +data_set_config = sp.config(data_set_name) + +dl = DataLoader(data_set_config) +valid_loader = dl.get_data_loader('valid') +train_loader = dl.get_data_loader('train') + + +if analyse: + ana = Analyser(data_set_name, sp.session_dir, save_fig=plot_process) + + + +sp.config(model_type)['input_size'] = dl.x_size +sp.config(model_type)['output_size'] = dl.y_size +model = MANN(sp.config('mann'), analyse) + +data, target, mask = model.feed + + +trainer = Optimizer(sp.config('training'), model.loss, model.trainable_variables) +optimizer = trainer.optimizer + +init_op = tf.global_variables_initializer() +saver = tf.train.Saver(max_to_keep=30) + +summary_train_loss = tf.summary.scalar("train_loss", model.loss) +summary_valid_loss = tf.summary.scalar("valid_loss", model.loss) + +lstm_scale = tf.summary.scalar("lstm_scale", tf.reduce_mean(model.trainable_variables[2])) +lstm_beta = tf.summary.scalar("lstm_beta", tf.reduce_mean(model.trainable_variables[3])) + +sp.pub("vocabulary size: {}".format(dl.vocabulary_size)) +sp.pub("train set length: {}".format(dl.sample_amount('train'))) +sp.pub("train batch amount: {}".format(dl.batch_amount('train'))) +sp.pub("valid set length: {}".format(dl.sample_amount('valid'))) +sp.pub("valid batch amount: {}".format(dl.batch_amount('valid'))) +sp.pub("model parameter amount: {}".format(model.parameter_amount)) + + +conf = tf.ConfigProto() +conf.gpu_options.per_process_gpu_memory_fraction = 0.8 +conf.gpu_options.allocator_type = 'BFC' +conf.gpu_options.allow_growth = True +conf.allow_soft_placement = True + +with tf.Session(config=conf) as sess: + + if sp.restore and restore_checkpoint: + saver.restore(sess, os.path.join(sp.session_dir, "model_dump_{}.ckpt".format(restore_checkpoint))) + epoch_start = restore_checkpoint + 1 + sp.pub("restart training with checkpoint {}".format(epoch_start - 1)) + elif sp.restore and not restore_checkpoint: + if tf.train.latest_checkpoint(sp.session_dir) == None: + sess.run(init_op) + epoch_start = 0 + sp.pub("start new training") + else: + saver.restore(sess,tf.train.latest_checkpoint(sp.session_dir)) + epoch_start = int(tf.train.latest_checkpoint(sp.session_dir).split('_')[-1].split('.')[0]) + 1 + sp.pub("restart training with checkpoint {}".format(epoch_start - 1)) + else: + sess.run(init_op) + epoch_start = 0 + sp.pub("start new training") + + writer = tf.summary.FileWriter(os.path.join(sp.session_dir, "summary"), sess.graph) + + for e in range(epoch_start, sp.config('training')['epochs']): + + train_cost = 0 + train_count = 0 + all_corrects = 0 + all_overall = 0 + time_e = time.time() + time_0 = time.time() + + for step in tqdm(range(int(dl.batch_amount('train')))): + + sample = next(train_loader) + + _, c, summary, lb, ls = sess.run([optimizer, model.loss, summary_train_loss, lstm_beta, lstm_scale],feed_dict={data: sample['x'], target: sample['y'], mask: sample['m']}) + train_cost += c + train_count += 1 + writer.add_summary(summary, e * dl.batch_amount('train') + step) + writer.add_summary(lb, e * dl.batch_amount('train') + step) + writer.add_summary(ls, e * dl.batch_amount('train') + step) + + valid_cost = 0 + valid_count = 0 + + for v in range(int(dl.batch_amount('valid'))): + vsample = next(valid_loader) + vcost, vpred, summary = sess.run([model.loss, model.prediction, summary_valid_loss],feed_dict={data: vsample['x'], target: vsample['y'], mask: vsample['m']}) + valid_cost += vcost + valid_count += 1 + writer.add_summary(summary, e * dl.batch_amount('valid') + v) + tm = np.argmax(vsample['y'], axis=-1) + pm = np.argmax(vpred, axis=-1) + corrects = np.equal(tm, pm) + all_corrects += np.sum(corrects * vsample['m']) + all_overall += np.sum(vsample['m']) + + valid_cost = valid_cost / valid_count + train_cost = train_cost /train_count + word_error_rate = 1- (all_corrects/all_overall) + + if not np.isnan(valid_cost): + + save_path = saver.save(sess, os.path.join(sp.session_dir ,"model_dump_{}.ckpt".format(e))) + + if analyse: + + controller_inf = [] + memory_inf = [] + all_corrects = 0 + all_overall = 0 + + for vstep in range(10): + vsample = next(valid_loader) + + analyse_values, prediction, gradients = sess.run([model.analyse, model.prediction, trainer.gradients], + feed_dict={data: vsample['x'], target: vsample['y'], mask: vsample['m']}) + weights = {v.name: {'var':g[1], 'grad':g[0], 'shape':g[0].shape } for v, g in zip(model.trainable_variables, gradients)} + if 'x_word' not in vsample.keys(): + vsample['x_word'] = np.transpose(np.argmax(vsample['x'], axis=-1),(1,0)) + data_sample = [vsample['x'], vsample['y'], vsample['m'], vsample['x_word'],] + + + decoded_targets, decoded_predictions = dl.decode_output(vsample, prediction) + + save_list = [analyse_values, prediction, decoded_predictions, data_sample, weights ] + + co_inf, mu_inf = ana.feed_variables_two(save_list, e, name="states_epoch", save_plot=vstep) + controller_inf.append(co_inf) + memory_inf.append(mu_inf) + + + controller_inf = np.mean(controller_inf) + memory_inf = np.mean(memory_inf) + + writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='wer', simple_value=word_error_rate)]), e * dl.batch_amount('train') + step) + writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='controller_inf', simple_value=controller_inf)]), e * dl.batch_amount('train') + step) + writer.add_summary(tf.Summary(value=[tf.Summary.Value(tag='memory_inf', simple_value=memory_inf)]), e * dl.batch_amount('train') + step) + + + sp.pub("epoch {:3}, step {:5}, train cost {:4.3f}, valid cost {:4.3f}, wer {:4.3f}, controller influence {:4.3f}, " + "memory influence {:4.3f}, duration {:5.1f}sec, time: {}, Model saved in {}".format( + e, step, train_cost, valid_cost, word_error_rate, controller_inf, memory_inf, time.time() - time_0, sp.time_stamp(), save_path)) + sp.monitor(["epoch", "step", "train cost", "valid cost", "duration", "controller influence", "memory influence", "wer"], + [e, step, train_cost, valid_cost, time.time() - time_0, controller_inf, memory_inf, word_error_rate]) + + else: + sp.pub("epoch {:3}, step {:5}, train cost {:4.3f}, valid cost {:4.3f}, duration {:5.1f}sec, time: {}, Model saved in {}".format( + e, step, train_cost, valid_cost, time.time() - time_0, sp.time_stamp(), save_path)) + sp.monitor(["epoch", "step", "train cost", "valid cost", "duration"], [e, step, train_cost, valid_cost, time.time() - time_0]) + else: + sp.pub("ERROR: nan in training") + sys.exit("NAN") + + if early_stop(valid_cost): + sp.pub("EARLYSTOP: valid error increase") + sys.exit("EARLYSTOP")