add ner models
This commit is contained in:
parent
cf107e0683
commit
f9bd431a12
@ -4,6 +4,9 @@ import os
|
||||
import pickle
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger()
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def pickle_dump_large_file(obj, filepath):
|
||||
max_bytes = 2 ** 31 - 1
|
||||
@ -44,3 +47,56 @@ def write_json(obj, path):
|
||||
with open(path, 'wb') as f:
|
||||
f.write(json.dumps(obj, indent=2, ensure_ascii=False).
|
||||
encode('utf-8'))
|
||||
|
||||
|
||||
def _read_conll(path, encoding='utf-8', indexes=2, dropna=True):
|
||||
"""
|
||||
Construct a generator to read conll items.
|
||||
|
||||
:param path: file path
|
||||
:param encoding: file's encoding, default: utf-8
|
||||
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None
|
||||
:param dropna: weather to ignore and drop invalid data,
|
||||
:if False, raise ValueError when reading invalid data. default: True
|
||||
:return: generator, every time yield (line number, conll item)
|
||||
"""
|
||||
|
||||
def parse_conll(sample):
|
||||
sample = list(map(list, zip(*sample)))
|
||||
sample = [sample[i] for i in range(indexes)]
|
||||
for f in sample:
|
||||
if len(f) <= 0:
|
||||
raise ValueError('empty field')
|
||||
return sample
|
||||
|
||||
with open(path, 'r', encoding=encoding) as f:
|
||||
sample = []
|
||||
start = next(f).strip()
|
||||
if start != '':
|
||||
sample.append(start.split())
|
||||
for line_idx, line in enumerate(f, 1):
|
||||
line = line[:-1]
|
||||
if line == '':
|
||||
if len(sample):
|
||||
try:
|
||||
res = parse_conll(sample)
|
||||
sample = []
|
||||
yield line_idx, res
|
||||
except Exception as e:
|
||||
if dropna:
|
||||
logger.warning('Invalid instance which ends at line: {} has been dropped.'.format(line_idx))
|
||||
continue
|
||||
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx))
|
||||
elif line.startswith('#'):
|
||||
continue
|
||||
else:
|
||||
sample.append(line.split('\t'))
|
||||
# if len(sample) > 0:
|
||||
# try:
|
||||
# res = parse_conll(sample)
|
||||
# yield line_idx, res
|
||||
# except Exception as e:
|
||||
# if dropna:
|
||||
# return
|
||||
# logger.error('invalid instance ends at line: {}'.format(line_idx))
|
||||
# raise e
|
||||
|
151
utils/metrics.py
Normal file
151
utils/metrics.py
Normal file
@ -0,0 +1,151 @@
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class MetricBase(object):
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_metric(self, reset=True):
|
||||
raise NotImplemented
|
||||
|
||||
def __call__(self, p_ids, pred, eval_file):
|
||||
return self.evaluate(p_ids, pred, eval_file)
|
||||
|
||||
|
||||
def _bmeso_tag_to_spans(tags, ignore_labels=None):
|
||||
"""
|
||||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。
|
||||
返回[('singer', (1, 4))] (左闭右开区间)
|
||||
|
||||
:param tags: List[str],
|
||||
:param ignore_labels: List[str], 在该list中的label将被忽略
|
||||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
|
||||
"""
|
||||
ignore_labels = set(ignore_labels) if ignore_labels else set()
|
||||
|
||||
spans = []
|
||||
prev_bmes_tag = None
|
||||
for idx, tag in enumerate(tags):
|
||||
tag = tag.lower()
|
||||
bmes_tag, label = tag[:1], tag[2:]
|
||||
if bmes_tag in ('b', 's'):
|
||||
spans.append((label, [idx, idx]))
|
||||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
|
||||
spans[-1][1][1] = idx
|
||||
elif bmes_tag == 'o':
|
||||
pass
|
||||
else:
|
||||
spans.append((label, [idx, idx]))
|
||||
prev_bmes_tag = bmes_tag
|
||||
return [(span[0], (span[1][0], span[1][1] + 1))
|
||||
for span in spans
|
||||
if span[0] not in ignore_labels
|
||||
]
|
||||
|
||||
|
||||
class SpanFPreRecMetric(MetricBase):
|
||||
def __init__(self, tag_type, pred=None, target=None, encoding_type='bmeso',
|
||||
only_gross=True, f_type='micro', beta=1):
|
||||
self.tag_type = tag_type
|
||||
self.only_gross = only_gross
|
||||
self.f_type = f_type
|
||||
self.beta = beta
|
||||
self.beta_square = self.beta ** 2
|
||||
self.encoding_type = encoding_type
|
||||
if self.encoding_type == 'bmeso':
|
||||
self.tag_to_span_func = _bmeso_tag_to_spans
|
||||
|
||||
self._true_positives = defaultdict(int)
|
||||
self._false_positives = defaultdict(int)
|
||||
self._false_negatives = defaultdict(int)
|
||||
|
||||
def evaluate(self, p_ids, preds, eval_file):
|
||||
answer_dict = {}
|
||||
for p_id, pred in zip(p_ids.tolist(), preds.tolist()):
|
||||
gold_ = eval_file[p_id].gold_answer
|
||||
pred_ = [self.tag_type[tag] for tag in pred]
|
||||
# gold_str_tags = [self.tag_type[tag] for tag in gold]
|
||||
pred_spans = self.tag_to_span_func(pred_)
|
||||
gold_spans = self.tag_to_span_func(gold_)
|
||||
answer_dict[str(p_id)] = [pred_spans, gold_spans]
|
||||
for span in pred_spans:
|
||||
if span in gold_spans:
|
||||
self._true_positives[span[0]] += 1
|
||||
gold_spans.remove(span)
|
||||
else:
|
||||
self._false_positives[span[0]] += 1
|
||||
for span in gold_spans:
|
||||
self._false_negatives[span[0]] += 1
|
||||
return answer_dict
|
||||
|
||||
def get_metric(self, reset=True):
|
||||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
|
||||
evaluate_result = {}
|
||||
if not self.only_gross or self.f_type == 'macro':
|
||||
tags = set(self._false_negatives.keys())
|
||||
tags.update(set(self._false_positives.keys()))
|
||||
tags.update(set(self._true_positives.keys()))
|
||||
f_sum = 0
|
||||
pre_sum = 0
|
||||
rec_sum = 0
|
||||
for tag in tags:
|
||||
tp = self._true_positives[tag]
|
||||
fn = self._false_negatives[tag]
|
||||
fp = self._false_positives[tag]
|
||||
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp)
|
||||
f_sum += f
|
||||
pre_sum += pre
|
||||
rec_sum += rec
|
||||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
|
||||
f_key = 'f-{}'.format(tag)
|
||||
pre_key = 'pre-{}'.format(tag)
|
||||
rec_key = 'rec-{}'.format(tag)
|
||||
evaluate_result[f_key] = f
|
||||
evaluate_result[pre_key] = pre
|
||||
evaluate_result[rec_key] = rec
|
||||
|
||||
if self.f_type == 'macro':
|
||||
evaluate_result['f'] = f_sum / len(tags)
|
||||
evaluate_result['pre'] = pre_sum / len(tags)
|
||||
evaluate_result['rec'] = rec_sum / len(tags)
|
||||
|
||||
if self.f_type == 'micro':
|
||||
f, pre, rec,em,pre_num,gold_num = self._compute_f_pre_rec(sum(self._true_positives.values()),
|
||||
sum(self._false_negatives.values()),
|
||||
sum(self._false_positives.values()))
|
||||
evaluate_result['f'] = f
|
||||
evaluate_result['pre'] = pre
|
||||
evaluate_result['rec'] = rec
|
||||
evaluate_result['em'] = em
|
||||
evaluate_result['pre'] = pre_num
|
||||
evaluate_result['gold'] = gold_num
|
||||
|
||||
|
||||
if reset:
|
||||
self._true_positives = defaultdict(int)
|
||||
self._false_positives = defaultdict(int)
|
||||
self._false_negatives = defaultdict(int)
|
||||
|
||||
for key, value in evaluate_result.items():
|
||||
evaluate_result[key] = round(value, 6)
|
||||
print(evaluate_result)
|
||||
return evaluate_result
|
||||
|
||||
def _compute_f_pre_rec(self, tp, fn, fp):
|
||||
"""
|
||||
|
||||
:param tp: int, true positive
|
||||
:param fn: int, false negative
|
||||
:param fp: int, false positive
|
||||
:return: (f, pre, rec)
|
||||
"""
|
||||
pre = tp / (fp + tp + 1e-13)
|
||||
rec = tp / (fn + tp + 1e-13)
|
||||
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13)
|
||||
|
||||
|
||||
return f, pre, rec,tp,fp + tp,fn + tp
|
0
utils/ner_loader.py
Normal file
0
utils/ner_loader.py
Normal file
@ -4,7 +4,7 @@ from layers.encoders.transformers.bert.bert_optimization import BertAdam
|
||||
|
||||
|
||||
def set_optimizer(args, model, train_steps=None):
|
||||
if args.use_bert:
|
||||
if args.warm_up:
|
||||
print('using BertAdam')
|
||||
param_optimizer = list(model.named_parameters())
|
||||
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
|
||||
|
Loading…
Reference in New Issue
Block a user