Add files via upload
This commit is contained in:
parent
cafe5a84ad
commit
2d660e7ced
@ -0,0 +1,3 @@
|
|||||||
|
#! /bin/bash
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES='' python run_dee_task.py --skip_train True $*
|
@ -0,0 +1,72 @@
|
|||||||
|
#! /bin/bash
|
||||||
|
|
||||||
|
DATA_DIR=./Data
|
||||||
|
EXP_DIR=./Exps
|
||||||
|
COMMON_TASK_NAME=HelloEDAG
|
||||||
|
RESUME_TRAIN=True
|
||||||
|
SAVE_CPT=True
|
||||||
|
N_EPOCH=100
|
||||||
|
TRAIN_BS=64
|
||||||
|
EVAL_BS=2
|
||||||
|
NUM_GPUS=8
|
||||||
|
GRAD_ACC_STEP=8
|
||||||
|
# There is one parameter update for every GRAD_ACC_STEP back-propagation steps,
|
||||||
|
# so the real runtime batch size is TRAIN_BS / GRAD_ACC_STEP.
|
||||||
|
# In this way, we can achieve a large batch training with only a few GPUs
|
||||||
|
|
||||||
|
|
||||||
|
# Doc2EDAG Models: Doc2EDAG, GreedyDec
|
||||||
|
MODEL_TYPE=Doc2EDAG
|
||||||
|
MODEL_STR=Doc2EDAG
|
||||||
|
echo "---> ${MODEL_STR} Run"
|
||||||
|
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
|
||||||
|
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
|
||||||
|
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
|
||||||
|
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR} --add_greedy_dec True
|
||||||
|
|
||||||
|
|
||||||
|
# DCFEE Baselines: DCFEE-O, DCFEE-M
|
||||||
|
MODEL_TYPE=DCFEE
|
||||||
|
MODEL_STR=DCFEE
|
||||||
|
echo "---> ${MODEL_STR} Run"
|
||||||
|
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
|
||||||
|
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
|
||||||
|
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
|
||||||
|
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR}
|
||||||
|
|
||||||
|
|
||||||
|
# Ablation Tests of Doc2EDAG
|
||||||
|
MODEL_TYPE=Doc2EDAG
|
||||||
|
|
||||||
|
# Ablation Test 1
|
||||||
|
MODEL_STR=Doc2EDAG-NoPathMem
|
||||||
|
echo "---> ${MODEL_STR} Run"
|
||||||
|
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
|
||||||
|
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
|
||||||
|
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
|
||||||
|
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR} --use_path_mem False
|
||||||
|
|
||||||
|
# Ablation Test 2
|
||||||
|
MODEL_STR=Doc2EDAG-NoScheduledSampling
|
||||||
|
echo "---> ${MODEL_STR} Run"
|
||||||
|
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
|
||||||
|
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
|
||||||
|
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
|
||||||
|
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR} --use_scheduled_sampling False
|
||||||
|
|
||||||
|
# Ablation Test 3
|
||||||
|
MODEL_STR=Doc2EDAG-NoDocEnc
|
||||||
|
echo "---> ${MODEL_STR} Run"
|
||||||
|
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
|
||||||
|
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
|
||||||
|
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
|
||||||
|
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR} --use_doc_enc False
|
||||||
|
|
||||||
|
# Ablation Test 4
|
||||||
|
MODEL_STR=Doc2EDAG-NoFPPenalty
|
||||||
|
echo "---> ${MODEL_STR} Run"
|
||||||
|
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
|
||||||
|
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
|
||||||
|
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
|
||||||
|
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR} --neg_field_loss_scaling 1.0
|
||||||
|
|
@ -0,0 +1,97 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# AUTHOR: Shun Zheng
|
||||||
|
# DATE: 19-9-19
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from dee.utils import set_basic_log_config, strtobool
|
||||||
|
from dee.dee_task import DEETask, DEETaskSetting
|
||||||
|
from dee.dee_helper import aggregate_task_eval_info, print_total_eval_info, print_single_vs_multi_performance
|
||||||
|
|
||||||
|
set_basic_log_config()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args(in_args=None):
|
||||||
|
arg_parser = argparse.ArgumentParser()
|
||||||
|
arg_parser.add_argument('--task_name', type=str, required=True,
|
||||||
|
help='Take Name')
|
||||||
|
arg_parser.add_argument('--data_dir', type=str, default='./Data',
|
||||||
|
help='Data directory')
|
||||||
|
arg_parser.add_argument('--exp_dir', type=str, default='./Exps',
|
||||||
|
help='Experiment directory')
|
||||||
|
arg_parser.add_argument('--save_cpt_flag', type=strtobool, default=True,
|
||||||
|
help='Whether to save cpt for each epoch')
|
||||||
|
arg_parser.add_argument('--skip_train', type=strtobool, default=False,
|
||||||
|
help='Whether to skip training')
|
||||||
|
arg_parser.add_argument('--eval_model_names', type=str, default='DCFEE-O,DCFEE-M,GreedyDec,Doc2EDAG',
|
||||||
|
help="Models to be evaluated, seperated by ','")
|
||||||
|
arg_parser.add_argument('--re_eval_flag', type=strtobool, default=False,
|
||||||
|
help='Whether to re-evaluate previous predictions')
|
||||||
|
|
||||||
|
# add task setting arguments
|
||||||
|
for key, val in DEETaskSetting.base_attr_default_pairs:
|
||||||
|
if isinstance(val, bool):
|
||||||
|
arg_parser.add_argument('--' + key, type=strtobool, default=val)
|
||||||
|
else:
|
||||||
|
arg_parser.add_argument('--'+key, type=type(val), default=val)
|
||||||
|
|
||||||
|
arg_info = arg_parser.parse_args(args=in_args)
|
||||||
|
|
||||||
|
return arg_info
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
in_argv = parse_args()
|
||||||
|
|
||||||
|
task_dir = os.path.join(in_argv.exp_dir, in_argv.task_name)
|
||||||
|
if not os.path.exists(task_dir):
|
||||||
|
os.makedirs(task_dir, exist_ok=True)
|
||||||
|
|
||||||
|
in_argv.model_dir = os.path.join(task_dir, "Model")
|
||||||
|
in_argv.output_dir = os.path.join(task_dir, "Output")
|
||||||
|
|
||||||
|
# in_argv must contain 'data_dir', 'model_dir', 'output_dir'
|
||||||
|
dee_setting = DEETaskSetting(
|
||||||
|
**in_argv.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
# build task
|
||||||
|
dee_task = DEETask(dee_setting, load_train=not in_argv.skip_train)
|
||||||
|
|
||||||
|
if not in_argv.skip_train:
|
||||||
|
# dump hyper-parameter settings
|
||||||
|
if dee_task.is_master_node():
|
||||||
|
fn = '{}.task_setting.json'.format(dee_setting.cpt_file_name)
|
||||||
|
dee_setting.dump_to(task_dir, file_name=fn)
|
||||||
|
|
||||||
|
dee_task.train(save_cpt_flag=in_argv.save_cpt_flag)
|
||||||
|
else:
|
||||||
|
dee_task.logging('Skip training')
|
||||||
|
|
||||||
|
if dee_task.is_master_node():
|
||||||
|
if in_argv.re_eval_flag:
|
||||||
|
data_span_type2model_str2epoch_res_list = dee_task.reevaluate_dee_prediction(dump_flag=True)
|
||||||
|
else:
|
||||||
|
data_span_type2model_str2epoch_res_list = aggregate_task_eval_info(in_argv.output_dir, dump_flag=True)
|
||||||
|
data_type = 'test'
|
||||||
|
span_type = 'pred_span'
|
||||||
|
metric_type = 'micro'
|
||||||
|
mstr_bepoch_list = print_total_eval_info(
|
||||||
|
data_span_type2model_str2epoch_res_list, metric_type=metric_type, span_type=span_type,
|
||||||
|
model_strs=in_argv.eval_model_names.split(','),
|
||||||
|
target_set=data_type
|
||||||
|
)
|
||||||
|
print_single_vs_multi_performance(
|
||||||
|
mstr_bepoch_list, in_argv.output_dir, dee_task.test_features,
|
||||||
|
metric_type=metric_type, data_type=data_type, span_type=span_type
|
||||||
|
)
|
||||||
|
|
||||||
|
# ensure every processes exit at the same time
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user