Create shift_reduce.py

This commit is contained in:
missQian 2020-10-04 21:32:22 +08:00 committed by GitHub
parent eba476ed56
commit ab696a8ba7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -0,0 +1,451 @@
import numpy as np
import dynet as dy
import nn
import ops
from dy_utils import ParamManager as pm
from actions import Actions
from vocab import Vocab
from event_constraints import EventConstraint
import io_utils
class MultiTask(object):
def __init__(self, config, encoder_output_dim , action_dict, ent_dict, tri_dict, arg_dict):
self.config = config
self.model = pm.global_collection()
bi_rnn_dim = encoder_output_dim # config['rnn_dim'] * 2 #+ config['edge_embed_dim']
lmda_dim = config['lmda_rnn_dim']
part_ent_dim = config['part_ent_rnn_dim']
self.lmda_dim = lmda_dim
self.bi_rnn_dim = bi_rnn_dim
hidden_input_dim = lmda_dim * 3 + bi_rnn_dim * 2 + config['out_rnn_dim']
self.hidden_arg = nn.Linear(hidden_input_dim, config['output_hidden_dim'],
activation='tanh')
self.output_arg = nn.Linear(config['output_hidden_dim'], len(arg_dict))
hidden_input_dim_co = lmda_dim * 3 + bi_rnn_dim * 2 + config['out_rnn_dim']
self.hidden_ent_corel = nn.Linear(hidden_input_dim_co, config['output_hidden_dim'],
activation='tanh')
self.output_ent_corel = nn.Linear(config['output_hidden_dim'], 2)
self.position_embed = nn.Embedding(500, 20)
attn_input = self.bi_rnn_dim * 1 + 20 * 2
self.attn_hidden = nn.Linear(attn_input, 80, activation='tanh')
self.attn_out = nn.Linear(80, 1)
def forward_ent_corel(self, beta_embed, lmda_embed, sigma_embed, delta_embed, out_embed, hidden_mat, start1, ent1, start2, end2, seq_len, last_h, gold_arg):
attn_rep = self.position_aware_attn(hidden_mat, last_h, start1, ent1, start2, end2, seq_len)
state_embed = ops.cat(
[beta_embed, lmda_embed, sigma_embed, delta_embed, out_embed, attn_rep], dim=0)
state_embed = dy.dropout(state_embed, 0.2)
hidden = self.hidden_ent_corel(state_embed)
out = self.output_ent_corel(hidden)
loss = dy.pickneglogsoftmax(out, gold_arg)
return loss
def forward_arg(self, beta_embed, lmda_embed, sigma_embed, delta_embed, out_embed, hidden_mat, tri_idx, ent_start, ent_end, seq_len, last_h, gold_arg):
attn_rep = self.position_aware_attn(hidden_mat, last_h, tri_idx, tri_idx, ent_start, ent_end, seq_len)
state_embed = ops.cat(
[beta_embed, lmda_embed, sigma_embed, delta_embed, out_embed, attn_rep], dim=0)
state_embed = dy.dropout(state_embed, 0.25)
hidden = self.hidden_arg(state_embed)
out = self.output_arg(hidden)
# probs = dy.softmax(out)
# gold_prob = dy.pick(probs, gold_arg)
# log_gold_prob = dy.log(gold_prob)
# loss_weight = dy.pow(1.03 - gold_prob, dy.scalarInput(2))
# loss = - loss_weight * log_gold_prob
loss = dy.pickneglogsoftmax(out, gold_arg)
return loss
def decode_arg(self, beta_embed, lmda_embed, sigma_embed, delta_embed, out_embed, hidden_mat, tri_idx, ent_start, ent_end, seq_len, last_h):
attn_rep = self.position_aware_attn(hidden_mat, last_h, tri_idx, tri_idx, ent_start, ent_end, seq_len)
state_embed = ops.cat(
[beta_embed, lmda_embed, sigma_embed, delta_embed, out_embed, attn_rep], dim=0)
hidden = self.hidden_arg(state_embed)
out = self.output_arg(hidden)
np_score = out.npvalue().flatten()
return np.argmax(np_score)
def position_aware_attn(self, hidden_mat, last_h, start1, ent1, start2, end2, seq_len):
tri_pos_list = []
ent_pos_list = []
for i in range(seq_len):
tri_pos_list.append(io_utils.relative_position(start1, ent1, i))
ent_pos_list.append(io_utils.relative_position(start2, end2, i))
tri_pos_emb = self.position_embed(tri_pos_list)
tri_pos_mat = ops.cat(tri_pos_emb, 1)
ent_pos_emb = self.position_embed(ent_pos_list)
ent_pos_mat = ops.cat(ent_pos_emb, 1)
#expand_last_h = nn.cat([last_h] * seq_len, 1)
# (birnn * 2 + pos_emb*2, seq_len)
att_input = ops.cat([hidden_mat, tri_pos_mat, ent_pos_mat], 0)
hidden = self.attn_hidden(att_input)
attn_out = self.attn_out(hidden)
# (1, seq_len)
attn_prob = nn.softmax(attn_out, dim=1)
# (rnn_dim * 2, 1)
rep = hidden_mat * dy.transpose(attn_prob)
return rep
class ShiftReduce(object):
def __init__(self, config, encoder_output_dim , action_dict, ent_dict, tri_dict, arg_dict):
self.config = config
self.model = pm.global_collection()
self.multi_task = MultiTask(config, encoder_output_dim , action_dict, ent_dict, tri_dict, arg_dict)
self.arg_null_id = arg_dict[Vocab.NULL]
bi_rnn_dim = encoder_output_dim # config['rnn_dim'] * 2 #+ config['edge_embed_dim']
lmda_dim = config['lmda_rnn_dim']
part_ent_dim = config['part_ent_rnn_dim']
self.lmda_dim = lmda_dim
self.bi_rnn_dim = bi_rnn_dim
self.lambda_var = nn.LambdaVar(lmda_dim)
dp_state = config['dp_state']
dp_state_h = config['dp_state_h']
self.sigma_rnn = nn.StackLSTM(lmda_dim, lmda_dim, dp_state, dp_state_h) # stack
self.delta_rnn = nn.StackLSTM(lmda_dim, lmda_dim, dp_state, dp_state_h) # will be pushed back
self.part_ent_rnn = nn.StackLSTM(bi_rnn_dim, part_ent_dim, dp_state, dp_state_h)
#self.beta = [] # buffer, unprocessed words
self.actions_rnn = nn.StackLSTM(config['action_embed_dim'], config['action_rnn_dim'], dp_state, dp_state_h)
self.out_rnn = nn.StackLSTM(bi_rnn_dim, config['out_rnn_dim'], dp_state, dp_state_h)
self.act_table = nn.Embedding(len(action_dict), config['action_embed_dim'])
self.ent_table = nn.Embedding(len(ent_dict), config['entity_embed_dim'])
self.tri_table = nn.Embedding(len(tri_dict), config['trigger_embed_dim'])
self.act= Actions(action_dict, ent_dict, tri_dict, arg_dict)
hidden_input_dim = bi_rnn_dim + lmda_dim * 3 + part_ent_dim \
+ config['action_rnn_dim'] + config['out_rnn_dim']
self.hidden_linear = nn.Linear(hidden_input_dim, config['output_hidden_dim'], activation='tanh')
self.output_linear = nn.Linear(config['output_hidden_dim'], len(action_dict))
entity_embed_dim = config['entity_embed_dim']
trigger_embed_dim = config['trigger_embed_dim']
ent_to_lmda_dim = config['part_ent_rnn_dim'] + entity_embed_dim #+ config['sent_vec_dim'] * 4
self.ent_to_lmda = nn.Linear(ent_to_lmda_dim, lmda_dim, activation='tanh')
tri_to_lmda_dim = bi_rnn_dim + trigger_embed_dim #+ config['sent_vec_dim']
self.tri_to_lmda = nn.Linear(tri_to_lmda_dim, lmda_dim, activation='tanh')
self.hidden_arg = nn.Linear(lmda_dim * 2 + self.bi_rnn_dim, config['output_hidden_dim'],
activation='tanh')
self.output_arg = nn.Linear(config['output_hidden_dim'], len(arg_dict))
self.empty_buffer_emb = self.model.add_parameters((bi_rnn_dim,), name='bufferGuardEmb')
self.event_cons = EventConstraint(ent_dict, tri_dict, arg_dict)
#self.cached_valid_args = self.cache_valid_args(ent_dict, tri_dict)
self.empty_times = 0
def cache_valid_args(self, ent_dict, tri_dict):
cached_valid_args = {}
for ent_type in ent_dict.vals():
for tri_type in tri_dict.vals():
valid_args = self.event_cons.get_constraint_arg_types(ent_type, tri_type)
if valid_args is None:
valid_actions = []
else:
valid_actions = self.act.get_act_ids_by_args(valid_args)
cached_valid_args[(ent_type, tri_type)] = valid_actions
print('cached_valid_args,',len(cached_valid_args))
return cached_valid_args
def get_valid_args(self, ent_type_id, tri_type_id):
return self.cached_valid_args[(ent_type_id, tri_type_id)]
def __call__(self, toks, hidden_state_list, last_h, oracle_actions=None,
oracle_action_strs=None, is_train=True, ents=None, tris=None, args=None):
ent_dic = dict()
tri_dic = dict()
gold_arg_dict = {(arg[0], arg[2]): arg[-1] for arg in args} # (ent_start, tri_idx):role_type
same_event_ents = self.same(args)
args = []
hidden_mat = ops.cat(hidden_state_list, 1)
seq_len = len(toks)
buffer = nn.Buffer(self.bi_rnn_dim, hidden_state_list)
losses = []
loss_rels = []
loss_roles = []
pred_action_strs = []
self.sigma_rnn.init_sequence(not is_train)
self.delta_rnn.init_sequence(not is_train)
self.part_ent_rnn.init_sequence(not is_train)
self.actions_rnn.init_sequence(not is_train)
self.out_rnn.init_sequence(not is_train)
steps = 0
while not (buffer.is_empty() and self.lambda_var.is_empty() and self.part_ent_rnn.is_empty()):
pre_action = None if self.actions_rnn.is_empty() else self.actions_rnn.last_idx()
# based on parser state, get valid actions
valid_actions = []
if pre_action is not None and self.act.is_ent_gen(pre_action):
valid_actions += [self.act.entity_back_id] #[self.act.entity_back_id, self.act.entity_shift_id]
# There are parts of the entity in e, we should finish this entity before process other actions
elif not self.part_ent_rnn.is_empty():
valid_actions += [self.act.entity_shift_id]
valid_actions += self.act.get_ent_gen_list()
elif not self.lambda_var.is_empty():
if self.sigma_rnn.is_empty():
valid_actions += [self.act.shift_id, self.act.copy_shift_id]
else:
valid_actions += [self.act.no_pass_id]
lmda_idx = self.lambda_var.idx
sigma_idx = self.sigma_rnn.last_idx()
if lmda_idx in ent_dic and sigma_idx in tri_dic:
valid_actions += [self.act.right_pass_id]
elif lmda_idx in tri_dic and sigma_idx in ent_dic:
valid_actions += [self.act.left_pass_id]
else:
valid_actions += [self.act.entity_shift_id, self.act.o_del_id]
valid_actions += self.act.get_tri_gen_list()
action = None
if buffer.is_empty():
self.empty_times += 1
beta_embed = self.empty_buffer_emb if buffer.is_empty() else buffer.hidden_embedding()
lmda_embed = self.lambda_var.embedding()
sigma_embed = self.sigma_rnn.embedding()
delta_embed = self.delta_rnn.embedding()
part_ent_embed = self.part_ent_rnn.embedding()
action_embed = self.actions_rnn.embedding()
out_embed = self.out_rnn.embedding()
state_embed = ops.cat([beta_embed, lmda_embed, sigma_embed, delta_embed, part_ent_embed, action_embed, out_embed], dim=0)
if is_train:
state_embed = dy.dropout(state_embed, self.config['dp_out'])
hidden_rep = self.hidden_linear(state_embed)
logits = self.output_linear(hidden_rep)
if is_train:
log_probs = dy.log_softmax(logits, valid_actions)
else:
log_probs = dy.log_softmax(logits, valid_actions)
if is_train:
action = oracle_actions[steps]
action_str = oracle_action_strs[steps]
if action not in valid_actions:
raise RuntimeError('Action %s dose not in valid_actions'%action_str)
# append the action-specific loss
#if self.act.is_o_del(action) or self.act.is_tri_gen(action):
losses.append(dy.pick(log_probs, action))
#val, idx = log_probs.tensor_value().topk(0, 5)
else:
np_log_probs = log_probs.npvalue()
act_prob = np.max(np_log_probs)
action = np.argmax(np_log_probs)
action_str = self.act.to_act_str(action)
pred_action_strs.append(action_str)
#print(action_str)
#if True:continue
# execute the action to update the parser state
if self.act.is_o_del(action):
hx, idx = buffer.pop()
self.out_rnn.push(hx, idx)
elif self.act.is_tri_gen(action):
hx, idx = buffer.pop()
type_id = self.act.to_tri_id(action)
tri_dic[idx] = (idx, type_id)
tri_embed = self.tri_table[type_id]
tri_rep = self.tri_to_lmda(ops.cat([hx, tri_embed], dim=0))
#tri_rep = self.tri_to_lmda(hx)
self.lambda_var.push(tri_rep, idx, nn.LambdaVar.TRIGGER)
elif self.act.is_ent_shift(action):
if buffer.is_empty():
break
hx, idx = buffer.pop()
self.part_ent_rnn.push(hx, idx)
elif self.act.is_ent_gen(action):
start, end = self.part_ent_rnn.idx_range()
type_id = self.act.to_ent_id(action)
ent = (start, end, type_id)
ent_dic[start] = ent
hx, _ = self.part_ent_rnn.last_state()
ent_embed = self.ent_table[type_id]
ent_rep = self.ent_to_lmda(ops.cat([hx, ent_embed], dim=0))
#ent_rep = self.ent_to_lmda(hx)
self.lambda_var.push(ent_rep, start, nn.LambdaVar.ENTITY)
elif self.act.is_ent_back(action):
new_idx = buffer.idx
new_idx -= len(self.part_ent_rnn) - 1
buffer.move_pointer(new_idx)
self.part_ent_rnn.clear()
elif self.act.is_shift(action):
while not self.delta_rnn.is_empty():
self.sigma_rnn.push(*self.delta_rnn.pop())
self.sigma_rnn.push(*self.lambda_var.pop())
elif self.act.is_copy_shift(action):
while not self.delta_rnn.is_empty():
self.sigma_rnn.push(*self.delta_rnn.pop())
self.sigma_rnn.push(*self.lambda_var.pop())
buffer.move_back()
elif self.act.is_no_pass(action):
lmda_idx = self.lambda_var.idx
sigma_last_embed, sigma_last_idx = self.sigma_rnn.pop()
if lmda_idx in ent_dic and sigma_last_idx in ent_dic:
ent_start1, ent_end1, _ = ent_dic[lmda_idx]
ent_start2, ent_end2, _ = ent_dic[sigma_last_idx]
corel = 1 if (lmda_idx, sigma_last_idx) in same_event_ents else 0
loss_corel = self.multi_task.forward_ent_corel(beta_embed, lmda_embed, sigma_embed, delta_embed, out_embed,
hidden_mat, ent_start1, ent_end1, ent_start2, ent_end2, seq_len, last_h, corel)
loss_rels.append(loss_corel)
self.delta_rnn.push(sigma_last_embed, sigma_last_idx)
elif self.act.is_left_pass(action):
lmda_idx = self.lambda_var.idx
sigma_last_embed, sigma_last_idx = self.sigma_rnn.pop()
tri_idx = lmda_idx
ent_start, ent_end, _ = ent_dic[sigma_last_idx]
if is_train:
role_label = gold_arg_dict.get((ent_start, tri_idx), self.arg_null_id)
loss_role = self.multi_task.forward_arg(beta_embed, lmda_embed, sigma_embed, delta_embed, out_embed,
hidden_mat, tri_idx, ent_start, ent_end, seq_len, last_h, role_label)
loss_roles.append(loss_role)
else:
role_label = self.multi_task.decode_arg(beta_embed, lmda_embed, sigma_embed, delta_embed, out_embed,
hidden_mat, tri_idx, ent_start, ent_end, seq_len, last_h)
event = (ent_start, ent_end, tri_idx, role_label)
args.append(event)
self.delta_rnn.push(sigma_last_embed, sigma_last_idx)
elif self.act.is_right_pass(action):
lmda_idx = self.lambda_var.idx
sigma_last_embed, sigma_last_idx = self.sigma_rnn.pop()
tri_idx = sigma_last_idx
ent_start, ent_end, _ = ent_dic[lmda_idx]
if is_train:
role_label = gold_arg_dict.get((ent_start, tri_idx), self.arg_null_id)
loss_role = self.multi_task.forward_arg(beta_embed, lmda_embed, sigma_embed, delta_embed, out_embed, hidden_mat, tri_idx, ent_start, ent_end, seq_len, last_h, role_label)
loss_roles.append(loss_role)
else:
role_label = self.multi_task.decode_arg(beta_embed, lmda_embed, sigma_embed, delta_embed, out_embed, hidden_mat, tri_idx, ent_start, ent_end, seq_len, last_h)
event = (ent_start, ent_end, tri_idx, role_label)
args.append(event)
self.delta_rnn.push(sigma_embed, sigma_last_idx)
else:
raise RuntimeError('Unknown action type:'+str(action))
self.actions_rnn.push(self.act_table[action], action)
steps += 1
#if not is_train:print(len(self.actions_rnn.indices), self.actions_rnn.indices)
pred_args = []
if is_train:
pred_args = set(args)
else:
for arg in args:
ent_start, ent_end, tri_idx, role_type = arg
ent_type_id = ent_dic[ent_start][-1]
tri_type_id = tri_dic[tri_idx][-1]
valid_args = self.event_cons.get_constraint_arg_types(ent_type_id, tri_type_id)
if valid_args and role_type in valid_args:
pred_args.append(arg)
self.clear()
return losses, loss_roles, loss_rels, set(ent_dic.values()), set(tri_dic.values()), pred_args, pred_action_strs
def clear(self):
self.sigma_rnn.clear()
self.delta_rnn.clear()
self.part_ent_rnn.clear()
self.actions_rnn.clear()
self.lambda_var.clear()
self.out_rnn.clear()
def same(self, args):
same_event_ents = set()
for arg1 in args:
ent_start1, ent_end1, tri_idx1, _ = arg1
for arg2 in args:
ent_start2, ent_end2, tri_idx2, _ = arg2
if tri_idx1 == tri_idx2:
same_event_ents.add((ent_start1, ent_start2))
same_event_ents.add((ent_start2, ent_start1))
return same_event_ents