!720 fix doc problem

Merge pull request !720 from JichenZhao/normfix
This commit is contained in:
mindspore-ci-bot 2020-04-27 19:57:55 +08:00 committed by Gitee
commit 2538f0ba79
1 changed files with 8 additions and 2 deletions

View File

@ -17,6 +17,7 @@ from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.context as context import mindspore.context as context
@ -166,6 +167,10 @@ class _BatchNorm(Cell):
return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance) self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
@constexpr
def _channel_check(channel, num_channel):
if channel != num_channel:
raise ValueError("the input channel is not equal with num_channel")
class BatchNorm1d(_BatchNorm): class BatchNorm1d(_BatchNorm):
r""" r"""
@ -324,7 +329,7 @@ class GlobalBatchNorm(_BatchNorm):
Args: Args:
num_features (int): `C` from an expected input of size (N, C, H, W). num_features (int): `C` from an expected input of size (N, C, H, W).
device_num_each_group (int): The number of device in each group. device_num_each_group (int): The number of devices in each group.
eps (float): A value added to the denominator for numerical stability. Default: 1e-5. eps (float): A value added to the denominator for numerical stability. Default: 1e-5.
momentum (float): A floating hyperparameter of the momentum for the momentum (float): A floating hyperparameter of the momentum for the
running_mean and running_var computation. Default: 0.9. running_mean and running_var computation. Default: 0.9.
@ -350,7 +355,7 @@ class GlobalBatchNorm(_BatchNorm):
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples: Examples:
>>> global_bn_op = nn.GlobalBatchNorm(num_features=3, group=4) >>> global_bn_op = nn.GlobalBatchNorm(num_features=3, device_num_each_group=4)
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> global_bn_op(input) >>> global_bn_op(input)
""" """
@ -507,6 +512,7 @@ class GroupNorm(Cell):
def construct(self, x): def construct(self, x):
batch, channel, height, width = self.shape(x) batch, channel, height, width = self.shape(x)
_channel_check(channel, self.num_channels)
x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups)) x = self.reshape(x, (batch, self.num_groups, channel*height*width/self.num_groups))
mean = self.reduce_mean(x, 2) mean = self.reduce_mean(x, 2)
var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1) var = self.reduce_sum(self.square(x - mean), 2) / (channel * height * width / self.num_groups - 1)