Add files via upload

This commit is contained in:
missQian 2020-10-04 21:55:03 +08:00 committed by GitHub
parent ce53f4c43a
commit 1eaf0b33f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 983 additions and 0 deletions

View File

@ -0,0 +1,280 @@
# -*- coding: utf-8 -*-
import numpy as np
import json
class Actions(object):
trigger_gen = 'TRIGGER-GEN-'
entity_shift = 'ENTITY-SHIFT'
entity_gen = 'ENTITY-GEN-'
entity_back = 'ENTITY-BACK'
o_delete = 'O-DELETE'
event_gen = 'EVENT-GEN-'
shift = 'SHIFT'
no_pass = 'NO-PASS'
left_pass = 'LEFT-PASS'
right_pass = 'RIGHT-PASS'
back_shift = 'DUAL-SHIFT'
# ------------------------------
copy_shift = 'COPY-SHIFT'
# ------------------------------
# event_shift = 'EVENT-SHIFT-'
# event_reduce = 'EVENT-REDUCE-'
# no_reduce = 'NO-REDUCE'
_PASS_PLACEHOLDER = 'PASS'
_SHIFT_PLACEHOLDER = 'SHIFT'
def __init__(self, action_dict, ent_dict, tri_dict, arg_dict, with_copy_shift=True):
self.entity_shift_id = action_dict[Actions.entity_shift]
self.entity_back_id = action_dict[Actions.entity_back]
self.o_del_id = action_dict[Actions.o_delete]
self.shift_id = action_dict[Actions.shift]
self.left_pass_id = action_dict[Actions.left_pass]
self.right_pass_id = action_dict[Actions.right_pass]
if with_copy_shift:
self.copy_shift_id = action_dict[Actions.copy_shift]
else:
self.back_shift_id = action_dict[Actions.back_shift]
self.no_pass_id = action_dict[Actions.no_pass]
self.ent_gen_group = set()
self.tri_gen_group = set()
#self.event_reduce_group = set()
#self.event_shift_group = set()
self.event_gen_group = set()
self.act_to_ent_id = {}
self.act_to_tri_id = {}
self.act_to_arg_id = {}
self.arg_to_act_id = {}
self.act_id_to_str = {v:k for k, v in action_dict.items()}
for name, id in action_dict.items():
if name.startswith(Actions.entity_gen):
self.ent_gen_group.add(id)
self.act_to_ent_id[id] = ent_dict[name[len(Actions.entity_gen):]]
elif name.startswith(Actions.trigger_gen):
self.tri_gen_group.add(id)
self.act_to_tri_id[id] = tri_dict[name[len(Actions.trigger_gen):]]
elif name.startswith(Actions.event_gen):
self.event_gen_group.add(id)
self.act_to_arg_id[id] = arg_dict[name[len(Actions.event_gen):]]
for k,v in self.act_to_arg_id.items():
self.arg_to_act_id[v] = k
def get_act_ids_by_args(self, arg_type_ids):
acts = []
for arg_id in arg_type_ids:
acts.append(self.arg_to_act_id[arg_id])
return acts
def get_ent_gen_list(self):
return list(self.ent_gen_group)
def get_tri_gen_list(self):
return list(self.tri_gen_group)
def get_event_gen_list(self):
return list(self.event_gen_group)
def to_act_str(self, act_id):
return self.act_id_to_str[act_id]
def to_ent_id(self, act_id):
return self.act_to_ent_id[act_id]
def to_tri_id(self, act_id):
return self.act_to_tri_id[act_id]
def to_arg_id(self, act_id):
return self.act_to_arg_id[act_id]
# action check
def is_ent_shift(self, act_id):
return self.entity_shift_id == act_id
def is_ent_back(self, act_id):
return self.entity_back_id == act_id
def is_o_del(self, act_id):
return self.o_del_id == act_id
def is_shift(self, act_id):
return self.shift_id == act_id
def is_back_shift(self, act_id):
return self.back_shift_id == act_id
def is_copy_shift(self, act_id):
return self.copy_shift_id == act_id
def is_no_pass(self, act_id):
return self.no_pass_id == act_id
def is_left_pass(self, act_id):
return self.left_pass_id == act_id
def is_right_pass(self, act_id):
return self.right_pass_id == act_id
def is_ent_gen(self, act_id):
return act_id in self.ent_gen_group
def is_tri_gen(self, act_id):
return act_id in self.tri_gen_group
def is_event_gen(self, act_id):
return act_id in self.event_gen_group
@staticmethod
def make_oracle(tokens, triggers, ents, args, with_copy_shift=True):
'''
In this dataset, there are no nested entities sharing common start idx,
therefore, we push back words from e to buffer exclude the first
word in e.
# TODO with_copy_shift
trigger_list : [(idx, event_type)...] e.g. [(27, '500')...]
ent_list : [[start, end, ent_type],...] e.g. [[3, 3, '402']...]
arg_list : [[arg_start, arg_end, trigger_idx, role_type]] e.g. [[21, 21, 27, 'Vehicle'],...]
'''
ent_dic = {ent[0]:ent for ent in ents}
trigger_dic = {tri[0]:tri[1] for tri in triggers}
# (tri_idx, arg_start_idx)
arg_dic = {(arg[2], arg[0]):arg for arg in args}
# for tri in trigger_dic.keys():
# if tri in ent_dic:
# print(tri,'======', ent_dic[tri])
actions = []
# GEN entities and triggers
tri_actions = {} # start_idx : actions list
ent_actions = {} # start_idx : actions list
for tri in triggers:
idx, event_type = tri
tri_actions[idx] = [Actions.trigger_gen + event_type]
for ent in ents:
start, end, ent_type, ref = ent
act = []
for _ in range(start, end + 1):
act.append(Actions.entity_shift)
act.append(Actions.entity_gen + ent_type)
act.append(Actions.entity_back)
ent_actions[start] = act
for tri_i in trigger_dic:
cur_actions = tri_actions[tri_i]
for j in range(tri_i - 1, -1, -1):
if j in trigger_dic:
cur_actions.append(Actions.no_pass)
if j in ent_dic:
key = (tri_i, j)
if key in arg_dic:
arg_start, arg_end, trigger_idx, role_type = arg_dic[key]
cur_actions.append(Actions.left_pass)
#cur_actions.append(Actions.event_gen + role_type)
else:
cur_actions.append(Actions.no_pass)
if tri_i in ent_dic:
if with_copy_shift:
cur_actions.append(Actions.copy_shift)
else:
cur_actions.append(Actions.back_shift)
else:
cur_actions.append(Actions.shift)
for ent_i in ent_dic:
cur_actions = ent_actions[ent_i]
# Take into account that a word can be a trigger as well as an entity start
if with_copy_shift and ent_i in trigger_dic:
if (ent_i, ent_i) in arg_dic:
arg_start, arg_end, trigger_idx, role_type = arg_dic[(ent_i, ent_i)]
cur_actions.append(Actions.right_pass)
#cur_actions.append(Actions.event_gen + role_type)
else:
cur_actions.append(Actions.no_pass)
for j in range(ent_i - 1, -1, -1):
if j in trigger_dic:
key = (j, ent_i)
if key in arg_dic:
arg_start, arg_end, trigger_idx, role_type = arg_dic[key]
cur_actions.append(Actions.right_pass)
#cur_actions.append(Actions.event_gen + role_type)
else:
cur_actions.append(Actions.no_pass)
if j in ent_dic:
cur_actions.append(Actions.no_pass)
cur_actions.append(Actions.shift)
#print(tri_actions)
#print(ent_actions)
for i in range(len(tokens)):
is_ent_or_tri = False
if i in tri_actions:
actions += tri_actions[i]
is_ent_or_tri = True
if i in ent_actions:
actions += ent_actions[i]
is_ent_or_tri = True
if not is_ent_or_tri:
actions.append(Actions.o_delete)
return actions #, tri_actions, ent_actions

