DeepNER/main.py

219 lines
7.4 KiB
Python
Raw Permalink Normal View History

2020-12-23 19:17:20 +08:00
import time
import os
import json
import logging
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold
from src.utils.trainer import train
from src.utils.options import Args
from src.utils.model_utils import build_model
from src.utils.dataset_utils import NERDataset
from src.utils.evaluator import crf_evaluation, span_evaluation, mrc_evaluation
from src.utils.functions_utils import set_seed, get_model_path_list, load_model_and_parallel, get_time_dif
from src.preprocess.processor import NERProcessor, convert_examples_to_features
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO
)
def train_base(opt, train_examples, dev_examples=None):
with open(os.path.join(opt.mid_data_dir, f'{opt.task_type}_ent2id.json'), encoding='utf-8') as f:
ent2id = json.load(f)
train_features = convert_examples_to_features(opt.task_type, train_examples,
opt.max_seq_len, opt.bert_dir, ent2id)[0]
train_dataset = NERDataset(opt.task_type, train_features, 'train', use_type_embed=opt.use_type_embed)
if opt.task_type == 'crf':
model = build_model('crf', opt.bert_dir, num_tags=len(ent2id),
dropout_prob=opt.dropout_prob)
elif opt.task_type == 'mrc':
model = build_model('mrc', opt.bert_dir,
dropout_prob=opt.dropout_prob,
use_type_embed=opt.use_type_embed,
loss_type=opt.loss_type)
else:
model = build_model('span', opt.bert_dir, num_tags=len(ent2id)+1,
dropout_prob=opt.dropout_prob,
loss_type=opt.loss_type)
train(opt, model, train_dataset)
if dev_examples is not None:
dev_features, dev_callback_info = convert_examples_to_features(opt.task_type, dev_examples,
opt.max_seq_len, opt.bert_dir, ent2id)
dev_dataset = NERDataset(opt.task_type, dev_features, 'dev', use_type_embed=opt.use_type_embed)
dev_loader = DataLoader(dev_dataset, batch_size=opt.eval_batch_size,
shuffle=False, num_workers=0)
dev_info = (dev_loader, dev_callback_info)
model_path_list = get_model_path_list(opt.output_dir)
metric_str = ''
max_f1 = 0.
max_f1_step = 0
max_f1_path = ''
for idx, model_path in enumerate(model_path_list):
tmp_step = model_path.split('/')[-2].split('-')[-1]
model, device = load_model_and_parallel(model, opt.gpu_ids[0],
ckpt_path=model_path)
if opt.task_type == 'crf':
tmp_metric_str, tmp_f1 = crf_evaluation(model, dev_info, device, ent2id)
elif opt.task_type == 'mrc':
tmp_metric_str, tmp_f1 = mrc_evaluation(model, dev_info, device)
else:
tmp_metric_str, tmp_f1 = span_evaluation(model, dev_info, device, ent2id)
logger.info(f'In step {tmp_step}:\n {tmp_metric_str}')
metric_str += f'In step {tmp_step}:\n {tmp_metric_str}' + '\n\n'
if tmp_f1 > max_f1:
max_f1 = tmp_f1
max_f1_step = tmp_step
max_f1_path = model_path
max_metric_str = f'Max f1 is: {max_f1}, in step {max_f1_step}'
logger.info(max_metric_str)
metric_str += max_metric_str + '\n'
eval_save_path = os.path.join(opt.output_dir, 'eval_metric.txt')
with open(eval_save_path, 'a', encoding='utf-8') as f1:
f1.write(metric_str)
with open('./best_ckpt_path.txt', 'a', encoding='utf-8') as f2:
f2.write(max_f1_path + '\n')
del_dir_list = [os.path.join(opt.output_dir, path.split('/')[-2])
for path in model_path_list if path != max_f1_path]
import shutil
for x in del_dir_list:
shutil.rmtree(x)
logger.info('{}已删除'.format(x))
def training(opt):
if args.task_type == 'mrc':
# 62 for mrc query
processor = NERProcessor(opt.max_seq_len-62)
else:
processor = NERProcessor(opt.max_seq_len)
train_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'train.json'))
# add pseudo data to train data
pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json'))
train_raw_examples = train_raw_examples + pseudo_raw_examples
train_examples = processor.get_examples(train_raw_examples, 'train')
dev_examples = None
if opt.eval_model:
dev_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'dev.json'))
dev_examples = processor.get_examples(dev_raw_examples, 'dev')
train_base(opt, train_examples, dev_examples)
def stacking(opt):
logger.info('Start to KFold stack attribution model')
if args.task_type == 'mrc':
# 62 for mrc query
processor = NERProcessor(opt.max_seq_len-62)
else:
processor = NERProcessor(opt.max_seq_len)
kf = KFold(5, shuffle=True, random_state=42)
stack_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'stack.json'))
pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json'))
base_output_dir = opt.output_dir
for i, (train_ids, dev_ids) in enumerate(kf.split(stack_raw_examples)):
logger.info(f'Start to train the {i} fold')
train_raw_examples = [stack_raw_examples[_idx] for _idx in train_ids]
# add pseudo data to train data
train_raw_examples = train_raw_examples + pseudo_raw_examples
train_examples = processor.get_examples(train_raw_examples, 'train')
dev_raw_examples = [stack_raw_examples[_idx] for _idx in dev_ids]
dev_info = processor.get_examples(dev_raw_examples, 'dev')
tmp_output_dir = os.path.join(base_output_dir, f'v{i}')
opt.output_dir = tmp_output_dir
train_base(opt, train_examples, dev_info)
if __name__ == '__main__':
start_time = time.time()
logging.info('----------------开始计时----------------')
logging.info('----------------------------------------')
args = Args().get_parser()
assert args.mode in ['train', 'stack'], 'mode mismatch'
assert args.task_type in ['crf', 'span', 'mrc']
args.output_dir = os.path.join(args.output_dir, args.bert_type)
set_seed(args.seed)
if args.attack_train != '':
args.output_dir += f'_{args.attack_train}'
if args.weight_decay:
args.output_dir += '_wd'
if args.use_fp16:
args.output_dir += '_fp16'
if args.task_type == 'span':
args.output_dir += f'_{args.loss_type}'
if args.task_type == 'mrc':
if args.use_type_embed:
args.output_dir += f'_embed'
args.output_dir += f'_{args.loss_type}'
args.output_dir += f'_{args.task_type}'
if args.mode == 'stack':
args.output_dir += '_stack'
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
logger.info(f'{args.mode} {args.task_type} in max_seq_len {args.max_seq_len}')
if args.mode == 'train':
training(args)
else:
stacking(args)
time_dif = get_time_dif(start_time)
logging.info("----------本次容器运行时长:{}-----------".format(time_dif))