first commit

This commit is contained in:
zxx 2020-12-23 19:17:20 +08:00
commit 465eb0ace3
43 changed files with 436547 additions and 0 deletions

230
README.md Normal file
View File

@ -0,0 +1,230 @@
# Chinese-DeepNER-Pytorch
## 天池中药说明书实体识别挑战冠军方案开源
### 后续官方开放数据集后本项目会进行优化升级成为DeepNER项目减少代码冗余提高代码可读性。项目包含完整的数据处理、训练、验证、测试、部署流程提供详细的代码注释与模型介绍提供更普适的基于预训练的中文命名实体识别方案开箱即用欢迎star and issue
代码框架基于pytorch和transformers支持最新版的pytorch, 框架**复用性、解耦性、易读性**较高很容易修改迁移至其他NLP任务中
# 赛题背景
## 任务描述
人工智能加速了中医药领域的传承创新发展其中中医药文本的信息抽取部分是构建中医药知识图谱的核心部分为上层应用如临床辅助诊疗系统的构建CDSS等奠定了基础。本次NER挑战需要抽取中药药品说明书中的关键信息包括药品、药物成分、疾病、症状、证候等13类实体构建中医药药品知识库。
## 数据探索分析
本次竞赛训练数据有三个特点:
- 中药药品说明书以长文本居多
<img src="md_files/1.png" style="zoom:50%;" />
- 医疗场景下的标注样本不足
<img src="md_files/2.png" style="zoom:50%;" />
- 标签分布不平衡
<img src="md_files/3.png" style="zoom:50%;" />
# 核心思路
## 数据预处理
首先对说明书文本进行预清洗与长文本切分。预清洗部分对无效字符进行过滤。针对长文本问题,采用两级文本切分的策略。切分后的句子可能过短,将短文本归并,使得归并后的文本长度不超过设置的最大长度。流程图如下:
<img src="md_files/5.png" style="zoom: 33%;" />
此外,利用全部标注数据构造实体知识库,作为领域先验词典。
## Baseline: BERT-CRF
<img src="md_files/6.png" style="zoom: 33%;" />
- Baseline 细节
- 预训练模型:选用 UER-large-24 layer[1]UER在RoBerta-wwm 框架下采用大规模优质中文语料继续训练CLUE 任务中单模第一
- 差分学习率BERT层学习率2e-5其他层学习率2e-3
- 参数初始化模型其他模块与BERT采用相同的初始化方式
- 滑动参数平均加权平均最后几个epoch模型的权重得到更加平滑和表现更优的模型
- Baseline bad-case分析
<img src="md_files/7.png" style="zoom: 33%;" />
## 优化1对抗训练
- 动机:采用对抗训练缓解模型鲁棒性差的问题,提升模型泛化能力
- 对抗训练是一种引入噪声的训练方式,可以对参数进行正则化,提升模型鲁棒性和泛化能力
- Fast Gradient Method (FGM)对embedding层在梯度方向添加扰动
- Projected Gradient Descent (PGD) [2]:迭代扰动,每次扰动被投影到规定范围内
## 优化2混合精度训练FP16
- 动机:对抗训练降低了计算效率,使用混合精度训练优化训练耗时
- 混合精度训练
- 在内存中用FP16做存储和乘法来加速
- 用FP32做累加避免舍入误差
- 损失放大
- 反向传播前扩大2^k倍loss防止loss下溢出
- 反向传播后将权重梯度还原
## 优化3多模型融合
- 动机baseline 错误集中于歧义性错误,采用多级医学命名实体识别系统以消除歧义性
- 方法:差异化多级模型融合系统
- 模型框架差异化BERT-CRF & BERT-SPAN & BERT-MRC
- 训练数据差异化更换随机种子、更换句子切分长度256、512
- 多级模型融合策略
- 融合模型1——BERT-SPAN
- 采用SPAN指针的形式替代CRF模块加快训练速度
- 以半指针-半标注的结构预测实体的起始位置,同时标注过程中给出实体类别
- 采用严格解码形式重叠实体选取logits最大的一个保证准确率
- 使用label smooth缓解过拟合问题
<img src="md_files/10.png" style="zoom:33%;" />
- 融合模型2——BERT-MRC
- 基于阅读理解的方式处理NER任务
- query实体类型的描述来作为query
- doc分句后的原始文本作为doc
- 针对每一种类型构造一个样本训练时有大量负样本可以随机选取30%加入训练,其余丢弃,保证效率
- 预测时对每一类都需构造一次样本,对解码输出不做限制,保证召回率
- 使用label smooth缓解过拟合问题
- MRC在本次数据集上精度表现不佳且训练和推理效率较低仅作为提升召回率的方案提供代码仅供学习不推荐日常使用
<img src="md_files/11.png" style="zoom:33%;" />
- 多级融合策略
- CRF/SPAN/MRC 5折交叉验证得到的模型进行第一级概率融合将 logits 平均后解码实体
- CRF/SPAN/MRC 概率融合后的模型进行第二级投票融合,获取最终结果
<img src="md_files/12.png" style="zoom:33%;" />
## 优化4半监督学习
- 动机:为了缓解医疗场景下的标注语料稀缺的问题, 我们使用半监督学习伪标签充分利用未标注的500条初赛测试集
- 策略:动态伪标签
- 首先使用原始标注数据训练一个基准模型M
- 使用基准模型M对初赛测试集进行预测得到伪标签
- 将伪标签加入训练集赋予伪标签一个动态可学习权重图中alpha加入真实标签数据中共同训练得到模型M
<img src="md_files/13.png" style="zoom: 25%;" />
- tips使用多模融合的基准模型减少伪标签的噪音权重也可以固定选取需多尝试哪个效果好本质上是降低伪标签的loss权重是缓解伪标签噪音的一种方法。
## 其他无明显提升的尝试方案
- 取BERT后四层动态加权输出无明显提升
- BERT 输出后加上BiLSTM / IDCNN 模块,过拟合严重,训练速度大大降低
- 数据增强,对同类实体词进行随机替换,以扩充训练数据
- BERT-SPAN / MRC 模型采用focal loss / dice loss 等缓解标签不平衡
- 利用构造的领域词典修正模型输出
## 最终线上成绩72.90%复赛Rank 1决赛Rank 1
# Ref
[1] Zhao et al., UER: An Open-Source Toolkit for Pre-training Models, EMNLP-IJCNLP, 2019.
[2] Madry et al., Towards Deep Learning Models Resistant to Adversarial Attacks, ICLR, 2018.
## 环境
```python
python3.7
pytorch==1.6.0 +
transformers==2.10.0
pytorch-crf==0.7.2
```
## 项目目录说明
```shell
DeepNER
├── data # 数据文件夹
│ ├── mid_data # 存放一些中间数据
│ │ ├── crf_ent2id.json # crf 模型的 schema
│ │ └── span_ent2id.json # span 模型的 schema
│ │ └── mrc_ent2id.json # mrc 模型的 schema
│ ├── raw_data # 转换后的数据
│ │ ├── dev.json # 转换后的验证集
│ │ ├── test.json # 转换后的初赛测试集
│ │ ├── pseudo.json # 转换后的半监督数据
│ │ ├── stack.json # 转换后的全体数据
│ └── └── train.json # 转换后的训练集
├── out # 存放训练好的模型
│ ├── ...
│ └── ...
├── src
│ ├── preprocess
│ │ ├── convert_raw_data.py # 处理转换原始数据
│ │ └── processor.py # 转换数据为 Bert 模型的输入
│ ├── utils
│ │ ├── attack_train_utils.py # 对抗训练 FGM / PGD
│ │ ├── dataset_utils.py # torch Dataset
│ │ ├── evaluator.py # 模型评估
│ │ ├── functions_utils.py # 跨文件调用的一些 functions
│ │ ├── model_utils.py # Span & CRF & MRC model (pytorch)
│ │ ├── options.py # 命令行参数
│ | └── trainer.py # 训练器
|
├── competition_predict.py # 复赛数据推理并提交
├── README.md # ...
├── convert_test_data.py # 将复赛 test 转化成 json 格式
├── run.sh # 运行脚本
└── main.py # main 函数 (主要用于训练/评估)
```
## 使用说明
### 预训练使用说明
* 腾讯预训练模型 Uer-large24层 https://github.com/dbiir/UER-py/wiki/Modelzoo
* 哈工大预训练模型 https://github.com/ymcui/Chinese-BERT-wwm
百度云下载链接:
链接https://pan.baidu.com/s/1axdkovbzGaszl8bXIn4sPw
提取码jjba
(注意:需人工将 vocab.txt 中两个 [unused] 转换成 [INV] 和 [BLANK]
tips: 推荐使用 uer、roberta-wwm、robert-wwm-large
### 数据转换
**注:已提供转换好的数据 无需运行**
```python
python src/preprocessing/convert_raw_data.py
```
### 训练阶段
```shell
bash run.sh
```
**注:脚本中指定的 BERT_DIR 指BERT所在文件夹需要把 BERT 下载到指定文件夹中**
##### BERT-CRF模型训练
```python
task_type='crf'
mode='train' or 'stack' train:单模训练与验证 stack:5折训练与验证
swa_start: swa 模型权重平均开始的 epoch
attack_train 'pgd' / 'fgm' / '' 对抗训练 fgm 训练速度慢一倍, pgd 慢两倍pgd 本次数据集效果明显
```
##### BERT-SPAN模型训练
```python
task_type='span'
mode同上
attack_train: 同上
loss_type: 'ce':交叉熵; 'ls_ce'label_smooth; 'focal': focal loss
```
##### BERT-MRC模型训练
```python
task_type='mrc'
mode同上
attack_train: 同上
loss_type: 同上
```
### 预测复赛 test 文件 (上述模型训练完成后)
**注:暂无数据运行,等待官方数据开源后可运行**
```shell
# convert_test_data
python convert_test_data.py
# predict
python competition_predict.py
```

320
competition_predict.py Normal file
View File

@ -0,0 +1,320 @@
import os
import json
import torch
from collections import defaultdict
from transformers import BertTokenizer
from src.utils.model_utils import CRFModel, SpanModel, EnsembleCRFModel, EnsembleSpanModel
from src.utils.evaluator import crf_decode, span_decode
from src.utils.functions_utils import load_model_and_parallel, ensemble_vote
from src.preprocess.processor import cut_sent, fine_grade_tokenize
MID_DATA_DIR = "./data/mid_data"
RAW_DATA_DIR = "./data/raw_data_random"
SUBMIT_DIR = "./result"
GPU_IDS = "0"
LAMBDA = 0.3
THRESHOLD = 0.9
MAX_SEQ_LEN = 512
TASK_TYPE = "crf" # choose crf or span
VOTE = True # choose True or False
VERSION = "mixed" # choose single or ensemble or mixed ; if mixed VOTE and TAST_TYPE is useless.
# single_predict
BERT_TYPE = "uer_large" # roberta_wwm / ernie_1 / uer_large
BERT_DIR = f"./bert/torch_{BERT_TYPE}"
with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f:
CKPT_PATH = f.read().strip()
# ensemble_predict
BERT_DIR_LIST = ["./bert/torch_uer_large", "./bert/torch_roberta_wwm"]
with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f:
ENSEMBLE_DIR_LIST = f.readlines()
print('ENSEMBLE_DIR_LIST:{}'.format(ENSEMBLE_DIR_LIST))
# mixed_predict
MIX_BERT_DIR = "./bert/torch_uer_large"
with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f:
MIX_DIR_LIST = f.readlines()
print('MIX_DIR_LIST:{}'.format(MIX_DIR_LIST))
def prepare_info():
info_dict = {}
with open(os.path.join(MID_DATA_DIR, f'{TASK_TYPE}_ent2id.json'), encoding='utf-8') as f:
ent2id = json.load(f)
with open(os.path.join(RAW_DATA_DIR, 'test.json'), encoding='utf-8') as f:
info_dict['examples'] = json.load(f)
info_dict['id2ent'] = {ent2id[key]: key for key in ent2id.keys()}
info_dict['tokenizer'] = BertTokenizer(os.path.join(BERT_DIR, 'vocab.txt'))
return info_dict
def mixed_prepare_info(mixed='crf'):
info_dict = {}
with open(os.path.join(MID_DATA_DIR, f'{mixed}_ent2id.json'), encoding='utf-8') as f:
ent2id = json.load(f)
with open(os.path.join(RAW_DATA_DIR, 'test.json'), encoding='utf-8') as f:
info_dict['examples'] = json.load(f)
info_dict['id2ent'] = {ent2id[key]: key for key in ent2id.keys()}
info_dict['tokenizer'] = BertTokenizer(os.path.join(BERT_DIR, 'vocab.txt'))
return info_dict
def base_predict(model, device, info_dict, ensemble=False, mixed=''):
labels = defaultdict(list)
tokenizer = info_dict['tokenizer']
id2ent = info_dict['id2ent']
with torch.no_grad():
for _ex in info_dict['examples']:
ex_idx = _ex['id']
raw_text = _ex['text']
if not len(raw_text):
labels[ex_idx] = []
print('{}为空'.format(ex_idx))
continue
sentences = cut_sent(raw_text, MAX_SEQ_LEN)
start_index = 0
for sent in sentences:
sent_tokens = fine_grade_tokenize(sent, tokenizer)
encode_dict = tokenizer.encode_plus(text=sent_tokens,
max_length=MAX_SEQ_LEN,
is_pretokenized=True,
pad_to_max_length=False,
return_tensors='pt',
return_token_type_ids=True,
return_attention_mask=True)
model_inputs = {'token_ids': encode_dict['input_ids'],
'attention_masks': encode_dict['attention_mask'],
'token_type_ids': encode_dict['token_type_ids']}
for key in model_inputs:
model_inputs[key] = model_inputs[key].to(device)
if ensemble:
if TASK_TYPE == 'crf':
if VOTE:
decode_entities = model.vote_entities(model_inputs, sent, id2ent, THRESHOLD)
else:
pred_tokens = model.predict(model_inputs)[0]
decode_entities = crf_decode(pred_tokens, sent, id2ent)
else:
if VOTE:
decode_entities = model.vote_entities(model_inputs, sent, id2ent, THRESHOLD)
else:
start_logits, end_logits = model.predict(model_inputs)
start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)]
end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)]
decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
else:
if mixed:
if mixed == 'crf':
pred_tokens = model(**model_inputs)[0][0]
decode_entities = crf_decode(pred_tokens, sent, id2ent)
else:
start_logits, end_logits = model(**model_inputs)
start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)]
end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)]
decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
else:
if TASK_TYPE == 'crf':
pred_tokens = model(**model_inputs)[0][0]
decode_entities = crf_decode(pred_tokens, sent, id2ent)
else:
start_logits, end_logits = model(**model_inputs)
start_logits = start_logits[0].cpu().numpy()[1:1+len(sent)]
end_logits = end_logits[0].cpu().numpy()[1:1+len(sent)]
decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
for _ent_type in decode_entities:
for _ent in decode_entities[_ent_type]:
tmp_start = _ent[1] + start_index
tmp_end = tmp_start + len(_ent[0])
assert raw_text[tmp_start: tmp_end] == _ent[0]
labels[ex_idx].append((_ent_type, tmp_start, tmp_end, _ent[0]))
start_index += len(sent)
if not len(labels[ex_idx]):
labels[ex_idx] = []
return labels
def single_predict():
save_dir = os.path.join(SUBMIT_DIR, VERSION)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
info_dict = prepare_info()
if TASK_TYPE == 'crf':
model = CRFModel(bert_dir=BERT_DIR, num_tags=len(info_dict['id2ent']))
else:
model = SpanModel(bert_dir=BERT_DIR, num_tags=len(info_dict['id2ent'])+1)
print(f'Load model from {CKPT_PATH}')
model, device = load_model_and_parallel(model, GPU_IDS, CKPT_PATH)
model.eval()
labels = base_predict(model, device, info_dict)
for key in labels.keys():
with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f:
if not len(labels[key]):
print(key)
f.write("")
else:
for idx, _label in enumerate(labels[key]):
f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n')
def ensemble_predict():
save_dir = os.path.join(SUBMIT_DIR, VERSION)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
info_dict = prepare_info()
model_path_list = [x.strip() for x in ENSEMBLE_DIR_LIST]
print('model_path_list:{}'.format(model_path_list))
device = torch.device(f'cuda:{GPU_IDS[0]}')
if TASK_TYPE == 'crf':
model = EnsembleCRFModel(model_path_list=model_path_list,
bert_dir_list=BERT_DIR_LIST,
num_tags=len(info_dict['id2ent']),
device=device,
lamb=LAMBDA)
else:
model = EnsembleSpanModel(model_path_list=model_path_list,
bert_dir_list=BERT_DIR_LIST,
num_tags=len(info_dict['id2ent'])+1,
device=device)
labels = base_predict(model, device, info_dict, ensemble=True)
for key in labels.keys():
with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f:
if not len(labels[key]):
print(key)
f.write("")
else:
for idx, _label in enumerate(labels[key]):
f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n')
def mixed_predict():
save_dir = os.path.join(SUBMIT_DIR, VERSION)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
model_path_list = [x.strip() for x in MIX_DIR_LIST]
print('model_path_list:{}'.format(model_path_list))
all_labels = []
for i, model_path in enumerate(model_path_list):
if i <= 4:
info_dict = mixed_prepare_info(mixed='span')
model = SpanModel(bert_dir=MIX_BERT_DIR, num_tags=len(info_dict['id2ent']) + 1)
print(f'Load model from {model_path}')
model, device = load_model_and_parallel(model, GPU_IDS, model_path)
model.eval()
labels = base_predict(model, device, info_dict, ensemble=False, mixed='span')
else:
info_dict = mixed_prepare_info(mixed='crf')
model = CRFModel(bert_dir=MIX_BERT_DIR, num_tags=len(info_dict['id2ent']))
print(f'Load model from {model_path}')
model, device = load_model_and_parallel(model, GPU_IDS, model_path)
model.eval()
labels = base_predict(model, device, info_dict, ensemble=False, mixed='crf')
all_labels.append(labels)
labels = ensemble_vote(all_labels, THRESHOLD)
# for key in labels.keys():
for key in range(1500, 1997):
with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f:
if not len(labels[key]):
print(key)
f.write("")
else:
for idx, _label in enumerate(labels[key]):
f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n')
if __name__ == '__main__':
assert VERSION in ['single', 'ensemble', 'mixed'], 'VERSION mismatch'
if VERSION == 'single':
single_predict()
elif VERSION == 'ensemble':
if VOTE:
print("————————开始投票:————————")
ensemble_predict()
elif VERSION == 'mixed':
print("————————开始混合投票:————————")
mixed_predict()
# 压缩result.zip
import zipfile
def zip_file(src_dir):
zip_name = src_dir + '.zip'
z = zipfile.ZipFile(zip_name, 'w', zipfile.ZIP_DEFLATED)
for dirpath, dirnames, filenames in os.walk(src_dir):
fpath = dirpath.replace(src_dir, '')
fpath = fpath and fpath + os.sep or ''
for filename in filenames:
z.write(os.path.join(dirpath, filename), fpath + filename)
print('==压缩成功==')
z.close()
zip_file('./result')