View File

@ -0,0 +1,34 @@
data_dir: 'data_files/'
ace05_event_dir: 'data_files/samples/'
vocab_dir: 'data_files/vocab/'
token_vocab_file: 'token_vocab.txt'
char_vocab_file: 'char_vocab.txt'
ent_type_vocab_file: 'ent_type_vocal.txt'
ent_ref_vocab_file: 'ent_ref_vocab.txt' # co-reference
tri_type_vocab_file: 'tri_type_vocab.txt'
arg_type_vocab_file: 'arg_type_vocal.txt'
action_vocab_file: 'action_vocab.txt'
pos_vocab_file: 'pos_vocab.txt'
#deptype_vocab_file: 'deptype_vocab.txt'
#nertag_vocab_file: 'nertag_vocab.txt'
#rel_vocab_file: 'rel_type_vocab.txt'
pickle_dir: 'data_files/pickle/'
vec_npy: 'data_files/pickle/word_vec.npy'
inst_pl_file: 'data_files/pickle/data.pl'
model_save_file: 'data_files/saved_models/model.ckpt'
train_sent_file: 'data_files/bert_emb/train_bert.npy'
dev_sent_file: 'data_files/bert_emb/dev_bert.npy'
test_sent_file: 'data_files/bert_emb/test_bert.npy'
embedding_dir: 'data_files/glove_emb/'
embedding_file: 'glove.6B.100d.txt'
embedding_type: 'glove'
normalize_digits: false
lower_case: false

