mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
107 lines
4.0 KiB
YAML
107 lines
4.0 KiB
YAML
|
# Copyright 2018 Jörg Franke
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
#######################################
|
||
|
### Global Configuration ###
|
||
|
#######################################
|
||
|
|
||
|
global:
|
||
|
batch_size: &batch_size 32
|
||
|
|
||
|
#######################################
|
||
|
### Training Configuration ###
|
||
|
#######################################
|
||
|
training:
|
||
|
epochs: 50 # epochs to train
|
||
|
learn_rate: 0.00005 # learning reate for optimizer
|
||
|
optimizer: 'rmsprop' # optimizer [ rmsprop,, adam, momentum, adadelta, adagrad, sgd]
|
||
|
optimizer_config: {'momentum':0.9} # config for optimizer [momentum, nesterov]
|
||
|
gradient_clipping: 10 # gradient clipping value
|
||
|
weight_decay: False # weight decay, False or float
|
||
|
|
||
|
|
||
|
|
||
|
#######################################
|
||
|
### MANN Configuration ###
|
||
|
#######################################
|
||
|
mann:
|
||
|
name: 'mann1'
|
||
|
seed: 245
|
||
|
input_size: 0
|
||
|
output_size: 0
|
||
|
batch_size: *batch_size
|
||
|
input_embedding: False
|
||
|
architecture: 'uni' # bidirectional 172 384
|
||
|
controller_config: {"num_units":[128], "layer_norm":True, "activation":'tanh', 'cell_type':'clstm', 'connect':'sparse'}
|
||
|
memory_unit_config: {"cell_type":'cbmu', "memory_length":64, "memory_width":32, "read_heads":4, "write_heads": 2, "dnc_norm":True, "bypass_dropout":False, "wgate1":False}
|
||
|
atop_rnn_config: False # {"num_units":[32], "layer_norm":True, "activation":'tanh', 'cell_type':'clstm', 'connect':'sparse', 'attention':False}
|
||
|
output_function: "softmax" # softmax tanh5 linear
|
||
|
output_mask: True
|
||
|
loss_function: 'cross_entropy' # cross_entropy, mse
|
||
|
bw_input_fw: False
|
||
|
|
||
|
|
||
|
###################################################################
|
||
|
####### bAbI QA Task ######
|
||
|
###################################################################
|
||
|
babi_task:
|
||
|
data_set: 'babi'
|
||
|
|
||
|
# data_dir: 'data_babi/tasks_1-20_v1-2'
|
||
|
# tmp_dir: 'data_dir'
|
||
|
|
||
|
seed: 876
|
||
|
valid_ratio: 0.1 # like nature paper
|
||
|
batch_size: *batch_size
|
||
|
max_len: 1000
|
||
|
|
||
|
set_type: ['en-10k'] # ['hn-10k', 'en-10k', 'shuffled-10k']
|
||
|
task_selection: ['1', '2', '12'] # list of number (1-20) or 'all'
|
||
|
augment16: False # augmentation of task 16
|
||
|
|
||
|
num_chached: 5 # number of cached samples
|
||
|
threads: 1 # number of parallel threads
|
||
|
|
||
|
|
||
|
|
||
|
##################################################################
|
||
|
###### Copy Task ######
|
||
|
##################################################################
|
||
|
copy_task:
|
||
|
data_set: 'copy_task'
|
||
|
|
||
|
seed: 125
|
||
|
batch_size: *batch_size
|
||
|
|
||
|
set_list:
|
||
|
train:
|
||
|
quantity: 6000 # quantity of the training set
|
||
|
min_length: 20 # min length of the training sample
|
||
|
max_length: 50 # max length of a training sample
|
||
|
valid:
|
||
|
quantity: 600 # quantity of the validation set
|
||
|
min_length: 50
|
||
|
max_length: 100
|
||
|
# test:
|
||
|
# quantity: 100
|
||
|
# min_length: 10
|
||
|
# max_length: 10
|
||
|
|
||
|
feature_width: 100 # width of the feature vector
|
||
|
|
||
|
num_chached: 10 # number of cached samples
|
||
|
threads: 1 # number of parallel threads
|
||
|
|