add QuantConfig & modify quant cells

This commit is contained in:
yuchaojie 2020-10-16 11:20:20 +08:00
parent 2b58af0e9d
commit a84affffd7
4 changed files with 231 additions and 428 deletions

View File

@ -15,6 +15,7 @@
"""Quantization aware training."""
from functools import partial
from collections import namedtuple
import numpy as np
from mindspore import nn
import mindspore.common.dtype as mstype
@ -34,7 +35,7 @@ from ...ops.operations import _quant_ops as Q
__all__ = [
'Conv2dBnAct',
'DenseBnAct',
'FakeQuantWithMinMax',
'FakeQuantWithMinMaxObserver',
'Conv2dBnFoldQuant',
'Conv2dBnWithoutFoldQuant',
'Conv2dQuant',
@ -422,14 +423,14 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
symmetric=False,
narrow_range=False,
quant_delay=0):
"""Initialize FakeQuantWithMinMax layer"""
"""Initialize FakeQuantWithMinMaxObserver"""
super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
symmetric=symmetric, narrow_range=narrow_range,
num_channels=num_channels)
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init
self.max_init = max_init
self.quant_dtype = quant_dtype
@ -498,119 +499,9 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
return out
class FakeQuantWithMinMax(Cell):
r"""
Quantization aware op. This OP provides the fake quantization observer function on data with min and max.
QuantConfig = namedtuple("QuantConfig", ['weight', 'activation'])
Args:
min_init (int, float): The initialized min value. Default: -6.
max_init (int, float): The initialized max value. Default: 6.
ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
channel_axis (int): Quantization by channel axis. Default: 1.
num_channels (int): declarate the min and max channel size, Default: 1.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax.
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> fake_quant = FakeQuantWithMinMax()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
"""
def __init__(self,
min_init=-6,
max_init=6,
ema=False,
ema_decay=0.999,
per_channel=False,
channel_axis=1,
num_channels=1,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
"""Initialize FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
Validator.check_type("min_init", min_init, [int, float])
Validator.check_type("max_init", max_init, [int, float])
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
Validator.check_non_negative_int(quant_delay, 'quant_delay')
self.min_init = min_init
self.max_init = max_init
self.num_bits = num_bits
self.ema = ema
self.ema_decay = ema_decay
self.per_channel = per_channel
self.num_channels = num_channels
self.channel_axis = channel_axis
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.is_ascend = context.get_context('device_target') == "Ascend"
# init tensor min and max for fake quant op
if self.per_channel:
min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
else:
min_array = np.array([self.min_init]).astype(np.float32)
max_array = np.array([self.max_init]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
# init fake quant relative op
if self.per_channel:
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
else:
quant_fun = Q.FakeQuantPerLayer
ema_fun = Q.MinMaxUpdatePerLayer
self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
if self.is_ascend:
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
quant_delay=self.quant_delay)
self.fake_quant_infer = self.fake_quant_train
else:
quant_fun = partial(quant_fun,
ema=self.ema,
ema_decay=ema_decay,
num_bits=self.num_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range,
quant_delay=self.quant_delay)
self.fake_quant_train = quant_fun(training=True)
self.fake_quant_infer = quant_fun(training=False)
def extend_repr(self):
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
'quant_delay={}, min_init={}, max_init={}'.format(self.num_bits, self.symmetric, self.narrow_range,
self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.num_channels, self.quant_delay,
self.min_init, self.max_init)
return s
def construct(self, x):
if self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up)
out = self.fake_quant_train(x, self.minq, self.maxq)
else:
out = self.fake_quant_infer(x, self.minq, self.maxq)
return out
quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver, activation=FakeQuantWithMinMaxObserver)
class Conv2dBnFoldQuant(Cell):
@ -641,12 +532,9 @@ class Conv2dBnFoldQuant(Cell):
mean vector. Default: 'zeros'.
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
variance vector. Default: 'ones'.
fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMax op. Default: True.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): The Quantization delay parameters according to the global step. Default: 0.
fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
freeze_bn (int): The quantization freeze BatchNormal op is according to the global step. Default: 100000.
Inputs:
@ -680,11 +568,8 @@ class Conv2dBnFoldQuant(Cell):
mean_init='zeros',
var_init='ones',
fake=True,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0,
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8,
freeze_bn=100000):
"""Initialize Conv2dBnFoldQuant layer"""
super(Conv2dBnFoldQuant, self).__init__()
@ -699,13 +584,10 @@ class Conv2dBnFoldQuant(Cell):
self.eps = eps
self.momentum = momentum
self.has_bias = has_bias
self.quant_delay = quant_delay
self.freeze_bn = freeze_bn
self.fake = fake
self.num_bits = num_bits
self.per_channel = per_channel
self.symmetric = symmetric
self.narrow_range = narrow_range
self.quant_config = quant_config
self.quant_dtype = quant_dtype
self.is_gpu = context.get_context('device_target') == "GPU"
# initialize convolution op and Parameter
@ -745,16 +627,12 @@ class Conv2dBnFoldQuant(Cell):
requires_grad=False)
# initialize fake ops
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6,
ema=False,
per_channel=per_channel,
channel_axis=channel_axis,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = Q.CorrectionMul(channel_axis)
if context.get_context('device_target') == "Ascend":
@ -777,7 +655,7 @@ class Conv2dBnFoldQuant(Cell):
self.pad_mode, self.padding, self.dilation,
self.group,
self.fake, self.freeze_bn, self.momentum,
self.quant_delay)
self.fake_quant_weight.quant_delay)
return s
def construct(self, x):
@ -836,11 +714,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -868,11 +743,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
momentum=0.997,
weight_init='normal',
bias_init='zeros',
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(Conv2dBnWithoutFoldQuant, self).__init__()
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
@ -886,7 +758,6 @@ class Conv2dBnWithoutFoldQuant(Cell):
self.pad_mode = pad_mode
self.padding = padding
self.group = group
self.quant_delay = quant_delay
self.bias_add = P.BiasAdd()
if Validator.check_bool(has_bias):
@ -917,16 +788,12 @@ class Conv2dBnWithoutFoldQuant(Cell):
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
channel_axis = 0
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6,
ema=False,
per_channel=per_channel,
channel_axis=channel_axis,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum)
def construct(self, x):
@ -942,7 +809,7 @@ class Conv2dBnWithoutFoldQuant(Cell):
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.quant_delay)
self.has_bias, self.fake_quant_weight.quant_delay)
return s
@ -966,11 +833,8 @@ class Conv2dQuant(Cell):
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -996,11 +860,8 @@ class Conv2dQuant(Cell):
has_bias=False,
weight_init='normal',
bias_init='zeros',
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(Conv2dQuant, self).__init__()
if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size)
@ -1014,7 +875,6 @@ class Conv2dQuant(Cell):
self.pad_mode = pad_mode
self.padding = padding
self.group = group
self.quant_delay = quant_delay
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
@ -1033,16 +893,12 @@ class Conv2dQuant(Cell):
stride=self.stride,
dilation=self.dilation,
group=self.group)
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6,
ema=False,
per_channel=per_channel,
channel_axis=0,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
def construct(self, x):
weight = self.fake_quant_weight(self.weight)
@ -1056,7 +912,7 @@ class Conv2dQuant(Cell):
'pad_mode={}, padding={}, dilation={}, group={}, ' \
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.quant_delay)
self.has_bias, self.fake_quant_weight.quant_delay)
return s
@ -1075,11 +931,8 @@ class DenseQuant(Cell):
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -1093,19 +946,15 @@ class DenseQuant(Cell):
>>> result = dense_quant(input_x)
"""
def __init__(
self,
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(DenseQuant, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
@ -1132,16 +981,12 @@ class DenseQuant(Cell):
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_weight = quant_config.weight(min_init=-6,
max_init=6,
ema=False,
per_channel=per_channel,
channel_axis=0,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
def construct(self, x):
"""Use operators to construct the Dense layer."""
@ -1179,16 +1024,13 @@ class ActQuant(_QuantActivation):
Quantization aware training activation function.
Add the fake quant op to the end of activation op, by which the output of activation op will be truncated.
Please check `FakeQuantWithMinMax` for more details.
Please check `FakeQuantWithMinMaxObserver` or other observer for more details.
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global steps. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
Inputs:
- **x** (Tensor) - The input of ReLU6Quant.
@ -1205,21 +1047,14 @@ class ActQuant(_QuantActivation):
def __init__(self,
activation,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(ActQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
self.fake_quant_act = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema=False,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
self.act = activation
def construct(self, x):
@ -1240,11 +1075,8 @@ class LeakyReLUQuant(_QuantActivation):
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
Inputs:
- **x** (Tensor) - The input of LeakyReLUQuant.
@ -1261,30 +1093,19 @@ class LeakyReLUQuant(_QuantActivation):
def __init__(self,
activation,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(LeakyReLUQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_act_before = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
quant_dtype=quant_dtype)
self.fake_quant_act_after = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
if issubclass(activation.__class__, nn.LeakyReLU):
self.act = activation
else:
@ -1309,11 +1130,8 @@ class HSwishQuant(_QuantActivation):
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
Inputs:
- **x** (Tensor) - The input of HSwishQuant.
@ -1330,30 +1148,19 @@ class HSwishQuant(_QuantActivation):
def __init__(self,
activation,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(HSwishQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_act_before = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
quant_dtype=quant_dtype)
self.fake_quant_act_after = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
if issubclass(activation.__class__, nn.HSwish):
self.act = activation
else:
@ -1378,11 +1185,8 @@ class HSigmoidQuant(_QuantActivation):
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
Inputs:
- **x** (Tensor) - The input of HSigmoidQuant.
@ -1399,30 +1203,19 @@ class HSigmoidQuant(_QuantActivation):
def __init__(self,
activation,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(HSigmoidQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_act_before = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
quant_dtype=quant_dtype)
self.fake_quant_act_after = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
if issubclass(activation.__class__, nn.HSigmoid):
self.act = activation
else:
@ -1446,11 +1239,8 @@ class TensorAddQuant(Cell):
Args:
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
Inputs:
- **x** (Tensor) - The input of TensorAddQuant.
@ -1467,21 +1257,14 @@ class TensorAddQuant(Cell):
def __init__(self,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(TensorAddQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_act = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
self.add = P.TensorAdd()
def construct(self, x1, x2):
@ -1498,11 +1281,8 @@ class MulQuant(Cell):
Args:
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
Inputs:
- **x** (Tensor) - The input of MulQuant.
@ -1510,25 +1290,23 @@ class MulQuant(Cell):
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> mul_quant = nn.MulQuant()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> input_y = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
>>> result = mul_quant(input_x, input_y)
"""
def __init__(self,
ema_decay=0.999,
per_channel=False,
num_bits=8,
symmetric=False,
narrow_range=False,
quant_delay=0):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(MulQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
self.fake_quant_act = quant_config.activation(min_init=-6,
max_init=6,
ema=True,
ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric,
narrow_range=narrow_range,
quant_delay=quant_delay)
quant_dtype=quant_dtype)
self.mul = P.Mul()
def construct(self, x1, x2):

View File

@ -27,6 +27,7 @@ from ...common import Tensor
from ...common import dtype as mstype
from ...common.api import _executor
from ...nn.layer import quant
from ...compression.common import QuantDtype
from ...ops import functional as F
from ...ops import operations as P
from ...ops.operations import _inner_ops as inner
@ -41,6 +42,46 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
nn.HSwish: quant.HSwishQuant}
def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver),
quant_delay=(0, 0),
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
per_channel=(False, False),
symmetric=(False, False),
narrow_range=(False, False)
):
r"""
Configs the oberser type of weights and data flow with quant params.
Args:
quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent
weights and second element represent data flow.
Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver)
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
element represent weights and second element represent data flow.
Default: (QuantDtype.INT8, QuantDtype.INT8)
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: (False, False)
symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
symmetric otherwise base on asymmetric. The first element represent weights and second
element represent data flow. Default: (False, False)
narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
The first element represents weights and the second element represents data flow. Default: (False, False)
Returns:
QuantConfig, Contains the oberser type of weight and activation.
"""
weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0],
per_channel=per_channel[0], symmetric=symmetric[0],
narrow_range=narrow_range[0])
act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1],
per_channel=per_channel[-1], symmetric=symmetric[-1],
narrow_range=narrow_range[-1])
return quant.QuantConfig(weight=weight_observer, activation=act_observer)
class _AddFakeQuantInput(nn.Cell):
"""
Add FakeQuant OP at input of the network. Only support one input case.
@ -48,7 +89,8 @@ class _AddFakeQuantInput(nn.Cell):
def __init__(self, network, quant_delay=0):
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6,
quant_delay=quant_delay, ema=True)
self.fake_quant_input.update_parameters_name('fake_quant_input.')
self.network = network
@ -66,10 +108,10 @@ class _AddFakeQuantAfterSubCell(nn.Cell):
def __init__(self, subcell, **kwargs):
super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
self.subcell = subcell
self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6,
self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6,
max_init=6,
ema=True,
num_bits=kwargs["num_bits"],
quant_dtype=kwargs["quant_dtype"],
quant_delay=kwargs["quant_delay"],
per_channel=kwargs["per_channel"],
symmetric=kwargs["symmetric"],
@ -93,8 +135,8 @@ class ConvertToQuantNetwork:
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay")
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn")
self.weight_bits = Validator.check_non_negative_int(kwargs["num_bits"][0], "weights bit")
self.act_bits = Validator.check_int(kwargs["num_bits"][-1], 0, Rel.GE, "activations bit")
self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype)
self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype)
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")
@ -103,6 +145,11 @@ class ConvertToQuantNetwork:
self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range")
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense}
self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"],
quant_dtype=kwargs["quant_dtype"],
per_channel=kwargs["per_channel"],
symmetric=kwargs["symmetric"],
narrow_range=kwargs["narrow_range"])
def _convert_op_name(self, name):
pattern = re.compile(r'([A-Z]{1})')
@ -149,7 +196,7 @@ class ConvertToQuantNetwork:
for name, prim_op in add_list:
prefix = name
add_quant = _AddFakeQuantAfterSubCell(prim_op,
num_bits=self.act_bits,
quant_dtype=self.act_dtype,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
@ -180,15 +227,12 @@ class ConvertToQuantNetwork:
group=conv_inner.group,
eps=bn_inner.eps,
momentum=bn_inner.momentum,
quant_delay=self.weight_qdelay,
freeze_bn=self.freeze_bn,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
fake=True,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range,
has_bias=conv_inner.has_bias,
bias_init=conv_inner.bias_init)
bias_init=conv_inner.bias_init,
freeze_bn=self.freeze_bn,
quant_config=self.quant_config,
quant_dtype=self.weight_dtype,
fake=True)
# change original network BatchNormal OP parameters to quant network
conv_inner.gamma = subcell.batchnorm.gamma
conv_inner.beta = subcell.batchnorm.beta
@ -209,13 +253,10 @@ class ConvertToQuantNetwork:
group=conv_inner.group,
eps=bn_inner.eps,
momentum=bn_inner.momentum,
quant_delay=self.weight_qdelay,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range,
has_bias=conv_inner.has_bias,
bias_init=conv_inner.bias_init)
bias_init=conv_inner.bias_init,
quant_config=self.quant_config,
quant_dtype=self.weight_dtype)
# change original network BatchNormal OP parameters to quant network
conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
conv_inner.batchnorm.beta = subcell.batchnorm.beta
@ -234,11 +275,8 @@ class ConvertToQuantNetwork:
dilation=conv_inner.dilation,
group=conv_inner.group,
has_bias=conv_inner.has_bias,
quant_delay=self.weight_qdelay,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
quant_config=self.quant_config,
quant_dtype=self.weight_dtype)
# change original network Conv2D OP parameters to quant network
conv_inner.weight = subcell.conv.weight
if subcell.conv.has_bias:
@ -249,7 +287,7 @@ class ConvertToQuantNetwork:
elif subcell.after_fake:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
quant_dtype=self.act_dtype,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
@ -264,11 +302,8 @@ class ConvertToQuantNetwork:
dense_inner = quant.DenseQuant(dense_inner.in_channels,
dense_inner.out_channels,
has_bias=dense_inner.has_bias,
num_bits=self.weight_bits,
quant_delay=self.weight_qdelay,
per_channel=self.weight_channel,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
quant_config=self.quant_config,
quant_dtype=self.weight_dtype)
# change original network Dense OP parameters to quant network
dense_inner.weight = subcell.dense.weight
if subcell.dense.has_bias:
@ -279,7 +314,7 @@ class ConvertToQuantNetwork:
elif subcell.after_fake:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
quant_dtype=self.act_dtype,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
@ -291,11 +326,8 @@ class ConvertToQuantNetwork:
if act_class not in _ACTIVATION_MAP:
raise ValueError("Unsupported activation in auto quant: ", act_class)
return _ACTIVATION_MAP[act_class](activation=activation,
num_bits=self.act_bits,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
quant_config=self.quant_config,
quant_dtype=self.act_dtype)
class ExportToQuantInferNetwork:
@ -523,7 +555,7 @@ def convert_quant_network(network,
bn_fold=True,
freeze_bn=10000000,
quant_delay=(0, 0),
num_bits=(8, 8),
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
per_channel=(False, False),
symmetric=(False, False),
narrow_range=(False, False)
@ -537,8 +569,9 @@ def convert_quant_network(network,
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
num_bits (int, list or tuple): Number of bits to use for quantize weights and activations. The first
element represent weights and second element represent data flow. Default: (8, 8)
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
element represent weights and second element represent data flow.
Default: (QuantDtype.INT8, QuantDtype.INT8)
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: (False, False)
@ -561,7 +594,7 @@ def convert_quant_network(network,
return value
quant_delay = convert2list("quant delay", quant_delay)
num_bits = convert2list("num bits", num_bits)
quant_dtype = convert2list("quant dtype", quant_dtype)
per_channel = convert2list("per channel", per_channel)
symmetric = convert2list("symmetric", symmetric)
narrow_range = convert2list("narrow range", narrow_range)
@ -573,7 +606,7 @@ def convert_quant_network(network,
quant_delay=quant_delay,
bn_fold=bn_fold,
freeze_bn=freeze_bn,
num_bits=num_bits,
quant_dtype=quant_dtype,
per_channel=per_channel,
symmetric=symmetric,
narrow_range=narrow_range)

View File

@ -17,12 +17,14 @@ import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant
from mindspore.train.quant import quant
_ema_decay = 0.999
_symmetric = True
_fake = True
_per_channel = True
_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))
def _weight_variable(shape, factor=0.01):
@ -89,7 +91,7 @@ class ConvBNReLU(nn.Cell):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric)
group=groups, fake=_fake, quant_config=_quant_config)
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
self.features = nn.SequentialCell(layers)
@ -124,13 +126,12 @@ class ResidualBlock(nn.Cell):
channel = out_channel // self.expansion
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel,
symmetric=_symmetric,
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
quant_config=_quant_config,
kernel_size=1, stride=1, pad_mode='same', padding=0),
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False)
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False)
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=_quant_config,
kernel_size=1, stride=1,
pad_mode='same', padding=0)
@ -142,16 +143,15 @@ class ResidualBlock(nn.Cell):
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=_quant_config,
kernel_size=1, stride=stride,
pad_mode='same', padding=0),
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay,
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay,
symmetric=False)
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=\
_quant_config,
kernel_size=1,
stride=stride,
pad_mode='same',
@ -235,9 +235,8 @@ class ResNet(nn.Cell):
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel,
symmetric=_symmetric)
self.output_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, quant_config=_quant_config)
self.output_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay)
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
"""

View File

@ -13,20 +13,19 @@
# limitations under the License.
# ============================================================================
"""ResNet."""
import numpy as np
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant
from mindspore import Tensor
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant
from mindspore.train.quant import quant
_ema_decay = 0.999
_symmetric = True
_fake = True
_per_channel = True
_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))
def _weight_variable(shape, factor=0.01):
@ -93,7 +92,7 @@ class ConvBNReLU(nn.Cell):
super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric)
group=groups, fake=_fake, quant_config=_quant_config)
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
self.features = nn.SequentialCell(layers)
@ -128,14 +127,12 @@ class ResidualBlock(nn.Cell):
channel = out_channel // self.expansion
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel,
symmetric=_symmetric,
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
quant_config=_quant_config,
kernel_size=1, stride=1, pad_mode='same', padding=0),
FakeQuantWithMinMax(
ema=True, ema_decay=_ema_decay, symmetric=False)
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False)
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=_quant_config,
kernel_size=1, stride=1,
pad_mode='same', padding=0)
@ -147,16 +144,15 @@ class ResidualBlock(nn.Cell):
if self.down_sample:
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=_quant_config,
kernel_size=1, stride=stride,
pad_mode='same', padding=0),
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay,
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay,
symmetric=False)
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
fake=_fake,
per_channel=_per_channel,
symmetric=_symmetric,
quant_config=\
_quant_config,
kernel_size=1,
stride=stride,
pad_mode='same',
@ -212,8 +208,7 @@ class ResNet(nn.Cell):
super(ResNet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError(
"the length of layer_num, in_channels, out_channels list must be 4!")
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
@ -241,10 +236,8 @@ class ResNet(nn.Cell):
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel,
symmetric=_symmetric)
self.output_fake = nn.FakeQuantWithMinMax(
ema=True, ema_decay=_ema_decay)
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, quant_config=_quant_config)
self.output_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay)
# init weights
self._initialize_weights()