View File

@ -0,0 +1,285 @@
import os
import re
import numpy as np
import dynet as dy
# code adopted from https://github.com/neulab/xnmt/blob/master/xnmt/param_collection.py
class ParamManager(object):
"""
A static class that manages the currently loaded DyNet parameters of all components.
Responsibilities are registering of all components that use DyNet parameters and loading pretrained parameters.
Components can register parameters by calling ParamManager.my_params(self) from within their __init__() method.
This allocates a subcollection with a unique identifier for this component. When loading previously saved parameters,
one or several paths are specified to look for the corresponding saved DyNet collection named after this identifier.
"""
initialized = False
@staticmethod
def init_param_col() -> None:
"""
Initializes or resets the parameter collection.
This must be invoked before every time a new model is loaded (e.g. on startup and between consecutive experiments).
"""
ParamManager.param_col = ParamCollection()
ParamManager.load_paths = []
ParamManager.initialized = True
# @staticmethod
# def set_save_file(file_name: str, save_num_checkpoints: int=1) -> None:
# assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
# ParamManager.param_col.model_file = file_name
# ParamManager.param_col.save_num_checkpoints = save_num_checkpoints
@staticmethod
def add_load_path(data_file: str) -> None:
"""
Add new data directory path to load from.
When calling populate(), pretrained parameters from all directories added in this way are searched for the
requested component identifiers.
Args:
data_file: a data directory (usually named ``*.data``) containing DyNet parameter collections.
"""
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
if not data_file in ParamManager.load_paths: ParamManager.load_paths.append(data_file)
@staticmethod
def populate() -> None:
"""
Populate the parameter collections.
Searches the given data paths and loads parameter collections if they exist, otherwise leave parameters in their
randomly initialized state.
"""
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
populated_subcols = []
for subcol_name in ParamManager.param_col.subcols:
for load_path in ParamManager.load_paths:
data_file = os.path.join(load_path, subcol_name)
if os.path.isfile(data_file):
ParamManager.param_col.load_subcol_from_data_file(subcol_name, data_file)
populated_subcols.append(subcol_name)
if len(ParamManager.param_col.subcols) == len(populated_subcols):
print(f"> populated DyNet weights of all components from given data files")
elif len(populated_subcols)==0:
print(f"> use randomly initialized DyNet weights of all components")
else:
print(f"> populated a subset of DyNet weights from given data files: {populated_subcols}.\n"
f" Did not populate {ParamManager.param_col.subcols.keys() - set(populated_subcols)}.\n"
f" If partial population was not intended, likely the unpopulated component or its owner"
f" does not adhere to the Serializable protocol correctly, see documentation:\n"
f" http://xnmt.readthedocs.io/en/latest/writing_xnmt_classes.html#using-serializable-subcomponents")
print(f" DyNet param count: {ParamManager.param_col._param_col.parameter_count()}")
@staticmethod
def my_params(subcol_owner) -> dy.ParameterCollection:
"""Creates a dedicated parameter subcollection for a serializable object.
This should only be called from the __init__ method of a Serializable.
Args:
subcol_owner (Serializable): The object which is requesting to be assigned a subcollection.
Returns:
The assigned subcollection.
"""
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
assert not getattr(subcol_owner, "init_completed", False), \
f"my_params(obj) cannot be called after obj.__init__() has completed. Conflicting obj: {subcol_owner}"
if not hasattr(subcol_owner, "xnmt_subcol_name"):
raise ValueError(f"{subcol_owner} does not have an attribute 'xnmt_subcol_name'.\n"
f"Did you forget to wrap the __init__() in @serializable_init ?")
subcol_name = subcol_owner.xnmt_subcol_name
subcol = ParamManager.param_col.add_subcollection(subcol_owner, subcol_name)
subcol_owner.save_processed_arg("xnmt_subcol_name", subcol_name)
return subcol
@staticmethod
def global_collection() -> dy.ParameterCollection:
""" Access the top-level parameter collection, including all parameters.
Returns:
top-level DyNet parameter collection
"""
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
return ParamManager.param_col._param_col
class ParamCollection(object):
def __init__(self):
self.reset()
def reset(self):
self._save_num_checkpoints = 1
self._model_file = None
self._param_col = dy.Model()
self._is_saved = False
self.subcols = {}
self.all_subcol_owners = set()
@property
def save_num_checkpoints(self):
return self._save_num_checkpoints
@save_num_checkpoints.setter
def save_num_checkpoints(self, value):
self._save_num_checkpoints = value
self._update_data_files()
@property
def model_file(self):
return self._model_file
@model_file.setter
def model_file(self, value):
self._model_file = value
self._update_data_files()
def _update_data_files(self):
if self._save_num_checkpoints>0 and self._model_file:
self._data_files = [self.model_file + '.data']
for i in range(1,self._save_num_checkpoints):
self._data_files.append(self.model_file + '.data.' + str(i))
else:
self._data_files = []
def add_subcollection(self, subcol_owner, subcol_name):
assert subcol_owner not in self.all_subcol_owners
self.all_subcol_owners.add(subcol_owner)
assert subcol_name not in self.subcols
new_subcol = self._param_col.add_subcollection(subcol_name)
self.subcols[subcol_name] = new_subcol
return new_subcol
def load_subcol_from_data_file(self, subcol_name, data_file):
self.subcols[subcol_name].populate(data_file)
def save(self):
if not self._is_saved:
self._remove_existing_history()
self._shift_saved_checkpoints()
if not os.path.exists(self._data_files[0]):
os.makedirs(self._data_files[0])
for subcol_name, subcol in self.subcols.items():
subcol.save(os.path.join(self._data_files[0], subcol_name))
self._is_saved = True
def revert_to_best_model(self):
if not self._is_saved:
raise ValueError("revert_to_best_model() is illegal because this model has never been saved.")
for subcol_name, subcol in self.subcols.items():
subcol.populate(os.path.join(self._data_files[0], subcol_name))
def _remove_existing_history(self):
for fname in self._data_files:
if os.path.exists(fname):
self._remove_data_dir(fname)
def _remove_data_dir(self, data_dir):
assert data_dir.endswith(".data") or data_dir.split(".")[-2] == "data"
try:
dir_contents = os.listdir(data_dir)
for old_file in dir_contents:
spl = old_file.split(".")
# make sure we're only deleting files with the expected filenames
if len(spl)==2:
if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", spl[0]):
if re.match(r"^[0-9a-f]{8}$", spl[1]):
os.remove(os.path.join(data_dir, old_file))
except NotADirectoryError:
os.remove(data_dir)
def _shift_saved_checkpoints(self):
if os.path.exists(self._data_files[-1]):
self._remove_data_dir(self._data_files[-1])
for i in range(len(self._data_files)-1)[::-1]:
if os.path.exists(self._data_files[i]):
os.rename(self._data_files[i], self._data_files[i+1])
class Optimizer(object):
"""
A base classe for trainers. Trainers are mostly simple wrappers of DyNet trainers but can add extra functionality.
Args:
optimizer: the underlying DyNet optimizer (trainer)
skip_noisy: keep track of a moving average and a moving standard deviation of the log of the gradient norm
values, and abort a step if the norm of the gradient exceeds four standard deviations of the
moving average. Reference: https://arxiv.org/pdf/1804.09849.pdf
"""
def __init__(self, optimizer: dy.Trainer) -> None:
self.optimizer = optimizer
def update(self) -> None:
self.optimizer.update()
def status(self):
"""
Outputs information about the trainer in the stderr.
(number of updates since last call, number of clipped gradients, learning rate, etc)
"""
return self.optimizer.status()
def set_clip_threshold(self, thr):
"""
Set clipping thershold
To deactivate clipping, set the threshold to be <=0
Args:
thr (number): Clipping threshold
"""
return self.optimizer.set_clip_threshold(thr)
def get_clip_threshold(self):
"""
Get clipping threshold
Returns:
number: Gradient clipping threshold
"""
return self.optimizer.get_clip_threshold()
def restart(self):
"""
Restarts the optimizer
Clears all momentum values and assimilate (if applicable)
"""
return self.optimizer.restart()
@property
def learning_rate(self):
return self.optimizer.learning_rate
@learning_rate.setter
def learning_rate(self, value):
self.optimizer.learning_rate = value
class AdamTrainer(Optimizer):
"""
Adam optimizer
The Adam optimizer is similar to RMSProp but uses unbiased estimates of the first and second moments of the gradient
Args:
alpha (number): Initial learning rate
beta_1 (number): Moving average parameter for the mean
beta_2 (number): Moving average parameter for the variance
eps (number): Epsilon parameter to prevent numerical instability
skip_noisy: keep track of a moving average and a moving standard deviation of the log of the gradient norm
values, and abort a step if the norm of the gradient exceeds four standard deviations of the
moving average. Reference: https://arxiv.org/pdf/1804.09849.pdf
"""
yaml_tag = '!AdamTrainer'
def __init__(self, alpha=0.001, beta_1=0.9, beta_2=0.999, eps=1e-8, update_every: int = 1, skip_noisy: bool = False):
super().__init__(optimizer=dy.AdamTrainer(ParamManager.global_collection(), alpha, beta_1, beta_2, eps))

