2019-06-18 23:34:57 +08:00
|
|
|
import torch
|
|
|
|
import torch.autograd as autograd
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.optim as optim
|
|
|
|
import numpy as np
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LSTM_TAN(nn.Module):
|
|
|
|
def __init__(self,version,embedding_dim, hidden_dim, vocab_size, n_targets,embedding_matrix,dropout = 0.5):
|
|
|
|
super(LSTM_TAN, self).__init__()
|
|
|
|
if version not in ["tan-","tan","lstm"]:
|
|
|
|
print("Version is tan-,tan,lstm")
|
|
|
|
sys.exit(-1)
|
|
|
|
|
|
|
|
self.hidden_dim = hidden_dim
|
|
|
|
self.embedding_dim = embedding_dim
|
2019-06-19 00:46:04 +08:00
|
|
|
|
2019-06-18 23:34:57 +08:00
|
|
|
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
|
|
|
|
self.word_embeddings.weight = nn.Parameter(torch.tensor(embedding_matrix,dtype=torch.float))
|
|
|
|
self.word_embeddings.weight.requires_grad=True
|
|
|
|
self.version = version
|
|
|
|
|
|
|
|
if version == "tan-":
|
|
|
|
self.attention = nn.Linear(embedding_dim,1)
|
|
|
|
elif version == "tan":
|
|
|
|
self.attention = nn.Linear(2*embedding_dim,1)
|
|
|
|
|
|
|
|
|
2019-06-19 00:12:32 +08:00
|
|
|
self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=(version!="lstm"))
|
2019-06-18 23:34:57 +08:00
|
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
|
|
|
|
if version !="lstm":
|
|
|
|
self.hidden2target = nn.Linear(2*self.hidden_dim, n_targets)
|
|
|
|
else:
|
|
|
|
self.hidden2target = nn.Linear(self.hidden_dim, n_targets)
|
|
|
|
|
|
|
|
self.hidden = self.init_hidden()
|
|
|
|
|
|
|
|
def init_hidden(self):
|
|
|
|
return (torch.zeros(1, 1, self.hidden_dim),
|
|
|
|
torch.zeros(1, 1, self.hidden_dim))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, sentence, target,verbose=False):
|
|
|
|
x_emb = self.word_embeddings(sentence)
|
|
|
|
version = self.version
|
|
|
|
|
|
|
|
if version != "tan-":
|
|
|
|
t_emb = self.word_embeddings(target)
|
|
|
|
t_emb = torch.mean(t_emb,dim=0,keepdim=True)
|
|
|
|
xt_emb = torch.cat((x_emb,t_emb.expand(len(sentence),-1)),dim=1)
|
|
|
|
|
|
|
|
if version == "tan-":
|
|
|
|
lstm_out, _ = self.lstm(
|
|
|
|
x_emb.view(len(sentence), 1 , self.embedding_dim))
|
|
|
|
|
|
|
|
a = self.attention(x_emb)
|
|
|
|
|
|
|
|
final_hidden_state = torch.mm(F.softmax(a.view(1,-1),dim=1),lstm_out.view(len(sentence),-1))
|
|
|
|
|
|
|
|
elif version == "tan":
|
|
|
|
a = self.attention(xt_emb)
|
|
|
|
|
|
|
|
lstm_out, _ = self.lstm(x_emb.view(len(sentence), 1 , self.embedding_dim))
|
|
|
|
|
|
|
|
final_hidden_state = torch.mm(F.softmax(a.view(1,-1),dim=1),lstm_out.view(len(sentence),-1))
|
|
|
|
|
|
|
|
elif version == "lstm":
|
|
|
|
_, hidden_state = self.lstm(
|
|
|
|
x_emb.view(len(sentence), 1 , self.embedding_dim))
|
|
|
|
|
|
|
|
final_hidden_state = hidden_state[0].view(-1,self.hidden_dim)
|
|
|
|
|
|
|
|
|
|
|
|
target_space = self.hidden2target(self.dropout(final_hidden_state))
|
|
|
|
target_scores = F.log_softmax(target_space, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
return target_scores
|
|
|
|
|