32
convert_test_data.py Normal file
View File

@ -0,0 +1,32 @@
import os
import json
from tqdm import trange
def save_info(data_dir, data, desc):
with open(os.path.join(data_dir, f'{desc}.json'), 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def convert_test_data_to_json(test_dir, save_dir):
test_examples = []
# process test examples
for i in trange(1500, 1997):
with open(os.path.join(test_dir, f'{i}.txt'), encoding='utf-8') as f:
text = f.read()
test_examples.append({'id': i,
'text': text})
save_info(save_dir, test_examples, 'test')
if __name__ == '__main__':
test_dir = './tcdata/juesai'
save_dir = './data/raw_data_random'
convert_test_data_to_json(test_dir, save_dir)
print('测试数据转换完成')

View File

@ -0,0 +1,55 @@
{
"O": 0,
"B-DRUG_GROUP": 1,
"B-DRUG_DOSAGE": 2,
"B-FOOD": 3,
"B-DRUG_EFFICACY": 4,
"B-FOOD_GROUP": 5,
"B-SYMPTOM": 6,
"B-DISEASE_GROUP": 7,
"B-SYNDROME": 8,
"B-PERSON_GROUP": 9,
"B-DRUG_TASTE": 10,
"B-DRUG_INGREDIENT": 11,
"B-DRUG": 12,
"B-DISEASE": 13,
"I-DRUG_GROUP": 14,
"I-DRUG_DOSAGE": 15,
"I-FOOD": 16,
"I-DRUG_EFFICACY": 17,
"I-FOOD_GROUP": 18,
"I-SYMPTOM": 19,
"I-DISEASE_GROUP": 20,
"I-SYNDROME": 21,
"I-PERSON_GROUP": 22,
"I-DRUG_TASTE": 23,
"I-DRUG_INGREDIENT": 24,
"I-DRUG": 25,
"I-DISEASE": 26,
"E-DRUG_GROUP": 27,
"E-DRUG_DOSAGE": 28,
"E-FOOD": 29,
"E-DRUG_EFFICACY": 30,
"E-FOOD_GROUP": 31,
"E-SYMPTOM": 32,
"E-DISEASE_GROUP": 33,
"E-SYNDROME": 34,
"E-PERSON_GROUP": 35,
"E-DRUG_TASTE": 36,
"E-DRUG_INGREDIENT": 37,
"E-DRUG": 38,
"E-DISEASE": 39,
"S-DRUG_GROUP": 40,
"S-DRUG_DOSAGE": 41,
"S-FOOD": 42,
"S-DRUG_EFFICACY": 43,
"S-FOOD_GROUP": 44,
"S-SYMPTOM": 45,
"S-DISEASE_GROUP": 46,
"S-SYNDROME": 47,
"S-PERSON_GROUP": 48,
"S-DRUG_TASTE": 49,
"S-DRUG_INGREDIENT": 50,
"S-DRUG": 51,
"S-DISEASE": 52
}

View File

@ -0,0 +1,15 @@
{
"DRUG": "找出药物:用于预防、治疗、诊断疾病并具有康复与保健作用的物质。",
"DRUG_INGREDIENT": "找出药物成分:中药组成成分,指中药复方中所含有的所有与该复方临床应用目的密切相关的药理活性成分。",
"DISEASE": "找出疾病:指人体在一定原因的损害性作用下,因自稳调节紊乱而发生的异常生命活动过程,会影响生物体的部分或是所有器官。",
"SYMPTOM": "找出症状:指疾病过程中机体内的一系列机能、代谢和形态结构异常变化所引起的病人主观上的异常感觉或某些客观病态改变。",
"SYNDROME": "找出症候:概括为一系列有相互关联的症状总称,是指不同症状和体征的综合表现。",
"DISEASE_GROUP": "找出疾病分组:疾病涉及有人体组织部位的疾病名称的统称概念,非某项具体医学疾病。",
"FOOD": "找出食物:指能够满足机体正常生理和生化能量需求,并能延续正常寿命的物质。",
"FOOD_GROUP": "找出食物分组:中医中饮食养生中,将食物分为寒热温凉四性,同时中医药禁忌中对于具有某类共同属性食物的统称,记为食物分组。",
"PERSON_GROUP": "找出人群:中医药的适用及禁忌范围内相关特定人群。",
"DRUG_GROUP": "找出药品分组:具有某一类共同属性的药品类统称概念,非某项具体药品名。例子:止咳药、退烧药",
"DRUG_DOSAGE": "找出药物剂量:药物在供给临床使用前,均必须制成适合于医疗和预防应用的形式,成为药物剂型。",
"DRUG_TASTE": "找出药物性味:药品的性质和气味。例子:味甘、酸涩、气凉。",
"DRUG_EFFICACY": "找出中药功效:药品的主治功能和效果的统称。例子:滋阴补肾、去瘀生新、活血化瘀"
}

View File

@ -0,0 +1,15 @@
{
"DRUG_GROUP": 1,
"DRUG_DOSAGE": 2,
"FOOD": 3,
"DRUG_EFFICACY": 4,
"FOOD_GROUP": 5,
"SYMPTOM": 6,
"DISEASE_GROUP": 7,
"SYNDROME": 8,
"PERSON_GROUP": 9,
"DRUG_TASTE": 10,
"DRUG_INGREDIENT": 11,
"DRUG": 12,
"DISEASE": 13
}

24384
data/raw_data/dev.json Normal file

File diff suppressed because it is too large Load Diff

85588
data/raw_data/pseudo.json Normal file

File diff suppressed because one or more lines are too long

164121
data/raw_data/stack.json Normal file

File diff suppressed because one or more lines are too long

19417
data/raw_data/test.json Normal file

File diff suppressed because one or more lines are too long

139739
data/raw_data/train.json Normal file

File diff suppressed because one or more lines are too long

218
main.py Normal file
View File

@ -0,0 +1,218 @@
import time
import os
import json
import logging
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold
from src.utils.trainer import train
from src.utils.options import Args
from src.utils.model_utils import build_model
from src.utils.dataset_utils import NERDataset
from src.utils.evaluator import crf_evaluation, span_evaluation, mrc_evaluation
from src.utils.functions_utils import set_seed, get_model_path_list, load_model_and_parallel, get_time_dif
from src.preprocess.processor import NERProcessor, convert_examples_to_features
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO
)
def train_base(opt, train_examples, dev_examples=None):
with open(os.path.join(opt.mid_data_dir, f'{opt.task_type}_ent2id.json'), encoding='utf-8') as f:
ent2id = json.load(f)
train_features = convert_examples_to_features(opt.task_type, train_examples,
opt.max_seq_len, opt.bert_dir, ent2id)[0]
train_dataset = NERDataset(opt.task_type, train_features, 'train', use_type_embed=opt.use_type_embed)
if opt.task_type == 'crf':
model = build_model('crf', opt.bert_dir, num_tags=len(ent2id),
dropout_prob=opt.dropout_prob)
elif opt.task_type == 'mrc':
model = build_model('mrc', opt.bert_dir,
dropout_prob=opt.dropout_prob,
use_type_embed=opt.use_type_embed,
loss_type=opt.loss_type)
else:
model = build_model('span', opt.bert_dir, num_tags=len(ent2id)+1,
dropout_prob=opt.dropout_prob,
loss_type=opt.loss_type)
train(opt, model, train_dataset)
if dev_examples is not None:
dev_features, dev_callback_info = convert_examples_to_features(opt.task_type, dev_examples,
opt.max_seq_len, opt.bert_dir, ent2id)
dev_dataset = NERDataset(opt.task_type, dev_features, 'dev', use_type_embed=opt.use_type_embed)
dev_loader = DataLoader(dev_dataset, batch_size=opt.eval_batch_size,
shuffle=False, num_workers=0)
dev_info = (dev_loader, dev_callback_info)
model_path_list = get_model_path_list(opt.output_dir)
metric_str = ''
max_f1 = 0.
max_f1_step = 0
max_f1_path = ''
for idx, model_path in enumerate(model_path_list):
tmp_step = model_path.split('/')[-2].split('-')[-1]
model, device = load_model_and_parallel(model, opt.gpu_ids[0],
ckpt_path=model_path)
if opt.task_type == 'crf':
tmp_metric_str, tmp_f1 = crf_evaluation(model, dev_info, device, ent2id)
elif opt.task_type == 'mrc':
tmp_metric_str, tmp_f1 = mrc_evaluation(model, dev_info, device)
else:
tmp_metric_str, tmp_f1 = span_evaluation(model, dev_info, device, ent2id)
logger.info(f'In step {tmp_step}:\n {tmp_metric_str}')
metric_str += f'In step {tmp_step}:\n {tmp_metric_str}' + '\n\n'
if tmp_f1 > max_f1:
max_f1 = tmp_f1
max_f1_step = tmp_step
max_f1_path = model_path
max_metric_str = f'Max f1 is: {max_f1}, in step {max_f1_step}'
logger.info(max_metric_str)
metric_str += max_metric_str + '\n'
eval_save_path = os.path.join(opt.output_dir, 'eval_metric.txt')
with open(eval_save_path, 'a', encoding='utf-8') as f1:
f1.write(metric_str)
with open('./best_ckpt_path.txt', 'a', encoding='utf-8') as f2:
f2.write(max_f1_path + '\n')
del_dir_list = [os.path.join(opt.output_dir, path.split('/')[-2])
for path in model_path_list if path != max_f1_path]
import shutil
for x in del_dir_list:
shutil.rmtree(x)
logger.info('{}已删除'.format(x))
def training(opt):
if args.task_type == 'mrc':
# 62 for mrc query
processor = NERProcessor(opt.max_seq_len-62)
else:
processor = NERProcessor(opt.max_seq_len)
train_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'train.json'))
# add pseudo data to train data
pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json'))
train_raw_examples = train_raw_examples + pseudo_raw_examples
train_examples = processor.get_examples(train_raw_examples, 'train')
dev_examples = None
if opt.eval_model:
dev_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'dev.json'))
dev_examples = processor.get_examples(dev_raw_examples, 'dev')
train_base(opt, train_examples, dev_examples)
def stacking(opt):
logger.info('Start to KFold stack attribution model')
if args.task_type == 'mrc':
# 62 for mrc query
processor = NERProcessor(opt.max_seq_len-62)
else:
processor = NERProcessor(opt.max_seq_len)
kf = KFold(5, shuffle=True, random_state=42)
stack_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'stack.json'))
pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json'))
base_output_dir = opt.output_dir
for i, (train_ids, dev_ids) in enumerate(kf.split(stack_raw_examples)):
logger.info(f'Start to train the {i} fold')
train_raw_examples = [stack_raw_examples[_idx] for _idx in train_ids]
# add pseudo data to train data
train_raw_examples = train_raw_examples + pseudo_raw_examples
train_examples = processor.get_examples(train_raw_examples, 'train')
dev_raw_examples = [stack_raw_examples[_idx] for _idx in dev_ids]
dev_info = processor.get_examples(dev_raw_examples, 'dev')
tmp_output_dir = os.path.join(base_output_dir, f'v{i}')
opt.output_dir = tmp_output_dir
train_base(opt, train_examples, dev_info)
if __name__ == '__main__':
start_time = time.time()
logging.info('----------------开始计时----------------')
logging.info('----------------------------------------')
args = Args().get_parser()
assert args.mode in ['train', 'stack'], 'mode mismatch'
assert args.task_type in ['crf', 'span', 'mrc']
args.output_dir = os.path.join(args.output_dir, args.bert_type)
set_seed(args.seed)
if args.attack_train != '':
args.output_dir += f'_{args.attack_train}'
if args.weight_decay:
args.output_dir += '_wd'
if args.use_fp16:
args.output_dir += '_fp16'
if args.task_type == 'span':
args.output_dir += f'_{args.loss_type}'
if args.task_type == 'mrc':
if args.use_type_embed:
args.output_dir += f'_embed'
args.output_dir += f'_{args.loss_type}'
args.output_dir += f'_{args.task_type}'
if args.mode == 'stack':
args.output_dir += '_stack'
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
logger.info(f'{args.mode} {args.task_type} in max_seq_len {args.max_seq_len}')
if args.mode == 'train':
training(args)
else:
stacking(args)
time_dif = get_time_dif(start_time)
logging.info("----------本次容器运行时长:{}-----------".format(time_dif))

