2018-07-06 14:28:17 +08:00
|
|
|
#!/usr/bin/env python
|
2018-06-25 19:47:57 +08:00
|
|
|
# Copyright 2018 Jörg Franke
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# ==============================================================================
|
|
|
|
import os
|
2018-07-05 05:59:24 +08:00
|
|
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "" # gpu not required for inference
|
2018-07-06 04:11:00 +08:00
|
|
|
|
2018-07-05 05:59:24 +08:00
|
|
|
import argparse
|
2018-06-25 19:47:57 +08:00
|
|
|
import yaml
|
|
|
|
import numpy as np
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
from adnc.data.loader import DataLoader
|
|
|
|
from adnc.model.mann import MANN
|
|
|
|
|
2018-07-11 20:35:33 +08:00
|
|
|
"""
|
|
|
|
This script performs a inference with the given models of this repository on the bAbI task 1 or on 1-20. Please add the
|
|
|
|
model name when calling the script. (dnc, adnc, biadnc, biadnc-all, biadnc-aug16-all)
|
|
|
|
"""
|
|
|
|
|
2018-07-05 05:59:24 +08:00
|
|
|
parser = argparse.ArgumentParser(description='Load model')
|
|
|
|
parser.add_argument('model', type=str, default=False, help='model name')
|
|
|
|
model_name = parser.parse_args().model
|
2018-06-25 19:47:57 +08:00
|
|
|
|
|
|
|
# Choose a pre trained model by uncomment
|
2018-07-05 05:59:24 +08:00
|
|
|
if model_name == 'dnc':
|
|
|
|
expt_dir = "experiments/pre_trained/babi_task_1/dnc" # DNC trained on bAbI tasks 1
|
|
|
|
elif model_name == 'adnc':
|
|
|
|
expt_dir = "experiments/pre_trained/babi_task_1/adnc" # ADNC trained on bAbI tasks 1
|
|
|
|
elif model_name == 'biadnc':
|
|
|
|
expt_dir = "experiments/pre_trained/babi_task_1/biadnc" # BiADNC trained on bAbI tasks 1
|
|
|
|
elif model_name == 'biadnc-all':
|
|
|
|
expt_dir = "experiments/pre_trained/babi_task_all/biadnc" # BiADNC trained on all bAbI tasks
|
|
|
|
else:
|
|
|
|
expt_dir = "experiments/pre_trained/babi_task_all/biadnc_aug16" # BiADNC trained on all bAbI tasks with task 16 augmentation
|
2018-06-25 19:47:57 +08:00
|
|
|
|
|
|
|
config_file = 'config.yml'
|
2018-07-05 05:59:24 +08:00
|
|
|
with open(os.path.join(expt_dir, config_file), 'r') as f:
|
|
|
|
configs = yaml.load(f) # load config from file
|
2018-06-25 19:47:57 +08:00
|
|
|
|
|
|
|
dataset_config = configs['babi_task']
|
|
|
|
model_config = configs['mann']
|
|
|
|
|
|
|
|
dataset_config['batch_size'] = 1
|
|
|
|
model_config['batch_size'] = 1
|
|
|
|
|
2018-07-05 05:59:24 +08:00
|
|
|
dataset_config['threads'] = 1 # only one thread for data loading
|
|
|
|
dataset_config['max_len'] = 1921 # set max length to maximal
|
|
|
|
dataset_config['augment16'] = False # disable augmentation for inference
|
2018-06-25 19:47:57 +08:00
|
|
|
|
|
|
|
if dataset_config['task_selection'] == ['all']:
|
2018-07-05 05:59:24 +08:00
|
|
|
task_list = [i + 1 for i in range(20)]
|
2018-06-25 19:47:57 +08:00
|
|
|
else:
|
|
|
|
task_list = [int(i) for i in dataset_config['task_selection']]
|
|
|
|
|
2018-07-05 05:59:24 +08:00
|
|
|
dl = DataLoader(dataset_config) # load data loader by config
|
2018-06-25 19:47:57 +08:00
|
|
|
|
2018-07-05 05:59:24 +08:00
|
|
|
model_config['input_size'] = dl.x_size # add data size to model config
|
2018-06-25 19:47:57 +08:00
|
|
|
model_config['output_size'] = dl.y_size
|
2018-07-05 05:59:24 +08:00
|
|
|
model_config['memory_unit_config']['bypass_dropout'] = False # no dropout during inference
|
2018-06-25 19:47:57 +08:00
|
|
|
|
2018-07-05 05:59:24 +08:00
|
|
|
model = MANN(model_config) # load memory augmented neural network model
|
2018-06-25 19:47:57 +08:00
|
|
|
|
2018-07-05 05:59:24 +08:00
|
|
|
data, target, mask = model.feed # create data feed for session run
|
2018-06-25 19:47:57 +08:00
|
|
|
|
2018-07-05 05:59:24 +08:00
|
|
|
word_dict = dl.dataset.word_dict # full dictionary of all tasks
|
|
|
|
re_word_dict = dl.dataset.re_word_dict # reverse dictionary
|
2018-06-25 19:47:57 +08:00
|
|
|
|
|
|
|
print("vocabulary size: {}".format(dl.vocabulary_size))
|
|
|
|
print("train set length: {}".format(dl.sample_amount('train')))
|
|
|
|
print("valid set length: {}".format(dl.sample_amount('valid')))
|
|
|
|
print("model parameter amount: {}".format(model.parameter_amount))
|
|
|
|
|
|
|
|
saver = tf.train.Saver()
|
|
|
|
conf = tf.ConfigProto()
|
|
|
|
conf.gpu_options.per_process_gpu_memory_fraction = 0.8
|
|
|
|
conf.gpu_options.allocator_type = 'BFC'
|
|
|
|
conf.gpu_options.allow_growth = True
|
|
|
|
|
|
|
|
with tf.Session(config=conf) as sess:
|
|
|
|
saver.restore(sess, os.path.join(expt_dir, "model_dump.ckpt"))
|
|
|
|
mean_error = []
|
|
|
|
for task in task_list:
|
|
|
|
|
|
|
|
# load data loader for task
|
|
|
|
dataset_config['task_selection'] = [task]
|
|
|
|
dl = DataLoader(dataset_config, word_dict, re_word_dict)
|
|
|
|
valid_loader = dl.get_data_loader('test')
|
|
|
|
|
|
|
|
predictions, targets, masks = [], [], []
|
|
|
|
all_corrects, all_overall = 0, 0
|
|
|
|
|
|
|
|
# infer model
|
|
|
|
for v in range(int(dl.batch_amount('test'))):
|
|
|
|
sample = next(valid_loader)
|
2018-07-05 05:59:24 +08:00
|
|
|
prediction = sess.run([model.prediction, ], feed_dict={data: sample['x'],
|
|
|
|
target: sample['y'],
|
|
|
|
mask: sample['m']})
|
2018-06-25 19:47:57 +08:00
|
|
|
predictions.append(prediction)
|
|
|
|
targets.append(sample['y'])
|
|
|
|
masks.append(sample['m'])
|
|
|
|
|
|
|
|
# calculate mean error rate for task
|
|
|
|
for p, t, m in zip(predictions, targets, masks):
|
|
|
|
tm = np.argmax(t, axis=-1)
|
|
|
|
pm = np.argmax(p, axis=-1)
|
|
|
|
corrects = np.equal(tm, pm)
|
|
|
|
all_corrects += np.sum(corrects * m)
|
|
|
|
all_overall += np.sum(m)
|
|
|
|
|
2018-07-05 05:59:24 +08:00
|
|
|
word_error_rate = 1 - (all_corrects / all_overall)
|
2018-06-25 19:47:57 +08:00
|
|
|
mean_error.append(word_error_rate)
|
|
|
|
|
|
|
|
print("word error rate task {:2}: {:0.3}".format(task, word_error_rate))
|
|
|
|
print("mean word error rate : {:0.3}".format(np.mean(mean_error)))
|