From 4f3435fa183e15b0f04f79c568cb67a798f622dd Mon Sep 17 00:00:00 2001 From: lightsmile Date: Tue, 26 Mar 2019 23:32:46 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=AE=9E=E4=BD=93=E6=8A=BD?= =?UTF-8?q?=E5=8F=96=E3=80=81=E5=85=B3=E7=B3=BB=E6=8A=BD=E5=8F=96=E3=80=81?= =?UTF-8?q?=E4=BA=8B=E4=BB=B6=E6=8A=BD=E5=8F=96=E6=A8=A1=E5=9E=8B=E5=8F=8A?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E9=A2=84=E6=B5=8B=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 +- README.md | 135 +++++++++++++++++++++++++++ {test => examples}/test_krl.py | 0 examples/test_ner.py | 15 +++ examples/test_re.py | 15 +++ examples/test_srl.py | 21 +++++ lightkg/ede/__init__.py | 2 + lightkg/{esq => ede/srl}/__init__.py | 0 lightkg/ede/srl/config.py | 22 +++++ lightkg/ede/srl/model.py | 97 +++++++++++++++++++ lightkg/ede/srl/module.py | 94 +++++++++++++++++++ lightkg/ede/srl/tool.py | 80 ++++++++++++++++ lightkg/ede/srl/utils/convert.py | 51 ++++++++++ lightkg/ere/__init__.py | 2 + lightkg/ere/re/__init__.py | 0 lightkg/ere/re/config.py | 20 ++++ lightkg/ere/re/model.py | 110 ++++++++++++++++++++++ lightkg/ere/re/module.py | 101 ++++++++++++++++++++ lightkg/ere/re/tool.py | 75 +++++++++++++++ lightkg/ere/re/utils/dataset.py | 22 +++++ lightkg/ere/re/utils/preprocess.py | 24 +++++ lightkg/erl/__init__.py | 2 + lightkg/erl/ner/__init__.py | 0 lightkg/erl/ner/config.py | 19 ++++ lightkg/erl/ner/model.py | 89 ++++++++++++++++++ lightkg/erl/ner/module.py | 87 +++++++++++++++++ lightkg/erl/ner/tool.py | 68 ++++++++++++++ lightkg/erl/ner/utils/convert.py | 26 ++++++ lightkg/ksq/__init__.py | 0 29 files changed, 1178 insertions(+), 1 deletion(-) rename {test => examples}/test_krl.py (100%) create mode 100644 examples/test_ner.py create mode 100644 examples/test_re.py create mode 100644 examples/test_srl.py rename lightkg/{esq => ede/srl}/__init__.py (100%) create mode 100644 lightkg/ede/srl/config.py create mode 100644 lightkg/ede/srl/model.py create mode 100644 lightkg/ede/srl/module.py create mode 100644 lightkg/ede/srl/tool.py create mode 100644 lightkg/ede/srl/utils/convert.py create mode 100644 lightkg/ere/re/__init__.py create mode 100644 lightkg/ere/re/config.py create mode 100644 lightkg/ere/re/model.py create mode 100644 lightkg/ere/re/module.py create mode 100644 lightkg/ere/re/tool.py create mode 100644 lightkg/ere/re/utils/dataset.py create mode 100644 lightkg/ere/re/utils/preprocess.py create mode 100644 lightkg/erl/ner/__init__.py create mode 100644 lightkg/erl/ner/config.py create mode 100644 lightkg/erl/ner/model.py create mode 100644 lightkg/erl/ner/module.py create mode 100644 lightkg/erl/ner/tool.py create mode 100644 lightkg/erl/ner/utils/convert.py create mode 100644 lightkg/ksq/__init__.py diff --git a/.gitignore b/.gitignore index 0843191..a545e5b 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,4 @@ build dist .vscode lightKG.egg-info/ -test/*_saves \ No newline at end of file +examples/*_saves \ No newline at end of file diff --git a/README.md b/README.md index e729a9e..46537fc 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,16 @@ ### 实体识别与链接 +- 命名实体识别, ner + ### 实体关系抽取 +- 关系抽取, re + ### 事件检测与抽取 +- 语义角色标注, srl + ### 知识存储与查询 ### 知识推理 @@ -147,6 +153,124 @@ print(krl.predict_head(rel='外文名', tail='Compiler')) [('编译器', 0.998942494392395), ('译码器', 0.36795616149902344), ('计算机,单片机,编程语言', 0.36788302659988403)] ``` +### ner + +#### 训练 + +```python +from lightkg.erl import NER + +# 创建NER对象 +ner_model = NER() + +train_path = '/home/lightsmile/NLP/corpus/ner/train.sample.txt' +dev_path = '/home/lightsmile/NLP/corpus/ner/test.sample.txt' +vec_path = '/home/lightsmile/NLP/embedding/char/token_vec_300.bin' + +# 只需指定训练数据路径,预训练字向量可选,开发集路径可选,模型保存路径可选。 +ner_model.train(train_path, vectors_path=vec_path, dev_path=dev_path, save_path='./ner_saves') +``` + +#### 测试 + +```python +# 加载模型,默认当前目录下的`saves`目录 +ner_model.load('./ner_saves') +# 对train_path下的测试集进行读取测试 +ner_model.test(train_path) +``` + +#### 预测 + +```python +from pprint import pprint + +pprint(ner_model.predict('另一个很酷的事情是,通过框架我们可以停止并在稍后恢复训练。')) +``` + +预测结果: + +```bash +[{'end': 15, 'entity': '我们', 'start': 14, 'type': 'Person'}] +``` + +### re + +#### 训练 + +```python +from lightkg.ere import RE + +re = RE() + +train_path = '/home/lightsmile/Projects/NLP/ChineseNRE/data/people-relation/train.sample.txt' +dev_path = '/home/lightsmile/Projects/NLP/ChineseNRE/data/people-relation/test.sample.txt' +vec_path = '/home/lightsmile/NLP/embedding/word/sgns.zhihu.bigram-char' + +re.train(train_path, dev_path=dev_path, vectors_path=vec_path, save_path='./re_saves') + +``` + +#### 测试 + +```python +re.load('./re_saves') +re.test(dev_path) +``` + +#### 预测 + +```python +print(re.predict('钱钟书', '辛笛', '与辛笛京沪唱和聽钱钟书与钱钟书是清华校友,钱钟书高辛笛两班。')) +``` + +预测结果: + +```python +(0.7306928038597107, '同门') # return格式为(预测概率,预测标签) +``` + +### srl + +#### 训练 + +```python +from lightkg.ede import SRL + +srl_model = SRL() + +train_path = '/home/lightsmile/NLP/corpus/srl/train.sample.tsv' +dev_path = '/home/lightsmile/NLP/corpus/srl/test.sample.tsv' +vec_path = '/home/lightsmile/NLP/embedding/word/sgns.zhihu.bigram-char' + + +srl_model.train(train_path, vectors_path=vec_path, dev_path=dev_path, save_path='./srl_saves') +``` + +#### 测试 + +```python +srl_model.load('./srl_saves') + +srl_model.test(dev_path) +``` + +#### 预测 + +```python +word_list = ['代表', '朝方', '对', '中国', '党政', '领导人', '和', '人民', '哀悼', '金日成', '主席', '逝世', '表示', '深切', '谢意', '。'] +pos_list = ['VV', 'NN', 'P', 'NR', 'NN', 'NN', 'CC', 'NN', 'VV', 'NR', 'NN', 'VV', 'VV', 'JJ', 'NN', 'PU'] +rel_list = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0] + +print(srl_model.predict(word_list, pos_list, rel_list)) +``` + +预测结果: + +```bash +{'ARG0': '中国党政领导人和人民', 'rel': '哀悼', 'ARG1': '金日成主席逝世'} +``` + ## 项目组织结构 ### 项目架构 @@ -158,10 +282,18 @@ print(krl.predict_head(rel='外文名', tail='Compiler')) - common - entity.py - relation.py +- ede + - srl, 语义角色标注 +- ere + - re, 关系抽取 +- erl + - ner, 命名实体识别 +- kr - krl,知识表示学习 - models - transE - utils +- ksq - utils ### 架构说明 @@ -207,6 +339,9 @@ print(krl.predict_head(rel='外文名', tail='Compiler')) ### 功能 +- [x] 增加关系抽取相关模型以及训练预测代码 +- [x] 增加事件抽取相关模型以及训练预测代码 +- [x] 增加命名实体识别相关模型以及预测训练代码 - [x] 增加基于翻译模型的知识表示学习相关模型以及训练预测代码 ## 参考 diff --git a/test/test_krl.py b/examples/test_krl.py similarity index 100% rename from test/test_krl.py rename to examples/test_krl.py diff --git a/examples/test_ner.py b/examples/test_ner.py new file mode 100644 index 0000000..79f244c --- /dev/null +++ b/examples/test_ner.py @@ -0,0 +1,15 @@ +from lightkg.erl import NER + +ner_model = NER() + +train_path = '/home/lightsmile/NLP/corpus/ner/train.sample.txt' +dev_path = '/home/lightsmile/NLP/corpus/ner/test.sample.txt' +vec_path = '/home/lightsmile/NLP/embedding/char/token_vec_300.bin' + +# ner_model.train(train_path, vectors_path=vec_path, dev_path=dev_path, save_path='./ner_saves') + +ner_model.load('./ner_saves') +# ner_model.test(train_path) + +from pprint import pprint +pprint(ner_model.predict('另一个很酷的事情是,通过框架我们可以停止并在稍后恢复训练。')) diff --git a/examples/test_re.py b/examples/test_re.py new file mode 100644 index 0000000..8087f83 --- /dev/null +++ b/examples/test_re.py @@ -0,0 +1,15 @@ +from lightkg.ere import RE + +re = RE() + +train_path = '/home/lightsmile/Projects/NLP/ChineseNRE/data/people-relation/train.sample.txt' +dev_path = '/home/lightsmile/Projects/NLP/ChineseNRE/data/people-relation/test.sample.txt' +vec_path = '/home/lightsmile/NLP/embedding/word/sgns.zhihu.bigram-char' + + +# re.train(train_path, dev_path=train_path, vectors_path=vec_path, save_path='./re_saves') + +re.load('./re_saves') +# re.test(train_path) +# 钱钟书 辛笛 同门 与辛笛京沪唱和聽钱钟书与钱钟书是清华校友,钱钟书高辛笛两班。 +print(re.predict('钱钟书', '辛笛', '与辛笛京沪唱和聽钱钟书与钱钟书是清华校友,钱钟书高辛笛两班。')) diff --git a/examples/test_srl.py b/examples/test_srl.py new file mode 100644 index 0000000..88a0fe2 --- /dev/null +++ b/examples/test_srl.py @@ -0,0 +1,21 @@ +from lightkg.ede import SRL + +srl_model = SRL() + +train_path = '/home/lightsmile/NLP/corpus/srl/train.sample.tsv' +dev_path = '/home/lightsmile/NLP/corpus/srl/test.sample.tsv' +vec_path = '/home/lightsmile/NLP/embedding/word/sgns.zhihu.bigram-char' + + +# srl_model.train(train_path, vectors_path=vec_path, dev_path=dev_path, save_path='./srl_saves') + +srl_model.load('./srl_saves') + +# srl_model.test(dev_path) + +word_list = ['代表', '朝方', '对', '中国', '党政', '领导人', '和', '人民', '哀悼', '金日成', '主席', '逝世', '表示', '深切', '谢意', '。'] +pos_list = ['VV', 'NN', 'P', 'NR', 'NN', 'NN', 'CC', 'NN', 'VV', 'NR', 'NN', 'VV', 'VV', 'JJ', 'NN', 'PU'] +rel_list = [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0] + +print(srl_model.predict(word_list, pos_list, rel_list)) + diff --git a/lightkg/ede/__init__.py b/lightkg/ede/__init__.py index e69de29..511c5fd 100644 --- a/lightkg/ede/__init__.py +++ b/lightkg/ede/__init__.py @@ -0,0 +1,2 @@ +from .srl.module import SRL +__all__ = ['SRL'] diff --git a/lightkg/esq/__init__.py b/lightkg/ede/srl/__init__.py similarity index 100% rename from lightkg/esq/__init__.py rename to lightkg/ede/srl/__init__.py diff --git a/lightkg/ede/srl/config.py b/lightkg/ede/srl/config.py new file mode 100644 index 0000000..3f40121 --- /dev/null +++ b/lightkg/ede/srl/config.py @@ -0,0 +1,22 @@ +from ...base.config import DEVICE +DEFAULT_CONFIG = { + 'lr': 0.02, + 'epoch': 30, + 'lr_decay': 0.05, + 'batch_size': 128, + 'dropout': 0.5, + 'static': False, + 'non_static': False, + 'embedding_dim': 300, + 'pos_dim': 50, + 'num_layers': 2, + 'pad_index': 1, + 'vector_path': '', + 'tag_num': 0, + 'vocabulary_size': 0, + 'pos_size': 0, + 'word_vocab': None, + 'pos_vocab': None, + 'tag_vocab': None, + 'save_path': './saves' +} \ No newline at end of file diff --git a/lightkg/ede/srl/model.py b/lightkg/ede/srl/model.py new file mode 100644 index 0000000..922aa6d --- /dev/null +++ b/lightkg/ede/srl/model.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from torchcrf import CRF +from torchtext.vocab import Vectors +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +from ...utils.log import logger +from .config import DEVICE, DEFAULT_CONFIG +from ...base.model import BaseConfig, BaseModel + + +class Config(BaseConfig): + def __init__(self, word_vocab, pos_vocab, tag_vocab, vector_path, **kwargs): + super(Config, self).__init__() + for name, value in DEFAULT_CONFIG.items(): + setattr(self, name, value) + self.word_vocab = word_vocab + self.pos_vocab = pos_vocab + self.tag_vocab = tag_vocab + self.tag_num = len(self.tag_vocab) + self.vocabulary_size = len(self.word_vocab) + self.pos_size = len(self.pos_vocab) + self.vector_path = vector_path + for name, value in kwargs.items(): + setattr(self, name, value) + + +class BiLstmCrf(BaseModel): + def __init__(self, args): + super(BiLstmCrf, self).__init__(args) + self.args = args + self.hidden_dim = 300 + self.tag_num = args.tag_num + self.batch_size = args.batch_size + self.bidirectional = True + self.num_layers = args.num_layers + self.pad_index = args.pad_index + self.dropout = args.dropout + self.save_path = args.save_path + + vocabulary_size = args.vocabulary_size + embedding_dimension = args.embedding_dim + pos_size = args.pos_size + pos_dim = args.pos_dim + + self.word_embedding = nn.Embedding(vocabulary_size, embedding_dimension).to(DEVICE) + if args.static: + logger.info('logging word vectors from {}'.format(args.vector_path)) + vectors = Vectors(args.vector_path).vectors + self.word_embedding = nn.Embedding.from_pretrained(vectors, freeze=not args.non_static).to(DEVICE) + self.pos_embedding = nn.Embedding(pos_size, pos_dim).to(DEVICE) + + self.lstm = nn.LSTM(embedding_dimension + pos_dim + 1, self.hidden_dim // 2, bidirectional=self.bidirectional, + num_layers=self.num_layers, dropout=self.dropout).to(DEVICE) + self.hidden2label = nn.Linear(self.hidden_dim, self.tag_num).to(DEVICE) + self.crflayer = CRF(self.tag_num).to(DEVICE) + + # self.init_weight() + + def init_weight(self): + nn.init.xavier_normal_(self.embedding.weight) + for name, param in self.lstm.named_parameters(): + if 'weight' in name: + nn.init.xavier_normal_(param) + nn.init.xavier_normal_(self.hidden2label.weight) + + def init_hidden(self, batch_size=None): + if batch_size is None: + batch_size = self.batch_size + + h0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim // 2).to(DEVICE) + c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim // 2).to(DEVICE) + + return h0, c0 + + def loss(self, x, sent_lengths, pos, rel, y): + mask = torch.ne(x, self.pad_index) + emissions = self.lstm_forward(x, pos, rel, sent_lengths) + return self.crflayer(emissions, y, mask=mask) + + def forward(self, x, poses, rels, sent_lengths): + mask = torch.ne(x, self.pad_index) + emissions = self.lstm_forward(x, poses, rels, sent_lengths) + return self.crflayer.decode(emissions, mask=mask) + + def lstm_forward(self, sentence, poses, rels, sent_lengths): + word = self.word_embedding(sentence.to(DEVICE)).to(DEVICE) + pos = self.pos_embedding(poses.to(DEVICE)).to(DEVICE) + rels = rels.view(rels.size(0), rels.size(1), 1).float().to(DEVICE) + x = torch.cat((word, pos, rels), dim=2) + x = pack_padded_sequence(x, sent_lengths) + self.hidden = self.init_hidden(batch_size=len(sent_lengths)) + lstm_out, self.hidden = self.lstm(x, self.hidden) + lstm_out, new_batch_size = pad_packed_sequence(lstm_out) + assert torch.equal(sent_lengths, new_batch_size.to(DEVICE)) + y = self.hidden2label(lstm_out.to(DEVICE)) + return y.to(DEVICE) diff --git a/lightkg/ede/srl/module.py b/lightkg/ede/srl/module.py new file mode 100644 index 0000000..c520034 --- /dev/null +++ b/lightkg/ede/srl/module.py @@ -0,0 +1,94 @@ +import torch +from tqdm import tqdm + +from ...utils.learning import adjust_learning_rate +from ...utils.log import logger +from ...base.module import Module + +from .config import DEVICE, DEFAULT_CONFIG +from .model import Config, BiLstmCrf +from .tool import srl_tool +from .utils.convert import iobes_ranges + +seed = 2019 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +class SRL(Module): + """ + """ + def __init__(self): + self._model = None + self._word_vocab = None + self._tag_vocab = None + self._pos_vocab = None + + def train(self, train_path, save_path=DEFAULT_CONFIG['save_path'], dev_path=None, vectors_path=None, **kwargs): + train_dataset = srl_tool.get_dataset(train_path) + if dev_path: + dev_dataset = srl_tool.get_dataset(dev_path) + word_vocab, pos_vocab, tag_vocab = srl_tool.get_vocab(train_dataset, dev_dataset) + else: + word_vocab, pos_vocab, tag_vocab = srl_tool.get_vocab(train_dataset) + self._word_vocab = word_vocab + self._pos_vocab = pos_vocab + self._tag_vocab = tag_vocab + train_iter = srl_tool.get_iterator(train_dataset, batch_size=DEFAULT_CONFIG['batch_size']) + config = Config(word_vocab, pos_vocab, tag_vocab, save_path=save_path, vector_path=vectors_path, **kwargs) + bilstmcrf = BiLstmCrf(config) + self._model = bilstmcrf + optim = torch.optim.Adam(bilstmcrf.parameters(), lr=config.lr) + for epoch in range(config.epoch): + bilstmcrf.train() + acc_loss = 0 + for item in tqdm(train_iter): + bilstmcrf.zero_grad() + item_text_sentences = item.text[0] + item_text_lengths = item.text[1] + item_loss = (-bilstmcrf.loss(item_text_sentences, item_text_lengths, item.pos, item.rel, item.tag)) / item.tag.size(1) + acc_loss += item_loss.view(-1).cpu().data.tolist()[0] + item_loss.backward() + optim.step() + logger.info('epoch: {}, acc_loss: {}'.format(epoch, acc_loss)) + if dev_path: + dev_score = self._validate(dev_dataset) + logger.info('dev score:{}'.format(dev_score)) + + adjust_learning_rate(optim, config.lr / (1 + (epoch + 1) * config.lr_decay)) + config.save() + bilstmcrf.save() + + def predict(self, word_list, pos_list, rel_list): + self._model.eval() + assert len(word_list) == len(pos_list) == len(rel_list) + vec_text = torch.tensor([self._word_vocab.stoi[x] for x in word_list]).view(-1, 1).to(DEVICE) + len_text = torch.tensor([len(vec_text)]).to(DEVICE) + vec_pos = torch.tensor([self._pos_vocab.stoi[x] for x in pos_list]).view(-1, 1).to(DEVICE) + vec_rel = torch.tensor([int(x) for x in rel_list]).view(-1, 1).to(DEVICE) + vec_predict = self._model(vec_text, vec_pos, vec_rel, len_text)[0] + tag_predict = [self._tag_vocab.itos[i] for i in vec_predict] + return iobes_ranges([x for x in word_list], tag_predict) + + def load(self, save_path=DEFAULT_CONFIG['save_path']): + config = Config.load(save_path) + bilstmcrf = BiLstmCrf(config) + bilstmcrf.load() + self._model = bilstmcrf + self._word_vocab = config.word_vocab + self._tag_vocab = config.tag_vocab + self._pos_vocab = config.pos_vocab + + def test(self, test_path): + test_dataset = srl_tool.get_dataset(test_path) + test_score = self._validate(test_dataset) + logger.info('test score:{}'.format(test_score)) + + def _validate(self, dev_dataset): + self._model.eval() + dev_score_list = [] + for dev_item in tqdm(dev_dataset): + item_score = srl_tool.get_score(self._model, dev_item.text, dev_item.tag, dev_item.pos, dev_item.rel, + self._word_vocab, self._tag_vocab, self._pos_vocab) + dev_score_list.append(item_score) + return sum(dev_score_list) / len(dev_score_list) diff --git a/lightkg/ede/srl/tool.py b/lightkg/ede/srl/tool.py new file mode 100644 index 0000000..1806715 --- /dev/null +++ b/lightkg/ede/srl/tool.py @@ -0,0 +1,80 @@ +import torch +from torchtext.data import Dataset, Field, BucketIterator, ReversibleField +from torchtext.vocab import Vectors +from torchtext.datasets import SequenceTaggingDataset +from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score + +from ...base.tool import Tool +from ...utils.log import logger +from .config import DEVICE, DEFAULT_CONFIG + +seed = 2019 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +def light_tokenize(sequence: str): + return [sequence] + + +def post_process(arr, _): + return [[int(item) for item in arr_item] for arr_item in arr] + + +TEXT = Field(sequential=True, tokenize=light_tokenize, include_lengths=True) +POS = Field(sequential=True, tokenize=light_tokenize) +REL = Field(sequential=True, use_vocab=False, unk_token=None, pad_token=0, postprocessing=post_process) +TAG = Field(sequential=True, tokenize=light_tokenize, is_target=True, unk_token=None) +Fields = [('text', TEXT), ('pos', POS), ('rel', REL), ('tag', TAG)] + + +class SRLTool(Tool): + def get_dataset(self, path: str, fields=Fields, separator='\t'): + logger.info('loading dataset from {}'.format(path)) + st_dataset = SequenceTaggingDataset(path, fields=fields, separator=separator) + logger.info('successed loading dataset') + return st_dataset + + def get_vocab(self, *dataset): + logger.info('building word vocab...') + TEXT.build_vocab(*dataset) + logger.info('successed building word vocab') + logger.info('building pos vocab...') + POS.build_vocab(*dataset) + logger.info('successed building pos vocab') + logger.info('building tag vocab...') + TAG.build_vocab(*dataset) + logger.info('successed building tag vocab') + return TEXT.vocab, POS.vocab, TAG.vocab + + def get_vectors(self, path: str): + logger.info('loading vectors from {}'.format(path)) + vectors = Vectors(path) + logger.info('successed loading vectors') + return vectors + + def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, + sort_key=lambda x: len(x.text), sort_within_batch=True): + return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key, + sort_within_batch=sort_within_batch) + + def get_score(self, model, x, y, pos, rel, field_x, field_y, field_pos, score_type='f1'): + metrics_map = { + 'f1': f1_score, + 'p': precision_score, + 'r': recall_score, + 'acc': accuracy_score + } + metric_func = metrics_map[score_type] if score_type in metrics_map else metrics_map['f1'] + vec_x = torch.tensor([field_x.stoi[i] for i in x]) + len_vec_x = torch.tensor([len(vec_x)]).to(DEVICE) + vec_pos = torch.tensor([field_pos.stoi[i] for i in pos]) + vec_rel = torch.tensor([int(x) for x in rel]) + predict_y = model(vec_x.view(-1, 1).to(DEVICE), vec_pos.view(-1, 1).to(DEVICE), vec_rel.view(-1, 1).to(DEVICE), + len_vec_x)[0] + true_y = [field_y.stoi[i] for i in y] + assert len(true_y) == len(predict_y) + return metric_func(predict_y, true_y, average='micro') + + +srl_tool = SRLTool() diff --git a/lightkg/ede/srl/utils/convert.py b/lightkg/ede/srl/utils/convert.py new file mode 100644 index 0000000..2b25179 --- /dev/null +++ b/lightkg/ede/srl/utils/convert.py @@ -0,0 +1,51 @@ +def iobes_iob(tags): + """ + IOBES -> IOB + """ + new_tags = [] + for i, tag in enumerate(tags): + if tag == 'rel': + new_tags.append(tag) + elif tag.split('-')[0] == 'B': + new_tags.append(tag) + elif tag.split('-')[0] == 'I': + new_tags.append(tag) + elif tag.split('-')[0] == 'S': + new_tags.append(tag.replace('S-', 'B-')) + elif tag.split('-')[0] == 'E': + new_tags.append(tag.replace('E-', 'I-')) + elif tag.split('-')[0] == 'O': + new_tags.append(tag) + else: + raise Exception('Invalid format!') + return new_tags + + +def iob_ranges(words, tags): + """ + IOB -> Ranges + """ + assert len(words) == len(tags) + events = {} + + def check_if_closing_range(): + if i == len(tags) - 1 or tags[i + 1].split('-')[0] == 'O' or tags[i+1] == 'rel': + events[temp_type] = ''.join(words[begin: i + 1]) + + for i, tag in enumerate(tags): + if tag == 'rel': + events['rel'] = words[i] + elif tag.split('-')[0] == 'O': + pass + elif tag.split('-')[0] == 'B': + begin = i + temp_type = tag.split('-')[1] + check_if_closing_range() + elif tag.split('-')[0] == 'I': + check_if_closing_range() + return events + + +def iobes_ranges(words, tags): + new_tags = iobes_iob(tags) + return iob_ranges(words, new_tags) diff --git a/lightkg/ere/__init__.py b/lightkg/ere/__init__.py index e69de29..2046d7a 100644 --- a/lightkg/ere/__init__.py +++ b/lightkg/ere/__init__.py @@ -0,0 +1,2 @@ +from .re.module import RE +__all__ = ['RE'] \ No newline at end of file diff --git a/lightkg/ere/re/__init__.py b/lightkg/ere/re/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lightkg/ere/re/config.py b/lightkg/ere/re/config.py new file mode 100644 index 0000000..f31f302 --- /dev/null +++ b/lightkg/ere/re/config.py @@ -0,0 +1,20 @@ +from ...base.config import DEVICE +DEFAULT_CONFIG = { + 'lr': 0.002, + 'epoch': 5, + 'lr_decay': 0.05, + 'batch_size': 128, + 'dropout': 0.5, + 'static': False, + 'non_static': False, + 'embedding_dim': 300, + 'vector_path': '', + 'class_num': 0, + 'vocabulary_size': 0, + 'word_vocab': None, + 'tag_vocab': None, + 'save_path': './saves', + 'filter_num': 100, + 'filter_sizes': (3, 4, 5), + 'multichannel': False +} \ No newline at end of file diff --git a/lightkg/ere/re/model.py b/lightkg/ere/re/model.py new file mode 100644 index 0000000..bd3f858 --- /dev/null +++ b/lightkg/ere/re/model.py @@ -0,0 +1,110 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchtext.vocab import Vectors + +from ...utils.log import logger +from ...base.model import BaseConfig, BaseModel + +from .config import DEVICE, DEFAULT_CONFIG + + +class Config(BaseConfig): + def __init__(self, word_vocab, label_vocab, vector_path, **kwargs): + super(Config, self).__init__() + for name, value in DEFAULT_CONFIG.items(): + setattr(self, name, value) + self.word_vocab = word_vocab + self.label_vocab = label_vocab + self.class_num = len(self.label_vocab) + self.vocabulary_size = len(self.word_vocab) + self.vector_path = vector_path + for name, value in kwargs.items(): + setattr(self, name, value) + + +class TextCNN(BaseModel): + def __init__(self, args): + super(TextCNN, self).__init__(args) + + self.class_num = args.class_num + self.chanel_num = 1 + self.filter_num = args.filter_num + self.filter_sizes = args.filter_sizes + + self.vocabulary_size = args.vocabulary_size + self.embedding_dimension = args.embedding_dim + self.embedding = nn.Embedding(self.vocabulary_size, self.embedding_dimension).to(DEVICE) + if args.static: + logger.info('logging word vectors from {}'.format(args.vector_path)) + vectors = Vectors(args.vector_path).vectors + self.embedding = self.embedding.from_pretrained(vectors, freeze=not args.non_static).to(DEVICE) + if args.multichannel: + self.embedding2 = nn.Embedding(self.vocabulary_size, self.embedding_dimension).from_pretrained( + args.vectors).to(DEVICE) + self.chanel_num += 1 + else: + self.embedding2 = None + self.convs = nn.ModuleList( + [nn.Conv2d(self.chanel_num, self.filter_num, (size, self.embedding_dimension)) for size in + self.filter_sizes]).to(DEVICE) + self.dropout = nn.Dropout(args.dropout).to(DEVICE) + self.fc = nn.Linear(len(self.filter_sizes) * self.filter_num, self.class_num).to(DEVICE) + + def forward(self, x): + if self.embedding2: + x = torch.stack((self.embedding(x), self.embedding2(x)), dim=1).to(DEVICE) + else: + x = self.embedding(x).to(DEVICE) + x = x.unsqueeze(1) + x = [F.relu(conv(x)).squeeze(3) for conv in self.convs] + x = [F.max_pool1d(item, item.size(2)).squeeze(2) for item in x] + x = torch.cat(x, 1) + x = self.dropout(x) + logits = self.fc(x) + return logits + + +class LSTMClassifier(BaseModel): + def __init__(self, args): + super(LSTMClassifier, self).__init__(args) + + self.hidden_dim = 300 + self.class_num = args.class_num + self.batch_size = args.batch_size + + self.vocabulary_size = args.vocabulary_size + self.embedding_dimension = args.embedding_dim + + self.embedding = nn.Embedding(self.vocabulary_size, self.embedding_dimension).to(DEVICE) + if args.static: + self.embedding = self.embedding.from_pretrained(args.vectors, freeze=not args.non_static).to(DEVICE) + if args.multichannel: + self.embedding2 = nn.Embedding(self.vocabulary_size, self.embedding_dimension).from_pretrained( + args.vectors).to(DEVICE) + else: + self.embedding2 = None + + self.lstm = nn.LSTM(self.embedding_dimension, self.hidden_dim).to(DEVICE) + self.hidden2label = nn.Linear(self.hidden_dim, self.class_num).to(DEVICE) + self.hidden = self.init_hidden() + + def init_hidden(self, batch_size=None): + if batch_size is None: + batch_size = self.batch_size + + h0 = torch.zeros(1, batch_size, self.hidden_dim).to(DEVICE) + c0 = torch.zeros(1, batch_size, self.hidden_dim).to(DEVICE) + + return h0, c0 + + def forward(self, sentence): + embeds = self.embedding(sentence).to(DEVICE) + + x = embeds.permute(1, 0, 2).to(DEVICE) + self.hidden = self.init_hidden(sentence.size()[0]) + lstm_out, self.hidden = self.lstm(x, self.hidden) + lstm_out = lstm_out.to(DEVICE) + final = lstm_out[-1].to(DEVICE) + y = self.hidden2label(final) + return y diff --git a/lightkg/ere/re/module.py b/lightkg/ere/re/module.py new file mode 100644 index 0000000..c872d73 --- /dev/null +++ b/lightkg/ere/re/module.py @@ -0,0 +1,101 @@ +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from ...utils.learning import adjust_learning_rate +from ...utils.log import logger +from ...base.module import Module + +from .model import Config, TextCNN +from .config import DEVICE, DEFAULT_CONFIG +from .tool import re_tool, TEXT, LABEL +from .utils.preprocess import handle_line + +seed = 2019 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +class RE(Module): + """ + """ + + def __init__(self): + self._model = None + self._word_vocab = None + self._label_vocab = None + + def train(self, train_path, save_path=DEFAULT_CONFIG['save_path'], dev_path=None, vectors_path=None, **kwargs): + train_dataset = re_tool.get_dataset(train_path) + if dev_path: + dev_dataset = re_tool.get_dataset(dev_path) + word_vocab, label_vocab = re_tool.get_vocab(train_dataset, dev_dataset) + else: + word_vocab, label_vocab = re_tool.get_vocab(train_dataset) + self._word_vocab = word_vocab + self._label_vocab = label_vocab + train_iter = re_tool.get_iterator(train_dataset, batch_size=DEFAULT_CONFIG['batch_size']) + config = Config(word_vocab, label_vocab, save_path=save_path, vector_path=vectors_path, **kwargs) + textcnn = TextCNN(config) + # print(textcnn) + self._model = textcnn + optim = torch.optim.Adam(textcnn.parameters(), lr=config.lr) + for epoch in range(config.epoch): + textcnn.train() + acc_loss = 0 + for fuck in tqdm(train_iter): + optim.zero_grad() + logits = self._model(fuck.text) + item_loss = F.cross_entropy(logits, fuck.label) + acc_loss += item_loss.item() + item_loss.backward() + optim.step() + logger.info('epoch: {}, acc_loss: {}'.format(epoch, acc_loss)) + if dev_path: + dev_score = self._validate(dev_dataset) + logger.info('dev score:{}'.format(dev_score)) + adjust_learning_rate(optim, config.lr / (1 + (epoch + 1) * config.lr_decay)) + config.save() + textcnn.save() + + def predict(self, entity1: str, entity2: str, sentence: str): + self._model.eval() + text = handle_line(entity1, entity2, sentence) + vec_text = torch.tensor([self._word_vocab.stoi[x] for x in text]) + vec_text = vec_text.reshape(1, -1).to(DEVICE) + vec_predict = self._model(vec_text)[0] + soft_predict = torch.softmax(vec_predict, dim=0) + predict_prob, predict_index = torch.max(soft_predict.cpu().data, dim=0) + predict_class = self._label_vocab.itos[predict_index] + predict_prob = predict_prob.item() + return predict_prob, predict_class + + def load(self, save_path=DEFAULT_CONFIG['save_path']): + config = Config.load(save_path) + textcnn = TextCNN(config) + textcnn.load() + self._model = textcnn + self._word_vocab = config.word_vocab + self._label_vocab = config.label_vocab + self._check_vocab() + + def test(self, test_path): + self._model.eval() + test_dataset = re_tool.get_dataset(test_path) + test_score = self._validate(test_dataset) + logger.info('test score:{}'.format(test_score)) + + def _validate(self, dev_dataset, batch_size=DEFAULT_CONFIG['batch_size']): + self._model.eval() + dev_score_list = [] + dev_iter = re_tool.get_iterator(dev_dataset, batch_size=batch_size) + for dev_item in tqdm(dev_iter): + item_score = re_tool.get_score(self._model, dev_item.text, dev_item.label) + dev_score_list.append(item_score) + return sum(dev_score_list) / len(dev_score_list) + + def _check_vocab(self): + if not hasattr(TEXT, 'vocab'): + TEXT.vocab = self._word_vocab + if not hasattr(LABEL, 'vocab'): + LABEL.vocab = self._label_vocab diff --git a/lightkg/ere/re/tool.py b/lightkg/ere/re/tool.py new file mode 100644 index 0000000..09c63d3 --- /dev/null +++ b/lightkg/ere/re/tool.py @@ -0,0 +1,75 @@ +import re +import torch +from torchtext.data import Field, Iterator +from torchtext.vocab import Vectors +from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score + +from ...base.tool import Tool +from ...utils.log import logger +from .config import DEVICE, DEFAULT_CONFIG + +from .utils.dataset import REDataset + +seed = 2019 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +def light_tokenize(text): + return text + + +TEXT = Field(lower=True, tokenize=light_tokenize, batch_first=True) +LABEL = Field(sequential=False, unk_token=None) +Fields = [ + ('text', TEXT), + ('label', LABEL) + ] + + +class RETool(Tool): + def get_dataset(self, path: str, fields=Fields): + logger.info('loading dataset from {}'.format(path)) + re_dataset = REDataset(path, fields) + logger.info('successed loading dataset') + return re_dataset + + def get_vocab(self, *dataset): + logger.info('building word vocab...') + TEXT.build_vocab(*dataset) + logger.info('successed building word vocab') + logger.info('building label vocab...') + LABEL.build_vocab(*dataset) + logger.info('successed building label vocab') + return TEXT.vocab, LABEL.vocab + + def get_vectors(self, path: str): + logger.info('loading vectors from {}'.format(path)) + vectors = Vectors(path) + logger.info('successed loading vectors') + return vectors + + def get_iterator(self, dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, + sort_key=lambda x: len(x.text)): + return Iterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key) + + def get_score(self, model, texts, labels, score_type='f1'): + metrics_map = { + 'f1': f1_score, + 'p': precision_score, + 'r': recall_score, + 'acc': accuracy_score + } + metric_func = metrics_map[score_type] if score_type in metrics_map else metrics_map['f1'] + assert texts.size(0) == len(labels) + vec_predict = model(texts) + soft_predict = torch.softmax(vec_predict, dim=1) + predict_prob, predict_index = torch.max(soft_predict.cpu().data, dim=1) + # print('prob', predict_prob) + # print('index', predict_index) + # print('labels', labels) + labels = labels.view(-1).cpu().data.numpy() + return metric_func(predict_index, labels, average='micro') + + +re_tool = RETool() diff --git a/lightkg/ere/re/utils/dataset.py b/lightkg/ere/re/utils/dataset.py new file mode 100644 index 0000000..b42d54f --- /dev/null +++ b/lightkg/ere/re/utils/dataset.py @@ -0,0 +1,22 @@ +from torchtext.data import Dataset, Example +from .preprocess import handle_line + + +class REDataset(Dataset): + """Defines a Dataset of relation extraction format. + eg: + 钱钟书 辛笛 同门 与辛笛京沪唱和聽钱钟书与钱钟书是清华校友,钱钟书高辛笛两班。 + 元武 元华 unknown 于师傅在一次京剧表演中,选了元龙(洪金宝)、元楼(元奎)、元彪、成龙、元华、元武、元泰7人担任七小福的主角。 + """ + + def __init__(self, path, fields, encoding="utf-8", **kwargs): + examples = [] + with open(path, "r", encoding=encoding) as f: + for line in f: + chunks = line.split() + entity_1, entity_2, relation, sentence = tuple(chunks) + sentence_list = handle_line(entity_1, entity_2, sentence) + + examples.append(Example.fromlist((sentence_list, relation), fields)) + super(REDataset, self).__init__(examples, fields, **kwargs) + diff --git a/lightkg/ere/re/utils/preprocess.py b/lightkg/ere/re/utils/preprocess.py new file mode 100644 index 0000000..036c0ba --- /dev/null +++ b/lightkg/ere/re/utils/preprocess.py @@ -0,0 +1,24 @@ +import jieba + + +def handle_line(entity1, entity2, sentence, begin_e1_token='', end_e1_token='', begin_e2_token='', + end_e2_token=''): + assert entity1 in sentence + assert entity2 in sentence + sentence = sentence.replace(entity1, begin_e1_token + entity1 + end_e1_token) + sentence = sentence.replace(entity2, begin_e2_token + entity2 + end_e2_token) + sentence = ' '.join(jieba.cut(sentence)) + sentence = sentence.replace('< e1 >', begin_e1_token) + sentence = sentence.replace('< / e1 >', end_e1_token) + sentence = sentence.replace('< e2 >', begin_e2_token) + sentence = sentence.replace('< / e2 >', end_e2_token) + return sentence.split() + + +if __name__ == '__main__': + test_str = '曾经沧海难为水哈哈谁说不是呢?!呵呵 低头不见抬头见' + e1 = '沧海' + e2 = '不是' + print(list(jieba.cut(test_str))) + print(handle_line(e1, e2, test_str)) + diff --git a/lightkg/erl/__init__.py b/lightkg/erl/__init__.py index e69de29..0cd1c68 100644 --- a/lightkg/erl/__init__.py +++ b/lightkg/erl/__init__.py @@ -0,0 +1,2 @@ +from .ner.module import NER +__all__ = ['NER'] diff --git a/lightkg/erl/ner/__init__.py b/lightkg/erl/ner/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lightkg/erl/ner/config.py b/lightkg/erl/ner/config.py new file mode 100644 index 0000000..96ed838 --- /dev/null +++ b/lightkg/erl/ner/config.py @@ -0,0 +1,19 @@ +from ...base.config import DEVICE +DEFAULT_CONFIG = { + 'lr': 0.02, + 'epoch': 30, + 'lr_decay': 0.05, + 'batch_size': 128, + 'dropout': 0.5, + 'static': False, + 'non_static': False, + 'embedding_dim': 300, + 'num_layers': 2, + 'pad_index': 1, + 'vector_path': '', + 'tag_num': 0, + 'vocabulary_size': 0, + 'word_vocab': None, + 'tag_vocab': None, + 'save_path': './saves' +} \ No newline at end of file diff --git a/lightkg/erl/ner/model.py b/lightkg/erl/ner/model.py new file mode 100644 index 0000000..089b1e9 --- /dev/null +++ b/lightkg/erl/ner/model.py @@ -0,0 +1,89 @@ +import torch +import torch.nn as nn +from torchcrf import CRF +from torchtext.vocab import Vectors +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +from ...utils.log import logger +from .config import DEVICE, DEFAULT_CONFIG +from ...base.model import BaseConfig, BaseModel + + +class Config(BaseConfig): + def __init__(self, word_vocab, tag_vocab, vector_path, **kwargs): + super(Config, self).__init__() + for name, value in DEFAULT_CONFIG.items(): + setattr(self, name, value) + self.word_vocab = word_vocab + self.tag_vocab = tag_vocab + self.tag_num = len(self.tag_vocab) + self.vocabulary_size = len(self.word_vocab) + self.vector_path = vector_path + for name, value in kwargs.items(): + setattr(self, name, value) + + +class BiLstmCrf(BaseModel): + def __init__(self, args): + super(BiLstmCrf, self).__init__(args) + self.args = args + self.hidden_dim = 300 + self.tag_num = args.tag_num + self.batch_size = args.batch_size + self.bidirectional = True + self.num_layers = args.num_layers + self.pad_index = args.pad_index + self.dropout = args.dropout + self.save_path = args.save_path + + vocabulary_size = args.vocabulary_size + embedding_dimension = args.embedding_dim + + self.embedding = nn.Embedding(vocabulary_size, embedding_dimension).to(DEVICE) + if args.static: + logger.info('logging word vectors from {}'.format(args.vector_path)) + vectors = Vectors(args.vector_path).vectors + self.embedding = self.embedding.from_pretrained(vectors, freeze=not args.non_static).to(DEVICE) + + self.lstm = nn.LSTM(embedding_dimension, self.hidden_dim // 2, bidirectional=self.bidirectional, + num_layers=self.num_layers, dropout=self.dropout).to(DEVICE) + self.hidden2label = nn.Linear(self.hidden_dim, self.tag_num).to(DEVICE) + self.crflayer = CRF(self.tag_num).to(DEVICE) + + # self.init_weight() + + def init_weight(self): + nn.init.xavier_normal_(self.embedding.weight) + for name, param in self.lstm.named_parameters(): + if 'weight' in name: + nn.init.xavier_normal_(param) + nn.init.xavier_normal_(self.hidden2label.weight) + + def init_hidden(self, batch_size=None): + if batch_size is None: + batch_size = self.batch_size + + h0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim // 2).to(DEVICE) + c0 = torch.zeros(self.num_layers * 2, batch_size, self.hidden_dim // 2).to(DEVICE) + + return h0, c0 + + def loss(self, x, sent_lengths, y): + mask = torch.ne(x, self.pad_index) + emissions = self.lstm_forward(x, sent_lengths) + return self.crflayer(emissions, y, mask=mask) + + def forward(self, x, sent_lengths): + mask = torch.ne(x, self.pad_index) + emissions = self.lstm_forward(x, sent_lengths) + return self.crflayer.decode(emissions, mask=mask) + + def lstm_forward(self, sentence, sent_lengths): + x = self.embedding(sentence.to(DEVICE)).to(DEVICE) + x = pack_padded_sequence(x, sent_lengths) + self.hidden = self.init_hidden(batch_size=len(sent_lengths)) + lstm_out, self.hidden = self.lstm(x, self.hidden) + lstm_out, new_batch_size = pad_packed_sequence(lstm_out) + assert torch.equal(sent_lengths, new_batch_size.to(DEVICE)) + y = self.hidden2label(lstm_out.to(DEVICE)) + return y.to(DEVICE) diff --git a/lightkg/erl/ner/module.py b/lightkg/erl/ner/module.py new file mode 100644 index 0000000..f919336 --- /dev/null +++ b/lightkg/erl/ner/module.py @@ -0,0 +1,87 @@ +import torch +from tqdm import tqdm + +from ...utils.learning import adjust_learning_rate +from ...utils.log import logger +from ...base.module import Module + +from .config import DEVICE, DEFAULT_CONFIG +from .model import Config, BiLstmCrf +from .tool import ner_tool +from .utils.convert import iob_ranges + +seed = 2019 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +class NER(Module): + """ + """ + def __init__(self): + self._model = None + self._word_vocab = None + self._tag_vocab = None + + def train(self, train_path, save_path=DEFAULT_CONFIG['save_path'], dev_path=None, vectors_path=None, **kwargs): + train_dataset = ner_tool.get_dataset(train_path) + if dev_path: + dev_dataset = ner_tool.get_dataset(dev_path) + word_vocab, tag_vocab = ner_tool.get_vocab(train_dataset, dev_dataset) + else: + word_vocab, tag_vocab = ner_tool.get_vocab(train_dataset) + self._word_vocab = word_vocab + self._tag_vocab = tag_vocab + train_iter = ner_tool.get_iterator(train_dataset, batch_size=DEFAULT_CONFIG['batch_size']) + config = Config(word_vocab, tag_vocab, save_path=save_path, vector_path=vectors_path, **kwargs) + bilstmcrf = BiLstmCrf(config) + self._model = bilstmcrf + optim = torch.optim.Adam(bilstmcrf.parameters(), lr=config.lr) + for epoch in range(config.epoch): + bilstmcrf.train() + acc_loss = 0 + for item in tqdm(train_iter): + bilstmcrf.zero_grad() + item_text_sentences = item.text[0] + item_text_lengths = item.text[1] + item_loss = (-bilstmcrf.loss(item_text_sentences, item_text_lengths, item.tag)) / item.tag.size(1) + acc_loss += item_loss.view(-1).cpu().data.tolist()[0] + item_loss.backward() + optim.step() + logger.info('epoch: {}, acc_loss: {}'.format(epoch, acc_loss)) + if dev_path: + dev_score = self._validate(dev_dataset) + logger.info('dev score:{}'.format(dev_score)) + + adjust_learning_rate(optim, config.lr / (1 + (epoch + 1) * config.lr_decay)) + config.save() + bilstmcrf.save() + + def predict(self, text): + self._model.eval() + vec_text = torch.tensor([self._word_vocab.stoi[x] for x in text]) + len_text = torch.tensor([len(vec_text)]).to(DEVICE) + vec_predict = self._model(vec_text.view(-1, 1).to(DEVICE), len_text)[0] + tag_predict = [self._tag_vocab.itos[i] for i in vec_predict] + return iob_ranges([x for x in text], tag_predict) + + def load(self, save_path=DEFAULT_CONFIG['save_path']): + config = Config.load(save_path) + bilstmcrf = BiLstmCrf(config) + bilstmcrf.load() + self._model = bilstmcrf + self._word_vocab = config.word_vocab + self._tag_vocab = config.tag_vocab + + def test(self, test_path): + test_dataset = ner_tool.get_dataset(test_path) + test_score = self._validate(test_dataset) + logger.info('test score:{}'.format(test_score)) + + def _validate(self, dev_dataset): + self._model.eval() + dev_score_list = [] + for dev_item in tqdm(dev_dataset): + item_score = ner_tool.get_score(self._model, dev_item.text, dev_item.tag, self._word_vocab, self._tag_vocab) + dev_score_list.append(item_score) + return sum(dev_score_list) / len(dev_score_list) diff --git a/lightkg/erl/ner/tool.py b/lightkg/erl/ner/tool.py new file mode 100644 index 0000000..21835e5 --- /dev/null +++ b/lightkg/erl/ner/tool.py @@ -0,0 +1,68 @@ +import torch +from torchtext.data import Dataset, Field, BucketIterator, ReversibleField +from torchtext.vocab import Vectors +from torchtext.datasets import SequenceTaggingDataset +from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score + +from ...base.tool import Tool +from ...utils.log import logger +from .config import DEVICE, DEFAULT_CONFIG + +seed = 2019 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +def light_tokenize(sequence: str): + return [sequence] + + +TEXT = Field(sequential=True, tokenize=light_tokenize, include_lengths=True) +TAG = ReversibleField(sequential=True, tokenize=light_tokenize, is_target=True, unk_token=None) +Fields = [('text', TEXT), ('tag', TAG)] + + +class NERTool(Tool): + def get_dataset(self, path: str, fields=Fields, separator=' '): + logger.info('loading dataset from {}'.format(path)) + st_dataset = SequenceTaggingDataset(path, fields=fields, separator=separator) + logger.info('successed loading dataset') + return st_dataset + + def get_vocab(self, *dataset): + logger.info('building word vocab...') + TEXT.build_vocab(*dataset) + logger.info('successed building word vocab') + logger.info('building tag vocab...') + TAG.build_vocab(*dataset) + logger.info('successed building tag vocab') + return TEXT.vocab, TAG.vocab + + def get_vectors(self, path: str): + logger.info('loading vectors from {}'.format(path)) + vectors = Vectors(path) + logger.info('successed loading vectors') + return vectors + + def get_iterator(self, dataset: Dataset, batch_size=DEFAULT_CONFIG['batch_size'], device=DEVICE, + sort_key=lambda x: len(x.text), sort_within_batch=True): + return BucketIterator(dataset, batch_size=batch_size, device=device, sort_key=sort_key, + sort_within_batch=sort_within_batch) + + def get_score(self, model, x, y, field_x, field_y, score_type='f1'): + metrics_map = { + 'f1': f1_score, + 'p': precision_score, + 'r': recall_score, + 'acc': accuracy_score + } + metric_func = metrics_map[score_type] if score_type in metrics_map else metrics_map['f1'] + vec_x = torch.tensor([field_x.stoi[i] for i in x]) + len_vec_x = torch.tensor([len(vec_x)]).to(DEVICE) + predict_y = model(vec_x.view(-1, 1).to(DEVICE), len_vec_x)[0] + true_y = [field_y.stoi[i] for i in y] + assert len(true_y) == len(predict_y) + return metric_func(predict_y, true_y, average='micro') + + +ner_tool = NERTool() diff --git a/lightkg/erl/ner/utils/convert.py b/lightkg/erl/ner/utils/convert.py new file mode 100644 index 0000000..5f954d7 --- /dev/null +++ b/lightkg/erl/ner/utils/convert.py @@ -0,0 +1,26 @@ +def iob_ranges(words, tags): + """ + IOB -> Ranges + """ + assert len(words) == len(tags) + ranges = [] + + def check_if_closing_range(): + if i == len(tags) - 1 or tags[i + 1].split('_')[0] == 'O': + ranges.append({ + 'entity': ''.join(words[begin: i + 1]), + 'type': temp_type, + 'start': begin, + 'end': i + }) + + for i, tag in enumerate(tags): + if tag.split('_')[0] == 'O': + pass + elif tag.split('_')[0] == 'B': + begin = i + temp_type = tag.split('_')[1] + check_if_closing_range() + elif tag.split('_')[0] == 'I': + check_if_closing_range() + return ranges diff --git a/lightkg/ksq/__init__.py b/lightkg/ksq/__init__.py new file mode 100644 index 0000000..e69de29