BIN
md_files/1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

BIN
md_files/10.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 159 KiB

BIN
md_files/11.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 207 KiB

BIN
md_files/12.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 188 KiB

BIN
md_files/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 667 KiB

BIN
md_files/2.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

BIN
md_files/3.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 57 KiB

BIN
md_files/4.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 279 KiB

BIN
md_files/5.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

BIN
md_files/6.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 177 KiB

BIN
md_files/7.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 104 KiB

BIN
md_files/8.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

BIN
md_files/9.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.2 KiB

35
run.sh Normal file
View File

@ -0,0 +1,35 @@
#!/usr/bin/env bash
export MID_DATA_DIR="./data/mid_data"
export RAW_DATA_DIR="./data/raw_data"
export OUTPUT_DIR="./out"
export GPU_IDS="0"
export BERT_TYPE="roberta_wwm" # roberta_wwm / roberta_wwm_large / uer_large
export BERT_DIR="../bert/torch_$BERT_TYPE"
export MODE="train"
export TASK_TYPE="crf"
python main.py \
--gpu_ids=$GPU_IDS \
--output_dir=$OUTPUT_DIR \
--mid_data_dir=$MID_DATA_DIR \
--mode=$MODE \
--task_type=$TASK_TYPE \
--raw_data_dir=$RAW_DATA_DIR \
--bert_dir=$BERT_DIR \
--bert_type=$BERT_TYPE \
--train_epochs=10 \
--swa_start=5 \
--attack_train="" \
--train_batch_size=24 \
--dropout_prob=0.1 \
--max_seq_len=512 \
--lr=2e-5 \
--other_lr=2e-3 \
--seed=123 \
--weight_decay=0.01 \
--loss_type='ls_ce' \
--eval_model \
#--use_fp16

Binary file not shown.

View File

