K-BERT/uer/layers/position_ffn.py
2019-12-12 19:37:32 +08:00

17 lines
532 B
Python

# -*- encoding:utf-8 -*-
import torch.nn as nn
from uer.utils.act_fun import gelu
class PositionwiseFeedForward(nn.Module):
""" Feed Forward Layer """
def __init__(self, hidden_size, feedforward_size):
super(PositionwiseFeedForward, self).__init__()
self.linear_1 = nn.Linear(hidden_size, feedforward_size)
self.linear_2 = nn.Linear(feedforward_size, hidden_size)
def forward(self, x):
inter = gelu(self.linear_1(x))
output = self.linear_2(inter)
return output