17 lines
532 B
Python
17 lines
532 B
Python
# -*- encoding:utf-8 -*-
|
|
import torch.nn as nn
|
|
from uer.utils.act_fun import gelu
|
|
|
|
|
|
class PositionwiseFeedForward(nn.Module):
|
|
""" Feed Forward Layer """
|
|
def __init__(self, hidden_size, feedforward_size):
|
|
super(PositionwiseFeedForward, self).__init__()
|
|
self.linear_1 = nn.Linear(hidden_size, feedforward_size)
|
|
self.linear_2 = nn.Linear(feedforward_size, hidden_size)
|
|
|
|
def forward(self, x):
|
|
inter = gelu(self.linear_1(x))
|
|
output = self.linear_2(inter)
|
|
return output
|