forked from mindspore-Ecosystem/mindspore
!1136 fix bn train&eval loss problem
Merge pull request !1136 from JichenZhao/bn_train_eval_loss_issue
This commit is contained in:
commit
93a5201210
|
@ -43,7 +43,7 @@ class _BatchNorm(Cell):
|
|||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=True,
|
||||
use_batch_statistics=None,
|
||||
device_num_each_group=1):
|
||||
super(_BatchNorm, self).__init__()
|
||||
if num_features < 1:
|
||||
|
@ -147,7 +147,11 @@ class _BatchNorm(Cell):
|
|||
return y
|
||||
|
||||
def construct(self, x):
|
||||
if self.training and self.use_batch_statistics:
|
||||
if self.use_batch_statistics is None:
|
||||
flag = self.training
|
||||
else:
|
||||
flag = self.use_batch_statistics
|
||||
if flag:
|
||||
if self.is_ge_backend and self.is_global:
|
||||
axes, re_shape = _shape_infer(F.shape(x), self.num_features)
|
||||
y = self._global_sync(x, axes, re_shape)
|
||||
|
@ -236,8 +240,10 @@ class BatchNorm1d(_BatchNorm):
|
|||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
|
||||
the mean value and variance value of specified value. Default: True.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
|
||||
use the mean value and variance value of specified value. If None, training process will use the mean and
|
||||
variance of current batch data and track the running mean and variance, eval process will use the running
|
||||
mean and variance. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
@ -259,7 +265,7 @@ class BatchNorm1d(_BatchNorm):
|
|||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=True):
|
||||
use_batch_statistics=None):
|
||||
super(BatchNorm1d, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
|
@ -307,8 +313,10 @@ class BatchNorm2d(_BatchNorm):
|
|||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
|
||||
the mean value and variance value of specified value. Default: True.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
|
||||
use the mean value and variance value of specified value. If None, training process will use the mean and
|
||||
variance of current batch data and track the running mean and variance, eval process will use the running
|
||||
mean and variance. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
@ -330,7 +338,7 @@ class BatchNorm2d(_BatchNorm):
|
|||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=True):
|
||||
use_batch_statistics=None):
|
||||
super(BatchNorm2d, self).__init__(num_features,
|
||||
eps,
|
||||
momentum,
|
||||
|
@ -379,8 +387,10 @@ class GlobalBatchNorm(_BatchNorm):
|
|||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform',
|
||||
'he_uniform', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use
|
||||
the mean value and variance value of specified value. Default: True.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
|
||||
use the mean value and variance value of specified value. If None, training process will use the mean and
|
||||
variance of current batch data and track the running mean and variance, eval process will use the running
|
||||
mean and variance. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
@ -402,7 +412,7 @@ class GlobalBatchNorm(_BatchNorm):
|
|||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=True,
|
||||
use_batch_statistics=None,
|
||||
device_num_each_group=1):
|
||||
super(GlobalBatchNorm, self).__init__(num_features,
|
||||
eps,
|
||||
|
|
Loading…
Reference in New Issue