mirror of
https://github.com/JoergFranke/ADNC.git
synced 2024-11-17 13:58:03 +08:00
add layer normalization
This commit is contained in:
parent
26b78c4bfe
commit
da7babc725
30
adnc/utils/normalization.py
Normal file
30
adnc/utils/normalization.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright 2018 Jörg Franke
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def layer_norm(weights, name, dtype=tf.float32, reuse=False, collection='ADNC'):
|
||||
_eps = 1e-6
|
||||
|
||||
with tf.variable_scope("ln_{}".format(name), reuse=reuse):
|
||||
scale = tf.get_variable('scale', shape=[weights.get_shape()[1]], initializer=tf.constant_initializer(1.),
|
||||
collections=[collection, tf.GraphKeys.GLOBAL_VARIABLES], dtype=dtype)
|
||||
beta = tf.get_variable('beta', shape=[weights.get_shape()[1]], initializer=tf.constant_initializer(0.),
|
||||
collections=[collection, tf.GraphKeys.GLOBAL_VARIABLES], dtype=dtype)
|
||||
|
||||
mean, var = tf.nn.moments(weights, axes=[1], keep_dims=True)
|
||||
norm_weights = (weights - mean) / tf.sqrt(var + _eps)
|
||||
|
||||
return norm_weights * scale + beta
|
44
test/adnc/utils/test_normalization.py
Normal file
44
test/adnc/utils/test_normalization.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Copyright 2018 Jörg Franke
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from adnc.utils.normalization import layer_norm
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def np_rng():
|
||||
seed = np.random.randint(1, 999)
|
||||
return np.random.RandomState(seed)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session():
|
||||
with tf.Session() as sess:
|
||||
yield sess
|
||||
tf.reset_default_graph()
|
||||
|
||||
|
||||
def test_layer_norm(session, np_rng):
|
||||
np_weights = np_rng.normal(0, 1, [64, 128])
|
||||
|
||||
weights = tf.constant(np_weights, dtype=tf.float32)
|
||||
weights_ln = layer_norm(weights, 'test')
|
||||
|
||||
session.run(tf.global_variables_initializer())
|
||||
weights_ln = session.run(weights_ln)
|
||||
|
||||
assert weights_ln.shape == (64, 128)
|
Loading…
Reference in New Issue
Block a user