forked from mindspore-Ecosystem/mindspore
modify compression module notes
This commit is contained in:
parent
b8fbabae34
commit
bf7cd2ceaa
|
@ -17,3 +17,6 @@ Compression common module.
|
|||
"""
|
||||
|
||||
from .constant import *
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(constant.__all__)
|
||||
|
|
|
@ -24,7 +24,7 @@ __all__ = ["QuantDtype"]
|
|||
@enum.unique
|
||||
class QuantDtype(enum.Enum):
|
||||
"""
|
||||
For type switch
|
||||
An enum for quant datatype, contains `INT2`~`INT8`, `UINT2`~`UINT8`.
|
||||
"""
|
||||
INT2 = "INT2"
|
||||
INT3 = "INT3"
|
||||
|
@ -42,20 +42,42 @@ class QuantDtype(enum.Enum):
|
|||
UINT7 = "UINT7"
|
||||
UINT8 = "UINT8"
|
||||
|
||||
FLOAT16 = "FLOAT16"
|
||||
FLOAT32 = "FLOAT32"
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}"
|
||||
|
||||
@staticmethod
|
||||
def is_signed(dtype):
|
||||
"""
|
||||
Get whether the quant datatype is signed.
|
||||
|
||||
Args:
|
||||
dtype (QuantDtype): quant datatype.
|
||||
|
||||
Returns:
|
||||
bool, whether the input quant datatype is signed.
|
||||
|
||||
Examples:
|
||||
>>> quant_dtype = QuantDtype.INT8
|
||||
>>> is_signed = QuantDtype.is_signed(quant_dtype)
|
||||
"""
|
||||
return dtype in [QuantDtype.INT2, QuantDtype.INT3, QuantDtype.INT4, QuantDtype.INT5,
|
||||
QuantDtype.INT6, QuantDtype.INT7, QuantDtype.INT8]
|
||||
|
||||
@staticmethod
|
||||
def switch_signed(dtype):
|
||||
"""switch signed"""
|
||||
"""
|
||||
Swicth the signed state of the input quant datatype.
|
||||
|
||||
Args:
|
||||
dtype (QuantDtype): quant datatype.
|
||||
|
||||
Returns:
|
||||
QuantDtype, quant datatype with opposite signed state as the input.
|
||||
|
||||
Examples:
|
||||
>>> quant_dtype = QuantDtype.INT8
|
||||
>>> quant_dtype = QuantDtype.switch_signed(quant_dtype)
|
||||
"""
|
||||
type_map = {
|
||||
QuantDtype.INT2: QuantDtype.UINT2,
|
||||
QuantDtype.INT3: QuantDtype.UINT3,
|
||||
|
@ -75,11 +97,20 @@ class QuantDtype(enum.Enum):
|
|||
return type_map[dtype]
|
||||
|
||||
@DynamicClassAttribute
|
||||
def value(self):
|
||||
def _value(self):
|
||||
"""The value of the Enum member."""
|
||||
return int(re.search(r"(\d+)", self._value_).group(1))
|
||||
|
||||
@DynamicClassAttribute
|
||||
def num_bits(self):
|
||||
"""The num_bits of the Enum member."""
|
||||
return self.value
|
||||
"""
|
||||
Get the num bits of the QuantDtype member.
|
||||
|
||||
Returns:
|
||||
int, the num bits of the QuantDtype member
|
||||
|
||||
Examples:
|
||||
>>> quant_dtype = QuantDtype.INT8
|
||||
>>> num_bits = quant_dtype.num_bits
|
||||
"""
|
||||
return self._value
|
||||
|
|
|
@ -19,3 +19,8 @@ Compression quant module.
|
|||
from .quantizer import *
|
||||
from .qat import *
|
||||
from .quant_utils import *
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(qat.__all__)
|
||||
__all__.extend(quantizer.__all__)
|
||||
__all__.extend(quant_utils.__all__)
|
||||
|
|
|
@ -125,34 +125,6 @@ class _AddFakeQuantAfterSubCell(nn.Cell):
|
|||
return output
|
||||
|
||||
|
||||
class ConvertToQuantNetwork:
|
||||
"""
|
||||
Convert network to quantization aware network
|
||||
"""
|
||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
|
||||
self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay")
|
||||
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay")
|
||||
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
|
||||
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn")
|
||||
self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype)
|
||||
self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype)
|
||||
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
|
||||
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
|
||||
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")
|
||||
self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric")
|
||||
self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range")
|
||||
self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range")
|
||||
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
|
||||
quant.DenseBnAct: self._convert_dense}
|
||||
self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"],
|
||||
quant_dtype=kwargs["quant_dtype"],
|
||||
per_channel=kwargs["per_channel"],
|
||||
symmetric=kwargs["symmetric"],
|
||||
narrow_range=kwargs["narrow_range"])
|
||||
|
||||
class QuantizationAwareTraining(Quantizer):
|
||||
r"""
|
||||
Quantizer for quantization aware training.
|
||||
|
@ -175,6 +147,39 @@ class QuantizationAwareTraining(Quantizer):
|
|||
The first element represents weights and the second element represents data flow. Default: (False, False)
|
||||
optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only
|
||||
support QAT. Default: OptimizeOption.QAT
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self, num_class=10, channel=1):
|
||||
>>> super(LeNet5, self).__init__()
|
||||
>>> self.type = "fusion"
|
||||
>>> self.num_class = num_class
|
||||
>>>
|
||||
>>> # change `nn.Conv2d` to `nn.Conv2dBnAct`
|
||||
>>> self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
|
||||
>>> self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
|
||||
>>> # change `nn.Dense` to `nn.DenseBnAct`
|
||||
>>> self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
||||
>>> self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
||||
>>> self.fc3 = nn.DenseBnAct(84, self.num_class)
|
||||
>>>
|
||||
>>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
>>> self.flatten = nn.Flatten()
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> x = self.conv1(x)
|
||||
>>> x = self.max_pool2d(x)
|
||||
>>> x = self.conv2(x)
|
||||
>>> x = self.max_pool2d(x)
|
||||
>>> x = self.flatten(x)
|
||||
>>> x = self.fc1(x)
|
||||
>>> x = self.fc2(x)
|
||||
>>> x = self.fc3(x)
|
||||
>>> return x
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
|
||||
>>> net_qat = quantizer.quantize(net)
|
||||
"""
|
||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
|
@ -230,6 +235,17 @@ class QuantizationAwareTraining(Quantizer):
|
|||
return name_new
|
||||
|
||||
def quantize(self, network):
|
||||
"""
|
||||
Quant API to convert input network to a quantization aware training network
|
||||
|
||||
Args:
|
||||
network (Cell): network to be quantized.
|
||||
|
||||
Examples:
|
||||
>>> net = Net()
|
||||
>>> quantizer = QuantizationAwareTraining()
|
||||
>>> net_qat = quantizer.quantize(net)
|
||||
"""
|
||||
support_device = ["Ascend", "GPU"]
|
||||
if context.get_context('device_target') not in support_device:
|
||||
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
||||
|
|
|
@ -267,13 +267,13 @@ def without_fold_batchnorm(weight, cell_quant):
|
|||
|
||||
|
||||
def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None):
|
||||
"""
|
||||
load fp32 model parameters to quantization model.
|
||||
r"""
|
||||
Load fp32 model parameters into quantization model.
|
||||
|
||||
Args:
|
||||
quant_model: quantization model.
|
||||
params_dict: f32 param.
|
||||
quant_new_params:parameters that exist in quantative network but not in unquantative network.
|
||||
params_dict: parameter dict that stores fp32 parameters.
|
||||
quant_new_params: parameters that exist in quantative network but not in unquantative network.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
|
|
@ -19,12 +19,12 @@ from enum import Enum
|
|||
|
||||
from ..._checkparam import Validator
|
||||
|
||||
__all__ = ["OptimizeOption", "Quantizer"]
|
||||
__all__ = ["OptimizeOption"]
|
||||
|
||||
|
||||
class OptimizeOption(Enum):
|
||||
"""
|
||||
An enum for the model quantization optimize option.
|
||||
r"""
|
||||
An enum for the model quantization optimize option, currently only support `QAT`.
|
||||
"""
|
||||
# using quantization aware training
|
||||
QAT = "QAT"
|
||||
|
|
Loading…
Reference in New Issue