@ -0,0 +1,183 @@
import os
import json
from tqdm import trange
from sklearn.model_selection import train_test_split, KFold
def save_info(data_dir, data, desc):
with open(os.path.join(data_dir, f'{desc}.json'), 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def convert_data_to_json(base_dir, save_data=False, save_dict=False):
stack_examples = []
pseudo_examples = []
test_examples = []
stack_dir = os.path.join(base_dir, 'train')
pseudo_dir = os.path.join(base_dir, 'pseudo')
test_dir = os.path.join(base_dir, 'test')
# process train examples
for i in trange(1000):
with open(os.path.join(stack_dir, f'{i}.txt'), encoding='utf-8') as f:
text = f.read()
labels = []
with open(os.path.join(stack_dir, f'{i}.ann'), encoding='utf-8') as f:
for line in f.readlines():
tmp_label = line.strip().split('\t')
assert len(tmp_label) == 3
tmp_mid = tmp_label[1].split()
tmp_label = [tmp_label[0]] + tmp_mid + [tmp_label[2]]
labels.append(tmp_label)
tmp_label[2] = int(tmp_label[2])
tmp_label[3] = int(tmp_label[3])
assert text[tmp_label[2]:tmp_label[3]] == tmp_label[-1], '{},{}索引抽取错误'.format(tmp_label, i)
stack_examples.append({'id': i,
'text': text,
'labels': labels,
'pseudo': 0})
# 构建实体知识库
kf = KFold(10)
entities = set()
ent_types = set()
for _now_id, _candidate_id in kf.split(stack_examples):
now = [stack_examples[_id] for _id in _now_id]
candidate = [stack_examples[_id] for _id in _candidate_id]
now_entities = set()
for _ex in now:
for _label in _ex['labels']:
ent_types.add(_label[1])
if len(_label[-1]) > 1:
now_entities.add(_label[-1])
entities.add(_label[-1])
# print(len(now_entities))
for _ex in candidate:
text = _ex['text']
candidate_entities = []
for _ent in now_entities:
if _ent in text:
candidate_entities.append(_ent)
_ex['candidate_entities'] = candidate_entities
assert len(ent_types) == 13
# process test examples predicted by the preliminary model
for i in trange(1000, 1500):
with open(os.path.join(pseudo_dir, f'{i}.txt'), encoding='utf-8') as f:
text = f.read()
candidate_entities = []
for _ent in entities:
if _ent in text:
candidate_entities.append(_ent)
labels = []
with open(os.path.join(pseudo_dir, f'{i}.ann'), encoding='utf-8') as f:
for line in f.readlines():
tmp_label = line.strip().split('\t')
assert len(tmp_label) == 3
tmp_mid = tmp_label[1].split()
tmp_label = [tmp_label[0]] + tmp_mid + [tmp_label[2]]
labels.append(tmp_label)
tmp_label[2] = int(tmp_label[2])
tmp_label[3] = int(tmp_label[3])
assert text[tmp_label[2]:tmp_label[3]] == tmp_label[-1], '{},{}索引抽取错误'.format(tmp_label, i)
pseudo_examples.append({'id': i,
'text': text,
'labels': labels,
'candidate_entities': candidate_entities,
'pseudo': 1})
# process test examples
for i in trange(1000, 1500):
with open(os.path.join(test_dir, f'{i}.txt'), encoding='utf-8') as f:
text = f.read()
candidate_entities = []
for _ent in entities:
if _ent in text:
candidate_entities.append(_ent)
test_examples.append({'id': i,
'text': text,
'candidate_entities': candidate_entities})
train, dev = train_test_split(stack_examples, shuffle=True, random_state=123, test_size=0.15)
if save_data:
save_info(base_dir, stack_examples, 'stack')
save_info(base_dir, train, 'train')
save_info(base_dir, dev, 'dev')
save_info(base_dir, test_examples, 'test')
save_info(base_dir, pseudo_examples, 'pseudo')
if save_dict:
ent_types = list(ent_types)
span_ent2id = {_type: i+1 for i, _type in enumerate(ent_types)}
ent_types = ['O'] + [p + '-' + _type for p in ['B', 'I', 'E', 'S'] for _type in list(ent_types)]
crf_ent2id = {ent: i for i, ent in enumerate(ent_types)}
mid_data_dir = os.path.join(os.path.split(base_dir)[0], 'mid_data')
if not os.path.exists(mid_data_dir):
os.mkdir(mid_data_dir)
save_info(mid_data_dir, span_ent2id, 'span_ent2id')
save_info(mid_data_dir, crf_ent2id, 'crf_ent2id')
def build_ent2query(data_dir):
# 利用比赛实体类型简介来描述 query
ent2query = {
# 药物
'DRUG': "找出药物:用于预防、治疗、诊断疾病并具有康复与保健作用的物质。",
# 药物成分
'DRUG_INGREDIENT': "找出药物成分:中药组成成分,指中药复方中所含有的所有与该复方临床应用目的密切相关的药理活性成分。",
# 疾病
'DISEASE': "找出疾病:指人体在一定原因的损害性作用下,因自稳调节紊乱而发生的异常生命活动过程,会影响生物体的部分或是所有器官。",
# 症状
'SYMPTOM': "找出症状:指疾病过程中机体内的一系列机能、代谢和形态结构异常变化所引起的病人主观上的异常感觉或某些客观病态改变。",
# 症候
'SYNDROME': "找出症候:概括为一系列有相互关联的症状总称,是指不同症状和体征的综合表现。",
# 疾病分组
'DISEASE_GROUP': "找出疾病分组:疾病涉及有人体组织部位的疾病名称的统称概念,非某项具体医学疾病。",
# 食物
'FOOD': "找出食物:指能够满足机体正常生理和生化能量需求,并能延续正常寿命的物质。",
# 食物分组
'FOOD_GROUP': "找出食物分组:中医中饮食养生中,将食物分为寒热温凉四性,"
"同时中医药禁忌中对于具有某类共同属性食物的统称,记为食物分组。",
# 人群
'PERSON_GROUP': "找出人群:中医药的适用及禁忌范围内相关特定人群。",
# 药品分组
'DRUG_GROUP': "找出药品分组:具有某一类共同属性的药品类统称概念,非某项具体药品名。例子:止咳药、退烧药",
# 药物剂量
'DRUG_DOSAGE': "找出药物剂量:药物在供给临床使用前,均必须制成适合于医疗和预防应用的形式,成为药物剂型。",
# 药物性味
'DRUG_TASTE': "找出药物性味:药品的性质和气味。例子:味甘、酸涩、气凉。",
# 中药功效
'DRUG_EFFICACY': "找出中药功效:药品的主治功能和效果的统称。例子:滋阴补肾、去瘀生新、活血化瘀"
}
with open(os.path.join(data_dir, 'mrc_ent2id.json'), 'w', encoding='utf-8') as f:
json.dump(ent2query, f, ensure_ascii=False, indent=2)
if __name__ == '__main__':
convert_data_to_json('../../data/raw_data', save_data=True, save_dict=True)
build_ent2query('../../data/mid_data')

654
src/preprocess/processor.py Normal file
View File

@ -0,0 +1,654 @@
import os
import re
import json
import logging
from transformers import BertTokenizer
from collections import defaultdict
import random
logger = logging.getLogger(__name__)
ENTITY_TYPES = ['DRUG', 'DRUG_INGREDIENT', 'DISEASE', 'SYMPTOM', 'SYNDROME', 'DISEASE_GROUP',
'FOOD', 'FOOD_GROUP', 'PERSON_GROUP', 'DRUG_GROUP', 'DRUG_DOSAGE', 'DRUG_TASTE',
'DRUG_EFFICACY']
class InputExample:
def __init__(self,
set_type,
text,
labels=None,
pseudo=None,
distant_labels=None):
self.set_type = set_type
self.text = text
self.labels = labels
self.pseudo = pseudo
self.distant_labels = distant_labels
class BaseFeature:
def __init__(self,
token_ids,
attention_masks,
token_type_ids):
# BERT 输入
self.token_ids = token_ids
self.attention_masks = attention_masks
self.token_type_ids = token_type_ids
class CRFFeature(BaseFeature):
def __init__(self,
token_ids,
attention_masks,
token_type_ids,
labels=None,
pseudo=None,
distant_labels=None):
super(CRFFeature, self).__init__(token_ids=token_ids,
attention_masks=attention_masks,
token_type_ids=token_type_ids)
# labels
self.labels = labels
# pseudo
self.pseudo = pseudo
# distant labels
self.distant_labels = distant_labels
class SpanFeature(BaseFeature):
def __init__(self,
token_ids,
attention_masks,
token_type_ids,
start_ids=None,
end_ids=None,
pseudo=None):
super(SpanFeature, self).__init__(token_ids=token_ids,
attention_masks=attention_masks,
token_type_ids=token_type_ids)
self.start_ids = start_ids
self.end_ids = end_ids
# pseudo
self.pseudo = pseudo
class MRCFeature(BaseFeature):
def __init__(self,
token_ids,
attention_masks,
token_type_ids,
ent_type=None,
start_ids=None,
end_ids=None,
pseudo=None):
super(MRCFeature, self).__init__(token_ids=token_ids,
attention_masks=attention_masks,
token_type_ids=token_type_ids)
self.ent_type = ent_type
self.start_ids = start_ids
self.end_ids = end_ids
# pseudo
self.pseudo = pseudo
class NERProcessor:
def __init__(self, cut_sent_len=256):
self.cut_sent_len = cut_sent_len
@staticmethod
def read_json(file_path):
with open(file_path, encoding='utf-8') as f:
raw_examples = json.load(f)
return raw_examples
@staticmethod
def _refactor_labels(sent, labels, distant_labels, start_index):
"""
分句后需要重构 labels offset
:param sent: 切分并重新合并后的句子
:param labels: 原始文档级的 labels
:param distant_labels: 远程监督 label
:param start_index: 该句子在文档中的起始 offset
:return (type, entity, offset)
"""
new_labels, new_distant_labels = [], []
end_index = start_index + len(sent)
for _label in labels:
if start_index <= _label[2] <= _label[3] <= end_index:
new_offset = _label[2] - start_index
assert sent[new_offset: new_offset + len(_label[-1])] == _label[-1]
new_labels.append((_label[1], _label[-1], new_offset))
# label 被截断的情况
elif _label[2] < end_index < _label[3]:
raise RuntimeError(f'{sent}, {_label}')
for _label in distant_labels:
if _label in sent:
new_distant_labels.append(_label)
return new_labels, new_distant_labels
def get_examples(self, raw_examples, set_type):
examples = []
for i, item in enumerate(raw_examples):
text = item['text']
distant_labels = item['candidate_entities']
pseudo = item['pseudo']
sentences = cut_sent(text, self.cut_sent_len)
start_index = 0
for sent in sentences:
labels, tmp_distant_labels = self._refactor_labels(sent, item['labels'], distant_labels, start_index)
start_index += len(sent)
examples.append(InputExample(set_type=set_type,
text=sent,
labels=labels,
pseudo=pseudo,
distant_labels=tmp_distant_labels))
return examples
def fine_grade_tokenize(raw_text, tokenizer):
"""
序列标注任务 BERT 分词器可能会导致标注偏移
char-level tokenize
"""
tokens = []
for _ch in raw_text:
if _ch in [' ', '\t', '\n']:
tokens.append('[BLANK]')
else:
if not len(tokenizer.tokenize(_ch)):
tokens.append('[INV]')
else:
tokens.append(_ch)
return tokens
def cut_sentences_v1(sent):
"""
the first rank of sentence cut
"""
sent = re.sub('([。!?\?])([^”’])', r"\1\n\2", sent) # 单字符断句符
sent = re.sub('(\.{6})([^”’])', r"\1\n\2", sent) # 英文省略号
sent = re.sub('(\{2})([^”’])', r"\1\n\2", sent) # 中文省略号
sent = re.sub('([。!?\?][”’])([^,。!?\?])', r"\1\n\2", sent)
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后
return sent.split("\n")
def cut_sentences_v2(sent):
"""
the second rank of spilt sentence, split '' | ';'
"""
sent = re.sub('([;])([^”’])', r"\1\n\2", sent)
return sent.split("\n")
def cut_sent(text, max_seq_len):
# 将句子分句,细粒度分句后再重新合并
sentences = []
# 细粒度划分
sentences_v1 = cut_sentences_v1(text)
for sent_v1 in sentences_v1:
if len(sent_v1) > max_seq_len - 2:
sentences_v2 = cut_sentences_v2(sent_v1)
sentences.extend(sentences_v2)
else:
sentences.append(sent_v1)
assert ''.join(sentences) == text
# 合并
merged_sentences = []
start_index_ = 0
while start_index_ < len(sentences):
tmp_text = sentences[start_index_]
end_index_ = start_index_ + 1
while end_index_ < len(sentences) and \
len(tmp_text) + len(sentences[end_index_]) <= max_seq_len - 2:
tmp_text += sentences[end_index_]
end_index_ += 1
start_index_ = end_index_
merged_sentences.append(tmp_text)
return merged_sentences
def sent_mask(sent, stop_mask_range_list, mask_prob=0.15):
"""
将句子中的词以 mask prob 的概率随机 mask
其中 85% 概率被置为 [mask] 15% 的概率不变
:param sent: list of segment words
:param stop_mask_range_list: 不能 mask 的区域
:param mask_prob: max mask nums: len(sent) * max_mask_prob
:return:
"""
max_mask_token_nums = int(len(sent) * mask_prob)
mask_nums = 0
mask_sent = []
for i in range(len(sent)):
flag = False
for _stop_range in stop_mask_range_list:
if _stop_range[0] <= i <= _stop_range[1]:
flag = True
break
if flag:
mask_sent.append(sent[i])
continue
if mask_nums < max_mask_token_nums:
# mask_prob 的概率进行 mask, 80% 概率被置为 [mask]10% 概率被替换, 10% 的概率不变
if random.random() < mask_prob:
mask_sent.append('[MASK]')
mask_nums += 1
else:
mask_sent.append(sent[i])
else:
mask_sent.append(sent[i])
return mask_sent
def convert_crf_example(ex_idx, example: InputExample, tokenizer: BertTokenizer,
max_seq_len, ent2id):
set_type = example.set_type
raw_text = example.text
entities = example.labels
pseudo = example.pseudo
callback_info = (raw_text,)
callback_labels = {x: [] for x in ENTITY_TYPES}
for _label in entities:
callback_labels[_label[0]].append((_label[1], _label[2]))
callback_info += (callback_labels,)
tokens = fine_grade_tokenize(raw_text, tokenizer)
assert len(tokens) == len(raw_text)
label_ids = None
if set_type == 'train':
# information for dev callback
label_ids = [0] * len(tokens)
# tag labels ent ex. (T1, DRUG_DOSAGE, 447, 450, 小蜜丸)
for ent in entities:
ent_type = ent[0]
ent_start = ent[-1]
ent_end = ent_start + len(ent[1]) - 1
if ent_start == ent_end:
label_ids[ent_start] = ent2id['S-' + ent_type]
else:
label_ids[ent_start] = ent2id['B-' + ent_type]
label_ids[ent_end] = ent2id['E-' + ent_type]
for i in range(ent_start + 1, ent_end):
label_ids[i] = ent2id['I-' + ent_type]
if len(label_ids) > max_seq_len - 2:
label_ids = label_ids[:max_seq_len - 2]
label_ids = [0] + label_ids + [0]
# pad
if len(label_ids) < max_seq_len:
pad_length = max_seq_len - len(label_ids)
label_ids = label_ids + [0] * pad_length # CLS SEP PAD label都为O
assert len(label_ids) == max_seq_len, f'{len(label_ids)}'
encode_dict = tokenizer.encode_plus(text=tokens,
max_length=max_seq_len,
pad_to_max_length=True,
is_pretokenized=True,
return_token_type_ids=True,
return_attention_mask=True)
token_ids = encode_dict['input_ids']
attention_masks = encode_dict['attention_mask']
token_type_ids = encode_dict['token_type_ids']
# if ex_idx < 3:
# logger.info(f"*** {set_type}_example-{ex_idx} ***")
# logger.info(f'text: {" ".join(tokens)}')
# logger.info(f"token_ids: {token_ids}")
# logger.info(f"attention_masks: {attention_masks}")
# logger.info(f"token_type_ids: {token_type_ids}")
# logger.info(f"labels: {label_ids}")
feature = CRFFeature(
# bert inputs
token_ids=token_ids,
attention_masks=attention_masks,
token_type_ids=token_type_ids,
labels=label_ids,
pseudo=pseudo
)
return feature, callback_info
def convert_span_example(ex_idx, example: InputExample, tokenizer: BertTokenizer,
max_seq_len, ent2id):
set_type = example.set_type
raw_text = example.text
entities = example.labels
pseudo = example.pseudo
tokens = fine_grade_tokenize(raw_text, tokenizer)
assert len(tokens) == len(raw_text)
callback_labels = {x: [] for x in ENTITY_TYPES}
for _label in entities:
callback_labels[_label[0]].append((_label[1], _label[2]))
callback_info = (raw_text, callback_labels,)
start_ids, end_ids = None, None
if set_type == 'train':
start_ids = [0] * len(tokens)
end_ids = [0] * len(tokens)
for _ent in entities:
ent_type = ent2id[_ent[0]]
ent_start = _ent[-1]
ent_end = ent_start + len(_ent[1]) - 1
start_ids[ent_start] = ent_type
end_ids[ent_end] = ent_type
if len(start_ids) > max_seq_len - 2:
start_ids = start_ids[:max_seq_len - 2]
end_ids = end_ids[:max_seq_len - 2]
start_ids = [0] + start_ids + [0]
end_ids = [0] + end_ids + [0]
# pad
if len(start_ids) < max_seq_len:
pad_length = max_seq_len - len(start_ids)
start_ids = start_ids + [0] * pad_length # CLS SEP PAD label都为O
end_ids = end_ids + [0] * pad_length
assert len(start_ids) == max_seq_len
assert len(end_ids) == max_seq_len
encode_dict = tokenizer.encode_plus(text=tokens,
max_length=max_seq_len,
pad_to_max_length=True,
is_pretokenized=True,
return_token_type_ids=True,
return_attention_mask=True)
token_ids = encode_dict['input_ids']
attention_masks = encode_dict['attention_mask']
token_type_ids = encode_dict['token_type_ids']
# if ex_idx < 3:
# logger.info(f"*** {set_type}_example-{ex_idx} ***")
# logger.info(f'text: {" ".join(tokens)}')
# logger.info(f"token_ids: {token_ids}")
# logger.info(f"attention_masks: {attention_masks}")
# logger.info(f"token_type_ids: {token_type_ids}")
# if start_ids and end_ids:
# logger.info(f"start_ids: {start_ids}")
# logger.info(f"end_ids: {end_ids}")
feature = SpanFeature(token_ids=token_ids,
attention_masks=attention_masks,
token_type_ids=token_type_ids,
start_ids=start_ids,
end_ids=end_ids,
pseudo=pseudo)
return feature, callback_info
def convert_mrc_example(ex_idx, example: InputExample, tokenizer: BertTokenizer,
max_seq_len, ent2id, ent2query, mask_prob=None):
set_type = example.set_type
text_b = example.text
entities = example.labels
pseudo = example.pseudo
features = []
callback_info = []
tokens_b = fine_grade_tokenize(text_b, tokenizer)
assert len(tokens_b) == len(text_b)
label_dict = defaultdict(list)
for ent in entities:
ent_type = ent[0]
ent_start = ent[-1]
ent_end = ent_start + len(ent[1]) - 1
label_dict[ent_type].append((ent_start, ent_end, ent[1]))
# 训练数据中构造
if set_type == 'train':
# 每一类为一个 example
# for _type in label_dict.keys():
for _type in ENTITY_TYPES:
start_ids = [0] * len(tokens_b)
end_ids = [0] * len(tokens_b)
stop_mask_ranges = []
text_a = ent2query[_type]
tokens_a = fine_grade_tokenize(text_a, tokenizer)
for _label in label_dict[_type]:
start_ids[_label[0]] = 1
end_ids[_label[1]] = 1
stop_mask_ranges.append((_label[0], _label[1]))
if len(start_ids) > max_seq_len - len(tokens_a) - 3:
start_ids = start_ids[:max_seq_len - len(tokens_a) - 3]
end_ids = end_ids[:max_seq_len - len(tokens_a) - 3]
print('产生了不该有的截断')
start_ids = [0] + [0] * len(tokens_a) + [0] + start_ids + [0]
end_ids = [0] + [0] * len(tokens_a) + [0] + end_ids + [0]
# pad
if len(start_ids) < max_seq_len:
pad_length = max_seq_len - len(start_ids)
start_ids = start_ids + [0] * pad_length # CLS SEP PAD label都为O
end_ids = end_ids + [0] * pad_length
assert len(start_ids) == max_seq_len
assert len(end_ids) == max_seq_len
# 随机mask
if mask_prob:
tokens_b = sent_mask(tokens_b, stop_mask_ranges, mask_prob=mask_prob)
encode_dict = tokenizer.encode_plus(text=tokens_a,
text_pair=tokens_b,
max_length=max_seq_len,
pad_to_max_length=True,
truncation_strategy='only_second',
is_pretokenized=True,
return_token_type_ids=True,
return_attention_mask=True)
token_ids = encode_dict['input_ids']
attention_masks = encode_dict['attention_mask']
token_type_ids = encode_dict['token_type_ids']
# if ex_idx < 3:
# logger.info(f"*** {set_type}_example-{ex_idx} ***")
# logger.info(f'text: {" ".join(tokens_b)}')
# logger.info(f"token_ids: {token_ids}")
# logger.info(f"attention_masks: {attention_masks}")
# logger.info(f"token_type_ids: {token_type_ids}")
# logger.info(f'entity type: {_type}')
# logger.info(f"start_ids: {start_ids}")
# logger.info(f"end_ids: {end_ids}")
feature = MRCFeature(token_ids=token_ids,
attention_masks=attention_masks,
token_type_ids=token_type_ids,
ent_type=ent2id[_type],
start_ids=start_ids,
end_ids=end_ids,
pseudo=pseudo
)
features.append(feature)
# 测试数据构造,为每一类单独构造一个 example
else:
for _type in ENTITY_TYPES:
text_a = ent2query[_type]
tokens_a = fine_grade_tokenize(text_a, tokenizer)
encode_dict = tokenizer.encode_plus(text=tokens_a,
text_pair=tokens_b,
max_length=max_seq_len,
pad_to_max_length=True,
truncation_strategy='only_second',
is_pretokenized=True,
return_token_type_ids=True,
return_attention_mask=True)
token_ids = encode_dict['input_ids']
attention_masks = encode_dict['attention_mask']
token_type_ids = encode_dict['token_type_ids']
tmp_callback = (text_b, len(tokens_a) + 2, _type) # (text, text_offset, type, labels)
tmp_callback_labels = []
for _label in label_dict[_type]:
tmp_callback_labels.append((_label[2], _label[0]))
tmp_callback += (tmp_callback_labels, )
callback_info.append(tmp_callback)
feature = MRCFeature(token_ids=token_ids,
attention_masks=attention_masks,
token_type_ids=token_type_ids,
ent_type=ent2id[_type])
features.append(feature)
return features, callback_info
def convert_examples_to_features(task_type, examples, max_seq_len, bert_dir, ent2id):
assert task_type in ['crf', 'span', 'mrc']
tokenizer = BertTokenizer(os.path.join(bert_dir, 'vocab.txt'))
features = []
callback_info = []
logger.info(f'Convert {len(examples)} examples to features')
type2id = {x: i for i, x in enumerate(ENTITY_TYPES)}
for i, example in enumerate(examples):
if task_type == 'crf':
feature, tmp_callback = convert_crf_example(
ex_idx=i,
example=example,
max_seq_len=max_seq_len,
ent2id=ent2id,
tokenizer=tokenizer
)
elif task_type == 'mrc':
feature, tmp_callback = convert_mrc_example(
ex_idx=i,
example=example,
max_seq_len=max_seq_len,
ent2id=type2id,
ent2query=ent2id,
tokenizer=tokenizer
)
else:
feature, tmp_callback = convert_span_example(
ex_idx=i,
example=example,
max_seq_len=max_seq_len,
ent2id=ent2id,
tokenizer=tokenizer
)
if feature is None:
continue
if task_type == 'mrc':
features.extend(feature)
callback_info.extend(tmp_callback)
else:
features.append(feature)
callback_info.append(tmp_callback)
logger.info(f'Build {len(features)} features')
out = (features, )
if not len(callback_info):
return out
type_weight = {} # 统计每一类的比例,用于计算 micro-f1
for _type in ENTITY_TYPES:
type_weight[_type] = 0.
count = 0.
if task_type == 'mrc':
for _callback in callback_info:
type_weight[_callback[-2]] += len(_callback[-1])
count += len(_callback[-1])
else:
for _callback in callback_info:
for _type in _callback[1]:
type_weight[_type] += len(_callback[1][_type])
count += len(_callback[1][_type])
for key in type_weight:
type_weight[key] /= count
out += ((callback_info, type_weight), )
return out
if __name__ == '__main__':
pass

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,76 @@
import torch
import torch.nn as nn
# FGM
class FGM:
def __init__(self, model: nn.Module, eps=1.):
self.model = (
model.module if hasattr(model, "module") else model
)
self.eps = eps
self.backup = {}
# only attack word embedding
def attack(self, emb_name='word_embeddings'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm and not torch.isnan(norm):
r_at = self.eps * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='word_embeddings'):
for name, para in self.model.named_parameters():
if para.requires_grad and emb_name in name:
assert name in self.backup
para.data = self.backup[name]
self.backup = {}
# PGD
class PGD:
def __init__(self, model, eps=1., alpha=0.3):
self.model = (
model.module if hasattr(model, "module") else model
)
self.eps = eps
self.alpha = alpha
self.emb_backup = {}
self.grad_backup = {}
def attack(self, emb_name='word_embeddings', is_first_attack=False):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
if is_first_attack:
self.emb_backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm):
r_at = self.alpha * param.grad / norm
param.data.add_(r_at)
param.data = self.project(name, param.data)
def restore(self, emb_name='word_embeddings'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.emb_backup
param.data = self.emb_backup[name]
self.emb_backup = {}
def project(self, param_name, param_data):
r = param_data - self.emb_backup[param_name]
if torch.norm(r) > self.eps:
r = self.eps * r / torch.norm(r)
return self.emb_backup[param_name] + r
def backup_grad(self):
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
self.grad_backup[name] = param.grad.clone()
def restore_grad(self):
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
param.grad = self.grad_backup[name]

View File

@ -0,0 +1,53 @@
import torch
from torch.utils.data import Dataset
class NERDataset(Dataset):
def __init__(self, task_type, features, mode, **kwargs):
self.nums = len(features)
self.token_ids = [torch.tensor(example.token_ids).long() for example in features]
self.attention_masks = [torch.tensor(example.attention_masks).float() for example in features]
self.token_type_ids = [torch.tensor(example.token_type_ids).long() for example in features]
self.labels = None
self.start_ids, self.end_ids = None, None
self.ent_type = None
self.pseudo = None
if mode == 'train':
self.pseudo = [torch.tensor(example.pseudo).long() for example in features]
if task_type == 'crf':
self.labels = [torch.tensor(example.labels) for example in features]
else:
self.start_ids = [torch.tensor(example.start_ids).long() for example in features]
self.end_ids = [torch.tensor(example.end_ids).long() for example in features]
if kwargs.pop('use_type_embed', False):
self.ent_type = [torch.tensor(example.ent_type) for example in features]
def __len__(self):
return self.nums
def __getitem__(self, index):
data = {'token_ids': self.token_ids[index],
'attention_masks': self.attention_masks[index],
'token_type_ids': self.token_type_ids[index]}
if self.ent_type is not None:
data['ent_type'] = self.ent_type[index]
if self.labels is not None:
data['labels'] = self.labels[index]
if self.pseudo is not None:
data['pseudo'] = self.pseudo[index]
if self.start_ids is not None:
data['start_ids'] = self.start_ids[index]
data['end_ids'] = self.end_ids[index]
return data

286
src/utils/evaluator.py Normal file
View File

@ -0,0 +1,286 @@
import torch
import logging
import numpy as np
from collections import defaultdict
from src.preprocess.processor import ENTITY_TYPES
logger = logging.getLogger(__name__)
def get_base_out(model, loader, device):
"""
每一个任务的 forward 都一样封装起来
"""
model.eval()
with torch.no_grad():
for idx, _batch in enumerate(loader):
for key in _batch.keys():
_batch[key] = _batch[key].to(device)
tmp_out = model(**_batch)
yield tmp_out
def crf_decode(decode_tokens, raw_text, id2ent):
"""
CRF 解码用于解码 time loc 的提取
"""
predict_entities = {}
decode_tokens = decode_tokens[1:-1] # 除去 CLS SEP token
index_ = 0
while index_ < len(decode_tokens):
token_label = id2ent[decode_tokens[index_]].split('-')
if token_label[0].startswith('S'):
token_type = token_label[1]
tmp_ent = raw_text[index_]
if token_type not in predict_entities:
predict_entities[token_type] = [(tmp_ent, index_)]
else:
predict_entities[token_type].append((tmp_ent, int(index_)))
index_ += 1
elif token_label[0].startswith('B'):
token_type = token_label[1]
start_index = index_
index_ += 1
while index_ < len(decode_tokens):
temp_token_label = id2ent[decode_tokens[index_]].split('-')
if temp_token_label[0].startswith('I') and token_type == temp_token_label[1]:
index_ += 1
elif temp_token_label[0].startswith('E') and token_type == temp_token_label[1]:
end_index = index_
index_ += 1
tmp_ent = raw_text[start_index: end_index + 1]
if token_type not in predict_entities:
predict_entities[token_type] = [(tmp_ent, start_index)]
else:
predict_entities[token_type].append((tmp_ent, int(start_index)))
break
else:
break
else:
index_ += 1
return predict_entities
# 严格解码 baseline
def span_decode(start_logits, end_logits, raw_text, id2ent):
predict_entities = defaultdict(list)
start_pred = np.argmax(start_logits, -1)
end_pred = np.argmax(end_logits, -1)
for i, s_type in enumerate(start_pred):
if s_type == 0:
continue
for j, e_type in enumerate(end_pred[i:]):
if s_type == e_type:
tmp_ent = raw_text[i:i + j + 1]
predict_entities[id2ent[s_type]].append((tmp_ent, i))
break
return predict_entities
# 严格解码 baseline
def mrc_decode(start_logits, end_logits, raw_text):
predict_entities = []
start_pred = np.argmax(start_logits, -1)
end_pred = np.argmax(end_logits, -1)
for i, s_type in enumerate(start_pred):
if s_type == 0:
continue
for j, e_type in enumerate(end_pred[i:]):
if s_type == e_type:
tmp_ent = raw_text[i:i+j+1]
predict_entities.append((tmp_ent, i))
break
return predict_entities
def calculate_metric(gt, predict):
"""
计算 tp fp fn
"""
tp, fp, fn = 0, 0, 0
for entity_predict in predict:
flag = 0
for entity_gt in gt:
if entity_predict[0] == entity_gt[0] and entity_predict[1] == entity_gt[1]:
flag = 1
tp += 1
break
if flag == 0:
fp += 1
fn = len(gt) - tp
return np.array([tp, fp, fn])
def get_p_r_f(tp, fp, fn):
p = tp / (tp + fp) if tp + fp != 0 else 0
r = tp / (tp + fn) if tp + fn != 0 else 0
f1 = 2 * p * r / (p + r) if p + r != 0 else 0
return np.array([p, r, f1])
def crf_evaluation(model, dev_info, device, ent2id):
dev_loader, (dev_callback_info, type_weight) = dev_info
pred_tokens = []
for tmp_pred in get_base_out(model, dev_loader, device):
pred_tokens.extend(tmp_pred[0])
assert len(pred_tokens) == len(dev_callback_info)
id2ent = {ent2id[key]: key for key in ent2id.keys()}
role_metric = np.zeros([13, 3])
mirco_metrics = np.zeros(3)
for tmp_tokens, tmp_callback in zip(pred_tokens, dev_callback_info):
text, gt_entities = tmp_callback
tmp_metric = np.zeros([13, 3])
pred_entities = crf_decode(tmp_tokens, text, id2ent)
for idx, _type in enumerate(ENTITY_TYPES):
if _type not in pred_entities:
pred_entities[_type] = []
tmp_metric[idx] += calculate_metric(gt_entities[_type], pred_entities[_type])
role_metric += tmp_metric
for idx, _type in enumerate(ENTITY_TYPES):
temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2])
mirco_metrics += temp_metric * type_weight[_type]
metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \
f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}'
return metric_str, mirco_metrics[2]
def span_evaluation(model, dev_info, device, ent2id):
dev_loader, (dev_callback_info, type_weight) = dev_info
start_logits, end_logits = None, None
model.eval()
for tmp_pred in get_base_out(model, dev_loader, device):
tmp_start_logits = tmp_pred[0].cpu().numpy()
tmp_end_logits = tmp_pred[1].cpu().numpy()
if start_logits is None:
start_logits = tmp_start_logits
end_logits = tmp_end_logits
else:
start_logits = np.append(start_logits, tmp_start_logits, axis=0)
end_logits = np.append(end_logits, tmp_end_logits, axis=0)
assert len(start_logits) == len(end_logits) == len(dev_callback_info)
role_metric = np.zeros([13, 3])
mirco_metrics = np.zeros(3)
id2ent = {ent2id[key]: key for key in ent2id.keys()}
for tmp_start_logits, tmp_end_logits, tmp_callback \
in zip(start_logits, end_logits, dev_callback_info):
text, gt_entities = tmp_callback
tmp_start_logits = tmp_start_logits[1:1 + len(text)]
tmp_end_logits = tmp_end_logits[1:1 + len(text)]
pred_entities = span_decode(tmp_start_logits, tmp_end_logits, text, id2ent)
for idx, _type in enumerate(ENTITY_TYPES):
if _type not in pred_entities:
pred_entities[_type] = []
role_metric[idx] += calculate_metric(gt_entities[_type], pred_entities[_type])
for idx, _type in enumerate(ENTITY_TYPES):
temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2])
mirco_metrics += temp_metric * type_weight[_type]
metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \
f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}'
return metric_str, mirco_metrics[2]
def mrc_evaluation(model, dev_info, device):
dev_loader, (dev_callback_info, type_weight) = dev_info
start_logits, end_logits = None, None
model.eval()
for tmp_pred in get_base_out(model, dev_loader, device):
tmp_start_logits = tmp_pred[0].cpu().numpy()
tmp_end_logits = tmp_pred[1].cpu().numpy()
if start_logits is None:
start_logits = tmp_start_logits
end_logits = tmp_end_logits
else:
start_logits = np.append(start_logits, tmp_start_logits, axis=0)
end_logits = np.append(end_logits, tmp_end_logits, axis=0)
assert len(start_logits) == len(end_logits) == len(dev_callback_info)
role_metric = np.zeros([13, 3])
mirco_metrics = np.zeros(3)
id2ent = {x: i for i, x in enumerate(ENTITY_TYPES)}
for tmp_start_logits, tmp_end_logits, tmp_callback \
in zip(start_logits, end_logits, dev_callback_info):
text, text_offset, ent_type, gt_entities = tmp_callback
tmp_start_logits = tmp_start_logits[text_offset:text_offset+len(text)]
tmp_end_logits = tmp_end_logits[text_offset:text_offset+len(text)]
pred_entities = mrc_decode(tmp_start_logits, tmp_end_logits, text)
role_metric[id2ent[ent_type]] += calculate_metric(gt_entities, pred_entities)
for idx, _type in enumerate(ENTITY_TYPES):
temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2])
mirco_metrics += temp_metric * type_weight[_type]
metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \
f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}'
return metric_str, mirco_metrics[2]

