Batch Normalization and Layer Normalization

Batch Normalization and Layer Normalization

tensorflow 代码

tensorflow/contrib/layers/python/layers/layers.py

  • tf.contrib.layers.batch_norm() 和 tf.contrib.layers.layer_norm() 的主要作用是根据自己需要求的维度,先求出均值和方差,并为 beta 和 gamma 进行初始化.

  • tf.nn.batch_normalization() 的作用是根据你求出的均值和方差,对原输入进行归一化操作。

1. tf.nn.moments()函数
  • 定义:def moments(x, axes, name=None, keep_dims=False)
  • 解释:
    1
    2
    3
    4
    x 可以理解为我们输出的数据,形如 [batchsize, height, width, channel]
    axes 表示在哪个维度上求解,是个list,例如 [0, 1, 2]
    name 就是个名字,不多解释
    keep_dims 是否保持维度,不多解释
  • 输出:
    1
    2
    3
    4
    5
    Two Tensor objects: mean andvariance. (均值和方差)
    输出的维度:axes表示要在哪些维度上求解,输出的均值和方差的维度与剩下的维度保持一致。
    例如:
    inputs.shape = [128,4,2,5], axes = [0,1,2], 那么 mean.shape = [5]
    inputs.shape = [128,4,2,5], axes = [0,1], 那么 mean.shape = [2,5]
  • 举例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    # coding: utf-8
    import tensorflow as tf
    img = tf.Variable(tf.random_normal([128, 4, 2, 3]))
    axis = [0,1,2] # 所以剩余的是四维,看做一个整体shape为[3]
    mean, variance = tf.nn.moments(img, axis)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
    sess.run(init)
    mean_, variance_ = sess.run([mean, variance])
    print("均值:",mean_)
    print("方差:",variance_)
    均值: [-0.03587267  0.06021447  0.02401767]
    方差: [0.99473494 0.93040663 0.98113006]

    eg2:
    # coding: utf-8
    import tensorflow as tf
    img = tf.Variable(tf.random_normal([128, 4, 2, 3]))
    axis = [0,1] #所以剩余的是第三 四维,看做一个整体shape为[2, 3]
    mean, variance = tf.nn.moments(img, axis)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
    sess.run(init)
    mean_, variance_ = sess.run([mean, variance])
    print("均值:",mean_)
    print("方差:",variance_)
    均值:
    [[-0.04313184  0.01417894  0.06847101]
    [ 0.04183875 -0.01508999 -0.11406976]]
    方差:
    [[0.976376   0.91841435 1.0207324 ]
    [1.0403597  0.9773739  1.0360421 ]]
2. tf.contrib.layers.batch_norm() 函数

代码中和计算均值和方差不太相关的地方删掉了,可以自行看源码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def batch_norm(inputs,
decay=0.999,
center=True,
scale=False,
epsilon=0.001,
activation_fn=None,
param_initializers=None,
param_regularizers=None,
updates_collections=ops.GraphKeys.UPDATE_OPS,
is_training=True,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
batch_weights=None,
fused=None,
data_format=DATA_FORMAT_NHWC,
zero_debias_moving_mean=False,
scope=None,
renorm=False,
renorm_clipping=None,
renorm_decay=0.99,
adjustment=None):
"""Adds a Batch Normalization layer from .
1. Can be used as a normalizer function for conv2d and fully_connected.
2. Args:
- inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`.
The normalization is over all but the last dimension if `data_format` is `NHWC`
and the second dimension if `data_format` is `NCHW`.
也就是说:channel通道不做归一化,即 batch norm 取不同样本的同一个通道的特征做归一化;
- data_format: A string. `NHWC` (default) and `NCHW` are supported.
- reuse的相关内容看源码
- updates_collections:when training, the moving_mean and moving_variance need to be updated.
By default the update ops are placed in updates_collections = tf.GraphKeys.UPDATE_OPS.
If None, a control dependency would be added to make sure the updates are computed in place.
3. Returns:
A `Tensor` representing the output of the operation.
"""
inputs = ops.convert_to_tensor(inputs)
rank = inputs.get_shape().ndims

if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
raise ValueError('data_format has to be either NCHW or NHWC.')

layer_variable_getter = _build_variable_getter()
with variable_scope.variable_scope(scope,'BatchNorm', [inputs],reuse=reuse, custom_getter=layer_variable_getter) as sc:
inputs = ops.convert_to_tensor(inputs)
inputs_shape = inputs.get_shape()
inputs_rank = inputs_shape.ndims
if inputs_rank is None:
raise ValueError('Inputs %s has undefined rank.' % inputs.name)
dtype = inputs.dtype.base_dtype
......

if data_format == DATA_FORMAT_NCHW:
moments_axes = [0] + list(range(2, inputs_rank)) # [0,2,3]
params_shape = inputs_shape[1:2] # C
# params_shape的rank和inputs_rank相同,除了channel维度以外,其他都是1 -> [1,C,1,1]
params_shape_broadcast = list([1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
else:
moments_axes = list(range(inputs_rank - 1)) # [1,2,3]
params_shape = inputs_shape[-1:] # C
params_shape_broadcast = None

# Allocate parameters for the beta and gamma of the normalization.
beta, gamma = None, None
......
beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable)
gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable)

