| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 
 | def batch_norm(inputs,decay=0.999,
 center=True,
 scale=False,
 epsilon=0.001,
 activation_fn=None,
 param_initializers=None,
 param_regularizers=None,
 updates_collections=ops.GraphKeys.UPDATE_OPS,
 is_training=True,
 reuse=None,
 variables_collections=None,
 outputs_collections=None,
 trainable=True,
 batch_weights=None,
 fused=None,
 data_format=DATA_FORMAT_NHWC,
 zero_debias_moving_mean=False,
 scope=None,
 renorm=False,
 renorm_clipping=None,
 renorm_decay=0.99,
 adjustment=None):
 """Adds a Batch Normalization layer from .
 1. Can be used as a normalizer function for conv2d and fully_connected.
 2. Args:
 - inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`.
 The normalization is over all but the last dimension if `data_format` is `NHWC`
 and the second dimension if `data_format` is `NCHW`.
 也就是说:channel通道不做归一化,即 batch norm 取不同样本的同一个通道的特征做归一化;
 - data_format: A string. `NHWC` (default) and `NCHW` are supported.
 - reuse的相关内容看源码
 - updates_collections:when training, the moving_mean and moving_variance need to be updated.
 By default the update ops are placed in updates_collections = tf.GraphKeys.UPDATE_OPS.
 If None, a control dependency would be added to make sure the updates are computed in place.
 3. Returns:
 A `Tensor` representing the output of the operation.
 """
 inputs = ops.convert_to_tensor(inputs)
 rank = inputs.get_shape().ndims
 
 if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
 raise ValueError('data_format has to be either NCHW or NHWC.')
 
 layer_variable_getter = _build_variable_getter()
 with variable_scope.variable_scope(scope,'BatchNorm', [inputs],reuse=reuse, custom_getter=layer_variable_getter) as sc:
 inputs = ops.convert_to_tensor(inputs)
 inputs_shape = inputs.get_shape()
 inputs_rank = inputs_shape.ndims
 if inputs_rank is None:
 raise ValueError('Inputs %s has undefined rank.' % inputs.name)
 dtype = inputs.dtype.base_dtype
 ......
 
 if data_format == DATA_FORMAT_NCHW:
 moments_axes = [0] + list(range(2, inputs_rank))
 params_shape = inputs_shape[1:2]
 
 params_shape_broadcast = list([1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
 else:
 moments_axes = list(range(inputs_rank - 1))
 params_shape = inputs_shape[-1:]
 params_shape_broadcast = None
 
 
 beta, gamma = None, None
 ......
 beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable)
 gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable)
 
 
 ......
 moving_mean = variables.model_variable('moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections)
 moving_variance = variables.model_variable('moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections)
 
 
 is_training_value = utils.constant_value(is_training)
 need_moments = is_training_value is None or is_training_value
 if need_moments:
 
 ......
 if data_format == DATA_FORMAT_NCHW:
 
 mean, variance = nn.moments(inputs, moments_axes, keep_dims=True)
 mean = array_ops.reshape(mean, [-1])
 variance = array_ops.reshape(variance, [-1])
 else:
 mean, variance = nn.moments(inputs, moments_axes)
 
 
 moving_vars_fn = lambda: (moving_mean, moving_variance)
 if updates_collections is None:
 def _force_updates():
 """Internal function forces updates moving_vars if is_training."""
 update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
 update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay, zero_debias=False)
 with ops.control_dependencies([update_moving_mean, update_moving_variance]):
 return array_ops.identity(mean), array_ops.identity(variance)
 
 mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn)
 else:
 def _delay_updates():
 """Internal function that delay updates moving_vars if is_training."""
 update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
 update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay, zero_debias=False)
 return update_moving_mean, update_moving_variance
 
 update_mean, update_variance = utils.smart_cond(is_training, _delay_updates, moving_vars_fn)
 ops.add_to_collections(updates_collections, update_mean)
 ops.add_to_collections(updates_collections, update_variance)
 
 
 vars_fn = lambda: (mean, variance)
 mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn)
 
 else:
 mean, variance = moving_mean, moving_variance
 if data_format == DATA_FORMAT_NCHW:
 mean = array_ops.reshape(mean, params_shape_broadcast)
 variance = array_ops.reshape(variance, params_shape_broadcast)
 if beta is not None:
 beta = array_ops.reshape(beta, params_shape_broadcast)
 if gamma is not None:
 gamma = array_ops.reshape(gamma, params_shape_broadcast)
 
 
 outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)
 outputs.set_shape(inputs_shape)
 if activation_fn is not None:
 outputs = activation_fn(outputs)
 return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
 
 |