View File

@ -0,0 +1,159 @@
import os
import copy
import torch
import random
import numpy as np
from collections import defaultdict
from datetime import timedelta
import time
import logging
logger = logging.getLogger(__name__)
def get_time_dif(start_time):
"""
获取已经使用的时间
:param start_time:
:return:
"""
end_time = time.time()
time_dif = end_time - start_time
return timedelta(seconds=int(round(time_dif)))
def set_seed(seed):
"""
设置随机种子
:param seed:
:return:
"""
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
def load_model_and_parallel(model, gpu_ids, ckpt_path=None, strict=True):
"""
加载模型 & 放置到 GPU 单卡 / 多卡
"""
gpu_ids = gpu_ids.split(',')
# set to device to the first cuda
device = torch.device("cpu" if gpu_ids[0] == '-1' else "cuda:" + gpu_ids[0])
if ckpt_path is not None:
logger.info(f'Load ckpt from {ckpt_path}')
model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu')), strict=strict)
model.to(device)
if len(gpu_ids) > 1:
logger.info(f'Use multi gpus in: {gpu_ids}')
gpu_ids = [int(x) for x in gpu_ids]
model = torch.nn.DataParallel(model, device_ids=gpu_ids)
else:
logger.info(f'Use single gpu in: {gpu_ids}')
return model, device
def get_model_path_list(base_dir):
"""
从文件夹中获取 model.pt 的路径
"""
model_lists = []
for root, dirs, files in os.walk(base_dir):
for _file in files:
if 'model.pt' == _file:
model_lists.append(os.path.join(root, _file))
model_lists = sorted(model_lists,
key=lambda x: (x.split('/')[-3], int(x.split('/')[-2].split('-')[-1])))
return model_lists
def swa(model, model_dir, swa_start=1):
"""
swa 滑动平均模型一般在训练平稳阶段再使用 SWA
"""
model_path_list = get_model_path_list(model_dir)
assert 1 <= swa_start < len(model_path_list) - 1, \
f'Using swa, swa start should smaller than {len(model_path_list) - 1} and bigger than 0'
swa_model = copy.deepcopy(model)
swa_n = 0.
with torch.no_grad():
for _ckpt in model_path_list[swa_start:]:
logger.info(f'Load model from {_ckpt}')
model.load_state_dict(torch.load(_ckpt, map_location=torch.device('cpu')))
tmp_para_dict = dict(model.named_parameters())
alpha = 1. / (swa_n + 1.)
for name, para in swa_model.named_parameters():
para.copy_(tmp_para_dict[name].data.clone() * alpha + para.data.clone() * (1. - alpha))
swa_n += 1
# use 100000 to represent swa to avoid clash
swa_model_dir = os.path.join(model_dir, f'checkpoint-100000')
if not os.path.exists(swa_model_dir):
os.mkdir(swa_model_dir)
logger.info(f'Save swa model in: {swa_model_dir}')
swa_model_path = os.path.join(swa_model_dir, 'model.pt')
torch.save(swa_model.state_dict(), swa_model_path)
return swa_model
def vote(entities_list, threshold=0.9):
"""
实体级别的投票方式 (entity_type, entity_start, entity_end, entity_text)
:param entities_list: 所有模型预测出的一个文件的实体
:param threshold:大于70%模型预测出来的实体才能被选中
:return:
"""
threshold_nums = int(len(entities_list)*threshold)
entities_dict = defaultdict(int)
entities = defaultdict(list)
for _entities in entities_list:
for _type in _entities:
for _ent in _entities[_type]:
entities_dict[(_type, _ent[0], _ent[1])] += 1
for key in entities_dict:
if entities_dict[key] >= threshold_nums:
entities[key[0]].append((key[1], key[2]))
return entities
def ensemble_vote(entities_list, threshold=0.9):
"""
针对 ensemble model 进行的 vote
实体级别的投票方式 (entity_type, entity_start, entity_end, entity_text)
"""
threshold_nums = int(len(entities_list)*threshold)
entities_dict = defaultdict(int)
entities = defaultdict(list)
for _entities in entities_list:
for _id in _entities:
for _ent in _entities[_id]:
entities_dict[(_id, ) + _ent] += 1
for key in entities_dict:
if entities_dict[key] >= threshold_nums:
entities[key[0]].append(key[1:])
return entities

