diff --git a/mindspore/python/mindspore/nn/layer/quant.py b/mindspore/python/mindspore/nn/layer/quant.py index 345d2857544..00c68190da6 100644 --- a/mindspore/python/mindspore/nn/layer/quant.py +++ b/mindspore/python/mindspore/nn/layer/quant.py @@ -759,8 +759,7 @@ class Conv2dBnFoldQuantOneConv(Cell): requires_grad=False) # initialize fake ops - self.fake_quant_weight = quant_config.weight(ema=False, - channel_axis=channel_axis, + self.fake_quant_weight = quant_config.weight(channel_axis=channel_axis, num_channels=out_channels) self.freeze_bn = False self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps, @@ -1025,8 +1024,7 @@ class Conv2dBnFoldQuant(Cell): requires_grad=False) # initialize fake ops - self.fake_quant_weight = quant_config.weight(ema=False, - channel_axis=channel_axis, + self.fake_quant_weight = quant_config.weight(channel_axis=channel_axis, num_channels=out_channels) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) self.correct_mul = Q.CorrectionMul(channel_axis) @@ -1257,8 +1255,7 @@ class Conv2dBnWithoutFoldQuant(Cell): weight_shape = [out_channels, in_channels // group, *self.kernel_size] channel_axis = 0 self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') - self.fake_quant_weight = quant_config.weight(ema=False, - channel_axis=channel_axis, + self.fake_quant_weight = quant_config.weight(channel_axis=channel_axis, num_channels=out_channels, quant_dtype=quant_dtype) self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum) @@ -1402,8 +1399,7 @@ class Conv2dQuant(Cell): dilation=self.dilation, group=self.group) channel_axis = 0 - self.fake_quant_weight = quant_config.weight(ema=False, - channel_axis=channel_axis, + self.fake_quant_weight = quant_config.weight(channel_axis=channel_axis, num_channels=out_channels) @classmethod @@ -1555,8 +1551,7 @@ class DenseQuant(Cell): f"but got {activation}.") self.activation_flag = self.activation is not None - self.fake_quant_weight = quant_config.weight(ema=False, - channel_axis=0, + self.fake_quant_weight = quant_config.weight(channel_axis=0, num_channels=out_channels) @classmethod