26 lines
992 B
Python
Executable File
26 lines
992 B
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class LstmSubencoder(nn.Module):
|
|
def __init__(self, args, vocab_size):
|
|
super(LstmSubencoder, self).__init__()
|
|
self.hidden_size= args.emb_size
|
|
self.layers_num = args.sub_layers_num
|
|
|
|
self.embedding_layer = nn.Embedding(vocab_size, args.emb_size)
|
|
self.rnn = nn.LSTM(input_size=args.emb_size,
|
|
hidden_size=self.hidden_size,
|
|
num_layers=self.layers_num,
|
|
dropout=args.dropout,
|
|
batch_first=True)
|
|
|
|
def forward(self, ids):
|
|
batch_size, _ = ids.size() # batch_size, max_length
|
|
hidden = (torch.zeros(self.layers_num, batch_size, self.hidden_size).to(ids.device),
|
|
torch.zeros(self.layers_num, batch_size, self.hidden_size).to(ids.device))
|
|
emb = self.embedding_layer(ids)
|
|
output, hidden = self.rnn(emb, hidden)
|
|
output = output.mean(1)
|
|
return output
|