mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
add pre trained models and inference babi task script
This commit is contained in:
parent
ce06bdba9a
commit
88bbe5f5aa
58
experiments/pre_trained/babi_task_1/adnc/config.yml
Normal file
58
experiments/pre_trained/babi_task_1/adnc/config.yml
Normal file
@ -0,0 +1,58 @@
|
||||
|
||||
#######################################
|
||||
### Global Configuration ###
|
||||
#######################################
|
||||
global:
|
||||
batch_size: &batch_size 32
|
||||
|
||||
#######################################
|
||||
### Training Configuration ###
|
||||
#######################################
|
||||
training:
|
||||
epochs: 75
|
||||
learn_rate: 0.0001
|
||||
optimizer: 'rmsprop'
|
||||
optimizer_config: {'momentum':0.9}
|
||||
gradient_clipping: 10
|
||||
weight_decay: False
|
||||
|
||||
|
||||
|
||||
#######################################
|
||||
### MANN Configuration ###
|
||||
#######################################
|
||||
mann:
|
||||
name: 'mann1'
|
||||
seed: 987
|
||||
input_size: 0
|
||||
output_size: 0
|
||||
batch_size: *batch_size
|
||||
input_embedding: False
|
||||
architecture: 'uni'
|
||||
controller_config: {"num_units":[64], "layer_norm":True, "activation":'tanh', 'cell_type':'clstm', 'connect':'sparse', 'attention':False}
|
||||
memory_unit_config: {"cell_type":'cbmu', "memory_length":128, "memory_width":32, "read_heads":2, "write_heads": 1, "dnc_norm":True, "bypass_dropout":0.8}
|
||||
output_function: "softmax"
|
||||
output_mask: True
|
||||
loss_function: 'cross_entropy'
|
||||
|
||||
|
||||
|
||||
###################################################################
|
||||
####### bAbI QA Task ######
|
||||
###################################################################
|
||||
babi_task:
|
||||
data_set: 'babi'
|
||||
|
||||
load_test: True
|
||||
load_vocab: True
|
||||
|
||||
seed: 212
|
||||
valid_ratio: 0.1 # like nature paper
|
||||
batch_size: *batch_size
|
||||
max_len: 500
|
||||
|
||||
set_type: ['en-10k']
|
||||
task_selection: [1]
|
||||
|
||||
num_chached: 5
|
||||
threads: 1
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3977630485b36af3ab58f9b8cf29f6f12a934f055c3821710a62785b5138aae3
|
||||
size 639880
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9ffc1a056f5937cc6d44f75ae147894093e4fd6a41ddc11559773746a442bfb1
|
||||
size 1415
|
58
experiments/pre_trained/babi_task_1/biadnc/config.yml
Normal file
58
experiments/pre_trained/babi_task_1/biadnc/config.yml
Normal file
@ -0,0 +1,58 @@
|
||||
|
||||
#######################################
|
||||
### Global Configuration ###
|
||||
#######################################
|
||||
global:
|
||||
batch_size: &batch_size 32
|
||||
|
||||
#######################################
|
||||
### Training Configuration ###
|
||||
#######################################
|
||||
training:
|
||||
epochs: 75
|
||||
learn_rate: 0.0001
|
||||
optimizer: 'rmsprop'
|
||||
optimizer_config: {'momentum':0.9}
|
||||
gradient_clipping: 10
|
||||
weight_decay: False
|
||||
|
||||
|
||||
|
||||
#######################################
|
||||
### MANN Configuration ###
|
||||
#######################################
|
||||
mann:
|
||||
name: 'mann1'
|
||||
seed: 347
|
||||
input_size: 0
|
||||
output_size: 0
|
||||
batch_size: *batch_size
|
||||
input_embedding: False
|
||||
architecture: 'bi'
|
||||
controller_config: {"num_units":[32], "layer_norm":True, "activation":'tanh', 'cell_type':'clstm', 'connect':'sparse', 'attention':False}
|
||||
memory_unit_config: {"cell_type":'cbmu', "memory_length":128, "memory_width":32, "read_heads":2, "write_heads": 1, "dnc_norm":True, "bypass_dropout":0.8}
|
||||
output_function: "softmax"
|
||||
output_mask: True
|
||||
loss_function: 'cross_entropy'
|
||||
|
||||
|
||||
|
||||
###################################################################
|
||||
####### bAbI QA Task ######
|
||||
###################################################################
|
||||
babi_task:
|
||||
data_set: 'babi'
|
||||
|
||||
load_test: True
|
||||
load_vocab: True
|
||||
|
||||
seed: 212
|
||||
valid_ratio: 0.1 # like nature paper
|
||||
batch_size: *batch_size
|
||||
max_len: 500
|
||||
|
||||
set_type: ['en-10k']
|
||||
task_selection: [1]
|
||||
|
||||
num_chached: 5
|
||||
threads: 1
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:892ab8c61d529470ccf7965cfc02f00799a69efced7a01501a900dc7ce904b4c
|
||||
size 443272
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a196248ae95aa28f1d719f1ad16f9831a14a0694f30c95a7eab24e159553a216
|
||||
size 2114
|
57
experiments/pre_trained/babi_task_1/dnc/config.yml
Normal file
57
experiments/pre_trained/babi_task_1/dnc/config.yml
Normal file
@ -0,0 +1,57 @@
|
||||
#######################################
|
||||
### Global Configuration ###
|
||||
#######################################
|
||||
global:
|
||||
batch_size: &batch_size 32
|
||||
|
||||
#######################################
|
||||
### Training Configuration ###
|
||||
#######################################
|
||||
training:
|
||||
epochs: 75
|
||||
learn_rate: 0.0001
|
||||
optimizer: 'rmsprop'
|
||||
optimizer_config: {'momentum':0.9}
|
||||
gradient_clipping: 10
|
||||
weight_decay: False
|
||||
|
||||
|
||||
|
||||
#######################################
|
||||
### MANN Configuration ###
|
||||
#######################################
|
||||
mann:
|
||||
name: 'mann1'
|
||||
seed: 148
|
||||
input_size: 0
|
||||
output_size: 0
|
||||
batch_size: *batch_size
|
||||
input_embedding: False
|
||||
architecture: 'uni'
|
||||
controller_config: {"num_units":[64], "layer_norm":True, "activation":'tanh', 'cell_type':'clstm', 'connect':'sparse', 'attention':False}
|
||||
memory_unit_config: {"cell_type":'dnc', "memory_length":128, "memory_width":32, "read_heads":2, "write_heads": 1, "dnc_norm":False, "bypass_dropout":False}
|
||||
output_function: "softmax"
|
||||
output_mask: True
|
||||
loss_function: 'cross_entropy'
|
||||
|
||||
|
||||
###################################################################
|
||||
####### bAbI QA Task ######
|
||||
###################################################################
|
||||
babi_task:
|
||||
data_set: 'babi'
|
||||
|
||||
load_test: True
|
||||
load_vocab: True
|
||||
|
||||
seed: 212
|
||||
valid_ratio: 0.1 # like nature paper
|
||||
batch_size: *batch_size
|
||||
max_len: 500
|
||||
|
||||
set_type: ['en-10k']
|
||||
task_selection: [1]
|
||||
|
||||
num_chached: 5
|
||||
threads: 1
|
||||
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dd7d2f452ce5248b102b78a9aad31af1bc49ee296c883d311dd629601904415b
|
||||
size 640552
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5ab09c717bfcd1aa920429444da94fc7bfd132b12ad3faa0d5b96e52b40a7f76
|
||||
size 1224
|
@ -0,0 +1,57 @@
|
||||
#######################################
|
||||
### Global Configuration ###
|
||||
#######################################
|
||||
|
||||
global:
|
||||
batch_size: &batch_size 32
|
||||
|
||||
#######################################
|
||||
### Training Configuration ###
|
||||
#######################################
|
||||
training:
|
||||
epochs: 100
|
||||
learn_rate: 0.00003
|
||||
optimizer: 'rmsprop' # rmsprop,
|
||||
optimizer_config: {'momentum':0.9}
|
||||
gradient_clipping: 10
|
||||
weight_decay: False
|
||||
|
||||
#######################################
|
||||
### MANN Configuration ###
|
||||
#######################################
|
||||
mann:
|
||||
name: 'mann1'
|
||||
seed: 486
|
||||
input_size: 0
|
||||
output_size: 0
|
||||
batch_size: *batch_size
|
||||
input_embedding: False
|
||||
architecture: 'bi'
|
||||
controller_config: {"num_units":[172], "layer_norm":True, "activation":'tanh', 'cell_type':'clstm', 'connect':'sparse', 'attention':False}
|
||||
memory_unit_config: {"cell_type":'cbmu', "memory_length":192, "memory_width":64, "read_heads":4, "write_heads": 1, "dnc_norm":True, "bypass_dropout":0.9, "wgate1":False}
|
||||
output_function: "softmax"
|
||||
output_mask: True
|
||||
loss_function: 'cross_entropy'
|
||||
bw_input_fw: False
|
||||
|
||||
|
||||
###################################################################
|
||||
####### bAbI QA Task ######
|
||||
###################################################################
|
||||
babi_task:
|
||||
data_set: 'babi'
|
||||
|
||||
load_test: True
|
||||
load_vocab: True
|
||||
|
||||
seed: 325
|
||||
valid_ratio: 0.1 # like nature paper
|
||||
batch_size: *batch_size
|
||||
max_len: 1000
|
||||
|
||||
set_type: ['en-10k']
|
||||
task_selection: ['all']
|
||||
augment16: True
|
||||
|
||||
num_chached: 5
|
||||
threads: 1
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1832374db1174765e6c568aa9374861de3cd8211e1dd877cb5177e5c69c6248d
|
||||
size 10694788
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:54958b109d55105fdc6d6a1660b6645bd7ff65f4abee0f322c9572349fc03b5a
|
||||
size 2178
|
59
experiments/pre_trained/babi_task_all/bidanc/config.yml
Normal file
59
experiments/pre_trained/babi_task_all/bidanc/config.yml
Normal file
@ -0,0 +1,59 @@
|
||||
|
||||
#######################################
|
||||
### Global Configuration ###
|
||||
#######################################
|
||||
|
||||
global:
|
||||
batch_size: &batch_size 32
|
||||
|
||||
#######################################
|
||||
### Training Configuration ###
|
||||
#######################################
|
||||
|
||||
training:
|
||||
epochs: 50
|
||||
learn_rate: 0.00003
|
||||
optimizer: 'rmsprop' # rmsprop,
|
||||
optimizer_config: {'momentum':0.9}
|
||||
gradient_clipping: 10
|
||||
weight_decay: False
|
||||
|
||||
|
||||
#######################################
|
||||
### MANN Configuration ###
|
||||
#######################################
|
||||
mann:
|
||||
name: 'mann1'
|
||||
seed: 857
|
||||
input_size: 0
|
||||
output_size: 0
|
||||
batch_size: *batch_size
|
||||
input_embedding: False
|
||||
architecture: 'bi'
|
||||
controller_config: {"num_units":[172], "layer_norm":True, "activation":'tanh', 'cell_type':'clstm', 'connect':'sparse', 'attention':False}
|
||||
memory_unit_config: {"cell_type":'cbmu', "memory_length":128, "memory_width":64, "read_heads":4, "write_heads": 1, "dnc_norm":True, "bypass_dropout":0.9}
|
||||
output_function: "softmax"
|
||||
output_mask: True
|
||||
loss_function: 'cross_entropy'
|
||||
|
||||
###################################################################
|
||||
####### bAbI QA Task ######
|
||||
###################################################################
|
||||
babi_task:
|
||||
data_set: 'babi'
|
||||
|
||||
load_test: True
|
||||
load_vocab: True
|
||||
|
||||
seed: 211
|
||||
valid_ratio: 0.1 # like nature paper
|
||||
batch_size: *batch_size
|
||||
max_len: 500
|
||||
|
||||
set_type: ['en-10k']
|
||||
task_selection: ['all']
|
||||
|
||||
num_chached: 5
|
||||
threads: 1
|
||||
|
||||
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ed6300bfd665a00fb74537d88cbb7559e1f2378d00f552e1ed016bfc60d17c6d
|
||||
size 10694788
|
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:30495139bd63fa8c62b5e92789bbf2e268bbc32c8fc57aa86e8e17175a7c5484
|
||||
size 2178
|
115
inference_babi_task.py
Executable file
115
inference_babi_task.py
Executable file
@ -0,0 +1,115 @@
|
||||
# Copyright 2018 Jörg Franke
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"]="" #gpu not required for inference
|
||||
import yaml
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from adnc.data.loader import DataLoader
|
||||
from adnc.model.mann import MANN
|
||||
|
||||
|
||||
# Choose a pre trained model by uncomment
|
||||
|
||||
# expt_dir = "experiments/pre_trained/babi_task_1/dnc" # DNC trained on bAbI tasks 1
|
||||
# expt_dir = "experiments/pre_trained/babi_task_1/adnc" # ADNC trained on bAbI tasks 1
|
||||
# expt_dir = "experiments/pre_trained/babi_task_1/biadnc" # BiADNC trained on bAbI tasks 1
|
||||
# expt_dir = "experiments/pre_trained/babi_task_all/biadnc" # BiADNC trained on all bAbI tasks
|
||||
expt_dir = "experiments/pre_trained/babi_task_all/biadnc_aug16" # BiADNC trained on all bAbI tasks with task 16 augmentation
|
||||
|
||||
|
||||
config_file = 'config.yml'
|
||||
with open(os.path.join(expt_dir, config_file) , 'r') as f:
|
||||
configs = yaml.load(f) # load config from file
|
||||
|
||||
dataset_config = configs['babi_task']
|
||||
model_config = configs['mann']
|
||||
|
||||
dataset_config['batch_size'] = 1
|
||||
model_config['batch_size'] = 1
|
||||
|
||||
dataset_config['threads'] = 1 # only one thread for data loading
|
||||
dataset_config['max_len'] = 1921 # set max length to maximal
|
||||
dataset_config['augment16'] = False # disable augmentation for inference
|
||||
|
||||
if dataset_config['task_selection'] == ['all']:
|
||||
task_list = [i+1 for i in range(20)]
|
||||
else:
|
||||
task_list = [int(i) for i in dataset_config['task_selection']]
|
||||
|
||||
dl = DataLoader(dataset_config) # load data loader by config
|
||||
|
||||
model_config['input_size'] = dl.x_size # add data size to model config
|
||||
model_config['output_size'] = dl.y_size
|
||||
model_config['memory_unit_config']['bypass_dropout'] = False # no dropout during inference
|
||||
|
||||
model = MANN(model_config) # load memory augmented neural network model
|
||||
|
||||
data, target, mask = model.feed # create data feed for session run
|
||||
|
||||
word_dict = dl.dataset.word_dict # full dictionary of all tasks
|
||||
re_word_dict = dl.dataset.re_word_dict # reverse dictionary
|
||||
|
||||
print("vocabulary size: {}".format(dl.vocabulary_size))
|
||||
print("train set length: {}".format(dl.sample_amount('train')))
|
||||
print("valid set length: {}".format(dl.sample_amount('valid')))
|
||||
print("model parameter amount: {}".format(model.parameter_amount))
|
||||
|
||||
saver = tf.train.Saver()
|
||||
conf = tf.ConfigProto()
|
||||
conf.gpu_options.per_process_gpu_memory_fraction = 0.8
|
||||
conf.gpu_options.allocator_type = 'BFC'
|
||||
conf.gpu_options.allow_growth = True
|
||||
|
||||
with tf.Session(config=conf) as sess:
|
||||
|
||||
saver.restore(sess, os.path.join(expt_dir, "model_dump.ckpt"))
|
||||
mean_error = []
|
||||
for task in task_list:
|
||||
|
||||
# load data loader for task
|
||||
dataset_config['task_selection'] = [task]
|
||||
dl = DataLoader(dataset_config, word_dict, re_word_dict)
|
||||
valid_loader = dl.get_data_loader('test')
|
||||
|
||||
predictions, targets, masks = [], [], []
|
||||
all_corrects, all_overall = 0, 0
|
||||
|
||||
# infer model
|
||||
for v in range(int(dl.batch_amount('test'))):
|
||||
sample = next(valid_loader)
|
||||
prediction = sess.run([model.prediction, ],feed_dict={data: sample['x'],
|
||||
target: sample['y'],
|
||||
mask: sample['m']})
|
||||
predictions.append(prediction)
|
||||
targets.append(sample['y'])
|
||||
masks.append(sample['m'])
|
||||
|
||||
# calculate mean error rate for task
|
||||
for p, t, m in zip(predictions, targets, masks):
|
||||
tm = np.argmax(t, axis=-1)
|
||||
pm = np.argmax(p, axis=-1)
|
||||
corrects = np.equal(tm, pm)
|
||||
all_corrects += np.sum(corrects * m)
|
||||
all_overall += np.sum(m)
|
||||
|
||||
word_error_rate = 1- (all_corrects/all_overall)
|
||||
mean_error.append(word_error_rate)
|
||||
|
||||
print("word error rate task {:2}: {:0.3}".format(task, word_error_rate))
|
||||
print("mean word error rate : {:0.3}".format(np.mean(mean_error)))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user