forked from mindspore-Ecosystem/mindspore
!2194 fix FakeQuantPerLayer/FakeQuantPerLayerGrad symmetric=True calculation error bug
Merge pull request !2194 from 王东旭/master
This commit is contained in:
commit
74c3e15675
|
@ -214,7 +214,7 @@ class BatchNormFoldCell(Cell):
|
|||
Batch normalization folded.
|
||||
|
||||
Args:
|
||||
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
|
||||
momentum (float): Momentum value should be [0, 1]. Default: 0.9.
|
||||
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
|
||||
float32 else 1e-3. Default: 1e-5.
|
||||
freeze_bn (int): Delay in steps at which computation switches from regular batch
|
||||
|
@ -280,6 +280,7 @@ class FakeQuantWithMinMax(Cell):
|
|||
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization by layer or channel. Default: False.
|
||||
channel_axis (int): Quantization by channel axis. Default: 1.
|
||||
out_channels (int): declarate the min and max channel size, Default: 1.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
|
@ -391,17 +392,17 @@ class Conv2dBatchNormQuant(Cell):
|
|||
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
||||
padding: (int): Implicit paddings on both sides of the input. Default: 0.
|
||||
eps (int): Parameters for BatchNormal. Default: 1e-5.
|
||||
momentum (int): Parameters for BatchNormal op. Default: 0.9.
|
||||
momentum (int): Parameters for BatchNormal op. Default: 0.997.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
convolution kernel. Default: 'None'.
|
||||
convolution kernel. Default: 'normal'.
|
||||
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
beta vector. Default: 'None'.
|
||||
beta vector. Default: 'zeros'.
|
||||
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
gamma vector. Default: 'None'.
|
||||
gamma vector. Default: 'ones'.
|
||||
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
mean vector. Default: 'None'.
|
||||
mean vector. Default: 'zeros'.
|
||||
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
variance vector. Default: 'None'.
|
||||
variance vector. Default: 'ones'.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
|
||||
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
|
||||
|
@ -434,11 +435,11 @@ class Conv2dBatchNormQuant(Cell):
|
|||
group=1,
|
||||
eps=1e-5,
|
||||
momentum=0.997,
|
||||
weight_init=None,
|
||||
beta_init=None,
|
||||
gamma_init=None,
|
||||
mean_init=None,
|
||||
var_init=None,
|
||||
weight_init='normal',
|
||||
beta_init='zeros',
|
||||
gamma_init='ones',
|
||||
mean_init='zeros',
|
||||
var_init='ones',
|
||||
quant_delay=0,
|
||||
freeze_bn=100000,
|
||||
fake=True,
|
||||
|
@ -477,8 +478,7 @@ class Conv2dBatchNormQuant(Cell):
|
|||
pad=padding,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation)
|
||||
if weight_init is None:
|
||||
weight_init = initializer('normal', [1, in_channels, *self.kernel_size])
|
||||
weight_shape = [1, in_channels, *self.kernel_size]
|
||||
channel_axis = 1
|
||||
else:
|
||||
self.conv = P.Conv2D(out_channel=out_channels,
|
||||
|
@ -488,24 +488,16 @@ class Conv2dBatchNormQuant(Cell):
|
|||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
group=group)
|
||||
if weight_init is None:
|
||||
weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size])
|
||||
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
|
||||
channel_axis = 0
|
||||
self.weight = Parameter(weight_init, name='weight')
|
||||
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
||||
|
||||
# initialize batchnorm Parameter
|
||||
if gamma_init is None:
|
||||
gamma_init = initializer('ones', [out_channels])
|
||||
self.gamma = Parameter(gamma_init, name='gamma')
|
||||
if beta_init is None:
|
||||
beta_init = initializer('zeros', [out_channels])
|
||||
self.beta = Parameter(beta_init, name='beta')
|
||||
if mean_init is None:
|
||||
mean_init = initializer('zeros', [out_channels])
|
||||
self.moving_mean = Parameter(mean_init, name='moving_mean', requires_grad=False)
|
||||
if var_init is None:
|
||||
var_init = initializer('ones', [out_channels])
|
||||
self.moving_variance = Parameter(var_init, name='moving_variance', requires_grad=False)
|
||||
self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma')
|
||||
self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta')
|
||||
self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False)
|
||||
self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance',
|
||||
requires_grad=False)
|
||||
|
||||
# initialize fake ops
|
||||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||
|
@ -588,8 +580,8 @@ class Conv2dQuant(Cell):
|
|||
divisible by the number of groups. Default: 1.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
Default: None.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: None.
|
||||
Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
|
@ -619,8 +611,8 @@ class Conv2dQuant(Cell):
|
|||
dilation=1,
|
||||
group=1,
|
||||
has_bias=False,
|
||||
weight_init=None,
|
||||
bias_init=None,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
quant_delay=0,
|
||||
num_bits=8,
|
||||
per_channel=False,
|
||||
|
@ -641,15 +633,14 @@ class Conv2dQuant(Cell):
|
|||
self.group = group
|
||||
self.quant_delay = quant_delay
|
||||
|
||||
if weight_init is None:
|
||||
weight_init = initializer(
|
||||
'normal', [out_channels, in_channels // group, *self.kernel_size])
|
||||
self.weight = Parameter(weight_init, name='weight')
|
||||
if bias_init is None:
|
||||
bias_init = initializer('zeros', [out_channels])
|
||||
if has_bias:
|
||||
self.bias = Parameter(bias_init, name='bias')
|
||||
self.bias_add = P.BiasAdd()
|
||||
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
|
||||
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
||||
|
||||
self.bias_add = P.BiasAdd()
|
||||
if check_bool(has_bias):
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.conv = P.Conv2D(out_channel=self.out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
|
@ -738,8 +729,8 @@ class DenseQuant(Cell):
|
|||
self.has_bias = check_bool(has_bias)
|
||||
|
||||
if isinstance(weight_init, Tensor):
|
||||
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
|
||||
weight_init.shape[1] != in_channels:
|
||||
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
|
||||
weight_init.shape()[1] != in_channels:
|
||||
raise ValueError("weight_init shape error")
|
||||
|
||||
self.weight = Parameter(initializer(
|
||||
|
@ -747,7 +738,7 @@ class DenseQuant(Cell):
|
|||
|
||||
if self.has_bias:
|
||||
if isinstance(bias_init, Tensor):
|
||||
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
|
||||
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
|
||||
raise ValueError("bias_init shape error")
|
||||
|
||||
self.bias = Parameter(initializer(
|
||||
|
|
|
@ -65,7 +65,6 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
|
|||
momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW",
|
||||
kernel_name="batchnorm_fold"):
|
||||
"""batchnorm_fold TBE op"""
|
||||
momentum = 1.0 - momentum
|
||||
util.check_kernel_name(kernel_name)
|
||||
data_format = data_format.upper()
|
||||
if data_format != "NCHW":
|
||||
|
@ -120,13 +119,12 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
|
|||
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
|
||||
mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
|
||||
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
|
||||
|
||||
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
|
||||
if num == 1:
|
||||
batch_var_scaler = 0.0
|
||||
else:
|
||||
batch_var_scaler = float(num) / (num - 1)
|
||||
batch_variance = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
|
||||
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_variance, epsilon))
|
||||
batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
|
||||
|
||||
factor = 1.0 - momentum
|
||||
factor_reverse = momentum
|
||||
|
@ -134,7 +132,7 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
|
|||
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
|
||||
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
|
||||
|
||||
var_mul = te.lang.cce.vmuls(batch_variance, factor)
|
||||
var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
|
||||
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
|
||||
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)
|
||||
|
||||
|
|
|
@ -50,15 +50,16 @@ def _fake_quant_per_layer_tbe():
|
|||
|
||||
|
||||
@fusion_manager.register("fake_quant_per_layer")
|
||||
def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max,
|
||||
def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max, symmetric,
|
||||
kernel_name="fake_quant_per_layer"):
|
||||
"""FakeQuantPerLayer"""
|
||||
shape = te.lang.cce.util.shape_to_list(x.shape)
|
||||
shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
|
||||
quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype)
|
||||
quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype)
|
||||
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
|
||||
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
|
||||
if symmetric:
|
||||
max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val)
|
||||
min_val = te.lang.cce.vmuls(max_val, -1.)
|
||||
|
||||
# CalNudge(NudgeMinMax)
|
||||
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
|
||||
|
@ -119,12 +120,8 @@ def fake_quant_per_layer(x, min_val, max_val, y,
|
|||
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
|
||||
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
|
||||
|
||||
if symmetric:
|
||||
quant_min = 0 - 2 ** (num_bits - 1)
|
||||
quant_max = 2 ** (num_bits - 1) - 1
|
||||
else:
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
|
@ -132,7 +129,7 @@ def fake_quant_per_layer(x, min_val, max_val, y,
|
|||
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
|
||||
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
|
||||
res = fake_quant_per_layer_compute(input_data, min_data, max_data, y,
|
||||
quant_min, quant_max, kernel_name)
|
||||
quant_min, quant_max, symmetric, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
|
|
|
@ -78,7 +78,7 @@ def _fake_quant_per_layer_grad_tbe():
|
|||
|
||||
|
||||
@fusion_manager.register("fake_quant_per_layer_grad")
|
||||
def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max,
|
||||
def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, symmetric,
|
||||
kernel_name="fake_quant_per_layer_grad"):
|
||||
"""FakeQuantPerLayerGrad"""
|
||||
shape = te.lang.cce.util.shape_to_list(x.shape)
|
||||
|
@ -88,6 +88,10 @@ def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quan
|
|||
quant_min = te.lang.cce.broadcast(quant_min, shape_min)
|
||||
quant_max = te.lang.cce.broadcast(quant_max, shape_min)
|
||||
|
||||
if symmetric:
|
||||
max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val)
|
||||
min_val = te.lang.cce.vmuls(max_val, -1.)
|
||||
|
||||
# CalNudge(NudgeMinMax)
|
||||
scale = te.lang.cce.vdiv(te.lang.cce.vsub(
|
||||
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
|
||||
|
@ -142,12 +146,8 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx,
|
|||
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
|
||||
shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
|
||||
|
||||
if symmetric:
|
||||
quant_min = 0 - 2 ** (num_bits - 1)
|
||||
quant_max = 2 ** (num_bits - 1) - 1
|
||||
else:
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
quant_min = 0
|
||||
quant_max = 2 ** num_bits - 1
|
||||
if narrow_range:
|
||||
quant_min = quant_min + 1
|
||||
|
||||
|
@ -155,8 +155,8 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx,
|
|||
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
|
||||
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
|
||||
res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, quant_min,
|
||||
quant_max, kernel_name)
|
||||
res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data,
|
||||
quant_min, quant_max, symmetric, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res)
|
||||
|
|
|
@ -58,7 +58,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
|
|||
BiasAdd, Conv2D,
|
||||
DepthwiseConv2dNative,
|
||||
DropoutDoMask, DropoutGrad, Dropout,
|
||||
DropoutGenMask, Flatten, FusedBatchNorm,
|
||||
DropoutGenMask, Flatten, FusedBatchNorm, BNTrainingReduce, BNTrainingUpdate,
|
||||
Gelu, Elu,
|
||||
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss,
|
||||
LogSoftmax,
|
||||
|
@ -76,7 +76,6 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
|
|||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
|
||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
||||
CheckValid, MakeRefKey, Partial, Depend, CheckBprop)
|
||||
from . import _quant_ops
|
||||
from ._quant_ops import *
|
||||
from .thor_ops import *
|
||||
|
||||
|
@ -101,6 +100,9 @@ __all__ = [
|
|||
'Conv2D',
|
||||
'Flatten',
|
||||
'MaxPoolWithArgmax',
|
||||
'FusedBatchNorm',
|
||||
'BNTrainingReduce',
|
||||
'BNTrainingUpdate',
|
||||
'BatchNorm',
|
||||
'MaxPool',
|
||||
'TopK',
|
||||
|
@ -311,5 +313,4 @@ __all__ = [
|
|||
"InTopK"
|
||||
]
|
||||
|
||||
__all__.extend(_quant_ops.__all__)
|
||||
__all__.sort()
|
||||
|
|
|
@ -36,7 +36,6 @@ __all__ = ["FakeQuantPerLayer",
|
|||
"BatchNormFold2Grad",
|
||||
"BatchNormFoldD",
|
||||
"BatchNormFoldGradD",
|
||||
"BNTrainingReduce",
|
||||
"BatchNormFold2_D",
|
||||
"BatchNormFold2GradD",
|
||||
"BatchNormFold2GradReduce",
|
||||
|
@ -334,7 +333,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
|||
Batch normalization folded.
|
||||
|
||||
Args:
|
||||
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
|
||||
momentum (float): Momentum value should be [0, 1]. Default: 0.9.
|
||||
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
|
||||
float32 else 1e-3. Default: 1e-5.
|
||||
is_training (bool): In training mode set True, else set False. Default: True.
|
||||
|
@ -366,7 +365,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
|||
channel_axis = 1
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
"""init batch norm fold layer"""
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||
|
@ -731,32 +730,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
|
|||
return x_type
|
||||
|
||||
|
||||
class BNTrainingReduce(PrimitiveWithInfer):
|
||||
"""
|
||||
reduce sum at axis [0, 2, 3].
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
|
||||
Outputs:
|
||||
- **x_sum** (Tensor) - Tensor has the same shape as x.
|
||||
- **x_square_sum** (Tensor) - Tensor has the same shape as x.
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""init _BNTrainingReduce layer"""
|
||||
self.init_prim_io_names(inputs=['x'],
|
||||
outputs=['x_sum', 'x_square_sum'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return [x_shape[1]], [x_shape[1]]
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
return x_type, x_type
|
||||
|
||||
|
||||
class BatchNormFold2_D(PrimitiveWithInfer):
|
||||
"""
|
||||
Scale the bias with a correction factor to the long term statistics
|
||||
|
|
|
@ -585,6 +585,50 @@ class FusedBatchNorm(Primitive):
|
|||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
|
||||
|
||||
class BNTrainingReduce(PrimitiveWithInfer):
|
||||
"""
|
||||
reduce sum at axis [0, 2, 3].
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
|
||||
|
||||
Outputs:
|
||||
- **sum** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
- **square_sum** (Tensor) - Tensor of shape :math:`(C,)`.
|
||||
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum'])
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
|
||||
return ([x_shape[1]], [x_shape[1]])
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
return (x_type, x_type)
|
||||
|
||||
|
||||
class BNTrainingUpdate(PrimitiveWithInfer):
|
||||
"""
|
||||
primitive operator of bn_training_update's register and info descriptor
|
||||
"""
|
||||
@prim_attr_register
|
||||
def __init__(self, isRef=True, epsilon=1e-5, factor=0.1):
|
||||
self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'],
|
||||
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
|
||||
#self.isRef = validator.check_integer('isRef', isRef, [0, 1], Rel.IN)
|
||||
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate')
|
||||
self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate')
|
||||
|
||||
def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
|
||||
return (x, variance, variance, variance, variance)
|
||||
|
||||
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
|
||||
return (x, variance, variance, variance, variance)
|
||||
|
||||
|
||||
class BatchNorm(PrimitiveWithInfer):
|
||||
r"""
|
||||
Batch Normalization for input data and updated parameters.
|
||||
|
|
|
@ -28,7 +28,7 @@ context.set_context(device_target='GPU')
|
|||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.op = P.BatchNormFold(freeze_bn=10)
|
||||
self.op = P.BatchNormFold(momentum=0.9, freeze_bn=10)
|
||||
|
||||
@ms_function
|
||||
def construct(self, x, mean, variance, current_step):
|
||||
|
@ -40,8 +40,8 @@ def np_result(x, mean, var, momentum, epsilon):
|
|||
np_mean = x.mean(axis=(0, 2, 3))
|
||||
np_var = x.var(axis=(0, 2, 3))
|
||||
n = x.shape[0] * x.shape[2] * x.shape[3]
|
||||
mean_update = momentum * np_mean + (1 - momentum) * mean
|
||||
var_update = momentum * np_var * n / (n - 1) + (1 - momentum) * var
|
||||
mean_update = (1 - momentum) * np_mean + momentum * mean
|
||||
var_update = (1 - momentum) * np_var * n / (n - 1) + momentum * var
|
||||
np_var = np.sqrt(np_var + epsilon)
|
||||
delay_mean = mean.copy()
|
||||
delay_std = np.sqrt(var + epsilon)
|
||||
|
|
Loading…
Reference in New Issue