!1136 fix bn train&eval loss problem

Merge pull request !1136 from JichenZhao/bn_train_eval_loss_issue
This commit is contained in:
mindspore-ci-bot 2020-05-28 09:35:11 +08:00 committed by Gitee
commit 93a5201210
1 changed files with 21 additions and 11 deletions

View File

@ -43,7 +43,7 @@ class _BatchNorm(Cell):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
use_batch_statistics=None,
device_num_each_group=1):
super(_BatchNorm, self).__init__()
if num_features < 1:
@ -147,7 +147,11 @@ class _BatchNorm(Cell):
return y
def construct(self, x):
if self.training and self.use_batch_statistics:
if self.use_batch_statistics is None:
flag = self.training
else:
flag = self.use_batch_statistics
if flag:
if self.is_ge_backend and self.is_global:
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
y = self._global_sync(x, axes, re_shape)
@ -236,8 +240,10 @@ class BatchNorm1d(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -259,7 +265,7 @@ class BatchNorm1d(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
use_batch_statistics=None):
super(BatchNorm1d, self).__init__(num_features,
eps,
momentum,
@ -307,8 +313,10 @@ class BatchNorm2d(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -330,7 +338,7 @@ class BatchNorm2d(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
use_batch_statistics=None):
super(BatchNorm2d, self).__init__(num_features,
eps,
momentum,
@ -379,8 +387,10 @@ class GlobalBatchNorm(_BatchNorm):
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
'he_uniform', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
the mean value and variance value of specified value. Default: True.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. If None, training process will use the mean and
variance of current batch data and track the running mean and variance, eval process will use the running
mean and variance. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -402,7 +412,7 @@ class GlobalBatchNorm(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True,
use_batch_statistics=None,
device_num_each_group=1):
super(GlobalBatchNorm, self).__init__(num_features,
eps,