add ner models

This commit is contained in:
loujie0822 2020-04-29 18:26:47 +08:00
parent cf107e0683
commit f9bd431a12
4 changed files with 208 additions and 1 deletions

View File

@ -4,6 +4,9 @@ import os
import pickle
import sys
logger = logging.getLogger()
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
def pickle_dump_large_file(obj, filepath):
max_bytes = 2 ** 31 - 1
@ -44,3 +47,56 @@ def write_json(obj, path):
with open(path, 'wb') as f:
f.write(json.dumps(obj, indent=2, ensure_ascii=False).
encode('utf-8'))
def _read_conll(path, encoding='utf-8', indexes=2, dropna=True):
"""
Construct a generator to read conll items.
:param path: file path
:param encoding: file's encoding, default: utf-8
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None
:param dropna: weather to ignore and drop invalid data,
:if False, raise ValueError when reading invalid data. default: True
:return: generator, every time yield (line number, conll item)
"""
def parse_conll(sample):
sample = list(map(list, zip(*sample)))
sample = [sample[i] for i in range(indexes)]
for f in sample:
if len(f) <= 0:
raise ValueError('empty field')
return sample
with open(path, 'r', encoding=encoding) as f:
sample = []
start = next(f).strip()
if start != '':
sample.append(start.split())
for line_idx, line in enumerate(f, 1):
line = line[:-1]
if line == '':
if len(sample):
try:
res = parse_conll(sample)
sample = []
yield line_idx, res
except Exception as e:
if dropna:
logger.warning('Invalid instance which ends at line: {} has been dropped.'.format(line_idx))
continue
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx))
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
# if len(sample) > 0:
# try:
# res = parse_conll(sample)
# yield line_idx, res
# except Exception as e:
# if dropna:
# return
# logger.error('invalid instance ends at line: {}'.format(line_idx))
# raise e

151
utils/metrics.py Normal file
View File

@ -0,0 +1,151 @@
from abc import abstractmethod
from collections import defaultdict
class MetricBase(object):
@abstractmethod
def evaluate(self, *args, **kwargs):
raise NotImplementedError
@abstractmethod
def get_metric(self, reset=True):
raise NotImplemented
def __call__(self, p_ids, pred, eval_file):
return self.evaluate(p_ids, pred, eval_file)
def _bmeso_tag_to_spans(tags, ignore_labels=None):
"""
给定一个tags的lis比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']
返回[('singer', (1, 4))] (左闭右开区间)
:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
:return: List[Tuple[str, List[int, int]]]. [(label[start, end])]
"""
ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = []
prev_bmes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
spans[-1][1][1] = idx
elif bmes_tag == 'o':
pass
else:
spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
]
class SpanFPreRecMetric(MetricBase):
def __init__(self, tag_type, pred=None, target=None, encoding_type='bmeso',
only_gross=True, f_type='micro', beta=1):
self.tag_type = tag_type
self.only_gross = only_gross
self.f_type = f_type
self.beta = beta
self.beta_square = self.beta ** 2
self.encoding_type = encoding_type
if self.encoding_type == 'bmeso':
self.tag_to_span_func = _bmeso_tag_to_spans
self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)
def evaluate(self, p_ids, preds, eval_file):
answer_dict = {}
for p_id, pred in zip(p_ids.tolist(), preds.tolist()):
gold_ = eval_file[p_id].gold_answer
pred_ = [self.tag_type[tag] for tag in pred]
# gold_str_tags = [self.tag_type[tag] for tag in gold]
pred_spans = self.tag_to_span_func(pred_)
gold_spans = self.tag_to_span_func(gold_)
answer_dict[str(p_id)] = [pred_spans, gold_spans]
for span in pred_spans:
if span in gold_spans:
self._true_positives[span[0]] += 1
gold_spans.remove(span)
else:
self._false_positives[span[0]] += 1
for span in gold_spans:
self._false_negatives[span[0]] += 1
return answer_dict
def get_metric(self, reset=True):
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
evaluate_result = {}
if not self.only_gross or self.f_type == 'macro':
tags = set(self._false_negatives.keys())
tags.update(set(self._false_positives.keys()))
tags.update(set(self._true_positives.keys()))
f_sum = 0
pre_sum = 0
rec_sum = 0
for tag in tags:
tp = self._true_positives[tag]
fn = self._false_negatives[tag]
fp = self._false_positives[tag]
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp)
f_sum += f
pre_sum += pre
rec_sum += rec
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag)
pre_key = 'pre-{}'.format(tag)
rec_key = 'rec-{}'.format(tag)
evaluate_result[f_key] = f
evaluate_result[pre_key] = pre
evaluate_result[rec_key] = rec
if self.f_type == 'macro':
evaluate_result['f'] = f_sum / len(tags)
evaluate_result['pre'] = pre_sum / len(tags)
evaluate_result['rec'] = rec_sum / len(tags)
if self.f_type == 'micro':
f, pre, rec,em,pre_num,gold_num = self._compute_f_pre_rec(sum(self._true_positives.values()),
sum(self._false_negatives.values()),
sum(self._false_positives.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec
evaluate_result['em'] = em
evaluate_result['pre'] = pre_num
evaluate_result['gold'] = gold_num
if reset:
self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)
for key, value in evaluate_result.items():
evaluate_result[key] = round(value, 6)
print(evaluate_result)
return evaluate_result
def _compute_f_pre_rec(self, tp, fn, fp):
"""
:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13)
return f, pre, rec,tp,fp + tp,fn + tp

0
utils/ner_loader.py Normal file
View File

View File

@ -4,7 +4,7 @@ from layers.encoders.transformers.bert.bert_optimization import BertAdam
def set_optimizer(args, model, train_steps=None):
if args.use_bert:
if args.warm_up:
print('using BertAdam')
param_optimizer = list(model.named_parameters())
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]