View File

@ -0,0 +1,72 @@
import os
from vocab import Vocab
from io_utils import read_yaml, read_lines, read_json_lines
from str_utils import capitalize_first_char, normalize_tok, normalize_sent, collapse_role_type
class EventConstraint(object):
'''
This class is used to make sure that (event types, entity types) -> (argument roles) obey event constraints.
'''
def __init__(self, ent_dict, tri_dict, arg_dict):
constraint_file = './data_files/argrole_dict.txt'
self.constraint_list = [] # [(ent_type, tri_type, arg_type)]
for line in read_lines(constraint_file):
line = str(line).lower()
arr = line.split()
arg_type = arr[0]
for pair in arr[1:]:
pair_arr = pair.split(',')
tri_type = pair_arr[0]
ent_type = pair_arr[1]
ent_type = self._replace_ent(ent_type)
self.constraint_list.append((ent_type, tri_type, arg_type))
print('Event constraint size:',len(self.constraint_list))
# { (ent_type, tri_type) : (arg_type1, ...)}
self.ent_tri_to_arg_hash = {}
for cons in self.constraint_list:
ent_id = ent_dict[cons[0]]
tri_id = tri_dict[cons[1]]
arg_id = arg_dict[cons[2]]
# ent_id = cons[0]
# tri_id = cons[1]
# arg_id = cons[2]
if (ent_id, tri_id) not in self.ent_tri_to_arg_hash:
self.ent_tri_to_arg_hash[(ent_id, tri_id)] = set()
self.ent_tri_to_arg_hash[(ent_id, tri_id)].add(arg_id)
#print(self.ent_tri_to_arg_hash)
# single = 0
# for key, val in self.ent_tri_to_arg_hash.items():
# if len(val) == 1:
# single += 1
# print(single)
def _replace_ent(self, ent_type):
if ent_type == 'time':
return 'tim'
if ent_type == 'value':
return 'val'
return ent_type
def check_constraint(self, ent_type, tri_type, arg_type):
if (ent_type, tri_type, arg_type) in self.constraint_list:
return True
else:
return False
def get_constraint_arg_types(self, ent_type_id, tri_type_id):
return self.ent_tri_to_arg_hash.get((ent_type_id, tri_type_id), None)