646
src/utils/model_utils.py Normal file
View File

@ -0,0 +1,646 @@
import os
import math
import torch
import torch.nn as nn
from torchcrf import CRF
from itertools import repeat
from transformers import BertModel
from src.utils.functions_utils import vote
from src.utils.evaluator import crf_decode, span_decode
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, eps=0.1, reduction='mean', ignore_index=-100):
super(LabelSmoothingCrossEntropy, self).__init__()
self.eps = eps
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, output, target):
c = output.size()[-1]
log_pred = torch.log_softmax(output, dim=-1)
if self.reduction == 'sum':
loss = -log_pred.sum()
else:
loss = -log_pred.sum(dim=-1)
if self.reduction == 'mean':
loss = loss.mean()
return loss * self.eps / c + (1 - self.eps) * torch.nn.functional.nll_loss(log_pred, target,
reduction=self.reduction,
ignore_index=self.ignore_index)
class FocalLoss(nn.Module):
"""Multi-class Focal loss implementation"""
def __init__(self, gamma=2, weight=None, reduction='mean', ignore_index=-100):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, input, target):
"""
input: [N, C]
target: [N, ]
"""
log_pt = torch.log_softmax(input, dim=1)
pt = torch.exp(log_pt)
log_pt = (1 - pt) ** self.gamma * log_pt
loss = torch.nn.functional.nll_loss(log_pt, target, self.weight, reduction=self.reduction, ignore_index=self.ignore_index)
return loss
class SpatialDropout(nn.Module):
"""
对字级别的向量进行丢弃
"""
def __init__(self, drop_prob):
super(SpatialDropout, self).__init__()
self.drop_prob = drop_prob
@staticmethod
def _make_noise(input):
return input.new().resize_(input.size(0), *repeat(1, input.dim() - 2), input.size(2))
def forward(self, inputs):
output = inputs.clone()
if not self.training or self.drop_prob == 0:
return inputs
else:
noise = self._make_noise(inputs)
if self.drop_prob == 1:
noise.fill_(0)
else:
noise.bernoulli_(1 - self.drop_prob).div_(1 - self.drop_prob)
noise = noise.expand_as(inputs)
output.mul_(noise)
return output
class ConditionalLayerNorm(nn.Module):
def __init__(self,
normalized_shape,
cond_shape,
eps=1e-12):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.Tensor(normalized_shape))
self.bias = nn.Parameter(torch.Tensor(normalized_shape))
self.weight_dense = nn.Linear(cond_shape, normalized_shape, bias=False)
self.bias_dense = nn.Linear(cond_shape, normalized_shape, bias=False)
self.reset_weight_and_bias()
def reset_weight_and_bias(self):
"""
此处初始化的作用是在训练开始阶段不让 conditional layer norm 起作用
"""
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
nn.init.zeros_(self.weight_dense.weight)
nn.init.zeros_(self.bias_dense.weight)
def forward(self, inputs, cond=None):
assert cond is not None, 'Conditional tensor need to input when use conditional layer norm'
cond = torch.unsqueeze(cond, 1) # (b, 1, h*2)
weight = self.weight_dense(cond) + self.weight # (b, 1, h)
bias = self.bias_dense(cond) + self.bias # (b, 1, h)
mean = torch.mean(inputs, dim=-1, keepdim=True) # b, s, 1
outputs = inputs - mean # (b, s, h)
variance = torch.mean(outputs ** 2, dim=-1, keepdim=True)
std = torch.sqrt(variance + self.eps) # (b, s, 1)
outputs = outputs / std # (b, s, h)
outputs = outputs * weight + bias
return outputs
class BaseModel(nn.Module):
def __init__(self,
bert_dir,
dropout_prob):
super(BaseModel, self).__init__()
config_path = os.path.join(bert_dir, 'config.json')
assert os.path.exists(bert_dir) and os.path.exists(config_path), \
'pretrained bert file does not exist'
self.bert_module = BertModel.from_pretrained(bert_dir,
output_hidden_states=True,
hidden_dropout_prob=dropout_prob)
self.bert_config = self.bert_module.config
@staticmethod
def _init_weights(blocks, **kwargs):
"""
参数初始化 Linear / Embedding / LayerNorm Bert 进行一样的初始化
"""
for block in blocks:
for module in block.modules():
if isinstance(module, nn.Linear):
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0, std=kwargs.pop('initializer_range', 0.02))
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
# baseline
class CRFModel(BaseModel):
def __init__(self,
bert_dir,
num_tags,
dropout_prob=0.1,
**kwargs):
super(CRFModel, self).__init__(bert_dir=bert_dir, dropout_prob=dropout_prob)
out_dims = self.bert_config.hidden_size
mid_linear_dims = kwargs.pop('mid_linear_dims', 128)
self.mid_linear = nn.Sequential(
nn.Linear(out_dims, mid_linear_dims),
nn.ReLU(),
nn.Dropout(dropout_prob)
)
out_dims = mid_linear_dims
self.classifier = nn.Linear(out_dims, num_tags)
self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.loss_weight.data.fill_(-0.2)
self.crf_module = CRF(num_tags=num_tags, batch_first=True)
init_blocks = [self.mid_linear, self.classifier]
self._init_weights(init_blocks, initializer_range=self.bert_config.initializer_range)
def forward(self,
token_ids,
attention_masks,
token_type_ids,
labels=None,
pseudo=None):
bert_outputs = self.bert_module(
input_ids=token_ids,
attention_mask=attention_masks,
token_type_ids=token_type_ids
)
# 常规
seq_out = bert_outputs[0]
seq_out = self.mid_linear(seq_out)
emissions = self.classifier(seq_out)
if labels is not None:
if pseudo is not None:
# (batch,)
tokens_loss = -1. * self.crf_module(emissions=emissions,
tags=labels.long(),
mask=attention_masks.byte(),
reduction='none')
# nums of pseudo data
pseudo_nums = pseudo.sum().item()
total_nums = token_ids.shape[0]
# learning parameter
rate = torch.sigmoid(self.loss_weight)
if pseudo_nums == 0:
loss_0 = tokens_loss.mean()
loss_1 = (rate*pseudo*tokens_loss).sum()
else:
if total_nums == pseudo_nums:
loss_0 = 0
else:
loss_0 = ((1 - rate) * (1 - pseudo) * tokens_loss).sum() / (total_nums - pseudo_nums)
loss_1 = (rate*pseudo*tokens_loss).sum() / pseudo_nums
tokens_loss = loss_0 + loss_1
else:
tokens_loss = -1. * self.crf_module(emissions=emissions,
tags=labels.long(),
mask=attention_masks.byte(),
reduction='mean')
out = (tokens_loss,)
else:
tokens_out = self.crf_module.decode(emissions=emissions, mask=attention_masks.byte())
out = (tokens_out, emissions)
return out
class SpanModel(BaseModel):
def __init__(self,
bert_dir,
num_tags,
dropout_prob=0.1,
loss_type='ce',
**kwargs):
"""
tag the subject and object corresponding to the predicate
:param loss_type: train loss type in ['ce', 'ls_ce', 'focal']
"""
super(SpanModel, self).__init__(bert_dir, dropout_prob=dropout_prob)
out_dims = self.bert_config.hidden_size
mid_linear_dims = kwargs.pop('mid_linear_dims', 128)
self.num_tags = num_tags
self.mid_linear = nn.Sequential(
nn.Linear(out_dims, mid_linear_dims),
nn.ReLU(),
nn.Dropout(dropout_prob)
)
out_dims = mid_linear_dims
self.start_fc = nn.Linear(out_dims, num_tags)
self.end_fc = nn.Linear(out_dims, num_tags)
reduction = 'none'
if loss_type == 'ce':
self.criterion = nn.CrossEntropyLoss(reduction=reduction)
elif loss_type == 'ls_ce':
self.criterion = LabelSmoothingCrossEntropy(reduction=reduction)
else:
self.criterion = FocalLoss(reduction=reduction)
self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.loss_weight.data.fill_(-0.2)
init_blocks = [self.mid_linear, self.start_fc, self.end_fc]
self._init_weights(init_blocks)
def forward(self,
token_ids,
attention_masks,
token_type_ids,
start_ids=None,
end_ids=None,
pseudo=None):
bert_outputs = self.bert_module(
input_ids=token_ids,
attention_mask=attention_masks,
token_type_ids=token_type_ids
)
seq_out = bert_outputs[0]
seq_out = self.mid_linear(seq_out)
start_logits = self.start_fc(seq_out)
end_logits = self.end_fc(seq_out)
out = (start_logits, end_logits, )
if start_ids is not None and end_ids is not None and self.training:
start_logits = start_logits.view(-1, self.num_tags)
end_logits = end_logits.view(-1, self.num_tags)
# 去掉 padding 部分的标签,计算真实 loss
active_loss = attention_masks.view(-1) == 1
active_start_logits = start_logits[active_loss]
active_end_logits = end_logits[active_loss]
active_start_labels = start_ids.view(-1)[active_loss]
active_end_labels = end_ids.view(-1)[active_loss]
if pseudo is not None:
# (batch,)
start_loss = self.criterion(start_logits, start_ids.view(-1)).view(-1, 512).mean(dim=-1)
end_loss = self.criterion(end_logits, end_ids.view(-1)).view(-1, 512).mean(dim=-1)
# nums of pseudo data
pseudo_nums = pseudo.sum().item()
total_nums = token_ids.shape[0]
# learning parameter
rate = torch.sigmoid(self.loss_weight)
if pseudo_nums == 0:
start_loss = start_loss.mean()
end_loss = end_loss.mean()
else:
if total_nums == pseudo_nums:
start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums
end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums
else:
start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums \
+ ((1 - rate) * (1 - pseudo) * start_loss).sum() / (total_nums - pseudo_nums)
end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums \
+ ((1 - rate) * (1 - pseudo) * end_loss).sum() / (total_nums - pseudo_nums)
else:
start_loss = self.criterion(active_start_logits, active_start_labels)
end_loss = self.criterion(active_end_logits, active_end_labels)
loss = start_loss + end_loss
out = (loss, ) + out
return out
class MRCModel(BaseModel):
def __init__(self,
bert_dir,
dropout_prob=0.1,
use_type_embed=False,
loss_type='ce',
**kwargs):
"""
tag the subject and object corresponding to the predicate
:param use_type_embed: type embedding for the sentence
:param loss_type: train loss type in ['ce', 'ls_ce', 'focal']
"""
super(MRCModel, self).__init__(bert_dir, dropout_prob=dropout_prob)
self.use_type_embed = use_type_embed
self.use_smooth = loss_type
out_dims = self.bert_config.hidden_size
if self.use_type_embed:
embed_dims = kwargs.pop('predicate_embed_dims', self.bert_config.hidden_size)
self.type_embedding = nn.Embedding(13, embed_dims)
self.conditional_layer_norm = ConditionalLayerNorm(out_dims, embed_dims,
eps=self.bert_config.layer_norm_eps)
mid_linear_dims = kwargs.pop('mid_linear_dims', 128)
self.mid_linear = nn.Sequential(
nn.Linear(out_dims, mid_linear_dims),
nn.ReLU(),
nn.Dropout(dropout_prob)
)
out_dims = mid_linear_dims
self.start_fc = nn.Linear(out_dims, 2)
self.end_fc = nn.Linear(out_dims, 2)
reduction = 'none'
if loss_type == 'ce':
self.criterion = nn.CrossEntropyLoss(reduction=reduction)
elif loss_type == 'ls_ce':
self.criterion = LabelSmoothingCrossEntropy(reduction=reduction)
else:
self.criterion = FocalLoss(reduction=reduction)
self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.loss_weight.data.fill_(-0.2)
init_blocks = [self.mid_linear, self.start_fc, self.end_fc]
if self.use_type_embed:
init_blocks.append(self.type_embedding)
self._init_weights(init_blocks)
def forward(self,
token_ids,
attention_masks,
token_type_ids,
ent_type=None,
start_ids=None,
end_ids=None,
pseudo=None):
bert_outputs = self.bert_module(
input_ids=token_ids,
attention_mask=attention_masks,
token_type_ids=token_type_ids
)
seq_out = bert_outputs[0]
if self.use_type_embed:
assert ent_type is not None, \
'Using predicate embedding, predicate should be implemented'
predicate_feature = self.type_embedding(ent_type)
seq_out = self.conditional_layer_norm(seq_out, predicate_feature)
seq_out = self.mid_linear(seq_out)
start_logits = self.start_fc(seq_out)
end_logits = self.end_fc(seq_out)
out = (start_logits, end_logits, )
if start_ids is not None and end_ids is not None:
start_logits = start_logits.view(-1, 2)
end_logits = end_logits.view(-1, 2)
# 去掉 text_a 和 padding 部分的标签,计算真实 loss
active_loss = token_type_ids.view(-1) == 1
active_start_logits = start_logits[active_loss]
active_end_logits = end_logits[active_loss]
active_start_labels = start_ids.view(-1)[active_loss]
active_end_labels = end_ids.view(-1)[active_loss]
if pseudo is not None:
# (batch,)
start_loss = self.criterion(start_logits, start_ids.view(-1)).view(-1, 512).mean(dim=-1)
end_loss = self.criterion(end_logits, end_ids.view(-1)).view(-1, 512).mean(dim=-1)
# nums of pseudo data
pseudo_nums = pseudo.sum().item()
total_nums = token_ids.shape[0]
# learning parameter
rate = torch.sigmoid(self.loss_weight)
if pseudo_nums == 0:
start_loss = start_loss.mean()
end_loss = end_loss.mean()
else:
if total_nums == pseudo_nums:
start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums
end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums
else:
start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums \
+ ((1 - rate) * (1 - pseudo) * start_loss).sum() / (total_nums - pseudo_nums)
end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums \
+ ((1 - rate) * (1 - pseudo) * end_loss).sum() / (total_nums - pseudo_nums)
else:
start_loss = self.criterion(active_start_logits, active_start_labels)
end_loss = self.criterion(active_end_logits, active_end_labels)
loss = start_loss + end_loss
out = (loss, ) + out
return out
class EnsembleCRFModel:
def __init__(self, model_path_list, bert_dir_list, num_tags, device, lamb=1/3):
self.models = []
self.crf_module = CRF(num_tags=num_tags, batch_first=True)
self.lamb = lamb
for idx, _path in enumerate(model_path_list):
print(f'Load model from {_path}')
print(f'Load model type: {bert_dir_list[0]}')
model = CRFModel(bert_dir=bert_dir_list[0], num_tags=num_tags)
model.load_state_dict(torch.load(_path, map_location=torch.device('cpu')))
model.eval()
model.to(device)
self.models.append(model)
if idx == 0:
print(f'Load CRF weight from {_path}')
self.crf_module.load_state_dict(model.crf_module.state_dict())
self.crf_module.to(device)
def weight(self, t):
"""
牛顿冷却定律加权融合
"""
return math.exp(-self.lamb*t)
def predict(self, model_inputs):
weight_sum = 0.
logits = None
attention_masks = model_inputs['attention_masks']
for idx, model in enumerate(self.models):
# 使用牛顿冷却概率融合
weight = self.weight(idx)
# 使用概率平均融合
# weight = 1 / len(self.models)
tmp_logits = model(**model_inputs)[1] * weight
weight_sum += weight
if logits is None:
logits = tmp_logits
else:
logits += tmp_logits
logits = logits / weight_sum
tokens_out = self.crf_module.decode(emissions=logits, mask=attention_masks.byte())
return tokens_out
def vote_entities(self, model_inputs, sent, id2ent, threshold):
entities_ls = []
for idx, model in enumerate(self.models):
tmp_tokens = model(**model_inputs)[0][0]
tmp_entities = crf_decode(tmp_tokens, sent, id2ent)
entities_ls.append(tmp_entities)
return vote(entities_ls, threshold)
class EnsembleSpanModel:
def __init__(self, model_path_list, bert_dir_list, num_tags, device):
self.models = []
for idx, _path in enumerate(model_path_list):
print(f'Load model from {_path}')
print(f'Load model type: {bert_dir_list[0]}')
model = SpanModel(bert_dir=bert_dir_list[0], num_tags=num_tags)
model.load_state_dict(torch.load(_path, map_location=torch.device('cpu')))
model.eval()
model.to(device)
self.models.append(model)
def predict(self, model_inputs):
start_logits, end_logits = None, None
for idx, model in enumerate(self.models):
# 使用概率平均融合
weight = 1 / len(self.models)
tmp_start_logits, tmp_end_logits = model(**model_inputs)
tmp_start_logits = tmp_start_logits * weight
tmp_end_logits = tmp_end_logits * weight
if start_logits is None:
start_logits = tmp_start_logits
end_logits = tmp_end_logits
else:
start_logits += tmp_start_logits
end_logits += tmp_end_logits
return start_logits, end_logits
def vote_entities(self, model_inputs, sent, id2ent, threshold):
entities_ls = []
for idx, model in enumerate(self.models):
start_logits, end_logits = model(**model_inputs)
start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)]
end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)]
decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
entities_ls.append(decode_entities)
return vote(entities_ls, threshold)
def build_model(task_type, bert_dir, **kwargs):
assert task_type in ['crf', 'span', 'mrc']
if task_type == 'crf':
model = CRFModel(bert_dir=bert_dir,
num_tags=kwargs.pop('num_tags'),
dropout_prob=kwargs.pop('dropout_prob', 0.1))
elif task_type == 'mrc':
model = MRCModel(bert_dir=bert_dir,
dropout_prob=kwargs.pop('dropout_prob', 0.1),
use_type_embed=kwargs.pop('use_type_embed'),
loss_type=kwargs.pop('loss_type', 'ce'))
else:
model = SpanModel(bert_dir=bert_dir,
num_tags=kwargs.pop('num_tags'),
dropout_prob=kwargs.pop('dropout_prob', 0.1),
loss_type=kwargs.pop('loss_type', 'ce'))
return model

