新增实体抽取、关系抽取、事件抽取模型及训练预测代码
This commit is contained in:
parent
b6be0be9af
commit
4f3435fa18
2
.gitignore
vendored
2
.gitignore
vendored
@ -6,4 +6,4 @@ build
|
|||||||
dist
|
dist
|
||||||
.vscode
|
.vscode
|
||||||
lightKG.egg-info/
|
lightKG.egg-info/
|
||||||
test/*_saves
|
examples/*_saves
|
135
README.md
135
README.md
@ -21,10 +21,16 @@
|
|||||||
|
|
||||||
### 实体识别与链接
|
### 实体识别与链接
|
||||||
|
|
||||||
|
- 命名实体识别, ner
|
||||||
|
|
||||||
### 实体关系抽取
|
### 实体关系抽取
|
||||||
|
|
||||||
|
- 关系抽取, re
|
||||||
|
|
||||||
### 事件检测与抽取
|
### 事件检测与抽取
|
||||||
|
|
||||||
|
- 语义角色标注, srl
|
||||||
|
|
||||||
### 知识存储与查询
|
### 知识存储与查询
|
||||||
|
|
||||||
### 知识推理
|
### 知识推理
|
||||||
@ -147,6 +153,124 @@ print(krl.predict_head(rel='外文名', tail='Compiler'))
|
|||||||
[('编译器', 0.998942494392395), ('译码器', 0.36795616149902344), ('计算机,单片机,编程语言', 0.36788302659988403)]
|
[('编译器', 0.998942494392395), ('译码器', 0.36795616149902344), ('计算机,单片机,编程语言', 0.36788302659988403)]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### ner
|
||||||
|
|
||||||
|
#### 训练
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lightkg.erl import NER
|
||||||
|
|
||||||
|
# 创建NER对象
|
||||||
|
ner_model = NER()
|
||||||
|
|
||||||
|
train_path = '/home/lightsmile/NLP/corpus/ner/train.sample.txt'
|
||||||
|
dev_path = '/home/lightsmile/NLP/corpus/ner/test.sample.txt'
|
||||||
|
vec_path = '/home/lightsmile/NLP/embedding/char/token_vec_300.bin'
|
||||||
|
|
||||||
|
# 只需指定训练数据路径,预训练字向量可选,开发集路径可选,模型保存路径可选。
|
||||||
|
ner_model.train(train_path, vectors_path=vec_path, dev_path=dev_path, save_path='./ner_saves')
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 测试
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 加载模型,默认当前目录下的`saves`目录
|
||||||
|
ner_model.load('./ner_saves')
|
||||||
|
# 对train_path下的测试集进行读取测试
|
||||||
|
ner_model.test(train_path)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 预测
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
pprint(ner_model.predict('另一个很酷的事情是,通过框架我们可以停止并在稍后恢复训练。'))
|
||||||
|
```
|
||||||
|
|
||||||
|
预测结果:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
[{'end': 15, 'entity': '我们', 'start': 14, 'type': 'Person'}]
|
||||||
|
```
|
||||||
|
|
||||||
|
### re
|
||||||
|
|
||||||
|
#### 训练
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lightkg.ere import RE
|
||||||
|
|
||||||
|
re = RE()
|
||||||
|
|
||||||
|
train_path = '/home/lightsmile/Projects/NLP/ChineseNRE/data/people-relation/train.sample.txt'
|
||||||
|
dev_path = '/home/lightsmile/Projects/NLP/ChineseNRE/data/people-relation/test.sample.txt'
|
||||||
|
vec_path = '/home/lightsmile/NLP/embedding/word/sgns.zhihu.bigram-char'
|
||||||
|
|
||||||
|
re.train(train_path, dev_path=dev_path, vectors_path=vec_path, save_path='./re_saves')
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 测试
|
||||||
|
|
||||||
|
```python
|
||||||
|
re.load('./re_saves')
|
||||||
|
re.test(dev_path)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 预测
|
||||||
|
|
||||||
|
```python
|
||||||
|
print(re.predict('钱钟书', '辛笛', '与辛笛京沪唱和聽钱钟书与钱钟书是清华校友,钱钟书高辛笛两班。'))
|
||||||
|
```
|
||||||
|
|
||||||
|
预测结果:
|
||||||
|
|
||||||
|
```python
|
||||||
|
(0.7306928038597107, '同门') # return格式为(预测概率,预测标签)
|
||||||
|
```
|
||||||
|
|
||||||
|
### srl
|
||||||
|
|
||||||
|
#### 训练
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lightkg.ede import SRL
|
||||||
|
|
||||||
|
srl_model = SRL()
|
||||||
|
|
||||||
|
train_path = '/home/lightsmile/NLP/corpus/srl/train.sample.tsv'
|
||||||
|
dev_path = '/home/lightsmile/NLP/corpus/srl/test.sample.tsv'
|
||||||
|
vec_path = '/home/lightsmile/NLP/embedding/word/sgns.zhihu.bigram-char'
|
||||||
|
|
||||||
|
|
||||||
|
srl_model.train(train_path, vectors_path=vec_path, dev_path=dev_path, save_path='./srl_saves')
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 测试
|
||||||
|
|
||||||
|
```python
|
||||||
|
srl_model.load('./srl_saves')
|
||||||
|
|
||||||
|
srl_model.test(dev_path)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 预测
|
||||||
|
|
||||||
|
```python
|
||||||
|
word_list = ['代表', '朝方', '对', '中国', '党政', '领导人', '和', '人民', '哀悼', '金日成', '主席', '逝世', '表示', '深切', '谢意', '。']
|
||||||
|
pos_list = ['VV', 'NN', 'P', 'NR', 'NN', 'NN', 'CC', 'NN', 'VV', 'NR', 'NN', 'VV', 'VV', 'JJ', 'NN', 'PU']
|
||||||
|
rel_list = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
|
||||||
|
|
||||||
|
print(srl_model.predict(word_list, pos_list, rel_list))
|
||||||
|
```
|
||||||
|
|
||||||
|
预测结果:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
{'ARG0': '中国党政领导人和人民', 'rel': '哀悼', 'ARG1': '金日成主席逝世'}
|
||||||
|
```
|
||||||
|
|
||||||
## 项目组织结构
|
## 项目组织结构
|
||||||
|
|
||||||
### 项目架构
|
### 项目架构
|
||||||
@ -158,10 +282,18 @@ print(krl.predict_head(rel='外文名', tail='Compiler'))
|
|||||||
- common
|
- common
|
||||||
- entity.py
|
- entity.py
|
||||||
- relation.py
|
- relation.py
|
||||||
|
- ede
|
||||||
|
- srl, 语义角色标注
|
||||||
|
- ere
|
||||||
|
- re, 关系抽取
|
||||||
|
- erl
|
||||||
|
- ner, 命名实体识别
|
||||||
|
- kr
|
||||||
- krl,知识表示学习
|
- krl,知识表示学习
|
||||||
- models
|
- models
|
||||||
- transE
|
- transE
|
||||||
- utils
|
- utils
|
||||||
|
- ksq
|
||||||
- utils
|
- utils
|
||||||
|
|
||||||
### 架构说明
|
### 架构说明
|
||||||
@ -207,6 +339,9 @@ print(krl.predict_head(rel='外文名', tail='Compiler'))
|
|||||||
|
|
||||||
### 功能
|
### 功能
|
||||||
|
|
||||||
|
- [x] 增加关系抽取相关模型以及训练预测代码
|
||||||
|
- [x] 增加事件抽取相关模型以及训练预测代码
|
||||||
|
- [x] 增加命名实体识别相关模型以及预测训练代码
|
||||||
- [x] 增加基于翻译模型的知识表示学习相关模型以及训练预测代码
|
- [x] 增加基于翻译模型的知识表示学习相关模型以及训练预测代码
|
||||||
|
|
||||||
## 参考
|
## 参考
|
||||||
|
15
examples/test_ner.py
Normal file
15
examples/test_ner.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from lightkg.erl import NER
|
||||||
|
|
||||||
|
ner_model = NER()
|
||||||
|
|
||||||
|
train_path = '/home/lightsmile/NLP/corpus/ner/train.sample.txt'
|
||||||
|
dev_path = '/home/lightsmile/NLP/corpus/ner/test.sample.txt'
|
||||||
|
vec_path = '/home/lightsmile/NLP/embedding/char/token_vec_300.bin'
|
||||||
|
|
||||||
|
# ner_model.train(train_path, vectors_path=vec_path, dev_path=dev_path, save_path='./ner_saves')
|
||||||
|
|
||||||
|
ner_model.load('./ner_saves')
|
||||||
|
# ner_model.test(train_path)
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
|
pprint(ner_model.predict('另一个很酷的事情是,通过框架我们可以停止并在稍后恢复训练。'))
|
15
examples/test_re.py
Normal file
15
examples/test_re.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from lightkg.ere import RE
|
||||||
|
|
||||||
|
re = RE()
|
||||||
|
|
||||||
|
train_path = '/home/lightsmile/Projects/NLP/ChineseNRE/data/people-relation/train.sample.txt'
|
||||||
|
dev_path = '/home/lightsmile/Projects/NLP/ChineseNRE/data/people-relation/test.sample.txt'
|
||||||
|
vec_path = '/home/lightsmile/NLP/embedding/word/sgns.zhihu.bigram-char'
|
||||||
|
|
||||||
|
|
||||||
|
# re.train(train_path, dev_path=train_path, vectors_path=vec_path, save_path='./re_saves')
|
||||||
|
|
||||||
|
re.load('./re_saves')
|
||||||
|
# re.test(train_path)
|
||||||
|
# 钱钟书 辛笛 同门 与辛笛京沪唱和聽钱钟书与钱钟书是清华校友,钱钟书高辛笛两班。
|
||||||
|
print(re.predict('钱钟书', '辛笛', '与辛笛京沪唱和聽钱钟书与钱钟书是清华校友,钱钟书高辛笛两班。'))
|
21
examples/test_srl.py
Normal file
21
examples/test_srl.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from lightkg.ede import SRL
|
||||||
|
|
||||||
|
srl_model = SRL()
|
||||||
|
|
||||||
|
train_path = '/home/lightsmile/NLP/corpus/srl/train.sample.tsv'
|
||||||
|
dev_path = '/home/lightsmile/NLP/corpus/srl/test.sample.tsv'
|
||||||
|
vec_path = '/home/lightsmile/NLP/embedding/word/sgns.zhihu.bigram-char'
|
||||||
|
|
||||||
|
|
||||||
|
# srl_model.train(train_path, vectors_path=vec_path, dev_path=dev_path, save_path='./srl_saves')
|
||||||
|
|
||||||
|
srl_model.load('./srl_saves')
|
||||||
|
|
||||||
|
# srl_model.test(dev_path)
|
||||||
|
|
||||||
|
word_list = ['代表', '朝方', '对', '中国', '党政', '领导人', '和', '人民', '哀悼', '金日成', '主席', '逝世', '表示', '深切', '谢意', '。']
|
||||||
|
pos_list = ['VV', 'NN', 'P', 'NR', 'NN', 'NN', 'CC', 'NN', 'VV', 'NR', 'NN', 'VV', 'VV', 'JJ', 'NN', 'PU']
|
||||||
|
rel_list = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
|
||||||
|
|
||||||
|
print(srl_model.predict(word_list, pos_list, rel_list))
|
||||||
|
|
@ -0,0 +1,2 @@
|
|||||||
|
from .srl.module import SRL
|
||||||
|
__all__ = ['SRL']
|
22
lightkg/ede/srl/config.py
Normal file
22
lightkg/ede/srl/config.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from ...base.config import DEVICE
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
'lr': 0.02,
|
||||||
|
'epoch': 30,
|
||||||
|
'lr_decay': 0.05,
|
||||||
|
'batch_size': 128,
|
||||||
|
'dropout': 0.5,
|
||||||
|
'static': False,
|
||||||
|
'non_static': False,
|
||||||
|
'embedding_dim': 300,
|
||||||
|
'pos_dim': 50,
|
||||||
|
'num_layers': 2,
|
||||||
|
'pad_index': 1,
|
||||||
|
'vector_path': '',
|
||||||
|
'tag_num': 0,
|
||||||
|
'vocabulary_size': 0,
|
||||||
|
'pos_size': 0,
|
||||||
|
'word_vocab': None,
|
||||||
|
'pos_vocab': None,
|
||||||
|
'tag_vocab': None,
|
||||||
|
'save_path': './saves'
|
||||||
|
}
|
97
lightkg/ede/srl/model.py
Normal file
97
lightkg/ede/srl/model.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchcrf import CRF
|
||||||
|
from torchtext.vocab import Vectors
|
||||||
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||||
|
|
||||||
|
from ...utils.log import logger
|
||||||
|
from .config import DEVICE, DEFAULT_CONFIG
|
||||||
|
from ...base.model import BaseConfig, BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Config(BaseConfig):
|
||||||
|
def __init__(self, word_vocab, pos_vocab, tag_vocab, vector_path, **kwargs):
|
||||||
|
super(Config, self).__init__()
|
||||||
|
for name, value in DEFAULT_CONFIG.items():
|
||||||
|
setattr(self, name, value)
|
||||||
|
self.word_vocab = word_vocab
|
||||||
|
self.pos_vocab = pos_vocab
|
||||||
|
self.tag_vocab = tag_vocab
|
||||||
|
self.tag_num = len(self.tag_vocab)
|
||||||
|
self.vocabulary_size = len(self.word_vocab)
|
||||||
|
self.pos_size = len(self.pos_vocab)
|
||||||
|
self.vector_path = vector_path
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
setattr(self, name, value)
|
||||||
|
|
||||||
|
|
||||||
|
class BiLstmCrf(BaseModel):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(BiLstmCrf, self).__init__(args)
|
||||||
|
self.args = args
|
||||||
|
self.hidden_dim = 300
|
||||||
|
self.tag_num = args.tag_num
|
||||||
|
self.batch_size = args.batch_size
|
||||||
|
self.bidirectional = True
|
||||||
|
self.num_layers = args.num_layers
|
||||||
|
self.pad_index = args.pad_index
|
||||||
|
self.dropout = args.dropout
|
||||||
|
self.save_path = args.save_path
|
||||||
|
|
||||||
|
vocabulary_size = args.vocabulary_size
|
||||||
|
embedding_dimension = args.embedding_dim
|
||||||
|
pos_size = args.pos_size
|
||||||
|
pos_dim = args.pos_dim
|
||||||
|
|
||||||
|
self.word_embedding = nn.Embedding(vocabulary_size, embedding_dimension).to(DEVICE)
|
||||||
|
if args.static:
|
||||||
|
logger.info('logging word vectors from {}'.format(args.vector_path))
|
||||||
|
vectors = Vectors(args.vector_path).vectors
|
||||||
|
self.word_embedding = nn.Embedding.from_pretrained(vectors, freeze=not args.non_static).to(DEVICE)
|
||||||
|
self.pos_embedding = nn.Embedding(pos_size, pos_dim).to(DEVICE)
|
||||||
|
|
||||||
|
self.lstm = nn.LSTM(embedding_dimension + pos_dim + 1, self.hidden_dim // 2, bidirectional=self.bidirectional,
|
||||||
|
num_layers=self.num_layers, dropout=self.dropout).to(DEVICE)
|
||||||
|
self.hidden2label = nn.Linear(self.hidden_dim, self.tag_num).to(DEVICE)
|
||||||
|
self.crflayer = CRF(self.tag_num).to(DEVICE)
|
||||||
|
|
||||||
|
# self.init_weight()
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
nn.init.xavier_normal_(self.embedding.weight)
|
||||||
|
for name, param in self.lstm.named_parameters():
|
||||||
|
if 'weight' in name:
|
||||||
|
nn.init.xavier_normal_(param)
|
||||||
|
nn.init.xavier_normal_(self.hidden2label.weight)
|
||||||
|
|
||||||
|
def init_hidden(self, batch_size=None):
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = self.batch_size
|
||||||
|
|
||||||
|
h0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim // 2).to(DEVICE)
|
||||||
|
c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim // 2).to(DEVICE)
|
||||||
|
|
||||||
|
return h0, c0
|
||||||
|
|
||||||
|
def loss(self, x, sent_lengths, pos, rel, y):
|
||||||
|
mask = torch.ne(x, self.pad_index)
|
||||||
|
emissions = self.lstm_forward(x, pos, rel, sent_lengths)
|
||||||
|
return self.crflayer(emissions, y, mask=mask)
|
||||||
|
|
||||||
|
def forward(self, x, poses, rels, sent_lengths):
|
||||||
|
mask = torch.ne(x, self.pad_index)
|
||||||
|
emissions = self.lstm_forward(x, poses, rels, sent_lengths)
|
||||||
|
return self.crflayer.decode(emissions, mask=mask)
|
||||||
|
|
||||||
|
def lstm_forward(self, sentence, poses, rels, sent_lengths):
|
||||||
|
word = self.word_embedding(sentence.to(DEVICE)).to(DEVICE)
|
||||||
|
pos = self.pos_embedding(poses.to(DEVICE)).to(DEVICE)
|
||||||
|
rels = rels.view(rels.size(0), rels.size(1), 1).float().to(DEVICE)
|
||||||
|
x = torch.cat((word, pos, rels), dim=2)
|
||||||
|
x = pack_padded_sequence(x, sent_lengths)
|
||||||
|
self.hidden = self.init_hidden(batch_size=len(sent_lengths))
|
||||||
|
lstm_out, self.hidden = self.lstm(x, self.hidden)
|
||||||
|
lstm_out, new_batch_size = pad_packed_sequence(lstm_out)
|
||||||
|
assert torch.equal(sent_lengths, new_batch_size.to(DEVICE))
|
||||||
|
y = self.hidden2label(lstm_out.to(DEVICE))
|
||||||
|
return y.to(DEVICE)
|
94
lightkg/ede/srl/module.py
Normal file
94
lightkg/ede/srl/module.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ...utils.learning import adjust_learning_rate
|
||||||
|
from ...utils.log import logger
|
||||||
|
from ...base.module import Module
|
||||||
|
|
||||||
|
from .config import DEVICE, DEFAULT_CONFIG
|
||||||
|
from .model import Config, BiLstmCrf
|
||||||
|
from .tool import srl_tool
|
||||||
|
from .utils.convert import iobes_ranges
|
||||||
|
|
||||||
|
seed = 2019
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class SRL(Module):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self._model = None
|
||||||
|
self._word_vocab = None
|
||||||
|
self._tag_vocab = None
|
||||||
|
self._pos_vocab = None
|
||||||
|
|
||||||
|
def train(self, train_path, save_path=DEFAULT_CONFIG['save_path'], dev_path=None, vectors_path=None, **kwargs):
|
||||||
|
train_dataset = srl_tool.get_dataset(train_path)
|
||||||
|
if dev_path:
|
||||||
|
dev_dataset = srl_tool.get_dataset(dev_path)
|
||||||
|
word_vocab, pos_vocab, tag_vocab = srl_tool.get_vocab(train_dataset, dev_dataset)
|
||||||
|
else:
|
||||||
|
word_vocab, pos_vocab, tag_vocab = srl_tool.get_vocab(train_dataset)
|
||||||
|
self._word_vocab = word_vocab
|
||||||
|
self._pos_vocab = pos_vocab
|
||||||
|
self._tag_vocab = tag_vocab
|
||||||
|
train_iter = srl_tool.get_iterator(train_dataset, batch_size=DEFAULT_CONFIG['batch_size'])
|
||||||
|
config = Config(word_vocab, pos_vocab, tag_vocab, save_path=save_path, vector_path=vectors_path, **kwargs)
|
||||||
|
bilstmcrf = BiLstmCrf(config)
|
||||||
|
self._model = bilstmcrf
|
||||||
|
optim = torch.optim.Adam(bilstmcrf.parameters(), lr=config.lr)
|
||||||
|
for epoch in range(config.epoch):
|
||||||
|
bilstmcrf.train()
|
||||||
|
acc_loss = 0
|
||||||
|
for item in tqdm(train_iter):
|
||||||
|
bilstmcrf.zero_grad()
|
||||||
|
item_text_sentences = item.text[0]
|
||||||
|
item_text_lengths = item.text[1]
|
||||||
|
item_loss = (-bilstmcrf.loss(item_text_sentences, item_text_lengths, item.pos, item.rel, item.tag)) / item.tag.size(1)
|
||||||
|
acc_loss += item_loss.view(-1).cpu().data.tolist()[0]
|
||||||
|
item_loss.backward()
|
||||||
|
optim.step()
|
||||||
|
logger.info('epoch: {}, acc_loss: {}'.format(epoch, acc_loss))
|
||||||
|
if dev_path:
|
||||||
|
dev_score = self._validate(dev_dataset)
|
||||||
|
logger.info('dev score:{}'.format(dev_score))
|
||||||
|
|
||||||
|
adjust_learning_rate(optim, config.lr / (1 + (epoch + 1) * config.lr_decay))
|
||||||
|
config.save()
|
||||||
|
bilstmcrf.save()
|
||||||
|
|
||||||
|
def predict(self, word_list, pos_list, rel_list):
|
||||||
|
self._model.eval()
|
||||||
|
assert len(word_list) == len(pos_list) == len(rel_list)
|
||||||
|
vec_text = torch.tensor([self._word_vocab.stoi[x] for x in word_list]).view(-1, 1).to(DEVICE)
|
||||||
|
len_text = torch.tensor([len(vec_text)]).to(DEVICE)
|
||||||
|
vec_pos = torch.tensor([self._pos_vocab.stoi[x] for x in pos_list]).view(-1, 1).to(DEVICE)
|
||||||
|
vec_rel = torch.tensor([int(x) for x in rel_list]).view(-1, 1).to(DEVICE)
|
||||||
|
vec_predict = self._model(vec_text, vec_pos, vec_rel, len_text)[0]
|
||||||
|
tag_predict = [self._tag_vocab.itos[i] for i in vec_predict]
|
||||||
|
return iobes_ranges([x for x in word_list], tag_predict)
|
||||||
|
|
||||||
|
def load(self, save_path=DEFAULT_CONFIG['save_path']):
|
||||||
|
config = Config.load(save_path)
|
||||||
|
bilstmcrf = BiLstmCrf(config)
|
||||||
|
bilstmcrf.load()
|
||||||
|
self._model = bilstmcrf
|
||||||
|
self._word_vocab = config.word_vocab
|
||||||
|
self._tag_vocab = config.tag_vocab
|
||||||
|
self._pos_vocab = config.pos_vocab
|
||||||
|
|
||||||
|
def test(self, test_path):
|
||||||
|
test_dataset = srl_tool.get_dataset(test_path)
|
||||||
|
test_score = self._validate(test_dataset)
|
||||||
|
logger.info('test score:{}'.format(test_score))
|
||||||
|
|
||||||
|
def _validate(self, dev_dataset):
|
||||||
|
self._model.eval()
|
||||||
|
dev_score_list = []
|
||||||
|
for dev_item in tqdm(dev_dataset):
|
||||||
|
item_score = srl_tool.get_score(self._model, dev_item.text, dev_item.tag, dev_item.pos, dev_item.rel,
|
||||||
|
self._word_vocab, self._tag_vocab, self._pos_vocab)
|
||||||
|
dev_score_list.append(item_score)
|
||||||
|
return sum(dev_score_list) / len(dev_score_list)
|
80
lightkg/ede/srl/tool.py
Normal file
80
lightkg/ede/srl/tool.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import torch
|
||||||
|
from torchtext.data import Dataset, Field, BucketIterator, ReversibleField
|
||||||
|
from torchtext.vocab import Vectors
|
||||||
|
from torchtext.datasets import SequenceTaggingDataset
|
||||||
|
from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score
|
||||||
|
|
||||||
|
from ...base.tool import Tool
|
||||||
|
from ...utils.log import logger
|
||||||
|
from .config import DEVICE, DEFAULT_CONFIG
|
||||||
|
|
||||||
|
seed = 2019
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def light_tokenize(sequence: str):
|
||||||
|
return [sequence]
|
||||||
|
|
||||||
|
|
||||||
|
def post_process(arr, _):
|
||||||
|
return [[int(item) for item in arr_item] for arr_item in arr]
|
||||||
|
|
||||||
|
|
||||||
|
TEXT = Field(sequential=True, tokenize=light_tokenize, include_lengths=True)
|
||||||
|
POS = Field(sequential=True, tokenize=light_tokenize)
|
||||||
|
REL = Field(sequential=True, use_vocab=False, unk_token=None, pad_token=0, postprocessing=post_process)
|
||||||
|
TAG = Field(sequential=True, tokenize=light_tokenize, is_target=True, unk_token=None)
|
||||||
|
Fields = [('text', TEXT), ('pos', POS), ('rel', REL), ('tag', TAG)]
|
||||||
|
|
||||||
|
|
||||||
|
class SRLTool(Tool):
|
||||||
|
def get_dataset(self, path: str, fields=Fields, separator='\t'):
|
||||||
|
logger.info('loading dataset from {}'.format(path))
|
||||||
|
st_dataset = SequenceTaggingDataset(path, fields=fields, separator=separator)
|
||||||
|
logger.info('successed loading dataset')
|
||||||
|
return st_dataset
|
||||||
|
|
||||||
|
def get_vocab(self, *dataset):
|
||||||
|
logger.info('building word vocab...')
|
||||||
|
TEXT.build_vocab(*dataset)
|
||||||
|
logger.info('successed building word vocab')
|
||||||
|
logger.info('building pos vocab...')
|
||||||
|
POS.build_vocab(*dataset)
|
||||||
|
logger.info('successed building pos vocab')
|
||||||
|
logger.info('building tag vocab...')
|
||||||
|
TAG.build_vocab(*dataset)
|
||||||
|
logger.info('successed building tag vocab')
|
||||||
|
return TEXT.vocab, POS.vocab, TAG.vocab
|
||||||
|
|
||||||
|
def get_vectors(self, path: str):
|
||||||
|
logger.info('loading vectors from {}'.format(path))
|
||||||
|
vectors = Vectors(path)
|
||||||
|
logger.info('successed loading vectors')
|
||||||
|
return vectors
|
||||||
|
|
||||||
|
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
|
||||||
|
sort_key=lambda x: len(x.text), sort_within_batch=True):
|
||||||
|
return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key,
|
||||||
|
sort_within_batch=sort_within_batch)
|
||||||
|
|
||||||
|
def get_score(self, model, x, y, pos, rel, field_x, field_y, field_pos, score_type='f1'):
|
||||||
|
metrics_map = {
|
||||||
|
'f1': f1_score,
|
||||||
|
'p': precision_score,
|
||||||
|
'r': recall_score,
|
||||||
|
'acc': accuracy_score
|
||||||
|
}
|
||||||
|
metric_func = metrics_map[score_type] if score_type in metrics_map else metrics_map['f1']
|
||||||
|
vec_x = torch.tensor([field_x.stoi[i] for i in x])
|
||||||
|
len_vec_x = torch.tensor([len(vec_x)]).to(DEVICE)
|
||||||
|
vec_pos = torch.tensor([field_pos.stoi[i] for i in pos])
|
||||||
|
vec_rel = torch.tensor([int(x) for x in rel])
|
||||||
|
predict_y = model(vec_x.view(-1, 1).to(DEVICE), vec_pos.view(-1, 1).to(DEVICE), vec_rel.view(-1, 1).to(DEVICE),
|
||||||
|
len_vec_x)[0]
|
||||||
|
true_y = [field_y.stoi[i] for i in y]
|
||||||
|
assert len(true_y) == len(predict_y)
|
||||||
|
return metric_func(predict_y, true_y, average='micro')
|
||||||
|
|
||||||
|
|
||||||
|
srl_tool = SRLTool()
|
51
lightkg/ede/srl/utils/convert.py
Normal file
51
lightkg/ede/srl/utils/convert.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
def iobes_iob(tags):
|
||||||
|
"""
|
||||||
|
IOBES -> IOB
|
||||||
|
"""
|
||||||
|
new_tags = []
|
||||||
|
for i, tag in enumerate(tags):
|
||||||
|
if tag == 'rel':
|
||||||
|
new_tags.append(tag)
|
||||||
|
elif tag.split('-')[0] == 'B':
|
||||||
|
new_tags.append(tag)
|
||||||
|
elif tag.split('-')[0] == 'I':
|
||||||
|
new_tags.append(tag)
|
||||||
|
elif tag.split('-')[0] == 'S':
|
||||||
|
new_tags.append(tag.replace('S-', 'B-'))
|
||||||
|
elif tag.split('-')[0] == 'E':
|
||||||
|
new_tags.append(tag.replace('E-', 'I-'))
|
||||||
|
elif tag.split('-')[0] == 'O':
|
||||||
|
new_tags.append(tag)
|
||||||
|
else:
|
||||||
|
raise Exception('Invalid format!')
|
||||||
|
return new_tags
|
||||||
|
|
||||||
|
|
||||||
|
def iob_ranges(words, tags):
|
||||||
|
"""
|
||||||
|
IOB -> Ranges
|
||||||
|
"""
|
||||||
|
assert len(words) == len(tags)
|
||||||
|
events = {}
|
||||||
|
|
||||||
|
def check_if_closing_range():
|
||||||
|
if i == len(tags) - 1 or tags[i + 1].split('-')[0] == 'O' or tags[i+1] == 'rel':
|
||||||
|
events[temp_type] = ''.join(words[begin: i + 1])
|
||||||
|
|
||||||
|
for i, tag in enumerate(tags):
|
||||||
|
if tag == 'rel':
|
||||||
|
events['rel'] = words[i]
|
||||||
|
elif tag.split('-')[0] == 'O':
|
||||||
|
pass
|
||||||
|
elif tag.split('-')[0] == 'B':
|
||||||
|
begin = i
|
||||||
|
temp_type = tag.split('-')[1]
|
||||||
|
check_if_closing_range()
|
||||||
|
elif tag.split('-')[0] == 'I':
|
||||||
|
check_if_closing_range()
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
def iobes_ranges(words, tags):
|
||||||
|
new_tags = iobes_iob(tags)
|
||||||
|
return iob_ranges(words, new_tags)
|
@ -0,0 +1,2 @@
|
|||||||
|
from .re.module import RE
|
||||||
|
__all__ = ['RE']
|
0
lightkg/ere/re/__init__.py
Normal file
0
lightkg/ere/re/__init__.py
Normal file
20
lightkg/ere/re/config.py
Normal file
20
lightkg/ere/re/config.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from ...base.config import DEVICE
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
'lr': 0.002,
|
||||||
|
'epoch': 5,
|
||||||
|
'lr_decay': 0.05,
|
||||||
|
'batch_size': 128,
|
||||||
|
'dropout': 0.5,
|
||||||
|
'static': False,
|
||||||
|
'non_static': False,
|
||||||
|
'embedding_dim': 300,
|
||||||
|
'vector_path': '',
|
||||||
|
'class_num': 0,
|
||||||
|
'vocabulary_size': 0,
|
||||||
|
'word_vocab': None,
|
||||||
|
'tag_vocab': None,
|
||||||
|
'save_path': './saves',
|
||||||
|
'filter_num': 100,
|
||||||
|
'filter_sizes': (3, 4, 5),
|
||||||
|
'multichannel': False
|
||||||
|
}
|
110
lightkg/ere/re/model.py
Normal file
110
lightkg/ere/re/model.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchtext.vocab import Vectors
|
||||||
|
|
||||||
|
from ...utils.log import logger
|
||||||
|
from ...base.model import BaseConfig, BaseModel
|
||||||
|
|
||||||
|
from .config import DEVICE, DEFAULT_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
class Config(BaseConfig):
|
||||||
|
def __init__(self, word_vocab, label_vocab, vector_path, **kwargs):
|
||||||
|
super(Config, self).__init__()
|
||||||
|
for name, value in DEFAULT_CONFIG.items():
|
||||||
|
setattr(self, name, value)
|
||||||
|
self.word_vocab = word_vocab
|
||||||
|
self.label_vocab = label_vocab
|
||||||
|
self.class_num = len(self.label_vocab)
|
||||||
|
self.vocabulary_size = len(self.word_vocab)
|
||||||
|
self.vector_path = vector_path
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
setattr(self, name, value)
|
||||||
|
|
||||||
|
|
||||||
|
class TextCNN(BaseModel):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(TextCNN, self).__init__(args)
|
||||||
|
|
||||||
|
self.class_num = args.class_num
|
||||||
|
self.chanel_num = 1
|
||||||
|
self.filter_num = args.filter_num
|
||||||
|
self.filter_sizes = args.filter_sizes
|
||||||
|
|
||||||
|
self.vocabulary_size = args.vocabulary_size
|
||||||
|
self.embedding_dimension = args.embedding_dim
|
||||||
|
self.embedding = nn.Embedding(self.vocabulary_size, self.embedding_dimension).to(DEVICE)
|
||||||
|
if args.static:
|
||||||
|
logger.info('logging word vectors from {}'.format(args.vector_path))
|
||||||
|
vectors = Vectors(args.vector_path).vectors
|
||||||
|
self.embedding = self.embedding.from_pretrained(vectors, freeze=not args.non_static).to(DEVICE)
|
||||||
|
if args.multichannel:
|
||||||
|
self.embedding2 = nn.Embedding(self.vocabulary_size, self.embedding_dimension).from_pretrained(
|
||||||
|
args.vectors).to(DEVICE)
|
||||||
|
self.chanel_num += 1
|
||||||
|
else:
|
||||||
|
self.embedding2 = None
|
||||||
|
self.convs = nn.ModuleList(
|
||||||
|
[nn.Conv2d(self.chanel_num, self.filter_num, (size, self.embedding_dimension)) for size in
|
||||||
|
self.filter_sizes]).to(DEVICE)
|
||||||
|
self.dropout = nn.Dropout(args.dropout).to(DEVICE)
|
||||||
|
self.fc = nn.Linear(len(self.filter_sizes) * self.filter_num, self.class_num).to(DEVICE)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.embedding2:
|
||||||
|
x = torch.stack((self.embedding(x), self.embedding2(x)), dim=1).to(DEVICE)
|
||||||
|
else:
|
||||||
|
x = self.embedding(x).to(DEVICE)
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
x = [F.relu(conv(x)).squeeze(3) for conv in self.convs]
|
||||||
|
x = [F.max_pool1d(item, item.size(2)).squeeze(2) for item in x]
|
||||||
|
x = torch.cat(x, 1)
|
||||||
|
x = self.dropout(x)
|
||||||
|
logits = self.fc(x)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class LSTMClassifier(BaseModel):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(LSTMClassifier, self).__init__(args)
|
||||||
|
|
||||||
|
self.hidden_dim = 300
|
||||||
|
self.class_num = args.class_num
|
||||||
|
self.batch_size = args.batch_size
|
||||||
|
|
||||||
|
self.vocabulary_size = args.vocabulary_size
|
||||||
|
self.embedding_dimension = args.embedding_dim
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(self.vocabulary_size, self.embedding_dimension).to(DEVICE)
|
||||||
|
if args.static:
|
||||||
|
self.embedding = self.embedding.from_pretrained(args.vectors, freeze=not args.non_static).to(DEVICE)
|
||||||
|
if args.multichannel:
|
||||||
|
self.embedding2 = nn.Embedding(self.vocabulary_size, self.embedding_dimension).from_pretrained(
|
||||||
|
args.vectors).to(DEVICE)
|
||||||
|
else:
|
||||||
|
self.embedding2 = None
|
||||||
|
|
||||||
|
self.lstm = nn.LSTM(self.embedding_dimension, self.hidden_dim).to(DEVICE)
|
||||||
|
self.hidden2label = nn.Linear(self.hidden_dim, self.class_num).to(DEVICE)
|
||||||
|
self.hidden = self.init_hidden()
|
||||||
|
|
||||||
|
def init_hidden(self, batch_size=None):
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = self.batch_size
|
||||||
|
|
||||||
|
h0 = torch.zeros(1, batch_size, self.hidden_dim).to(DEVICE)
|
||||||
|
c0 = torch.zeros(1, batch_size, self.hidden_dim).to(DEVICE)
|
||||||
|
|
||||||
|
return h0, c0
|
||||||
|
|
||||||
|
def forward(self, sentence):
|
||||||
|
embeds = self.embedding(sentence).to(DEVICE)
|
||||||
|
|
||||||
|
x = embeds.permute(1, 0, 2).to(DEVICE)
|
||||||
|
self.hidden = self.init_hidden(sentence.size()[0])
|
||||||
|
lstm_out, self.hidden = self.lstm(x, self.hidden)
|
||||||
|
lstm_out = lstm_out.to(DEVICE)
|
||||||
|
final = lstm_out[-1].to(DEVICE)
|
||||||
|
y = self.hidden2label(final)
|
||||||
|
return y
|
101
lightkg/ere/re/module.py
Normal file
101
lightkg/ere/re/module.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ...utils.learning import adjust_learning_rate
|
||||||
|
from ...utils.log import logger
|
||||||
|
from ...base.module import Module
|
||||||
|
|
||||||
|
from .model import Config, TextCNN
|
||||||
|
from .config import DEVICE, DEFAULT_CONFIG
|
||||||
|
from .tool import re_tool, TEXT, LABEL
|
||||||
|
from .utils.preprocess import handle_line
|
||||||
|
|
||||||
|
seed = 2019
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class RE(Module):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._model = None
|
||||||
|
self._word_vocab = None
|
||||||
|
self._label_vocab = None
|
||||||
|
|
||||||
|
def train(self, train_path, save_path=DEFAULT_CONFIG['save_path'], dev_path=None, vectors_path=None, **kwargs):
|
||||||
|
train_dataset = re_tool.get_dataset(train_path)
|
||||||
|
if dev_path:
|
||||||
|
dev_dataset = re_tool.get_dataset(dev_path)
|
||||||
|
word_vocab, label_vocab = re_tool.get_vocab(train_dataset, dev_dataset)
|
||||||
|
else:
|
||||||
|
word_vocab, label_vocab = re_tool.get_vocab(train_dataset)
|
||||||
|
self._word_vocab = word_vocab
|
||||||
|
self._label_vocab = label_vocab
|
||||||
|
train_iter = re_tool.get_iterator(train_dataset, batch_size=DEFAULT_CONFIG['batch_size'])
|
||||||
|
config = Config(word_vocab, label_vocab, save_path=save_path, vector_path=vectors_path, **kwargs)
|
||||||
|
textcnn = TextCNN(config)
|
||||||
|
# print(textcnn)
|
||||||
|
self._model = textcnn
|
||||||
|
optim = torch.optim.Adam(textcnn.parameters(), lr=config.lr)
|
||||||
|
for epoch in range(config.epoch):
|
||||||
|
textcnn.train()
|
||||||
|
acc_loss = 0
|
||||||
|
for fuck in tqdm(train_iter):
|
||||||
|
optim.zero_grad()
|
||||||
|
logits = self._model(fuck.text)
|
||||||
|
item_loss = F.cross_entropy(logits, fuck.label)
|
||||||
|
acc_loss += item_loss.item()
|
||||||
|
item_loss.backward()
|
||||||
|
optim.step()
|
||||||
|
logger.info('epoch: {}, acc_loss: {}'.format(epoch, acc_loss))
|
||||||
|
if dev_path:
|
||||||
|
dev_score = self._validate(dev_dataset)
|
||||||
|
logger.info('dev score:{}'.format(dev_score))
|
||||||
|
adjust_learning_rate(optim, config.lr / (1 + (epoch + 1) * config.lr_decay))
|
||||||
|
config.save()
|
||||||
|
textcnn.save()
|
||||||
|
|
||||||
|
def predict(self, entity1: str, entity2: str, sentence: str):
|
||||||
|
self._model.eval()
|
||||||
|
text = handle_line(entity1, entity2, sentence)
|
||||||
|
vec_text = torch.tensor([self._word_vocab.stoi[x] for x in text])
|
||||||
|
vec_text = vec_text.reshape(1, -1).to(DEVICE)
|
||||||
|
vec_predict = self._model(vec_text)[0]
|
||||||
|
soft_predict = torch.softmax(vec_predict, dim=0)
|
||||||
|
predict_prob, predict_index = torch.max(soft_predict.cpu().data, dim=0)
|
||||||
|
predict_class = self._label_vocab.itos[predict_index]
|
||||||
|
predict_prob = predict_prob.item()
|
||||||
|
return predict_prob, predict_class
|
||||||
|
|
||||||
|
def load(self, save_path=DEFAULT_CONFIG['save_path']):
|
||||||
|
config = Config.load(save_path)
|
||||||
|
textcnn = TextCNN(config)
|
||||||
|
textcnn.load()
|
||||||
|
self._model = textcnn
|
||||||
|
self._word_vocab = config.word_vocab
|
||||||
|
self._label_vocab = config.label_vocab
|
||||||
|
self._check_vocab()
|
||||||
|
|
||||||
|
def test(self, test_path):
|
||||||
|
self._model.eval()
|
||||||
|
test_dataset = re_tool.get_dataset(test_path)
|
||||||
|
test_score = self._validate(test_dataset)
|
||||||
|
logger.info('test score:{}'.format(test_score))
|
||||||
|
|
||||||
|
def _validate(self, dev_dataset, batch_size=DEFAULT_CONFIG['batch_size']):
|
||||||
|
self._model.eval()
|
||||||
|
dev_score_list = []
|
||||||
|
dev_iter = re_tool.get_iterator(dev_dataset, batch_size=batch_size)
|
||||||
|
for dev_item in tqdm(dev_iter):
|
||||||
|
item_score = re_tool.get_score(self._model, dev_item.text, dev_item.label)
|
||||||
|
dev_score_list.append(item_score)
|
||||||
|
return sum(dev_score_list) / len(dev_score_list)
|
||||||
|
|
||||||
|
def _check_vocab(self):
|
||||||
|
if not hasattr(TEXT, 'vocab'):
|
||||||
|
TEXT.vocab = self._word_vocab
|
||||||
|
if not hasattr(LABEL, 'vocab'):
|
||||||
|
LABEL.vocab = self._label_vocab
|
75
lightkg/ere/re/tool.py
Normal file
75
lightkg/ere/re/tool.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import re
|
||||||
|
import torch
|
||||||
|
from torchtext.data import Field, Iterator
|
||||||
|
from torchtext.vocab import Vectors
|
||||||
|
from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score
|
||||||
|
|
||||||
|
from ...base.tool import Tool
|
||||||
|
from ...utils.log import logger
|
||||||
|
from .config import DEVICE, DEFAULT_CONFIG
|
||||||
|
|
||||||
|
from .utils.dataset import REDataset
|
||||||
|
|
||||||
|
seed = 2019
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def light_tokenize(text):
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
TEXT = Field(lower=True, tokenize=light_tokenize, batch_first=True)
|
||||||
|
LABEL = Field(sequential=False, unk_token=None)
|
||||||
|
Fields = [
|
||||||
|
('text', TEXT),
|
||||||
|
('label', LABEL)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RETool(Tool):
|
||||||
|
def get_dataset(self, path: str, fields=Fields):
|
||||||
|
logger.info('loading dataset from {}'.format(path))
|
||||||
|
re_dataset = REDataset(path, fields)
|
||||||
|
logger.info('successed loading dataset')
|
||||||
|
return re_dataset
|
||||||
|
|
||||||
|
def get_vocab(self, *dataset):
|
||||||
|
logger.info('building word vocab...')
|
||||||
|
TEXT.build_vocab(*dataset)
|
||||||
|
logger.info('successed building word vocab')
|
||||||
|
logger.info('building label vocab...')
|
||||||
|
LABEL.build_vocab(*dataset)
|
||||||
|
logger.info('successed building label vocab')
|
||||||
|
return TEXT.vocab, LABEL.vocab
|
||||||
|
|
||||||
|
def get_vectors(self, path: str):
|
||||||
|
logger.info('loading vectors from {}'.format(path))
|
||||||
|
vectors = Vectors(path)
|
||||||
|
logger.info('successed loading vectors')
|
||||||
|
return vectors
|
||||||
|
|
||||||
|
def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
|
||||||
|
sort_key=lambda x: len(x.text)):
|
||||||
|
return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key)
|
||||||
|
|
||||||
|
def get_score(self, model, texts, labels, score_type='f1'):
|
||||||
|
metrics_map = {
|
||||||
|
'f1': f1_score,
|
||||||
|
'p': precision_score,
|
||||||
|
'r': recall_score,
|
||||||
|
'acc': accuracy_score
|
||||||
|
}
|
||||||
|
metric_func = metrics_map[score_type] if score_type in metrics_map else metrics_map['f1']
|
||||||
|
assert texts.size(0) == len(labels)
|
||||||
|
vec_predict = model(texts)
|
||||||
|
soft_predict = torch.softmax(vec_predict, dim=1)
|
||||||
|
predict_prob, predict_index = torch.max(soft_predict.cpu().data, dim=1)
|
||||||
|
# print('prob', predict_prob)
|
||||||
|
# print('index', predict_index)
|
||||||
|
# print('labels', labels)
|
||||||
|
labels = labels.view(-1).cpu().data.numpy()
|
||||||
|
return metric_func(predict_index, labels, average='micro')
|
||||||
|
|
||||||
|
|
||||||
|
re_tool = RETool()
|
22
lightkg/ere/re/utils/dataset.py
Normal file
22
lightkg/ere/re/utils/dataset.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from torchtext.data import Dataset, Example
|
||||||
|
from .preprocess import handle_line
|
||||||
|
|
||||||
|
|
||||||
|
class REDataset(Dataset):
|
||||||
|
"""Defines a Dataset of relation extraction format.
|
||||||
|
eg:
|
||||||
|
钱钟书 辛笛 同门 与辛笛京沪唱和聽钱钟书与钱钟书是清华校友,钱钟书高辛笛两班。
|
||||||
|
元武 元华 unknown 于师傅在一次京剧表演中,选了元龙(洪金宝)、元楼(元奎)、元彪、成龙、元华、元武、元泰7人担任七小福的主角。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path, fields, encoding="utf-8", **kwargs):
|
||||||
|
examples = []
|
||||||
|
with open(path, "r", encoding=encoding) as f:
|
||||||
|
for line in f:
|
||||||
|
chunks = line.split()
|
||||||
|
entity_1, entity_2, relation, sentence = tuple(chunks)
|
||||||
|
sentence_list = handle_line(entity_1, entity_2, sentence)
|
||||||
|
|
||||||
|
examples.append(Example.fromlist((sentence_list, relation), fields))
|
||||||
|
super(REDataset, self).__init__(examples, fields, **kwargs)
|
||||||
|
|
24
lightkg/ere/re/utils/preprocess.py
Normal file
24
lightkg/ere/re/utils/preprocess.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import jieba
|
||||||
|
|
||||||
|
|
||||||
|
def handle_line(entity1, entity2, sentence, begin_e1_token='<e1>', end_e1_token='</e1>', begin_e2_token='<e2>',
|
||||||
|
end_e2_token='</e2>'):
|
||||||
|
assert entity1 in sentence
|
||||||
|
assert entity2 in sentence
|
||||||
|
sentence = sentence.replace(entity1, begin_e1_token + entity1 + end_e1_token)
|
||||||
|
sentence = sentence.replace(entity2, begin_e2_token + entity2 + end_e2_token)
|
||||||
|
sentence = ' '.join(jieba.cut(sentence))
|
||||||
|
sentence = sentence.replace('< e1 >', begin_e1_token)
|
||||||
|
sentence = sentence.replace('< / e1 >', end_e1_token)
|
||||||
|
sentence = sentence.replace('< e2 >', begin_e2_token)
|
||||||
|
sentence = sentence.replace('< / e2 >', end_e2_token)
|
||||||
|
return sentence.split()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_str = '曾经沧海难为水哈哈谁说不是呢?!呵呵 低头不见抬头见'
|
||||||
|
e1 = '沧海'
|
||||||
|
e2 = '不是'
|
||||||
|
print(list(jieba.cut(test_str)))
|
||||||
|
print(handle_line(e1, e2, test_str))
|
||||||
|
|
@ -0,0 +1,2 @@
|
|||||||
|
from .ner.module import NER
|
||||||
|
__all__ = ['NER']
|
0
lightkg/erl/ner/__init__.py
Normal file
0
lightkg/erl/ner/__init__.py
Normal file
19
lightkg/erl/ner/config.py
Normal file
19
lightkg/erl/ner/config.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from ...base.config import DEVICE
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
'lr': 0.02,
|
||||||
|
'epoch': 30,
|
||||||
|
'lr_decay': 0.05,
|
||||||
|
'batch_size': 128,
|
||||||
|
'dropout': 0.5,
|
||||||
|
'static': False,
|
||||||
|
'non_static': False,
|
||||||
|
'embedding_dim': 300,
|
||||||
|
'num_layers': 2,
|
||||||
|
'pad_index': 1,
|
||||||
|
'vector_path': '',
|
||||||
|
'tag_num': 0,
|
||||||
|
'vocabulary_size': 0,
|
||||||
|
'word_vocab': None,
|
||||||
|
'tag_vocab': None,
|
||||||
|
'save_path': './saves'
|
||||||
|
}
|
89
lightkg/erl/ner/model.py
Normal file
89
lightkg/erl/ner/model.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchcrf import CRF
|
||||||
|
from torchtext.vocab import Vectors
|
||||||
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||||
|
|
||||||
|
from ...utils.log import logger
|
||||||
|
from .config import DEVICE, DEFAULT_CONFIG
|
||||||
|
from ...base.model import BaseConfig, BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Config(BaseConfig):
|
||||||
|
def __init__(self, word_vocab, tag_vocab, vector_path, **kwargs):
|
||||||
|
super(Config, self).__init__()
|
||||||
|
for name, value in DEFAULT_CONFIG.items():
|
||||||
|
setattr(self, name, value)
|
||||||
|
self.word_vocab = word_vocab
|
||||||
|
self.tag_vocab = tag_vocab
|
||||||
|
self.tag_num = len(self.tag_vocab)
|
||||||
|
self.vocabulary_size = len(self.word_vocab)
|
||||||
|
self.vector_path = vector_path
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
setattr(self, name, value)
|
||||||
|
|
||||||
|
|
||||||
|
class BiLstmCrf(BaseModel):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(BiLstmCrf, self).__init__(args)
|
||||||
|
self.args = args
|
||||||
|
self.hidden_dim = 300
|
||||||
|
self.tag_num = args.tag_num
|
||||||
|
self.batch_size = args.batch_size
|
||||||
|
self.bidirectional = True
|
||||||
|
self.num_layers = args.num_layers
|
||||||
|
self.pad_index = args.pad_index
|
||||||
|
self.dropout = args.dropout
|
||||||
|
self.save_path = args.save_path
|
||||||
|
|
||||||
|
vocabulary_size = args.vocabulary_size
|
||||||
|
embedding_dimension = args.embedding_dim
|
||||||
|
|
||||||
|
self.embedding = nn.Embedding(vocabulary_size, embedding_dimension).to(DEVICE)
|
||||||
|
if args.static:
|
||||||
|
logger.info('logging word vectors from {}'.format(args.vector_path))
|
||||||
|
vectors = Vectors(args.vector_path).vectors
|
||||||
|
self.embedding = self.embedding.from_pretrained(vectors, freeze=not args.non_static).to(DEVICE)
|
||||||
|
|
||||||
|
self.lstm = nn.LSTM(embedding_dimension, self.hidden_dim // 2, bidirectional=self.bidirectional,
|
||||||
|
num_layers=self.num_layers, dropout=self.dropout).to(DEVICE)
|
||||||
|
self.hidden2label = nn.Linear(self.hidden_dim, self.tag_num).to(DEVICE)
|
||||||
|
self.crflayer = CRF(self.tag_num).to(DEVICE)
|
||||||
|
|
||||||
|
# self.init_weight()
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
nn.init.xavier_normal_(self.embedding.weight)
|
||||||
|
for name, param in self.lstm.named_parameters():
|
||||||
|
if 'weight' in name:
|
||||||
|
nn.init.xavier_normal_(param)
|
||||||
|
nn.init.xavier_normal_(self.hidden2label.weight)
|
||||||
|
|
||||||
|
def init_hidden(self, batch_size=None):
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = self.batch_size
|
||||||
|
|
||||||
|
h0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim // 2).to(DEVICE)
|
||||||
|
c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim // 2).to(DEVICE)
|
||||||
|
|
||||||
|
return h0, c0
|
||||||
|
|
||||||
|
def loss(self, x, sent_lengths, y):
|
||||||
|
mask = torch.ne(x, self.pad_index)
|
||||||
|
emissions = self.lstm_forward(x, sent_lengths)
|
||||||
|
return self.crflayer(emissions, y, mask=mask)
|
||||||
|
|
||||||
|
def forward(self, x, sent_lengths):
|
||||||
|
mask = torch.ne(x, self.pad_index)
|
||||||
|
emissions = self.lstm_forward(x, sent_lengths)
|
||||||
|
return self.crflayer.decode(emissions, mask=mask)
|
||||||
|
|
||||||
|
def lstm_forward(self, sentence, sent_lengths):
|
||||||
|
x = self.embedding(sentence.to(DEVICE)).to(DEVICE)
|
||||||
|
x = pack_padded_sequence(x, sent_lengths)
|
||||||
|
self.hidden = self.init_hidden(batch_size=len(sent_lengths))
|
||||||
|
lstm_out, self.hidden = self.lstm(x, self.hidden)
|
||||||
|
lstm_out, new_batch_size = pad_packed_sequence(lstm_out)
|
||||||
|
assert torch.equal(sent_lengths, new_batch_size.to(DEVICE))
|
||||||
|
y = self.hidden2label(lstm_out.to(DEVICE))
|
||||||
|
return y.to(DEVICE)
|
87
lightkg/erl/ner/module.py
Normal file
87
lightkg/erl/ner/module.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ...utils.learning import adjust_learning_rate
|
||||||
|
from ...utils.log import logger
|
||||||
|
from ...base.module import Module
|
||||||
|
|
||||||
|
from .config import DEVICE, DEFAULT_CONFIG
|
||||||
|
from .model import Config, BiLstmCrf
|
||||||
|
from .tool import ner_tool
|
||||||
|
from .utils.convert import iob_ranges
|
||||||
|
|
||||||
|
seed = 2019
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class NER(Module):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self._model = None
|
||||||
|
self._word_vocab = None
|
||||||
|
self._tag_vocab = None
|
||||||
|
|
||||||
|
def train(self, train_path, save_path=DEFAULT_CONFIG['save_path'], dev_path=None, vectors_path=None, **kwargs):
|
||||||
|
train_dataset = ner_tool.get_dataset(train_path)
|
||||||
|
if dev_path:
|
||||||
|
dev_dataset = ner_tool.get_dataset(dev_path)
|
||||||
|
word_vocab, tag_vocab = ner_tool.get_vocab(train_dataset, dev_dataset)
|
||||||
|
else:
|
||||||
|
word_vocab, tag_vocab = ner_tool.get_vocab(train_dataset)
|
||||||
|
self._word_vocab = word_vocab
|
||||||
|
self._tag_vocab = tag_vocab
|
||||||
|
train_iter = ner_tool.get_iterator(train_dataset, batch_size=DEFAULT_CONFIG['batch_size'])
|
||||||
|
config = Config(word_vocab, tag_vocab, save_path=save_path, vector_path=vectors_path, **kwargs)
|
||||||
|
bilstmcrf = BiLstmCrf(config)
|
||||||
|
self._model = bilstmcrf
|
||||||
|
optim = torch.optim.Adam(bilstmcrf.parameters(), lr=config.lr)
|
||||||
|
for epoch in range(config.epoch):
|
||||||
|
bilstmcrf.train()
|
||||||
|
acc_loss = 0
|
||||||
|
for item in tqdm(train_iter):
|
||||||
|
bilstmcrf.zero_grad()
|
||||||
|
item_text_sentences = item.text[0]
|
||||||
|
item_text_lengths = item.text[1]
|
||||||
|
item_loss = (-bilstmcrf.loss(item_text_sentences, item_text_lengths, item.tag)) / item.tag.size(1)
|
||||||
|
acc_loss += item_loss.view(-1).cpu().data.tolist()[0]
|
||||||
|
item_loss.backward()
|
||||||
|
optim.step()
|
||||||
|
logger.info('epoch: {}, acc_loss: {}'.format(epoch, acc_loss))
|
||||||
|
if dev_path:
|
||||||
|
dev_score = self._validate(dev_dataset)
|
||||||
|
logger.info('dev score:{}'.format(dev_score))
|
||||||
|
|
||||||
|
adjust_learning_rate(optim, config.lr / (1 + (epoch + 1) * config.lr_decay))
|
||||||
|
config.save()
|
||||||
|
bilstmcrf.save()
|
||||||
|
|
||||||
|
def predict(self, text):
|
||||||
|
self._model.eval()
|
||||||
|
vec_text = torch.tensor([self._word_vocab.stoi[x] for x in text])
|
||||||
|
len_text = torch.tensor([len(vec_text)]).to(DEVICE)
|
||||||
|
vec_predict = self._model(vec_text.view(-1, 1).to(DEVICE), len_text)[0]
|
||||||
|
tag_predict = [self._tag_vocab.itos[i] for i in vec_predict]
|
||||||
|
return iob_ranges([x for x in text], tag_predict)
|
||||||
|
|
||||||
|
def load(self, save_path=DEFAULT_CONFIG['save_path']):
|
||||||
|
config = Config.load(save_path)
|
||||||
|
bilstmcrf = BiLstmCrf(config)
|
||||||
|
bilstmcrf.load()
|
||||||
|
self._model = bilstmcrf
|
||||||
|
self._word_vocab = config.word_vocab
|
||||||
|
self._tag_vocab = config.tag_vocab
|
||||||
|
|
||||||
|
def test(self, test_path):
|
||||||
|
test_dataset = ner_tool.get_dataset(test_path)
|
||||||
|
test_score = self._validate(test_dataset)
|
||||||
|
logger.info('test score:{}'.format(test_score))
|
||||||
|
|
||||||
|
def _validate(self, dev_dataset):
|
||||||
|
self._model.eval()
|
||||||
|
dev_score_list = []
|
||||||
|
for dev_item in tqdm(dev_dataset):
|
||||||
|
item_score = ner_tool.get_score(self._model, dev_item.text, dev_item.tag, self._word_vocab, self._tag_vocab)
|
||||||
|
dev_score_list.append(item_score)
|
||||||
|
return sum(dev_score_list) / len(dev_score_list)
|
68
lightkg/erl/ner/tool.py
Normal file
68
lightkg/erl/ner/tool.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import torch
|
||||||
|
from torchtext.data import Dataset, Field, BucketIterator, ReversibleField
|
||||||
|
from torchtext.vocab import Vectors
|
||||||
|
from torchtext.datasets import SequenceTaggingDataset
|
||||||
|
from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score
|
||||||
|
|
||||||
|
from ...base.tool import Tool
|
||||||
|
from ...utils.log import logger
|
||||||
|
from .config import DEVICE, DEFAULT_CONFIG
|
||||||
|
|
||||||
|
seed = 2019
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def light_tokenize(sequence: str):
|
||||||
|
return [sequence]
|
||||||
|
|
||||||
|
|
||||||
|
TEXT = Field(sequential=True, tokenize=light_tokenize, include_lengths=True)
|
||||||
|
TAG = ReversibleField(sequential=True, tokenize=light_tokenize, is_target=True, unk_token=None)
|
||||||
|
Fields = [('text', TEXT), ('tag', TAG)]
|
||||||
|
|
||||||
|
|
||||||
|
class NERTool(Tool):
|
||||||
|
def get_dataset(self, path: str, fields=Fields, separator=' '):
|
||||||
|
logger.info('loading dataset from {}'.format(path))
|
||||||
|
st_dataset = SequenceTaggingDataset(path, fields=fields, separator=separator)
|
||||||
|
logger.info('successed loading dataset')
|
||||||
|
return st_dataset
|
||||||
|
|
||||||
|
def get_vocab(self, *dataset):
|
||||||
|
logger.info('building word vocab...')
|
||||||
|
TEXT.build_vocab(*dataset)
|
||||||
|
logger.info('successed building word vocab')
|
||||||
|
logger.info('building tag vocab...')
|
||||||
|
TAG.build_vocab(*dataset)
|
||||||
|
logger.info('successed building tag vocab')
|
||||||
|
return TEXT.vocab, TAG.vocab
|
||||||
|
|
||||||
|
def get_vectors(self, path: str):
|
||||||
|
logger.info('loading vectors from {}'.format(path))
|
||||||
|
vectors = Vectors(path)
|
||||||
|
logger.info('successed loading vectors')
|
||||||
|
return vectors
|
||||||
|
|
||||||
|
def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE,
|
||||||
|
sort_key=lambda x: len(x.text), sort_within_batch=True):
|
||||||
|
return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key,
|
||||||
|
sort_within_batch=sort_within_batch)
|
||||||
|
|
||||||
|
def get_score(self, model, x, y, field_x, field_y, score_type='f1'):
|
||||||
|
metrics_map = {
|
||||||
|
'f1': f1_score,
|
||||||
|
'p': precision_score,
|
||||||
|
'r': recall_score,
|
||||||
|
'acc': accuracy_score
|
||||||
|
}
|
||||||
|
metric_func = metrics_map[score_type] if score_type in metrics_map else metrics_map['f1']
|
||||||
|
vec_x = torch.tensor([field_x.stoi[i] for i in x])
|
||||||
|
len_vec_x = torch.tensor([len(vec_x)]).to(DEVICE)
|
||||||
|
predict_y = model(vec_x.view(-1, 1).to(DEVICE), len_vec_x)[0]
|
||||||
|
true_y = [field_y.stoi[i] for i in y]
|
||||||
|
assert len(true_y) == len(predict_y)
|
||||||
|
return metric_func(predict_y, true_y, average='micro')
|
||||||
|
|
||||||
|
|
||||||
|
ner_tool = NERTool()
|
26
lightkg/erl/ner/utils/convert.py
Normal file
26
lightkg/erl/ner/utils/convert.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
def iob_ranges(words, tags):
|
||||||
|
"""
|
||||||
|
IOB -> Ranges
|
||||||
|
"""
|
||||||
|
assert len(words) == len(tags)
|
||||||
|
ranges = []
|
||||||
|
|
||||||
|
def check_if_closing_range():
|
||||||
|
if i == len(tags) - 1 or tags[i + 1].split('_')[0] == 'O':
|
||||||
|
ranges.append({
|
||||||
|
'entity': ''.join(words[begin: i + 1]),
|
||||||
|
'type': temp_type,
|
||||||
|
'start': begin,
|
||||||
|
'end': i
|
||||||
|
})
|
||||||
|
|
||||||
|
for i, tag in enumerate(tags):
|
||||||
|
if tag.split('_')[0] == 'O':
|
||||||
|
pass
|
||||||
|
elif tag.split('_')[0] == 'B':
|
||||||
|
begin = i
|
||||||
|
temp_type = tag.split('_')[1]
|
||||||
|
check_if_closing_range()
|
||||||
|
elif tag.split('_')[0] == 'I':
|
||||||
|
check_if_closing_range()
|
||||||
|
return ranges
|
0
lightkg/ksq/__init__.py
Normal file
0
lightkg/ksq/__init__.py
Normal file
Loading…
Reference in New Issue
Block a user