219 lines
7.4 KiB
Python
219 lines
7.4 KiB
Python
import time
|
|
import os
|
|
import json
|
|
import logging
|
|
from torch.utils.data import DataLoader
|
|
from sklearn.model_selection import KFold
|
|
from src.utils.trainer import train
|
|
from src.utils.options import Args
|
|
from src.utils.model_utils import build_model
|
|
from src.utils.dataset_utils import NERDataset
|
|
from src.utils.evaluator import crf_evaluation, span_evaluation, mrc_evaluation
|
|
from src.utils.functions_utils import set_seed, get_model_path_list, load_model_and_parallel, get_time_dif
|
|
from src.preprocess.processor import NERProcessor, convert_examples_to_features
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
datefmt="%m/%d/%Y %H:%M:%S",
|
|
level=logging.INFO
|
|
)
|
|
|
|
def train_base(opt, train_examples, dev_examples=None):
|
|
with open(os.path.join(opt.mid_data_dir, f'{opt.task_type}_ent2id.json'), encoding='utf-8') as f:
|
|
ent2id = json.load(f)
|
|
|
|
train_features = convert_examples_to_features(opt.task_type, train_examples,
|
|
opt.max_seq_len, opt.bert_dir, ent2id)[0]
|
|
|
|
train_dataset = NERDataset(opt.task_type, train_features, 'train', use_type_embed=opt.use_type_embed)
|
|
|
|
if opt.task_type == 'crf':
|
|
model = build_model('crf', opt.bert_dir, num_tags=len(ent2id),
|
|
dropout_prob=opt.dropout_prob)
|
|
elif opt.task_type == 'mrc':
|
|
model = build_model('mrc', opt.bert_dir,
|
|
dropout_prob=opt.dropout_prob,
|
|
use_type_embed=opt.use_type_embed,
|
|
loss_type=opt.loss_type)
|
|
else:
|
|
model = build_model('span', opt.bert_dir, num_tags=len(ent2id)+1,
|
|
dropout_prob=opt.dropout_prob,
|
|
loss_type=opt.loss_type)
|
|
|
|
train(opt, model, train_dataset)
|
|
|
|
if dev_examples is not None:
|
|
|
|
dev_features, dev_callback_info = convert_examples_to_features(opt.task_type, dev_examples,
|
|
opt.max_seq_len, opt.bert_dir, ent2id)
|
|
|
|
dev_dataset = NERDataset(opt.task_type, dev_features, 'dev', use_type_embed=opt.use_type_embed)
|
|
|
|
dev_loader = DataLoader(dev_dataset, batch_size=opt.eval_batch_size,
|
|
shuffle=False, num_workers=0)
|
|
|
|
dev_info = (dev_loader, dev_callback_info)
|
|
|
|
model_path_list = get_model_path_list(opt.output_dir)
|
|
|
|
metric_str = ''
|
|
|
|
max_f1 = 0.
|
|
max_f1_step = 0
|
|
|
|
max_f1_path = ''
|
|
|
|
for idx, model_path in enumerate(model_path_list):
|
|
|
|
tmp_step = model_path.split('/')[-2].split('-')[-1]
|
|
|
|
|
|
model, device = load_model_and_parallel(model, opt.gpu_ids[0],
|
|
ckpt_path=model_path)
|
|
|
|
if opt.task_type == 'crf':
|
|
tmp_metric_str, tmp_f1 = crf_evaluation(model, dev_info, device, ent2id)
|
|
elif opt.task_type == 'mrc':
|
|
tmp_metric_str, tmp_f1 = mrc_evaluation(model, dev_info, device)
|
|
else:
|
|
tmp_metric_str, tmp_f1 = span_evaluation(model, dev_info, device, ent2id)
|
|
|
|
logger.info(f'In step {tmp_step}:\n {tmp_metric_str}')
|
|
|
|
metric_str += f'In step {tmp_step}:\n {tmp_metric_str}' + '\n\n'
|
|
|
|
if tmp_f1 > max_f1:
|
|
max_f1 = tmp_f1
|
|
max_f1_step = tmp_step
|
|
max_f1_path = model_path
|
|
|
|
max_metric_str = f'Max f1 is: {max_f1}, in step {max_f1_step}'
|
|
|
|
logger.info(max_metric_str)
|
|
|
|
metric_str += max_metric_str + '\n'
|
|
|
|
eval_save_path = os.path.join(opt.output_dir, 'eval_metric.txt')
|
|
|
|
with open(eval_save_path, 'a', encoding='utf-8') as f1:
|
|
f1.write(metric_str)
|
|
|
|
with open('./best_ckpt_path.txt', 'a', encoding='utf-8') as f2:
|
|
f2.write(max_f1_path + '\n')
|
|
|
|
del_dir_list = [os.path.join(opt.output_dir, path.split('/')[-2])
|
|
for path in model_path_list if path != max_f1_path]
|
|
|
|
import shutil
|
|
for x in del_dir_list:
|
|
shutil.rmtree(x)
|
|
logger.info('{}已删除'.format(x))
|
|
|
|
|
|
def training(opt):
|
|
if args.task_type == 'mrc':
|
|
# 62 for mrc query
|
|
processor = NERProcessor(opt.max_seq_len-62)
|
|
else:
|
|
processor = NERProcessor(opt.max_seq_len)
|
|
|
|
train_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'train.json'))
|
|
|
|
# add pseudo data to train data
|
|
pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json'))
|
|
train_raw_examples = train_raw_examples + pseudo_raw_examples
|
|
|
|
train_examples = processor.get_examples(train_raw_examples, 'train')
|
|
|
|
dev_examples = None
|
|
if opt.eval_model:
|
|
dev_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'dev.json'))
|
|
dev_examples = processor.get_examples(dev_raw_examples, 'dev')
|
|
|
|
train_base(opt, train_examples, dev_examples)
|
|
|
|
|
|
def stacking(opt):
|
|
logger.info('Start to KFold stack attribution model')
|
|
|
|
if args.task_type == 'mrc':
|
|
# 62 for mrc query
|
|
processor = NERProcessor(opt.max_seq_len-62)
|
|
else:
|
|
processor = NERProcessor(opt.max_seq_len)
|
|
|
|
kf = KFold(5, shuffle=True, random_state=42)
|
|
|
|
stack_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'stack.json'))
|
|
|
|
pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json'))
|
|
|
|
base_output_dir = opt.output_dir
|
|
|
|
for i, (train_ids, dev_ids) in enumerate(kf.split(stack_raw_examples)):
|
|
logger.info(f'Start to train the {i} fold')
|
|
train_raw_examples = [stack_raw_examples[_idx] for _idx in train_ids]
|
|
|
|
# add pseudo data to train data
|
|
train_raw_examples = train_raw_examples + pseudo_raw_examples
|
|
train_examples = processor.get_examples(train_raw_examples, 'train')
|
|
|
|
dev_raw_examples = [stack_raw_examples[_idx] for _idx in dev_ids]
|
|
dev_info = processor.get_examples(dev_raw_examples, 'dev')
|
|
|
|
tmp_output_dir = os.path.join(base_output_dir, f'v{i}')
|
|
|
|
opt.output_dir = tmp_output_dir
|
|
|
|
train_base(opt, train_examples, dev_info)
|
|
|
|
if __name__ == '__main__':
|
|
start_time = time.time()
|
|
logging.info('----------------开始计时----------------')
|
|
logging.info('----------------------------------------')
|
|
|
|
args = Args().get_parser()
|
|
|
|
assert args.mode in ['train', 'stack'], 'mode mismatch'
|
|
assert args.task_type in ['crf', 'span', 'mrc']
|
|
|
|
args.output_dir = os.path.join(args.output_dir, args.bert_type)
|
|
|
|
set_seed(args.seed)
|
|
|
|
if args.attack_train != '':
|
|
args.output_dir += f'_{args.attack_train}'
|
|
|
|
if args.weight_decay:
|
|
args.output_dir += '_wd'
|
|
|
|
if args.use_fp16:
|
|
args.output_dir += '_fp16'
|
|
|
|
if args.task_type == 'span':
|
|
args.output_dir += f'_{args.loss_type}'
|
|
|
|
if args.task_type == 'mrc':
|
|
if args.use_type_embed:
|
|
args.output_dir += f'_embed'
|
|
args.output_dir += f'_{args.loss_type}'
|
|
|
|
args.output_dir += f'_{args.task_type}'
|
|
|
|
if args.mode == 'stack':
|
|
args.output_dir += '_stack'
|
|
|
|
if not os.path.exists(args.output_dir):
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
logger.info(f'{args.mode} {args.task_type} in max_seq_len {args.max_seq_len}')
|
|
|
|
if args.mode == 'train':
|
|
training(args)
|
|
else:
|
|
stacking(args)
|
|
|
|
time_dif = get_time_dif(start_time)
|
|
logging.info("----------本次容器运行时长:{}-----------".format(time_dif))
|