Add files via upload

This commit is contained in:
missQian 2020-10-04 22:28:39 +08:00 committed by GitHub
parent cafe5a84ad
commit 2d660e7ced
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 172 additions and 0 deletions

View File

@ -0,0 +1,3 @@
#! /bin/bash
CUDA_VISIBLE_DEVICES='' python run_dee_task.py --skip_train True $*

View File

@ -0,0 +1,72 @@
#! /bin/bash
DATA_DIR=./Data
EXP_DIR=./Exps
COMMON_TASK_NAME=HelloEDAG
RESUME_TRAIN=True
SAVE_CPT=True
N_EPOCH=100
TRAIN_BS=64
EVAL_BS=2
NUM_GPUS=8
GRAD_ACC_STEP=8
# There is one parameter update for every GRAD_ACC_STEP back-propagation steps,
# so the real runtime batch size is TRAIN_BS / GRAD_ACC_STEP.
# In this way, we can achieve a large batch training with only a few GPUs
# Doc2EDAG Models: Doc2EDAG, GreedyDec
MODEL_TYPE=Doc2EDAG
MODEL_STR=Doc2EDAG
echo "---> ${MODEL_STR} Run"
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR} --add_greedy_dec True
# DCFEE Baselines: DCFEE-O, DCFEE-M
MODEL_TYPE=DCFEE
MODEL_STR=DCFEE
echo "---> ${MODEL_STR} Run"
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR}
# Ablation Tests of Doc2EDAG
MODEL_TYPE=Doc2EDAG
# Ablation Test 1
MODEL_STR=Doc2EDAG-NoPathMem
echo "---> ${MODEL_STR} Run"
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR} --use_path_mem False
# Ablation Test 2
MODEL_STR=Doc2EDAG-NoScheduledSampling
echo "---> ${MODEL_STR} Run"
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR} --use_scheduled_sampling False
# Ablation Test 3
MODEL_STR=Doc2EDAG-NoDocEnc
echo "---> ${MODEL_STR} Run"
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR} --use_doc_enc False
# Ablation Test 4
MODEL_STR=Doc2EDAG-NoFPPenalty
echo "---> ${MODEL_STR} Run"
./train_multi.sh ${NUM_GPUS} --resume_latest_cpt ${RESUME_TRAIN} --save_cpt_flag ${SAVE_CPT} \
--data_dir ${DATA_DIR} --exp_dir ${EXP_DIR} --task_name ${COMMON_TASK_NAME} --num_train_epochs ${N_EPOCH} \
--train_batch_size ${TRAIN_BS} --gradient_accumulation_steps ${GRAD_ACC_STEP} --eval_batch_size ${EVAL_BS} \
--model_type ${MODEL_TYPE} --cpt_file_name ${MODEL_STR} --neg_field_loss_scaling 1.0

View File

@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
# AUTHOR: Shun Zheng
# DATE: 19-9-19
import argparse
import os
import torch.distributed as dist
from dee.utils import set_basic_log_config, strtobool
from dee.dee_task import DEETask, DEETaskSetting
from dee.dee_helper import aggregate_task_eval_info, print_total_eval_info, print_single_vs_multi_performance
set_basic_log_config()
def parse_args(in_args=None):
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--task_name', type=str, required=True,
help='Take Name')
arg_parser.add_argument('--data_dir', type=str, default='./Data',
help='Data directory')
arg_parser.add_argument('--exp_dir', type=str, default='./Exps',
help='Experiment directory')
arg_parser.add_argument('--save_cpt_flag', type=strtobool, default=True,
help='Whether to save cpt for each epoch')
arg_parser.add_argument('--skip_train', type=strtobool, default=False,
help='Whether to skip training')
arg_parser.add_argument('--eval_model_names', type=str, default='DCFEE-O,DCFEE-M,GreedyDec,Doc2EDAG',
help="Models to be evaluated, seperated by ','")
arg_parser.add_argument('--re_eval_flag', type=strtobool, default=False,
help='Whether to re-evaluate previous predictions')
# add task setting arguments
for key, val in DEETaskSetting.base_attr_default_pairs:
if isinstance(val, bool):
arg_parser.add_argument('--' + key, type=strtobool, default=val)
else:
arg_parser.add_argument('--'+key, type=type(val), default=val)
arg_info = arg_parser.parse_args(args=in_args)
return arg_info
if __name__ == '__main__':
in_argv = parse_args()
task_dir = os.path.join(in_argv.exp_dir, in_argv.task_name)
if not os.path.exists(task_dir):
os.makedirs(task_dir, exist_ok=True)
in_argv.model_dir = os.path.join(task_dir, "Model")
in_argv.output_dir = os.path.join(task_dir, "Output")
# in_argv must contain 'data_dir', 'model_dir', 'output_dir'
dee_setting = DEETaskSetting(
**in_argv.__dict__
)
# build task
dee_task = DEETask(dee_setting, load_train=not in_argv.skip_train)
if not in_argv.skip_train:
# dump hyper-parameter settings
if dee_task.is_master_node():
fn = '{}.task_setting.json'.format(dee_setting.cpt_file_name)
dee_setting.dump_to(task_dir, file_name=fn)
dee_task.train(save_cpt_flag=in_argv.save_cpt_flag)
else:
dee_task.logging('Skip training')
if dee_task.is_master_node():
if in_argv.re_eval_flag:
data_span_type2model_str2epoch_res_list = dee_task.reevaluate_dee_prediction(dump_flag=True)
else:
data_span_type2model_str2epoch_res_list = aggregate_task_eval_info(in_argv.output_dir, dump_flag=True)
data_type = 'test'
span_type = 'pred_span'
metric_type = 'micro'
mstr_bepoch_list = print_total_eval_info(
data_span_type2model_str2epoch_res_list, metric_type=metric_type, span_type=span_type,
model_strs=in_argv.eval_model_names.split(','),
target_set=data_type
)
print_single_vs_multi_performance(
mstr_bepoch_list, in_argv.output_dir, dee_task.test_features,
metric_type=metric_type, data_type=data_type, span_type=span_type
)
# ensure every processes exit at the same time
if dist.is_initialized():
dist.barrier()