98
src/utils/options.py Normal file
View File

@ -0,0 +1,98 @@
import argparse
class Args:
@staticmethod
def parse():
parser = argparse.ArgumentParser()
return parser
@staticmethod
def initialize(parser: argparse.ArgumentParser):
# args for path
parser.add_argument('--raw_data_dir', default='./data/raw_data',
help='the data dir of raw data')
parser.add_argument('--mid_data_dir', default='./data/mid_data',
help='the mid data dir')
parser.add_argument('--output_dir', default='./out/',
help='the output dir for model checkpoints')
parser.add_argument('--bert_dir', default='../bert/torch_roberta_wwm',
help='bert dir for ernie / roberta-wwm / uer')
parser.add_argument('--bert_type', default='roberta_wwm',
help='roberta_wwm / ernie_1 / uer_large')
parser.add_argument('--task_type', default='crf',
help='crf / span / mrc')
parser.add_argument('--loss_type', default='ls_ce',
help='loss type for span bert')
parser.add_argument('--use_type_embed', default=False, action='store_true',
help='weather to use soft label in span loss')
parser.add_argument('--use_fp16', default=False, action='store_true',
help='weather to use fp16 during training')
# other args
parser.add_argument('--seed', type=int, default=123, help='random seed')
parser.add_argument('--gpu_ids', type=str, default='0',
help='gpu ids to use, -1 for cpu, "0,1" for multi gpu')
parser.add_argument('--mode', type=str, default='train',
help='train / stack')
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--eval_batch_size', default=64, type=int)
parser.add_argument('--swa_start', default=3, type=int,
help='the epoch when swa start')
# train args
parser.add_argument('--train_epochs', default=10, type=int,
help='Max training epoch')
parser.add_argument('--dropout_prob', default=0.1, type=float,
help='drop out probability')
parser.add_argument('--lr', default=2e-5, type=float,
help='learning rate for the bert module')
parser.add_argument('--other_lr', default=2e-3, type=float,
help='learning rate for the module except bert')
parser.add_argument('--max_grad_norm', default=1.0, type=float,
help='max grad clip')
parser.add_argument('--warmup_proportion', default=0.1, type=float)
parser.add_argument('--weight_decay', default=0.00, type=float)
parser.add_argument('--adam_epsilon', default=1e-8, type=float)
parser.add_argument('--train_batch_size', default=24, type=int)
parser.add_argument('--eval_model', default=True, action='store_true',
help='whether to eval model after training')
parser.add_argument('--attack_train', default='', type=str,
help='fgm / pgd attack train when training')
# test args
parser.add_argument('--version', default='v0', type=str,
help='submit version')
parser.add_argument('--submit_dir', default='./submit', type=str)
parser.add_argument('--ckpt_dir', default='', type=str)
return parser
def get_parser(self):
parser = self.parse()
parser = self.initialize(parser)
return parser.parse_args()

