!7677 modify compression module notes

Merge pull request !7677 from yuchaojie/quant2
This commit is contained in:
mindspore-ci-bot 2020-10-23 17:25:21 +08:00 committed by Gitee
commit cf06b22f3d
6 changed files with 98 additions and 43 deletions

View File

@ -17,3 +17,6 @@ Compression common module.
"""
from .constant import *
__all__ = []
__all__.extend(constant.__all__)

View File

@ -24,7 +24,7 @@ __all__ = ["QuantDtype"]
@enum.unique
class QuantDtype(enum.Enum):
"""
For type switch
An enum for quant datatype, contains `INT2`~`INT8`, `UINT2`~`UINT8`.
"""
INT2 = "INT2"
INT3 = "INT3"
@ -42,20 +42,42 @@ class QuantDtype(enum.Enum):
UINT7 = "UINT7"
UINT8 = "UINT8"
FLOAT16 = "FLOAT16"
FLOAT32 = "FLOAT32"
def __str__(self):
return f"{self.name}"
@staticmethod
def is_signed(dtype):
"""
Get whether the quant datatype is signed.
Args:
dtype (QuantDtype): quant datatype.
Returns:
bool, whether the input quant datatype is signed.
Examples:
>>> quant_dtype = QuantDtype.INT8
>>> is_signed = QuantDtype.is_signed(quant_dtype)
"""
return dtype in [QuantDtype.INT2, QuantDtype.INT3, QuantDtype.INT4, QuantDtype.INT5,
QuantDtype.INT6, QuantDtype.INT7, QuantDtype.INT8]
@staticmethod
def switch_signed(dtype):
"""switch signed"""
"""
Swicth the signed state of the input quant datatype.
Args:
dtype (QuantDtype): quant datatype.
Returns:
QuantDtype, quant datatype with opposite signed state as the input.
Examples:
>>> quant_dtype = QuantDtype.INT8
>>> quant_dtype = QuantDtype.switch_signed(quant_dtype)
"""
type_map = {
QuantDtype.INT2: QuantDtype.UINT2,
QuantDtype.INT3: QuantDtype.UINT3,
@ -75,11 +97,20 @@ class QuantDtype(enum.Enum):
return type_map[dtype]
@DynamicClassAttribute
def value(self):
def _value(self):
"""The value of the Enum member."""
return int(re.search(r"(\d+)", self._value_).group(1))
@DynamicClassAttribute
def num_bits(self):
"""The num_bits of the Enum member."""
return self.value
"""
Get the num bits of the QuantDtype member.
Returns:
int, the num bits of the QuantDtype member
Examples:
>>> quant_dtype = QuantDtype.INT8
>>> num_bits = quant_dtype.num_bits
"""
return self._value

View File

@ -19,3 +19,8 @@ Compression quant module.
from .quantizer import *
from .qat import *
from .quant_utils import *
__all__ = []
__all__.extend(qat.__all__)
__all__.extend(quantizer.__all__)
__all__.extend(quant_utils.__all__)

View File

@ -125,34 +125,6 @@ class _AddFakeQuantAfterSubCell(nn.Cell):
return output
class ConvertToQuantNetwork:
"""
Convert network to quantization aware network
"""
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self, **kwargs):
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay")
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay")
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn")
self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype)
self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype)
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")
self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric")
self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range")
self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range")
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense}
self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"],
quant_dtype=kwargs["quant_dtype"],
per_channel=kwargs["per_channel"],
symmetric=kwargs["symmetric"],
narrow_range=kwargs["narrow_range"])
class QuantizationAwareTraining(Quantizer):
r"""
Quantizer for quantization aware training.
@ -175,6 +147,39 @@ class QuantizationAwareTraining(Quantizer):
The first element represents weights and the second element represents data flow. Default: (False, False)
optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only
support QAT. Default: OptimizeOption.QAT
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self, num_class=10, channel=1):
>>> super(LeNet5, self).__init__()
>>> self.type = "fusion"
>>> self.num_class = num_class
>>>
>>> # change `nn.Conv2d` to `nn.Conv2dBnAct`
>>> self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
>>> self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
>>> # change `nn.Dense` to `nn.DenseBnAct`
>>> self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
>>> self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
>>> self.fc3 = nn.DenseBnAct(84, self.num_class)
>>>
>>> self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
>>> self.flatten = nn.Flatten()
>>>
>>> def construct(self, x):
>>> x = self.conv1(x)
>>> x = self.max_pool2d(x)
>>> x = self.conv2(x)
>>> x = self.max_pool2d(x)
>>> x = self.flatten(x)
>>> x = self.fc1(x)
>>> x = self.fc2(x)
>>> x = self.fc3(x)
>>> return x
>>>
>>> net = Net()
>>> quantizer = QuantizationAwareTraining(bn_fold=False, per_channel=[True, False], symmetric=[True, False])
>>> net_qat = quantizer.quantize(net)
"""
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
@ -230,6 +235,17 @@ class QuantizationAwareTraining(Quantizer):
return name_new
def quantize(self, network):
"""
Quant API to convert input network to a quantization aware training network
Args:
network (Cell): network to be quantized.
Examples:
>>> net = Net()
>>> quantizer = QuantizationAwareTraining()
>>> net_qat = quantizer.quantize(net)
"""
support_device = ["Ascend", "GPU"]
if context.get_context('device_target') not in support_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))

View File

@ -267,13 +267,13 @@ def without_fold_batchnorm(weight, cell_quant):
def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None):
"""
load fp32 model parameters to quantization model.
r"""
Load fp32 model parameters into quantization model.
Args:
quant_model: quantization model.
params_dict: f32 param.
quant_new_params:parameters that exist in quantative network but not in unquantative network.
params_dict: parameter dict that stores fp32 parameters.
quant_new_params: parameters that exist in quantative network but not in unquantative network.
Returns:
None

View File

@ -19,12 +19,12 @@ from enum import Enum
from ..._checkparam import Validator
__all__ = ["OptimizeOption", "Quantizer"]
__all__ = ["OptimizeOption"]
class OptimizeOption(Enum):
"""
An enum for the model quantization optimize option.
r"""
An enum for the model quantization optimize option, currently only support `QAT`.
"""
# using quantization aware training
QAT = "QAT"