View File

@ -0,0 +1,253 @@
from vocab import Vocab
def to_set(input):
out_set = set()
out_type_set = set()
for x in input:
out_set.add(tuple(x[:-1]))
out_type_set.add(tuple(x))
return out_set, out_type_set
class EventEval(object):
def __init__(self):
self.reset()
def reset(self):
self.correct_ent = 0.
self.correct_ent_with_type = 0.
self.num_pre_ent = 0.
self.num_gold_ent = 0.
self.correct_trigger = 0.
self.correct_trigger_with_type = 0.
self.num_pre_trigger = 0.
self.num_gold_trigger = 0.
self.correct_arg = 0.
self.correct_arg_with_role = 0.
self.num_pre_arg_no_type = 0.
self.num_gold_arg_no_type = 0.
self.num_pre_arg = 0.
self.num_gold_arg = 0.
# ------------------------------
self.num_tri_error = 0
self.num_ent_bound_error = 0
self.num_arg_error = 0
self.num_arg_error_with_role = 0
self.num_ent_not_in_arg_error = 0
self.num_tri_type_not_in_arg_error = 0
self.tri_type_error_count = {}
self.arg_type_error_count = {}
self.arg_error_chunk = {}
def get_coref_ent(self, g_ent_typed):
ent_ref_dict = {}
for ent1 in g_ent_typed:
start1, end1, ent_type1, ent_ref1 = ent1
coref_ents = []
ent_ref_dict[(start1, end1)] = coref_ents
for ent2 in g_ent_typed:
start2, end2, ent_type2, ent_ref2 = ent2
if ent_ref1 == ent_ref2:
coref_ents.append((start2, end2))
return ent_ref_dict
def split_prob(self, pred_args):
sp_args, probs = [], []
for arg in pred_args:
sp_args.append(arg[:-1])
probs.append(arg[-1])
return sp_args, probs
def update(self, pred_ents, gold_ents, pred_triggers, gold_triggers, pred_args, gold_args, eval_arg=True, words=None):
ent_ref_dict = self.get_coref_ent(gold_ents)
p_ent, p_ent_typed = to_set(pred_ents)
p_ent_to_type_dic = {(s,e):t for s,e,t in p_ent_typed}
g_ent, g_ent_typed = to_set([ent[:-1] for ent in gold_ents])
p_tri, p_tri_typed = to_set(pred_triggers)
g_tri, g_tri_typed = to_set(gold_triggers)
p_args, p_args_typed = to_set(pred_args)
g_args, g_args_typed = to_set(gold_args)
#p_args_typed = {(arg[2],arg[3]) for arg in p_args_typed}
#g_args_typed = {(arg[2], arg[3]) for arg in g_args_typed}
self.num_pre_ent += len(p_ent)
self.num_gold_ent += len(g_ent)
self.correct_ent += len(p_ent & g_ent)
self.correct_ent_with_type += len(p_ent_typed & g_ent_typed)
self.num_pre_trigger += len(p_tri)
self.num_gold_trigger += len(g_tri)
c_tri = p_tri & g_tri
c_tri_typed = p_tri_typed & g_tri_typed
self.correct_trigger += len(c_tri)
self.correct_trigger_with_type += len(c_tri_typed)
if not eval_arg:
return
c_tri_typed_indices = {tri[0] for tri in c_tri_typed}
p_tri_dic = {tri[0]: tri[1] for tri in p_tri_typed}
g_tri_dic = {tri[0]: tri[1] for tri in g_tri_typed}
p_arg_mention = {(arg[0], arg[1], p_tri_dic[arg[2]]) for arg in p_args_typed}
g_arg_mention = {(arg[0], arg[1], g_tri_dic[arg[2]]) for arg in g_args_typed}
p_arg_mention_typed = {(arg[0], arg[1], p_tri_dic[arg[2]], arg[3]) for arg in p_args_typed}
g_arg_mention_typed = {(arg[0], arg[1], g_tri_dic[arg[2]], arg[3]) for arg in g_args_typed}
self.num_pre_arg_no_type += len(p_arg_mention)
self.num_gold_arg_no_type += len(g_arg_mention)
self.num_pre_arg += len(p_arg_mention_typed)
self.num_gold_arg += len(g_arg_mention_typed)
for p_arg in p_arg_mention:
p_start, p_end, p_tri_type = p_arg
# if p_tri_idx not in c_tri_typed_indices:
# continue
if (p_start, p_end) not in ent_ref_dict:
continue
for coref_ent in ent_ref_dict[(p_start, p_end)]:
if (coref_ent[0], coref_ent[1], p_tri_type) in g_arg_mention:
self.correct_arg += 1
break
else:
self.num_arg_error += 1
# for p_arg in p_args_typed:
# #start, end, tri_idx, tri_type = p_arg
# tri_idx = p_arg[-2]
# if p_arg in g_args_typed and tri_idx in c_tri_typed_indices:
# self.correct_arg_with_role += 1
for num, p_arg in enumerate(p_arg_mention_typed):
p_start, p_end, p_tri_type, p_role_type = p_arg
p_ent_type = p_ent_to_type_dic[(p_start, p_end)]
#tri_idx = p_arg[-2]
# if p_tri_idx not in c_tri_typed_indices:
# self.num_tri_error += 1
# continue
if (p_start, p_end) not in ent_ref_dict:
self.num_ent_bound_error += 1
continue
for coref_ent in ent_ref_dict[(p_start, p_end)]:
if (coref_ent[0], coref_ent[1], p_tri_type, p_role_type) in g_arg_mention_typed:
self.correct_arg_with_role += 1
break
else:
self.num_arg_error_with_role += 1
for coref_ent in ent_ref_dict[(p_start, p_end)]:
has_ent = False
for g_arg in g_arg_mention_typed:
if coref_ent[0]==g_arg[0] and coref_ent[1]==g_arg[1]:
has_ent = True
break
if has_ent:
break
else:
self.num_ent_not_in_arg_error += 1
chunk = tuple([words[i].lower() for i in range(p_start, p_end+1)])
# print(' '.join(words))
# print(chunk, p_tri_type, p_role_type)
# print(p_arg_mention_typed)
# print(g_arg_mention_typed)
# print('==',[(words[tri[0]], tri[1]) for tri in pred_triggers])
# print('==',[(words[tri[0]], tri[1]) for tri in gold_triggers])
if chunk not in self.arg_error_chunk:
self.arg_error_chunk[chunk] = 1
else:
self.arg_error_chunk[chunk] += 1
key = p_ent_type
if key not in self.arg_type_error_count:
self.arg_type_error_count[key] = 1
else:
self.arg_type_error_count[key] += 1
for g_arg in g_arg_mention_typed:
if p_tri_type == g_arg[2]:
break
else:
self.num_tri_type_not_in_arg_error += 1
key = (p_tri_type, p_ent_type)
if key not in self.tri_type_error_count:
self.tri_type_error_count[key] = 1
else:
self.tri_type_error_count[key] += 1
def report(self):
p_ent = self.correct_ent / (self.num_pre_ent + 1e-18)
r_ent = self.correct_ent / (self.num_gold_ent + 1e-18)
f_ent = 2 * p_ent * r_ent / (p_ent + r_ent + 1e-18)
p_ent_typed = self.correct_ent_with_type / (self.num_pre_ent + 1e-18)
r_ent_typed = self.correct_ent_with_type / (self.num_gold_ent + 1e-18)
f_ent_typed = 2 * p_ent_typed * r_ent_typed / (p_ent_typed + r_ent_typed + 1e-18)
p_tri = self.correct_trigger / (self.num_pre_trigger + 1e-18)
r_tri = self.correct_trigger / (self.num_gold_trigger + 1e-18)
f_tri = 2 * p_tri * r_tri / (p_tri + r_tri + 1e-18)
p_tri_typed = self.correct_trigger_with_type / (self.num_pre_trigger + 1e-18)
r_tri_typed = self.correct_trigger_with_type / (self.num_gold_trigger + 1e-18)
f_tri_typed = 2 * p_tri_typed * r_tri_typed / (p_tri_typed + r_tri_typed + 1e-18)
p_arg = self.correct_arg / (self.num_pre_arg_no_type + 1e-18)
r_arg = self.correct_arg / (self.num_gold_arg_no_type + 1e-18)
f_arg = 2 * p_arg * r_arg / (p_arg + r_arg + 1e-18)
p_arg_typed = self.correct_arg_with_role / (self.num_pre_arg + 1e-18)
r_arg_typed = self.correct_arg_with_role / (self.num_gold_arg + 1e-18)
f_arg_typed = 2 * p_arg_typed * r_arg_typed / (p_arg_typed + r_arg_typed + 1e-18)
print('num_pre_arg:', self.num_pre_arg)
print('num_gold_arg:', self.num_gold_arg)
print('correct_arg_with_role:', self.correct_arg_with_role)
print('num_tri_error:', self.num_tri_error)
print('num_ent_bound_error:', self.num_ent_bound_error)
print('num_arg_error:', self.num_arg_error)
print('num_arg_error_with_role:', self.num_arg_error_with_role)
print('num_ent_not_in_arg_error:', self.num_ent_not_in_arg_error)
print('num_tri_type_not_in_arg_error:', self.num_tri_type_not_in_arg_error)
# for tri_type, count in self.tri_type_error_count.items():
# print(tri_type, ' : ' ,count)
#
# print(sum(self.tri_type_error_count.values()))
#
# for arg_type, count in self.arg_type_error_count.items():
# print(arg_type, ' : ' ,count)
#
# for chunk, count in self.arg_error_chunk.items():
# print(chunk, count)
# print(len(self.arg_error_chunk))
return (p_ent, r_ent, f_ent), (p_ent_typed, r_ent_typed, f_ent_typed), \
(p_tri, r_tri, f_tri), (p_tri_typed, r_tri_typed, f_tri_typed), \
(p_arg, r_arg, f_arg), (p_arg_typed, r_arg_typed, f_arg_typed)