223
src/utils/trainer.py Normal file
View File

@ -0,0 +1,223 @@
import os
import copy
import torch
import logging
from torch.cuda.amp import autocast as ac
from torch.utils.data import DataLoader, RandomSampler
from transformers import AdamW, get_linear_schedule_with_warmup
from src.utils.attack_train_utils import FGM, PGD
from src.utils.functions_utils import load_model_and_parallel, swa
logger = logging.getLogger(__name__)
def save_model(opt, model, global_step):
output_dir = os.path.join(opt.output_dir, 'checkpoint-{}'.format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True)
# take care of model distributed / parallel training
model_to_save = (
model.module if hasattr(model, "module") else model
)
logger.info(f'Saving model & optimizer & scheduler checkpoint to {output_dir}')
torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'model.pt'))
def build_optimizer_and_scheduler(opt, model, t_total):
module = (
model.module if hasattr(model, "module") else model
)
# 差分学习率
no_decay = ["bias", "LayerNorm.weight"]
model_param = list(module.named_parameters())
bert_param_optimizer = []
other_param_optimizer = []
for name, para in model_param:
space = name.split('.')
if space[0] == 'bert_module':
bert_param_optimizer.append((name, para))
else:
other_param_optimizer.append((name, para))
optimizer_grouped_parameters = [
# bert other module
{"params": [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay": opt.weight_decay, 'lr': opt.lr},
{"params": [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)],
"weight_decay": 0.0, 'lr': opt.lr},
# 其他模块,差分学习率
{"params": [p for n, p in other_param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay": opt.weight_decay, 'lr': opt.other_lr},
{"params": [p for n, p in other_param_optimizer if any(nd in n for nd in no_decay)],
"weight_decay": 0.0, 'lr': opt.other_lr},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=opt.lr, eps=opt.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=int(opt.warmup_proportion * t_total), num_training_steps=t_total
)
return optimizer, scheduler
def train(opt, model, train_dataset):
swa_raw_model = copy.deepcopy(model)
train_sampler = RandomSampler(train_dataset)
train_loader = DataLoader(dataset=train_dataset,
batch_size=opt.train_batch_size,
sampler=train_sampler,
num_workers=0)
scaler = None
if opt.use_fp16:
scaler = torch.cuda.amp.GradScaler()
model, device = load_model_and_parallel(model, opt.gpu_ids)
use_n_gpus = False
if hasattr(model, "module"):
use_n_gpus = True
t_total = len(train_loader) * opt.train_epochs
optimizer, scheduler = build_optimizer_and_scheduler(opt, model, t_total)
# Train
logger.info("***** Running training *****")
logger.info(f" Num Examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {opt.train_epochs}")
logger.info(f" Total training batch size = {opt.train_batch_size}")
logger.info(f" Total optimization steps = {t_total}")
global_step = 0
model.zero_grad()
fgm, pgd = None, None
attack_train_mode = opt.attack_train.lower()
if attack_train_mode == 'fgm':
fgm = FGM(model=model)
elif attack_train_mode == 'pgd':
pgd = PGD(model=model)
pgd_k = 3
save_steps = t_total // opt.train_epochs
eval_steps = save_steps
logger.info(f'Save model in {save_steps} steps; Eval model in {eval_steps} steps')
log_loss_steps = 20
avg_loss = 0.
for epoch in range(opt.train_epochs):
for step, batch_data in enumerate(train_loader):
model.train()
for key in batch_data.keys():
batch_data[key] = batch_data[key].to(device)
if opt.use_fp16:
with ac():
loss = model(**batch_data)[0]
else:
loss = model(**batch_data)[0]
if use_n_gpus:
loss = loss.mean()
if opt.use_fp16:
scaler.scale(loss).backward()
else:
loss.backward()
if fgm is not None:
fgm.attack()
if opt.use_fp16:
with ac():
loss_adv = model(**batch_data)[0]
else:
loss_adv = model(**batch_data)[0]
if use_n_gpus:
loss_adv = loss_adv.mean()
if opt.use_fp16:
scaler.scale(loss_adv).backward()
else:
loss_adv.backward()
fgm.restore()
elif pgd is not None:
pgd.backup_grad()
for _t in range(pgd_k):
pgd.attack(is_first_attack=(_t == 0))
if _t != pgd_k - 1:
model.zero_grad()
else:
pgd.restore_grad()
if opt.use_fp16:
with ac():
loss_adv = model(**batch_data)[0]
else:
loss_adv = model(**batch_data)[0]
if use_n_gpus:
loss_adv = loss_adv.mean()
if opt.use_fp16:
scaler.scale(loss_adv).backward()
else:
loss_adv.backward()
pgd.restore()
if opt.use_fp16:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm)
# optimizer.step()
if opt.use_fp16:
scaler.step(optimizer)
scaler.update()
else:
optimizer.step()
scheduler.step()
model.zero_grad()
global_step += 1
if global_step % log_loss_steps == 0:
avg_loss /= log_loss_steps
logger.info('Step: %d / %d ----> total loss: %.5f' % (global_step, t_total, avg_loss))
avg_loss = 0.
else:
avg_loss += loss.item()
if global_step % save_steps == 0:
save_model(opt, model, global_step)
swa(swa_raw_model, opt.output_dir, swa_start=opt.swa_start)
# clear cuda cache to avoid OOM
torch.cuda.empty_cache()
logger.info('Train done')