remove use_batch_statistics from InstanceNorm API

This commit is contained in:
tom__chen 2021-04-22 15:40:25 -04:00
parent d0978f709a
commit 6f1c12e4d7
4 changed files with 20 additions and 52 deletions

View File

@ -38,7 +38,6 @@ class InstanceNormGpuKernel : public GpuKernel {
workspace_size_(0),
mode_(CUDNN_BATCHNORM_SPATIAL),
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
is_training_(true),
epsilon_(10e-5),
exp_avg_factor_(0.1),
is_null_input_(false),
@ -89,21 +88,13 @@ class InstanceNormGpuKernel : public GpuKernel {
const float alpha = 1;
const float beta = 0;
float *reserve_addr = nullptr;
if (is_training_) {
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationForwardTrainingEx(
handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr,
scale_bias_mean_var_desc_, ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr,
save_variance_addr, nullptr, workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnBatchNormalizationForwardInference(
handle_, mode_, &alpha, &beta, x_desc_, x_addr, y_desc_, y_addr,
scale_bias_mean_var_desc_, ws_gamma, ws_beta, ws_mean, ws_var, epsilon_),
"Kernel launch failed");
}
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnBatchNormalizationForwardTrainingEx(
handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr, scale_bias_mean_var_desc_,
ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr, save_variance_addr, nullptr,
workspace_addr, workspace_size_, reserve_addr, 0),
"Kernel launch failed");
return true;
}
@ -114,8 +105,7 @@ class InstanceNormGpuKernel : public GpuKernel {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
InitResource();
is_training_ = GetAttr<bool>(kernel_node, "is_training");
mode_ = is_training_ ? CUDNN_BATCHNORM_SPATIAL_PERSISTENT : CUDNN_BATCHNORM_SPATIAL;
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum");
@ -220,7 +210,6 @@ class InstanceNormGpuKernel : public GpuKernel {
size_t workspace_size_;
cudnnBatchNormMode_t mode_;
cudnnBatchNormOps_t bn_ops_;
bool is_training_;
double epsilon_;
double exp_avg_factor_;
bool is_null_input_;

View File

@ -894,11 +894,7 @@ class InstanceNorm2d(Cell):
\gamma and \beta are learnable parameter vectors of size num_features if affine is True. The standard-deviation
is calculated via the biased estimator.
By default, this layer uses instance statistics computed from input data in both training and evaluation modes.
If use_batch_statistics is set to True, it means training phases, and this layer keeps running estimates of its
computed mean and variance, which are then used for normalization during evaluation. The running estimates are
kept with a default momentum of 0.1.
This layer uses instance statistics computed from input data in both training and evaluation modes.
InstanceNorm2d and BatchNorm2d are very similar, but have some differences. InstanceNorm2d is applied on each
channel of channeled data like RGB images, but BatchNorm2d is usually applied on each batch of batched data.
@ -918,12 +914,6 @@ class InstanceNorm2d(Cell):
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
use the mean value and variance value of specified value. Default: True.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32.
@ -940,12 +930,12 @@ class InstanceNorm2d(Cell):
TypeError: If `eps` is not a float.
TypeError: If `momentum` is not a float.
TypeError: If `affine` is not a bool.
TypeError: If the type of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is not same, or if
the initialized element type is not float32.
TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not
float32.
ValueError: If `num_features` is less than 1.
ValueError: If `momentum` is not in range [0, 1].
KeyError: If any of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is str and the homonymous
class inheriting from `Initializer` not exists.
KeyError: If any of `gamma_init`/`beta_init` is str and the homonymous class inheriting from `Initializer` not
exists.
Examples:
>>> import mindspore
@ -966,31 +956,24 @@ class InstanceNorm2d(Cell):
momentum=0.1,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True):
beta_init='zeros'):
super(InstanceNorm2d, self).__init__()
validator.check_value_type('num_features', num_features, [int], self.cls_name)
validator.check_value_type('eps', eps, [float], self.cls_name)
validator.check_value_type('momentum', momentum, [float], self.cls_name)
validator.check_value_type('affine', affine, [bool], self.cls_name)
args_input = {"gamma_init": gamma_init, "beta_init": beta_init,
"moving_mean_init": moving_mean_init, "moving_var_init": moving_var_init}
args_input = {"gamma_init": gamma_init, "beta_init": beta_init}
self.check_types_valid(args_input, 'InstanceNorm2d')
if num_features < 1:
raise ValueError("num_features must be at least 1")
if momentum < 0 or momentum > 1:
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
self.use_batch_statistics = use_batch_statistics
self.num_features = num_features
self.eps = eps
self.input_dims = '2d'
self.moving_mean = Parameter(initializer(
moving_mean_init, num_features), name="mean", requires_grad=False)
self.moving_variance = Parameter(initializer(
moving_var_init, num_features), name="variance", requires_grad=False)
self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False)
self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False)
self.gamma = Parameter(initializer(
gamma_init, num_features), name="gamma", requires_grad=affine)
self.beta = Parameter(initializer(
@ -998,9 +981,7 @@ class InstanceNorm2d(Cell):
self.shape = P.Shape()
self.momentum = momentum
self.instance_bn = P.InstanceNorm(is_training=self.use_batch_statistics,
epsilon=self.eps,
momentum=self.momentum)
self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum)
def _check_data_dim(self, x):
raise NotImplementedError

View File

@ -867,7 +867,6 @@ class InstanceNorm(PrimitiveWithInfer):
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
Args:
is_training (bool): Is training or inference. Default: True.
epsilon (float): A small value added for numerical stability. Default: 1e-5.
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
@ -934,10 +933,9 @@ class InstanceNorm(PrimitiveWithInfer):
)
@prim_attr_register
def __init__(self, is_training=True, epsilon=1e-5, momentum=0.1):
def __init__(self, epsilon=1e-5, momentum=0.1):
self.init_prim_io_names(inputs=['x', 'gamma', 'beta', 'mean', 'variance'],
outputs=['y', 'save_mean', 'save_variance'])
self.is_training = validator.check_bool(is_training, self.name)
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True

View File

@ -49,7 +49,7 @@ class Net(nn.Cell):
def test_InstanceNorm2d_fp32():
x_np = np.random.randn(3, 3, 2, 2).astype(np.float32)
bn_instance_comp = Net(3 * 3)
bn_instance_op = nn.InstanceNorm2d(3, use_batch_statistics=True, gamma_init=0.5, beta_init=0.5)
bn_instance_op = nn.InstanceNorm2d(3, gamma_init=0.5, beta_init=0.5)
comp_out = bn_instance_comp(Tensor(x_np))
op_out = bn_instance_op(Tensor(x_np))
assert np.allclose(comp_out.asnumpy(), op_out.asnumpy())