From 6f1c12e4d7d8bdbf724f814a1e241eac7d8f09c7 Mon Sep 17 00:00:00 2001 From: tom__chen Date: Thu, 22 Apr 2021 15:40:25 -0400 Subject: [PATCH] remove use_batch_statistics from InstanceNorm API --- .../gpu/nn/instance_norm_gpu_kernel.h | 27 ++++--------- mindspore/nn/layer/normalization.py | 39 +++++-------------- mindspore/ops/operations/nn_ops.py | 4 +- tests/st/ops/gpu/test_instancenorm2d.py | 2 +- 4 files changed, 20 insertions(+), 52 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h index 631deec3f22..60809ab376b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/instance_norm_gpu_kernel.h @@ -38,7 +38,6 @@ class InstanceNormGpuKernel : public GpuKernel { workspace_size_(0), mode_(CUDNN_BATCHNORM_SPATIAL), bn_ops_(CUDNN_BATCHNORM_OPS_BN), - is_training_(true), epsilon_(10e-5), exp_avg_factor_(0.1), is_null_input_(false), @@ -89,21 +88,13 @@ class InstanceNormGpuKernel : public GpuKernel { const float alpha = 1; const float beta = 0; float *reserve_addr = nullptr; - if (is_training_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnBatchNormalizationForwardTrainingEx( - handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr, - scale_bias_mean_var_desc_, ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr, - save_variance_addr, nullptr, workspace_addr, workspace_size_, reserve_addr, 0), - "Kernel launch failed"); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, - cudnnBatchNormalizationForwardInference( - handle_, mode_, &alpha, &beta, x_desc_, x_addr, y_desc_, y_addr, - scale_bias_mean_var_desc_, ws_gamma, ws_beta, ws_mean, ws_var, epsilon_), - "Kernel launch failed"); - } + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnBatchNormalizationForwardTrainingEx( + handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr, scale_bias_mean_var_desc_, + ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr, save_variance_addr, nullptr, + workspace_addr, workspace_size_, reserve_addr, 0), + "Kernel launch failed"); return true; } @@ -114,8 +105,7 @@ class InstanceNormGpuKernel : public GpuKernel { bn_ops_ = CUDNN_BATCHNORM_OPS_BN; InitResource(); - is_training_ = GetAttr(kernel_node, "is_training"); - mode_ = is_training_ ? CUDNN_BATCHNORM_SPATIAL_PERSISTENT : CUDNN_BATCHNORM_SPATIAL; + mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; epsilon_ = GetAttr(kernel_node, "epsilon"); exp_avg_factor_ = GetAttr(kernel_node, "momentum"); @@ -220,7 +210,6 @@ class InstanceNormGpuKernel : public GpuKernel { size_t workspace_size_; cudnnBatchNormMode_t mode_; cudnnBatchNormOps_t bn_ops_; - bool is_training_; double epsilon_; double exp_avg_factor_; bool is_null_input_; diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 0befa723071..54869e864e7 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -894,11 +894,7 @@ class InstanceNorm2d(Cell): \gamma and \beta are learnable parameter vectors of size num_features if affine is True. The standard-deviation is calculated via the biased estimator. - By default, this layer uses instance statistics computed from input data in both training and evaluation modes. - - If use_batch_statistics is set to True, it means training phases, and this layer keeps running estimates of its - computed mean and variance, which are then used for normalization during evaluation. The running estimates are - kept with a default momentum of 0.1. + This layer uses instance statistics computed from input data in both training and evaluation modes. InstanceNorm2d and BatchNorm2d are very similar, but have some differences. InstanceNorm2d is applied on each channel of channeled data like RGB images, but BatchNorm2d is usually applied on each batch of batched data. @@ -918,12 +914,6 @@ class InstanceNorm2d(Cell): The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. - moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. - The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'. - moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. - The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'. - use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false, - use the mean value and variance value of specified value. Default: True. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32. @@ -940,12 +930,12 @@ class InstanceNorm2d(Cell): TypeError: If `eps` is not a float. TypeError: If `momentum` is not a float. TypeError: If `affine` is not a bool. - TypeError: If the type of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is not same, or if - the initialized element type is not float32. + TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not + float32. ValueError: If `num_features` is less than 1. ValueError: If `momentum` is not in range [0, 1]. - KeyError: If any of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is str and the homonymous - class inheriting from `Initializer` not exists. + KeyError: If any of `gamma_init`/`beta_init` is str and the homonymous class inheriting from `Initializer` not + exists. Examples: >>> import mindspore @@ -966,31 +956,24 @@ class InstanceNorm2d(Cell): momentum=0.1, affine=True, gamma_init='ones', - beta_init='zeros', - moving_mean_init='zeros', - moving_var_init='ones', - use_batch_statistics=True): + beta_init='zeros'): super(InstanceNorm2d, self).__init__() validator.check_value_type('num_features', num_features, [int], self.cls_name) validator.check_value_type('eps', eps, [float], self.cls_name) validator.check_value_type('momentum', momentum, [float], self.cls_name) validator.check_value_type('affine', affine, [bool], self.cls_name) - args_input = {"gamma_init": gamma_init, "beta_init": beta_init, - "moving_mean_init": moving_mean_init, "moving_var_init": moving_var_init} + args_input = {"gamma_init": gamma_init, "beta_init": beta_init} self.check_types_valid(args_input, 'InstanceNorm2d') if num_features < 1: raise ValueError("num_features must be at least 1") if momentum < 0 or momentum > 1: raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) - self.use_batch_statistics = use_batch_statistics self.num_features = num_features self.eps = eps self.input_dims = '2d' - self.moving_mean = Parameter(initializer( - moving_mean_init, num_features), name="mean", requires_grad=False) - self.moving_variance = Parameter(initializer( - moving_var_init, num_features), name="variance", requires_grad=False) + self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False) + self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False) self.gamma = Parameter(initializer( gamma_init, num_features), name="gamma", requires_grad=affine) self.beta = Parameter(initializer( @@ -998,9 +981,7 @@ class InstanceNorm2d(Cell): self.shape = P.Shape() self.momentum = momentum - self.instance_bn = P.InstanceNorm(is_training=self.use_batch_statistics, - epsilon=self.eps, - momentum=self.momentum) + self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum) def _check_data_dim(self, x): raise NotImplementedError diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 9697b75e638..b1cff415933 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -867,7 +867,6 @@ class InstanceNorm(PrimitiveWithInfer): where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. Args: - is_training (bool): Is training or inference. Default: True. epsilon (float): A small value added for numerical stability. Default: 1e-5. momentum (float): The hyper parameter to compute moving average for running_mean and running_var (e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`). @@ -934,10 +933,9 @@ class InstanceNorm(PrimitiveWithInfer): ) @prim_attr_register - def __init__(self, is_training=True, epsilon=1e-5, momentum=0.1): + def __init__(self, epsilon=1e-5, momentum=0.1): self.init_prim_io_names(inputs=['x', 'gamma', 'beta', 'mean', 'variance'], outputs=['y', 'save_mean', 'save_variance']) - self.is_training = validator.check_bool(is_training, self.name) self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) self._update_parameter = True diff --git a/tests/st/ops/gpu/test_instancenorm2d.py b/tests/st/ops/gpu/test_instancenorm2d.py index bdf9fa548db..1cf8c304de4 100644 --- a/tests/st/ops/gpu/test_instancenorm2d.py +++ b/tests/st/ops/gpu/test_instancenorm2d.py @@ -49,7 +49,7 @@ class Net(nn.Cell): def test_InstanceNorm2d_fp32(): x_np = np.random.randn(3, 3, 2, 2).astype(np.float32) bn_instance_comp = Net(3 * 3) - bn_instance_op = nn.InstanceNorm2d(3, use_batch_statistics=True, gamma_init=0.5, beta_init=0.5) + bn_instance_op = nn.InstanceNorm2d(3, gamma_init=0.5, beta_init=0.5) comp_out = bn_instance_comp(Tensor(x_np)) op_out = bn_instance_op(Tensor(x_np)) assert np.allclose(comp_out.asnumpy(), op_out.asnumpy())