From c8ec34d63857748f03f0ed33ada3513ab2ada378 Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Mon, 19 Oct 2020 16:45:45 +0800 Subject: [PATCH] move train.quant to compression module & add QuantizationAwareTraining --- mindspore/compression/export/__init__.py | 17 + .../export/quant_export.py} | 451 ++---------------- .../{train => compression}/quant/__init__.py | 13 +- mindspore/compression/quant/qat.py | 406 ++++++++++++++++ .../quant/quant_utils.py | 3 + mindspore/compression/quant/quantizer.py | 52 ++ mindspore/train/serialization.py | 10 +- .../official/cv/lenet_quant/eval_quant.py | 10 +- model_zoo/official/cv/lenet_quant/export.py | 10 +- .../official/cv/lenet_quant/train_quant.py | 9 +- .../official/cv/mobilenetv2_quant/eval.py | 7 +- .../official/cv/mobilenetv2_quant/export.py | 7 +- .../official/cv/mobilenetv2_quant/train.py | 12 +- model_zoo/official/cv/resnet50_quant/eval.py | 21 +- .../models/resnet_quant_manual.py | 47 +- model_zoo/official/cv/resnet50_quant/train.py | 9 +- .../cv/yolov3_darknet53_quant/eval.py | 6 +- .../mindspore_hub_conf.py | 10 +- .../cv/yolov3_darknet53_quant/train.py | 6 +- .../lenet_quant/test_lenet_quant.py | 27 +- .../test_mobilenetv2_quant.py | 6 +- .../resnet50_quant/resnet_quant_manual.py | 47 +- .../resnet50_quant/test_resnet50_quant.py | 10 +- tests/ut/python/train/quant/test_quant.py | 25 +- 24 files changed, 689 insertions(+), 532 deletions(-) create mode 100644 mindspore/compression/export/__init__.py rename mindspore/{train/quant/quant.py => compression/export/quant_export.py} (51%) rename mindspore/{train => compression}/quant/__init__.py (55%) create mode 100644 mindspore/compression/quant/qat.py rename mindspore/{train => compression}/quant/quant_utils.py (99%) create mode 100644 mindspore/compression/quant/quantizer.py diff --git a/mindspore/compression/export/__init__.py b/mindspore/compression/export/__init__.py new file mode 100644 index 00000000000..48e59baa71a --- /dev/null +++ b/mindspore/compression/export/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Compression export module. +""" diff --git a/mindspore/train/quant/quant.py b/mindspore/compression/export/quant_export.py similarity index 51% rename from mindspore/train/quant/quant.py rename to mindspore/compression/export/quant_export.py index 4c564d95ce4..94e56bc9f62 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/compression/export/quant_export.py @@ -12,323 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""quantization aware.""" +"""Export for quantization.""" import copy -import re import numpy as np import mindspore.context as context from ... import log as logger from ... import nn, ops -from ..._checkparam import Validator, Rel +from ..._checkparam import Validator from ...common import Tensor from ...common import dtype as mstype from ...common.api import _executor from ...nn.layer import quant -from ...compression.common import QuantDtype -from ...ops import functional as F from ...ops import operations as P from ...ops.operations import _inner_ops as inner from ...train import serialization -from . import quant_utils - -_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant, - nn.ReLU6: quant.ActQuant, - nn.Sigmoid: quant.ActQuant, - nn.LeakyReLU: quant.LeakyReLUQuant, - nn.HSigmoid: quant.HSigmoidQuant, - nn.HSwish: quant.HSwishQuant} +from ..quant import quant_utils +from ..quant.qat import QuantizationAwareTraining, _AddFakeQuantInput, _AddFakeQuantAfterSubCell -def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver), - quant_delay=(0, 0), - quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), - per_channel=(False, False), - symmetric=(False, False), - narrow_range=(False, False) - ): - r""" - Configs the oberser type of weights and data flow with quant params. - - Args: - quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent - weights and second element represent data flow. - Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver) - quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during - eval. The first element represent weights and second element represent data flow. Default: (0, 0) - quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first - element represent weights and second element represent data flow. - Default: (QuantDtype.INT8, QuantDtype.INT8) - per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` - then base on per channel otherwise base on per layer. The first element represent weights - and second element represent data flow. Default: (False, False) - symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on - symmetric otherwise base on asymmetric. The first element represent weights and second - element represent data flow. Default: (False, False) - narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. - The first element represents weights and the second element represents data flow. Default: (False, False) - - Returns: - QuantConfig, Contains the oberser type of weight and activation. - """ - weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0], - per_channel=per_channel[0], symmetric=symmetric[0], - narrow_range=narrow_range[0]) - act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1], - per_channel=per_channel[-1], symmetric=symmetric[-1], - narrow_range=narrow_range[-1]) - return quant.QuantConfig(weight=weight_observer, activation=act_observer) - - -class _AddFakeQuantInput(nn.Cell): - """ - Add FakeQuant OP at input of the network. Only support one input case. - """ - - def __init__(self, network, quant_delay=0): - super(_AddFakeQuantInput, self).__init__(auto_prefix=False) - self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6, - quant_delay=quant_delay, ema=True) - self.fake_quant_input.update_parameters_name('fake_quant_input.') - self.network = network - - def construct(self, data): - data = self.fake_quant_input(data) - output = self.network(data) - return output - - -class _AddFakeQuantAfterSubCell(nn.Cell): - """ - Add FakeQuant OP after of the sub Cell. - """ - - def __init__(self, subcell, **kwargs): - super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) - self.subcell = subcell - self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6, - max_init=6, - ema=True, - quant_dtype=kwargs["quant_dtype"], - quant_delay=kwargs["quant_delay"], - per_channel=kwargs["per_channel"], - symmetric=kwargs["symmetric"], - narrow_range=kwargs["narrow_range"]) - - def construct(self, *data): - output = self.subcell(*data) - output = self.fake_quant_act(output) - return output - - -class ConvertToQuantNetwork: - """ - Convert network to quantization aware network - """ - __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] - - def __init__(self, **kwargs): - self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) - self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay") - self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay") - self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold") - self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn") - self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype) - self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype) - self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel") - self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel") - self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric") - self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric") - self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range") - self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range") - self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, - quant.DenseBnAct: self._convert_dense} - self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"], - quant_dtype=kwargs["quant_dtype"], - per_channel=kwargs["per_channel"], - symmetric=kwargs["symmetric"], - narrow_range=kwargs["narrow_range"]) - - def _convert_op_name(self, name): - pattern = re.compile(r'([A-Z]{1})') - name_new = re.sub(pattern, r'_\1', name).lower() - if name_new[0] == '_': - name_new = name_new[1:] - return name_new - - def run(self): - self.network.update_cell_prefix() - network = self._convert_subcells2quant(self.network) - self.network.update_cell_type("quant") - return network - - def _convert_subcells2quant(self, network): - """ - convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell - """ - cells = network.name_cells() - change = False - for name in cells: - subcell = cells[name] - if subcell == network: - continue - elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)): - prefix = subcell.param_prefix - new_subcell = self._convert_method_map[type(subcell)](subcell) - new_subcell.update_parameters_name(prefix + '.') - network.insert_child_to_cell(name, new_subcell) - change = True - else: - self._convert_subcells2quant(subcell) - if isinstance(network, nn.SequentialCell) and change: - network.cell_list = list(network.cells()) - - # add FakeQuant OP after OP in while list - add_list = [] - for name in network.__dict__: - if name[0] == '_': - continue - attr = network.__dict__[name] - if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__: - add_list.append((name, attr)) - for name, prim_op in add_list: - prefix = name - add_quant = _AddFakeQuantAfterSubCell(prim_op, - quant_dtype=self.act_dtype, - quant_delay=self.act_qdelay, - per_channel=self.act_channel, - symmetric=self.act_symmetric, - narrow_range=self.act_range) - prefix = self._convert_op_name(prim_op.name) - if network.param_prefix: - prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) - add_quant.update_parameters_name(prefix + '.') - del network.__dict__[name] - network.insert_child_to_cell(name, add_quant) - return network - - def _convert_conv(self, subcell): - """ - convert Conv2d cell to quant cell - """ - conv_inner = subcell.conv - if subcell.has_bn: - if self.bn_fold: - bn_inner = subcell.batchnorm - conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, - conv_inner.out_channels, - kernel_size=conv_inner.kernel_size, - stride=conv_inner.stride, - pad_mode=conv_inner.pad_mode, - padding=conv_inner.padding, - dilation=conv_inner.dilation, - group=conv_inner.group, - eps=bn_inner.eps, - momentum=bn_inner.momentum, - has_bias=conv_inner.has_bias, - bias_init=conv_inner.bias_init, - freeze_bn=self.freeze_bn, - quant_config=self.quant_config, - quant_dtype=self.weight_dtype, - fake=True) - # change original network BatchNormal OP parameters to quant network - conv_inner.gamma = subcell.batchnorm.gamma - conv_inner.beta = subcell.batchnorm.beta - conv_inner.moving_mean = subcell.batchnorm.moving_mean - conv_inner.moving_variance = subcell.batchnorm.moving_variance - del subcell.batchnorm - subcell.batchnorm = None - subcell.has_bn = False - else: - bn_inner = subcell.batchnorm - conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, - conv_inner.out_channels, - kernel_size=conv_inner.kernel_size, - stride=conv_inner.stride, - pad_mode=conv_inner.pad_mode, - padding=conv_inner.padding, - dilation=conv_inner.dilation, - group=conv_inner.group, - eps=bn_inner.eps, - momentum=bn_inner.momentum, - has_bias=conv_inner.has_bias, - bias_init=conv_inner.bias_init, - quant_config=self.quant_config, - quant_dtype=self.weight_dtype) - # change original network BatchNormal OP parameters to quant network - conv_inner.batchnorm.gamma = subcell.batchnorm.gamma - conv_inner.batchnorm.beta = subcell.batchnorm.beta - conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean - conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance - del subcell.batchnorm - subcell.batchnorm = None - subcell.has_bn = False - else: - conv_inner = quant.Conv2dQuant(conv_inner.in_channels, - conv_inner.out_channels, - kernel_size=conv_inner.kernel_size, - stride=conv_inner.stride, - pad_mode=conv_inner.pad_mode, - padding=conv_inner.padding, - dilation=conv_inner.dilation, - group=conv_inner.group, - has_bias=conv_inner.has_bias, - quant_config=self.quant_config, - quant_dtype=self.weight_dtype) - # change original network Conv2D OP parameters to quant network - conv_inner.weight = subcell.conv.weight - if subcell.conv.has_bias: - conv_inner.bias = subcell.conv.bias - subcell.conv = conv_inner - if subcell.has_act and subcell.activation is not None: - subcell.activation = self._convert_activation(subcell.activation) - elif subcell.after_fake: - subcell.has_act = True - subcell.activation = _AddFakeQuantAfterSubCell(F.identity, - quant_dtype=self.act_dtype, - quant_delay=self.act_qdelay, - per_channel=self.act_channel, - symmetric=self.act_symmetric, - narrow_range=self.act_range) - return subcell - - def _convert_dense(self, subcell): - """ - convert dense cell to combine dense cell - """ - dense_inner = subcell.dense - dense_inner = quant.DenseQuant(dense_inner.in_channels, - dense_inner.out_channels, - has_bias=dense_inner.has_bias, - quant_config=self.quant_config, - quant_dtype=self.weight_dtype) - # change original network Dense OP parameters to quant network - dense_inner.weight = subcell.dense.weight - if subcell.dense.has_bias: - dense_inner.bias = subcell.dense.bias - subcell.dense = dense_inner - if subcell.has_act and subcell.activation is not None: - subcell.activation = self._convert_activation(subcell.activation) - elif subcell.after_fake: - subcell.has_act = True - subcell.activation = _AddFakeQuantAfterSubCell(F.identity, - quant_dtype=self.act_dtype, - quant_delay=self.act_qdelay, - per_channel=self.act_channel, - symmetric=self.act_symmetric, - narrow_range=self.act_range) - return subcell - - def _convert_activation(self, activation): - act_class = activation.__class__ - if act_class not in _ACTIVATION_MAP: - raise ValueError("Unsupported activation in auto quant: ", act_class) - return _ACTIVATION_MAP[act_class](activation=activation, - quant_config=self.quant_config, - quant_dtype=self.act_dtype) - +__all__ = ["export", "manual_export"] class ExportToQuantInferNetwork: """ @@ -499,7 +204,7 @@ class ExportToQuantInferNetwork: change = True elif isinstance(subcell, _AddFakeQuantAfterSubCell): op = subcell.subcell - if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive): + if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive): if self.is_mindir: op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy())) op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy())) @@ -553,106 +258,6 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format=' serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) -def convert_quant_network(network, - bn_fold=True, - freeze_bn=10000000, - quant_delay=(0, 0), - quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), - per_channel=(False, False), - symmetric=(False, False), - narrow_range=(False, False) - ): - r""" - Create quantization aware training network. - - Args: - network (Cell): Obtain a pipeline through network for saving graph summary. - bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True. - freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7. - quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during - eval. The first element represent weights and second element represent data flow. Default: (0, 0) - quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first - element represent weights and second element represent data flow. - Default: (QuantDtype.INT8, QuantDtype.INT8) - per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` - then base on per channel otherwise base on per layer. The first element represent weights - and second element represent data flow. Default: (False, False) - symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on - symmetric otherwise base on asymmetric. The first element represent weights and second - element represent data flow. Default: (False, False) - narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. - The first element represents weights and the second element represents data flow. Default: (False, False) - - Returns: - Cell, Network which has change to quantization aware training network cell. - """ - support_device = ["Ascend", "GPU"] - - def convert2list(name, value): - if not isinstance(value, list) and not isinstance(value, tuple): - value = [value] - elif len(value) > 2: - raise ValueError("input `{}` len should less then 2".format(name)) - return value - - quant_delay = convert2list("quant delay", quant_delay) - quant_dtype = convert2list("quant dtype", quant_dtype) - per_channel = convert2list("per channel", per_channel) - symmetric = convert2list("symmetric", symmetric) - narrow_range = convert2list("narrow range", narrow_range) - - if context.get_context('device_target') not in support_device: - raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) - - net = ConvertToQuantNetwork(network=network, - quant_delay=quant_delay, - bn_fold=bn_fold, - freeze_bn=freeze_bn, - quant_dtype=quant_dtype, - per_channel=per_channel, - symmetric=symmetric, - narrow_range=narrow_range) - return net.run() - -def manual_export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='MINDIR'): - """ - Manual exports MindSpore quantization predict model to deploy wiAIR and MINDIR. - - Args: - network (Cell): MindSpore network produced by `convert_quant_network`. - inputs (Tensor): Inputs of the `quantization aware training network`. - file_name (str): File name of model to export. - mean (int, float): Input data mean. Default: 127.5. - std_dev (int, float): Input data variance. Default: 127.5. - file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported - quantization aware model. Default: 'AIR'. - - - AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of - Ascend model. - - MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format - for MindSpore models. - Recommended suffix for output file is '.mindir'. - """ - supported_device = ["Ascend", "GPU"] - supported_formats = ['AIR', 'MINDIR'] - - mean = Validator.check_type("mean", mean, (int, float)) - std_dev = Validator.check_type("std_dev", std_dev, (int, float)) - - if context.get_context('device_target') not in supported_device: - raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) - - if file_format not in supported_formats: - raise ValueError('Illegal file format {}.'.format(file_format)) - - network.set_train(False) - if file_format == "MINDIR": - exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) - else: - exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=False) - deploy_net = exporter.run() - serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) - class ExportManualQuantNetwork: """ Convert anual quantization aware network to infer network. @@ -713,7 +318,7 @@ class ExportManualQuantNetwork: elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)): network, change = self._convert_subcell(network, change, name, subcell, core=False) - elif isinstance(subcell, quant.FakeQuantWithMinMax) and self.upcell: + elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver) and self.upcell: np_type = mstype.dtype_to_nptype(self.data_type) _, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(subcell, np_type) self.upcell.core_op.add_prim_attr('output_maxq', Tensor(maxq)) @@ -721,7 +326,7 @@ class ExportManualQuantNetwork: network.insert_child_to_cell(self.upname, self.upcell) elif isinstance(subcell, _AddFakeQuantAfterSubCell): op = subcell.subcell - if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive): + if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive): if self.is_mindir: op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy())) op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy())) @@ -845,3 +450,43 @@ class ExportManualQuantNetwork: else: block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) return block + + +def manual_export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='MINDIR'): + """ + Manual exports MindSpore quantization predict model to deploy wiAIR and MINDIR. + + Args: + network (Cell): MindSpore network produced by `convert_quant_network`. + inputs (Tensor): Inputs of the `quantization aware training network`. + file_name (str): File name of model to export. + mean (int, float): Input data mean. Default: 127.5. + std_dev (int, float): Input data variance. Default: 127.5. + file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported + quantization aware model. Default: 'AIR'. + + - AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of + Ascend model. + - MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format + for MindSpore models. + Recommended suffix for output file is '.mindir'. + """ + supported_device = ["Ascend", "GPU"] + supported_formats = ['AIR', 'MINDIR'] + + mean = Validator.check_type("mean", mean, (int, float)) + std_dev = Validator.check_type("std_dev", std_dev, (int, float)) + + if context.get_context('device_target') not in supported_device: + raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) + + if file_format not in supported_formats: + raise ValueError('Illegal file format {}.'.format(file_format)) + + network.set_train(False) + if file_format == "MINDIR": + exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) + else: + exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=False) + deploy_net = exporter.run() + serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format) diff --git a/mindspore/train/quant/__init__.py b/mindspore/compression/quant/__init__.py similarity index 55% rename from mindspore/train/quant/__init__.py rename to mindspore/compression/quant/__init__.py index 449c0f706d4..29d50d92214 100644 --- a/mindspore/train/quant/__init__.py +++ b/mindspore/compression/quant/__init__.py @@ -13,14 +13,9 @@ # limitations under the License. # ============================================================================ """ -Quantization. - -User can use quantization aware to train a model. MindSpore supports quantization aware training, -which models quantization errors in both the forward and backward passes using fake-quantization -operations. Note that the entire computation is carried out in floating point. At the end of quantization -aware training, MindSpore provides conversion functions to convert the trained model into lower precision. +Compression quant module. """ -from .quant import convert_quant_network, export, manual_export - -__all__ = ["convert_quant_network", "export", "manual_export"] +from .quantizer import * +from .qat import * +from .quant_utils import * diff --git a/mindspore/compression/quant/qat.py b/mindspore/compression/quant/qat.py new file mode 100644 index 00000000000..26a6cc6d17f --- /dev/null +++ b/mindspore/compression/quant/qat.py @@ -0,0 +1,406 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Quantization aware training + +User can use quantization aware to train a model. MindSpore supports quantization aware training, +which models quantization errors in both the forward and backward passes using fake-quantization +operations. Note that the entire computation is carried out in floating point. At the end of quantization +aware training, MindSpore provides conversion functions to convert the trained model into lower precision. +""" + +import re + +import mindspore.context as context + +from ... import nn, ops +from ..._checkparam import Validator, Rel +from ...nn.layer import quant +from ...ops import functional as F +from ..common import QuantDtype +from .quantizer import Quantizer, OptimizeOption + + +__all__ = ["QuantizationAwareTraining"] + + +_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant, + nn.ReLU6: quant.ActQuant, + nn.Sigmoid: quant.ActQuant, + nn.LeakyReLU: quant.LeakyReLUQuant, + nn.HSigmoid: quant.HSigmoidQuant, + nn.HSwish: quant.HSwishQuant} + + +def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver), + quant_delay=(0, 0), + quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), + per_channel=(False, False), + symmetric=(False, False), + narrow_range=(False, False) + ): + r""" + Configs the oberser type of weights and data flow with quant params. + + Args: + quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent + weights and second element represent data flow. + Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver) + quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during + eval. The first element represent weights and second element represent data flow. Default: (0, 0) + quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first + element represent weights and second element represent data flow. + Default: (QuantDtype.INT8, QuantDtype.INT8) + per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` + then base on per channel otherwise base on per layer. The first element represent weights + and second element represent data flow. Default: (False, False) + symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on + symmetric otherwise base on asymmetric. The first element represent weights and second + element represent data flow. Default: (False, False) + narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. + The first element represents weights and the second element represents data flow. Default: (False, False) + + Returns: + QuantConfig, Contains the oberser type of weight and activation. + """ + weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0], + per_channel=per_channel[0], symmetric=symmetric[0], + narrow_range=narrow_range[0]) + act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1], + per_channel=per_channel[-1], symmetric=symmetric[-1], + narrow_range=narrow_range[-1]) + return quant.QuantConfig(weight=weight_observer, activation=act_observer) + + +class _AddFakeQuantInput(nn.Cell): + """ + Add FakeQuant OP at input of the network. Only support one input case. + """ + + def __init__(self, network, quant_delay=0): + super(_AddFakeQuantInput, self).__init__(auto_prefix=False) + self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6, + quant_delay=quant_delay, ema=True) + self.fake_quant_input.update_parameters_name('fake_quant_input.') + self.network = network + + def construct(self, data): + data = self.fake_quant_input(data) + output = self.network(data) + return output + + +class _AddFakeQuantAfterSubCell(nn.Cell): + """ + Add FakeQuant OP after of the sub Cell. + """ + + def __init__(self, subcell, **kwargs): + super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False) + self.subcell = subcell + self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6, + max_init=6, + ema=True, + quant_dtype=kwargs["quant_dtype"], + quant_delay=kwargs["quant_delay"], + per_channel=kwargs["per_channel"], + symmetric=kwargs["symmetric"], + narrow_range=kwargs["narrow_range"]) + + def construct(self, *data): + output = self.subcell(*data) + output = self.fake_quant_act(output) + return output + + +class ConvertToQuantNetwork: + """ + Convert network to quantization aware network + """ + __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] + + def __init__(self, **kwargs): + self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,)) + self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay") + self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay") + self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold") + self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn") + self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype) + self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype) + self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel") + self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel") + self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric") + self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric") + self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range") + self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range") + self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, + quant.DenseBnAct: self._convert_dense} + self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"], + quant_dtype=kwargs["quant_dtype"], + per_channel=kwargs["per_channel"], + symmetric=kwargs["symmetric"], + narrow_range=kwargs["narrow_range"]) + +class QuantizationAwareTraining(Quantizer): + r""" + Quantizer for quantization aware training. + + Args: + bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True. + freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7. + quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during + eval. The first element represent weights and second element represent data flow. Default: (0, 0) + quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first + element represent weights and second element represent data flow. + Default: (QuantDtype.INT8, QuantDtype.INT8) + per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` + then base on per channel otherwise base on per layer. The first element represent weights + and second element represent data flow. Default: (False, False) + symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on + symmetric otherwise base on asymmetric. The first element represent weights and second + element represent data flow. Default: (False, False) + narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not. + The first element represents weights and the second element represents data flow. Default: (False, False) + optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only + support QAT. Default: OptimizeOption.QAT + """ + __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] + + def __init__(self, + bn_fold=True, + freeze_bn=10000000, + quant_delay=(0, 0), + quant_dtype=(QuantDtype.INT8, QuantDtype.INT8), + per_channel=(False, False), + symmetric=(False, False), + narrow_range=(False, False), + optimize_option=OptimizeOption.QAT): + """Init for QuantizationAwareTraining quantizer""" + super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option) + def convert2list(name, value): + if not isinstance(value, list) and not isinstance(value, tuple): + value = [value] + elif len(value) > 2: + raise ValueError("input `{}` len should less then 2".format(name)) + return value + + quant_delay = convert2list("quant delay", quant_delay) + quant_dtype = convert2list("quant dtype", quant_dtype) + per_channel = convert2list("per channel", per_channel) + symmetric = convert2list("symmetric", symmetric) + narrow_range = convert2list("narrow range", narrow_range) + + self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay") + self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay") + self.bn_fold = Validator.check_bool(bn_fold, "bn fold") + self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn") + self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype) + self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype) + self.weight_channel = Validator.check_bool(per_channel[0], "per channel") + self.act_channel = Validator.check_bool(per_channel[-1], "per channel") + self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric") + self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric") + self.weight_range = Validator.check_bool(narrow_range[0], "narrow range") + self.act_range = Validator.check_bool(narrow_range[-1], "narrow range") + self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, + quant.DenseBnAct: self._convert_dense} + self.quant_config = get_quant_config(quant_delay=quant_delay, + quant_dtype=quant_dtype, + per_channel=per_channel, + symmetric=symmetric, + narrow_range=narrow_range) + + def _convert_op_name(self, name): + pattern = re.compile(r'([A-Z]{1})') + name_new = re.sub(pattern, r'_\1', name).lower() + if name_new[0] == '_': + name_new = name_new[1:] + return name_new + + def quantize(self, network): + support_device = ["Ascend", "GPU"] + if context.get_context('device_target') not in support_device: + raise KeyError("Unsupported {} device target.".format(context.get_context('device_target'))) + + if OptimizeOption.QAT in self.optimize_option: + network.update_cell_prefix() + network = self._convert_subcells2quant(network) + network.update_cell_type("quant") + return network + + def _convert_subcells2quant(self, network): + """ + convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell + """ + cells = network.name_cells() + change = False + for name in cells: + subcell = cells[name] + if subcell == network: + continue + elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)): + prefix = subcell.param_prefix + new_subcell = self._convert_method_map[type(subcell)](subcell) + new_subcell.update_parameters_name(prefix + '.') + network.insert_child_to_cell(name, new_subcell) + change = True + else: + self._convert_subcells2quant(subcell) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) + + # add FakeQuant OP after OP in while list + add_list = [] + for name in network.__dict__: + if name[0] == '_': + continue + attr = network.__dict__[name] + if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__: + add_list.append((name, attr)) + for name, prim_op in add_list: + prefix = name + add_quant = _AddFakeQuantAfterSubCell(prim_op, + quant_dtype=self.act_dtype, + quant_delay=self.act_qdelay, + per_channel=self.act_channel, + symmetric=self.act_symmetric, + narrow_range=self.act_range) + prefix = self._convert_op_name(prim_op.name) + if network.param_prefix: + prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) + add_quant.update_parameters_name(prefix + '.') + del network.__dict__[name] + network.insert_child_to_cell(name, add_quant) + return network + + def _convert_conv(self, subcell): + """ + convert Conv2d cell to quant cell + """ + conv_inner = subcell.conv + if subcell.has_bn: + if self.bn_fold: + bn_inner = subcell.batchnorm + conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + eps=bn_inner.eps, + momentum=bn_inner.momentum, + has_bias=conv_inner.has_bias, + bias_init=conv_inner.bias_init, + freeze_bn=self.freeze_bn, + quant_config=self.quant_config, + quant_dtype=self.weight_dtype, + fake=True) + # change original network BatchNormal OP parameters to quant network + conv_inner.gamma = subcell.batchnorm.gamma + conv_inner.beta = subcell.batchnorm.beta + conv_inner.moving_mean = subcell.batchnorm.moving_mean + conv_inner.moving_variance = subcell.batchnorm.moving_variance + del subcell.batchnorm + subcell.batchnorm = None + subcell.has_bn = False + else: + bn_inner = subcell.batchnorm + conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + eps=bn_inner.eps, + momentum=bn_inner.momentum, + has_bias=conv_inner.has_bias, + bias_init=conv_inner.bias_init, + quant_config=self.quant_config, + quant_dtype=self.weight_dtype) + # change original network BatchNormal OP parameters to quant network + conv_inner.batchnorm.gamma = subcell.batchnorm.gamma + conv_inner.batchnorm.beta = subcell.batchnorm.beta + conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean + conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance + del subcell.batchnorm + subcell.batchnorm = None + subcell.has_bn = False + else: + conv_inner = quant.Conv2dQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + has_bias=conv_inner.has_bias, + quant_config=self.quant_config, + quant_dtype=self.weight_dtype) + # change original network Conv2D OP parameters to quant network + conv_inner.weight = subcell.conv.weight + if subcell.conv.has_bias: + conv_inner.bias = subcell.conv.bias + subcell.conv = conv_inner + if subcell.has_act and subcell.activation is not None: + subcell.activation = self._convert_activation(subcell.activation) + elif subcell.after_fake: + subcell.has_act = True + subcell.activation = _AddFakeQuantAfterSubCell(F.identity, + quant_dtype=self.act_dtype, + quant_delay=self.act_qdelay, + per_channel=self.act_channel, + symmetric=self.act_symmetric, + narrow_range=self.act_range) + return subcell + + def _convert_dense(self, subcell): + """ + convert dense cell to combine dense cell + """ + dense_inner = subcell.dense + dense_inner = quant.DenseQuant(dense_inner.in_channels, + dense_inner.out_channels, + has_bias=dense_inner.has_bias, + quant_config=self.quant_config, + quant_dtype=self.weight_dtype) + # change original network Dense OP parameters to quant network + dense_inner.weight = subcell.dense.weight + if subcell.dense.has_bias: + dense_inner.bias = subcell.dense.bias + subcell.dense = dense_inner + if subcell.has_act and subcell.activation is not None: + subcell.activation = self._convert_activation(subcell.activation) + elif subcell.after_fake: + subcell.has_act = True + subcell.activation = _AddFakeQuantAfterSubCell(F.identity, + quant_dtype=self.act_dtype, + quant_delay=self.act_qdelay, + per_channel=self.act_channel, + symmetric=self.act_symmetric, + narrow_range=self.act_range) + return subcell + + def _convert_activation(self, activation): + act_class = activation.__class__ + if act_class not in _ACTIVATION_MAP: + raise ValueError("Unsupported activation in auto quant: ", act_class) + return _ACTIVATION_MAP[act_class](activation=activation, + quant_config=self.quant_config, + quant_dtype=self.act_dtype) diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/compression/quant/quant_utils.py similarity index 99% rename from mindspore/train/quant/quant_utils.py rename to mindspore/compression/quant/quant_utils.py index 21cb231fde9..8d0894088cd 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/compression/quant/quant_utils.py @@ -17,6 +17,9 @@ import numpy as np +__all__ = ["load_nonquant_param_into_quant_net"] + + def cal_quantization_params(input_min, input_max, data_type, diff --git a/mindspore/compression/quant/quantizer.py b/mindspore/compression/quant/quantizer.py new file mode 100644 index 00000000000..36a513c7c1d --- /dev/null +++ b/mindspore/compression/quant/quantizer.py @@ -0,0 +1,52 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Base Class of Quantizer.""" + +from abc import ABC, abstractmethod +from enum import Enum + +__all__ = ["OptimizeOption", "Quantizer"] + + +class OptimizeOption(Enum): + """ + An enum for the model quantization optimize option. + """ + # using quantization aware training + QAT = "QAT" + + def __str__(self): + return self.value + + +class Quantizer(ABC): + """ + Base class of Quantizer. You can implement different kind of quantizer to get different quantization result. + + Notes: + This class is an abstract class. + + Args: + optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options. Default: None. + """ + def __init__(self, + optimize_option=None): + if not isinstance(optimize_option, list) and not isinstance(optimize_option, tuple): + optimize_option = [optimize_option] + self.optimize_option = optimize_option + + @abstractmethod + def quantize(self, network): + pass diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 2d1f7818c7a..2084927ccc1 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -30,7 +30,7 @@ from mindspore.common.parameter import Parameter from mindspore.common.api import _executor from mindspore.common import dtype as mstype from mindspore._checkparam import check_input_data, Validator -from mindspore.train.quant import quant +from mindspore.compression.export import quant_export import mindspore.context as context __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print", @@ -596,14 +596,14 @@ def _quant_export(network, *inputs, file_format, **kwargs): network.set_train(False) if file_format == "MINDIR": if quant_mode == 'MANUAL': - exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) + exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) else: - exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) + exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) else: if quant_mode == 'MANUAL': - exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs) + exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs) else: - exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs) + exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs) deploy_net = exporter.run() return deploy_net diff --git a/model_zoo/official/cv/lenet_quant/eval_quant.py b/model_zoo/official/cv/lenet_quant/eval_quant.py index fb44c01a916..86f7db60961 100644 --- a/model_zoo/official/cv/lenet_quant/eval_quant.py +++ b/model_zoo/official/cv/lenet_quant/eval_quant.py @@ -25,7 +25,7 @@ from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train import Model from mindspore.nn.metrics import Accuracy -from mindspore.train.quant import quant +from mindspore.compression.quant import QuantizationAwareTraining from src.dataset import create_dataset from src.config import mnist_cfg as cfg from src.lenet_fusion import LeNet5 as LeNet5Fusion @@ -47,8 +47,12 @@ if __name__ == "__main__": # define fusion network network = LeNet5Fusion(cfg.num_classes) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, - per_channel=[True, False], symmetric=[True, False]) + quantizer = QuantizationAwareTraining(quant_delay=0, + bn_fold=False, + freeze_bn=10000, + per_channel=[True, False], + symmetric=[True, False]) + network = quantizer.quantize(network) # define loss net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") diff --git a/model_zoo/official/cv/lenet_quant/export.py b/model_zoo/official/cv/lenet_quant/export.py index c250edc6abe..4f5d9381d2d 100644 --- a/model_zoo/official/cv/lenet_quant/export.py +++ b/model_zoo/official/cv/lenet_quant/export.py @@ -22,7 +22,7 @@ import numpy as np import mindspore from mindspore import Tensor from mindspore import context -from mindspore.train.quant import quant +from mindspore.compression.quant import QuantizationAwareTraining from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from src.config import mnist_cfg as cfg @@ -44,8 +44,12 @@ if __name__ == "__main__": # define fusion network network = LeNet5Fusion(cfg.num_classes) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, - per_channel=[True, False], symmetric=[True, False]) + quantizer = QuantizationAwareTraining(quant_delay=0, + bn_fold=False, + freeze_bn=10000, + per_channel=[True, False], + symmetric=[True, False]) + network = quantizer.quantize(network) # load quantization aware network checkpoint param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) diff --git a/model_zoo/official/cv/lenet_quant/train_quant.py b/model_zoo/official/cv/lenet_quant/train_quant.py index 9e43cbe58cb..77279be3cd3 100644 --- a/model_zoo/official/cv/lenet_quant/train_quant.py +++ b/model_zoo/official/cv/lenet_quant/train_quant.py @@ -26,8 +26,8 @@ from mindspore.train.serialization import load_checkpoint from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train import Model from mindspore.nn.metrics import Accuracy -from mindspore.train.quant import quant -from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net +from mindspore.compression.quant import QuantizationAwareTraining +from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.common import set_seed from src.dataset import create_dataset from src.config import mnist_cfg as cfg @@ -59,8 +59,11 @@ if __name__ == "__main__": load_nonquant_param_into_quant_net(network, param_dict) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False], + quantizer = QuantizationAwareTraining(quant_delay=900, + bn_fold=False, + per_channel=[True, False], symmetric=[True, False]) + network = quantizer.quantize(network) # define network loss net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") diff --git a/model_zoo/official/cv/mobilenetv2_quant/eval.py b/model_zoo/official/cv/mobilenetv2_quant/eval.py index b0515e3f26b..76850fabeca 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/eval.py +++ b/model_zoo/official/cv/mobilenetv2_quant/eval.py @@ -21,7 +21,7 @@ from mindspore import context from mindspore import nn from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.quant import quant +from mindspore.compression.quant import QuantizationAwareTraining from src.mobilenetV2 import mobilenetV2 from src.dataset import create_dataset @@ -51,7 +51,10 @@ if __name__ == '__main__': # define fusion network network = mobilenetV2(num_classes=config_device_target.num_classes) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + network = quantizer.quantize(network) # define network loss loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') diff --git a/model_zoo/official/cv/mobilenetv2_quant/export.py b/model_zoo/official/cv/mobilenetv2_quant/export.py index 83d8ff3dad5..735673415fe 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/export.py +++ b/model_zoo/official/cv/mobilenetv2_quant/export.py @@ -21,7 +21,7 @@ import mindspore from mindspore import Tensor from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net, export -from mindspore.train.quant import quant +from mindspore.compression.quant import QuantizationAwareTraining from src.mobilenetV2 import mobilenetV2 from src.config import config_ascend_quant @@ -42,7 +42,10 @@ if __name__ == '__main__': # define fusion network network = mobilenetV2(num_classes=cfg.num_classes) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + network = quantizer.quantize(network) # load checkpoint param_dict = load_checkpoint(args_opt.checkpoint_path) load_param_into_net(network, param_dict) diff --git a/model_zoo/official/cv/mobilenetv2_quant/train.py b/model_zoo/official/cv/mobilenetv2_quant/train.py index 89ccb7c9e18..9d883cf4313 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/train.py +++ b/model_zoo/official/cv/mobilenetv2_quant/train.py @@ -26,8 +26,8 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import load_checkpoint from mindspore.communication.management import init, get_group_size, get_rank -from mindspore.train.quant import quant -from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net +from mindspore.compression.quant import QuantizationAwareTraining +from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.common import set_seed from src.dataset import create_dataset @@ -99,10 +99,10 @@ def train_on_ascend(): param_dict = load_checkpoint(args_opt.pre_trained) load_nonquant_param_into_quant_net(network, param_dict) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, - bn_fold=True, + quantizer = QuantizationAwareTraining(bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + network = quantizer.quantize(network) # get learning rate lr = Tensor(get_lr(global_step=config.start_epoch * step_size, @@ -162,12 +162,12 @@ def train_on_gpu(): load_nonquant_param_into_quant_net(network, param_dict) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, - bn_fold=True, + quantizer = QuantizationAwareTraining(bn_fold=True, per_channel=[True, False], symmetric=[True, False], freeze_bn=1000000, quant_delay=step_size * 2) + network = quantizer.quantize(network) # get learning rate loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) diff --git a/model_zoo/official/cv/resnet50_quant/eval.py b/model_zoo/official/cv/resnet50_quant/eval.py index 76cfa345239..30a0b46bcab 100755 --- a/model_zoo/official/cv/resnet50_quant/eval.py +++ b/model_zoo/official/cv/resnet50_quant/eval.py @@ -26,7 +26,7 @@ from models.resnet_quant_manual import resnet50_quant #manually construct quanta from mindspore import context from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.quant import quant +from mindspore.compression.quant import QuantizationAwareTraining parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') @@ -43,12 +43,13 @@ if args_opt.device_target == "Ascend": if __name__ == '__main__': # define fusion network - net = resnet50_quant(class_num=config.class_num) + network = resnet50_quant(class_num=config.class_num) # convert fusion network to quantization aware network - net = quant.convert_quant_network(net, - bn_fold=True, - per_channel=[True, False], - symmetric=[True, False]) + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + network = quantizer.quantize(network) + # define network loss if not config.use_label_smooth: config.label_smooth_factor = 0.0 @@ -65,13 +66,13 @@ if __name__ == '__main__': # load checkpoint if args_opt.checkpoint_path: param_dict = load_checkpoint(args_opt.checkpoint_path) - not_load_param = load_param_into_net(net, param_dict) + not_load_param = load_param_into_net(network, param_dict) if not_load_param: - raise ValueError("Load param into net fail!") - net.set_train(False) + raise ValueError("Load param into network fail!") + network.set_train(False) # define model - model = Model(net, loss_fn=loss, metrics={'acc'}) + model = Model(network, loss_fn=loss, metrics={'acc'}) print("============== Starting Validation ==============") res = model.eval(dataset) diff --git a/model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py b/model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py index 63da3cfaa8d..a74dba847b5 100644 --- a/model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py +++ b/model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py @@ -17,14 +17,14 @@ import numpy as np import mindspore.nn as nn from mindspore.ops import operations as P from mindspore import Tensor -from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant -from mindspore.train.quant import quant +from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant +from mindspore.compression.quant import qat _ema_decay = 0.999 _symmetric = True _fake = True _per_channel = True -_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) +_quant_config = qat.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) def _weight_variable(shape, factor=0.01): @@ -90,8 +90,8 @@ class ConvBNReLU(nn.Cell): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): super(ConvBNReLU, self).__init__() padding = (kernel_size - 1) // 2 - conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, - group=groups, fake=_fake, quant_config=_quant_config) + conv = Conv2dBnFoldQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, + group=groups, fake=_fake, quant_config=_quant_config) layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] self.features = nn.SequentialCell(layers) @@ -126,14 +126,14 @@ class ResidualBlock(nn.Cell): channel = out_channel // self.expansion self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) - self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, - quant_config=_quant_config, - kernel_size=1, stride=1, pad_mode='same', padding=0), + self.conv3 = nn.SequentialCell([Conv2dBnFoldQuant(channel, out_channel, fake=_fake, + quant_config=_quant_config, + kernel_size=1, stride=1, pad_mode='same', padding=0), FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) - ]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, - quant_config=_quant_config, - kernel_size=1, stride=1, - pad_mode='same', padding=0) + ]) if _fake else Conv2dBnFoldQuant(channel, out_channel, fake=_fake, + quant_config=_quant_config, + kernel_size=1, stride=1, + pad_mode='same', padding=0) self.down_sample = False @@ -142,20 +142,19 @@ class ResidualBlock(nn.Cell): self.down_sample_layer = None if self.down_sample: - self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, - quant_config=_quant_config, - kernel_size=1, stride=stride, - pad_mode='same', padding=0), + self.down_sample_layer = nn.SequentialCell([Conv2dBnFoldQuant(in_channel, out_channel, + quant_config=_quant_config, + kernel_size=1, stride=stride, + pad_mode='same', padding=0), FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) - ]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, - fake=_fake, - quant_config=\ - _quant_config, - kernel_size=1, - stride=stride, - pad_mode='same', - padding=0) + ]) if _fake else Conv2dBnFoldQuant(in_channel, out_channel, + fake=_fake, + quant_config=_quant_config, + kernel_size=1, + stride=stride, + pad_mode='same', + padding=0) self.add = nn.TensorAddQuant() self.relu = P.ReLU() diff --git a/model_zoo/official/cv/resnet50_quant/train.py b/model_zoo/official/cv/resnet50_quant/train.py index 5944eda925b..1c3cf16073f 100755 --- a/model_zoo/official/cv/resnet50_quant/train.py +++ b/model_zoo/official/cv/resnet50_quant/train.py @@ -25,8 +25,8 @@ from mindspore.context import ParallelMode from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint -from mindspore.train.quant import quant -from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net +from mindspore.compression.quant import QuantizationAwareTraining +from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.communication.management import init import mindspore.nn as nn import mindspore.common.initializer as weight_init @@ -113,7 +113,10 @@ if __name__ == '__main__': step_size = dataset.get_dataset_size() # convert fusion network to quantization aware network - net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + network = quantizer.quantize(network) # get learning rate lr = get_lr(lr_init=config.lr_init, diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/eval.py b/model_zoo/official/cv/yolov3_darknet53_quant/eval.py index be68e59e227..179b7d68e23 100644 --- a/model_zoo/official/cv/yolov3_darknet53_quant/eval.py +++ b/model_zoo/official/cv/yolov3_darknet53_quant/eval.py @@ -29,7 +29,7 @@ from mindspore.context import ParallelMode from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net import mindspore as ms -from mindspore.train.quant import quant +from mindspore.compression.quant import QuantizationAwareTraining from src.yolo import YOLOV3DarkNet53 from src.logger import get_logger @@ -265,10 +265,10 @@ def test(): # convert fusion network to quantization aware network if config.quantization_aware: - network = quant.convert_quant_network(network, - bn_fold=True, + quantizer = QuantizationAwareTraining(bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + network = quantizer.quantize(network) args.logger.info(args.pretrained) if os.path.isfile(args.pretrained): diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/mindspore_hub_conf.py b/model_zoo/official/cv/yolov3_darknet53_quant/mindspore_hub_conf.py index dfc33f87a68..658d81b15f3 100644 --- a/model_zoo/official/cv/yolov3_darknet53_quant/mindspore_hub_conf.py +++ b/model_zoo/official/cv/yolov3_darknet53_quant/mindspore_hub_conf.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """hub config.""" -from mindspore.train.quant import quant +from mindspore.compression.quant import QuantizationAwareTraining from src.yolo import YOLOV3DarkNet53 from src.config import ConfigYOLOV3DarkNet53 @@ -24,9 +24,9 @@ def create_network(name, *args, **kwargs): config = ConfigYOLOV3DarkNet53() # convert fusion network to quantization aware network if config.quantization_aware: - yolov3_darknet53_quant = quant.convert_quant_network(yolov3_darknet53_quant, - bn_fold=True, - per_channel=[True, False], - symmetric=[True, False]) + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + yolov3_darknet53_quant = quantizer.quantize(yolov3_darknet53_quant) return yolov3_darknet53_quant raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/train.py b/model_zoo/official/cv/yolov3_darknet53_quant/train.py index 8dca14d4701..a4b9be26a09 100644 --- a/model_zoo/official/cv/yolov3_darknet53_quant/train.py +++ b/model_zoo/official/cv/yolov3_darknet53_quant/train.py @@ -27,7 +27,7 @@ from mindspore.communication.management import init, get_rank, get_group_size from mindspore.train.callback import ModelCheckpoint, RunContext from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig import mindspore as ms -from mindspore.train.quant import quant +from mindspore.compression.quant import QuantizationAwareTraining from mindspore.common import set_seed from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper @@ -168,10 +168,10 @@ def train(): config = ConfigYOLOV3DarkNet53() # convert fusion network to quantization aware network if config.quantization_aware: - network = quant.convert_quant_network(network, - bn_fold=True, + quantizer = QuantizationAwareTraining(bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + network = quantizer.quantize(network) network = YoloWithLossCell(network) args.logger.info('finish get network') diff --git a/tests/st/quantization/lenet_quant/test_lenet_quant.py b/tests/st/quantization/lenet_quant/test_lenet_quant.py index 4dffcfccbee..6bde63a112d 100644 --- a/tests/st/quantization/lenet_quant/test_lenet_quant.py +++ b/tests/st/quantization/lenet_quant/test_lenet_quant.py @@ -26,8 +26,8 @@ from mindspore.nn.metrics import Accuracy from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train import Model -from mindspore.train.quant import quant -from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net +from mindspore.compression.quant import QuantizationAwareTraining +from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from dataset import create_dataset from config import nonquant_cfg, quant_cfg from lenet import LeNet5 @@ -73,8 +73,11 @@ def train_lenet_quant(): load_nonquant_param_into_quant_net(network, param_dict) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False], - symmetric=[False, False]) + quantizer = QuantizationAwareTraining(quant_delay=900, + bn_fold=False, + per_channel=[True, False], + symmetric=[True, False]) + network = quantizer.quantize(network) # define network loss net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") @@ -103,8 +106,12 @@ def eval_quant(): # define fusion network network = LeNet5Fusion(cfg.num_classes) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, - per_channel=[True, False]) + quantizer = QuantizationAwareTraining(quant_delay=0, + bn_fold=False, + freeze_bn=10000, + per_channel=[True, False], + symmetric=[True, False]) + network = quantizer.quantize(network) # define loss net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") @@ -131,8 +138,12 @@ def export_lenet(): # define fusion network network = LeNet5Fusion(cfg.num_classes) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, - per_channel=[True, False], symmetric=[True, False]) + quantizer = QuantizationAwareTraining(quant_delay=0, + bn_fold=False, + freeze_bn=10000, + per_channel=[True, False], + symmetric=[True, False]) + network = quantizer.quantize(network) # export network inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32) diff --git a/tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py b/tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py index e7c9c5e0bae..82655b1833c 100644 --- a/tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py +++ b/tests/st/quantization/mobilenetv2_quant/test_mobilenetv2_quant.py @@ -23,7 +23,7 @@ from mindspore import context from mindspore import Tensor from mindspore import nn from mindspore.train.model import Model -from mindspore.train.quant import quant +from mindspore.compression.quant import QuantizationAwareTraining from mindspore.common import set_seed from dataset import create_dataset @@ -84,10 +84,10 @@ def test_mobilenetv2_quant(): step_size = dataset.get_dataset_size() # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, - bn_fold=True, + quantizer = QuantizationAwareTraining(bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + network = quantizer.quantize(network) # get learning rate lr = Tensor(get_lr(global_step=config.start_epoch * step_size, diff --git a/tests/st/quantization/resnet50_quant/resnet_quant_manual.py b/tests/st/quantization/resnet50_quant/resnet_quant_manual.py index 9bb6cbea0fb..20f70f4b621 100644 --- a/tests/st/quantization/resnet50_quant/resnet_quant_manual.py +++ b/tests/st/quantization/resnet50_quant/resnet_quant_manual.py @@ -18,14 +18,14 @@ import mindspore.nn as nn import mindspore.common.initializer as weight_init from mindspore.ops import operations as P from mindspore import Tensor -from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant -from mindspore.train.quant import quant +from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant +from mindspore.compression.quant import qat _ema_decay = 0.999 _symmetric = True _fake = True _per_channel = True -_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) +_quant_config = qat.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) def _weight_variable(shape, factor=0.01): @@ -91,8 +91,8 @@ class ConvBNReLU(nn.Cell): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): super(ConvBNReLU, self).__init__() padding = (kernel_size - 1) // 2 - conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, - group=groups, fake=_fake, quant_config=_quant_config) + conv = Conv2dBnFoldQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, + group=groups, fake=_fake, quant_config=_quant_config) layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] self.features = nn.SequentialCell(layers) @@ -127,14 +127,14 @@ class ResidualBlock(nn.Cell): channel = out_channel // self.expansion self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) - self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, - quant_config=_quant_config, - kernel_size=1, stride=1, pad_mode='same', padding=0), + self.conv3 = nn.SequentialCell([Conv2dBnFoldQuant(channel, out_channel, fake=_fake, + quant_config=_quant_config, + kernel_size=1, stride=1, pad_mode='same', padding=0), FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) - ]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, - quant_config=_quant_config, - kernel_size=1, stride=1, - pad_mode='same', padding=0) + ]) if _fake else Conv2dBnFoldQuant(channel, out_channel, fake=_fake, + quant_config=_quant_config, + kernel_size=1, stride=1, + pad_mode='same', padding=0) self.down_sample = False @@ -143,20 +143,19 @@ class ResidualBlock(nn.Cell): self.down_sample_layer = None if self.down_sample: - self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, - quant_config=_quant_config, - kernel_size=1, stride=stride, - pad_mode='same', padding=0), + self.down_sample_layer = nn.SequentialCell([Conv2dBnFoldQuant(in_channel, out_channel, + quant_config=_quant_config, + kernel_size=1, stride=stride, + pad_mode='same', padding=0), FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) - ]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, - fake=_fake, - quant_config=\ - _quant_config, - kernel_size=1, - stride=stride, - pad_mode='same', - padding=0) + ]) if _fake else Conv2dBnFoldQuant(in_channel, out_channel, + fake=_fake, + quant_config=_quant_config, + kernel_size=1, + stride=stride, + pad_mode='same', + padding=0) self.add = nn.TensorAddQuant() self.relu = P.ReLU() diff --git a/tests/st/quantization/resnet50_quant/test_resnet50_quant.py b/tests/st/quantization/resnet50_quant/test_resnet50_quant.py index e2c060f8806..8b5d30add4a 100755 --- a/tests/st/quantization/resnet50_quant/test_resnet50_quant.py +++ b/tests/st/quantization/resnet50_quant/test_resnet50_quant.py @@ -22,7 +22,7 @@ from mindspore import context from mindspore import Tensor from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model -from mindspore.train.quant import quant +from mindspore.compression.quant import QuantizationAwareTraining from mindspore import set_seed from resnet_quant_manual import resnet50_quant @@ -89,10 +89,10 @@ def test_resnet50_quant(): step_size = dataset.get_dataset_size() # convert fusion network to quantization aware network - net = quant.convert_quant_network(net, - bn_fold=True, - per_channel=[True, False], - symmetric=[True, False]) + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + net = quantizer.quantize(net) # get learning rate lr = Tensor(get_lr(lr_init=config.lr_init, diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index 10e7250fcff..c8e46072fe5 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -19,7 +19,8 @@ import pytest import mindspore.context as context from mindspore import Tensor from mindspore import nn -from mindspore.train.quant import quant as qat +from mindspore.compression.quant import QuantizationAwareTraining +from mindspore.compression.export import quant_export from model_zoo.official.cv.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") @@ -66,27 +67,35 @@ class LeNet5(nn.Cell): def test_qat_lenet(): img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) net = LeNet5() - net = qat.convert_quant_network( - net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + net = quantizer.quantize(net) # should load the checkpoint. mock here net.init_parameters_data() - qat.export(net, img, file_name="quant.pb") + quant_export.export(net, img, file_name="quant.pb") @pytest.mark.skip(reason="no `te.lang.cce` in ut env") def test_qat_mobile_per_channel_tf(): network = mobilenetV2(num_classes=1000) img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) - network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + network = quantizer.quantize(network) # should load the checkpoint. mock here network.init_parameters_data() - qat.export(network, img, file_name="quant.pb") + quant_export.export(network, img, file_name="quant.pb") @pytest.mark.skip(reason="no `te.lang.cce` in ut env") def test_qat_mobile_per_channel_ff(): network = mobilenetV2(num_classes=1000) img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) - network = qat.convert_quant_network(network, bn_fold=True, per_channel=[False, False], symmetric=[True, False]) + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[False, False], + symmetric=[True, False]) + network = quantizer.quantize(network) # should load the checkpoint. mock here network.init_parameters_data() - qat.export(network, img, file_name="quant.pb") + quant_export.export(network, img, file_name="quant.pb")