# Create moving_mean and moving_variance variables
......
moving_mean = variables.model_variable('moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections)
moving_variance = variables.model_variable('moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections)

# 判断是训练阶段还是测试阶段,不同的时期计算均值的方差的方式是不同的,训练时基于的是batch
is_training_value = utils.constant_value(is_training)
need_moments = is_training_value is None or is_training_value
if need_moments:
# Calculate the moments based on the individual batch.
......
if data_format == DATA_FORMAT_NCHW:
# 下面这样写和直接:mean, variance = nn.moments(inputs, moments_axes) 有啥区别..?
mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
mean = array_ops.reshape(mean, [-1])
variance = array_ops.reshape(variance, [-1])
else:
mean, variance = nn.moments(inputs, moments_axes)

# 训练阶段,moving_mean, moving_variance 需要更新。
moving_vars_fn = lambda: (moving_mean, moving_variance)
if updates_collections is None:
def _force_updates():
"""Internal function forces updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay, zero_debias=False)
with ops.control_dependencies([update_moving_mean, update_moving_variance]):
return array_ops.identity(mean), array_ops.identity(variance)
# need_moments 不是已经能够判断是不是在训练阶段了吗?
mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn)
else:
def _delay_updates():
"""Internal function that delay updates moving_vars if is_training."""
update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay, zero_debias=False)
return update_moving_mean, update_moving_variance

update_mean, update_variance = utils.smart_cond(is_training, _delay_updates, moving_vars_fn)
ops.add_to_collections(updates_collections, update_mean)
ops.add_to_collections(updates_collections, update_variance)

# Use computed moments during training and moving_vars otherwise.
vars_fn = lambda: (mean, variance)
mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn)

else:
mean, variance = moving_mean, moving_variance
if data_format == DATA_FORMAT_NCHW:
mean = array_ops.reshape(mean, params_shape_broadcast)
variance = array_ops.reshape(variance, params_shape_broadcast)
if beta is not None:
beta = array_ops.reshape(beta, params_shape_broadcast)
if gamma is not None:
gamma = array_ops.reshape(gamma, params_shape_broadcast)

# Compute batch_normalization.
outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)
outputs.set_shape(inputs_shape)
if activation_fn is not None:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
3. tf.contrib.layers.layer_norm() 函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def layer_norm(inputs,
center=True,
scale=True,
activation_fn=None,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
begin_norm_axis=1,
begin_params_axis=-1,
scope=None):
"""Adds a Layer Normalization layer.
1. Can be used as a normalizer function for conv2d and fully_connected.
2. Given a tensor `inputs` of rank `R`, moments are calculated and normalization is performed over axes `begin_norm_axis ... R - 1`. Scaling and centering, if requested, is performed over axes `begin_params_axis .. R - 1`. By default, `begin_norm_axis = 1` and `begin_params_axis = -1`, meaning that normalization is performed over all but the first axis (the `HWC` if `inputs` is `NHWC`), while the `beta` and `gamma` trainable parameters are calculated for the rightmost axis (the `C` if `inputs` is `NHWC`). Scaling and recentering is performed via broadcast of the `beta` and `gamma` parameters with the normalized tensor.
3. The shapes of `beta` and `gamma` are `inputs.shape[begin_params_axis:]`, and this part of the inputs' shape must be fully defined.
4. Args:
- inputs: A tensor having rank `R`. The normalization is performed over axes `begin_norm_axis ... R - 1` and centering and scaling parameters are calculated over `begin_params_axis ... R - 1`.
- center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored.
- scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer.
- activation_fn: Activation function, default set to None to skip it and maintain a linear activation.
- begin_norm_axis: The first normalization dimension: normalization will be performed along dimensions `begin_norm_axis : rank(inputs)`
- begin_params_axis: The first parameter (beta, gamma) dimension: scale and centering parameters will have dimensions `begin_params_axis : rank(inputs)` and will be broadcast with the normalized inputs accordingly.
5. Returns:
A `Tensor` representing the output of the operation, having the same shape and dtype as `inputs`.
"""
with variable_scope.variable_scope(scope, 'LayerNorm', [inputs], reuse=reuse) as sc:
inputs = ops.convert_to_tensor(inputs)
inputs_shape = inputs.shape
inputs_rank = inputs_shape.ndims
if inputs_rank is None:
raise ValueError('Inputs %s has undefined rank.' % inputs.name)
dtype = inputs.dtype.base_dtype
if begin_norm_axis < 0:
begin_norm_axis = inputs_rank + begin_norm_axis
if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank:
raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) must be < rank(inputs) (%d)' % (begin_params_axis, begin_norm_axis, inputs_rank))
params_shape = inputs_shape[begin_params_axis:]
if not params_shape.is_fully_defined():
raise ValueError('Inputs %s: shape(inputs)[%s:] is not fully defined: %s' % (inputs.name, begin_params_axis, inputs_shape))

# Allocate parameters for the beta and gamma of the normalization.
beta, gamma = None, None
if center:
beta_collections = utils.get_variable_collections(variables_collections,'beta')
beta = variables.model_variable('beta', shape=params_shape,dtype=dtype, initializer=init_ops.zeros_initializer(), collections=beta_collections, trainable=trainable)
if scale:
gamma_collections = utils.get_variable_collections(variables_collections, 'gamma')
gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=init_ops.ones_initializer(), collections=gamma_collections, trainable=trainable)

# Calculate the moments on the last axis (layer activations).
norm_axes = list(range(begin_norm_axis, inputs_rank))
mean, variance = nn.moments(inputs, norm_axes, keep_dims=True)

# Compute layer normalization using the batch_normalization function.
variance_epsilon = 1e-12
outputs = nn.batch_normalization(inputs, mean, variance, offset=beta, scale=gamma, variance_epsilon=variance_epsilon)
outputs.set_shape(inputs_shape)
if activation_fn is not None:
outputs = activation_fn(outputs)
return utils.collect_named_outputs(outputs_collections, sc.name, outputs)