move train.quant to compression module & add QuantizationAwareTraining
This commit is contained in:
parent
ff85533155
commit
c8ec34d638
|
@ -0,0 +1,17 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Compression export module.
|
||||
"""
|
|
@ -12,323 +12,28 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""quantization aware."""
|
||||
"""Export for quantization."""
|
||||
|
||||
import copy
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
|
||||
from ... import log as logger
|
||||
from ... import nn, ops
|
||||
from ..._checkparam import Validator, Rel
|
||||
from ..._checkparam import Validator
|
||||
from ...common import Tensor
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import _executor
|
||||
from ...nn.layer import quant
|
||||
from ...compression.common import QuantDtype
|
||||
from ...ops import functional as F
|
||||
from ...ops import operations as P
|
||||
from ...ops.operations import _inner_ops as inner
|
||||
from ...train import serialization
|
||||
from . import quant_utils
|
||||
|
||||
_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
|
||||
nn.ReLU6: quant.ActQuant,
|
||||
nn.Sigmoid: quant.ActQuant,
|
||||
nn.LeakyReLU: quant.LeakyReLUQuant,
|
||||
nn.HSigmoid: quant.HSigmoidQuant,
|
||||
nn.HSwish: quant.HSwishQuant}
|
||||
from ..quant import quant_utils
|
||||
from ..quant.qat import QuantizationAwareTraining, _AddFakeQuantInput, _AddFakeQuantAfterSubCell
|
||||
|
||||
|
||||
def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver),
|
||||
quant_delay=(0, 0),
|
||||
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
|
||||
per_channel=(False, False),
|
||||
symmetric=(False, False),
|
||||
narrow_range=(False, False)
|
||||
):
|
||||
r"""
|
||||
Configs the oberser type of weights and data flow with quant params.
|
||||
|
||||
Args:
|
||||
quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent
|
||||
weights and second element represent data flow.
|
||||
Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver)
|
||||
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
|
||||
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
|
||||
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
|
||||
element represent weights and second element represent data flow.
|
||||
Default: (QuantDtype.INT8, QuantDtype.INT8)
|
||||
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
|
||||
then base on per channel otherwise base on per layer. The first element represent weights
|
||||
and second element represent data flow. Default: (False, False)
|
||||
symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
|
||||
symmetric otherwise base on asymmetric. The first element represent weights and second
|
||||
element represent data flow. Default: (False, False)
|
||||
narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
|
||||
The first element represents weights and the second element represents data flow. Default: (False, False)
|
||||
|
||||
Returns:
|
||||
QuantConfig, Contains the oberser type of weight and activation.
|
||||
"""
|
||||
weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0],
|
||||
per_channel=per_channel[0], symmetric=symmetric[0],
|
||||
narrow_range=narrow_range[0])
|
||||
act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1],
|
||||
per_channel=per_channel[-1], symmetric=symmetric[-1],
|
||||
narrow_range=narrow_range[-1])
|
||||
return quant.QuantConfig(weight=weight_observer, activation=act_observer)
|
||||
|
||||
|
||||
class _AddFakeQuantInput(nn.Cell):
|
||||
"""
|
||||
Add FakeQuant OP at input of the network. Only support one input case.
|
||||
"""
|
||||
|
||||
def __init__(self, network, quant_delay=0):
|
||||
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
|
||||
self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6,
|
||||
quant_delay=quant_delay, ema=True)
|
||||
self.fake_quant_input.update_parameters_name('fake_quant_input.')
|
||||
self.network = network
|
||||
|
||||
def construct(self, data):
|
||||
data = self.fake_quant_input(data)
|
||||
output = self.network(data)
|
||||
return output
|
||||
|
||||
|
||||
class _AddFakeQuantAfterSubCell(nn.Cell):
|
||||
"""
|
||||
Add FakeQuant OP after of the sub Cell.
|
||||
"""
|
||||
|
||||
def __init__(self, subcell, **kwargs):
|
||||
super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
|
||||
self.subcell = subcell
|
||||
self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
quant_dtype=kwargs["quant_dtype"],
|
||||
quant_delay=kwargs["quant_delay"],
|
||||
per_channel=kwargs["per_channel"],
|
||||
symmetric=kwargs["symmetric"],
|
||||
narrow_range=kwargs["narrow_range"])
|
||||
|
||||
def construct(self, *data):
|
||||
output = self.subcell(*data)
|
||||
output = self.fake_quant_act(output)
|
||||
return output
|
||||
|
||||
|
||||
class ConvertToQuantNetwork:
|
||||
"""
|
||||
Convert network to quantization aware network
|
||||
"""
|
||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
|
||||
self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay")
|
||||
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay")
|
||||
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
|
||||
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn")
|
||||
self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype)
|
||||
self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype)
|
||||
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
|
||||
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
|
||||
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")
|
||||
self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric")
|
||||
self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range")
|
||||
self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range")
|
||||
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
|
||||
quant.DenseBnAct: self._convert_dense}
|
||||
self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"],
|
||||
quant_dtype=kwargs["quant_dtype"],
|
||||
per_channel=kwargs["per_channel"],
|
||||
symmetric=kwargs["symmetric"],
|
||||
narrow_range=kwargs["narrow_range"])
|
||||
|
||||
def _convert_op_name(self, name):
|
||||
pattern = re.compile(r'([A-Z]{1})')
|
||||
name_new = re.sub(pattern, r'_\1', name).lower()
|
||||
if name_new[0] == '_':
|
||||
name_new = name_new[1:]
|
||||
return name_new
|
||||
|
||||
def run(self):
|
||||
self.network.update_cell_prefix()
|
||||
network = self._convert_subcells2quant(self.network)
|
||||
self.network.update_cell_type("quant")
|
||||
return network
|
||||
|
||||
def _convert_subcells2quant(self, network):
|
||||
"""
|
||||
convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
|
||||
"""
|
||||
cells = network.name_cells()
|
||||
change = False
|
||||
for name in cells:
|
||||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
continue
|
||||
elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)):
|
||||
prefix = subcell.param_prefix
|
||||
new_subcell = self._convert_method_map[type(subcell)](subcell)
|
||||
new_subcell.update_parameters_name(prefix + '.')
|
||||
network.insert_child_to_cell(name, new_subcell)
|
||||
change = True
|
||||
else:
|
||||
self._convert_subcells2quant(subcell)
|
||||
if isinstance(network, nn.SequentialCell) and change:
|
||||
network.cell_list = list(network.cells())
|
||||
|
||||
# add FakeQuant OP after OP in while list
|
||||
add_list = []
|
||||
for name in network.__dict__:
|
||||
if name[0] == '_':
|
||||
continue
|
||||
attr = network.__dict__[name]
|
||||
if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
|
||||
add_list.append((name, attr))
|
||||
for name, prim_op in add_list:
|
||||
prefix = name
|
||||
add_quant = _AddFakeQuantAfterSubCell(prim_op,
|
||||
quant_dtype=self.act_dtype,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
narrow_range=self.act_range)
|
||||
prefix = self._convert_op_name(prim_op.name)
|
||||
if network.param_prefix:
|
||||
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
|
||||
add_quant.update_parameters_name(prefix + '.')
|
||||
del network.__dict__[name]
|
||||
network.insert_child_to_cell(name, add_quant)
|
||||
return network
|
||||
|
||||
def _convert_conv(self, subcell):
|
||||
"""
|
||||
convert Conv2d cell to quant cell
|
||||
"""
|
||||
conv_inner = subcell.conv
|
||||
if subcell.has_bn:
|
||||
if self.bn_fold:
|
||||
bn_inner = subcell.batchnorm
|
||||
conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
stride=conv_inner.stride,
|
||||
pad_mode=conv_inner.pad_mode,
|
||||
padding=conv_inner.padding,
|
||||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
momentum=bn_inner.momentum,
|
||||
has_bias=conv_inner.has_bias,
|
||||
bias_init=conv_inner.bias_init,
|
||||
freeze_bn=self.freeze_bn,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype,
|
||||
fake=True)
|
||||
# change original network BatchNormal OP parameters to quant network
|
||||
conv_inner.gamma = subcell.batchnorm.gamma
|
||||
conv_inner.beta = subcell.batchnorm.beta
|
||||
conv_inner.moving_mean = subcell.batchnorm.moving_mean
|
||||
conv_inner.moving_variance = subcell.batchnorm.moving_variance
|
||||
del subcell.batchnorm
|
||||
subcell.batchnorm = None
|
||||
subcell.has_bn = False
|
||||
else:
|
||||
bn_inner = subcell.batchnorm
|
||||
conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
stride=conv_inner.stride,
|
||||
pad_mode=conv_inner.pad_mode,
|
||||
padding=conv_inner.padding,
|
||||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
momentum=bn_inner.momentum,
|
||||
has_bias=conv_inner.has_bias,
|
||||
bias_init=conv_inner.bias_init,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype)
|
||||
# change original network BatchNormal OP parameters to quant network
|
||||
conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
|
||||
conv_inner.batchnorm.beta = subcell.batchnorm.beta
|
||||
conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean
|
||||
conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance
|
||||
del subcell.batchnorm
|
||||
subcell.batchnorm = None
|
||||
subcell.has_bn = False
|
||||
else:
|
||||
conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
stride=conv_inner.stride,
|
||||
pad_mode=conv_inner.pad_mode,
|
||||
padding=conv_inner.padding,
|
||||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
has_bias=conv_inner.has_bias,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype)
|
||||
# change original network Conv2D OP parameters to quant network
|
||||
conv_inner.weight = subcell.conv.weight
|
||||
if subcell.conv.has_bias:
|
||||
conv_inner.bias = subcell.conv.bias
|
||||
subcell.conv = conv_inner
|
||||
if subcell.has_act and subcell.activation is not None:
|
||||
subcell.activation = self._convert_activation(subcell.activation)
|
||||
elif subcell.after_fake:
|
||||
subcell.has_act = True
|
||||
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
|
||||
quant_dtype=self.act_dtype,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
narrow_range=self.act_range)
|
||||
return subcell
|
||||
|
||||
def _convert_dense(self, subcell):
|
||||
"""
|
||||
convert dense cell to combine dense cell
|
||||
"""
|
||||
dense_inner = subcell.dense
|
||||
dense_inner = quant.DenseQuant(dense_inner.in_channels,
|
||||
dense_inner.out_channels,
|
||||
has_bias=dense_inner.has_bias,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype)
|
||||
# change original network Dense OP parameters to quant network
|
||||
dense_inner.weight = subcell.dense.weight
|
||||
if subcell.dense.has_bias:
|
||||
dense_inner.bias = subcell.dense.bias
|
||||
subcell.dense = dense_inner
|
||||
if subcell.has_act and subcell.activation is not None:
|
||||
subcell.activation = self._convert_activation(subcell.activation)
|
||||
elif subcell.after_fake:
|
||||
subcell.has_act = True
|
||||
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
|
||||
quant_dtype=self.act_dtype,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
narrow_range=self.act_range)
|
||||
return subcell
|
||||
|
||||
def _convert_activation(self, activation):
|
||||
act_class = activation.__class__
|
||||
if act_class not in _ACTIVATION_MAP:
|
||||
raise ValueError("Unsupported activation in auto quant: ", act_class)
|
||||
return _ACTIVATION_MAP[act_class](activation=activation,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.act_dtype)
|
||||
|
||||
__all__ = ["export", "manual_export"]
|
||||
|
||||
class ExportToQuantInferNetwork:
|
||||
"""
|
||||
|
@ -499,7 +204,7 @@ class ExportToQuantInferNetwork:
|
|||
change = True
|
||||
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
|
||||
op = subcell.subcell
|
||||
if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive):
|
||||
if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive):
|
||||
if self.is_mindir:
|
||||
op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy()))
|
||||
op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy()))
|
||||
|
@ -553,106 +258,6 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
|
|||
serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
|
||||
|
||||
|
||||
def convert_quant_network(network,
|
||||
bn_fold=True,
|
||||
freeze_bn=10000000,
|
||||
quant_delay=(0, 0),
|
||||
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
|
||||
per_channel=(False, False),
|
||||
symmetric=(False, False),
|
||||
narrow_range=(False, False)
|
||||
):
|
||||
r"""
|
||||
Create quantization aware training network.
|
||||
|
||||
Args:
|
||||
network (Cell): Obtain a pipeline through network for saving graph summary.
|
||||
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True.
|
||||
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.
|
||||
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
|
||||
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
|
||||
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
|
||||
element represent weights and second element represent data flow.
|
||||
Default: (QuantDtype.INT8, QuantDtype.INT8)
|
||||
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
|
||||
then base on per channel otherwise base on per layer. The first element represent weights
|
||||
and second element represent data flow. Default: (False, False)
|
||||
symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
|
||||
symmetric otherwise base on asymmetric. The first element represent weights and second
|
||||
element represent data flow. Default: (False, False)
|
||||
narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
|
||||
The first element represents weights and the second element represents data flow. Default: (False, False)
|
||||
|
||||
Returns:
|
||||
Cell, Network which has change to quantization aware training network cell.
|
||||
"""
|
||||
support_device = ["Ascend", "GPU"]
|
||||
|
||||
def convert2list(name, value):
|
||||
if not isinstance(value, list) and not isinstance(value, tuple):
|
||||
value = [value]
|
||||
elif len(value) > 2:
|
||||
raise ValueError("input `{}` len should less then 2".format(name))
|
||||
return value
|
||||
|
||||
quant_delay = convert2list("quant delay", quant_delay)
|
||||
quant_dtype = convert2list("quant dtype", quant_dtype)
|
||||
per_channel = convert2list("per channel", per_channel)
|
||||
symmetric = convert2list("symmetric", symmetric)
|
||||
narrow_range = convert2list("narrow range", narrow_range)
|
||||
|
||||
if context.get_context('device_target') not in support_device:
|
||||
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
||||
|
||||
net = ConvertToQuantNetwork(network=network,
|
||||
quant_delay=quant_delay,
|
||||
bn_fold=bn_fold,
|
||||
freeze_bn=freeze_bn,
|
||||
quant_dtype=quant_dtype,
|
||||
per_channel=per_channel,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
return net.run()
|
||||
|
||||
def manual_export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='MINDIR'):
|
||||
"""
|
||||
Manual exports MindSpore quantization predict model to deploy wiAIR and MINDIR.
|
||||
|
||||
Args:
|
||||
network (Cell): MindSpore network produced by `convert_quant_network`.
|
||||
inputs (Tensor): Inputs of the `quantization aware training network`.
|
||||
file_name (str): File name of model to export.
|
||||
mean (int, float): Input data mean. Default: 127.5.
|
||||
std_dev (int, float): Input data variance. Default: 127.5.
|
||||
file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported
|
||||
quantization aware model. Default: 'AIR'.
|
||||
|
||||
- AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
|
||||
Ascend model.
|
||||
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
|
||||
for MindSpore models.
|
||||
Recommended suffix for output file is '.mindir'.
|
||||
"""
|
||||
supported_device = ["Ascend", "GPU"]
|
||||
supported_formats = ['AIR', 'MINDIR']
|
||||
|
||||
mean = Validator.check_type("mean", mean, (int, float))
|
||||
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
|
||||
|
||||
if context.get_context('device_target') not in supported_device:
|
||||
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
||||
|
||||
if file_format not in supported_formats:
|
||||
raise ValueError('Illegal file format {}.'.format(file_format))
|
||||
|
||||
network.set_train(False)
|
||||
if file_format == "MINDIR":
|
||||
exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
||||
else:
|
||||
exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=False)
|
||||
deploy_net = exporter.run()
|
||||
serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
|
||||
|
||||
class ExportManualQuantNetwork:
|
||||
"""
|
||||
Convert anual quantization aware network to infer network.
|
||||
|
@ -713,7 +318,7 @@ class ExportManualQuantNetwork:
|
|||
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant,
|
||||
quant.Conv2dQuant, quant.DenseQuant)):
|
||||
network, change = self._convert_subcell(network, change, name, subcell, core=False)
|
||||
elif isinstance(subcell, quant.FakeQuantWithMinMax) and self.upcell:
|
||||
elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver) and self.upcell:
|
||||
np_type = mstype.dtype_to_nptype(self.data_type)
|
||||
_, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(subcell, np_type)
|
||||
self.upcell.core_op.add_prim_attr('output_maxq', Tensor(maxq))
|
||||
|
@ -721,7 +326,7 @@ class ExportManualQuantNetwork:
|
|||
network.insert_child_to_cell(self.upname, self.upcell)
|
||||
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
|
||||
op = subcell.subcell
|
||||
if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive):
|
||||
if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive):
|
||||
if self.is_mindir:
|
||||
op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy()))
|
||||
op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy()))
|
||||
|
@ -845,3 +450,43 @@ class ExportManualQuantNetwork:
|
|||
else:
|
||||
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
|
||||
return block
|
||||
|
||||
|
||||
def manual_export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='MINDIR'):
|
||||
"""
|
||||
Manual exports MindSpore quantization predict model to deploy wiAIR and MINDIR.
|
||||
|
||||
Args:
|
||||
network (Cell): MindSpore network produced by `convert_quant_network`.
|
||||
inputs (Tensor): Inputs of the `quantization aware training network`.
|
||||
file_name (str): File name of model to export.
|
||||
mean (int, float): Input data mean. Default: 127.5.
|
||||
std_dev (int, float): Input data variance. Default: 127.5.
|
||||
file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported
|
||||
quantization aware model. Default: 'AIR'.
|
||||
|
||||
- AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
|
||||
Ascend model.
|
||||
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
|
||||
for MindSpore models.
|
||||
Recommended suffix for output file is '.mindir'.
|
||||
"""
|
||||
supported_device = ["Ascend", "GPU"]
|
||||
supported_formats = ['AIR', 'MINDIR']
|
||||
|
||||
mean = Validator.check_type("mean", mean, (int, float))
|
||||
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
|
||||
|
||||
if context.get_context('device_target') not in supported_device:
|
||||
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
||||
|
||||
if file_format not in supported_formats:
|
||||
raise ValueError('Illegal file format {}.'.format(file_format))
|
||||
|
||||
network.set_train(False)
|
||||
if file_format == "MINDIR":
|
||||
exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
||||
else:
|
||||
exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=False)
|
||||
deploy_net = exporter.run()
|
||||
serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
|
|
@ -13,14 +13,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Quantization.
|
||||
|
||||
User can use quantization aware to train a model. MindSpore supports quantization aware training,
|
||||
which models quantization errors in both the forward and backward passes using fake-quantization
|
||||
operations. Note that the entire computation is carried out in floating point. At the end of quantization
|
||||
aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
|
||||
Compression quant module.
|
||||
"""
|
||||
|
||||
from .quant import convert_quant_network, export, manual_export
|
||||
|
||||
__all__ = ["convert_quant_network", "export", "manual_export"]
|
||||
from .quantizer import *
|
||||
from .qat import *
|
||||
from .quant_utils import *
|
|
@ -0,0 +1,406 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Quantization aware training
|
||||
|
||||
User can use quantization aware to train a model. MindSpore supports quantization aware training,
|
||||
which models quantization errors in both the forward and backward passes using fake-quantization
|
||||
operations. Note that the entire computation is carried out in floating point. At the end of quantization
|
||||
aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
import mindspore.context as context
|
||||
|
||||
from ... import nn, ops
|
||||
from ..._checkparam import Validator, Rel
|
||||
from ...nn.layer import quant
|
||||
from ...ops import functional as F
|
||||
from ..common import QuantDtype
|
||||
from .quantizer import Quantizer, OptimizeOption
|
||||
|
||||
|
||||
__all__ = ["QuantizationAwareTraining"]
|
||||
|
||||
|
||||
_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
|
||||
nn.ReLU6: quant.ActQuant,
|
||||
nn.Sigmoid: quant.ActQuant,
|
||||
nn.LeakyReLU: quant.LeakyReLUQuant,
|
||||
nn.HSigmoid: quant.HSigmoidQuant,
|
||||
nn.HSwish: quant.HSwishQuant}
|
||||
|
||||
|
||||
def get_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver),
|
||||
quant_delay=(0, 0),
|
||||
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
|
||||
per_channel=(False, False),
|
||||
symmetric=(False, False),
|
||||
narrow_range=(False, False)
|
||||
):
|
||||
r"""
|
||||
Configs the oberser type of weights and data flow with quant params.
|
||||
|
||||
Args:
|
||||
quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent
|
||||
weights and second element represent data flow.
|
||||
Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver)
|
||||
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
|
||||
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
|
||||
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
|
||||
element represent weights and second element represent data flow.
|
||||
Default: (QuantDtype.INT8, QuantDtype.INT8)
|
||||
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
|
||||
then base on per channel otherwise base on per layer. The first element represent weights
|
||||
and second element represent data flow. Default: (False, False)
|
||||
symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
|
||||
symmetric otherwise base on asymmetric. The first element represent weights and second
|
||||
element represent data flow. Default: (False, False)
|
||||
narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
|
||||
The first element represents weights and the second element represents data flow. Default: (False, False)
|
||||
|
||||
Returns:
|
||||
QuantConfig, Contains the oberser type of weight and activation.
|
||||
"""
|
||||
weight_observer = quant_observer[0].partial_init(quant_delay=quant_delay[0], quant_dtype=quant_dtype[0],
|
||||
per_channel=per_channel[0], symmetric=symmetric[0],
|
||||
narrow_range=narrow_range[0])
|
||||
act_observer = quant_observer[0].partial_init(quant_delay=quant_delay[-1], quant_dtype=quant_dtype[-1],
|
||||
per_channel=per_channel[-1], symmetric=symmetric[-1],
|
||||
narrow_range=narrow_range[-1])
|
||||
return quant.QuantConfig(weight=weight_observer, activation=act_observer)
|
||||
|
||||
|
||||
class _AddFakeQuantInput(nn.Cell):
|
||||
"""
|
||||
Add FakeQuant OP at input of the network. Only support one input case.
|
||||
"""
|
||||
|
||||
def __init__(self, network, quant_delay=0):
|
||||
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
|
||||
self.fake_quant_input = quant.FakeQuantWithMinMaxObserver(min_init=-6, max_init=6,
|
||||
quant_delay=quant_delay, ema=True)
|
||||
self.fake_quant_input.update_parameters_name('fake_quant_input.')
|
||||
self.network = network
|
||||
|
||||
def construct(self, data):
|
||||
data = self.fake_quant_input(data)
|
||||
output = self.network(data)
|
||||
return output
|
||||
|
||||
|
||||
class _AddFakeQuantAfterSubCell(nn.Cell):
|
||||
"""
|
||||
Add FakeQuant OP after of the sub Cell.
|
||||
"""
|
||||
|
||||
def __init__(self, subcell, **kwargs):
|
||||
super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
|
||||
self.subcell = subcell
|
||||
self.fake_quant_act = quant.FakeQuantWithMinMaxObserver(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
quant_dtype=kwargs["quant_dtype"],
|
||||
quant_delay=kwargs["quant_delay"],
|
||||
per_channel=kwargs["per_channel"],
|
||||
symmetric=kwargs["symmetric"],
|
||||
narrow_range=kwargs["narrow_range"])
|
||||
|
||||
def construct(self, *data):
|
||||
output = self.subcell(*data)
|
||||
output = self.fake_quant_act(output)
|
||||
return output
|
||||
|
||||
|
||||
class ConvertToQuantNetwork:
|
||||
"""
|
||||
Convert network to quantization aware network
|
||||
"""
|
||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.network = Validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
|
||||
self.weight_qdelay = Validator.check_non_negative_int(kwargs["quant_delay"][0], "quant delay")
|
||||
self.act_qdelay = Validator.check_int(kwargs["quant_delay"][-1], 0, Rel.GE, "quant delay")
|
||||
self.bn_fold = Validator.check_bool(kwargs["bn_fold"], "bn fold")
|
||||
self.freeze_bn = Validator.check_non_negative_int(kwargs["freeze_bn"], "freeze bn")
|
||||
self.weight_dtype = Validator.check_isinstance("weights dtype", kwargs["quant_dtype"][0], QuantDtype)
|
||||
self.act_dtype = Validator.check_isinstance("activations dtype", kwargs["quant_dtype"][-1], QuantDtype)
|
||||
self.weight_channel = Validator.check_bool(kwargs["per_channel"][0], "per channel")
|
||||
self.act_channel = Validator.check_bool(kwargs["per_channel"][-1], "per channel")
|
||||
self.weight_symmetric = Validator.check_bool(kwargs["symmetric"][0], "symmetric")
|
||||
self.act_symmetric = Validator.check_bool(kwargs["symmetric"][-1], "symmetric")
|
||||
self.weight_range = Validator.check_bool(kwargs["narrow_range"][0], "narrow range")
|
||||
self.act_range = Validator.check_bool(kwargs["narrow_range"][-1], "narrow range")
|
||||
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
|
||||
quant.DenseBnAct: self._convert_dense}
|
||||
self.quant_config = get_quant_config(quant_delay=kwargs["quant_delay"],
|
||||
quant_dtype=kwargs["quant_dtype"],
|
||||
per_channel=kwargs["per_channel"],
|
||||
symmetric=kwargs["symmetric"],
|
||||
narrow_range=kwargs["narrow_range"])
|
||||
|
||||
class QuantizationAwareTraining(Quantizer):
|
||||
r"""
|
||||
Quantizer for quantization aware training.
|
||||
|
||||
Args:
|
||||
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True.
|
||||
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7.
|
||||
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
|
||||
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
|
||||
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
|
||||
element represent weights and second element represent data flow.
|
||||
Default: (QuantDtype.INT8, QuantDtype.INT8)
|
||||
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
|
||||
then base on per channel otherwise base on per layer. The first element represent weights
|
||||
and second element represent data flow. Default: (False, False)
|
||||
symmetric (bool, list or tuple): Whether the quantization algorithm is symmetric or not. If `True` then base on
|
||||
symmetric otherwise base on asymmetric. The first element represent weights and second
|
||||
element represent data flow. Default: (False, False)
|
||||
narrow_range (bool, list or tuple): Whether the quantization algorithm uses narrow range or not.
|
||||
The first element represents weights and the second element represents data flow. Default: (False, False)
|
||||
optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options, currently only
|
||||
support QAT. Default: OptimizeOption.QAT
|
||||
"""
|
||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
def __init__(self,
|
||||
bn_fold=True,
|
||||
freeze_bn=10000000,
|
||||
quant_delay=(0, 0),
|
||||
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
|
||||
per_channel=(False, False),
|
||||
symmetric=(False, False),
|
||||
narrow_range=(False, False),
|
||||
optimize_option=OptimizeOption.QAT):
|
||||
"""Init for QuantizationAwareTraining quantizer"""
|
||||
super(QuantizationAwareTraining, self).__init__(optimize_option=optimize_option)
|
||||
def convert2list(name, value):
|
||||
if not isinstance(value, list) and not isinstance(value, tuple):
|
||||
value = [value]
|
||||
elif len(value) > 2:
|
||||
raise ValueError("input `{}` len should less then 2".format(name))
|
||||
return value
|
||||
|
||||
quant_delay = convert2list("quant delay", quant_delay)
|
||||
quant_dtype = convert2list("quant dtype", quant_dtype)
|
||||
per_channel = convert2list("per channel", per_channel)
|
||||
symmetric = convert2list("symmetric", symmetric)
|
||||
narrow_range = convert2list("narrow range", narrow_range)
|
||||
|
||||
self.weight_qdelay = Validator.check_non_negative_int(quant_delay[0], "quant delay")
|
||||
self.act_qdelay = Validator.check_int(quant_delay[-1], 0, Rel.GE, "quant delay")
|
||||
self.bn_fold = Validator.check_bool(bn_fold, "bn fold")
|
||||
self.freeze_bn = Validator.check_non_negative_int(freeze_bn, "freeze bn")
|
||||
self.weight_dtype = Validator.check_isinstance("weights dtype", quant_dtype[0], QuantDtype)
|
||||
self.act_dtype = Validator.check_isinstance("activations dtype", quant_dtype[-1], QuantDtype)
|
||||
self.weight_channel = Validator.check_bool(per_channel[0], "per channel")
|
||||
self.act_channel = Validator.check_bool(per_channel[-1], "per channel")
|
||||
self.weight_symmetric = Validator.check_bool(symmetric[0], "symmetric")
|
||||
self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric")
|
||||
self.weight_range = Validator.check_bool(narrow_range[0], "narrow range")
|
||||
self.act_range = Validator.check_bool(narrow_range[-1], "narrow range")
|
||||
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
|
||||
quant.DenseBnAct: self._convert_dense}
|
||||
self.quant_config = get_quant_config(quant_delay=quant_delay,
|
||||
quant_dtype=quant_dtype,
|
||||
per_channel=per_channel,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
|
||||
def _convert_op_name(self, name):
|
||||
pattern = re.compile(r'([A-Z]{1})')
|
||||
name_new = re.sub(pattern, r'_\1', name).lower()
|
||||
if name_new[0] == '_':
|
||||
name_new = name_new[1:]
|
||||
return name_new
|
||||
|
||||
def quantize(self, network):
|
||||
support_device = ["Ascend", "GPU"]
|
||||
if context.get_context('device_target') not in support_device:
|
||||
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
||||
|
||||
if OptimizeOption.QAT in self.optimize_option:
|
||||
network.update_cell_prefix()
|
||||
network = self._convert_subcells2quant(network)
|
||||
network.update_cell_type("quant")
|
||||
return network
|
||||
|
||||
def _convert_subcells2quant(self, network):
|
||||
"""
|
||||
convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
|
||||
"""
|
||||
cells = network.name_cells()
|
||||
change = False
|
||||
for name in cells:
|
||||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
continue
|
||||
elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)):
|
||||
prefix = subcell.param_prefix
|
||||
new_subcell = self._convert_method_map[type(subcell)](subcell)
|
||||
new_subcell.update_parameters_name(prefix + '.')
|
||||
network.insert_child_to_cell(name, new_subcell)
|
||||
change = True
|
||||
else:
|
||||
self._convert_subcells2quant(subcell)
|
||||
if isinstance(network, nn.SequentialCell) and change:
|
||||
network.cell_list = list(network.cells())
|
||||
|
||||
# add FakeQuant OP after OP in while list
|
||||
add_list = []
|
||||
for name in network.__dict__:
|
||||
if name[0] == '_':
|
||||
continue
|
||||
attr = network.__dict__[name]
|
||||
if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
|
||||
add_list.append((name, attr))
|
||||
for name, prim_op in add_list:
|
||||
prefix = name
|
||||
add_quant = _AddFakeQuantAfterSubCell(prim_op,
|
||||
quant_dtype=self.act_dtype,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
narrow_range=self.act_range)
|
||||
prefix = self._convert_op_name(prim_op.name)
|
||||
if network.param_prefix:
|
||||
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
|
||||
add_quant.update_parameters_name(prefix + '.')
|
||||
del network.__dict__[name]
|
||||
network.insert_child_to_cell(name, add_quant)
|
||||
return network
|
||||
|
||||
def _convert_conv(self, subcell):
|
||||
"""
|
||||
convert Conv2d cell to quant cell
|
||||
"""
|
||||
conv_inner = subcell.conv
|
||||
if subcell.has_bn:
|
||||
if self.bn_fold:
|
||||
bn_inner = subcell.batchnorm
|
||||
conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
stride=conv_inner.stride,
|
||||
pad_mode=conv_inner.pad_mode,
|
||||
padding=conv_inner.padding,
|
||||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
momentum=bn_inner.momentum,
|
||||
has_bias=conv_inner.has_bias,
|
||||
bias_init=conv_inner.bias_init,
|
||||
freeze_bn=self.freeze_bn,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype,
|
||||
fake=True)
|
||||
# change original network BatchNormal OP parameters to quant network
|
||||
conv_inner.gamma = subcell.batchnorm.gamma
|
||||
conv_inner.beta = subcell.batchnorm.beta
|
||||
conv_inner.moving_mean = subcell.batchnorm.moving_mean
|
||||
conv_inner.moving_variance = subcell.batchnorm.moving_variance
|
||||
del subcell.batchnorm
|
||||
subcell.batchnorm = None
|
||||
subcell.has_bn = False
|
||||
else:
|
||||
bn_inner = subcell.batchnorm
|
||||
conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
stride=conv_inner.stride,
|
||||
pad_mode=conv_inner.pad_mode,
|
||||
padding=conv_inner.padding,
|
||||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
momentum=bn_inner.momentum,
|
||||
has_bias=conv_inner.has_bias,
|
||||
bias_init=conv_inner.bias_init,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype)
|
||||
# change original network BatchNormal OP parameters to quant network
|
||||
conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
|
||||
conv_inner.batchnorm.beta = subcell.batchnorm.beta
|
||||
conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean
|
||||
conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance
|
||||
del subcell.batchnorm
|
||||
subcell.batchnorm = None
|
||||
subcell.has_bn = False
|
||||
else:
|
||||
conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
stride=conv_inner.stride,
|
||||
pad_mode=conv_inner.pad_mode,
|
||||
padding=conv_inner.padding,
|
||||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
has_bias=conv_inner.has_bias,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype)
|
||||
# change original network Conv2D OP parameters to quant network
|
||||
conv_inner.weight = subcell.conv.weight
|
||||
if subcell.conv.has_bias:
|
||||
conv_inner.bias = subcell.conv.bias
|
||||
subcell.conv = conv_inner
|
||||
if subcell.has_act and subcell.activation is not None:
|
||||
subcell.activation = self._convert_activation(subcell.activation)
|
||||
elif subcell.after_fake:
|
||||
subcell.has_act = True
|
||||
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
|
||||
quant_dtype=self.act_dtype,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
narrow_range=self.act_range)
|
||||
return subcell
|
||||
|
||||
def _convert_dense(self, subcell):
|
||||
"""
|
||||
convert dense cell to combine dense cell
|
||||
"""
|
||||
dense_inner = subcell.dense
|
||||
dense_inner = quant.DenseQuant(dense_inner.in_channels,
|
||||
dense_inner.out_channels,
|
||||
has_bias=dense_inner.has_bias,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.weight_dtype)
|
||||
# change original network Dense OP parameters to quant network
|
||||
dense_inner.weight = subcell.dense.weight
|
||||
if subcell.dense.has_bias:
|
||||
dense_inner.bias = subcell.dense.bias
|
||||
subcell.dense = dense_inner
|
||||
if subcell.has_act and subcell.activation is not None:
|
||||
subcell.activation = self._convert_activation(subcell.activation)
|
||||
elif subcell.after_fake:
|
||||
subcell.has_act = True
|
||||
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
|
||||
quant_dtype=self.act_dtype,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
narrow_range=self.act_range)
|
||||
return subcell
|
||||
|
||||
def _convert_activation(self, activation):
|
||||
act_class = activation.__class__
|
||||
if act_class not in _ACTIVATION_MAP:
|
||||
raise ValueError("Unsupported activation in auto quant: ", act_class)
|
||||
return _ACTIVATION_MAP[act_class](activation=activation,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.act_dtype)
|
|
@ -17,6 +17,9 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
__all__ = ["load_nonquant_param_into_quant_net"]
|
||||
|
||||
|
||||
def cal_quantization_params(input_min,
|
||||
input_max,
|
||||
data_type,
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Base Class of Quantizer."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
__all__ = ["OptimizeOption", "Quantizer"]
|
||||
|
||||
|
||||
class OptimizeOption(Enum):
|
||||
"""
|
||||
An enum for the model quantization optimize option.
|
||||
"""
|
||||
# using quantization aware training
|
||||
QAT = "QAT"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class Quantizer(ABC):
|
||||
"""
|
||||
Base class of Quantizer. You can implement different kind of quantizer to get different quantization result.
|
||||
|
||||
Notes:
|
||||
This class is an abstract class.
|
||||
|
||||
Args:
|
||||
optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options. Default: None.
|
||||
"""
|
||||
def __init__(self,
|
||||
optimize_option=None):
|
||||
if not isinstance(optimize_option, list) and not isinstance(optimize_option, tuple):
|
||||
optimize_option = [optimize_option]
|
||||
self.optimize_option = optimize_option
|
||||
|
||||
@abstractmethod
|
||||
def quantize(self, network):
|
||||
pass
|
|
@ -30,7 +30,7 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.common.api import _executor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._checkparam import check_input_data, Validator
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.compression.export import quant_export
|
||||
import mindspore.context as context
|
||||
|
||||
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
|
||||
|
@ -596,14 +596,14 @@ def _quant_export(network, *inputs, file_format, **kwargs):
|
|||
network.set_train(False)
|
||||
if file_format == "MINDIR":
|
||||
if quant_mode == 'MANUAL':
|
||||
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
||||
exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
||||
else:
|
||||
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
||||
exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
||||
else:
|
||||
if quant_mode == 'MANUAL':
|
||||
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs)
|
||||
exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs)
|
||||
else:
|
||||
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
|
||||
exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
|
||||
deploy_net = exporter.run()
|
||||
return deploy_net
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore import context
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from src.dataset import create_dataset
|
||||
from src.config import mnist_cfg as cfg
|
||||
from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||
|
@ -47,8 +47,12 @@ if __name__ == "__main__":
|
|||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
|
||||
per_channel=[True, False], symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(quant_delay=0,
|
||||
bn_fold=False,
|
||||
freeze_bn=10000,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
# define loss
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
|
|
|
@ -22,7 +22,7 @@ import numpy as np
|
|||
import mindspore
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import mnist_cfg as cfg
|
||||
|
@ -44,8 +44,12 @@ if __name__ == "__main__":
|
|||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
|
||||
per_channel=[True, False], symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(quant_delay=0,
|
||||
bn_fold=False,
|
||||
freeze_bn=10000,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
|
|
@ -26,8 +26,8 @@ from mindspore.train.serialization import load_checkpoint
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
from mindspore.common import set_seed
|
||||
from src.dataset import create_dataset
|
||||
from src.config import mnist_cfg as cfg
|
||||
|
@ -59,8 +59,11 @@ if __name__ == "__main__":
|
|||
load_nonquant_param_into_quant_net(network, param_dict)
|
||||
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False],
|
||||
quantizer = QuantizationAwareTraining(quant_delay=900,
|
||||
bn_fold=False,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
# define network loss
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
|
|
|
@ -21,7 +21,7 @@ from mindspore import context
|
|||
from mindspore import nn
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
|
||||
from src.mobilenetV2 import mobilenetV2
|
||||
from src.dataset import create_dataset
|
||||
|
@ -51,7 +51,10 @@ if __name__ == '__main__':
|
|||
# define fusion network
|
||||
network = mobilenetV2(num_classes=config_device_target.num_classes)
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
# define network loss
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import mindspore
|
|||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
|
||||
from src.mobilenetV2 import mobilenetV2
|
||||
from src.config import config_ascend_quant
|
||||
|
@ -42,7 +42,10 @@ if __name__ == '__main__':
|
|||
# define fusion network
|
||||
network = mobilenetV2(num_classes=cfg.num_classes)
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
|
|
@ -26,8 +26,8 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.dataset import create_dataset
|
||||
|
@ -99,10 +99,10 @@ def train_on_ascend():
|
|||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
load_nonquant_param_into_quant_net(network, param_dict)
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network,
|
||||
bn_fold=True,
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
# get learning rate
|
||||
lr = Tensor(get_lr(global_step=config.start_epoch * step_size,
|
||||
|
@ -162,12 +162,12 @@ def train_on_gpu():
|
|||
load_nonquant_param_into_quant_net(network, param_dict)
|
||||
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network,
|
||||
bn_fold=True,
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False],
|
||||
freeze_bn=1000000,
|
||||
quant_delay=step_size * 2)
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
# get learning rate
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
|
|
@ -26,7 +26,7 @@ from models.resnet_quant_manual import resnet50_quant #manually construct quanta
|
|||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
|
@ -43,12 +43,13 @@ if args_opt.device_target == "Ascend":
|
|||
|
||||
if __name__ == '__main__':
|
||||
# define fusion network
|
||||
net = resnet50_quant(class_num=config.class_num)
|
||||
network = resnet50_quant(class_num=config.class_num)
|
||||
# convert fusion network to quantization aware network
|
||||
net = quant.convert_quant_network(net,
|
||||
bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
# define network loss
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
|
@ -65,13 +66,13 @@ if __name__ == '__main__':
|
|||
# load checkpoint
|
||||
if args_opt.checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
not_load_param = load_param_into_net(net, param_dict)
|
||||
not_load_param = load_param_into_net(network, param_dict)
|
||||
if not_load_param:
|
||||
raise ValueError("Load param into net fail!")
|
||||
net.set_train(False)
|
||||
raise ValueError("Load param into network fail!")
|
||||
network.set_train(False)
|
||||
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, metrics={'acc'})
|
||||
model = Model(network, loss_fn=loss, metrics={'acc'})
|
||||
|
||||
print("============== Starting Validation ==============")
|
||||
res = model.eval(dataset)
|
||||
|
|
|
@ -17,14 +17,14 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant
|
||||
from mindspore.compression.quant import qat
|
||||
|
||||
_ema_decay = 0.999
|
||||
_symmetric = True
|
||||
_fake = True
|
||||
_per_channel = True
|
||||
_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))
|
||||
_quant_config = qat.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))
|
||||
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
|
@ -90,8 +90,8 @@ class ConvBNReLU(nn.Cell):
|
|||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
|
||||
group=groups, fake=_fake, quant_config=_quant_config)
|
||||
conv = Conv2dBnFoldQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
|
||||
group=groups, fake=_fake, quant_config=_quant_config)
|
||||
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
|
||||
self.features = nn.SequentialCell(layers)
|
||||
|
||||
|
@ -126,14 +126,14 @@ class ResidualBlock(nn.Cell):
|
|||
channel = out_channel // self.expansion
|
||||
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
|
||||
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
|
||||
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1, pad_mode='same', padding=0),
|
||||
self.conv3 = nn.SequentialCell([Conv2dBnFoldQuant(channel, out_channel, fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1, pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1,
|
||||
pad_mode='same', padding=0)
|
||||
]) if _fake else Conv2dBnFoldQuant(channel, out_channel, fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1,
|
||||
pad_mode='same', padding=0)
|
||||
|
||||
self.down_sample = False
|
||||
|
||||
|
@ -142,20 +142,19 @@ class ResidualBlock(nn.Cell):
|
|||
self.down_sample_layer = None
|
||||
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=stride,
|
||||
pad_mode='same', padding=0),
|
||||
self.down_sample_layer = nn.SequentialCell([Conv2dBnFoldQuant(in_channel, out_channel,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=stride,
|
||||
pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay,
|
||||
symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
fake=_fake,
|
||||
quant_config=\
|
||||
_quant_config,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
pad_mode='same',
|
||||
padding=0)
|
||||
]) if _fake else Conv2dBnFoldQuant(in_channel, out_channel,
|
||||
fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
pad_mode='same',
|
||||
padding=0)
|
||||
self.add = nn.TensorAddQuant()
|
||||
self.relu = P.ReLU()
|
||||
|
||||
|
|
|
@ -25,8 +25,8 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
from mindspore.communication.management import init
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
|
@ -113,7 +113,10 @@ if __name__ == '__main__':
|
|||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# convert fusion network to quantization aware network
|
||||
net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
# get learning rate
|
||||
lr = get_lr(lr_init=config.lr_init,
|
||||
|
|
|
@ -29,7 +29,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
import mindspore as ms
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
|
||||
from src.yolo import YOLOV3DarkNet53
|
||||
from src.logger import get_logger
|
||||
|
@ -265,10 +265,10 @@ def test():
|
|||
|
||||
# convert fusion network to quantization aware network
|
||||
if config.quantization_aware:
|
||||
network = quant.convert_quant_network(network,
|
||||
bn_fold=True,
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
args.logger.info(args.pretrained)
|
||||
if os.path.isfile(args.pretrained):
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""hub config."""
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from src.yolo import YOLOV3DarkNet53
|
||||
from src.config import ConfigYOLOV3DarkNet53
|
||||
|
||||
|
@ -24,9 +24,9 @@ def create_network(name, *args, **kwargs):
|
|||
config = ConfigYOLOV3DarkNet53()
|
||||
# convert fusion network to quantization aware network
|
||||
if config.quantization_aware:
|
||||
yolov3_darknet53_quant = quant.convert_quant_network(yolov3_darknet53_quant,
|
||||
bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
yolov3_darknet53_quant = quantizer.quantize(yolov3_darknet53_quant)
|
||||
return yolov3_darknet53_quant
|
||||
raise NotImplementedError(f"{name} is not implemented in the repo")
|
||||
|
|
|
@ -27,7 +27,7 @@ from mindspore.communication.management import init, get_rank, get_group_size
|
|||
from mindspore.train.callback import ModelCheckpoint, RunContext
|
||||
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
|
||||
import mindspore as ms
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
|
||||
|
@ -168,10 +168,10 @@ def train():
|
|||
config = ConfigYOLOV3DarkNet53()
|
||||
# convert fusion network to quantization aware network
|
||||
if config.quantization_aware:
|
||||
network = quant.convert_quant_network(network,
|
||||
bn_fold=True,
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
network = YoloWithLossCell(network)
|
||||
args.logger.info('finish get network')
|
||||
|
|
|
@ -26,8 +26,8 @@ from mindspore.nn.metrics import Accuracy
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
from dataset import create_dataset
|
||||
from config import nonquant_cfg, quant_cfg
|
||||
from lenet import LeNet5
|
||||
|
@ -73,8 +73,11 @@ def train_lenet_quant():
|
|||
load_nonquant_param_into_quant_net(network, param_dict)
|
||||
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False],
|
||||
symmetric=[False, False])
|
||||
quantizer = QuantizationAwareTraining(quant_delay=900,
|
||||
bn_fold=False,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
# define network loss
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
|
@ -103,8 +106,12 @@ def eval_quant():
|
|||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
|
||||
per_channel=[True, False])
|
||||
quantizer = QuantizationAwareTraining(quant_delay=0,
|
||||
bn_fold=False,
|
||||
freeze_bn=10000,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
# define loss
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
|
@ -131,8 +138,12 @@ def export_lenet():
|
|||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000,
|
||||
per_channel=[True, False], symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(quant_delay=0,
|
||||
bn_fold=False,
|
||||
freeze_bn=10000,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)
|
||||
|
|
|
@ -23,7 +23,7 @@ from mindspore import context
|
|||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from dataset import create_dataset
|
||||
|
@ -84,10 +84,10 @@ def test_mobilenetv2_quant():
|
|||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network,
|
||||
bn_fold=True,
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
|
||||
# get learning rate
|
||||
lr = Tensor(get_lr(global_step=config.start_epoch * step_size,
|
||||
|
|
|
@ -18,14 +18,14 @@ import mindspore.nn as nn
|
|||
import mindspore.common.initializer as weight_init
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant
|
||||
from mindspore.compression.quant import qat
|
||||
|
||||
_ema_decay = 0.999
|
||||
_symmetric = True
|
||||
_fake = True
|
||||
_per_channel = True
|
||||
_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))
|
||||
_quant_config = qat.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))
|
||||
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
|
@ -91,8 +91,8 @@ class ConvBNReLU(nn.Cell):
|
|||
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
|
||||
group=groups, fake=_fake, quant_config=_quant_config)
|
||||
conv = Conv2dBnFoldQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
|
||||
group=groups, fake=_fake, quant_config=_quant_config)
|
||||
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
|
||||
self.features = nn.SequentialCell(layers)
|
||||
|
||||
|
@ -127,14 +127,14 @@ class ResidualBlock(nn.Cell):
|
|||
channel = out_channel // self.expansion
|
||||
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
|
||||
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
|
||||
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1, pad_mode='same', padding=0),
|
||||
self.conv3 = nn.SequentialCell([Conv2dBnFoldQuant(channel, out_channel, fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1, pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1,
|
||||
pad_mode='same', padding=0)
|
||||
]) if _fake else Conv2dBnFoldQuant(channel, out_channel, fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=1,
|
||||
pad_mode='same', padding=0)
|
||||
|
||||
self.down_sample = False
|
||||
|
||||
|
@ -143,20 +143,19 @@ class ResidualBlock(nn.Cell):
|
|||
self.down_sample_layer = None
|
||||
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=stride,
|
||||
pad_mode='same', padding=0),
|
||||
self.down_sample_layer = nn.SequentialCell([Conv2dBnFoldQuant(in_channel, out_channel,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1, stride=stride,
|
||||
pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay,
|
||||
symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
fake=_fake,
|
||||
quant_config=\
|
||||
_quant_config,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
pad_mode='same',
|
||||
padding=0)
|
||||
]) if _fake else Conv2dBnFoldQuant(in_channel, out_channel,
|
||||
fake=_fake,
|
||||
quant_config=_quant_config,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
pad_mode='same',
|
||||
padding=0)
|
||||
self.add = nn.TensorAddQuant()
|
||||
self.relu = P.ReLU()
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore import context
|
|||
from mindspore import Tensor
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore import set_seed
|
||||
|
||||
from resnet_quant_manual import resnet50_quant
|
||||
|
@ -89,10 +89,10 @@ def test_resnet50_quant():
|
|||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# convert fusion network to quantization aware network
|
||||
net = quant.convert_quant_network(net,
|
||||
bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
net = quantizer.quantize(net)
|
||||
|
||||
# get learning rate
|
||||
lr = Tensor(get_lr(lr_init=config.lr_init,
|
||||
|
|
|
@ -19,7 +19,8 @@ import pytest
|
|||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore import nn
|
||||
from mindspore.train.quant import quant as qat
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.compression.export import quant_export
|
||||
from model_zoo.official.cv.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
@ -66,27 +67,35 @@ class LeNet5(nn.Cell):
|
|||
def test_qat_lenet():
|
||||
img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32))
|
||||
net = LeNet5()
|
||||
net = qat.convert_quant_network(
|
||||
net, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
net = quantizer.quantize(net)
|
||||
# should load the checkpoint. mock here
|
||||
net.init_parameters_data()
|
||||
qat.export(net, img, file_name="quant.pb")
|
||||
quant_export.export(net, img, file_name="quant.pb")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="no `te.lang.cce` in ut env")
|
||||
def test_qat_mobile_per_channel_tf():
|
||||
network = mobilenetV2(num_classes=1000)
|
||||
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
|
||||
network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
# should load the checkpoint. mock here
|
||||
network.init_parameters_data()
|
||||
qat.export(network, img, file_name="quant.pb")
|
||||
quant_export.export(network, img, file_name="quant.pb")
|
||||
|
||||
@pytest.mark.skip(reason="no `te.lang.cce` in ut env")
|
||||
def test_qat_mobile_per_channel_ff():
|
||||
network = mobilenetV2(num_classes=1000)
|
||||
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
|
||||
network = qat.convert_quant_network(network, bn_fold=True, per_channel=[False, False], symmetric=[True, False])
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[False, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
# should load the checkpoint. mock here
|
||||
network.init_parameters_data()
|
||||
qat.export(network, img, file_name="quant.pb")
|
||||
quant_export.export(network, img, file_name="quant.pb")
|
||||
|
|
Loading…
Reference in New Issue