forked from mindspore-Ecosystem/mindspore
code refine for BN docs
This commit is contained in:
parent
049acf6d58
commit
8438221259
|
@ -118,7 +118,6 @@ std::pair<std::string, bool> CudaEnvChecker::IsCudaRealPath(const std::string &p
|
|||
valid_path = (end == real_path.size() - 1) ? true : ((end == real_path.size() - 2) && (real_path.back() == '/'));
|
||||
return {real_path.substr(0, end + 1), valid_path};
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -625,7 +625,21 @@ class FusedBatchNorm(Primitive):
|
|||
|
||||
class FusedBatchNormEx(PrimitiveWithInfer):
|
||||
r"""
|
||||
FusedBatchNormEx is an extension of FusedBatchNorm
|
||||
FusedBatchNormEx is an extension of FusedBatchNorm, FusedBatchNormEx has one more output(output reserve)
|
||||
than FusedBatchNorm, reserve will be used in backpropagation phase. FusedBatchNorm is a BatchNorm that
|
||||
moving mean and moving variance will be computed instead of being loaded.
|
||||
|
||||
Batch Normalization is widely used in convolutional networks. This operation applies
|
||||
Batch Normalization over input to avoid internal covariate shift as described in the
|
||||
paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
|
||||
Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
|
||||
feature using a mini-batch of data and the learned parameters which can be described
|
||||
in the following formula.
|
||||
|
||||
.. math::
|
||||
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
||||
|
||||
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
||||
|
||||
Args:
|
||||
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
|
||||
|
@ -635,21 +649,25 @@ class FusedBatchNormEx(PrimitiveWithInfer):
|
|||
Momentum value should be [0, 1]. Default: 0.9.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
- **scale** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **bias** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`,
|
||||
data type: float16 or float32.
|
||||
- **scale** (Tensor) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`,
|
||||
data type: float32.
|
||||
- **bias** (Tensor) - Parameter bias, same with beta above-mentioned, Tensor of shape :math:`(C,)`,
|
||||
data type: float32.
|
||||
- **mean** (Tensor) - mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
- **variance** (Tensor) - variance value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
|
||||
Outputs:
|
||||
Tuple of 6 Tensor, the normalized input and the updated parameters.
|
||||
Tuple of 6 Tensor, the normalized input, the updated parameters and reserve.
|
||||
|
||||
- **output_x** (Tensor) - The same type and shape as the `input_x`.
|
||||
- **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **reserve** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **output_x** (Tensor) - The input of FusedBatchNormEx, same type and shape as the `input_x`.
|
||||
- **updated_scale** (Tensor) - Updated parameter scale, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
- **updated_bias** (Tensor) - Updated parameter bias, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
- **updated_moving_mean** (Tensor) - Updated mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
- **updated_moving_variance** (Tensor) - Updated variance value, Tensor of shape :math:`(C,)`,
|
||||
data type: float32.
|
||||
- **reserve** (Tensor) - reserve space, Tensor of shape :math:`(C,)`, data type: float32.
|
||||
|
||||
Examples:
|
||||
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
||||
|
|
Loading…
Reference in New Issue