forked from mindspore-Ecosystem/mindspore
remove use_batch_statistics from InstanceNorm API
This commit is contained in:
parent
d0978f709a
commit
6f1c12e4d7
|
@ -38,7 +38,6 @@ class InstanceNormGpuKernel : public GpuKernel {
|
|||
workspace_size_(0),
|
||||
mode_(CUDNN_BATCHNORM_SPATIAL),
|
||||
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
|
||||
is_training_(true),
|
||||
epsilon_(10e-5),
|
||||
exp_avg_factor_(0.1),
|
||||
is_null_input_(false),
|
||||
|
@ -89,21 +88,13 @@ class InstanceNormGpuKernel : public GpuKernel {
|
|||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
float *reserve_addr = nullptr;
|
||||
if (is_training_) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnBatchNormalizationForwardTrainingEx(
|
||||
handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr,
|
||||
scale_bias_mean_var_desc_, ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr,
|
||||
save_variance_addr, nullptr, workspace_addr, workspace_size_, reserve_addr, 0),
|
||||
"Kernel launch failed");
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnBatchNormalizationForwardInference(
|
||||
handle_, mode_, &alpha, &beta, x_desc_, x_addr, y_desc_, y_addr,
|
||||
scale_bias_mean_var_desc_, ws_gamma, ws_beta, ws_mean, ws_var, epsilon_),
|
||||
"Kernel launch failed");
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnBatchNormalizationForwardTrainingEx(
|
||||
handle_, mode_, bn_ops_, &alpha, &beta, x_desc_, x_addr, z_desc_, z, y_desc_, y_addr, scale_bias_mean_var_desc_,
|
||||
ws_gamma, ws_beta, exp_avg_factor_, ws_mean, ws_var, epsilon_, save_mean_addr, save_variance_addr, nullptr,
|
||||
workspace_addr, workspace_size_, reserve_addr, 0),
|
||||
"Kernel launch failed");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -114,8 +105,7 @@ class InstanceNormGpuKernel : public GpuKernel {
|
|||
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
|
||||
|
||||
InitResource();
|
||||
is_training_ = GetAttr<bool>(kernel_node, "is_training");
|
||||
mode_ = is_training_ ? CUDNN_BATCHNORM_SPATIAL_PERSISTENT : CUDNN_BATCHNORM_SPATIAL;
|
||||
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
|
||||
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
|
||||
exp_avg_factor_ = GetAttr<float>(kernel_node, "momentum");
|
||||
|
||||
|
@ -220,7 +210,6 @@ class InstanceNormGpuKernel : public GpuKernel {
|
|||
size_t workspace_size_;
|
||||
cudnnBatchNormMode_t mode_;
|
||||
cudnnBatchNormOps_t bn_ops_;
|
||||
bool is_training_;
|
||||
double epsilon_;
|
||||
double exp_avg_factor_;
|
||||
bool is_null_input_;
|
||||
|
|
|
@ -894,11 +894,7 @@ class InstanceNorm2d(Cell):
|
|||
\gamma and \beta are learnable parameter vectors of size num_features if affine is True. The standard-deviation
|
||||
is calculated via the biased estimator.
|
||||
|
||||
By default, this layer uses instance statistics computed from input data in both training and evaluation modes.
|
||||
|
||||
If use_batch_statistics is set to True, it means training phases, and this layer keeps running estimates of its
|
||||
computed mean and variance, which are then used for normalization during evaluation. The running estimates are
|
||||
kept with a default momentum of 0.1.
|
||||
This layer uses instance statistics computed from input data in both training and evaluation modes.
|
||||
|
||||
InstanceNorm2d and BatchNorm2d are very similar, but have some differences. InstanceNorm2d is applied on each
|
||||
channel of channeled data like RGB images, but BatchNorm2d is usually applied on each batch of batched data.
|
||||
|
@ -918,12 +914,6 @@ class InstanceNorm2d(Cell):
|
|||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'zeros'.
|
||||
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
|
||||
The values of str refer to the function `initializer` including 'zeros', 'ones', etc. Default: 'ones'.
|
||||
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If false,
|
||||
use the mean value and variance value of specified value. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32.
|
||||
|
@ -940,12 +930,12 @@ class InstanceNorm2d(Cell):
|
|||
TypeError: If `eps` is not a float.
|
||||
TypeError: If `momentum` is not a float.
|
||||
TypeError: If `affine` is not a bool.
|
||||
TypeError: If the type of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is not same, or if
|
||||
the initialized element type is not float32.
|
||||
TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not
|
||||
float32.
|
||||
ValueError: If `num_features` is less than 1.
|
||||
ValueError: If `momentum` is not in range [0, 1].
|
||||
KeyError: If any of `gamma_init`/`beta_init`/`moving_mean_init`/`moving_var_init` is str and the homonymous
|
||||
class inheriting from `Initializer` not exists.
|
||||
KeyError: If any of `gamma_init`/`beta_init` is str and the homonymous class inheriting from `Initializer` not
|
||||
exists.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
|
@ -966,31 +956,24 @@ class InstanceNorm2d(Cell):
|
|||
momentum=0.1,
|
||||
affine=True,
|
||||
gamma_init='ones',
|
||||
beta_init='zeros',
|
||||
moving_mean_init='zeros',
|
||||
moving_var_init='ones',
|
||||
use_batch_statistics=True):
|
||||
beta_init='zeros'):
|
||||
super(InstanceNorm2d, self).__init__()
|
||||
validator.check_value_type('num_features', num_features, [int], self.cls_name)
|
||||
validator.check_value_type('eps', eps, [float], self.cls_name)
|
||||
validator.check_value_type('momentum', momentum, [float], self.cls_name)
|
||||
validator.check_value_type('affine', affine, [bool], self.cls_name)
|
||||
args_input = {"gamma_init": gamma_init, "beta_init": beta_init,
|
||||
"moving_mean_init": moving_mean_init, "moving_var_init": moving_var_init}
|
||||
args_input = {"gamma_init": gamma_init, "beta_init": beta_init}
|
||||
self.check_types_valid(args_input, 'InstanceNorm2d')
|
||||
if num_features < 1:
|
||||
raise ValueError("num_features must be at least 1")
|
||||
|
||||
if momentum < 0 or momentum > 1:
|
||||
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
|
||||
self.use_batch_statistics = use_batch_statistics
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.input_dims = '2d'
|
||||
self.moving_mean = Parameter(initializer(
|
||||
moving_mean_init, num_features), name="mean", requires_grad=False)
|
||||
self.moving_variance = Parameter(initializer(
|
||||
moving_var_init, num_features), name="variance", requires_grad=False)
|
||||
self.moving_mean = Parameter(initializer('zeros', num_features), name="mean", requires_grad=False)
|
||||
self.moving_variance = Parameter(initializer('ones', num_features), name="variance", requires_grad=False)
|
||||
self.gamma = Parameter(initializer(
|
||||
gamma_init, num_features), name="gamma", requires_grad=affine)
|
||||
self.beta = Parameter(initializer(
|
||||
|
@ -998,9 +981,7 @@ class InstanceNorm2d(Cell):
|
|||
|
||||
self.shape = P.Shape()
|
||||
self.momentum = momentum
|
||||
self.instance_bn = P.InstanceNorm(is_training=self.use_batch_statistics,
|
||||
epsilon=self.eps,
|
||||
momentum=self.momentum)
|
||||
self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum)
|
||||
|
||||
def _check_data_dim(self, x):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -867,7 +867,6 @@ class InstanceNorm(PrimitiveWithInfer):
|
|||
where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon.
|
||||
|
||||
Args:
|
||||
is_training (bool): Is training or inference. Default: True.
|
||||
epsilon (float): A small value added for numerical stability. Default: 1e-5.
|
||||
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
|
||||
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
|
||||
|
@ -934,10 +933,9 @@ class InstanceNorm(PrimitiveWithInfer):
|
|||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, is_training=True, epsilon=1e-5, momentum=0.1):
|
||||
def __init__(self, epsilon=1e-5, momentum=0.1):
|
||||
self.init_prim_io_names(inputs=['x', 'gamma', 'beta', 'mean', 'variance'],
|
||||
outputs=['y', 'save_mean', 'save_variance'])
|
||||
self.is_training = validator.check_bool(is_training, self.name)
|
||||
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
||||
self._update_parameter = True
|
||||
|
|
|
@ -49,7 +49,7 @@ class Net(nn.Cell):
|
|||
def test_InstanceNorm2d_fp32():
|
||||
x_np = np.random.randn(3, 3, 2, 2).astype(np.float32)
|
||||
bn_instance_comp = Net(3 * 3)
|
||||
bn_instance_op = nn.InstanceNorm2d(3, use_batch_statistics=True, gamma_init=0.5, beta_init=0.5)
|
||||
bn_instance_op = nn.InstanceNorm2d(3, gamma_init=0.5, beta_init=0.5)
|
||||
comp_out = bn_instance_comp(Tensor(x_np))
|
||||
op_out = bn_instance_op(Tensor(x_np))
|
||||
assert np.allclose(comp_out.asnumpy(), op_out.asnumpy())
|
||||
|
|
Loading…
Reference in New Issue