forked from mindspore-Ecosystem/mindspore
code refine for BN docs
This commit is contained in:
parent
049acf6d58
commit
8438221259
|
@ -118,7 +118,6 @@ std::pair<std::string, bool> CudaEnvChecker::IsCudaRealPath(const std::string &p
|
||||||
valid_path = (end == real_path.size() - 1) ? true : ((end == real_path.size() - 2) && (real_path.back() == '/'));
|
valid_path = (end == real_path.size() - 1) ? true : ((end == real_path.size() - 2) && (real_path.back() == '/'));
|
||||||
return {real_path.substr(0, end + 1), valid_path};
|
return {real_path.substr(0, end + 1), valid_path};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -625,7 +625,21 @@ class FusedBatchNorm(Primitive):
|
||||||
|
|
||||||
class FusedBatchNormEx(PrimitiveWithInfer):
|
class FusedBatchNormEx(PrimitiveWithInfer):
|
||||||
r"""
|
r"""
|
||||||
FusedBatchNormEx is an extension of FusedBatchNorm
|
FusedBatchNormEx is an extension of FusedBatchNorm, FusedBatchNormEx has one more output(output reserve)
|
||||||
|
than FusedBatchNorm, reserve will be used in backpropagation phase. FusedBatchNorm is a BatchNorm that
|
||||||
|
moving mean and moving variance will be computed instead of being loaded.
|
||||||
|
|
||||||
|
Batch Normalization is widely used in convolutional networks. This operation applies
|
||||||
|
Batch Normalization over input to avoid internal covariate shift as described in the
|
||||||
|
paper `Batch Normalization: Accelerating Deep Network Training by Reducing Internal
|
||||||
|
Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
|
||||||
|
feature using a mini-batch of data and the learned parameters which can be described
|
||||||
|
in the following formula.
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
||||||
|
|
||||||
|
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
|
mode (int): Mode of batch normalization, value is 0 or 1. Default: 0.
|
||||||
|
@ -635,21 +649,25 @@ class FusedBatchNormEx(PrimitiveWithInfer):
|
||||||
Momentum value should be [0, 1]. Default: 0.9.
|
Momentum value should be [0, 1]. Default: 0.9.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
- **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`,
|
||||||
- **scale** (Tensor) - Tensor of shape :math:`(C,)`.
|
data type: float16 or float32.
|
||||||
- **bias** (Tensor) - Tensor of shape :math:`(C,)`.
|
- **scale** (Tensor) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`,
|
||||||
- **mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
data type: float32.
|
||||||
- **variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
- **bias** (Tensor) - Parameter bias, same with beta above-mentioned, Tensor of shape :math:`(C,)`,
|
||||||
|
data type: float32.
|
||||||
|
- **mean** (Tensor) - mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||||
|
- **variance** (Tensor) - variance value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tuple of 6 Tensor, the normalized input and the updated parameters.
|
Tuple of 6 Tensor, the normalized input, the updated parameters and reserve.
|
||||||
|
|
||||||
- **output_x** (Tensor) - The same type and shape as the `input_x`.
|
- **output_x** (Tensor) - The input of FusedBatchNormEx, same type and shape as the `input_x`.
|
||||||
- **updated_scale** (Tensor) - Tensor of shape :math:`(C,)`.
|
- **updated_scale** (Tensor) - Updated parameter scale, Tensor of shape :math:`(C,)`, data type: float32.
|
||||||
- **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`.
|
- **updated_bias** (Tensor) - Updated parameter bias, Tensor of shape :math:`(C,)`, data type: float32.
|
||||||
- **updated_moving_mean** (Tensor) - Tensor of shape :math:`(C,)`.
|
- **updated_moving_mean** (Tensor) - Updated mean value, Tensor of shape :math:`(C,)`, data type: float32.
|
||||||
- **updated_moving_variance** (Tensor) - Tensor of shape :math:`(C,)`.
|
- **updated_moving_variance** (Tensor) - Updated variance value, Tensor of shape :math:`(C,)`,
|
||||||
- **reserve** (Tensor) - Tensor of shape :math:`(C,)`.
|
data type: float32.
|
||||||
|
- **reserve** (Tensor) - reserve space, Tensor of shape :math:`(C,)`, data type: float32.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
>>> input_x = Tensor(np.ones([128, 64, 32, 64]), mindspore.float32)
|
||||||
|
|
Loading…
Reference in New Issue