forked from mindspore-Ecosystem/mindspore
add QuantConfig & modify quant cells
This commit is contained in:
parent
2b58af0e9d
commit
a84affffd7
|
@ -15,6 +15,7 @@
|
|||
"""Quantization aware training."""
|
||||
|
||||
from functools import partial
|
||||
from collections import namedtuple
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
import mindspore.common.dtype as mstype
|
||||
|
@ -34,7 +35,7 @@ from ...ops.operations import _quant_ops as Q
|
|||
__all__ = [
|
||||
'Conv2dBnAct',
|
||||
'DenseBnAct',
|
||||
'FakeQuantWithMinMax',
|
||||
'FakeQuantWithMinMaxObserver',
|
||||
'Conv2dBnFoldQuant',
|
||||
'Conv2dBnWithoutFoldQuant',
|
||||
'Conv2dQuant',
|
||||
|
@ -422,14 +423,14 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
|
|||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
"""Initialize FakeQuantWithMinMax layer"""
|
||||
"""Initialize FakeQuantWithMinMaxObserver"""
|
||||
super(FakeQuantWithMinMaxObserver, self).__init__(quant_dtype=quant_dtype, per_channel=per_channel,
|
||||
symmetric=symmetric, narrow_range=narrow_range,
|
||||
num_channels=num_channels)
|
||||
Validator.check_type("min_init", min_init, [int, float])
|
||||
Validator.check_type("max_init", max_init, [int, float])
|
||||
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
|
||||
Validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
|
||||
Validator.check_non_negative_int(quant_delay, 'quant_delay')
|
||||
self.min_init = min_init
|
||||
self.max_init = max_init
|
||||
self.quant_dtype = quant_dtype
|
||||
|
@ -498,119 +499,9 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver):
|
|||
return out
|
||||
|
||||
|
||||
class FakeQuantWithMinMax(Cell):
|
||||
r"""
|
||||
Quantization aware op. This OP provides the fake quantization observer function on data with min and max.
|
||||
QuantConfig = namedtuple("QuantConfig", ['weight', 'activation'])
|
||||
|
||||
Args:
|
||||
min_init (int, float): The initialized min value. Default: -6.
|
||||
max_init (int, float): The initialized max value. Default: 6.
|
||||
ema (bool): The exponential Moving Average algorithm updates min and max. Default: False.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
channel_axis (int): Quantization by channel axis. Default: 1.
|
||||
num_channels (int): declarate the min and max channel size, Default: 1.
|
||||
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
|
||||
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of FakeQuantWithMinMax.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Examples:
|
||||
>>> fake_quant = FakeQuantWithMinMax()
|
||||
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||
>>> result = fake_quant(input_x)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
channel_axis=1,
|
||||
num_channels=1,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
"""Initialize FakeQuantWithMinMax layer"""
|
||||
super(FakeQuantWithMinMax, self).__init__()
|
||||
Validator.check_type("min_init", min_init, [int, float])
|
||||
Validator.check_type("max_init", max_init, [int, float])
|
||||
Validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
|
||||
Validator.check_non_negative_int(quant_delay, 'quant_delay')
|
||||
self.min_init = min_init
|
||||
self.max_init = max_init
|
||||
self.num_bits = num_bits
|
||||
self.ema = ema
|
||||
self.ema_decay = ema_decay
|
||||
self.per_channel = per_channel
|
||||
self.num_channels = num_channels
|
||||
self.channel_axis = channel_axis
|
||||
self.quant_delay = quant_delay
|
||||
self.symmetric = symmetric
|
||||
self.narrow_range = narrow_range
|
||||
self.is_ascend = context.get_context('device_target') == "Ascend"
|
||||
|
||||
# init tensor min and max for fake quant op
|
||||
if self.per_channel:
|
||||
min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
|
||||
max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
|
||||
else:
|
||||
min_array = np.array([self.min_init]).astype(np.float32)
|
||||
max_array = np.array([self.max_init]).astype(np.float32)
|
||||
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
|
||||
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
|
||||
|
||||
# init fake quant relative op
|
||||
if self.per_channel:
|
||||
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
|
||||
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
|
||||
else:
|
||||
quant_fun = Q.FakeQuantPerLayer
|
||||
ema_fun = Q.MinMaxUpdatePerLayer
|
||||
|
||||
self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
|
||||
if self.is_ascend:
|
||||
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
quant_delay=self.quant_delay)
|
||||
self.fake_quant_infer = self.fake_quant_train
|
||||
else:
|
||||
quant_fun = partial(quant_fun,
|
||||
ema=self.ema,
|
||||
ema_decay=ema_decay,
|
||||
num_bits=self.num_bits,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
quant_delay=self.quant_delay)
|
||||
self.fake_quant_train = quant_fun(training=True)
|
||||
self.fake_quant_infer = quant_fun(training=False)
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
|
||||
'quant_delay={}, min_init={}, max_init={}'.format(self.num_bits, self.symmetric, self.narrow_range,
|
||||
self.ema, self.ema_decay, self.per_channel,
|
||||
self.channel_axis, self.num_channels, self.quant_delay,
|
||||
self.min_init, self.max_init)
|
||||
return s
|
||||
|
||||
def construct(self, x):
|
||||
if self.training:
|
||||
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
|
||||
P.Assign()(self.minq, min_up)
|
||||
P.Assign()(self.maxq, max_up)
|
||||
out = self.fake_quant_train(x, self.minq, self.maxq)
|
||||
else:
|
||||
out = self.fake_quant_infer(x, self.minq, self.maxq)
|
||||
return out
|
||||
quant_config_default = QuantConfig(weight=FakeQuantWithMinMaxObserver, activation=FakeQuantWithMinMaxObserver)
|
||||
|
||||
|
||||
class Conv2dBnFoldQuant(Cell):
|
||||
|
@ -641,12 +532,9 @@ class Conv2dBnFoldQuant(Cell):
|
|||
mean vector. Default: 'zeros'.
|
||||
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
variance vector. Default: 'ones'.
|
||||
fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMax op. Default: True.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
|
||||
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): The Quantization delay parameters according to the global step. Default: 0.
|
||||
fake (bool): Whether Conv2dBnFoldQuant Cell adds FakeQuantWithMinMaxObserver. Default: True.
|
||||
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
|
||||
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
|
||||
freeze_bn (int): The quantization freeze BatchNormal op is according to the global step. Default: 100000.
|
||||
|
||||
Inputs:
|
||||
|
@ -680,11 +568,8 @@ class Conv2dBnFoldQuant(Cell):
|
|||
mean_init='zeros',
|
||||
var_init='ones',
|
||||
fake=True,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0,
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8,
|
||||
freeze_bn=100000):
|
||||
"""Initialize Conv2dBnFoldQuant layer"""
|
||||
super(Conv2dBnFoldQuant, self).__init__()
|
||||
|
@ -699,13 +584,10 @@ class Conv2dBnFoldQuant(Cell):
|
|||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.has_bias = has_bias
|
||||
self.quant_delay = quant_delay
|
||||
self.freeze_bn = freeze_bn
|
||||
self.fake = fake
|
||||
self.num_bits = num_bits
|
||||
self.per_channel = per_channel
|
||||
self.symmetric = symmetric
|
||||
self.narrow_range = narrow_range
|
||||
self.quant_config = quant_config
|
||||
self.quant_dtype = quant_dtype
|
||||
self.is_gpu = context.get_context('device_target') == "GPU"
|
||||
|
||||
# initialize convolution op and Parameter
|
||||
|
@ -745,16 +627,12 @@ class Conv2dBnFoldQuant(Cell):
|
|||
requires_grad=False)
|
||||
|
||||
# initialize fake ops
|
||||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||
self.fake_quant_weight = quant_config.weight(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
per_channel=per_channel,
|
||||
channel_axis=channel_axis,
|
||||
num_channels=out_channels,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
quant_dtype=quant_dtype)
|
||||
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
|
||||
self.correct_mul = Q.CorrectionMul(channel_axis)
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
|
@ -777,7 +655,7 @@ class Conv2dBnFoldQuant(Cell):
|
|||
self.pad_mode, self.padding, self.dilation,
|
||||
self.group,
|
||||
self.fake, self.freeze_bn, self.momentum,
|
||||
self.quant_delay)
|
||||
self.fake_quant_weight.quant_delay)
|
||||
return s
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -836,11 +714,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
|
||||
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
|
||||
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
|
||||
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
@ -868,11 +743,8 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|||
momentum=0.997,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
super(Conv2dBnWithoutFoldQuant, self).__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
self.kernel_size = (kernel_size, kernel_size)
|
||||
|
@ -886,7 +758,6 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|||
self.pad_mode = pad_mode
|
||||
self.padding = padding
|
||||
self.group = group
|
||||
self.quant_delay = quant_delay
|
||||
|
||||
self.bias_add = P.BiasAdd()
|
||||
if Validator.check_bool(has_bias):
|
||||
|
@ -917,16 +788,12 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|||
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
|
||||
channel_axis = 0
|
||||
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
||||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||
self.fake_quant_weight = quant_config.weight(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
per_channel=per_channel,
|
||||
channel_axis=channel_axis,
|
||||
num_channels=out_channels,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
quant_dtype=quant_dtype)
|
||||
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum)
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -942,7 +809,7 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|||
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
||||
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||
self.pad_mode, self.padding, self.dilation, self.group,
|
||||
self.has_bias, self.quant_delay)
|
||||
self.has_bias, self.fake_quant_weight.quant_delay)
|
||||
return s
|
||||
|
||||
|
||||
|
@ -966,11 +833,8 @@ class Conv2dQuant(Cell):
|
|||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
|
||||
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
|
||||
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
|
||||
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
@ -996,11 +860,8 @@ class Conv2dQuant(Cell):
|
|||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
super(Conv2dQuant, self).__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
self.kernel_size = (kernel_size, kernel_size)
|
||||
|
@ -1014,7 +875,6 @@ class Conv2dQuant(Cell):
|
|||
self.pad_mode = pad_mode
|
||||
self.padding = padding
|
||||
self.group = group
|
||||
self.quant_delay = quant_delay
|
||||
|
||||
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
|
||||
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
||||
|
@ -1033,16 +893,12 @@ class Conv2dQuant(Cell):
|
|||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
group=self.group)
|
||||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||
self.fake_quant_weight = quant_config.weight(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
per_channel=per_channel,
|
||||
channel_axis=0,
|
||||
num_channels=out_channels,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
quant_dtype=quant_dtype)
|
||||
|
||||
def construct(self, x):
|
||||
weight = self.fake_quant_weight(self.weight)
|
||||
|
@ -1056,7 +912,7 @@ class Conv2dQuant(Cell):
|
|||
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
||||
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||
self.pad_mode, self.padding, self.dilation, self.group,
|
||||
self.has_bias, self.quant_delay)
|
||||
self.has_bias, self.fake_quant_weight.quant_delay)
|
||||
return s
|
||||
|
||||
|
||||
|
@ -1075,11 +931,8 @@ class DenseQuant(Cell):
|
|||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
|
||||
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
|
||||
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
|
||||
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
@ -1093,19 +946,15 @@ class DenseQuant(Cell):
|
|||
>>> result = dense_quant(input_x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
has_bias=True,
|
||||
activation=None,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
has_bias=True,
|
||||
activation=None,
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
super(DenseQuant, self).__init__()
|
||||
self.in_channels = Validator.check_positive_int(in_channels)
|
||||
self.out_channels = Validator.check_positive_int(out_channels)
|
||||
|
@ -1132,16 +981,12 @@ class DenseQuant(Cell):
|
|||
|
||||
self.activation = get_activation(activation)
|
||||
self.activation_flag = self.activation is not None
|
||||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||
self.fake_quant_weight = quant_config.weight(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
per_channel=per_channel,
|
||||
channel_axis=0,
|
||||
num_channels=out_channels,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
quant_dtype=quant_dtype)
|
||||
|
||||
def construct(self, x):
|
||||
"""Use operators to construct the Dense layer."""
|
||||
|
@ -1179,16 +1024,13 @@ class ActQuant(_QuantActivation):
|
|||
Quantization aware training activation function.
|
||||
|
||||
Add the fake quant op to the end of activation op, by which the output of activation op will be truncated.
|
||||
Please check `FakeQuantWithMinMax` for more details.
|
||||
Please check `FakeQuantWithMinMaxObserver` or other observer for more details.
|
||||
|
||||
Args:
|
||||
activation (Cell): Activation cell class.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
|
||||
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according to the global steps. Default: 0.
|
||||
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
|
||||
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of ReLU6Quant.
|
||||
|
@ -1205,21 +1047,14 @@ class ActQuant(_QuantActivation):
|
|||
def __init__(self,
|
||||
activation,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
super(ActQuant, self).__init__()
|
||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
ema_decay=ema_decay,
|
||||
quant_dtype=quant_dtype)
|
||||
self.act = activation
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -1240,11 +1075,8 @@ class LeakyReLUQuant(_QuantActivation):
|
|||
Args:
|
||||
activation (Cell): Activation cell class.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
|
||||
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
|
||||
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
|
||||
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of LeakyReLUQuant.
|
||||
|
@ -1261,30 +1093,19 @@ class LeakyReLUQuant(_QuantActivation):
|
|||
def __init__(self,
|
||||
activation,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
super(LeakyReLUQuant, self).__init__()
|
||||
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act_before = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
quant_dtype=quant_dtype)
|
||||
self.fake_quant_act_after = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
quant_dtype=quant_dtype)
|
||||
if issubclass(activation.__class__, nn.LeakyReLU):
|
||||
self.act = activation
|
||||
else:
|
||||
|
@ -1309,11 +1130,8 @@ class HSwishQuant(_QuantActivation):
|
|||
Args:
|
||||
activation (Cell): Activation cell class.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
|
||||
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
|
||||
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
|
||||
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of HSwishQuant.
|
||||
|
@ -1330,30 +1148,19 @@ class HSwishQuant(_QuantActivation):
|
|||
def __init__(self,
|
||||
activation,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
super(HSwishQuant, self).__init__()
|
||||
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act_before = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
quant_dtype=quant_dtype)
|
||||
self.fake_quant_act_after = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
quant_dtype=quant_dtype)
|
||||
if issubclass(activation.__class__, nn.HSwish):
|
||||
self.act = activation
|
||||
else:
|
||||
|
@ -1378,11 +1185,8 @@ class HSigmoidQuant(_QuantActivation):
|
|||
Args:
|
||||
activation (Cell): Activation cell class.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
|
||||
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
|
||||
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
|
||||
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of HSigmoidQuant.
|
||||
|
@ -1399,30 +1203,19 @@ class HSigmoidQuant(_QuantActivation):
|
|||
def __init__(self,
|
||||
activation,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
super(HSigmoidQuant, self).__init__()
|
||||
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act_before = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
quant_dtype=quant_dtype)
|
||||
self.fake_quant_act_after = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
quant_dtype=quant_dtype)
|
||||
if issubclass(activation.__class__, nn.HSigmoid):
|
||||
self.act = activation
|
||||
else:
|
||||
|
@ -1446,11 +1239,8 @@ class TensorAddQuant(Cell):
|
|||
|
||||
Args:
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
|
||||
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
|
||||
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
|
||||
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of TensorAddQuant.
|
||||
|
@ -1467,21 +1257,14 @@ class TensorAddQuant(Cell):
|
|||
|
||||
def __init__(self,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
super(TensorAddQuant, self).__init__()
|
||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
quant_dtype=quant_dtype)
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x1, x2):
|
||||
|
@ -1498,11 +1281,8 @@ class MulQuant(Cell):
|
|||
|
||||
Args:
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): The bit number of quantization, supporting 4 and 8bits. Default: 8.
|
||||
symmetric (bool): The quantization algorithm is symmetric or not. Default: False.
|
||||
narrow_range (bool): The quantization algorithm uses narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according to the global step. Default: 0.
|
||||
quant_config (QuantConfig): Configs the oberser type of weight and activation. Default: quant_config_default.
|
||||
quant_dtype (QuantDtype): Specifies the FakeQuant datatype. Default: QuantDtype.INT8.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of MulQuant.
|
||||
|
@ -1510,25 +1290,23 @@ class MulQuant(Cell):
|
|||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Examples:
|
||||
>>> mul_quant = nn.MulQuant()
|
||||
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||
>>> input_y = Tensor(np.random.randint(-2, 2, (2, 3)), mindspore.float32)
|
||||
>>> result = mul_quant(input_x, input_y)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
super(MulQuant, self).__init__()
|
||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
quant_dtype=quant_dtype)
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, x1, x2):
|
||||
|
|
|
@ -27,6 +27,7 @@ from ...common import Tensor
|
|||
from ...common import dtype as mstype
|
||||
from ...common.api import _executor
|
||||
from ...nn.layer import quant
|
||||
from ...compression.common import QuantDtype
|
||||
from ...ops import functional as F
|
||||
from ...ops import operations as P
|
||||
from ...ops.operations import _inner_ops as inner
|
||||
|
@ -41,6 +42,46 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
|
|||
nn.HSwish: quant.HSwishQuant}
|
||||
|
||||
|
||||
def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver),
|
||||
quant_delay=(0, 0),
|
||||
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
|
||||
per_channel=(False, False),
|
||||
symmetric=(False, False),
|
||||
narrow_range=(False, False)
|
||||
):
|
||||
r"""
|
||||
Configs the oberser type of weights and data flow with quant params.
|
||||
|
||||
Args:
|
||||
quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent
|
||||
weights and second element represent data flow.
|
||||
Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver)
|
||||
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
|
||||
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
|
||||
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
|
||||
element represent weights and second element represent data flow.
|
||||
Default: (QuantDtype.INT8, QuantDtype.INT8)
|
||||
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
|
||||
then base on per channel otherwise base on per layer. The first element represent weights
|
||||
and second element represent data flow. Default: (False, False)
|
||||
symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
|
||||
symmetric otherwise base on asymmetric. The first element represent weights and second
|
||||
element represent data flow. Default: (False, False)
|
||||
narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
|
||||
The first element represents weights and the second element represents data flow. Default: (False, False)
|
||||
|
||||
Returns:
|
||||
QuantConfig, Contains the oberser type of weight and activation.
|
||||
"""
|
||||
weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0],
|
||||
per_channel=per_channel[0], symmetric=symmetric[0],
|
||||
narrow_range=narrow_range[0])
|
||||
act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1],
|
||||
per_channel=per_channel[-1], symmetric=symmetric[-1],
|
||||
narrow_range=narrow_range[-1])
|
||||
return quant.QuantConfig(weight=weight_observer, activation=act_observer)
|
||||
|
||||
|
||||
class _AddFakeQuantInput(nn.Cell):
|
||||
"""
|
||||
Add FakeQuant OP at input of the network. Only support one input case.
|
||||
|
@ -48,7 +89,8 @@ class _AddFakeQuantInput(nn.Cell):
|
|||
|
||||
def __init__(self, network, quant_delay=0):
|
||||
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
|
||||
self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
|
||||
self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6,
|
||||
quant_delay=quant_delay, ema=True)
|
||||
self.fake_quant_input.update_parameters_name('fake_quant_input.')
|
||||
self.network = network
|
||||
|
||||
|
@ -66,14 +108,14 @@ class _AddFakeQuantAfterSubCell(nn.Cell):
|
|||
def __init__(self, subcell, **kwargs):
|
||||
super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
|
||||
self.subcell = subcell
|
||||
self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
num_bits=kwargs["num_bits"],
|
||||
quant_delay=kwargs["quant_delay"],
|
||||
per_channel=kwargs["per_channel"],
|
||||
symmetric=kwargs["symmetric"],
|
||||
narrow_range=kwargs["narrow_range"])
|
||||
self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
quant_dtype=kwargs["quant_dtype"],
|
||||
quant_delay=kwargs["quant_delay"],
|
||||
per_channel=kwargs["per_channel"],
|
||||
symmetric=kwargs["symmetric"],
|
||||
narrow_range=kwargs["narrow_range"])
|
||||
|
||||
def construct(self, *data):
|
||||
output = self.subcell(*data)
|
||||
|
@ -93,8 +135,8 @@ class ConvertToQuantNetwork:
|
|||
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay")
|
||||
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
|
||||
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn")
|
||||
self.weight_bits = Validator.check_non_negative_int(kwargs["num_bits"][0], "weights bit")
|
||||
self.act_bits = Validator.check_int(kwargs["num_bits"][-1], 0, Rel.GE, "activations bit")
|
||||
self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype)
|
||||
self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype)
|
||||
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
|
||||
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
|
||||
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")
|
||||
|
@ -103,6 +145,11 @@ class ConvertToQuantNetwork:
|
|||
self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range")
|
||||
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
|
||||
quant.DenseBnAct: self._convert_dense}
|
||||
self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"],
|
||||
quant_dtype=kwargs["quant_dtype"],
|
||||
per_channel=kwargs["per_channel"],
|
||||
symmetric=kwargs["symmetric"],
|
||||
narrow_range=kwargs["narrow_range"])
|
||||
|
||||
def _convert_op_name(self, name):
|
||||
pattern = re.compile(r'([A-Z]{1})')
|
||||
|
@ -149,7 +196,7 @@ class ConvertToQuantNetwork:
|
|||
for name, prim_op in add_list:
|
||||
prefix = name
|
||||
add_quant = _AddFakeQuantAfterSubCell(prim_op,
|
||||
num_bits=self.act_bits,
|
||||
quant_dtype=self.act_dtype,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
|
@ -180,15 +227,12 @@ class ConvertToQuantNetwork:
|
|||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
momentum=bn_inner.momentum,
|
||||
quant_delay=self.weight_qdelay,
|
||||
freeze_bn=self.freeze_bn,
|
||||
per_channel=self.weight_channel,
|
||||
num_bits=self.weight_bits,
|
||||
fake=True,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range,
|
||||
has_bias=conv_inner.has_bias,
|
||||
bias_init=conv_inner.bias_init)
|
||||
bias_init=conv_inner.bias_init,
|
||||
freeze_bn=self.freeze_bn,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype,
|
||||
fake=True)
|
||||
# change original network BatchNormal OP parameters to quant network
|
||||
conv_inner.gamma = subcell.batchnorm.gamma
|
||||
conv_inner.beta = subcell.batchnorm.beta
|
||||
|
@ -209,13 +253,10 @@ class ConvertToQuantNetwork:
|
|||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
momentum=bn_inner.momentum,
|
||||
quant_delay=self.weight_qdelay,
|
||||
per_channel=self.weight_channel,
|
||||
num_bits=self.weight_bits,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range,
|
||||
has_bias=conv_inner.has_bias,
|
||||
bias_init=conv_inner.bias_init)
|
||||
bias_init=conv_inner.bias_init,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype)
|
||||
# change original network BatchNormal OP parameters to quant network
|
||||
conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
|
||||
conv_inner.batchnorm.beta = subcell.batchnorm.beta
|
||||
|
@ -234,11 +275,8 @@ class ConvertToQuantNetwork:
|
|||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
has_bias=conv_inner.has_bias,
|
||||
quant_delay=self.weight_qdelay,
|
||||
per_channel=self.weight_channel,
|
||||
num_bits=self.weight_bits,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range)
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype)
|
||||
# change original network Conv2D OP parameters to quant network
|
||||
conv_inner.weight = subcell.conv.weight
|
||||
if subcell.conv.has_bias:
|
||||
|
@ -249,7 +287,7 @@ class ConvertToQuantNetwork:
|
|||
elif subcell.after_fake:
|
||||
subcell.has_act = True
|
||||
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
|
||||
num_bits=self.act_bits,
|
||||
quant_dtype=self.act_dtype,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
|
@ -264,11 +302,8 @@ class ConvertToQuantNetwork:
|
|||
dense_inner = quant.DenseQuant(dense_inner.in_channels,
|
||||
dense_inner.out_channels,
|
||||
has_bias=dense_inner.has_bias,
|
||||
num_bits=self.weight_bits,
|
||||
quant_delay=self.weight_qdelay,
|
||||
per_channel=self.weight_channel,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range)
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype)
|
||||
# change original network Dense OP parameters to quant network
|
||||
dense_inner.weight = subcell.dense.weight
|
||||
if subcell.dense.has_bias:
|
||||
|
@ -279,7 +314,7 @@ class ConvertToQuantNetwork:
|
|||
elif subcell.after_fake:
|
||||
subcell.has_act = True
|
||||
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
|
||||
num_bits=self.act_bits,
|
||||
quant_dtype=self.act_dtype,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
|
@ -291,11 +326,8 @@ class ConvertToQuantNetwork:
|
|||
if act_class not in _ACTIVATION_MAP:
|
||||
raise ValueError("Unsupported activation in auto quant: ", act_class)
|
||||
return _ACTIVATION_MAP[act_class](activation=activation,
|
||||
num_bits=self.act_bits,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
narrow_range=self.act_range)
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.act_dtype)
|
||||
|
||||
|
||||
class ExportToQuantInferNetwork:
|
||||
|
@ -523,7 +555,7 @@ def convert_quant_network(network,
|
|||
bn_fold=True,
|
||||
freeze_bn=10000000,
|
||||
quant_delay=(0, 0),
|
||||
num_bits=(8, 8),
|
||||
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
|
||||
per_channel=(False, False),
|
||||
symmetric=(False, False),
|
||||
narrow_range=(False, False)
|
||||
|
@ -537,8 +569,9 @@ def convert_quant_network(network,
|
|||
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.
|
||||
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
|
||||
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
|
||||
num_bits (int, list or tuple): Number of bits to use for quantize weights and activations. The first
|
||||
element represent weights and second element represent data flow. Default: (8, 8)
|
||||
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
|
||||
element represent weights and second element represent data flow.
|
||||
Default: (QuantDtype.INT8, QuantDtype.INT8)
|
||||
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
|
||||
then base on per channel otherwise base on per layer. The first element represent weights
|
||||
and second element represent data flow. Default: (False, False)
|
||||
|
@ -561,7 +594,7 @@ def convert_quant_network(network,
|
|||
return value
|
||||
|
||||
quant_delay = convert2list("quant delay", quant_delay)
|
||||
num_bits = convert2list("num bits", num_bits)
|
||||
quant_dtype = convert2list("quant dtype", quant_dtype)
|
||||
per_channel = convert2list("per channel", per_channel)
|
||||
symmetric = convert2list("symmetric", symmetric)
|
||||
narrow_range = convert2list("narrow range", narrow_range)
|
||||
|
@ -573,7 +606,7 @@ def convert_quant_network(network,
|
|||
quant_delay=quant_delay,
|
||||
bn_fold=bn_fold,
|
||||
freeze_bn=freeze_bn,
|
||||
num_bits=num_bits,
|
||||
quant_dtype=quant_dtype,
|
||||
per_channel=per_channel,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
|
|
|
@ -17,12 +17,14 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant
|
||||
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant
|
||||
from mindspore.train.quant import quant
|
||||
|
||||
_ema_decay = 0.999
|
||||
_symmetric = True
|
||||
_fake = True
|
||||
_per_channel = True
|
||||
_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))
|
||||
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
|
@ -89,7 +91,7 @@ class ConvBNReLU(nn.Cell):
|
|||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
|
||||
group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric)
|
||||
group=groups, fake=_fake, quant_config=_quant_config)
|
||||
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
|
||||
self.features = nn.SequentialCell(layers)
|
||||
|
||||
|
@ -124,13 +126,12 @@ class ResidualBlock(nn.Cell):
|
|||
channel = out_channel // self.expansion
|
||||
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
|
||||
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
|
||||
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1, pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False)
|
||||
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1,
|
||||
pad_mode='same', padding=0)
|
||||
|
||||
|
@ -142,16 +143,15 @@ class ResidualBlock(nn.Cell):
|
|||
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=stride,
|
||||
pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay,
|
||||
symmetric=False)
|
||||
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay,
|
||||
symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
fake=_fake,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
quant_config=\
|
||||
_quant_config,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
pad_mode='same',
|
||||
|
@ -235,9 +235,8 @@ class ResNet(nn.Cell):
|
|||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel,
|
||||
symmetric=_symmetric)
|
||||
self.output_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)
|
||||
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, quant_config=_quant_config)
|
||||
self.output_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay)
|
||||
|
||||
def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
|
||||
"""
|
||||
|
|
|
@ -13,20 +13,19 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""ResNet."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant
|
||||
from mindspore.train.quant import quant
|
||||
|
||||
_ema_decay = 0.999
|
||||
_symmetric = True
|
||||
_fake = True
|
||||
_per_channel = True
|
||||
_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))
|
||||
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
|
@ -93,7 +92,7 @@ class ConvBNReLU(nn.Cell):
|
|||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
|
||||
group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric)
|
||||
group=groups, fake=_fake, quant_config=_quant_config)
|
||||
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
|
||||
self.features = nn.SequentialCell(layers)
|
||||
|
||||
|
@ -128,14 +127,12 @@ class ResidualBlock(nn.Cell):
|
|||
channel = out_channel // self.expansion
|
||||
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
|
||||
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
|
||||
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1, pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMax(
|
||||
ema=True, ema_decay=_ema_decay, symmetric=False)
|
||||
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1,
|
||||
pad_mode='same', padding=0)
|
||||
|
||||
|
@ -147,16 +144,15 @@ class ResidualBlock(nn.Cell):
|
|||
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=stride,
|
||||
pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay,
|
||||
symmetric=False)
|
||||
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay,
|
||||
symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
fake=_fake,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
quant_config=\
|
||||
_quant_config,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
pad_mode='same',
|
||||
|
@ -212,8 +208,7 @@ class ResNet(nn.Cell):
|
|||
super(ResNet, self).__init__()
|
||||
|
||||
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
|
||||
raise ValueError(
|
||||
"the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
|
||||
|
||||
self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
|
||||
|
@ -241,10 +236,8 @@ class ResNet(nn.Cell):
|
|||
|
||||
self.mean = P.ReduceMean(keep_dims=True)
|
||||
self.flatten = nn.Flatten()
|
||||
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel,
|
||||
symmetric=_symmetric)
|
||||
self.output_fake = nn.FakeQuantWithMinMax(
|
||||
ema=True, ema_decay=_ema_decay)
|
||||
self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, quant_config=_quant_config)
|
||||
self.output_fake = nn.FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay)
|
||||
|
||||
# init weights
|
||||
self._initialize_weights()
|
||||
|
|
Loading…
Reference in New Issue