View File

@ -0,0 +1,59 @@
from flair.data import Sentence
from flair.models import SequenceTagger
from flair.embeddings import CharLMEmbeddings, StackedEmbeddings, BertEmbeddings
import os
import pickle
import numpy as np
from io_utils import read_yaml, read_lines, read_json_lines
data_config = read_yaml('data_config.yaml')
data_dir = data_config['data_dir']
ace05_event_dir = data_config['ace05_event_dir']
train_list = read_json_lines(os.path.join(ace05_event_dir, 'train_nlp_ner.json'))
dev_list = read_json_lines(os.path.join(ace05_event_dir, 'dev_nlp_ner.json'))
test_list = read_json_lines(os.path.join(ace05_event_dir, 'test_nlp_ner.json'))
train_sent_file = data_config['train_sent_file']
bert = BertEmbeddings(layers='-1', bert_model_or_path='bert-base-uncased').to('cuda:0')
def save_bert(inst_list, filter_tri=True, name='train'):
sents = []
sent_lens = []
for inst in inst_list:
words, trigger_list, ent_list, arg_list = inst['nlp_words'], inst['Triggers'], inst['Entities'], inst['Arguments']
# Empirically filter out sentences where event size is 0 or entity size less than 3 (for traning)
if len(trigger_list) == 0 and len(ent_list) < 3 and filter_tri: continue
sents.append(words)
sent_lens.append(len(words))
total_word_nums = sum(sent_lens)
input_table = np.empty((total_word_nums,768 * 1))
acc_len = 0
for i, words in enumerate(sents):
if i % 100 ==0:
print('progress: %d, %d'%(i, len(sents)))
sent_len = sent_lens[i]
flair_sent = Sentence(' '.join(words))
bert.embed(flair_sent)
for j, token in enumerate(flair_sent):
start = acc_len + j
input_table[start, :] = token.embedding.cpu().detach().numpy()
acc_len += sent_len
bert_fname = data_config['train_sent_file'] if name == 'train' else \
data_config['dev_sent_file'] if name == 'dev' else data_config['test_sent_file']
np.save(bert_fname, input_table)
print('total_word_nums:', total_word_nums)
#print(len(sent_lens))
if __name__ == "__main__":
save_bert(train_list, name='train')
save_bert(dev_list,filter_tri=False, name='dev')
save_bert(test_list,filter_tri=False, name='test')