From a84affffd7ac08afc9d433bdb2b19b7b3fc3d787 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Fri, 16 Oct 2020 11:20:20 +0800 Subject: [PATCH] add QuantConfig & modify quant cells --- mindspore/nn/layer/quant.py | 458 +++++------------- mindspore/train/quant/quant.py | 129 +++-- .../models/resnet_quant_manual.py | 31 +- .../resnet50_quant/resnet_quant_manual.py | 41 +- 4 files changed, 231 insertions(+), 428 deletions(-) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 22d49598a53..17ae5bcd52c 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -15,6 +15,7 @@ """Quantization aware training.""" from functools import partial +from collections import namedtuple import numpy as np from mindspore import nn import mindspore.common.dtype as mstype @@ -34,7 +35,7 @@ from ...ops.operations import _quant_ops as Q __all__ = [ 'Conv2dBnAct', 'DenseBnAct', - 'FakeQuantWithMinMax', + 'FakeQuantWithMinMaxObserver', 'Conv2dBnFoldQuant', 'Conv2dBnWithoutFoldQuant', 'Conv2dQuant', @@ -422,14 +423,14 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): symmetric=False, narrow_range=False, quant_delay=0): - """Initialize FakeQuantWithMinMax layer""" + """Initialize FakeQuantWithMinMaxObserver""" super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel, symmetric=symmetric, narrow_range=narrow_range, num_channels=num_channels) Validator.check_type("min_init", min_init, [int, float]) Validator.check_type("max_init", max_init, [int, float]) Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) - Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE) + Validator.check_non_negative_int(quant_delay, 'quant_delay') self.min_init = min_init self.max_init = max_init self.quant_dtype = quant_dtype @@ -498,119 +499,9 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): return out -class FakeQuantWithMinMax(Cell): - r""" - Quantization aware op. This OP provides the fake quantization observer function on data with min and max. +QuantConfig = namedtuple("QuantConfig", ['weight', 'activation']) - Args: - min_init (int, float): The initialized min value. Default: -6. - max_init (int, float): The initialized max value. Default: 6. - ema (bool): The exponential Moving Average algorithm updates min and max. Default: False. - ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. - channel_axis (int): Quantization by channel axis. Default: 1. - num_channels (int): declarate the min and max channel size, Default: 1. - num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. - symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. - narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according to the global step. Default: 0. - - Inputs: - - **x** (Tensor) - The input of FakeQuantWithMinMax. - - Outputs: - Tensor, with the same type and shape as the `x`. - - Examples: - >>> fake_quant = FakeQuantWithMinMax() - >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) - >>> result = fake_quant(input_x) - """ - - def __init__(self, - min_init=-6, - max_init=6, - ema=False, - ema_decay=0.999, - per_channel=False, - channel_axis=1, - num_channels=1, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): - """Initialize FakeQuantWithMinMax layer""" - super(FakeQuantWithMinMax, self).__init__() - Validator.check_type("min_init", min_init, [int, float]) - Validator.check_type("max_init", max_init, [int, float]) - Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT) - Validator.check_non_negative_int(quant_delay, 'quant_delay') - self.min_init = min_init - self.max_init = max_init - self.num_bits = num_bits - self.ema = ema - self.ema_decay = ema_decay - self.per_channel = per_channel - self.num_channels = num_channels - self.channel_axis = channel_axis - self.quant_delay = quant_delay - self.symmetric = symmetric - self.narrow_range = narrow_range - self.is_ascend = context.get_context('device_target') == "Ascend" - - # init tensor min and max for fake quant op - if self.per_channel: - min_array = np.array([self.min_init] * self.num_channels).astype(np.float32) - max_array = np.array([self.max_init] * self.num_channels).astype(np.float32) - else: - min_array = np.array([self.min_init]).astype(np.float32) - max_array = np.array([self.max_init]).astype(np.float32) - self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) - self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) - - # init fake quant relative op - if self.per_channel: - quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) - ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis) - else: - quant_fun = Q.FakeQuantPerLayer - ema_fun = Q.MinMaxUpdatePerLayer - - self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay) - if self.is_ascend: - self.fake_quant_train = quant_fun(num_bits=self.num_bits, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - quant_delay=self.quant_delay) - self.fake_quant_infer = self.fake_quant_train - else: - quant_fun = partial(quant_fun, - ema=self.ema, - ema_decay=ema_decay, - num_bits=self.num_bits, - symmetric=self.symmetric, - narrow_range=self.narrow_range, - quant_delay=self.quant_delay) - self.fake_quant_train = quant_fun(training=True) - self.fake_quant_infer = quant_fun(training=False) - - def extend_repr(self): - s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ - 'quant_delay={}, min_init={}, max_init={}'.format(self.num_bits, self.symmetric, self.narrow_range, - self.ema, self.ema_decay, self.per_channel, - self.channel_axis, self.num_channels, self.quant_delay, - self.min_init, self.max_init) - return s - - def construct(self, x): - if self.training: - min_up, max_up = self.ema_update(x, self.minq, self.maxq) - P.Assign()(self.minq, min_up) - P.Assign()(self.maxq, max_up) - out = self.fake_quant_train(x, self.minq, self.maxq) - else: - out = self.fake_quant_infer(x, self.minq, self.maxq) - return out +quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver, activation=FakeQuantWithMinMaxObserver) class Conv2dBnFoldQuant(Cell): @@ -641,12 +532,9 @@ class Conv2dBnFoldQuant(Cell): mean vector. Default: 'zeros'. var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the variance vector. Default: 'ones'. - fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMax op. Default: True. - per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. - num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. - symmetric (bool): The quantization algorithm is symmetric or not. Default: False. - narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. - quant_delay (int): The Quantization delay parameters according to the global step. Default: 0. + fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True. + quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. + quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. freeze_bn (int): The quantization freeze BatchNormal op is according to the global step. Default: 100000. Inputs: @@ -680,11 +568,8 @@ class Conv2dBnFoldQuant(Cell): mean_init='zeros', var_init='ones', fake=True, - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0, + quant_config=quant_config_default, + quant_dtype=QuantDtype.INT8, freeze_bn=100000): """Initialize Conv2dBnFoldQuant layer""" super(Conv2dBnFoldQuant, self).__init__() @@ -699,13 +584,10 @@ class Conv2dBnFoldQuant(Cell): self.eps = eps self.momentum = momentum self.has_bias = has_bias - self.quant_delay = quant_delay self.freeze_bn = freeze_bn self.fake = fake - self.num_bits = num_bits - self.per_channel = per_channel - self.symmetric = symmetric - self.narrow_range = narrow_range + self.quant_config = quant_config + self.quant_dtype = quant_dtype self.is_gpu = context.get_context('device_target') == "GPU" # initialize convolution op and Parameter @@ -745,16 +627,12 @@ class Conv2dBnFoldQuant(Cell): requires_grad=False) # initialize fake ops - self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, + self.fake_quant_weight = quant_config.weight(min_init=-6, max_init=6, ema=False, - per_channel=per_channel, channel_axis=channel_axis, num_channels=out_channels, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) + quant_dtype=quant_dtype) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) self.correct_mul = Q.CorrectionMul(channel_axis) if context.get_context('device_target') == "Ascend": @@ -777,7 +655,7 @@ class Conv2dBnFoldQuant(Cell): self.pad_mode, self.padding, self.dilation, self.group, self.fake, self.freeze_bn, self.momentum, - self.quant_delay) + self.fake_quant_weight.quant_delay) return s def construct(self, x): @@ -836,11 +714,8 @@ class Conv2dBnWithoutFoldQuant(Cell): weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. Default: 'normal'. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. - per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. - num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. - symmetric (bool): The quantization algorithm is symmetric or not. Default: False. - narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according to the global step. Default: 0. + quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. + quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -868,11 +743,8 @@ class Conv2dBnWithoutFoldQuant(Cell): momentum=0.997, weight_init='normal', bias_init='zeros', - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): + quant_config=quant_config_default, + quant_dtype=QuantDtype.INT8): super(Conv2dBnWithoutFoldQuant, self).__init__() if isinstance(kernel_size, int): self.kernel_size = (kernel_size, kernel_size) @@ -886,7 +758,6 @@ class Conv2dBnWithoutFoldQuant(Cell): self.pad_mode = pad_mode self.padding = padding self.group = group - self.quant_delay = quant_delay self.bias_add = P.BiasAdd() if Validator.check_bool(has_bias): @@ -917,16 +788,12 @@ class Conv2dBnWithoutFoldQuant(Cell): weight_shape = [out_channels, in_channels // group, *self.kernel_size] channel_axis = 0 self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') - self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, + self.fake_quant_weight = quant_config.weight(min_init=-6, max_init=6, ema=False, - per_channel=per_channel, channel_axis=channel_axis, num_channels=out_channels, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) + quant_dtype=quant_dtype) self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum) def construct(self, x): @@ -942,7 +809,7 @@ class Conv2dBnWithoutFoldQuant(Cell): 'pad_mode={}, padding={}, dilation={}, group={}, ' \ 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding, self.dilation, self.group, - self.has_bias, self.quant_delay) + self.has_bias, self.fake_quant_weight.quant_delay) return s @@ -966,11 +833,8 @@ class Conv2dQuant(Cell): weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. Default: 'normal'. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. - per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. - num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. - symmetric (bool): The quantization algorithm is symmetric or not. Default: False. - narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according to the global step. Default: 0. + quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. + quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -996,11 +860,8 @@ class Conv2dQuant(Cell): has_bias=False, weight_init='normal', bias_init='zeros', - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): + quant_config=quant_config_default, + quant_dtype=QuantDtype.INT8): super(Conv2dQuant, self).__init__() if isinstance(kernel_size, int): self.kernel_size = (kernel_size, kernel_size) @@ -1014,7 +875,6 @@ class Conv2dQuant(Cell): self.pad_mode = pad_mode self.padding = padding self.group = group - self.quant_delay = quant_delay weight_shape = [out_channels, in_channels // group, *self.kernel_size] self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') @@ -1033,16 +893,12 @@ class Conv2dQuant(Cell): stride=self.stride, dilation=self.dilation, group=self.group) - self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, + self.fake_quant_weight = quant_config.weight(min_init=-6, max_init=6, ema=False, - per_channel=per_channel, channel_axis=0, num_channels=out_channels, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) + quant_dtype=quant_dtype) def construct(self, x): weight = self.fake_quant_weight(self.weight) @@ -1056,7 +912,7 @@ class Conv2dQuant(Cell): 'pad_mode={}, padding={}, dilation={}, group={}, ' \ 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding, self.dilation, self.group, - self.has_bias, self.quant_delay) + self.has_bias, self.fake_quant_weight.quant_delay) return s @@ -1075,11 +931,8 @@ class DenseQuant(Cell): same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None. - per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. - num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. - symmetric (bool): The quantization algorithm is symmetric or not. Default: False. - narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according to the global step. Default: 0. + quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. + quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -1093,19 +946,15 @@ class DenseQuant(Cell): >>> result = dense_quant(input_x) """ - def __init__( - self, - in_channels, - out_channels, - weight_init='normal', - bias_init='zeros', - has_bias=True, - activation=None, - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): + def __init__(self, + in_channels, + out_channels, + weight_init='normal', + bias_init='zeros', + has_bias=True, + activation=None, + quant_config=quant_config_default, + quant_dtype=QuantDtype.INT8): super(DenseQuant, self).__init__() self.in_channels = Validator.check_positive_int(in_channels) self.out_channels = Validator.check_positive_int(out_channels) @@ -1132,16 +981,12 @@ class DenseQuant(Cell): self.activation = get_activation(activation) self.activation_flag = self.activation is not None - self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, + self.fake_quant_weight = quant_config.weight(min_init=-6, max_init=6, ema=False, - per_channel=per_channel, channel_axis=0, num_channels=out_channels, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) + quant_dtype=quant_dtype) def construct(self, x): """Use operators to construct the Dense layer.""" @@ -1179,16 +1024,13 @@ class ActQuant(_QuantActivation): Quantization aware training activation function. Add the fake quant op to the end of activation op, by which the output of activation op will be truncated. - Please check `FakeQuantWithMinMax` for more details. + Please check `FakeQuantWithMinMaxObserver` or other observer for more details. Args: activation (Cell): Activation cell class. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. - num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. - symmetric (bool): The quantization algorithm is symmetric or not. Default: False. - narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according to the global steps. Default: 0. + quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. + quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. Inputs: - **x** (Tensor) - The input of ReLU6Quant. @@ -1205,21 +1047,14 @@ class ActQuant(_QuantActivation): def __init__(self, activation, ema_decay=0.999, - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): + quant_config=quant_config_default, + quant_dtype=QuantDtype.INT8): super(ActQuant, self).__init__() - self.fake_quant_act = FakeQuantWithMinMax(min_init=0, - max_init=6, - ema=True, - ema_decay=ema_decay, - per_channel=per_channel, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) + self.fake_quant_act = quant_config.activation(min_init=-6, + max_init=6, + ema=False, + ema_decay=ema_decay, + quant_dtype=quant_dtype) self.act = activation def construct(self, x): @@ -1240,11 +1075,8 @@ class LeakyReLUQuant(_QuantActivation): Args: activation (Cell): Activation cell class. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. - num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. - symmetric (bool): The quantization algorithm is symmetric or not. Default: False. - narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according to the global step. Default: 0. + quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. + quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. Inputs: - **x** (Tensor) - The input of LeakyReLUQuant. @@ -1261,30 +1093,19 @@ class LeakyReLUQuant(_QuantActivation): def __init__(self, activation, ema_decay=0.999, - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): + quant_config=quant_config_default, + quant_dtype=QuantDtype.INT8): super(LeakyReLUQuant, self).__init__() - self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=True, - ema_decay=ema_decay, - per_channel=per_channel, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) - self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=True, - ema_decay=ema_decay, - per_channel=per_channel, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) + self.fake_quant_act_before = quant_config.activation(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + quant_dtype=quant_dtype) + self.fake_quant_act_after = quant_config.activation(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + quant_dtype=quant_dtype) if issubclass(activation.__class__, nn.LeakyReLU): self.act = activation else: @@ -1309,11 +1130,8 @@ class HSwishQuant(_QuantActivation): Args: activation (Cell): Activation cell class. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. - num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. - symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. - narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according to the global step. Default: 0. + quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. + quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. Inputs: - **x** (Tensor) - The input of HSwishQuant. @@ -1330,30 +1148,19 @@ class HSwishQuant(_QuantActivation): def __init__(self, activation, ema_decay=0.999, - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): + quant_config=quant_config_default, + quant_dtype=QuantDtype.INT8): super(HSwishQuant, self).__init__() - self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=True, - ema_decay=ema_decay, - per_channel=per_channel, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) - self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=True, - ema_decay=ema_decay, - per_channel=per_channel, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) + self.fake_quant_act_before = quant_config.activation(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + quant_dtype=quant_dtype) + self.fake_quant_act_after = quant_config.activation(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + quant_dtype=quant_dtype) if issubclass(activation.__class__, nn.HSwish): self.act = activation else: @@ -1378,11 +1185,8 @@ class HSigmoidQuant(_QuantActivation): Args: activation (Cell): Activation cell class. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. - num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. - symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False. - narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according to the global step. Default: 0. + quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. + quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. Inputs: - **x** (Tensor) - The input of HSigmoidQuant. @@ -1399,30 +1203,19 @@ class HSigmoidQuant(_QuantActivation): def __init__(self, activation, ema_decay=0.999, - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): + quant_config=quant_config_default, + quant_dtype=QuantDtype.INT8): super(HSigmoidQuant, self).__init__() - self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=True, - ema_decay=ema_decay, - per_channel=per_channel, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) - self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=True, - ema_decay=ema_decay, - per_channel=per_channel, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) + self.fake_quant_act_before = quant_config.activation(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + quant_dtype=quant_dtype) + self.fake_quant_act_after = quant_config.activation(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + quant_dtype=quant_dtype) if issubclass(activation.__class__, nn.HSigmoid): self.act = activation else: @@ -1446,11 +1239,8 @@ class TensorAddQuant(Cell): Args: ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. - num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. - symmetric (bool): The quantization algorithm is symmetric or not. Default: False. - narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according to the global step. Default: 0. + quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. + quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. Inputs: - **x** (Tensor) - The input of TensorAddQuant. @@ -1467,21 +1257,14 @@ class TensorAddQuant(Cell): def __init__(self, ema_decay=0.999, - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): + quant_config=quant_config_default, + quant_dtype=QuantDtype.INT8): super(TensorAddQuant, self).__init__() - self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=True, - ema_decay=ema_decay, - per_channel=per_channel, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) + self.fake_quant_act = quant_config.activation(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + quant_dtype=quant_dtype) self.add = P.TensorAdd() def construct(self, x1, x2): @@ -1498,11 +1281,8 @@ class MulQuant(Cell): Args: ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. - per_channel (bool): Quantization granularity based on layer or on channel. Default: False. - num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8. - symmetric (bool): The quantization algorithm is symmetric or not. Default: False. - narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False. - quant_delay (int): Quantization delay parameters according to the global step. Default: 0. + quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default. + quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8. Inputs: - **x** (Tensor) - The input of MulQuant. @@ -1510,25 +1290,23 @@ class MulQuant(Cell): Outputs: Tensor, with the same type and shape as the `x`. + Examples: + >>> mul_quant = nn.MulQuant() + >>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32) + >>> input_y = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32) + >>> result = mul_quant(input_x, input_y) """ def __init__(self, ema_decay=0.999, - per_channel=False, - num_bits=8, - symmetric=False, - narrow_range=False, - quant_delay=0): + quant_config=quant_config_default, + quant_dtype=QuantDtype.INT8): super(MulQuant, self).__init__() - self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=True, - ema_decay=ema_decay, - per_channel=per_channel, - num_bits=num_bits, - symmetric=symmetric, - narrow_range=narrow_range, - quant_delay=quant_delay) + self.fake_quant_act = quant_config.activation(min_init=-6, + max_init=6, + ema=True, + ema_decay=ema_decay, + quant_dtype=quant_dtype) self.mul = P.Mul() def construct(self, x1, x2): diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index f9b6aa2d761..6fcf6d76097 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -27,6 +27,7 @@ from ...common import Tensor from ...common import dtype as mstype from ...common.api import _executor from ...nn.layer import quant +from ...compression.common import QuantDtype from ...ops import functional as F from ...ops import operations as P from ...ops.operations import _inner_ops as inner @@ -41,6 +42,46 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant, nn.HSwish: quant.HSwishQuant} +def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver), + quant_delay=(0, 0), + quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), + per_channel=(False, False), + symmetric=(False, False), + narrow_range=(False, False) + ): + r""" + Configs the oberser type of weights and data flow with quant params. + + Args: + quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent + weights and second element represent data flow. + Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver) + quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during + eval. The first element represent weights and second element represent data flow. Default: (0, 0) + quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first + element represent weights and second element represent data flow. + Default: (QuantDtype.INT8, QuantDtype.INT8) + per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` + then base on per channel otherwise base on per layer. The first element represent weights + and second element represent data flow. Default: (False, False) + symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on + symmetric otherwise base on asymmetric. The first element represent weights and second + element represent data flow. Default: (False, False) + narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. + The first element represents weights and the second element represents data flow. Default: (False, False) + + Returns: + QuantConfig, Contains the oberser type of weight and activation. + """ + weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0], + per_channel=per_channel[0], symmetric=symmetric[0], + narrow_range=narrow_range[0]) + act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1], + per_channel=per_channel[-1], symmetric=symmetric[-1], + narrow_range=narrow_range[-1]) + return quant.QuantConfig(weight=weight_observer, activation=act_observer) + + class _AddFakeQuantInput(nn.Cell): """ Add FakeQuant OP at input of the network. Only support one input case. @@ -48,7 +89,8 @@ class _AddFakeQuantInput(nn.Cell): def __init__(self, network, quant_delay=0): super(_AddFakeQuantInput, self).__init__(auto_prefix=False) - self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) + self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6, + quant_delay=quant_delay, ema=True) self.fake_quant_input.update_parameters_name('fake_quant_input.') self.network = network @@ -66,14 +108,14 @@ class _AddFakeQuantAfterSubCell(nn.Cell): def __init__(self, subcell, **kwargs): super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) self.subcell = subcell - self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6, - max_init=6, - ema=True, - num_bits=kwargs["num_bits"], - quant_delay=kwargs["quant_delay"], - per_channel=kwargs["per_channel"], - symmetric=kwargs["symmetric"], - narrow_range=kwargs["narrow_range"]) + self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6, + max_init=6, + ema=True, + quant_dtype=kwargs["quant_dtype"], + quant_delay=kwargs["quant_delay"], + per_channel=kwargs["per_channel"], + symmetric=kwargs["symmetric"], + narrow_range=kwargs["narrow_range"]) def construct(self, *data): output = self.subcell(*data) @@ -93,8 +135,8 @@ class ConvertToQuantNetwork: self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay") self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold") self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn") - self.weight_bits = Validator.check_non_negative_int(kwargs["num_bits"][0], "weights bit") - self.act_bits = Validator.check_int(kwargs["num_bits"][-1], 0, Rel.GE, "activations bit") + self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype) + self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype) self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel") self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel") self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric") @@ -103,6 +145,11 @@ class ConvertToQuantNetwork: self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range") self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, quant.DenseBnAct: self._convert_dense} + self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"], + quant_dtype=kwargs["quant_dtype"], + per_channel=kwargs["per_channel"], + symmetric=kwargs["symmetric"], + narrow_range=kwargs["narrow_range"]) def _convert_op_name(self, name): pattern = re.compile(r'([A-Z]{1})') @@ -149,7 +196,7 @@ class ConvertToQuantNetwork: for name, prim_op in add_list: prefix = name add_quant = _AddFakeQuantAfterSubCell(prim_op, - num_bits=self.act_bits, + quant_dtype=self.act_dtype, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, @@ -180,15 +227,12 @@ class ConvertToQuantNetwork: group=conv_inner.group, eps=bn_inner.eps, momentum=bn_inner.momentum, - quant_delay=self.weight_qdelay, - freeze_bn=self.freeze_bn, - per_channel=self.weight_channel, - num_bits=self.weight_bits, - fake=True, - symmetric=self.weight_symmetric, - narrow_range=self.weight_range, has_bias=conv_inner.has_bias, - bias_init=conv_inner.bias_init) + bias_init=conv_inner.bias_init, + freeze_bn=self.freeze_bn, + quant_config=self.quant_config, + quant_dtype=self.weight_dtype, + fake=True) # change original network BatchNormal OP parameters to quant network conv_inner.gamma = subcell.batchnorm.gamma conv_inner.beta = subcell.batchnorm.beta @@ -209,13 +253,10 @@ class ConvertToQuantNetwork: group=conv_inner.group, eps=bn_inner.eps, momentum=bn_inner.momentum, - quant_delay=self.weight_qdelay, - per_channel=self.weight_channel, - num_bits=self.weight_bits, - symmetric=self.weight_symmetric, - narrow_range=self.weight_range, has_bias=conv_inner.has_bias, - bias_init=conv_inner.bias_init) + bias_init=conv_inner.bias_init, + quant_config=self.quant_config, + quant_dtype=self.weight_dtype) # change original network BatchNormal OP parameters to quant network conv_inner.batchnorm.gamma = subcell.batchnorm.gamma conv_inner.batchnorm.beta = subcell.batchnorm.beta @@ -234,11 +275,8 @@ class ConvertToQuantNetwork: dilation=conv_inner.dilation, group=conv_inner.group, has_bias=conv_inner.has_bias, - quant_delay=self.weight_qdelay, - per_channel=self.weight_channel, - num_bits=self.weight_bits, - symmetric=self.weight_symmetric, - narrow_range=self.weight_range) + quant_config=self.quant_config, + quant_dtype=self.weight_dtype) # change original network Conv2D OP parameters to quant network conv_inner.weight = subcell.conv.weight if subcell.conv.has_bias: @@ -249,7 +287,7 @@ class ConvertToQuantNetwork: elif subcell.after_fake: subcell.has_act = True subcell.activation = _AddFakeQuantAfterSubCell(F.identity, - num_bits=self.act_bits, + quant_dtype=self.act_dtype, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, @@ -264,11 +302,8 @@ class ConvertToQuantNetwork: dense_inner = quant.DenseQuant(dense_inner.in_channels, dense_inner.out_channels, has_bias=dense_inner.has_bias, - num_bits=self.weight_bits, - quant_delay=self.weight_qdelay, - per_channel=self.weight_channel, - symmetric=self.weight_symmetric, - narrow_range=self.weight_range) + quant_config=self.quant_config, + quant_dtype=self.weight_dtype) # change original network Dense OP parameters to quant network dense_inner.weight = subcell.dense.weight if subcell.dense.has_bias: @@ -279,7 +314,7 @@ class ConvertToQuantNetwork: elif subcell.after_fake: subcell.has_act = True subcell.activation = _AddFakeQuantAfterSubCell(F.identity, - num_bits=self.act_bits, + quant_dtype=self.act_dtype, quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, @@ -291,11 +326,8 @@ class ConvertToQuantNetwork: if act_class not in _ACTIVATION_MAP: raise ValueError("Unsupported activation in auto quant: ", act_class) return _ACTIVATION_MAP[act_class](activation=activation, - num_bits=self.act_bits, - quant_delay=self.act_qdelay, - per_channel=self.act_channel, - symmetric=self.act_symmetric, - narrow_range=self.act_range) + quant_config=self.quant_config, + quant_dtype=self.act_dtype) class ExportToQuantInferNetwork: @@ -523,7 +555,7 @@ def convert_quant_network(network, bn_fold=True, freeze_bn=10000000, quant_delay=(0, 0), - num_bits=(8, 8), + quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), per_channel=(False, False), symmetric=(False, False), narrow_range=(False, False) @@ -537,8 +569,9 @@ def convert_quant_network(network, freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7. quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during eval. The first element represent weights and second element represent data flow. Default: (0, 0) - num_bits (int, list or tuple): Number of bits to use for quantize weights and activations. The first - element represent weights and second element represent data flow. Default: (8, 8) + quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first + element represent weights and second element represent data flow. + Default: (QuantDtype.INT8, QuantDtype.INT8) per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` then base on per channel otherwise base on per layer. The first element represent weights and second element represent data flow. Default: (False, False) @@ -561,7 +594,7 @@ def convert_quant_network(network, return value quant_delay = convert2list("quant delay", quant_delay) - num_bits = convert2list("num bits", num_bits) + quant_dtype = convert2list("quant dtype", quant_dtype) per_channel = convert2list("per channel", per_channel) symmetric = convert2list("symmetric", symmetric) narrow_range = convert2list("narrow range", narrow_range) @@ -573,7 +606,7 @@ def convert_quant_network(network, quant_delay=quant_delay, bn_fold=bn_fold, freeze_bn=freeze_bn, - num_bits=num_bits, + quant_dtype=quant_dtype, per_channel=per_channel, symmetric=symmetric, narrow_range=narrow_range) diff --git a/model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py b/model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py index 8957ca9322b..63da3cfaa8d 100644 --- a/model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py +++ b/model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py @@ -17,12 +17,14 @@ import numpy as np import mindspore.nn as nn from mindspore.ops import operations as P from mindspore import Tensor -from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant +from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant +from mindspore.train.quant import quant _ema_decay = 0.999 _symmetric = True _fake = True _per_channel = True +_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) def _weight_variable(shape, factor=0.01): @@ -89,7 +91,7 @@ class ConvBNReLU(nn.Cell): super(ConvBNReLU, self).__init__() padding = (kernel_size - 1) // 2 conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, - group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric) + group=groups, fake=_fake, quant_config=_quant_config) layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] self.features = nn.SequentialCell(layers) @@ -124,13 +126,12 @@ class ResidualBlock(nn.Cell): channel = out_channel // self.expansion self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) - self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel, - symmetric=_symmetric, + self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, + quant_config=_quant_config, kernel_size=1, stride=1, pad_mode='same', padding=0), - FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False) + FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) ]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, - per_channel=_per_channel, - symmetric=_symmetric, + quant_config=_quant_config, kernel_size=1, stride=1, pad_mode='same', padding=0) @@ -142,16 +143,15 @@ class ResidualBlock(nn.Cell): if self.down_sample: self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, - per_channel=_per_channel, - symmetric=_symmetric, + quant_config=_quant_config, kernel_size=1, stride=stride, pad_mode='same', padding=0), - FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, - symmetric=False) + FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, + symmetric=False) ]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, fake=_fake, - per_channel=_per_channel, - symmetric=_symmetric, + quant_config=\ + _quant_config, kernel_size=1, stride=stride, pad_mode='same', @@ -235,9 +235,8 @@ class ResNet(nn.Cell): self.mean = P.ReduceMean(keep_dims=True) self.flatten = nn.Flatten() - self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel, - symmetric=_symmetric) - self.output_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay) + self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, quant_config=_quant_config) + self.output_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay) def _make_layer(self, block, layer_num, in_channel, out_channel, stride): """ diff --git a/tests/st/quantization/resnet50_quant/resnet_quant_manual.py b/tests/st/quantization/resnet50_quant/resnet_quant_manual.py index 0298971c03f..9bb6cbea0fb 100644 --- a/tests/st/quantization/resnet50_quant/resnet_quant_manual.py +++ b/tests/st/quantization/resnet50_quant/resnet_quant_manual.py @@ -13,20 +13,19 @@ # limitations under the License. # ============================================================================ """ResNet.""" - import numpy as np - import mindspore.nn as nn import mindspore.common.initializer as weight_init -from mindspore import Tensor from mindspore.ops import operations as P -from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant - +from mindspore import Tensor +from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant +from mindspore.train.quant import quant _ema_decay = 0.999 _symmetric = True _fake = True _per_channel = True +_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) def _weight_variable(shape, factor=0.01): @@ -93,7 +92,7 @@ class ConvBNReLU(nn.Cell): super(ConvBNReLU, self).__init__() padding = (kernel_size - 1) // 2 conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, - group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric) + group=groups, fake=_fake, quant_config=_quant_config) layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] self.features = nn.SequentialCell(layers) @@ -128,14 +127,12 @@ class ResidualBlock(nn.Cell): channel = out_channel // self.expansion self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) - self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel, - symmetric=_symmetric, + self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, + quant_config=_quant_config, kernel_size=1, stride=1, pad_mode='same', padding=0), - FakeQuantWithMinMax( - ema=True, ema_decay=_ema_decay, symmetric=False) + FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) ]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, - per_channel=_per_channel, - symmetric=_symmetric, + quant_config=_quant_config, kernel_size=1, stride=1, pad_mode='same', padding=0) @@ -147,16 +144,15 @@ class ResidualBlock(nn.Cell): if self.down_sample: self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, - per_channel=_per_channel, - symmetric=_symmetric, + quant_config=_quant_config, kernel_size=1, stride=stride, pad_mode='same', padding=0), - FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, - symmetric=False) + FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, + symmetric=False) ]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, fake=_fake, - per_channel=_per_channel, - symmetric=_symmetric, + quant_config=\ + _quant_config, kernel_size=1, stride=stride, pad_mode='same', @@ -212,8 +208,7 @@ class ResNet(nn.Cell): super(ResNet, self).__init__() if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: - raise ValueError( - "the length of layer_num, in_channels, out_channels list must be 4!") + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") @@ -241,10 +236,8 @@ class ResNet(nn.Cell): self.mean = P.ReduceMean(keep_dims=True) self.flatten = nn.Flatten() - self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel, - symmetric=_symmetric) - self.output_fake = nn.FakeQuantWithMinMax( - ema=True, ema_decay=_ema_decay) + self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, quant_config=_quant_config) + self.output_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay) # init weights self._initialize_weights()