1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
| def batch_norm(inputs, decay=0.999, center=True, scale=False, epsilon=0.001, activation_fn=None, param_initializers=None, param_regularizers=None, updates_collections=ops.GraphKeys.UPDATE_OPS, is_training=True, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, batch_weights=None, fused=None, data_format=DATA_FORMAT_NHWC, zero_debias_moving_mean=False, scope=None, renorm=False, renorm_clipping=None, renorm_decay=0.99, adjustment=None): """Adds a Batch Normalization layer from . 1. Can be used as a normalizer function for conv2d and fully_connected. 2. Args: - inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. 也就是说:channel通道不做归一化,即 batch norm 取不同样本的同一个通道的特征做归一化; - data_format: A string. `NHWC` (default) and `NCHW` are supported. - reuse的相关内容看源码 - updates_collections:when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in updates_collections = tf.GraphKeys.UPDATE_OPS. If None, a control dependency would be added to make sure the updates are computed in place. 3. Returns: A `Tensor` representing the output of the operation. """ inputs = ops.convert_to_tensor(inputs) rank = inputs.get_shape().ndims
if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.')
layer_variable_getter = _build_variable_getter() with variable_scope.variable_scope(scope,'BatchNorm', [inputs],reuse=reuse, custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.get_shape() inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype ......
if data_format == DATA_FORMAT_NCHW: moments_axes = [0] + list(range(2, inputs_rank)) params_shape = inputs_shape[1:2] params_shape_broadcast = list([1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)]) else: moments_axes = list(range(inputs_rank - 1)) params_shape = inputs_shape[-1:] params_shape_broadcast = None
beta, gamma = None, None ...... beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable) gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) ...... moving_mean = variables.model_variable('moving_mean', shape=params_shape, dtype=dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) moving_variance = variables.model_variable('moving_variance', shape=params_shape, dtype=dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) is_training_value = utils.constant_value(is_training) need_moments = is_training_value is None or is_training_value if need_moments: ...... if data_format == DATA_FORMAT_NCHW: mean, variance = nn.moments(inputs, moments_axes, keep_dims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: mean, variance = nn.moments(inputs, moments_axes) moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay, zero_debias=False) with ops.control_dependencies([update_moving_mean, update_moving_variance]): return array_ops.identity(mean), array_ops.identity(variance) mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) else: def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance
update_mean, update_variance = utils.smart_cond(is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) vars_fn = lambda: (mean, variance) mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn) else: mean, variance = moving_mean, moving_variance if data_format == DATA_FORMAT_NCHW: mean = array_ops.reshape(mean, params_shape_broadcast) variance = array_ops.reshape(variance, params_shape_broadcast) if beta is not None: beta = array_ops.reshape(beta, params_shape_broadcast) if gamma is not None: gamma = array_ops.reshape(gamma, params_shape_broadcast) outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
|