first commit
230
README.md
Normal file
@ -0,0 +1,230 @@
|
||||
# Chinese-DeepNER-Pytorch
|
||||
|
||||
## 天池中药说明书实体识别挑战冠军方案开源
|
||||
|
||||
### 后续官方开放数据集后本项目会进行优化升级成为DeepNER项目,减少代码冗余,提高代码可读性。项目包含完整的数据处理、训练、验证、测试、部署流程,提供详细的代码注释与模型介绍,提供更普适的基于预训练的中文命名实体识别方案,开箱即用,欢迎star and issue
|
||||
|
||||
(代码框架基于pytorch和transformers,支持最新版的pytorch, 框架**复用性、解耦性、易读性**较高,很容易修改迁移至其他NLP任务中)
|
||||
|
||||
# 赛题背景
|
||||
## 任务描述
|
||||
人工智能加速了中医药领域的传承创新发展,其中中医药文本的信息抽取部分是构建中医药知识图谱的核心部分,为上层应用如临床辅助诊疗系统的构建(CDSS)等奠定了基础。本次NER挑战需要抽取中药药品说明书中的关键信息,包括药品、药物成分、疾病、症状、证候等13类实体,构建中医药药品知识库。
|
||||
|
||||
## 数据探索分析
|
||||
本次竞赛训练数据有三个特点:
|
||||
|
||||
- 中药药品说明书以长文本居多
|
||||
<img src="md_files/1.png" style="zoom:50%;" />
|
||||
- 医疗场景下的标注样本不足
|
||||
<img src="md_files/2.png" style="zoom:50%;" />
|
||||
- 标签分布不平衡
|
||||
<img src="md_files/3.png" style="zoom:50%;" />
|
||||
|
||||
# 核心思路
|
||||
## 数据预处理
|
||||
首先对说明书文本进行预清洗与长文本切分。预清洗部分对无效字符进行过滤。针对长文本问题,采用两级文本切分的策略。切分后的句子可能过短,将短文本归并,使得归并后的文本长度不超过设置的最大长度。流程图如下:
|
||||
<img src="md_files/5.png" style="zoom: 33%;" />
|
||||
此外,利用全部标注数据构造实体知识库,作为领域先验词典。
|
||||
|
||||
## Baseline: BERT-CRF
|
||||
<img src="md_files/6.png" style="zoom: 33%;" />
|
||||
|
||||
- Baseline 细节
|
||||
- 预训练模型:选用 UER-large-24 layer[1],UER在RoBerta-wwm 框架下采用大规模优质中文语料继续训练,CLUE 任务中单模第一
|
||||
- 差分学习率:BERT层学习率2e-5;其他层学习率2e-3
|
||||
- 参数初始化:模型其他模块与BERT采用相同的初始化方式
|
||||
- 滑动参数平均:加权平均最后几个epoch模型的权重,得到更加平滑和表现更优的模型
|
||||
- Baseline bad-case分析
|
||||
<img src="md_files/7.png" style="zoom: 33%;" />
|
||||
|
||||
## 优化1:对抗训练
|
||||
- 动机:采用对抗训练缓解模型鲁棒性差的问题,提升模型泛化能力
|
||||
- 对抗训练是一种引入噪声的训练方式,可以对参数进行正则化,提升模型鲁棒性和泛化能力
|
||||
- Fast Gradient Method (FGM):对embedding层在梯度方向添加扰动
|
||||
- Projected Gradient Descent (PGD) [2]:迭代扰动,每次扰动被投影到规定范围内
|
||||
|
||||
## 优化2:混合精度训练(FP16)
|
||||
- 动机:对抗训练降低了计算效率,使用混合精度训练优化训练耗时
|
||||
- 混合精度训练
|
||||
- 在内存中用FP16做存储和乘法来加速
|
||||
- 用FP32做累加避免舍入误差
|
||||
- 损失放大
|
||||
- 反向传播前扩大2^k倍loss,防止loss下溢出
|
||||
- 反向传播后将权重梯度还原
|
||||
|
||||
## 优化3:多模型融合
|
||||
- 动机:baseline 错误集中于歧义性错误,采用多级医学命名实体识别系统以消除歧义性
|
||||
- 方法:差异化多级模型融合系统
|
||||
- 模型框架差异化:BERT-CRF & BERT-SPAN & BERT-MRC
|
||||
- 训练数据差异化:更换随机种子、更换句子切分长度(256、512)
|
||||
- 多级模型融合策略
|
||||
|
||||
- 融合模型1——BERT-SPAN
|
||||
- 采用SPAN指针的形式替代CRF模块,加快训练速度
|
||||
- 以半指针-半标注的结构预测实体的起始位置,同时标注过程中给出实体类别
|
||||
- 采用严格解码形式,重叠实体选取logits最大的一个,保证准确率
|
||||
- 使用label smooth缓解过拟合问题
|
||||
<img src="md_files/10.png" style="zoom:33%;" />
|
||||
|
||||
- 融合模型2——BERT-MRC
|
||||
- 基于阅读理解的方式处理NER任务
|
||||
- query:实体类型的描述来作为query
|
||||
- doc:分句后的原始文本作为doc
|
||||
- 针对每一种类型构造一个样本,训练时有大量负样本,可以随机选取30%加入训练,其余丢弃,保证效率
|
||||
- 预测时对每一类都需构造一次样本,对解码输出不做限制,保证召回率
|
||||
- 使用label smooth缓解过拟合问题
|
||||
- MRC在本次数据集上精度表现不佳,且训练和推理效率较低,仅作为提升召回率的方案,提供代码仅供学习,不推荐日常使用
|
||||
<img src="md_files/11.png" style="zoom:33%;" />
|
||||
|
||||
- 多级融合策略
|
||||
- CRF/SPAN/MRC 5折交叉验证得到的模型进行第一级概率融合,将 logits 平均后解码实体
|
||||
- CRF/SPAN/MRC 概率融合后的模型进行第二级投票融合,获取最终结果
|
||||
<img src="md_files/12.png" style="zoom:33%;" />
|
||||
|
||||
## 优化4:半监督学习
|
||||
- 动机:为了缓解医疗场景下的标注语料稀缺的问题, 我们使用半监督学习(伪标签)充分利用未标注的500条初赛测试集
|
||||
- 策略:动态伪标签
|
||||
- 首先使用原始标注数据训练一个基准模型M
|
||||
- 使用基准模型M对初赛测试集进行预测得到伪标签
|
||||
- 将伪标签加入训练集,赋予伪标签一个动态可学习权重(图中alpha),加入真实标签数据中共同训练得到模型M’
|
||||
<img src="md_files/13.png" style="zoom: 25%;" />
|
||||
- tips:使用多模融合的基准模型减少伪标签的噪音;权重也可以固定,选取需多尝试哪个效果好,本质上是降低伪标签的loss权重,是缓解伪标签噪音的一种方法。
|
||||
|
||||
## 其他无明显提升的尝试方案
|
||||
- 取BERT后四层动态加权输出,无明显提升
|
||||
- BERT 输出后加上BiLSTM / IDCNN 模块,过拟合严重,训练速度大大降低
|
||||
- 数据增强,对同类实体词进行随机替换,以扩充训练数据
|
||||
- BERT-SPAN / MRC 模型采用focal loss / dice loss 等缓解标签不平衡
|
||||
- 利用构造的领域词典修正模型输出
|
||||
|
||||
## 最终线上成绩72.90%,复赛Rank 1,决赛Rank 1
|
||||
|
||||
|
||||
# Ref
|
||||
[1] Zhao et al., UER: An Open-Source Toolkit for Pre-training Models, EMNLP-IJCNLP, 2019.
|
||||
[2] Madry et al., Towards Deep Learning Models Resistant to Adversarial Attacks, ICLR, 2018.
|
||||
|
||||
|
||||
|
||||
## 环境
|
||||
|
||||
```python
|
||||
python3.7
|
||||
pytorch==1.6.0 +
|
||||
transformers==2.10.0
|
||||
pytorch-crf==0.7.2
|
||||
```
|
||||
|
||||
## 项目目录说明
|
||||
|
||||
```shell
|
||||
DeepNER
|
||||
│
|
||||
├── data # 数据文件夹
|
||||
│ ├── mid_data # 存放一些中间数据
|
||||
│ │ ├── crf_ent2id.json # crf 模型的 schema
|
||||
│ │ └── span_ent2id.json # span 模型的 schema
|
||||
│ │ └── mrc_ent2id.json # mrc 模型的 schema
|
||||
│
|
||||
│ ├── raw_data # 转换后的数据
|
||||
│ │ ├── dev.json # 转换后的验证集
|
||||
│ │ ├── test.json # 转换后的初赛测试集
|
||||
│ │ ├── pseudo.json # 转换后的半监督数据
|
||||
│ │ ├── stack.json # 转换后的全体数据
|
||||
│ └── └── train.json # 转换后的训练集
|
||||
│
|
||||
├── out # 存放训练好的模型
|
||||
│ ├── ...
|
||||
│ └── ...
|
||||
│
|
||||
├── src
|
||||
│ ├── preprocess
|
||||
│ │ ├── convert_raw_data.py # 处理转换原始数据
|
||||
│ │ └── processor.py # 转换数据为 Bert 模型的输入
|
||||
│ ├── utils
|
||||
│ │ ├── attack_train_utils.py # 对抗训练 FGM / PGD
|
||||
│ │ ├── dataset_utils.py # torch Dataset
|
||||
│ │ ├── evaluator.py # 模型评估
|
||||
│ │ ├── functions_utils.py # 跨文件调用的一些 functions
|
||||
│ │ ├── model_utils.py # Span & CRF & MRC model (pytorch)
|
||||
│ │ ├── options.py # 命令行参数
|
||||
│ | └── trainer.py # 训练器
|
||||
|
|
||||
├── competition_predict.py # 复赛数据推理并提交
|
||||
├── README.md # ...
|
||||
├── convert_test_data.py # 将复赛 test 转化成 json 格式
|
||||
├── run.sh # 运行脚本
|
||||
└── main.py # main 函数 (主要用于训练/评估)
|
||||
```
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 预训练使用说明
|
||||
|
||||
* 腾讯预训练模型 Uer-large(24层): https://github.com/dbiir/UER-py/wiki/Modelzoo
|
||||
|
||||
* 哈工大预训练模型 :https://github.com/ymcui/Chinese-BERT-wwm
|
||||
|
||||
百度云下载链接:
|
||||
|
||||
链接:https://pan.baidu.com/s/1axdkovbzGaszl8bXIn4sPw
|
||||
提取码:jjba
|
||||
|
||||
(注意:需人工将 vocab.txt 中两个 [unused] 转换成 [INV] 和 [BLANK])
|
||||
|
||||
tips: 推荐使用 uer、roberta-wwm、robert-wwm-large
|
||||
|
||||
### 数据转换
|
||||
|
||||
**注:已提供转换好的数据 无需运行**
|
||||
|
||||
```python
|
||||
python src/preprocessing/convert_raw_data.py
|
||||
```
|
||||
|
||||
### 训练阶段
|
||||
|
||||
```shell
|
||||
bash run.sh
|
||||
```
|
||||
|
||||
**注:脚本中指定的 BERT_DIR 指BERT所在文件夹,需要把 BERT 下载到指定文件夹中**
|
||||
|
||||
##### BERT-CRF模型训练
|
||||
|
||||
```python
|
||||
task_type='crf'
|
||||
mode='train' or 'stack' train:单模训练与验证 ; stack:5折训练与验证
|
||||
|
||||
swa_start: swa 模型权重平均开始的 epoch
|
||||
attack_train: 'pgd' / 'fgm' / '' 对抗训练 fgm 训练速度慢一倍, pgd 慢两倍,pgd 本次数据集效果明显
|
||||
```
|
||||
|
||||
##### BERT-SPAN模型训练
|
||||
|
||||
```python
|
||||
task_type='span'
|
||||
mode:同上
|
||||
attack_train: 同上
|
||||
loss_type: 'ce':交叉熵; 'ls_ce':label_smooth; 'focal': focal loss
|
||||
```
|
||||
|
||||
##### BERT-MRC模型训练
|
||||
|
||||
```python
|
||||
task_type='mrc'
|
||||
mode:同上
|
||||
attack_train: 同上
|
||||
loss_type: 同上
|
||||
```
|
||||
|
||||
### 预测复赛 test 文件 (上述模型训练完成后)
|
||||
|
||||
**注:暂无数据运行,等待官方数据开源后可运行**
|
||||
|
||||
```shell
|
||||
# convert_test_data
|
||||
python convert_test_data.py
|
||||
# predict
|
||||
python competition_predict.py
|
||||
```
|
320
competition_predict.py
Normal file
@ -0,0 +1,320 @@
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from transformers import BertTokenizer
|
||||
from src.utils.model_utils import CRFModel, SpanModel, EnsembleCRFModel, EnsembleSpanModel
|
||||
from src.utils.evaluator import crf_decode, span_decode
|
||||
from src.utils.functions_utils import load_model_and_parallel, ensemble_vote
|
||||
from src.preprocess.processor import cut_sent, fine_grade_tokenize
|
||||
|
||||
MID_DATA_DIR = "./data/mid_data"
|
||||
RAW_DATA_DIR = "./data/raw_data_random"
|
||||
SUBMIT_DIR = "./result"
|
||||
GPU_IDS = "0"
|
||||
|
||||
LAMBDA = 0.3
|
||||
THRESHOLD = 0.9
|
||||
MAX_SEQ_LEN = 512
|
||||
|
||||
TASK_TYPE = "crf" # choose crf or span
|
||||
VOTE = True # choose True or False
|
||||
VERSION = "mixed" # choose single or ensemble or mixed ; if mixed VOTE and TAST_TYPE is useless.
|
||||
|
||||
# single_predict
|
||||
BERT_TYPE = "uer_large" # roberta_wwm / ernie_1 / uer_large
|
||||
|
||||
BERT_DIR = f"./bert/torch_{BERT_TYPE}"
|
||||
with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f:
|
||||
CKPT_PATH = f.read().strip()
|
||||
|
||||
# ensemble_predict
|
||||
BERT_DIR_LIST = ["./bert/torch_uer_large", "./bert/torch_roberta_wwm"]
|
||||
|
||||
with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f:
|
||||
ENSEMBLE_DIR_LIST = f.readlines()
|
||||
print('ENSEMBLE_DIR_LIST:{}'.format(ENSEMBLE_DIR_LIST))
|
||||
|
||||
|
||||
# mixed_predict
|
||||
MIX_BERT_DIR = "./bert/torch_uer_large"
|
||||
|
||||
with open('./best_ckpt_path.txt', 'r', encoding='utf-8') as f:
|
||||
MIX_DIR_LIST = f.readlines()
|
||||
print('MIX_DIR_LIST:{}'.format(MIX_DIR_LIST))
|
||||
|
||||
|
||||
def prepare_info():
|
||||
info_dict = {}
|
||||
with open(os.path.join(MID_DATA_DIR, f'{TASK_TYPE}_ent2id.json'), encoding='utf-8') as f:
|
||||
ent2id = json.load(f)
|
||||
|
||||
with open(os.path.join(RAW_DATA_DIR, 'test.json'), encoding='utf-8') as f:
|
||||
info_dict['examples'] = json.load(f)
|
||||
|
||||
info_dict['id2ent'] = {ent2id[key]: key for key in ent2id.keys()}
|
||||
|
||||
info_dict['tokenizer'] = BertTokenizer(os.path.join(BERT_DIR, 'vocab.txt'))
|
||||
|
||||
return info_dict
|
||||
|
||||
|
||||
def mixed_prepare_info(mixed='crf'):
|
||||
info_dict = {}
|
||||
with open(os.path.join(MID_DATA_DIR, f'{mixed}_ent2id.json'), encoding='utf-8') as f:
|
||||
ent2id = json.load(f)
|
||||
|
||||
with open(os.path.join(RAW_DATA_DIR, 'test.json'), encoding='utf-8') as f:
|
||||
info_dict['examples'] = json.load(f)
|
||||
|
||||
info_dict['id2ent'] = {ent2id[key]: key for key in ent2id.keys()}
|
||||
|
||||
info_dict['tokenizer'] = BertTokenizer(os.path.join(BERT_DIR, 'vocab.txt'))
|
||||
|
||||
return info_dict
|
||||
|
||||
|
||||
def base_predict(model, device, info_dict, ensemble=False, mixed=''):
|
||||
labels = defaultdict(list)
|
||||
|
||||
tokenizer = info_dict['tokenizer']
|
||||
id2ent = info_dict['id2ent']
|
||||
|
||||
with torch.no_grad():
|
||||
for _ex in info_dict['examples']:
|
||||
ex_idx = _ex['id']
|
||||
raw_text = _ex['text']
|
||||
|
||||
if not len(raw_text):
|
||||
labels[ex_idx] = []
|
||||
print('{}为空'.format(ex_idx))
|
||||
continue
|
||||
|
||||
sentences = cut_sent(raw_text, MAX_SEQ_LEN)
|
||||
|
||||
start_index = 0
|
||||
|
||||
for sent in sentences:
|
||||
|
||||
sent_tokens = fine_grade_tokenize(sent, tokenizer)
|
||||
|
||||
encode_dict = tokenizer.encode_plus(text=sent_tokens,
|
||||
max_length=MAX_SEQ_LEN,
|
||||
is_pretokenized=True,
|
||||
pad_to_max_length=False,
|
||||
return_tensors='pt',
|
||||
return_token_type_ids=True,
|
||||
return_attention_mask=True)
|
||||
|
||||
model_inputs = {'token_ids': encode_dict['input_ids'],
|
||||
'attention_masks': encode_dict['attention_mask'],
|
||||
'token_type_ids': encode_dict['token_type_ids']}
|
||||
|
||||
for key in model_inputs:
|
||||
model_inputs[key] = model_inputs[key].to(device)
|
||||
|
||||
if ensemble:
|
||||
if TASK_TYPE == 'crf':
|
||||
if VOTE:
|
||||
decode_entities = model.vote_entities(model_inputs, sent, id2ent, THRESHOLD)
|
||||
else:
|
||||
pred_tokens = model.predict(model_inputs)[0]
|
||||
decode_entities = crf_decode(pred_tokens, sent, id2ent)
|
||||
else:
|
||||
if VOTE:
|
||||
decode_entities = model.vote_entities(model_inputs, sent, id2ent, THRESHOLD)
|
||||
else:
|
||||
start_logits, end_logits = model.predict(model_inputs)
|
||||
start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)]
|
||||
end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)]
|
||||
|
||||
decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
|
||||
|
||||
else:
|
||||
|
||||
if mixed:
|
||||
if mixed == 'crf':
|
||||
pred_tokens = model(**model_inputs)[0][0]
|
||||
decode_entities = crf_decode(pred_tokens, sent, id2ent)
|
||||
else:
|
||||
start_logits, end_logits = model(**model_inputs)
|
||||
|
||||
start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)]
|
||||
end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)]
|
||||
|
||||
decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
|
||||
|
||||
else:
|
||||
if TASK_TYPE == 'crf':
|
||||
pred_tokens = model(**model_inputs)[0][0]
|
||||
decode_entities = crf_decode(pred_tokens, sent, id2ent)
|
||||
else:
|
||||
start_logits, end_logits = model(**model_inputs)
|
||||
|
||||
start_logits = start_logits[0].cpu().numpy()[1:1+len(sent)]
|
||||
end_logits = end_logits[0].cpu().numpy()[1:1+len(sent)]
|
||||
|
||||
decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
|
||||
|
||||
|
||||
for _ent_type in decode_entities:
|
||||
for _ent in decode_entities[_ent_type]:
|
||||
tmp_start = _ent[1] + start_index
|
||||
tmp_end = tmp_start + len(_ent[0])
|
||||
|
||||
assert raw_text[tmp_start: tmp_end] == _ent[0]
|
||||
|
||||
labels[ex_idx].append((_ent_type, tmp_start, tmp_end, _ent[0]))
|
||||
|
||||
start_index += len(sent)
|
||||
|
||||
if not len(labels[ex_idx]):
|
||||
labels[ex_idx] = []
|
||||
|
||||
return labels
|
||||
|
||||
|
||||
def single_predict():
|
||||
save_dir = os.path.join(SUBMIT_DIR, VERSION)
|
||||
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
info_dict = prepare_info()
|
||||
|
||||
if TASK_TYPE == 'crf':
|
||||
model = CRFModel(bert_dir=BERT_DIR, num_tags=len(info_dict['id2ent']))
|
||||
else:
|
||||
model = SpanModel(bert_dir=BERT_DIR, num_tags=len(info_dict['id2ent'])+1)
|
||||
|
||||
print(f'Load model from {CKPT_PATH}')
|
||||
model, device = load_model_and_parallel(model, GPU_IDS, CKPT_PATH)
|
||||
model.eval()
|
||||
|
||||
labels = base_predict(model, device, info_dict)
|
||||
|
||||
for key in labels.keys():
|
||||
with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f:
|
||||
if not len(labels[key]):
|
||||
print(key)
|
||||
f.write("")
|
||||
else:
|
||||
for idx, _label in enumerate(labels[key]):
|
||||
f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n')
|
||||
|
||||
|
||||
def ensemble_predict():
|
||||
save_dir = os.path.join(SUBMIT_DIR, VERSION)
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
info_dict = prepare_info()
|
||||
|
||||
model_path_list = [x.strip() for x in ENSEMBLE_DIR_LIST]
|
||||
print('model_path_list:{}'.format(model_path_list))
|
||||
|
||||
device = torch.device(f'cuda:{GPU_IDS[0]}')
|
||||
|
||||
|
||||
if TASK_TYPE == 'crf':
|
||||
model = EnsembleCRFModel(model_path_list=model_path_list,
|
||||
bert_dir_list=BERT_DIR_LIST,
|
||||
num_tags=len(info_dict['id2ent']),
|
||||
device=device,
|
||||
lamb=LAMBDA)
|
||||
else:
|
||||
model = EnsembleSpanModel(model_path_list=model_path_list,
|
||||
bert_dir_list=BERT_DIR_LIST,
|
||||
num_tags=len(info_dict['id2ent'])+1,
|
||||
device=device)
|
||||
|
||||
|
||||
labels = base_predict(model, device, info_dict, ensemble=True)
|
||||
|
||||
|
||||
for key in labels.keys():
|
||||
with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f:
|
||||
if not len(labels[key]):
|
||||
print(key)
|
||||
f.write("")
|
||||
else:
|
||||
for idx, _label in enumerate(labels[key]):
|
||||
f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n')
|
||||
|
||||
def mixed_predict():
|
||||
save_dir = os.path.join(SUBMIT_DIR, VERSION)
|
||||
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
model_path_list = [x.strip() for x in MIX_DIR_LIST]
|
||||
print('model_path_list:{}'.format(model_path_list))
|
||||
|
||||
all_labels = []
|
||||
|
||||
for i, model_path in enumerate(model_path_list):
|
||||
if i <= 4:
|
||||
info_dict = mixed_prepare_info(mixed='span')
|
||||
|
||||
model = SpanModel(bert_dir=MIX_BERT_DIR, num_tags=len(info_dict['id2ent']) + 1)
|
||||
print(f'Load model from {model_path}')
|
||||
model, device = load_model_and_parallel(model, GPU_IDS, model_path)
|
||||
model.eval()
|
||||
labels = base_predict(model, device, info_dict, ensemble=False, mixed='span')
|
||||
|
||||
else:
|
||||
info_dict = mixed_prepare_info(mixed='crf')
|
||||
|
||||
model = CRFModel(bert_dir=MIX_BERT_DIR, num_tags=len(info_dict['id2ent']))
|
||||
print(f'Load model from {model_path}')
|
||||
model, device = load_model_and_parallel(model, GPU_IDS, model_path)
|
||||
model.eval()
|
||||
labels = base_predict(model, device, info_dict, ensemble=False, mixed='crf')
|
||||
|
||||
all_labels.append(labels)
|
||||
|
||||
labels = ensemble_vote(all_labels, THRESHOLD)
|
||||
|
||||
# for key in labels.keys():
|
||||
|
||||
for key in range(1500, 1997):
|
||||
with open(os.path.join(save_dir, f'{key}.ann'), 'w', encoding='utf-8') as f:
|
||||
if not len(labels[key]):
|
||||
print(key)
|
||||
f.write("")
|
||||
else:
|
||||
for idx, _label in enumerate(labels[key]):
|
||||
f.write(f'T{idx + 1}\t{_label[0]} {_label[1]} {_label[2]}\t{_label[3]}\n')
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
assert VERSION in ['single', 'ensemble', 'mixed'], 'VERSION mismatch'
|
||||
|
||||
if VERSION == 'single':
|
||||
single_predict()
|
||||
elif VERSION == 'ensemble':
|
||||
if VOTE:
|
||||
print("————————开始投票:————————")
|
||||
ensemble_predict()
|
||||
|
||||
elif VERSION == 'mixed':
|
||||
print("————————开始混合投票:————————")
|
||||
mixed_predict()
|
||||
|
||||
# 压缩result.zip
|
||||
import zipfile
|
||||
|
||||
def zip_file(src_dir):
|
||||
zip_name = src_dir + '.zip'
|
||||
z = zipfile.ZipFile(zip_name, 'w', zipfile.ZIP_DEFLATED)
|
||||
for dirpath, dirnames, filenames in os.walk(src_dir):
|
||||
fpath = dirpath.replace(src_dir, '')
|
||||
fpath = fpath and fpath + os.sep or ''
|
||||
for filename in filenames:
|
||||
z.write(os.path.join(dirpath, filename), fpath + filename)
|
||||
print('==压缩成功==')
|
||||
z.close()
|
||||
|
||||
zip_file('./result')
|
||||
|
32
convert_test_data.py
Normal file
@ -0,0 +1,32 @@
|
||||
import os
|
||||
import json
|
||||
from tqdm import trange
|
||||
|
||||
|
||||
def save_info(data_dir, data, desc):
|
||||
with open(os.path.join(data_dir, f'{desc}.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def convert_test_data_to_json(test_dir, save_dir):
|
||||
|
||||
test_examples = []
|
||||
|
||||
|
||||
# process test examples
|
||||
for i in trange(1500, 1997):
|
||||
with open(os.path.join(test_dir, f'{i}.txt'), encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
|
||||
test_examples.append({'id': i,
|
||||
'text': text})
|
||||
|
||||
save_info(save_dir, test_examples, 'test')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dir = './tcdata/juesai'
|
||||
save_dir = './data/raw_data_random'
|
||||
convert_test_data_to_json(test_dir, save_dir)
|
||||
print('测试数据转换完成')
|
||||
|
55
data/mid_data/crf_ent2id.json
Normal file
@ -0,0 +1,55 @@
|
||||
{
|
||||
"O": 0,
|
||||
"B-DRUG_GROUP": 1,
|
||||
"B-DRUG_DOSAGE": 2,
|
||||
"B-FOOD": 3,
|
||||
"B-DRUG_EFFICACY": 4,
|
||||
"B-FOOD_GROUP": 5,
|
||||
"B-SYMPTOM": 6,
|
||||
"B-DISEASE_GROUP": 7,
|
||||
"B-SYNDROME": 8,
|
||||
"B-PERSON_GROUP": 9,
|
||||
"B-DRUG_TASTE": 10,
|
||||
"B-DRUG_INGREDIENT": 11,
|
||||
"B-DRUG": 12,
|
||||
"B-DISEASE": 13,
|
||||
"I-DRUG_GROUP": 14,
|
||||
"I-DRUG_DOSAGE": 15,
|
||||
"I-FOOD": 16,
|
||||
"I-DRUG_EFFICACY": 17,
|
||||
"I-FOOD_GROUP": 18,
|
||||
"I-SYMPTOM": 19,
|
||||
"I-DISEASE_GROUP": 20,
|
||||
"I-SYNDROME": 21,
|
||||
"I-PERSON_GROUP": 22,
|
||||
"I-DRUG_TASTE": 23,
|
||||
"I-DRUG_INGREDIENT": 24,
|
||||
"I-DRUG": 25,
|
||||
"I-DISEASE": 26,
|
||||
"E-DRUG_GROUP": 27,
|
||||
"E-DRUG_DOSAGE": 28,
|
||||
"E-FOOD": 29,
|
||||
"E-DRUG_EFFICACY": 30,
|
||||
"E-FOOD_GROUP": 31,
|
||||
"E-SYMPTOM": 32,
|
||||
"E-DISEASE_GROUP": 33,
|
||||
"E-SYNDROME": 34,
|
||||
"E-PERSON_GROUP": 35,
|
||||
"E-DRUG_TASTE": 36,
|
||||
"E-DRUG_INGREDIENT": 37,
|
||||
"E-DRUG": 38,
|
||||
"E-DISEASE": 39,
|
||||
"S-DRUG_GROUP": 40,
|
||||
"S-DRUG_DOSAGE": 41,
|
||||
"S-FOOD": 42,
|
||||
"S-DRUG_EFFICACY": 43,
|
||||
"S-FOOD_GROUP": 44,
|
||||
"S-SYMPTOM": 45,
|
||||
"S-DISEASE_GROUP": 46,
|
||||
"S-SYNDROME": 47,
|
||||
"S-PERSON_GROUP": 48,
|
||||
"S-DRUG_TASTE": 49,
|
||||
"S-DRUG_INGREDIENT": 50,
|
||||
"S-DRUG": 51,
|
||||
"S-DISEASE": 52
|
||||
}
|
15
data/mid_data/mrc_ent2id.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"DRUG": "找出药物:用于预防、治疗、诊断疾病并具有康复与保健作用的物质。",
|
||||
"DRUG_INGREDIENT": "找出药物成分:中药组成成分,指中药复方中所含有的所有与该复方临床应用目的密切相关的药理活性成分。",
|
||||
"DISEASE": "找出疾病:指人体在一定原因的损害性作用下,因自稳调节紊乱而发生的异常生命活动过程,会影响生物体的部分或是所有器官。",
|
||||
"SYMPTOM": "找出症状:指疾病过程中机体内的一系列机能、代谢和形态结构异常变化所引起的病人主观上的异常感觉或某些客观病态改变。",
|
||||
"SYNDROME": "找出症候:概括为一系列有相互关联的症状总称,是指不同症状和体征的综合表现。",
|
||||
"DISEASE_GROUP": "找出疾病分组:疾病涉及有人体组织部位的疾病名称的统称概念,非某项具体医学疾病。",
|
||||
"FOOD": "找出食物:指能够满足机体正常生理和生化能量需求,并能延续正常寿命的物质。",
|
||||
"FOOD_GROUP": "找出食物分组:中医中饮食养生中,将食物分为寒热温凉四性,同时中医药禁忌中对于具有某类共同属性食物的统称,记为食物分组。",
|
||||
"PERSON_GROUP": "找出人群:中医药的适用及禁忌范围内相关特定人群。",
|
||||
"DRUG_GROUP": "找出药品分组:具有某一类共同属性的药品类统称概念,非某项具体药品名。例子:止咳药、退烧药",
|
||||
"DRUG_DOSAGE": "找出药物剂量:药物在供给临床使用前,均必须制成适合于医疗和预防应用的形式,成为药物剂型。",
|
||||
"DRUG_TASTE": "找出药物性味:药品的性质和气味。例子:味甘、酸涩、气凉。",
|
||||
"DRUG_EFFICACY": "找出中药功效:药品的主治功能和效果的统称。例子:滋阴补肾、去瘀生新、活血化瘀"
|
||||
}
|
15
data/mid_data/span_ent2id.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"DRUG_GROUP": 1,
|
||||
"DRUG_DOSAGE": 2,
|
||||
"FOOD": 3,
|
||||
"DRUG_EFFICACY": 4,
|
||||
"FOOD_GROUP": 5,
|
||||
"SYMPTOM": 6,
|
||||
"DISEASE_GROUP": 7,
|
||||
"SYNDROME": 8,
|
||||
"PERSON_GROUP": 9,
|
||||
"DRUG_TASTE": 10,
|
||||
"DRUG_INGREDIENT": 11,
|
||||
"DRUG": 12,
|
||||
"DISEASE": 13
|
||||
}
|
24384
data/raw_data/dev.json
Normal file
85588
data/raw_data/pseudo.json
Normal file
164121
data/raw_data/stack.json
Normal file
19417
data/raw_data/test.json
Normal file
139739
data/raw_data/train.json
Normal file
218
main.py
Normal file
@ -0,0 +1,218 @@
|
||||
import time
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from torch.utils.data import DataLoader
|
||||
from sklearn.model_selection import KFold
|
||||
from src.utils.trainer import train
|
||||
from src.utils.options import Args
|
||||
from src.utils.model_utils import build_model
|
||||
from src.utils.dataset_utils import NERDataset
|
||||
from src.utils.evaluator import crf_evaluation, span_evaluation, mrc_evaluation
|
||||
from src.utils.functions_utils import set_seed, get_model_path_list, load_model_and_parallel, get_time_dif
|
||||
from src.preprocess.processor import NERProcessor, convert_examples_to_features
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO
|
||||
)
|
||||
|
||||
def train_base(opt, train_examples, dev_examples=None):
|
||||
with open(os.path.join(opt.mid_data_dir, f'{opt.task_type}_ent2id.json'), encoding='utf-8') as f:
|
||||
ent2id = json.load(f)
|
||||
|
||||
train_features = convert_examples_to_features(opt.task_type, train_examples,
|
||||
opt.max_seq_len, opt.bert_dir, ent2id)[0]
|
||||
|
||||
train_dataset = NERDataset(opt.task_type, train_features, 'train', use_type_embed=opt.use_type_embed)
|
||||
|
||||
if opt.task_type == 'crf':
|
||||
model = build_model('crf', opt.bert_dir, num_tags=len(ent2id),
|
||||
dropout_prob=opt.dropout_prob)
|
||||
elif opt.task_type == 'mrc':
|
||||
model = build_model('mrc', opt.bert_dir,
|
||||
dropout_prob=opt.dropout_prob,
|
||||
use_type_embed=opt.use_type_embed,
|
||||
loss_type=opt.loss_type)
|
||||
else:
|
||||
model = build_model('span', opt.bert_dir, num_tags=len(ent2id)+1,
|
||||
dropout_prob=opt.dropout_prob,
|
||||
loss_type=opt.loss_type)
|
||||
|
||||
train(opt, model, train_dataset)
|
||||
|
||||
if dev_examples is not None:
|
||||
|
||||
dev_features, dev_callback_info = convert_examples_to_features(opt.task_type, dev_examples,
|
||||
opt.max_seq_len, opt.bert_dir, ent2id)
|
||||
|
||||
dev_dataset = NERDataset(opt.task_type, dev_features, 'dev', use_type_embed=opt.use_type_embed)
|
||||
|
||||
dev_loader = DataLoader(dev_dataset, batch_size=opt.eval_batch_size,
|
||||
shuffle=False, num_workers=0)
|
||||
|
||||
dev_info = (dev_loader, dev_callback_info)
|
||||
|
||||
model_path_list = get_model_path_list(opt.output_dir)
|
||||
|
||||
metric_str = ''
|
||||
|
||||
max_f1 = 0.
|
||||
max_f1_step = 0
|
||||
|
||||
max_f1_path = ''
|
||||
|
||||
for idx, model_path in enumerate(model_path_list):
|
||||
|
||||
tmp_step = model_path.split('/')[-2].split('-')[-1]
|
||||
|
||||
|
||||
model, device = load_model_and_parallel(model, opt.gpu_ids[0],
|
||||
ckpt_path=model_path)
|
||||
|
||||
if opt.task_type == 'crf':
|
||||
tmp_metric_str, tmp_f1 = crf_evaluation(model, dev_info, device, ent2id)
|
||||
elif opt.task_type == 'mrc':
|
||||
tmp_metric_str, tmp_f1 = mrc_evaluation(model, dev_info, device)
|
||||
else:
|
||||
tmp_metric_str, tmp_f1 = span_evaluation(model, dev_info, device, ent2id)
|
||||
|
||||
logger.info(f'In step {tmp_step}:\n {tmp_metric_str}')
|
||||
|
||||
metric_str += f'In step {tmp_step}:\n {tmp_metric_str}' + '\n\n'
|
||||
|
||||
if tmp_f1 > max_f1:
|
||||
max_f1 = tmp_f1
|
||||
max_f1_step = tmp_step
|
||||
max_f1_path = model_path
|
||||
|
||||
max_metric_str = f'Max f1 is: {max_f1}, in step {max_f1_step}'
|
||||
|
||||
logger.info(max_metric_str)
|
||||
|
||||
metric_str += max_metric_str + '\n'
|
||||
|
||||
eval_save_path = os.path.join(opt.output_dir, 'eval_metric.txt')
|
||||
|
||||
with open(eval_save_path, 'a', encoding='utf-8') as f1:
|
||||
f1.write(metric_str)
|
||||
|
||||
with open('./best_ckpt_path.txt', 'a', encoding='utf-8') as f2:
|
||||
f2.write(max_f1_path + '\n')
|
||||
|
||||
del_dir_list = [os.path.join(opt.output_dir, path.split('/')[-2])
|
||||
for path in model_path_list if path != max_f1_path]
|
||||
|
||||
import shutil
|
||||
for x in del_dir_list:
|
||||
shutil.rmtree(x)
|
||||
logger.info('{}已删除'.format(x))
|
||||
|
||||
|
||||
def training(opt):
|
||||
if args.task_type == 'mrc':
|
||||
# 62 for mrc query
|
||||
processor = NERProcessor(opt.max_seq_len-62)
|
||||
else:
|
||||
processor = NERProcessor(opt.max_seq_len)
|
||||
|
||||
train_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'train.json'))
|
||||
|
||||
# add pseudo data to train data
|
||||
pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json'))
|
||||
train_raw_examples = train_raw_examples + pseudo_raw_examples
|
||||
|
||||
train_examples = processor.get_examples(train_raw_examples, 'train')
|
||||
|
||||
dev_examples = None
|
||||
if opt.eval_model:
|
||||
dev_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'dev.json'))
|
||||
dev_examples = processor.get_examples(dev_raw_examples, 'dev')
|
||||
|
||||
train_base(opt, train_examples, dev_examples)
|
||||
|
||||
|
||||
def stacking(opt):
|
||||
logger.info('Start to KFold stack attribution model')
|
||||
|
||||
if args.task_type == 'mrc':
|
||||
# 62 for mrc query
|
||||
processor = NERProcessor(opt.max_seq_len-62)
|
||||
else:
|
||||
processor = NERProcessor(opt.max_seq_len)
|
||||
|
||||
kf = KFold(5, shuffle=True, random_state=42)
|
||||
|
||||
stack_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'stack.json'))
|
||||
|
||||
pseudo_raw_examples = processor.read_json(os.path.join(opt.raw_data_dir, 'pseudo.json'))
|
||||
|
||||
base_output_dir = opt.output_dir
|
||||
|
||||
for i, (train_ids, dev_ids) in enumerate(kf.split(stack_raw_examples)):
|
||||
logger.info(f'Start to train the {i} fold')
|
||||
train_raw_examples = [stack_raw_examples[_idx] for _idx in train_ids]
|
||||
|
||||
# add pseudo data to train data
|
||||
train_raw_examples = train_raw_examples + pseudo_raw_examples
|
||||
train_examples = processor.get_examples(train_raw_examples, 'train')
|
||||
|
||||
dev_raw_examples = [stack_raw_examples[_idx] for _idx in dev_ids]
|
||||
dev_info = processor.get_examples(dev_raw_examples, 'dev')
|
||||
|
||||
tmp_output_dir = os.path.join(base_output_dir, f'v{i}')
|
||||
|
||||
opt.output_dir = tmp_output_dir
|
||||
|
||||
train_base(opt, train_examples, dev_info)
|
||||
|
||||
if __name__ == '__main__':
|
||||
start_time = time.time()
|
||||
logging.info('----------------开始计时----------------')
|
||||
logging.info('----------------------------------------')
|
||||
|
||||
args = Args().get_parser()
|
||||
|
||||
assert args.mode in ['train', 'stack'], 'mode mismatch'
|
||||
assert args.task_type in ['crf', 'span', 'mrc']
|
||||
|
||||
args.output_dir = os.path.join(args.output_dir, args.bert_type)
|
||||
|
||||
set_seed(args.seed)
|
||||
|
||||
if args.attack_train != '':
|
||||
args.output_dir += f'_{args.attack_train}'
|
||||
|
||||
if args.weight_decay:
|
||||
args.output_dir += '_wd'
|
||||
|
||||
if args.use_fp16:
|
||||
args.output_dir += '_fp16'
|
||||
|
||||
if args.task_type == 'span':
|
||||
args.output_dir += f'_{args.loss_type}'
|
||||
|
||||
if args.task_type == 'mrc':
|
||||
if args.use_type_embed:
|
||||
args.output_dir += f'_embed'
|
||||
args.output_dir += f'_{args.loss_type}'
|
||||
|
||||
args.output_dir += f'_{args.task_type}'
|
||||
|
||||
if args.mode == 'stack':
|
||||
args.output_dir += '_stack'
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
logger.info(f'{args.mode} {args.task_type} in max_seq_len {args.max_seq_len}')
|
||||
|
||||
if args.mode == 'train':
|
||||
training(args)
|
||||
else:
|
||||
stacking(args)
|
||||
|
||||
time_dif = get_time_dif(start_time)
|
||||
logging.info("----------本次容器运行时长:{}-----------".format(time_dif))
|
BIN
md_files/1.png
Normal file
After Width: | Height: | Size: 26 KiB |
BIN
md_files/10.png
Normal file
After Width: | Height: | Size: 159 KiB |
BIN
md_files/11.png
Normal file
After Width: | Height: | Size: 207 KiB |
BIN
md_files/12.png
Normal file
After Width: | Height: | Size: 188 KiB |
BIN
md_files/13.png
Normal file
After Width: | Height: | Size: 667 KiB |
BIN
md_files/2.png
Normal file
After Width: | Height: | Size: 57 KiB |
BIN
md_files/3.png
Normal file
After Width: | Height: | Size: 57 KiB |
BIN
md_files/4.png
Normal file
After Width: | Height: | Size: 279 KiB |
BIN
md_files/5.png
Normal file
After Width: | Height: | Size: 59 KiB |
BIN
md_files/6.png
Normal file
After Width: | Height: | Size: 177 KiB |
BIN
md_files/7.png
Normal file
After Width: | Height: | Size: 104 KiB |
BIN
md_files/8.png
Normal file
After Width: | Height: | Size: 25 KiB |
BIN
md_files/9.png
Normal file
After Width: | Height: | Size: 9.2 KiB |
35
run.sh
Normal file
@ -0,0 +1,35 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
export MID_DATA_DIR="./data/mid_data"
|
||||
export RAW_DATA_DIR="./data/raw_data"
|
||||
export OUTPUT_DIR="./out"
|
||||
|
||||
export GPU_IDS="0"
|
||||
export BERT_TYPE="roberta_wwm" # roberta_wwm / roberta_wwm_large / uer_large
|
||||
export BERT_DIR="../bert/torch_$BERT_TYPE"
|
||||
|
||||
export MODE="train"
|
||||
export TASK_TYPE="crf"
|
||||
|
||||
python main.py \
|
||||
--gpu_ids=$GPU_IDS \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mid_data_dir=$MID_DATA_DIR \
|
||||
--mode=$MODE \
|
||||
--task_type=$TASK_TYPE \
|
||||
--raw_data_dir=$RAW_DATA_DIR \
|
||||
--bert_dir=$BERT_DIR \
|
||||
--bert_type=$BERT_TYPE \
|
||||
--train_epochs=10 \
|
||||
--swa_start=5 \
|
||||
--attack_train="" \
|
||||
--train_batch_size=24 \
|
||||
--dropout_prob=0.1 \
|
||||
--max_seq_len=512 \
|
||||
--lr=2e-5 \
|
||||
--other_lr=2e-3 \
|
||||
--seed=123 \
|
||||
--weight_decay=0.01 \
|
||||
--loss_type='ls_ce' \
|
||||
--eval_model \
|
||||
#--use_fp16
|
BIN
src/preprocess/__pycache__/processor.cpython-36.pyc
Normal file
183
src/preprocess/convert_raw_data.py
Normal file
@ -0,0 +1,183 @@
|
||||
import os
|
||||
import json
|
||||
from tqdm import trange
|
||||
from sklearn.model_selection import train_test_split, KFold
|
||||
|
||||
|
||||
def save_info(data_dir, data, desc):
|
||||
with open(os.path.join(data_dir, f'{desc}.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def convert_data_to_json(base_dir, save_data=False, save_dict=False):
|
||||
stack_examples = []
|
||||
pseudo_examples = []
|
||||
test_examples = []
|
||||
|
||||
stack_dir = os.path.join(base_dir, 'train')
|
||||
pseudo_dir = os.path.join(base_dir, 'pseudo')
|
||||
test_dir = os.path.join(base_dir, 'test')
|
||||
|
||||
|
||||
# process train examples
|
||||
for i in trange(1000):
|
||||
with open(os.path.join(stack_dir, f'{i}.txt'), encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
|
||||
labels = []
|
||||
with open(os.path.join(stack_dir, f'{i}.ann'), encoding='utf-8') as f:
|
||||
for line in f.readlines():
|
||||
tmp_label = line.strip().split('\t')
|
||||
assert len(tmp_label) == 3
|
||||
tmp_mid = tmp_label[1].split()
|
||||
tmp_label = [tmp_label[0]] + tmp_mid + [tmp_label[2]]
|
||||
|
||||
labels.append(tmp_label)
|
||||
tmp_label[2] = int(tmp_label[2])
|
||||
tmp_label[3] = int(tmp_label[3])
|
||||
|
||||
assert text[tmp_label[2]:tmp_label[3]] == tmp_label[-1], '{},{}索引抽取错误'.format(tmp_label, i)
|
||||
|
||||
stack_examples.append({'id': i,
|
||||
'text': text,
|
||||
'labels': labels,
|
||||
'pseudo': 0})
|
||||
|
||||
|
||||
# 构建实体知识库
|
||||
kf = KFold(10)
|
||||
entities = set()
|
||||
ent_types = set()
|
||||
for _now_id, _candidate_id in kf.split(stack_examples):
|
||||
now = [stack_examples[_id] for _id in _now_id]
|
||||
candidate = [stack_examples[_id] for _id in _candidate_id]
|
||||
now_entities = set()
|
||||
|
||||
for _ex in now:
|
||||
for _label in _ex['labels']:
|
||||
ent_types.add(_label[1])
|
||||
|
||||
if len(_label[-1]) > 1:
|
||||
now_entities.add(_label[-1])
|
||||
entities.add(_label[-1])
|
||||
# print(len(now_entities))
|
||||
for _ex in candidate:
|
||||
text = _ex['text']
|
||||
candidate_entities = []
|
||||
|
||||
for _ent in now_entities:
|
||||
if _ent in text:
|
||||
candidate_entities.append(_ent)
|
||||
|
||||
_ex['candidate_entities'] = candidate_entities
|
||||
assert len(ent_types) == 13
|
||||
|
||||
# process test examples predicted by the preliminary model
|
||||
for i in trange(1000, 1500):
|
||||
with open(os.path.join(pseudo_dir, f'{i}.txt'), encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
|
||||
candidate_entities = []
|
||||
for _ent in entities:
|
||||
if _ent in text:
|
||||
candidate_entities.append(_ent)
|
||||
|
||||
labels = []
|
||||
with open(os.path.join(pseudo_dir, f'{i}.ann'), encoding='utf-8') as f:
|
||||
for line in f.readlines():
|
||||
tmp_label = line.strip().split('\t')
|
||||
assert len(tmp_label) == 3
|
||||
tmp_mid = tmp_label[1].split()
|
||||
tmp_label = [tmp_label[0]] + tmp_mid + [tmp_label[2]]
|
||||
|
||||
labels.append(tmp_label)
|
||||
tmp_label[2] = int(tmp_label[2])
|
||||
tmp_label[3] = int(tmp_label[3])
|
||||
|
||||
assert text[tmp_label[2]:tmp_label[3]] == tmp_label[-1], '{},{}索引抽取错误'.format(tmp_label, i)
|
||||
|
||||
pseudo_examples.append({'id': i,
|
||||
'text': text,
|
||||
'labels': labels,
|
||||
'candidate_entities': candidate_entities,
|
||||
'pseudo': 1})
|
||||
|
||||
# process test examples
|
||||
for i in trange(1000, 1500):
|
||||
with open(os.path.join(test_dir, f'{i}.txt'), encoding='utf-8') as f:
|
||||
text = f.read()
|
||||
|
||||
candidate_entities = []
|
||||
for _ent in entities:
|
||||
if _ent in text:
|
||||
candidate_entities.append(_ent)
|
||||
|
||||
test_examples.append({'id': i,
|
||||
'text': text,
|
||||
'candidate_entities': candidate_entities})
|
||||
|
||||
train, dev = train_test_split(stack_examples, shuffle=True, random_state=123, test_size=0.15)
|
||||
|
||||
if save_data:
|
||||
save_info(base_dir, stack_examples, 'stack')
|
||||
save_info(base_dir, train, 'train')
|
||||
save_info(base_dir, dev, 'dev')
|
||||
save_info(base_dir, test_examples, 'test')
|
||||
|
||||
save_info(base_dir, pseudo_examples, 'pseudo')
|
||||
|
||||
if save_dict:
|
||||
ent_types = list(ent_types)
|
||||
span_ent2id = {_type: i+1 for i, _type in enumerate(ent_types)}
|
||||
|
||||
ent_types = ['O'] + [p + '-' + _type for p in ['B', 'I', 'E', 'S'] for _type in list(ent_types)]
|
||||
crf_ent2id = {ent: i for i, ent in enumerate(ent_types)}
|
||||
|
||||
mid_data_dir = os.path.join(os.path.split(base_dir)[0], 'mid_data')
|
||||
if not os.path.exists(mid_data_dir):
|
||||
os.mkdir(mid_data_dir)
|
||||
|
||||
save_info(mid_data_dir, span_ent2id, 'span_ent2id')
|
||||
save_info(mid_data_dir, crf_ent2id, 'crf_ent2id')
|
||||
|
||||
|
||||
def build_ent2query(data_dir):
|
||||
# 利用比赛实体类型简介来描述 query
|
||||
ent2query = {
|
||||
# 药物
|
||||
'DRUG': "找出药物:用于预防、治疗、诊断疾病并具有康复与保健作用的物质。",
|
||||
# 药物成分
|
||||
'DRUG_INGREDIENT': "找出药物成分:中药组成成分,指中药复方中所含有的所有与该复方临床应用目的密切相关的药理活性成分。",
|
||||
# 疾病
|
||||
'DISEASE': "找出疾病:指人体在一定原因的损害性作用下,因自稳调节紊乱而发生的异常生命活动过程,会影响生物体的部分或是所有器官。",
|
||||
# 症状
|
||||
'SYMPTOM': "找出症状:指疾病过程中机体内的一系列机能、代谢和形态结构异常变化所引起的病人主观上的异常感觉或某些客观病态改变。",
|
||||
# 症候
|
||||
'SYNDROME': "找出症候:概括为一系列有相互关联的症状总称,是指不同症状和体征的综合表现。",
|
||||
# 疾病分组
|
||||
'DISEASE_GROUP': "找出疾病分组:疾病涉及有人体组织部位的疾病名称的统称概念,非某项具体医学疾病。",
|
||||
# 食物
|
||||
'FOOD': "找出食物:指能够满足机体正常生理和生化能量需求,并能延续正常寿命的物质。",
|
||||
# 食物分组
|
||||
'FOOD_GROUP': "找出食物分组:中医中饮食养生中,将食物分为寒热温凉四性,"
|
||||
"同时中医药禁忌中对于具有某类共同属性食物的统称,记为食物分组。",
|
||||
# 人群
|
||||
'PERSON_GROUP': "找出人群:中医药的适用及禁忌范围内相关特定人群。",
|
||||
# 药品分组
|
||||
'DRUG_GROUP': "找出药品分组:具有某一类共同属性的药品类统称概念,非某项具体药品名。例子:止咳药、退烧药",
|
||||
# 药物剂量
|
||||
'DRUG_DOSAGE': "找出药物剂量:药物在供给临床使用前,均必须制成适合于医疗和预防应用的形式,成为药物剂型。",
|
||||
# 药物性味
|
||||
'DRUG_TASTE': "找出药物性味:药品的性质和气味。例子:味甘、酸涩、气凉。",
|
||||
# 中药功效
|
||||
'DRUG_EFFICACY': "找出中药功效:药品的主治功能和效果的统称。例子:滋阴补肾、去瘀生新、活血化瘀"
|
||||
}
|
||||
|
||||
with open(os.path.join(data_dir, 'mrc_ent2id.json'), 'w', encoding='utf-8') as f:
|
||||
json.dump(ent2query, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
convert_data_to_json('../../data/raw_data', save_data=True, save_dict=True)
|
||||
build_ent2query('../../data/mid_data')
|
||||
|
654
src/preprocess/processor.py
Normal file
@ -0,0 +1,654 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
from transformers import BertTokenizer
|
||||
from collections import defaultdict
|
||||
import random
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ENTITY_TYPES = ['DRUG', 'DRUG_INGREDIENT', 'DISEASE', 'SYMPTOM', 'SYNDROME', 'DISEASE_GROUP',
|
||||
'FOOD', 'FOOD_GROUP', 'PERSON_GROUP', 'DRUG_GROUP', 'DRUG_DOSAGE', 'DRUG_TASTE',
|
||||
'DRUG_EFFICACY']
|
||||
|
||||
|
||||
class InputExample:
|
||||
def __init__(self,
|
||||
set_type,
|
||||
text,
|
||||
labels=None,
|
||||
pseudo=None,
|
||||
distant_labels=None):
|
||||
self.set_type = set_type
|
||||
self.text = text
|
||||
self.labels = labels
|
||||
self.pseudo = pseudo
|
||||
self.distant_labels = distant_labels
|
||||
|
||||
|
||||
class BaseFeature:
|
||||
def __init__(self,
|
||||
token_ids,
|
||||
attention_masks,
|
||||
token_type_ids):
|
||||
# BERT 输入
|
||||
self.token_ids = token_ids
|
||||
self.attention_masks = attention_masks
|
||||
self.token_type_ids = token_type_ids
|
||||
|
||||
|
||||
class CRFFeature(BaseFeature):
|
||||
def __init__(self,
|
||||
token_ids,
|
||||
attention_masks,
|
||||
token_type_ids,
|
||||
labels=None,
|
||||
pseudo=None,
|
||||
distant_labels=None):
|
||||
super(CRFFeature, self).__init__(token_ids=token_ids,
|
||||
attention_masks=attention_masks,
|
||||
token_type_ids=token_type_ids)
|
||||
# labels
|
||||
self.labels = labels
|
||||
|
||||
# pseudo
|
||||
self.pseudo = pseudo
|
||||
|
||||
# distant labels
|
||||
self.distant_labels = distant_labels
|
||||
|
||||
|
||||
class SpanFeature(BaseFeature):
|
||||
def __init__(self,
|
||||
token_ids,
|
||||
attention_masks,
|
||||
token_type_ids,
|
||||
start_ids=None,
|
||||
end_ids=None,
|
||||
pseudo=None):
|
||||
super(SpanFeature, self).__init__(token_ids=token_ids,
|
||||
attention_masks=attention_masks,
|
||||
token_type_ids=token_type_ids)
|
||||
self.start_ids = start_ids
|
||||
self.end_ids = end_ids
|
||||
# pseudo
|
||||
self.pseudo = pseudo
|
||||
|
||||
class MRCFeature(BaseFeature):
|
||||
def __init__(self,
|
||||
token_ids,
|
||||
attention_masks,
|
||||
token_type_ids,
|
||||
ent_type=None,
|
||||
start_ids=None,
|
||||
end_ids=None,
|
||||
pseudo=None):
|
||||
super(MRCFeature, self).__init__(token_ids=token_ids,
|
||||
attention_masks=attention_masks,
|
||||
token_type_ids=token_type_ids)
|
||||
self.ent_type = ent_type
|
||||
self.start_ids = start_ids
|
||||
self.end_ids = end_ids
|
||||
|
||||
# pseudo
|
||||
self.pseudo = pseudo
|
||||
|
||||
|
||||
class NERProcessor:
|
||||
def __init__(self, cut_sent_len=256):
|
||||
self.cut_sent_len = cut_sent_len
|
||||
|
||||
@staticmethod
|
||||
def read_json(file_path):
|
||||
with open(file_path, encoding='utf-8') as f:
|
||||
raw_examples = json.load(f)
|
||||
return raw_examples
|
||||
|
||||
@staticmethod
|
||||
def _refactor_labels(sent, labels, distant_labels, start_index):
|
||||
"""
|
||||
分句后需要重构 labels 的 offset
|
||||
:param sent: 切分并重新合并后的句子
|
||||
:param labels: 原始文档级的 labels
|
||||
:param distant_labels: 远程监督 label
|
||||
:param start_index: 该句子在文档中的起始 offset
|
||||
:return (type, entity, offset)
|
||||
"""
|
||||
new_labels, new_distant_labels = [], []
|
||||
end_index = start_index + len(sent)
|
||||
|
||||
for _label in labels:
|
||||
if start_index <= _label[2] <= _label[3] <= end_index:
|
||||
new_offset = _label[2] - start_index
|
||||
|
||||
assert sent[new_offset: new_offset + len(_label[-1])] == _label[-1]
|
||||
|
||||
new_labels.append((_label[1], _label[-1], new_offset))
|
||||
# label 被截断的情况
|
||||
elif _label[2] < end_index < _label[3]:
|
||||
raise RuntimeError(f'{sent}, {_label}')
|
||||
|
||||
for _label in distant_labels:
|
||||
if _label in sent:
|
||||
new_distant_labels.append(_label)
|
||||
|
||||
return new_labels, new_distant_labels
|
||||
|
||||
def get_examples(self, raw_examples, set_type):
|
||||
examples = []
|
||||
|
||||
for i, item in enumerate(raw_examples):
|
||||
text = item['text']
|
||||
distant_labels = item['candidate_entities']
|
||||
pseudo = item['pseudo']
|
||||
|
||||
sentences = cut_sent(text, self.cut_sent_len)
|
||||
start_index = 0
|
||||
|
||||
for sent in sentences:
|
||||
labels, tmp_distant_labels = self._refactor_labels(sent, item['labels'], distant_labels, start_index)
|
||||
|
||||
start_index += len(sent)
|
||||
|
||||
examples.append(InputExample(set_type=set_type,
|
||||
text=sent,
|
||||
labels=labels,
|
||||
pseudo=pseudo,
|
||||
distant_labels=tmp_distant_labels))
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def fine_grade_tokenize(raw_text, tokenizer):
|
||||
"""
|
||||
序列标注任务 BERT 分词器可能会导致标注偏移,
|
||||
用 char-level 来 tokenize
|
||||
"""
|
||||
tokens = []
|
||||
|
||||
for _ch in raw_text:
|
||||
if _ch in [' ', '\t', '\n']:
|
||||
tokens.append('[BLANK]')
|
||||
else:
|
||||
if not len(tokenizer.tokenize(_ch)):
|
||||
tokens.append('[INV]')
|
||||
else:
|
||||
tokens.append(_ch)
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def cut_sentences_v1(sent):
|
||||
"""
|
||||
the first rank of sentence cut
|
||||
"""
|
||||
sent = re.sub('([。!?\?])([^”’])', r"\1\n\2", sent) # 单字符断句符
|
||||
sent = re.sub('(\.{6})([^”’])', r"\1\n\2", sent) # 英文省略号
|
||||
sent = re.sub('(\…{2})([^”’])', r"\1\n\2", sent) # 中文省略号
|
||||
sent = re.sub('([。!?\?][”’])([^,。!?\?])', r"\1\n\2", sent)
|
||||
# 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后
|
||||
return sent.split("\n")
|
||||
|
||||
|
||||
def cut_sentences_v2(sent):
|
||||
"""
|
||||
the second rank of spilt sentence, split ';' | ';'
|
||||
"""
|
||||
sent = re.sub('([;;])([^”’])', r"\1\n\2", sent)
|
||||
return sent.split("\n")
|
||||
|
||||
|
||||
def cut_sent(text, max_seq_len):
|
||||
# 将句子分句,细粒度分句后再重新合并
|
||||
sentences = []
|
||||
|
||||
# 细粒度划分
|
||||
sentences_v1 = cut_sentences_v1(text)
|
||||
for sent_v1 in sentences_v1:
|
||||
if len(sent_v1) > max_seq_len - 2:
|
||||
sentences_v2 = cut_sentences_v2(sent_v1)
|
||||
sentences.extend(sentences_v2)
|
||||
else:
|
||||
sentences.append(sent_v1)
|
||||
|
||||
assert ''.join(sentences) == text
|
||||
|
||||
# 合并
|
||||
merged_sentences = []
|
||||
start_index_ = 0
|
||||
|
||||
while start_index_ < len(sentences):
|
||||
tmp_text = sentences[start_index_]
|
||||
|
||||
end_index_ = start_index_ + 1
|
||||
|
||||
while end_index_ < len(sentences) and \
|
||||
len(tmp_text) + len(sentences[end_index_]) <= max_seq_len - 2:
|
||||
tmp_text += sentences[end_index_]
|
||||
end_index_ += 1
|
||||
|
||||
start_index_ = end_index_
|
||||
|
||||
merged_sentences.append(tmp_text)
|
||||
|
||||
return merged_sentences
|
||||
|
||||
def sent_mask(sent, stop_mask_range_list, mask_prob=0.15):
|
||||
"""
|
||||
将句子中的词以 mask prob 的概率随机 mask,
|
||||
其中 85% 概率被置为 [mask] 15% 的概率不变。
|
||||
:param sent: list of segment words
|
||||
:param stop_mask_range_list: 不能 mask 的区域
|
||||
:param mask_prob: max mask nums: len(sent) * max_mask_prob
|
||||
:return:
|
||||
"""
|
||||
max_mask_token_nums = int(len(sent) * mask_prob)
|
||||
mask_nums = 0
|
||||
mask_sent = []
|
||||
|
||||
for i in range(len(sent)):
|
||||
|
||||
flag = False
|
||||
for _stop_range in stop_mask_range_list:
|
||||
if _stop_range[0] <= i <= _stop_range[1]:
|
||||
flag = True
|
||||
break
|
||||
|
||||
if flag:
|
||||
mask_sent.append(sent[i])
|
||||
continue
|
||||
|
||||
if mask_nums < max_mask_token_nums:
|
||||
# mask_prob 的概率进行 mask, 80% 概率被置为 [mask],10% 概率被替换, 10% 的概率不变
|
||||
if random.random() < mask_prob:
|
||||
mask_sent.append('[MASK]')
|
||||
mask_nums += 1
|
||||
else:
|
||||
mask_sent.append(sent[i])
|
||||
else:
|
||||
mask_sent.append(sent[i])
|
||||
|
||||
return mask_sent
|
||||
|
||||
|
||||
def convert_crf_example(ex_idx, example: InputExample, tokenizer: BertTokenizer,
|
||||
max_seq_len, ent2id):
|
||||
set_type = example.set_type
|
||||
raw_text = example.text
|
||||
entities = example.labels
|
||||
pseudo = example.pseudo
|
||||
|
||||
callback_info = (raw_text,)
|
||||
callback_labels = {x: [] for x in ENTITY_TYPES}
|
||||
|
||||
for _label in entities:
|
||||
callback_labels[_label[0]].append((_label[1], _label[2]))
|
||||
|
||||
callback_info += (callback_labels,)
|
||||
|
||||
tokens = fine_grade_tokenize(raw_text, tokenizer)
|
||||
assert len(tokens) == len(raw_text)
|
||||
|
||||
label_ids = None
|
||||
|
||||
if set_type == 'train':
|
||||
# information for dev callback
|
||||
label_ids = [0] * len(tokens)
|
||||
|
||||
# tag labels ent ex. (T1, DRUG_DOSAGE, 447, 450, 小蜜丸)
|
||||
for ent in entities:
|
||||
ent_type = ent[0]
|
||||
|
||||
ent_start = ent[-1]
|
||||
ent_end = ent_start + len(ent[1]) - 1
|
||||
|
||||
if ent_start == ent_end:
|
||||
label_ids[ent_start] = ent2id['S-' + ent_type]
|
||||
else:
|
||||
label_ids[ent_start] = ent2id['B-' + ent_type]
|
||||
label_ids[ent_end] = ent2id['E-' + ent_type]
|
||||
for i in range(ent_start + 1, ent_end):
|
||||
label_ids[i] = ent2id['I-' + ent_type]
|
||||
|
||||
if len(label_ids) > max_seq_len - 2:
|
||||
label_ids = label_ids[:max_seq_len - 2]
|
||||
|
||||
label_ids = [0] + label_ids + [0]
|
||||
|
||||
# pad
|
||||
if len(label_ids) < max_seq_len:
|
||||
pad_length = max_seq_len - len(label_ids)
|
||||
label_ids = label_ids + [0] * pad_length # CLS SEP PAD label都为O
|
||||
|
||||
assert len(label_ids) == max_seq_len, f'{len(label_ids)}'
|
||||
|
||||
encode_dict = tokenizer.encode_plus(text=tokens,
|
||||
max_length=max_seq_len,
|
||||
pad_to_max_length=True,
|
||||
is_pretokenized=True,
|
||||
return_token_type_ids=True,
|
||||
return_attention_mask=True)
|
||||
|
||||
token_ids = encode_dict['input_ids']
|
||||
attention_masks = encode_dict['attention_mask']
|
||||
token_type_ids = encode_dict['token_type_ids']
|
||||
|
||||
# if ex_idx < 3:
|
||||
# logger.info(f"*** {set_type}_example-{ex_idx} ***")
|
||||
# logger.info(f'text: {" ".join(tokens)}')
|
||||
# logger.info(f"token_ids: {token_ids}")
|
||||
# logger.info(f"attention_masks: {attention_masks}")
|
||||
# logger.info(f"token_type_ids: {token_type_ids}")
|
||||
# logger.info(f"labels: {label_ids}")
|
||||
|
||||
feature = CRFFeature(
|
||||
# bert inputs
|
||||
token_ids=token_ids,
|
||||
attention_masks=attention_masks,
|
||||
token_type_ids=token_type_ids,
|
||||
labels=label_ids,
|
||||
pseudo=pseudo
|
||||
)
|
||||
|
||||
return feature, callback_info
|
||||
|
||||
|
||||
def convert_span_example(ex_idx, example: InputExample, tokenizer: BertTokenizer,
|
||||
max_seq_len, ent2id):
|
||||
set_type = example.set_type
|
||||
raw_text = example.text
|
||||
entities = example.labels
|
||||
pseudo = example.pseudo
|
||||
|
||||
tokens = fine_grade_tokenize(raw_text, tokenizer)
|
||||
assert len(tokens) == len(raw_text)
|
||||
|
||||
callback_labels = {x: [] for x in ENTITY_TYPES}
|
||||
|
||||
for _label in entities:
|
||||
callback_labels[_label[0]].append((_label[1], _label[2]))
|
||||
|
||||
callback_info = (raw_text, callback_labels,)
|
||||
|
||||
start_ids, end_ids = None, None
|
||||
|
||||
if set_type == 'train':
|
||||
start_ids = [0] * len(tokens)
|
||||
end_ids = [0] * len(tokens)
|
||||
|
||||
for _ent in entities:
|
||||
|
||||
ent_type = ent2id[_ent[0]]
|
||||
ent_start = _ent[-1]
|
||||
ent_end = ent_start + len(_ent[1]) - 1
|
||||
|
||||
start_ids[ent_start] = ent_type
|
||||
end_ids[ent_end] = ent_type
|
||||
|
||||
if len(start_ids) > max_seq_len - 2:
|
||||
start_ids = start_ids[:max_seq_len - 2]
|
||||
end_ids = end_ids[:max_seq_len - 2]
|
||||
|
||||
start_ids = [0] + start_ids + [0]
|
||||
end_ids = [0] + end_ids + [0]
|
||||
|
||||
# pad
|
||||
if len(start_ids) < max_seq_len:
|
||||
pad_length = max_seq_len - len(start_ids)
|
||||
|
||||
start_ids = start_ids + [0] * pad_length # CLS SEP PAD label都为O
|
||||
end_ids = end_ids + [0] * pad_length
|
||||
|
||||
assert len(start_ids) == max_seq_len
|
||||
assert len(end_ids) == max_seq_len
|
||||
|
||||
encode_dict = tokenizer.encode_plus(text=tokens,
|
||||
max_length=max_seq_len,
|
||||
pad_to_max_length=True,
|
||||
is_pretokenized=True,
|
||||
return_token_type_ids=True,
|
||||
return_attention_mask=True)
|
||||
|
||||
token_ids = encode_dict['input_ids']
|
||||
attention_masks = encode_dict['attention_mask']
|
||||
token_type_ids = encode_dict['token_type_ids']
|
||||
|
||||
# if ex_idx < 3:
|
||||
# logger.info(f"*** {set_type}_example-{ex_idx} ***")
|
||||
# logger.info(f'text: {" ".join(tokens)}')
|
||||
# logger.info(f"token_ids: {token_ids}")
|
||||
# logger.info(f"attention_masks: {attention_masks}")
|
||||
# logger.info(f"token_type_ids: {token_type_ids}")
|
||||
# if start_ids and end_ids:
|
||||
# logger.info(f"start_ids: {start_ids}")
|
||||
# logger.info(f"end_ids: {end_ids}")
|
||||
|
||||
feature = SpanFeature(token_ids=token_ids,
|
||||
attention_masks=attention_masks,
|
||||
token_type_ids=token_type_ids,
|
||||
start_ids=start_ids,
|
||||
end_ids=end_ids,
|
||||
pseudo=pseudo)
|
||||
|
||||
return feature, callback_info
|
||||
|
||||
def convert_mrc_example(ex_idx, example: InputExample, tokenizer: BertTokenizer,
|
||||
max_seq_len, ent2id, ent2query, mask_prob=None):
|
||||
set_type = example.set_type
|
||||
text_b = example.text
|
||||
entities = example.labels
|
||||
pseudo = example.pseudo
|
||||
|
||||
features = []
|
||||
callback_info = []
|
||||
|
||||
tokens_b = fine_grade_tokenize(text_b, tokenizer)
|
||||
assert len(tokens_b) == len(text_b)
|
||||
|
||||
label_dict = defaultdict(list)
|
||||
|
||||
for ent in entities:
|
||||
ent_type = ent[0]
|
||||
ent_start = ent[-1]
|
||||
ent_end = ent_start + len(ent[1]) - 1
|
||||
label_dict[ent_type].append((ent_start, ent_end, ent[1]))
|
||||
|
||||
# 训练数据中构造
|
||||
if set_type == 'train':
|
||||
|
||||
# 每一类为一个 example
|
||||
# for _type in label_dict.keys():
|
||||
for _type in ENTITY_TYPES:
|
||||
start_ids = [0] * len(tokens_b)
|
||||
end_ids = [0] * len(tokens_b)
|
||||
|
||||
stop_mask_ranges = []
|
||||
|
||||
text_a = ent2query[_type]
|
||||
tokens_a = fine_grade_tokenize(text_a, tokenizer)
|
||||
|
||||
for _label in label_dict[_type]:
|
||||
start_ids[_label[0]] = 1
|
||||
end_ids[_label[1]] = 1
|
||||
|
||||
stop_mask_ranges.append((_label[0], _label[1]))
|
||||
|
||||
if len(start_ids) > max_seq_len - len(tokens_a) - 3:
|
||||
start_ids = start_ids[:max_seq_len - len(tokens_a) - 3]
|
||||
end_ids = end_ids[:max_seq_len - len(tokens_a) - 3]
|
||||
print('产生了不该有的截断')
|
||||
|
||||
start_ids = [0] + [0] * len(tokens_a) + [0] + start_ids + [0]
|
||||
end_ids = [0] + [0] * len(tokens_a) + [0] + end_ids + [0]
|
||||
|
||||
# pad
|
||||
if len(start_ids) < max_seq_len:
|
||||
pad_length = max_seq_len - len(start_ids)
|
||||
|
||||
start_ids = start_ids + [0] * pad_length # CLS SEP PAD label都为O
|
||||
end_ids = end_ids + [0] * pad_length
|
||||
|
||||
assert len(start_ids) == max_seq_len
|
||||
assert len(end_ids) == max_seq_len
|
||||
|
||||
# 随机mask
|
||||
if mask_prob:
|
||||
tokens_b = sent_mask(tokens_b, stop_mask_ranges, mask_prob=mask_prob)
|
||||
|
||||
encode_dict = tokenizer.encode_plus(text=tokens_a,
|
||||
text_pair=tokens_b,
|
||||
max_length=max_seq_len,
|
||||
pad_to_max_length=True,
|
||||
truncation_strategy='only_second',
|
||||
is_pretokenized=True,
|
||||
return_token_type_ids=True,
|
||||
return_attention_mask=True)
|
||||
|
||||
token_ids = encode_dict['input_ids']
|
||||
attention_masks = encode_dict['attention_mask']
|
||||
token_type_ids = encode_dict['token_type_ids']
|
||||
|
||||
# if ex_idx < 3:
|
||||
# logger.info(f"*** {set_type}_example-{ex_idx} ***")
|
||||
# logger.info(f'text: {" ".join(tokens_b)}')
|
||||
# logger.info(f"token_ids: {token_ids}")
|
||||
# logger.info(f"attention_masks: {attention_masks}")
|
||||
# logger.info(f"token_type_ids: {token_type_ids}")
|
||||
# logger.info(f'entity type: {_type}')
|
||||
# logger.info(f"start_ids: {start_ids}")
|
||||
# logger.info(f"end_ids: {end_ids}")
|
||||
|
||||
feature = MRCFeature(token_ids=token_ids,
|
||||
attention_masks=attention_masks,
|
||||
token_type_ids=token_type_ids,
|
||||
ent_type=ent2id[_type],
|
||||
start_ids=start_ids,
|
||||
end_ids=end_ids,
|
||||
pseudo=pseudo
|
||||
)
|
||||
|
||||
features.append(feature)
|
||||
|
||||
# 测试数据构造,为每一类单独构造一个 example
|
||||
else:
|
||||
for _type in ENTITY_TYPES:
|
||||
text_a = ent2query[_type]
|
||||
tokens_a = fine_grade_tokenize(text_a, tokenizer)
|
||||
|
||||
encode_dict = tokenizer.encode_plus(text=tokens_a,
|
||||
text_pair=tokens_b,
|
||||
max_length=max_seq_len,
|
||||
pad_to_max_length=True,
|
||||
truncation_strategy='only_second',
|
||||
is_pretokenized=True,
|
||||
return_token_type_ids=True,
|
||||
return_attention_mask=True)
|
||||
|
||||
token_ids = encode_dict['input_ids']
|
||||
attention_masks = encode_dict['attention_mask']
|
||||
token_type_ids = encode_dict['token_type_ids']
|
||||
|
||||
tmp_callback = (text_b, len(tokens_a) + 2, _type) # (text, text_offset, type, labels)
|
||||
tmp_callback_labels = []
|
||||
|
||||
for _label in label_dict[_type]:
|
||||
tmp_callback_labels.append((_label[2], _label[0]))
|
||||
|
||||
tmp_callback += (tmp_callback_labels, )
|
||||
|
||||
callback_info.append(tmp_callback)
|
||||
|
||||
feature = MRCFeature(token_ids=token_ids,
|
||||
attention_masks=attention_masks,
|
||||
token_type_ids=token_type_ids,
|
||||
ent_type=ent2id[_type])
|
||||
|
||||
features.append(feature)
|
||||
|
||||
return features, callback_info
|
||||
|
||||
|
||||
|
||||
def convert_examples_to_features(task_type, examples, max_seq_len, bert_dir, ent2id):
|
||||
assert task_type in ['crf', 'span', 'mrc']
|
||||
|
||||
tokenizer = BertTokenizer(os.path.join(bert_dir, 'vocab.txt'))
|
||||
|
||||
features = []
|
||||
|
||||
callback_info = []
|
||||
|
||||
logger.info(f'Convert {len(examples)} examples to features')
|
||||
type2id = {x: i for i, x in enumerate(ENTITY_TYPES)}
|
||||
|
||||
for i, example in enumerate(examples):
|
||||
if task_type == 'crf':
|
||||
feature, tmp_callback = convert_crf_example(
|
||||
ex_idx=i,
|
||||
example=example,
|
||||
max_seq_len=max_seq_len,
|
||||
ent2id=ent2id,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
elif task_type == 'mrc':
|
||||
feature, tmp_callback = convert_mrc_example(
|
||||
ex_idx=i,
|
||||
example=example,
|
||||
max_seq_len=max_seq_len,
|
||||
ent2id=type2id,
|
||||
ent2query=ent2id,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
else:
|
||||
feature, tmp_callback = convert_span_example(
|
||||
ex_idx=i,
|
||||
example=example,
|
||||
max_seq_len=max_seq_len,
|
||||
ent2id=ent2id,
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
if feature is None:
|
||||
continue
|
||||
|
||||
if task_type == 'mrc':
|
||||
features.extend(feature)
|
||||
callback_info.extend(tmp_callback)
|
||||
else:
|
||||
features.append(feature)
|
||||
callback_info.append(tmp_callback)
|
||||
|
||||
logger.info(f'Build {len(features)} features')
|
||||
|
||||
out = (features, )
|
||||
|
||||
if not len(callback_info):
|
||||
return out
|
||||
|
||||
type_weight = {} # 统计每一类的比例,用于计算 micro-f1
|
||||
for _type in ENTITY_TYPES:
|
||||
type_weight[_type] = 0.
|
||||
|
||||
count = 0.
|
||||
|
||||
if task_type == 'mrc':
|
||||
for _callback in callback_info:
|
||||
type_weight[_callback[-2]] += len(_callback[-1])
|
||||
count += len(_callback[-1])
|
||||
else:
|
||||
for _callback in callback_info:
|
||||
for _type in _callback[1]:
|
||||
type_weight[_type] += len(_callback[1][_type])
|
||||
count += len(_callback[1][_type])
|
||||
|
||||
for key in type_weight:
|
||||
type_weight[key] /= count
|
||||
|
||||
out += ((callback_info, type_weight), )
|
||||
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
BIN
src/utils/__pycache__/attack_train_utils.cpython-36.pyc
Normal file
BIN
src/utils/__pycache__/dataset_utils.cpython-36.pyc
Normal file
BIN
src/utils/__pycache__/evaluator.cpython-36.pyc
Normal file
BIN
src/utils/__pycache__/functions_utils.cpython-36.pyc
Normal file
BIN
src/utils/__pycache__/model_utils.cpython-36.pyc
Normal file
BIN
src/utils/__pycache__/options.cpython-36.pyc
Normal file
BIN
src/utils/__pycache__/trainer.cpython-36.pyc
Normal file
76
src/utils/attack_train_utils.py
Normal file
@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# FGM
|
||||
class FGM:
|
||||
def __init__(self, model: nn.Module, eps=1.):
|
||||
self.model = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
)
|
||||
self.eps = eps
|
||||
self.backup = {}
|
||||
|
||||
# only attack word embedding
|
||||
def attack(self, emb_name='word_embeddings'):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad and emb_name in name:
|
||||
self.backup[name] = param.data.clone()
|
||||
norm = torch.norm(param.grad)
|
||||
if norm and not torch.isnan(norm):
|
||||
r_at = self.eps * param.grad / norm
|
||||
param.data.add_(r_at)
|
||||
|
||||
def restore(self, emb_name='word_embeddings'):
|
||||
for name, para in self.model.named_parameters():
|
||||
if para.requires_grad and emb_name in name:
|
||||
assert name in self.backup
|
||||
para.data = self.backup[name]
|
||||
|
||||
self.backup = {}
|
||||
|
||||
|
||||
# PGD
|
||||
class PGD:
|
||||
def __init__(self, model, eps=1., alpha=0.3):
|
||||
self.model = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
)
|
||||
self.eps = eps
|
||||
self.alpha = alpha
|
||||
self.emb_backup = {}
|
||||
self.grad_backup = {}
|
||||
|
||||
def attack(self, emb_name='word_embeddings', is_first_attack=False):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad and emb_name in name:
|
||||
if is_first_attack:
|
||||
self.emb_backup[name] = param.data.clone()
|
||||
norm = torch.norm(param.grad)
|
||||
if norm != 0 and not torch.isnan(norm):
|
||||
r_at = self.alpha * param.grad / norm
|
||||
param.data.add_(r_at)
|
||||
param.data = self.project(name, param.data)
|
||||
|
||||
def restore(self, emb_name='word_embeddings'):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad and emb_name in name:
|
||||
assert name in self.emb_backup
|
||||
param.data = self.emb_backup[name]
|
||||
self.emb_backup = {}
|
||||
|
||||
def project(self, param_name, param_data):
|
||||
r = param_data - self.emb_backup[param_name]
|
||||
if torch.norm(r) > self.eps:
|
||||
r = self.eps * r / torch.norm(r)
|
||||
return self.emb_backup[param_name] + r
|
||||
|
||||
def backup_grad(self):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad and param.grad is not None:
|
||||
self.grad_backup[name] = param.grad.clone()
|
||||
|
||||
def restore_grad(self):
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.requires_grad and param.grad is not None:
|
||||
param.grad = self.grad_backup[name]
|
53
src/utils/dataset_utils.py
Normal file
@ -0,0 +1,53 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class NERDataset(Dataset):
|
||||
def __init__(self, task_type, features, mode, **kwargs):
|
||||
|
||||
self.nums = len(features)
|
||||
|
||||
self.token_ids = [torch.tensor(example.token_ids).long() for example in features]
|
||||
self.attention_masks = [torch.tensor(example.attention_masks).float() for example in features]
|
||||
self.token_type_ids = [torch.tensor(example.token_type_ids).long() for example in features]
|
||||
|
||||
self.labels = None
|
||||
self.start_ids, self.end_ids = None, None
|
||||
self.ent_type = None
|
||||
self.pseudo = None
|
||||
if mode == 'train':
|
||||
self.pseudo = [torch.tensor(example.pseudo).long() for example in features]
|
||||
if task_type == 'crf':
|
||||
self.labels = [torch.tensor(example.labels) for example in features]
|
||||
else:
|
||||
self.start_ids = [torch.tensor(example.start_ids).long() for example in features]
|
||||
self.end_ids = [torch.tensor(example.end_ids).long() for example in features]
|
||||
|
||||
if kwargs.pop('use_type_embed', False):
|
||||
self.ent_type = [torch.tensor(example.ent_type) for example in features]
|
||||
|
||||
def __len__(self):
|
||||
return self.nums
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = {'token_ids': self.token_ids[index],
|
||||
'attention_masks': self.attention_masks[index],
|
||||
'token_type_ids': self.token_type_ids[index]}
|
||||
|
||||
if self.ent_type is not None:
|
||||
data['ent_type'] = self.ent_type[index]
|
||||
|
||||
if self.labels is not None:
|
||||
data['labels'] = self.labels[index]
|
||||
|
||||
if self.pseudo is not None:
|
||||
data['pseudo'] = self.pseudo[index]
|
||||
|
||||
if self.start_ids is not None:
|
||||
data['start_ids'] = self.start_ids[index]
|
||||
data['end_ids'] = self.end_ids[index]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
|
286
src/utils/evaluator.py
Normal file
@ -0,0 +1,286 @@
|
||||
import torch
|
||||
import logging
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from src.preprocess.processor import ENTITY_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_base_out(model, loader, device):
|
||||
"""
|
||||
每一个任务的 forward 都一样,封装起来
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for idx, _batch in enumerate(loader):
|
||||
|
||||
for key in _batch.keys():
|
||||
_batch[key] = _batch[key].to(device)
|
||||
|
||||
tmp_out = model(**_batch)
|
||||
|
||||
yield tmp_out
|
||||
|
||||
|
||||
def crf_decode(decode_tokens, raw_text, id2ent):
|
||||
"""
|
||||
CRF 解码,用于解码 time loc 的提取
|
||||
"""
|
||||
predict_entities = {}
|
||||
|
||||
decode_tokens = decode_tokens[1:-1] # 除去 CLS SEP token
|
||||
|
||||
index_ = 0
|
||||
|
||||
while index_ < len(decode_tokens):
|
||||
|
||||
token_label = id2ent[decode_tokens[index_]].split('-')
|
||||
|
||||
if token_label[0].startswith('S'):
|
||||
token_type = token_label[1]
|
||||
tmp_ent = raw_text[index_]
|
||||
|
||||
if token_type not in predict_entities:
|
||||
predict_entities[token_type] = [(tmp_ent, index_)]
|
||||
else:
|
||||
predict_entities[token_type].append((tmp_ent, int(index_)))
|
||||
|
||||
index_ += 1
|
||||
|
||||
elif token_label[0].startswith('B'):
|
||||
token_type = token_label[1]
|
||||
start_index = index_
|
||||
|
||||
index_ += 1
|
||||
while index_ < len(decode_tokens):
|
||||
temp_token_label = id2ent[decode_tokens[index_]].split('-')
|
||||
|
||||
if temp_token_label[0].startswith('I') and token_type == temp_token_label[1]:
|
||||
index_ += 1
|
||||
elif temp_token_label[0].startswith('E') and token_type == temp_token_label[1]:
|
||||
end_index = index_
|
||||
index_ += 1
|
||||
|
||||
tmp_ent = raw_text[start_index: end_index + 1]
|
||||
|
||||
if token_type not in predict_entities:
|
||||
predict_entities[token_type] = [(tmp_ent, start_index)]
|
||||
else:
|
||||
predict_entities[token_type].append((tmp_ent, int(start_index)))
|
||||
|
||||
break
|
||||
else:
|
||||
break
|
||||
else:
|
||||
index_ += 1
|
||||
|
||||
return predict_entities
|
||||
|
||||
|
||||
# 严格解码 baseline
|
||||
def span_decode(start_logits, end_logits, raw_text, id2ent):
|
||||
predict_entities = defaultdict(list)
|
||||
|
||||
start_pred = np.argmax(start_logits, -1)
|
||||
end_pred = np.argmax(end_logits, -1)
|
||||
|
||||
for i, s_type in enumerate(start_pred):
|
||||
if s_type == 0:
|
||||
continue
|
||||
for j, e_type in enumerate(end_pred[i:]):
|
||||
if s_type == e_type:
|
||||
tmp_ent = raw_text[i:i + j + 1]
|
||||
predict_entities[id2ent[s_type]].append((tmp_ent, i))
|
||||
break
|
||||
|
||||
return predict_entities
|
||||
|
||||
# 严格解码 baseline
|
||||
def mrc_decode(start_logits, end_logits, raw_text):
|
||||
predict_entities = []
|
||||
start_pred = np.argmax(start_logits, -1)
|
||||
end_pred = np.argmax(end_logits, -1)
|
||||
|
||||
for i, s_type in enumerate(start_pred):
|
||||
if s_type == 0:
|
||||
continue
|
||||
for j, e_type in enumerate(end_pred[i:]):
|
||||
if s_type == e_type:
|
||||
tmp_ent = raw_text[i:i+j+1]
|
||||
predict_entities.append((tmp_ent, i))
|
||||
break
|
||||
|
||||
return predict_entities
|
||||
|
||||
|
||||
def calculate_metric(gt, predict):
|
||||
"""
|
||||
计算 tp fp fn
|
||||
"""
|
||||
tp, fp, fn = 0, 0, 0
|
||||
for entity_predict in predict:
|
||||
flag = 0
|
||||
for entity_gt in gt:
|
||||
if entity_predict[0] == entity_gt[0] and entity_predict[1] == entity_gt[1]:
|
||||
flag = 1
|
||||
tp += 1
|
||||
break
|
||||
if flag == 0:
|
||||
fp += 1
|
||||
|
||||
fn = len(gt) - tp
|
||||
|
||||
return np.array([tp, fp, fn])
|
||||
|
||||
|
||||
def get_p_r_f(tp, fp, fn):
|
||||
p = tp / (tp + fp) if tp + fp != 0 else 0
|
||||
r = tp / (tp + fn) if tp + fn != 0 else 0
|
||||
f1 = 2 * p * r / (p + r) if p + r != 0 else 0
|
||||
return np.array([p, r, f1])
|
||||
|
||||
|
||||
def crf_evaluation(model, dev_info, device, ent2id):
|
||||
dev_loader, (dev_callback_info, type_weight) = dev_info
|
||||
|
||||
pred_tokens = []
|
||||
|
||||
for tmp_pred in get_base_out(model, dev_loader, device):
|
||||
pred_tokens.extend(tmp_pred[0])
|
||||
|
||||
assert len(pred_tokens) == len(dev_callback_info)
|
||||
|
||||
id2ent = {ent2id[key]: key for key in ent2id.keys()}
|
||||
|
||||
role_metric = np.zeros([13, 3])
|
||||
|
||||
mirco_metrics = np.zeros(3)
|
||||
|
||||
for tmp_tokens, tmp_callback in zip(pred_tokens, dev_callback_info):
|
||||
|
||||
text, gt_entities = tmp_callback
|
||||
|
||||
tmp_metric = np.zeros([13, 3])
|
||||
|
||||
pred_entities = crf_decode(tmp_tokens, text, id2ent)
|
||||
|
||||
for idx, _type in enumerate(ENTITY_TYPES):
|
||||
if _type not in pred_entities:
|
||||
pred_entities[_type] = []
|
||||
|
||||
tmp_metric[idx] += calculate_metric(gt_entities[_type], pred_entities[_type])
|
||||
|
||||
role_metric += tmp_metric
|
||||
|
||||
for idx, _type in enumerate(ENTITY_TYPES):
|
||||
temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2])
|
||||
|
||||
mirco_metrics += temp_metric * type_weight[_type]
|
||||
|
||||
metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \
|
||||
f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}'
|
||||
|
||||
return metric_str, mirco_metrics[2]
|
||||
|
||||
|
||||
def span_evaluation(model, dev_info, device, ent2id):
|
||||
dev_loader, (dev_callback_info, type_weight) = dev_info
|
||||
|
||||
start_logits, end_logits = None, None
|
||||
|
||||
model.eval()
|
||||
|
||||
for tmp_pred in get_base_out(model, dev_loader, device):
|
||||
tmp_start_logits = tmp_pred[0].cpu().numpy()
|
||||
tmp_end_logits = tmp_pred[1].cpu().numpy()
|
||||
|
||||
if start_logits is None:
|
||||
start_logits = tmp_start_logits
|
||||
end_logits = tmp_end_logits
|
||||
else:
|
||||
start_logits = np.append(start_logits, tmp_start_logits, axis=0)
|
||||
end_logits = np.append(end_logits, tmp_end_logits, axis=0)
|
||||
|
||||
assert len(start_logits) == len(end_logits) == len(dev_callback_info)
|
||||
|
||||
role_metric = np.zeros([13, 3])
|
||||
|
||||
mirco_metrics = np.zeros(3)
|
||||
|
||||
id2ent = {ent2id[key]: key for key in ent2id.keys()}
|
||||
|
||||
for tmp_start_logits, tmp_end_logits, tmp_callback \
|
||||
in zip(start_logits, end_logits, dev_callback_info):
|
||||
|
||||
text, gt_entities = tmp_callback
|
||||
|
||||
tmp_start_logits = tmp_start_logits[1:1 + len(text)]
|
||||
tmp_end_logits = tmp_end_logits[1:1 + len(text)]
|
||||
|
||||
pred_entities = span_decode(tmp_start_logits, tmp_end_logits, text, id2ent)
|
||||
|
||||
for idx, _type in enumerate(ENTITY_TYPES):
|
||||
if _type not in pred_entities:
|
||||
pred_entities[_type] = []
|
||||
|
||||
role_metric[idx] += calculate_metric(gt_entities[_type], pred_entities[_type])
|
||||
|
||||
for idx, _type in enumerate(ENTITY_TYPES):
|
||||
temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2])
|
||||
|
||||
mirco_metrics += temp_metric * type_weight[_type]
|
||||
|
||||
metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \
|
||||
f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}'
|
||||
|
||||
return metric_str, mirco_metrics[2]
|
||||
|
||||
def mrc_evaluation(model, dev_info, device):
|
||||
dev_loader, (dev_callback_info, type_weight) = dev_info
|
||||
|
||||
start_logits, end_logits = None, None
|
||||
|
||||
model.eval()
|
||||
|
||||
for tmp_pred in get_base_out(model, dev_loader, device):
|
||||
tmp_start_logits = tmp_pred[0].cpu().numpy()
|
||||
tmp_end_logits = tmp_pred[1].cpu().numpy()
|
||||
|
||||
if start_logits is None:
|
||||
start_logits = tmp_start_logits
|
||||
end_logits = tmp_end_logits
|
||||
else:
|
||||
start_logits = np.append(start_logits, tmp_start_logits, axis=0)
|
||||
end_logits = np.append(end_logits, tmp_end_logits, axis=0)
|
||||
|
||||
assert len(start_logits) == len(end_logits) == len(dev_callback_info)
|
||||
|
||||
role_metric = np.zeros([13, 3])
|
||||
|
||||
mirco_metrics = np.zeros(3)
|
||||
|
||||
id2ent = {x: i for i, x in enumerate(ENTITY_TYPES)}
|
||||
|
||||
for tmp_start_logits, tmp_end_logits, tmp_callback \
|
||||
in zip(start_logits, end_logits, dev_callback_info):
|
||||
|
||||
text, text_offset, ent_type, gt_entities = tmp_callback
|
||||
|
||||
tmp_start_logits = tmp_start_logits[text_offset:text_offset+len(text)]
|
||||
tmp_end_logits = tmp_end_logits[text_offset:text_offset+len(text)]
|
||||
|
||||
pred_entities = mrc_decode(tmp_start_logits, tmp_end_logits, text)
|
||||
|
||||
role_metric[id2ent[ent_type]] += calculate_metric(gt_entities, pred_entities)
|
||||
|
||||
for idx, _type in enumerate(ENTITY_TYPES):
|
||||
temp_metric = get_p_r_f(role_metric[idx][0], role_metric[idx][1], role_metric[idx][2])
|
||||
|
||||
mirco_metrics += temp_metric * type_weight[_type]
|
||||
|
||||
metric_str = f'[MIRCO] precision: {mirco_metrics[0]:.4f}, ' \
|
||||
f'recall: {mirco_metrics[1]:.4f}, f1: {mirco_metrics[2]:.4f}'
|
||||
|
||||
return metric_str, mirco_metrics[2]
|
159
src/utils/functions_utils.py
Normal file
@ -0,0 +1,159 @@
|
||||
import os
|
||||
import copy
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_time_dif(start_time):
|
||||
"""
|
||||
获取已经使用的时间
|
||||
:param start_time:
|
||||
:return:
|
||||
"""
|
||||
end_time = time.time()
|
||||
time_dif = end_time - start_time
|
||||
return timedelta(seconds=int(round(time_dif)))
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""
|
||||
设置随机种子
|
||||
:param seed:
|
||||
:return:
|
||||
"""
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def load_model_and_parallel(model, gpu_ids, ckpt_path=None, strict=True):
|
||||
"""
|
||||
加载模型 & 放置到 GPU 中(单卡 / 多卡)
|
||||
"""
|
||||
gpu_ids = gpu_ids.split(',')
|
||||
|
||||
# set to device to the first cuda
|
||||
device = torch.device("cpu" if gpu_ids[0] == '-1' else "cuda:" + gpu_ids[0])
|
||||
|
||||
if ckpt_path is not None:
|
||||
logger.info(f'Load ckpt from {ckpt_path}')
|
||||
model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu')), strict=strict)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if len(gpu_ids) > 1:
|
||||
logger.info(f'Use multi gpus in: {gpu_ids}')
|
||||
gpu_ids = [int(x) for x in gpu_ids]
|
||||
model = torch.nn.DataParallel(model, device_ids=gpu_ids)
|
||||
else:
|
||||
logger.info(f'Use single gpu in: {gpu_ids}')
|
||||
|
||||
return model, device
|
||||
|
||||
|
||||
def get_model_path_list(base_dir):
|
||||
"""
|
||||
从文件夹中获取 model.pt 的路径
|
||||
"""
|
||||
model_lists = []
|
||||
|
||||
for root, dirs, files in os.walk(base_dir):
|
||||
for _file in files:
|
||||
if 'model.pt' == _file:
|
||||
model_lists.append(os.path.join(root, _file))
|
||||
|
||||
model_lists = sorted(model_lists,
|
||||
key=lambda x: (x.split('/')[-3], int(x.split('/')[-2].split('-')[-1])))
|
||||
|
||||
return model_lists
|
||||
|
||||
|
||||
def swa(model, model_dir, swa_start=1):
|
||||
"""
|
||||
swa 滑动平均模型,一般在训练平稳阶段再使用 SWA
|
||||
"""
|
||||
model_path_list = get_model_path_list(model_dir)
|
||||
|
||||
assert 1 <= swa_start < len(model_path_list) - 1, \
|
||||
f'Using swa, swa start should smaller than {len(model_path_list) - 1} and bigger than 0'
|
||||
|
||||
swa_model = copy.deepcopy(model)
|
||||
swa_n = 0.
|
||||
|
||||
with torch.no_grad():
|
||||
for _ckpt in model_path_list[swa_start:]:
|
||||
logger.info(f'Load model from {_ckpt}')
|
||||
model.load_state_dict(torch.load(_ckpt, map_location=torch.device('cpu')))
|
||||
tmp_para_dict = dict(model.named_parameters())
|
||||
|
||||
alpha = 1. / (swa_n + 1.)
|
||||
|
||||
for name, para in swa_model.named_parameters():
|
||||
para.copy_(tmp_para_dict[name].data.clone() * alpha + para.data.clone() * (1. - alpha))
|
||||
|
||||
swa_n += 1
|
||||
|
||||
# use 100000 to represent swa to avoid clash
|
||||
swa_model_dir = os.path.join(model_dir, f'checkpoint-100000')
|
||||
if not os.path.exists(swa_model_dir):
|
||||
os.mkdir(swa_model_dir)
|
||||
|
||||
logger.info(f'Save swa model in: {swa_model_dir}')
|
||||
|
||||
swa_model_path = os.path.join(swa_model_dir, 'model.pt')
|
||||
|
||||
torch.save(swa_model.state_dict(), swa_model_path)
|
||||
|
||||
return swa_model
|
||||
|
||||
|
||||
def vote(entities_list, threshold=0.9):
|
||||
"""
|
||||
实体级别的投票方式 (entity_type, entity_start, entity_end, entity_text)
|
||||
:param entities_list: 所有模型预测出的一个文件的实体
|
||||
:param threshold:大于70%模型预测出来的实体才能被选中
|
||||
:return:
|
||||
"""
|
||||
threshold_nums = int(len(entities_list)*threshold)
|
||||
entities_dict = defaultdict(int)
|
||||
entities = defaultdict(list)
|
||||
|
||||
for _entities in entities_list:
|
||||
for _type in _entities:
|
||||
for _ent in _entities[_type]:
|
||||
entities_dict[(_type, _ent[0], _ent[1])] += 1
|
||||
|
||||
for key in entities_dict:
|
||||
if entities_dict[key] >= threshold_nums:
|
||||
entities[key[0]].append((key[1], key[2]))
|
||||
|
||||
return entities
|
||||
|
||||
def ensemble_vote(entities_list, threshold=0.9):
|
||||
"""
|
||||
针对 ensemble model 进行的 vote
|
||||
实体级别的投票方式 (entity_type, entity_start, entity_end, entity_text)
|
||||
"""
|
||||
threshold_nums = int(len(entities_list)*threshold)
|
||||
entities_dict = defaultdict(int)
|
||||
|
||||
entities = defaultdict(list)
|
||||
|
||||
for _entities in entities_list:
|
||||
for _id in _entities:
|
||||
for _ent in _entities[_id]:
|
||||
entities_dict[(_id, ) + _ent] += 1
|
||||
|
||||
for key in entities_dict:
|
||||
if entities_dict[key] >= threshold_nums:
|
||||
entities[key[0]].append(key[1:])
|
||||
|
||||
return entities
|
646
src/utils/model_utils.py
Normal file
@ -0,0 +1,646 @@
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchcrf import CRF
|
||||
from itertools import repeat
|
||||
from transformers import BertModel
|
||||
from src.utils.functions_utils import vote
|
||||
from src.utils.evaluator import crf_decode, span_decode
|
||||
|
||||
|
||||
class LabelSmoothingCrossEntropy(nn.Module):
|
||||
def __init__(self, eps=0.1, reduction='mean', ignore_index=-100):
|
||||
super(LabelSmoothingCrossEntropy, self).__init__()
|
||||
self.eps = eps
|
||||
self.reduction = reduction
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self, output, target):
|
||||
c = output.size()[-1]
|
||||
log_pred = torch.log_softmax(output, dim=-1)
|
||||
if self.reduction == 'sum':
|
||||
loss = -log_pred.sum()
|
||||
else:
|
||||
loss = -log_pred.sum(dim=-1)
|
||||
if self.reduction == 'mean':
|
||||
loss = loss.mean()
|
||||
|
||||
|
||||
return loss * self.eps / c + (1 - self.eps) * torch.nn.functional.nll_loss(log_pred, target,
|
||||
reduction=self.reduction,
|
||||
ignore_index=self.ignore_index)
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""Multi-class Focal loss implementation"""
|
||||
def __init__(self, gamma=2, weight=None, reduction='mean', ignore_index=-100):
|
||||
super(FocalLoss, self).__init__()
|
||||
self.gamma = gamma
|
||||
self.weight = weight
|
||||
self.ignore_index = ignore_index
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, input, target):
|
||||
"""
|
||||
input: [N, C]
|
||||
target: [N, ]
|
||||
"""
|
||||
log_pt = torch.log_softmax(input, dim=1)
|
||||
pt = torch.exp(log_pt)
|
||||
log_pt = (1 - pt) ** self.gamma * log_pt
|
||||
loss = torch.nn.functional.nll_loss(log_pt, target, self.weight, reduction=self.reduction, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
class SpatialDropout(nn.Module):
|
||||
"""
|
||||
对字级别的向量进行丢弃
|
||||
"""
|
||||
def __init__(self, drop_prob):
|
||||
super(SpatialDropout, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
@staticmethod
|
||||
def _make_noise(input):
|
||||
return input.new().resize_(input.size(0), *repeat(1, input.dim() - 2), input.size(2))
|
||||
|
||||
def forward(self, inputs):
|
||||
output = inputs.clone()
|
||||
if not self.training or self.drop_prob == 0:
|
||||
return inputs
|
||||
else:
|
||||
noise = self._make_noise(inputs)
|
||||
if self.drop_prob == 1:
|
||||
noise.fill_(0)
|
||||
else:
|
||||
noise.bernoulli_(1 - self.drop_prob).div_(1 - self.drop_prob)
|
||||
noise = noise.expand_as(inputs)
|
||||
output.mul_(noise)
|
||||
return output
|
||||
|
||||
class ConditionalLayerNorm(nn.Module):
|
||||
def __init__(self,
|
||||
normalized_shape,
|
||||
cond_shape,
|
||||
eps=1e-12):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
|
||||
self.weight = nn.Parameter(torch.Tensor(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.Tensor(normalized_shape))
|
||||
|
||||
self.weight_dense = nn.Linear(cond_shape, normalized_shape, bias=False)
|
||||
self.bias_dense = nn.Linear(cond_shape, normalized_shape, bias=False)
|
||||
|
||||
self.reset_weight_and_bias()
|
||||
|
||||
def reset_weight_and_bias(self):
|
||||
"""
|
||||
此处初始化的作用是在训练开始阶段不让 conditional layer norm 起作用
|
||||
"""
|
||||
nn.init.ones_(self.weight)
|
||||
nn.init.zeros_(self.bias)
|
||||
|
||||
nn.init.zeros_(self.weight_dense.weight)
|
||||
nn.init.zeros_(self.bias_dense.weight)
|
||||
|
||||
def forward(self, inputs, cond=None):
|
||||
assert cond is not None, 'Conditional tensor need to input when use conditional layer norm'
|
||||
cond = torch.unsqueeze(cond, 1) # (b, 1, h*2)
|
||||
|
||||
weight = self.weight_dense(cond) + self.weight # (b, 1, h)
|
||||
bias = self.bias_dense(cond) + self.bias # (b, 1, h)
|
||||
|
||||
mean = torch.mean(inputs, dim=-1, keepdim=True) # (b, s, 1)
|
||||
outputs = inputs - mean # (b, s, h)
|
||||
|
||||
variance = torch.mean(outputs ** 2, dim=-1, keepdim=True)
|
||||
std = torch.sqrt(variance + self.eps) # (b, s, 1)
|
||||
|
||||
outputs = outputs / std # (b, s, h)
|
||||
|
||||
outputs = outputs * weight + bias
|
||||
|
||||
return outputs
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
def __init__(self,
|
||||
bert_dir,
|
||||
dropout_prob):
|
||||
super(BaseModel, self).__init__()
|
||||
config_path = os.path.join(bert_dir, 'config.json')
|
||||
|
||||
assert os.path.exists(bert_dir) and os.path.exists(config_path), \
|
||||
'pretrained bert file does not exist'
|
||||
|
||||
self.bert_module = BertModel.from_pretrained(bert_dir,
|
||||
output_hidden_states=True,
|
||||
hidden_dropout_prob=dropout_prob)
|
||||
|
||||
self.bert_config = self.bert_module.config
|
||||
|
||||
@staticmethod
|
||||
def _init_weights(blocks, **kwargs):
|
||||
"""
|
||||
参数初始化,将 Linear / Embedding / LayerNorm 与 Bert 进行一样的初始化
|
||||
"""
|
||||
for block in blocks:
|
||||
for module in block.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
nn.init.normal_(module.weight, mean=0, std=kwargs.pop('initializer_range', 0.02))
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
nn.init.ones_(module.weight)
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
# baseline
|
||||
class CRFModel(BaseModel):
|
||||
def __init__(self,
|
||||
bert_dir,
|
||||
num_tags,
|
||||
dropout_prob=0.1,
|
||||
**kwargs):
|
||||
super(CRFModel, self).__init__(bert_dir=bert_dir, dropout_prob=dropout_prob)
|
||||
|
||||
out_dims = self.bert_config.hidden_size
|
||||
|
||||
mid_linear_dims = kwargs.pop('mid_linear_dims', 128)
|
||||
|
||||
self.mid_linear = nn.Sequential(
|
||||
nn.Linear(out_dims, mid_linear_dims),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_prob)
|
||||
)
|
||||
|
||||
out_dims = mid_linear_dims
|
||||
|
||||
self.classifier = nn.Linear(out_dims, num_tags)
|
||||
|
||||
self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
|
||||
self.loss_weight.data.fill_(-0.2)
|
||||
|
||||
self.crf_module = CRF(num_tags=num_tags, batch_first=True)
|
||||
|
||||
init_blocks = [self.mid_linear, self.classifier]
|
||||
|
||||
self._init_weights(init_blocks, initializer_range=self.bert_config.initializer_range)
|
||||
|
||||
def forward(self,
|
||||
token_ids,
|
||||
attention_masks,
|
||||
token_type_ids,
|
||||
labels=None,
|
||||
pseudo=None):
|
||||
|
||||
bert_outputs = self.bert_module(
|
||||
input_ids=token_ids,
|
||||
attention_mask=attention_masks,
|
||||
token_type_ids=token_type_ids
|
||||
)
|
||||
|
||||
# 常规
|
||||
seq_out = bert_outputs[0]
|
||||
|
||||
seq_out = self.mid_linear(seq_out)
|
||||
|
||||
emissions = self.classifier(seq_out)
|
||||
|
||||
if labels is not None:
|
||||
if pseudo is not None:
|
||||
# (batch,)
|
||||
tokens_loss = -1. * self.crf_module(emissions=emissions,
|
||||
tags=labels.long(),
|
||||
mask=attention_masks.byte(),
|
||||
reduction='none')
|
||||
|
||||
# nums of pseudo data
|
||||
pseudo_nums = pseudo.sum().item()
|
||||
total_nums = token_ids.shape[0]
|
||||
|
||||
# learning parameter
|
||||
rate = torch.sigmoid(self.loss_weight)
|
||||
if pseudo_nums == 0:
|
||||
loss_0 = tokens_loss.mean()
|
||||
loss_1 = (rate*pseudo*tokens_loss).sum()
|
||||
else:
|
||||
if total_nums == pseudo_nums:
|
||||
loss_0 = 0
|
||||
else:
|
||||
loss_0 = ((1 - rate) * (1 - pseudo) * tokens_loss).sum() / (total_nums - pseudo_nums)
|
||||
loss_1 = (rate*pseudo*tokens_loss).sum() / pseudo_nums
|
||||
|
||||
tokens_loss = loss_0 + loss_1
|
||||
|
||||
else:
|
||||
tokens_loss = -1. * self.crf_module(emissions=emissions,
|
||||
tags=labels.long(),
|
||||
mask=attention_masks.byte(),
|
||||
reduction='mean')
|
||||
|
||||
out = (tokens_loss,)
|
||||
|
||||
else:
|
||||
tokens_out = self.crf_module.decode(emissions=emissions, mask=attention_masks.byte())
|
||||
|
||||
out = (tokens_out, emissions)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SpanModel(BaseModel):
|
||||
def __init__(self,
|
||||
bert_dir,
|
||||
num_tags,
|
||||
dropout_prob=0.1,
|
||||
loss_type='ce',
|
||||
**kwargs):
|
||||
"""
|
||||
tag the subject and object corresponding to the predicate
|
||||
:param loss_type: train loss type in ['ce', 'ls_ce', 'focal']
|
||||
"""
|
||||
super(SpanModel, self).__init__(bert_dir, dropout_prob=dropout_prob)
|
||||
|
||||
out_dims = self.bert_config.hidden_size
|
||||
|
||||
mid_linear_dims = kwargs.pop('mid_linear_dims', 128)
|
||||
|
||||
self.num_tags = num_tags
|
||||
|
||||
self.mid_linear = nn.Sequential(
|
||||
nn.Linear(out_dims, mid_linear_dims),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_prob)
|
||||
)
|
||||
|
||||
out_dims = mid_linear_dims
|
||||
|
||||
self.start_fc = nn.Linear(out_dims, num_tags)
|
||||
self.end_fc = nn.Linear(out_dims, num_tags)
|
||||
|
||||
reduction = 'none'
|
||||
if loss_type == 'ce':
|
||||
self.criterion = nn.CrossEntropyLoss(reduction=reduction)
|
||||
elif loss_type == 'ls_ce':
|
||||
self.criterion = LabelSmoothingCrossEntropy(reduction=reduction)
|
||||
else:
|
||||
self.criterion = FocalLoss(reduction=reduction)
|
||||
|
||||
self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
|
||||
self.loss_weight.data.fill_(-0.2)
|
||||
|
||||
init_blocks = [self.mid_linear, self.start_fc, self.end_fc]
|
||||
|
||||
self._init_weights(init_blocks)
|
||||
|
||||
def forward(self,
|
||||
token_ids,
|
||||
attention_masks,
|
||||
token_type_ids,
|
||||
start_ids=None,
|
||||
end_ids=None,
|
||||
pseudo=None):
|
||||
|
||||
bert_outputs = self.bert_module(
|
||||
input_ids=token_ids,
|
||||
attention_mask=attention_masks,
|
||||
token_type_ids=token_type_ids
|
||||
)
|
||||
|
||||
seq_out = bert_outputs[0]
|
||||
|
||||
seq_out = self.mid_linear(seq_out)
|
||||
|
||||
start_logits = self.start_fc(seq_out)
|
||||
end_logits = self.end_fc(seq_out)
|
||||
|
||||
out = (start_logits, end_logits, )
|
||||
|
||||
if start_ids is not None and end_ids is not None and self.training:
|
||||
|
||||
start_logits = start_logits.view(-1, self.num_tags)
|
||||
end_logits = end_logits.view(-1, self.num_tags)
|
||||
|
||||
# 去掉 padding 部分的标签,计算真实 loss
|
||||
active_loss = attention_masks.view(-1) == 1
|
||||
active_start_logits = start_logits[active_loss]
|
||||
active_end_logits = end_logits[active_loss]
|
||||
|
||||
active_start_labels = start_ids.view(-1)[active_loss]
|
||||
active_end_labels = end_ids.view(-1)[active_loss]
|
||||
|
||||
|
||||
if pseudo is not None:
|
||||
# (batch,)
|
||||
start_loss = self.criterion(start_logits, start_ids.view(-1)).view(-1, 512).mean(dim=-1)
|
||||
end_loss = self.criterion(end_logits, end_ids.view(-1)).view(-1, 512).mean(dim=-1)
|
||||
|
||||
# nums of pseudo data
|
||||
pseudo_nums = pseudo.sum().item()
|
||||
total_nums = token_ids.shape[0]
|
||||
|
||||
# learning parameter
|
||||
rate = torch.sigmoid(self.loss_weight)
|
||||
if pseudo_nums == 0:
|
||||
start_loss = start_loss.mean()
|
||||
end_loss = end_loss.mean()
|
||||
else:
|
||||
if total_nums == pseudo_nums:
|
||||
start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums
|
||||
end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums
|
||||
else:
|
||||
start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums \
|
||||
+ ((1 - rate) * (1 - pseudo) * start_loss).sum() / (total_nums - pseudo_nums)
|
||||
end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums \
|
||||
+ ((1 - rate) * (1 - pseudo) * end_loss).sum() / (total_nums - pseudo_nums)
|
||||
else:
|
||||
start_loss = self.criterion(active_start_logits, active_start_labels)
|
||||
end_loss = self.criterion(active_end_logits, active_end_labels)
|
||||
|
||||
loss = start_loss + end_loss
|
||||
|
||||
out = (loss, ) + out
|
||||
|
||||
return out
|
||||
|
||||
class MRCModel(BaseModel):
|
||||
def __init__(self,
|
||||
bert_dir,
|
||||
dropout_prob=0.1,
|
||||
use_type_embed=False,
|
||||
loss_type='ce',
|
||||
**kwargs):
|
||||
"""
|
||||
tag the subject and object corresponding to the predicate
|
||||
:param use_type_embed: type embedding for the sentence
|
||||
:param loss_type: train loss type in ['ce', 'ls_ce', 'focal']
|
||||
"""
|
||||
super(MRCModel, self).__init__(bert_dir, dropout_prob=dropout_prob)
|
||||
|
||||
self.use_type_embed = use_type_embed
|
||||
self.use_smooth = loss_type
|
||||
|
||||
out_dims = self.bert_config.hidden_size
|
||||
|
||||
if self.use_type_embed:
|
||||
embed_dims = kwargs.pop('predicate_embed_dims', self.bert_config.hidden_size)
|
||||
self.type_embedding = nn.Embedding(13, embed_dims)
|
||||
|
||||
self.conditional_layer_norm = ConditionalLayerNorm(out_dims, embed_dims,
|
||||
eps=self.bert_config.layer_norm_eps)
|
||||
|
||||
mid_linear_dims = kwargs.pop('mid_linear_dims', 128)
|
||||
|
||||
self.mid_linear = nn.Sequential(
|
||||
nn.Linear(out_dims, mid_linear_dims),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout_prob)
|
||||
)
|
||||
|
||||
out_dims = mid_linear_dims
|
||||
|
||||
self.start_fc = nn.Linear(out_dims, 2)
|
||||
self.end_fc = nn.Linear(out_dims, 2)
|
||||
|
||||
reduction = 'none'
|
||||
if loss_type == 'ce':
|
||||
self.criterion = nn.CrossEntropyLoss(reduction=reduction)
|
||||
elif loss_type == 'ls_ce':
|
||||
self.criterion = LabelSmoothingCrossEntropy(reduction=reduction)
|
||||
else:
|
||||
self.criterion = FocalLoss(reduction=reduction)
|
||||
|
||||
self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
|
||||
self.loss_weight.data.fill_(-0.2)
|
||||
|
||||
init_blocks = [self.mid_linear, self.start_fc, self.end_fc]
|
||||
|
||||
if self.use_type_embed:
|
||||
init_blocks.append(self.type_embedding)
|
||||
|
||||
self._init_weights(init_blocks)
|
||||
|
||||
def forward(self,
|
||||
token_ids,
|
||||
attention_masks,
|
||||
token_type_ids,
|
||||
ent_type=None,
|
||||
start_ids=None,
|
||||
end_ids=None,
|
||||
pseudo=None):
|
||||
|
||||
bert_outputs = self.bert_module(
|
||||
input_ids=token_ids,
|
||||
attention_mask=attention_masks,
|
||||
token_type_ids=token_type_ids
|
||||
)
|
||||
|
||||
seq_out = bert_outputs[0]
|
||||
|
||||
if self.use_type_embed:
|
||||
assert ent_type is not None, \
|
||||
'Using predicate embedding, predicate should be implemented'
|
||||
|
||||
predicate_feature = self.type_embedding(ent_type)
|
||||
seq_out = self.conditional_layer_norm(seq_out, predicate_feature)
|
||||
|
||||
seq_out = self.mid_linear(seq_out)
|
||||
|
||||
start_logits = self.start_fc(seq_out)
|
||||
end_logits = self.end_fc(seq_out)
|
||||
|
||||
out = (start_logits, end_logits, )
|
||||
|
||||
if start_ids is not None and end_ids is not None:
|
||||
start_logits = start_logits.view(-1, 2)
|
||||
end_logits = end_logits.view(-1, 2)
|
||||
|
||||
# 去掉 text_a 和 padding 部分的标签,计算真实 loss
|
||||
active_loss = token_type_ids.view(-1) == 1
|
||||
active_start_logits = start_logits[active_loss]
|
||||
active_end_logits = end_logits[active_loss]
|
||||
|
||||
active_start_labels = start_ids.view(-1)[active_loss]
|
||||
active_end_labels = end_ids.view(-1)[active_loss]
|
||||
|
||||
if pseudo is not None:
|
||||
# (batch,)
|
||||
start_loss = self.criterion(start_logits, start_ids.view(-1)).view(-1, 512).mean(dim=-1)
|
||||
end_loss = self.criterion(end_logits, end_ids.view(-1)).view(-1, 512).mean(dim=-1)
|
||||
|
||||
# nums of pseudo data
|
||||
pseudo_nums = pseudo.sum().item()
|
||||
total_nums = token_ids.shape[0]
|
||||
|
||||
# learning parameter
|
||||
rate = torch.sigmoid(self.loss_weight)
|
||||
if pseudo_nums == 0:
|
||||
start_loss = start_loss.mean()
|
||||
end_loss = end_loss.mean()
|
||||
else:
|
||||
if total_nums == pseudo_nums:
|
||||
start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums
|
||||
end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums
|
||||
else:
|
||||
start_loss = (rate*pseudo*start_loss).sum() / pseudo_nums \
|
||||
+ ((1 - rate) * (1 - pseudo) * start_loss).sum() / (total_nums - pseudo_nums)
|
||||
end_loss = (rate*pseudo*end_loss).sum() / pseudo_nums \
|
||||
+ ((1 - rate) * (1 - pseudo) * end_loss).sum() / (total_nums - pseudo_nums)
|
||||
else:
|
||||
start_loss = self.criterion(active_start_logits, active_start_labels)
|
||||
end_loss = self.criterion(active_end_logits, active_end_labels)
|
||||
|
||||
loss = start_loss + end_loss
|
||||
|
||||
out = (loss, ) + out
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class EnsembleCRFModel:
|
||||
def __init__(self, model_path_list, bert_dir_list, num_tags, device, lamb=1/3):
|
||||
|
||||
self.models = []
|
||||
self.crf_module = CRF(num_tags=num_tags, batch_first=True)
|
||||
self.lamb = lamb
|
||||
|
||||
for idx, _path in enumerate(model_path_list):
|
||||
print(f'Load model from {_path}')
|
||||
|
||||
|
||||
print(f'Load model type: {bert_dir_list[0]}')
|
||||
model = CRFModel(bert_dir=bert_dir_list[0], num_tags=num_tags)
|
||||
|
||||
|
||||
model.load_state_dict(torch.load(_path, map_location=torch.device('cpu')))
|
||||
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
self.models.append(model)
|
||||
if idx == 0:
|
||||
print(f'Load CRF weight from {_path}')
|
||||
self.crf_module.load_state_dict(model.crf_module.state_dict())
|
||||
self.crf_module.to(device)
|
||||
|
||||
def weight(self, t):
|
||||
"""
|
||||
牛顿冷却定律加权融合
|
||||
"""
|
||||
return math.exp(-self.lamb*t)
|
||||
|
||||
def predict(self, model_inputs):
|
||||
weight_sum = 0.
|
||||
logits = None
|
||||
attention_masks = model_inputs['attention_masks']
|
||||
|
||||
for idx, model in enumerate(self.models):
|
||||
# 使用牛顿冷却概率融合
|
||||
weight = self.weight(idx)
|
||||
|
||||
# 使用概率平均融合
|
||||
# weight = 1 / len(self.models)
|
||||
|
||||
tmp_logits = model(**model_inputs)[1] * weight
|
||||
weight_sum += weight
|
||||
|
||||
if logits is None:
|
||||
logits = tmp_logits
|
||||
else:
|
||||
logits += tmp_logits
|
||||
|
||||
logits = logits / weight_sum
|
||||
|
||||
tokens_out = self.crf_module.decode(emissions=logits, mask=attention_masks.byte())
|
||||
|
||||
return tokens_out
|
||||
|
||||
def vote_entities(self, model_inputs, sent, id2ent, threshold):
|
||||
entities_ls = []
|
||||
for idx, model in enumerate(self.models):
|
||||
tmp_tokens = model(**model_inputs)[0][0]
|
||||
tmp_entities = crf_decode(tmp_tokens, sent, id2ent)
|
||||
entities_ls.append(tmp_entities)
|
||||
|
||||
return vote(entities_ls, threshold)
|
||||
|
||||
|
||||
class EnsembleSpanModel:
|
||||
def __init__(self, model_path_list, bert_dir_list, num_tags, device):
|
||||
|
||||
self.models = []
|
||||
|
||||
for idx, _path in enumerate(model_path_list):
|
||||
print(f'Load model from {_path}')
|
||||
|
||||
print(f'Load model type: {bert_dir_list[0]}')
|
||||
model = SpanModel(bert_dir=bert_dir_list[0], num_tags=num_tags)
|
||||
|
||||
model.load_state_dict(torch.load(_path, map_location=torch.device('cpu')))
|
||||
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
self.models.append(model)
|
||||
|
||||
def predict(self, model_inputs):
|
||||
start_logits, end_logits = None, None
|
||||
|
||||
for idx, model in enumerate(self.models):
|
||||
|
||||
# 使用概率平均融合
|
||||
weight = 1 / len(self.models)
|
||||
|
||||
tmp_start_logits, tmp_end_logits = model(**model_inputs)
|
||||
|
||||
tmp_start_logits = tmp_start_logits * weight
|
||||
tmp_end_logits = tmp_end_logits * weight
|
||||
|
||||
if start_logits is None:
|
||||
start_logits = tmp_start_logits
|
||||
end_logits = tmp_end_logits
|
||||
else:
|
||||
start_logits += tmp_start_logits
|
||||
end_logits += tmp_end_logits
|
||||
|
||||
return start_logits, end_logits
|
||||
|
||||
def vote_entities(self, model_inputs, sent, id2ent, threshold):
|
||||
entities_ls = []
|
||||
|
||||
for idx, model in enumerate(self.models):
|
||||
|
||||
start_logits, end_logits = model(**model_inputs)
|
||||
start_logits = start_logits[0].cpu().numpy()[1:1 + len(sent)]
|
||||
end_logits = end_logits[0].cpu().numpy()[1:1 + len(sent)]
|
||||
|
||||
decode_entities = span_decode(start_logits, end_logits, sent, id2ent)
|
||||
|
||||
entities_ls.append(decode_entities)
|
||||
|
||||
return vote(entities_ls, threshold)
|
||||
|
||||
|
||||
def build_model(task_type, bert_dir, **kwargs):
|
||||
assert task_type in ['crf', 'span', 'mrc']
|
||||
|
||||
if task_type == 'crf':
|
||||
model = CRFModel(bert_dir=bert_dir,
|
||||
num_tags=kwargs.pop('num_tags'),
|
||||
dropout_prob=kwargs.pop('dropout_prob', 0.1))
|
||||
|
||||
elif task_type == 'mrc':
|
||||
model = MRCModel(bert_dir=bert_dir,
|
||||
dropout_prob=kwargs.pop('dropout_prob', 0.1),
|
||||
use_type_embed=kwargs.pop('use_type_embed'),
|
||||
loss_type=kwargs.pop('loss_type', 'ce'))
|
||||
|
||||
else:
|
||||
model = SpanModel(bert_dir=bert_dir,
|
||||
num_tags=kwargs.pop('num_tags'),
|
||||
dropout_prob=kwargs.pop('dropout_prob', 0.1),
|
||||
loss_type=kwargs.pop('loss_type', 'ce'))
|
||||
|
||||
return model
|
98
src/utils/options.py
Normal file
@ -0,0 +1,98 @@
|
||||
import argparse
|
||||
|
||||
class Args:
|
||||
@staticmethod
|
||||
def parse():
|
||||
parser = argparse.ArgumentParser()
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def initialize(parser: argparse.ArgumentParser):
|
||||
# args for path
|
||||
parser.add_argument('--raw_data_dir', default='./data/raw_data',
|
||||
help='the data dir of raw data')
|
||||
|
||||
parser.add_argument('--mid_data_dir', default='./data/mid_data',
|
||||
help='the mid data dir')
|
||||
|
||||
parser.add_argument('--output_dir', default='./out/',
|
||||
help='the output dir for model checkpoints')
|
||||
|
||||
parser.add_argument('--bert_dir', default='../bert/torch_roberta_wwm',
|
||||
help='bert dir for ernie / roberta-wwm / uer')
|
||||
|
||||
parser.add_argument('--bert_type', default='roberta_wwm',
|
||||
help='roberta_wwm / ernie_1 / uer_large')
|
||||
|
||||
parser.add_argument('--task_type', default='crf',
|
||||
help='crf / span / mrc')
|
||||
|
||||
parser.add_argument('--loss_type', default='ls_ce',
|
||||
help='loss type for span bert')
|
||||
|
||||
parser.add_argument('--use_type_embed', default=False, action='store_true',
|
||||
help='weather to use soft label in span loss')
|
||||
|
||||
parser.add_argument('--use_fp16', default=False, action='store_true',
|
||||
help='weather to use fp16 during training')
|
||||
|
||||
# other args
|
||||
parser.add_argument('--seed', type=int, default=123, help='random seed')
|
||||
|
||||
parser.add_argument('--gpu_ids', type=str, default='0',
|
||||
help='gpu ids to use, -1 for cpu, "0,1" for multi gpu')
|
||||
|
||||
parser.add_argument('--mode', type=str, default='train',
|
||||
help='train / stack')
|
||||
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
|
||||
parser.add_argument('--eval_batch_size', default=64, type=int)
|
||||
|
||||
parser.add_argument('--swa_start', default=3, type=int,
|
||||
help='the epoch when swa start')
|
||||
|
||||
# train args
|
||||
parser.add_argument('--train_epochs', default=10, type=int,
|
||||
help='Max training epoch')
|
||||
|
||||
parser.add_argument('--dropout_prob', default=0.1, type=float,
|
||||
help='drop out probability')
|
||||
|
||||
parser.add_argument('--lr', default=2e-5, type=float,
|
||||
help='learning rate for the bert module')
|
||||
|
||||
parser.add_argument('--other_lr', default=2e-3, type=float,
|
||||
help='learning rate for the module except bert')
|
||||
|
||||
parser.add_argument('--max_grad_norm', default=1.0, type=float,
|
||||
help='max grad clip')
|
||||
|
||||
parser.add_argument('--warmup_proportion', default=0.1, type=float)
|
||||
|
||||
parser.add_argument('--weight_decay', default=0.00, type=float)
|
||||
|
||||
parser.add_argument('--adam_epsilon', default=1e-8, type=float)
|
||||
|
||||
parser.add_argument('--train_batch_size', default=24, type=int)
|
||||
|
||||
parser.add_argument('--eval_model', default=True, action='store_true',
|
||||
help='whether to eval model after training')
|
||||
|
||||
parser.add_argument('--attack_train', default='', type=str,
|
||||
help='fgm / pgd attack train when training')
|
||||
|
||||
# test args
|
||||
parser.add_argument('--version', default='v0', type=str,
|
||||
help='submit version')
|
||||
|
||||
parser.add_argument('--submit_dir', default='./submit', type=str)
|
||||
|
||||
parser.add_argument('--ckpt_dir', default='', type=str)
|
||||
|
||||
return parser
|
||||
|
||||
def get_parser(self):
|
||||
parser = self.parse()
|
||||
parser = self.initialize(parser)
|
||||
return parser.parse_args()
|
223
src/utils/trainer.py
Normal file
@ -0,0 +1,223 @@
|
||||
import os
|
||||
import copy
|
||||
import torch
|
||||
import logging
|
||||
from torch.cuda.amp import autocast as ac
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers import AdamW, get_linear_schedule_with_warmup
|
||||
from src.utils.attack_train_utils import FGM, PGD
|
||||
from src.utils.functions_utils import load_model_and_parallel, swa
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def save_model(opt, model, global_step):
|
||||
output_dir = os.path.join(opt.output_dir, 'checkpoint-{}'.format(global_step))
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# take care of model distributed / parallel training
|
||||
model_to_save = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
)
|
||||
logger.info(f'Saving model & optimizer & scheduler checkpoint to {output_dir}')
|
||||
torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'model.pt'))
|
||||
|
||||
|
||||
def build_optimizer_and_scheduler(opt, model, t_total):
|
||||
module = (
|
||||
model.module if hasattr(model, "module") else model
|
||||
)
|
||||
|
||||
# 差分学习率
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
model_param = list(module.named_parameters())
|
||||
|
||||
bert_param_optimizer = []
|
||||
other_param_optimizer = []
|
||||
|
||||
for name, para in model_param:
|
||||
space = name.split('.')
|
||||
if space[0] == 'bert_module':
|
||||
bert_param_optimizer.append((name, para))
|
||||
else:
|
||||
other_param_optimizer.append((name, para))
|
||||
|
||||
optimizer_grouped_parameters = [
|
||||
# bert other module
|
||||
{"params": [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": opt.weight_decay, 'lr': opt.lr},
|
||||
{"params": [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0, 'lr': opt.lr},
|
||||
|
||||
# 其他模块,差分学习率
|
||||
{"params": [p for n, p in other_param_optimizer if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": opt.weight_decay, 'lr': opt.other_lr},
|
||||
{"params": [p for n, p in other_param_optimizer if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0, 'lr': opt.other_lr},
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=opt.lr, eps=opt.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=int(opt.warmup_proportion * t_total), num_training_steps=t_total
|
||||
)
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
def train(opt, model, train_dataset):
|
||||
swa_raw_model = copy.deepcopy(model)
|
||||
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
|
||||
train_loader = DataLoader(dataset=train_dataset,
|
||||
batch_size=opt.train_batch_size,
|
||||
sampler=train_sampler,
|
||||
num_workers=0)
|
||||
|
||||
scaler = None
|
||||
if opt.use_fp16:
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
|
||||
model, device = load_model_and_parallel(model, opt.gpu_ids)
|
||||
|
||||
use_n_gpus = False
|
||||
if hasattr(model, "module"):
|
||||
use_n_gpus = True
|
||||
|
||||
t_total = len(train_loader) * opt.train_epochs
|
||||
|
||||
optimizer, scheduler = build_optimizer_and_scheduler(opt, model, t_total)
|
||||
|
||||
# Train
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num Examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {opt.train_epochs}")
|
||||
logger.info(f" Total training batch size = {opt.train_batch_size}")
|
||||
logger.info(f" Total optimization steps = {t_total}")
|
||||
|
||||
global_step = 0
|
||||
|
||||
model.zero_grad()
|
||||
|
||||
fgm, pgd = None, None
|
||||
|
||||
attack_train_mode = opt.attack_train.lower()
|
||||
if attack_train_mode == 'fgm':
|
||||
fgm = FGM(model=model)
|
||||
elif attack_train_mode == 'pgd':
|
||||
pgd = PGD(model=model)
|
||||
|
||||
pgd_k = 3
|
||||
|
||||
save_steps = t_total // opt.train_epochs
|
||||
eval_steps = save_steps
|
||||
|
||||
logger.info(f'Save model in {save_steps} steps; Eval model in {eval_steps} steps')
|
||||
|
||||
log_loss_steps = 20
|
||||
|
||||
avg_loss = 0.
|
||||
|
||||
for epoch in range(opt.train_epochs):
|
||||
|
||||
for step, batch_data in enumerate(train_loader):
|
||||
|
||||
model.train()
|
||||
|
||||
for key in batch_data.keys():
|
||||
batch_data[key] = batch_data[key].to(device)
|
||||
|
||||
if opt.use_fp16:
|
||||
with ac():
|
||||
loss = model(**batch_data)[0]
|
||||
else:
|
||||
loss = model(**batch_data)[0]
|
||||
|
||||
if use_n_gpus:
|
||||
loss = loss.mean()
|
||||
|
||||
if opt.use_fp16:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if fgm is not None:
|
||||
fgm.attack()
|
||||
|
||||
if opt.use_fp16:
|
||||
with ac():
|
||||
loss_adv = model(**batch_data)[0]
|
||||
else:
|
||||
loss_adv = model(**batch_data)[0]
|
||||
|
||||
if use_n_gpus:
|
||||
loss_adv = loss_adv.mean()
|
||||
|
||||
if opt.use_fp16:
|
||||
scaler.scale(loss_adv).backward()
|
||||
else:
|
||||
loss_adv.backward()
|
||||
|
||||
fgm.restore()
|
||||
|
||||
elif pgd is not None:
|
||||
pgd.backup_grad()
|
||||
|
||||
for _t in range(pgd_k):
|
||||
pgd.attack(is_first_attack=(_t == 0))
|
||||
|
||||
if _t != pgd_k - 1:
|
||||
model.zero_grad()
|
||||
else:
|
||||
pgd.restore_grad()
|
||||
|
||||
if opt.use_fp16:
|
||||
with ac():
|
||||
loss_adv = model(**batch_data)[0]
|
||||
else:
|
||||
loss_adv = model(**batch_data)[0]
|
||||
|
||||
if use_n_gpus:
|
||||
loss_adv = loss_adv.mean()
|
||||
|
||||
if opt.use_fp16:
|
||||
scaler.scale(loss_adv).backward()
|
||||
else:
|
||||
loss_adv.backward()
|
||||
|
||||
pgd.restore()
|
||||
|
||||
if opt.use_fp16:
|
||||
scaler.unscale_(optimizer)
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm)
|
||||
|
||||
# optimizer.step()
|
||||
if opt.use_fp16:
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
optimizer.step()
|
||||
|
||||
scheduler.step()
|
||||
model.zero_grad()
|
||||
|
||||
global_step += 1
|
||||
|
||||
if global_step % log_loss_steps == 0:
|
||||
avg_loss /= log_loss_steps
|
||||
logger.info('Step: %d / %d ----> total loss: %.5f' % (global_step, t_total, avg_loss))
|
||||
avg_loss = 0.
|
||||
else:
|
||||
avg_loss += loss.item()
|
||||
|
||||
if global_step % save_steps == 0:
|
||||
save_model(opt, model, global_step)
|
||||
|
||||
swa(swa_raw_model, opt.output_dir, swa_start=opt.swa_start)
|
||||
|
||||
# clear cuda cache to avoid OOM
|
||||
torch.cuda.empty_cache()
|
||||
logger.info('Train done')
|