Add files via upload
This commit is contained in:
parent
ce53f4c43a
commit
1eaf0b33f9
@ -0,0 +1,280 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
|
||||
|
||||
|
||||
class Actions(object):
|
||||
|
||||
trigger_gen = 'TRIGGER-GEN-'
|
||||
|
||||
entity_shift = 'ENTITY-SHIFT'
|
||||
entity_gen = 'ENTITY-GEN-'
|
||||
entity_back = 'ENTITY-BACK'
|
||||
|
||||
o_delete = 'O-DELETE'
|
||||
|
||||
|
||||
event_gen = 'EVENT-GEN-'
|
||||
|
||||
shift = 'SHIFT'
|
||||
no_pass = 'NO-PASS'
|
||||
left_pass = 'LEFT-PASS'
|
||||
right_pass = 'RIGHT-PASS'
|
||||
|
||||
back_shift = 'DUAL-SHIFT'
|
||||
# ------------------------------
|
||||
copy_shift = 'COPY-SHIFT'
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# event_shift = 'EVENT-SHIFT-'
|
||||
# event_reduce = 'EVENT-REDUCE-'
|
||||
# no_reduce = 'NO-REDUCE'
|
||||
|
||||
_PASS_PLACEHOLDER = 'PASS'
|
||||
_SHIFT_PLACEHOLDER = 'SHIFT'
|
||||
|
||||
def __init__(self, action_dict, ent_dict, tri_dict, arg_dict, with_copy_shift=True):
|
||||
self.entity_shift_id = action_dict[Actions.entity_shift]
|
||||
self.entity_back_id = action_dict[Actions.entity_back]
|
||||
self.o_del_id = action_dict[Actions.o_delete]
|
||||
self.shift_id = action_dict[Actions.shift]
|
||||
self.left_pass_id = action_dict[Actions.left_pass]
|
||||
self.right_pass_id = action_dict[Actions.right_pass]
|
||||
if with_copy_shift:
|
||||
self.copy_shift_id = action_dict[Actions.copy_shift]
|
||||
else:
|
||||
self.back_shift_id = action_dict[Actions.back_shift]
|
||||
self.no_pass_id = action_dict[Actions.no_pass]
|
||||
self.ent_gen_group = set()
|
||||
self.tri_gen_group = set()
|
||||
#self.event_reduce_group = set()
|
||||
#self.event_shift_group = set()
|
||||
self.event_gen_group = set()
|
||||
|
||||
self.act_to_ent_id = {}
|
||||
self.act_to_tri_id = {}
|
||||
self.act_to_arg_id = {}
|
||||
self.arg_to_act_id = {}
|
||||
|
||||
self.act_id_to_str = {v:k for k, v in action_dict.items()}
|
||||
|
||||
for name, id in action_dict.items():
|
||||
if name.startswith(Actions.entity_gen):
|
||||
self.ent_gen_group.add(id)
|
||||
self.act_to_ent_id[id] = ent_dict[name[len(Actions.entity_gen):]]
|
||||
|
||||
elif name.startswith(Actions.trigger_gen):
|
||||
self.tri_gen_group.add(id)
|
||||
self.act_to_tri_id[id] = tri_dict[name[len(Actions.trigger_gen):]]
|
||||
|
||||
elif name.startswith(Actions.event_gen):
|
||||
self.event_gen_group.add(id)
|
||||
self.act_to_arg_id[id] = arg_dict[name[len(Actions.event_gen):]]
|
||||
|
||||
for k,v in self.act_to_arg_id.items():
|
||||
self.arg_to_act_id[v] = k
|
||||
|
||||
def get_act_ids_by_args(self, arg_type_ids):
|
||||
acts = []
|
||||
for arg_id in arg_type_ids:
|
||||
acts.append(self.arg_to_act_id[arg_id])
|
||||
|
||||
return acts
|
||||
|
||||
def get_ent_gen_list(self):
|
||||
return list(self.ent_gen_group)
|
||||
|
||||
def get_tri_gen_list(self):
|
||||
return list(self.tri_gen_group)
|
||||
|
||||
|
||||
def get_event_gen_list(self):
|
||||
return list(self.event_gen_group)
|
||||
|
||||
|
||||
def to_act_str(self, act_id):
|
||||
return self.act_id_to_str[act_id]
|
||||
|
||||
def to_ent_id(self, act_id):
|
||||
return self.act_to_ent_id[act_id]
|
||||
|
||||
def to_tri_id(self, act_id):
|
||||
return self.act_to_tri_id[act_id]
|
||||
|
||||
def to_arg_id(self, act_id):
|
||||
return self.act_to_arg_id[act_id]
|
||||
|
||||
# action check
|
||||
|
||||
def is_ent_shift(self, act_id):
|
||||
return self.entity_shift_id == act_id
|
||||
|
||||
def is_ent_back(self, act_id):
|
||||
return self.entity_back_id == act_id
|
||||
|
||||
def is_o_del(self, act_id):
|
||||
return self.o_del_id == act_id
|
||||
|
||||
def is_shift(self, act_id):
|
||||
return self.shift_id == act_id
|
||||
|
||||
def is_back_shift(self, act_id):
|
||||
return self.back_shift_id == act_id
|
||||
|
||||
def is_copy_shift(self, act_id):
|
||||
return self.copy_shift_id == act_id
|
||||
|
||||
def is_no_pass(self, act_id):
|
||||
return self.no_pass_id == act_id
|
||||
|
||||
def is_left_pass(self, act_id):
|
||||
return self.left_pass_id == act_id
|
||||
|
||||
def is_right_pass(self, act_id):
|
||||
return self.right_pass_id == act_id
|
||||
|
||||
def is_ent_gen(self, act_id):
|
||||
return act_id in self.ent_gen_group
|
||||
|
||||
def is_tri_gen(self, act_id):
|
||||
return act_id in self.tri_gen_group
|
||||
|
||||
|
||||
def is_event_gen(self, act_id):
|
||||
return act_id in self.event_gen_group
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def make_oracle(tokens, triggers, ents, args, with_copy_shift=True):
|
||||
'''
|
||||
In this dataset, there are no nested entities sharing common start idx,
|
||||
therefore, we push back words from e to buffer exclude the first
|
||||
word in e.
|
||||
|
||||
# TODO with_copy_shift
|
||||
trigger_list : [(idx, event_type)...] e.g. [(27, '500')...]
|
||||
ent_list : [[start, end, ent_type],...] e.g. [[3, 3, '402']...]
|
||||
arg_list : [[arg_start, arg_end, trigger_idx, role_type]] e.g. [[21, 21, 27, 'Vehicle'],...]
|
||||
'''
|
||||
|
||||
|
||||
ent_dic = {ent[0]:ent for ent in ents}
|
||||
|
||||
trigger_dic = {tri[0]:tri[1] for tri in triggers}
|
||||
# (tri_idx, arg_start_idx)
|
||||
arg_dic = {(arg[2], arg[0]):arg for arg in args}
|
||||
|
||||
# for tri in trigger_dic.keys():
|
||||
# if tri in ent_dic:
|
||||
# print(tri,'======', ent_dic[tri])
|
||||
|
||||
|
||||
actions = []
|
||||
|
||||
# GEN entities and triggers
|
||||
tri_actions = {} # start_idx : actions list
|
||||
ent_actions = {} # start_idx : actions list
|
||||
|
||||
for tri in triggers:
|
||||
idx, event_type = tri
|
||||
tri_actions[idx] = [Actions.trigger_gen + event_type]
|
||||
|
||||
for ent in ents:
|
||||
start, end, ent_type, ref = ent
|
||||
act = []
|
||||
for _ in range(start, end + 1):
|
||||
act.append(Actions.entity_shift)
|
||||
|
||||
act.append(Actions.entity_gen + ent_type)
|
||||
act.append(Actions.entity_back)
|
||||
ent_actions[start] = act
|
||||
|
||||
|
||||
|
||||
|
||||
for tri_i in trigger_dic:
|
||||
cur_actions = tri_actions[tri_i]
|
||||
for j in range(tri_i - 1, -1, -1):
|
||||
|
||||
if j in trigger_dic:
|
||||
cur_actions.append(Actions.no_pass)
|
||||
|
||||
if j in ent_dic:
|
||||
key = (tri_i, j)
|
||||
if key in arg_dic:
|
||||
arg_start, arg_end, trigger_idx, role_type = arg_dic[key]
|
||||
cur_actions.append(Actions.left_pass)
|
||||
#cur_actions.append(Actions.event_gen + role_type)
|
||||
else:
|
||||
cur_actions.append(Actions.no_pass)
|
||||
|
||||
|
||||
if tri_i in ent_dic:
|
||||
if with_copy_shift:
|
||||
cur_actions.append(Actions.copy_shift)
|
||||
else:
|
||||
cur_actions.append(Actions.back_shift)
|
||||
else:
|
||||
cur_actions.append(Actions.shift)
|
||||
|
||||
|
||||
for ent_i in ent_dic:
|
||||
cur_actions = ent_actions[ent_i]
|
||||
# Take into account that a word can be a trigger as well as an entity start
|
||||
if with_copy_shift and ent_i in trigger_dic:
|
||||
if (ent_i, ent_i) in arg_dic:
|
||||
arg_start, arg_end, trigger_idx, role_type = arg_dic[(ent_i, ent_i)]
|
||||
cur_actions.append(Actions.right_pass)
|
||||
#cur_actions.append(Actions.event_gen + role_type)
|
||||
else:
|
||||
cur_actions.append(Actions.no_pass)
|
||||
|
||||
for j in range(ent_i - 1, -1, -1):
|
||||
|
||||
if j in trigger_dic:
|
||||
key = (j, ent_i)
|
||||
if key in arg_dic:
|
||||
arg_start, arg_end, trigger_idx, role_type = arg_dic[key]
|
||||
cur_actions.append(Actions.right_pass)
|
||||
#cur_actions.append(Actions.event_gen + role_type)
|
||||
|
||||
else:
|
||||
cur_actions.append(Actions.no_pass)
|
||||
|
||||
if j in ent_dic:
|
||||
cur_actions.append(Actions.no_pass)
|
||||
|
||||
|
||||
cur_actions.append(Actions.shift)
|
||||
|
||||
|
||||
|
||||
#print(tri_actions)
|
||||
#print(ent_actions)
|
||||
for i in range(len(tokens)):
|
||||
is_ent_or_tri = False
|
||||
if i in tri_actions:
|
||||
actions += tri_actions[i]
|
||||
is_ent_or_tri = True
|
||||
|
||||
if i in ent_actions:
|
||||
actions += ent_actions[i]
|
||||
is_ent_or_tri = True
|
||||
|
||||
if not is_ent_or_tri:
|
||||
actions.append(Actions.o_delete)
|
||||
|
||||
|
||||
return actions #, tri_actions, ent_actions
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,34 @@
|
||||
|
||||
|
||||
data_dir: 'data_files/'
|
||||
ace05_event_dir: 'data_files/samples/'
|
||||
|
||||
vocab_dir: 'data_files/vocab/'
|
||||
token_vocab_file: 'token_vocab.txt'
|
||||
char_vocab_file: 'char_vocab.txt'
|
||||
ent_type_vocab_file: 'ent_type_vocal.txt'
|
||||
ent_ref_vocab_file: 'ent_ref_vocab.txt' # co-reference
|
||||
tri_type_vocab_file: 'tri_type_vocab.txt'
|
||||
arg_type_vocab_file: 'arg_type_vocal.txt'
|
||||
action_vocab_file: 'action_vocab.txt'
|
||||
pos_vocab_file: 'pos_vocab.txt'
|
||||
#deptype_vocab_file: 'deptype_vocab.txt'
|
||||
#nertag_vocab_file: 'nertag_vocab.txt'
|
||||
#rel_vocab_file: 'rel_type_vocab.txt'
|
||||
|
||||
pickle_dir: 'data_files/pickle/'
|
||||
vec_npy: 'data_files/pickle/word_vec.npy'
|
||||
inst_pl_file: 'data_files/pickle/data.pl'
|
||||
model_save_file: 'data_files/saved_models/model.ckpt'
|
||||
|
||||
|
||||
train_sent_file: 'data_files/bert_emb/train_bert.npy'
|
||||
dev_sent_file: 'data_files/bert_emb/dev_bert.npy'
|
||||
test_sent_file: 'data_files/bert_emb/test_bert.npy'
|
||||
|
||||
embedding_dir: 'data_files/glove_emb/'
|
||||
embedding_file: 'glove.6B.100d.txt'
|
||||
embedding_type: 'glove'
|
||||
|
||||
normalize_digits: false
|
||||
lower_case: false
|
@ -0,0 +1,285 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import dynet as dy
|
||||
|
||||
# code adopted from https://github.com/neulab/xnmt/blob/master/xnmt/param_collection.py
|
||||
|
||||
class ParamManager(object):
|
||||
"""
|
||||
A static class that manages the currently loaded DyNet parameters of all components.
|
||||
|
||||
Responsibilities are registering of all components that use DyNet parameters and loading pretrained parameters.
|
||||
Components can register parameters by calling ParamManager.my_params(self) from within their __init__() method.
|
||||
This allocates a subcollection with a unique identifier for this component. When loading previously saved parameters,
|
||||
one or several paths are specified to look for the corresponding saved DyNet collection named after this identifier.
|
||||
"""
|
||||
initialized = False
|
||||
|
||||
@staticmethod
|
||||
def init_param_col() -> None:
|
||||
"""
|
||||
Initializes or resets the parameter collection.
|
||||
|
||||
This must be invoked before every time a new model is loaded (e.g. on startup and between consecutive experiments).
|
||||
"""
|
||||
ParamManager.param_col = ParamCollection()
|
||||
ParamManager.load_paths = []
|
||||
ParamManager.initialized = True
|
||||
|
||||
# @staticmethod
|
||||
# def set_save_file(file_name: str, save_num_checkpoints: int=1) -> None:
|
||||
# assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
|
||||
# ParamManager.param_col.model_file = file_name
|
||||
# ParamManager.param_col.save_num_checkpoints = save_num_checkpoints
|
||||
|
||||
@staticmethod
|
||||
def add_load_path(data_file: str) -> None:
|
||||
"""
|
||||
Add new data directory path to load from.
|
||||
|
||||
When calling populate(), pretrained parameters from all directories added in this way are searched for the
|
||||
requested component identifiers.
|
||||
|
||||
Args:
|
||||
data_file: a data directory (usually named ``*.data``) containing DyNet parameter collections.
|
||||
"""
|
||||
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
|
||||
if not data_file in ParamManager.load_paths: ParamManager.load_paths.append(data_file)
|
||||
|
||||
@staticmethod
|
||||
def populate() -> None:
|
||||
"""
|
||||
Populate the parameter collections.
|
||||
|
||||
Searches the given data paths and loads parameter collections if they exist, otherwise leave parameters in their
|
||||
randomly initialized state.
|
||||
"""
|
||||
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
|
||||
populated_subcols = []
|
||||
for subcol_name in ParamManager.param_col.subcols:
|
||||
for load_path in ParamManager.load_paths:
|
||||
data_file = os.path.join(load_path, subcol_name)
|
||||
if os.path.isfile(data_file):
|
||||
ParamManager.param_col.load_subcol_from_data_file(subcol_name, data_file)
|
||||
populated_subcols.append(subcol_name)
|
||||
if len(ParamManager.param_col.subcols) == len(populated_subcols):
|
||||
print(f"> populated DyNet weights of all components from given data files")
|
||||
elif len(populated_subcols)==0:
|
||||
print(f"> use randomly initialized DyNet weights of all components")
|
||||
else:
|
||||
print(f"> populated a subset of DyNet weights from given data files: {populated_subcols}.\n"
|
||||
f" Did not populate {ParamManager.param_col.subcols.keys() - set(populated_subcols)}.\n"
|
||||
f" If partial population was not intended, likely the unpopulated component or its owner"
|
||||
f" does not adhere to the Serializable protocol correctly, see documentation:\n"
|
||||
f" http://xnmt.readthedocs.io/en/latest/writing_xnmt_classes.html#using-serializable-subcomponents")
|
||||
print(f" DyNet param count: {ParamManager.param_col._param_col.parameter_count()}")
|
||||
|
||||
@staticmethod
|
||||
def my_params(subcol_owner) -> dy.ParameterCollection:
|
||||
"""Creates a dedicated parameter subcollection for a serializable object.
|
||||
|
||||
This should only be called from the __init__ method of a Serializable.
|
||||
|
||||
Args:
|
||||
subcol_owner (Serializable): The object which is requesting to be assigned a subcollection.
|
||||
|
||||
Returns:
|
||||
The assigned subcollection.
|
||||
"""
|
||||
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
|
||||
assert not getattr(subcol_owner, "init_completed", False), \
|
||||
f"my_params(obj) cannot be called after obj.__init__() has completed. Conflicting obj: {subcol_owner}"
|
||||
if not hasattr(subcol_owner, "xnmt_subcol_name"):
|
||||
raise ValueError(f"{subcol_owner} does not have an attribute 'xnmt_subcol_name'.\n"
|
||||
f"Did you forget to wrap the __init__() in @serializable_init ?")
|
||||
subcol_name = subcol_owner.xnmt_subcol_name
|
||||
subcol = ParamManager.param_col.add_subcollection(subcol_owner, subcol_name)
|
||||
subcol_owner.save_processed_arg("xnmt_subcol_name", subcol_name)
|
||||
return subcol
|
||||
|
||||
@staticmethod
|
||||
def global_collection() -> dy.ParameterCollection:
|
||||
""" Access the top-level parameter collection, including all parameters.
|
||||
|
||||
Returns:
|
||||
top-level DyNet parameter collection
|
||||
"""
|
||||
assert ParamManager.initialized, "must call ParamManager.init_param_col() first"
|
||||
return ParamManager.param_col._param_col
|
||||
|
||||
class ParamCollection(object):
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
def reset(self):
|
||||
self._save_num_checkpoints = 1
|
||||
self._model_file = None
|
||||
self._param_col = dy.Model()
|
||||
self._is_saved = False
|
||||
self.subcols = {}
|
||||
self.all_subcol_owners = set()
|
||||
|
||||
@property
|
||||
def save_num_checkpoints(self):
|
||||
return self._save_num_checkpoints
|
||||
@save_num_checkpoints.setter
|
||||
def save_num_checkpoints(self, value):
|
||||
self._save_num_checkpoints = value
|
||||
self._update_data_files()
|
||||
@property
|
||||
def model_file(self):
|
||||
return self._model_file
|
||||
|
||||
@model_file.setter
|
||||
def model_file(self, value):
|
||||
self._model_file = value
|
||||
self._update_data_files()
|
||||
|
||||
def _update_data_files(self):
|
||||
if self._save_num_checkpoints>0 and self._model_file:
|
||||
self._data_files = [self.model_file + '.data']
|
||||
for i in range(1,self._save_num_checkpoints):
|
||||
self._data_files.append(self.model_file + '.data.' + str(i))
|
||||
else:
|
||||
self._data_files = []
|
||||
|
||||
def add_subcollection(self, subcol_owner, subcol_name):
|
||||
assert subcol_owner not in self.all_subcol_owners
|
||||
self.all_subcol_owners.add(subcol_owner)
|
||||
assert subcol_name not in self.subcols
|
||||
new_subcol = self._param_col.add_subcollection(subcol_name)
|
||||
self.subcols[subcol_name] = new_subcol
|
||||
return new_subcol
|
||||
|
||||
def load_subcol_from_data_file(self, subcol_name, data_file):
|
||||
self.subcols[subcol_name].populate(data_file)
|
||||
|
||||
def save(self):
|
||||
if not self._is_saved:
|
||||
self._remove_existing_history()
|
||||
self._shift_saved_checkpoints()
|
||||
if not os.path.exists(self._data_files[0]):
|
||||
os.makedirs(self._data_files[0])
|
||||
for subcol_name, subcol in self.subcols.items():
|
||||
subcol.save(os.path.join(self._data_files[0], subcol_name))
|
||||
self._is_saved = True
|
||||
|
||||
def revert_to_best_model(self):
|
||||
if not self._is_saved:
|
||||
raise ValueError("revert_to_best_model() is illegal because this model has never been saved.")
|
||||
for subcol_name, subcol in self.subcols.items():
|
||||
subcol.populate(os.path.join(self._data_files[0], subcol_name))
|
||||
|
||||
def _remove_existing_history(self):
|
||||
for fname in self._data_files:
|
||||
if os.path.exists(fname):
|
||||
self._remove_data_dir(fname)
|
||||
|
||||
def _remove_data_dir(self, data_dir):
|
||||
assert data_dir.endswith(".data") or data_dir.split(".")[-2] == "data"
|
||||
try:
|
||||
dir_contents = os.listdir(data_dir)
|
||||
for old_file in dir_contents:
|
||||
spl = old_file.split(".")
|
||||
# make sure we're only deleting files with the expected filenames
|
||||
if len(spl)==2:
|
||||
if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", spl[0]):
|
||||
if re.match(r"^[0-9a-f]{8}$", spl[1]):
|
||||
os.remove(os.path.join(data_dir, old_file))
|
||||
except NotADirectoryError:
|
||||
os.remove(data_dir)
|
||||
|
||||
def _shift_saved_checkpoints(self):
|
||||
if os.path.exists(self._data_files[-1]):
|
||||
self._remove_data_dir(self._data_files[-1])
|
||||
for i in range(len(self._data_files)-1)[::-1]:
|
||||
if os.path.exists(self._data_files[i]):
|
||||
os.rename(self._data_files[i], self._data_files[i+1])
|
||||
|
||||
|
||||
|
||||
class Optimizer(object):
|
||||
"""
|
||||
A base classe for trainers. Trainers are mostly simple wrappers of DyNet trainers but can add extra functionality.
|
||||
|
||||
Args:
|
||||
optimizer: the underlying DyNet optimizer (trainer)
|
||||
skip_noisy: keep track of a moving average and a moving standard deviation of the log of the gradient norm
|
||||
values, and abort a step if the norm of the gradient exceeds four standard deviations of the
|
||||
moving average. Reference: https://arxiv.org/pdf/1804.09849.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer: dy.Trainer) -> None:
|
||||
self.optimizer = optimizer
|
||||
|
||||
def update(self) -> None:
|
||||
self.optimizer.update()
|
||||
|
||||
def status(self):
|
||||
"""
|
||||
Outputs information about the trainer in the stderr.
|
||||
|
||||
(number of updates since last call, number of clipped gradients, learning rate, etc…)
|
||||
"""
|
||||
return self.optimizer.status()
|
||||
|
||||
def set_clip_threshold(self, thr):
|
||||
"""
|
||||
Set clipping thershold
|
||||
|
||||
To deactivate clipping, set the threshold to be <=0
|
||||
|
||||
Args:
|
||||
thr (number): Clipping threshold
|
||||
"""
|
||||
return self.optimizer.set_clip_threshold(thr)
|
||||
|
||||
def get_clip_threshold(self):
|
||||
"""
|
||||
Get clipping threshold
|
||||
|
||||
Returns:
|
||||
number: Gradient clipping threshold
|
||||
"""
|
||||
return self.optimizer.get_clip_threshold()
|
||||
|
||||
def restart(self):
|
||||
"""
|
||||
Restarts the optimizer
|
||||
|
||||
Clears all momentum values and assimilate (if applicable)
|
||||
"""
|
||||
return self.optimizer.restart()
|
||||
|
||||
@property
|
||||
def learning_rate(self):
|
||||
return self.optimizer.learning_rate
|
||||
|
||||
@learning_rate.setter
|
||||
def learning_rate(self, value):
|
||||
self.optimizer.learning_rate = value
|
||||
|
||||
|
||||
|
||||
class AdamTrainer(Optimizer):
|
||||
"""
|
||||
Adam optimizer
|
||||
|
||||
The Adam optimizer is similar to RMSProp but uses unbiased estimates of the first and second moments of the gradient
|
||||
|
||||
Args:
|
||||
alpha (number): Initial learning rate
|
||||
beta_1 (number): Moving average parameter for the mean
|
||||
beta_2 (number): Moving average parameter for the variance
|
||||
eps (number): Epsilon parameter to prevent numerical instability
|
||||
skip_noisy: keep track of a moving average and a moving standard deviation of the log of the gradient norm
|
||||
values, and abort a step if the norm of the gradient exceeds four standard deviations of the
|
||||
moving average. Reference: https://arxiv.org/pdf/1804.09849.pdf
|
||||
"""
|
||||
yaml_tag = '!AdamTrainer'
|
||||
|
||||
def __init__(self, alpha=0.001, beta_1=0.9, beta_2=0.999, eps=1e-8, update_every: int = 1, skip_noisy: bool = False):
|
||||
super().__init__(optimizer=dy.AdamTrainer(ParamManager.global_collection(), alpha, beta_1, beta_2, eps))
|
||||
|
@ -0,0 +1,72 @@
|
||||
|
||||
|
||||
import os
|
||||
from vocab import Vocab
|
||||
from io_utils import read_yaml, read_lines, read_json_lines
|
||||
from str_utils import capitalize_first_char, normalize_tok, normalize_sent, collapse_role_type
|
||||
|
||||
class EventConstraint(object):
|
||||
'''
|
||||
This class is used to make sure that (event types, entity types) -> (argument roles) obey event constraints.
|
||||
'''
|
||||
|
||||
def __init__(self, ent_dict, tri_dict, arg_dict):
|
||||
constraint_file = './data_files/argrole_dict.txt'
|
||||
|
||||
self.constraint_list = [] # [(ent_type, tri_type, arg_type)]
|
||||
for line in read_lines(constraint_file):
|
||||
line = str(line).lower()
|
||||
arr = line.split()
|
||||
arg_type = arr[0]
|
||||
for pair in arr[1:]:
|
||||
pair_arr = pair.split(',')
|
||||
tri_type = pair_arr[0]
|
||||
ent_type = pair_arr[1]
|
||||
ent_type = self._replace_ent(ent_type)
|
||||
self.constraint_list.append((ent_type, tri_type, arg_type))
|
||||
|
||||
print('Event constraint size:',len(self.constraint_list))
|
||||
# { (ent_type, tri_type) : (arg_type1, ...)}
|
||||
self.ent_tri_to_arg_hash = {}
|
||||
for cons in self.constraint_list:
|
||||
ent_id = ent_dict[cons[0]]
|
||||
tri_id = tri_dict[cons[1]]
|
||||
arg_id = arg_dict[cons[2]]
|
||||
# ent_id = cons[0]
|
||||
# tri_id = cons[1]
|
||||
# arg_id = cons[2]
|
||||
if (ent_id, tri_id) not in self.ent_tri_to_arg_hash:
|
||||
self.ent_tri_to_arg_hash[(ent_id, tri_id)] = set()
|
||||
|
||||
self.ent_tri_to_arg_hash[(ent_id, tri_id)].add(arg_id)
|
||||
|
||||
#print(self.ent_tri_to_arg_hash)
|
||||
|
||||
# single = 0
|
||||
# for key, val in self.ent_tri_to_arg_hash.items():
|
||||
# if len(val) == 1:
|
||||
# single += 1
|
||||
# print(single)
|
||||
|
||||
def _replace_ent(self, ent_type):
|
||||
if ent_type == 'time':
|
||||
return 'tim'
|
||||
|
||||
if ent_type == 'value':
|
||||
return 'val'
|
||||
|
||||
return ent_type
|
||||
|
||||
def check_constraint(self, ent_type, tri_type, arg_type):
|
||||
if (ent_type, tri_type, arg_type) in self.constraint_list:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def get_constraint_arg_types(self, ent_type_id, tri_type_id):
|
||||
return self.ent_tri_to_arg_hash.get((ent_type_id, tri_type_id), None)
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,253 @@
|
||||
|
||||
from vocab import Vocab
|
||||
|
||||
def to_set(input):
|
||||
out_set = set()
|
||||
out_type_set = set()
|
||||
for x in input:
|
||||
out_set.add(tuple(x[:-1]))
|
||||
out_type_set.add(tuple(x))
|
||||
|
||||
return out_set, out_type_set
|
||||
|
||||
class EventEval(object):
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.correct_ent = 0.
|
||||
self.correct_ent_with_type = 0.
|
||||
self.num_pre_ent = 0.
|
||||
self.num_gold_ent = 0.
|
||||
|
||||
self.correct_trigger = 0.
|
||||
self.correct_trigger_with_type = 0.
|
||||
self.num_pre_trigger = 0.
|
||||
self.num_gold_trigger = 0.
|
||||
|
||||
self.correct_arg = 0.
|
||||
self.correct_arg_with_role = 0.
|
||||
self.num_pre_arg_no_type = 0.
|
||||
self.num_gold_arg_no_type = 0.
|
||||
self.num_pre_arg = 0.
|
||||
self.num_gold_arg = 0.
|
||||
|
||||
# ------------------------------
|
||||
self.num_tri_error = 0
|
||||
self.num_ent_bound_error = 0
|
||||
self.num_arg_error = 0
|
||||
self.num_arg_error_with_role = 0
|
||||
self.num_ent_not_in_arg_error = 0
|
||||
self.num_tri_type_not_in_arg_error = 0
|
||||
|
||||
self.tri_type_error_count = {}
|
||||
self.arg_type_error_count = {}
|
||||
self.arg_error_chunk = {}
|
||||
|
||||
|
||||
def get_coref_ent(self, g_ent_typed):
|
||||
ent_ref_dict = {}
|
||||
for ent1 in g_ent_typed:
|
||||
start1, end1, ent_type1, ent_ref1 = ent1
|
||||
coref_ents = []
|
||||
ent_ref_dict[(start1, end1)] = coref_ents
|
||||
for ent2 in g_ent_typed:
|
||||
start2, end2, ent_type2, ent_ref2 = ent2
|
||||
if ent_ref1 == ent_ref2:
|
||||
coref_ents.append((start2, end2))
|
||||
return ent_ref_dict
|
||||
|
||||
|
||||
def split_prob(self, pred_args):
|
||||
sp_args, probs = [], []
|
||||
for arg in pred_args:
|
||||
sp_args.append(arg[:-1])
|
||||
probs.append(arg[-1])
|
||||
return sp_args, probs
|
||||
|
||||
def update(self, pred_ents, gold_ents, pred_triggers, gold_triggers, pred_args, gold_args, eval_arg=True, words=None):
|
||||
ent_ref_dict = self.get_coref_ent(gold_ents)
|
||||
|
||||
p_ent, p_ent_typed = to_set(pred_ents)
|
||||
p_ent_to_type_dic = {(s,e):t for s,e,t in p_ent_typed}
|
||||
g_ent, g_ent_typed = to_set([ent[:-1] for ent in gold_ents])
|
||||
p_tri, p_tri_typed = to_set(pred_triggers)
|
||||
g_tri, g_tri_typed = to_set(gold_triggers)
|
||||
p_args, p_args_typed = to_set(pred_args)
|
||||
g_args, g_args_typed = to_set(gold_args)
|
||||
#p_args_typed = {(arg[2],arg[3]) for arg in p_args_typed}
|
||||
#g_args_typed = {(arg[2], arg[3]) for arg in g_args_typed}
|
||||
|
||||
|
||||
self.num_pre_ent += len(p_ent)
|
||||
self.num_gold_ent += len(g_ent)
|
||||
|
||||
self.correct_ent += len(p_ent & g_ent)
|
||||
self.correct_ent_with_type += len(p_ent_typed & g_ent_typed)
|
||||
|
||||
self.num_pre_trigger += len(p_tri)
|
||||
self.num_gold_trigger += len(g_tri)
|
||||
|
||||
c_tri = p_tri & g_tri
|
||||
c_tri_typed = p_tri_typed & g_tri_typed
|
||||
self.correct_trigger += len(c_tri)
|
||||
self.correct_trigger_with_type += len(c_tri_typed)
|
||||
|
||||
if not eval_arg:
|
||||
return
|
||||
|
||||
c_tri_typed_indices = {tri[0] for tri in c_tri_typed}
|
||||
p_tri_dic = {tri[0]: tri[1] for tri in p_tri_typed}
|
||||
g_tri_dic = {tri[0]: tri[1] for tri in g_tri_typed}
|
||||
p_arg_mention = {(arg[0], arg[1], p_tri_dic[arg[2]]) for arg in p_args_typed}
|
||||
g_arg_mention = {(arg[0], arg[1], g_tri_dic[arg[2]]) for arg in g_args_typed}
|
||||
p_arg_mention_typed = {(arg[0], arg[1], p_tri_dic[arg[2]], arg[3]) for arg in p_args_typed}
|
||||
g_arg_mention_typed = {(arg[0], arg[1], g_tri_dic[arg[2]], arg[3]) for arg in g_args_typed}
|
||||
|
||||
self.num_pre_arg_no_type += len(p_arg_mention)
|
||||
self.num_gold_arg_no_type += len(g_arg_mention)
|
||||
|
||||
self.num_pre_arg += len(p_arg_mention_typed)
|
||||
self.num_gold_arg += len(g_arg_mention_typed)
|
||||
|
||||
for p_arg in p_arg_mention:
|
||||
p_start, p_end, p_tri_type = p_arg
|
||||
# if p_tri_idx not in c_tri_typed_indices:
|
||||
# continue
|
||||
if (p_start, p_end) not in ent_ref_dict:
|
||||
continue
|
||||
for coref_ent in ent_ref_dict[(p_start, p_end)]:
|
||||
if (coref_ent[0], coref_ent[1], p_tri_type) in g_arg_mention:
|
||||
self.correct_arg += 1
|
||||
break
|
||||
else:
|
||||
self.num_arg_error += 1
|
||||
|
||||
|
||||
# for p_arg in p_args_typed:
|
||||
# #start, end, tri_idx, tri_type = p_arg
|
||||
# tri_idx = p_arg[-2]
|
||||
# if p_arg in g_args_typed and tri_idx in c_tri_typed_indices:
|
||||
# self.correct_arg_with_role += 1
|
||||
|
||||
|
||||
for num, p_arg in enumerate(p_arg_mention_typed):
|
||||
p_start, p_end, p_tri_type, p_role_type = p_arg
|
||||
p_ent_type = p_ent_to_type_dic[(p_start, p_end)]
|
||||
#tri_idx = p_arg[-2]
|
||||
# if p_tri_idx not in c_tri_typed_indices:
|
||||
# self.num_tri_error += 1
|
||||
# continue
|
||||
if (p_start, p_end) not in ent_ref_dict:
|
||||
self.num_ent_bound_error += 1
|
||||
continue
|
||||
|
||||
for coref_ent in ent_ref_dict[(p_start, p_end)]:
|
||||
if (coref_ent[0], coref_ent[1], p_tri_type, p_role_type) in g_arg_mention_typed:
|
||||
self.correct_arg_with_role += 1
|
||||
break
|
||||
|
||||
else:
|
||||
self.num_arg_error_with_role += 1
|
||||
|
||||
|
||||
|
||||
for coref_ent in ent_ref_dict[(p_start, p_end)]:
|
||||
has_ent = False
|
||||
for g_arg in g_arg_mention_typed:
|
||||
if coref_ent[0]==g_arg[0] and coref_ent[1]==g_arg[1]:
|
||||
has_ent = True
|
||||
break
|
||||
if has_ent:
|
||||
break
|
||||
else:
|
||||
self.num_ent_not_in_arg_error += 1
|
||||
chunk = tuple([words[i].lower() for i in range(p_start, p_end+1)])
|
||||
# print(' '.join(words))
|
||||
# print(chunk, p_tri_type, p_role_type)
|
||||
# print(p_arg_mention_typed)
|
||||
# print(g_arg_mention_typed)
|
||||
# print('==',[(words[tri[0]], tri[1]) for tri in pred_triggers])
|
||||
# print('==',[(words[tri[0]], tri[1]) for tri in gold_triggers])
|
||||
|
||||
if chunk not in self.arg_error_chunk:
|
||||
self.arg_error_chunk[chunk] = 1
|
||||
else:
|
||||
self.arg_error_chunk[chunk] += 1
|
||||
|
||||
key = p_ent_type
|
||||
if key not in self.arg_type_error_count:
|
||||
self.arg_type_error_count[key] = 1
|
||||
else:
|
||||
self.arg_type_error_count[key] += 1
|
||||
|
||||
|
||||
for g_arg in g_arg_mention_typed:
|
||||
if p_tri_type == g_arg[2]:
|
||||
break
|
||||
else:
|
||||
self.num_tri_type_not_in_arg_error += 1
|
||||
key = (p_tri_type, p_ent_type)
|
||||
if key not in self.tri_type_error_count:
|
||||
self.tri_type_error_count[key] = 1
|
||||
else:
|
||||
self.tri_type_error_count[key] += 1
|
||||
|
||||
|
||||
|
||||
def report(self):
|
||||
p_ent = self.correct_ent / (self.num_pre_ent + 1e-18)
|
||||
r_ent = self.correct_ent / (self.num_gold_ent + 1e-18)
|
||||
f_ent = 2 * p_ent * r_ent / (p_ent + r_ent + 1e-18)
|
||||
|
||||
p_ent_typed = self.correct_ent_with_type / (self.num_pre_ent + 1e-18)
|
||||
r_ent_typed = self.correct_ent_with_type / (self.num_gold_ent + 1e-18)
|
||||
f_ent_typed = 2 * p_ent_typed * r_ent_typed / (p_ent_typed + r_ent_typed + 1e-18)
|
||||
|
||||
p_tri = self.correct_trigger / (self.num_pre_trigger + 1e-18)
|
||||
r_tri = self.correct_trigger / (self.num_gold_trigger + 1e-18)
|
||||
f_tri = 2 * p_tri * r_tri / (p_tri + r_tri + 1e-18)
|
||||
|
||||
p_tri_typed = self.correct_trigger_with_type / (self.num_pre_trigger + 1e-18)
|
||||
r_tri_typed = self.correct_trigger_with_type / (self.num_gold_trigger + 1e-18)
|
||||
f_tri_typed = 2 * p_tri_typed * r_tri_typed / (p_tri_typed + r_tri_typed + 1e-18)
|
||||
|
||||
|
||||
p_arg = self.correct_arg / (self.num_pre_arg_no_type + 1e-18)
|
||||
r_arg = self.correct_arg / (self.num_gold_arg_no_type + 1e-18)
|
||||
f_arg = 2 * p_arg * r_arg / (p_arg + r_arg + 1e-18)
|
||||
|
||||
p_arg_typed = self.correct_arg_with_role / (self.num_pre_arg + 1e-18)
|
||||
r_arg_typed = self.correct_arg_with_role / (self.num_gold_arg + 1e-18)
|
||||
f_arg_typed = 2 * p_arg_typed * r_arg_typed / (p_arg_typed + r_arg_typed + 1e-18)
|
||||
|
||||
|
||||
print('num_pre_arg:', self.num_pre_arg)
|
||||
print('num_gold_arg:', self.num_gold_arg)
|
||||
print('correct_arg_with_role:', self.correct_arg_with_role)
|
||||
|
||||
print('num_tri_error:', self.num_tri_error)
|
||||
print('num_ent_bound_error:', self.num_ent_bound_error)
|
||||
print('num_arg_error:', self.num_arg_error)
|
||||
print('num_arg_error_with_role:', self.num_arg_error_with_role)
|
||||
|
||||
print('num_ent_not_in_arg_error:', self.num_ent_not_in_arg_error)
|
||||
print('num_tri_type_not_in_arg_error:', self.num_tri_type_not_in_arg_error)
|
||||
|
||||
# for tri_type, count in self.tri_type_error_count.items():
|
||||
# print(tri_type, ' : ' ,count)
|
||||
#
|
||||
# print(sum(self.tri_type_error_count.values()))
|
||||
#
|
||||
# for arg_type, count in self.arg_type_error_count.items():
|
||||
# print(arg_type, ' : ' ,count)
|
||||
#
|
||||
# for chunk, count in self.arg_error_chunk.items():
|
||||
# print(chunk, count)
|
||||
# print(len(self.arg_error_chunk))
|
||||
|
||||
return (p_ent, r_ent, f_ent), (p_ent_typed, r_ent_typed, f_ent_typed), \
|
||||
(p_tri, r_tri, f_tri), (p_tri_typed, r_tri_typed, f_tri_typed), \
|
||||
(p_arg, r_arg, f_arg), (p_arg_typed, r_arg_typed, f_arg_typed)
|
||||
|
@ -0,0 +1,59 @@
|
||||
from flair.data import Sentence
|
||||
from flair.models import SequenceTagger
|
||||
|
||||
from flair.embeddings import CharLMEmbeddings, StackedEmbeddings, BertEmbeddings
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from io_utils import read_yaml, read_lines, read_json_lines
|
||||
|
||||
data_config = read_yaml('data_config.yaml')
|
||||
|
||||
data_dir = data_config['data_dir']
|
||||
ace05_event_dir = data_config['ace05_event_dir']
|
||||
|
||||
train_list = read_json_lines(os.path.join(ace05_event_dir, 'train_nlp_ner.json'))
|
||||
dev_list = read_json_lines(os.path.join(ace05_event_dir, 'dev_nlp_ner.json'))
|
||||
test_list = read_json_lines(os.path.join(ace05_event_dir, 'test_nlp_ner.json'))
|
||||
|
||||
train_sent_file = data_config['train_sent_file']
|
||||
|
||||
bert = BertEmbeddings(layers='-1', bert_model_or_path='bert-base-uncased').to('cuda:0')
|
||||
|
||||
def save_bert(inst_list, filter_tri=True, name='train'):
|
||||
sents = []
|
||||
sent_lens = []
|
||||
for inst in inst_list:
|
||||
words, trigger_list, ent_list, arg_list = inst['nlp_words'], inst['Triggers'], inst['Entities'], inst['Arguments']
|
||||
# Empirically filter out sentences where event size is 0 or entity size less than 3 (for traning)
|
||||
if len(trigger_list) == 0 and len(ent_list) < 3 and filter_tri: continue
|
||||
sents.append(words)
|
||||
sent_lens.append(len(words))
|
||||
|
||||
total_word_nums = sum(sent_lens)
|
||||
input_table = np.empty((total_word_nums,768 * 1))
|
||||
acc_len = 0
|
||||
for i, words in enumerate(sents):
|
||||
if i % 100 ==0:
|
||||
print('progress: %d, %d'%(i, len(sents)))
|
||||
sent_len = sent_lens[i]
|
||||
flair_sent = Sentence(' '.join(words))
|
||||
bert.embed(flair_sent)
|
||||
for j, token in enumerate(flair_sent):
|
||||
start = acc_len + j
|
||||
input_table[start, :] = token.embedding.cpu().detach().numpy()
|
||||
acc_len += sent_len
|
||||
|
||||
bert_fname = data_config['train_sent_file'] if name == 'train' else \
|
||||
data_config['dev_sent_file'] if name == 'dev' else data_config['test_sent_file']
|
||||
np.save(bert_fname, input_table)
|
||||
|
||||
print('total_word_nums:', total_word_nums)
|
||||
#print(len(sent_lens))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
save_bert(train_list, name='train')
|
||||
save_bert(dev_list,filter_tri=False, name='dev')
|
||||
save_bert(test_list,filter_tri=False, name='test')
|
Loading…
Reference in New Issue
Block a user