From 59993c4843080c782d5c6417c91220e777594604 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Wed, 13 May 2020 05:48:03 -0400 Subject: [PATCH] fix bn train&eval loss problem --- mindspore/nn/layer/normalization.py | 32 +++++++++++++++++++---------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 16124c126a0..5119e9168b1 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -43,7 +43,7 @@ class _BatchNorm(Cell): beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', - use_batch_statistics=True, + use_batch_statistics=None, device_num_each_group=1): super(_BatchNorm, self).__init__() if num_features < 1: @@ -147,7 +147,11 @@ class _BatchNorm(Cell): return y def construct(self, x): - if self.training and self.use_batch_statistics: + if self.use_batch_statistics is None: + flag = self.training + else: + flag = self.use_batch_statistics + if flag: if self.is_ge_backend and self.is_global: axes, re_shape = _shape_infer(F.shape(x), self.num_features) y = self._global_sync(x, axes, re_shape) @@ -236,8 +240,10 @@ class BatchNorm1d(_BatchNorm): moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 'he_uniform', etc. Default: 'ones'. - use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use - the mean value and variance value of specified value. Default: True. + use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, + use the mean value and variance value of specified value. If None, training process will use the mean and + variance of current batch data and track the running mean and variance, eval process will use the running + mean and variance. Default: None. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -259,7 +265,7 @@ class BatchNorm1d(_BatchNorm): beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', - use_batch_statistics=True): + use_batch_statistics=None): super(BatchNorm1d, self).__init__(num_features, eps, momentum, @@ -307,8 +313,10 @@ class BatchNorm2d(_BatchNorm): moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 'he_uniform', etc. Default: 'ones'. - use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use - the mean value and variance value of specified value. Default: True. + use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, + use the mean value and variance value of specified value. If None, training process will use the mean and + variance of current batch data and track the running mean and variance, eval process will use the running + mean and variance. Default: None. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -330,7 +338,7 @@ class BatchNorm2d(_BatchNorm): beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', - use_batch_statistics=True): + use_batch_statistics=None): super(BatchNorm2d, self).__init__(num_features, eps, momentum, @@ -379,8 +387,10 @@ class GlobalBatchNorm(_BatchNorm): moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', 'he_uniform', etc. Default: 'ones'. - use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use - the mean value and variance value of specified value. Default: True. + use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, + use the mean value and variance value of specified value. If None, training process will use the mean and + variance of current batch data and track the running mean and variance, eval process will use the running + mean and variance. Default: None. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -402,7 +412,7 @@ class GlobalBatchNorm(_BatchNorm): beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', - use_batch_statistics=True, + use_batch_statistics=None, device_num_each_group=1): super(GlobalBatchNorm, self).__init__(num_features, eps,