17 lines
532 B
Python
17 lines
532 B
Python
|
# -*- encoding:utf-8 -*-
|
||
|
import torch.nn as nn
|
||
|
from uer.utils.act_fun import gelu
|
||
|
|
||
|
|
||
|
class PositionwiseFeedForward(nn.Module):
|
||
|
""" Feed Forward Layer """
|
||
|
def __init__(self, hidden_size, feedforward_size):
|
||
|
super(PositionwiseFeedForward, self).__init__()
|
||
|
self.linear_1 = nn.Linear(hidden_size, feedforward_size)
|
||
|
self.linear_2 = nn.Linear(feedforward_size, hidden_size)
|
||
|
|
||
|
def forward(self, x):
|
||
|
inter = gelu(self.linear_1(x))
|
||
|
output = self.linear_2(inter)
|
||
|
return output
|