forked from mindspore-Ecosystem/mindspore
add mobilenetV2 quant export
This commit is contained in:
parent
ccad0cae3a
commit
d383ade6f9
|
@ -289,7 +289,8 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
|
|||
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
|
||||
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
|
||||
auto filter = [](AnfNodePtr node) {
|
||||
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul));
|
||||
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative));
|
||||
};
|
||||
std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
|
||||
auto is_quant_cnode = [](AnfNodePtr node) {
|
||||
|
|
|
@ -530,6 +530,7 @@ _activation = {
|
|||
'relu6': ReLU6,
|
||||
'tanh': Tanh,
|
||||
'gelu': GELU,
|
||||
'elu': ELU,
|
||||
'sigmoid': Sigmoid,
|
||||
'prelu': PReLU,
|
||||
'leakyrelu': LeakyReLU,
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from mindspore import nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -41,8 +42,7 @@ __all__ = [
|
|||
'Conv2dBatchNormQuant',
|
||||
'Conv2dQuant',
|
||||
'DenseQuant',
|
||||
'ReLUQuant',
|
||||
'ReLU6Quant',
|
||||
'ActQuant',
|
||||
'HSwishQuant',
|
||||
'HSigmoidQuant',
|
||||
'TensorAddQuant',
|
||||
|
@ -375,9 +375,10 @@ class FakeQuantWithMinMax(Cell):
|
|||
|
||||
def extend_repr(self):
|
||||
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
|
||||
'quant_delay={}, min_init={}, max_init={}'.format(
|
||||
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel,
|
||||
self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init)
|
||||
'quant_delay={}, min_init={}, max_init={}'.format(self.num_bits, self.symmetric, self.narrow_range,
|
||||
self.ema, self.ema_decay, self.per_channel,
|
||||
self.channel_axis, self.num_channels, self.quant_delay,
|
||||
self.min_init, self.max_init)
|
||||
return s
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -540,10 +541,12 @@ class Conv2dBatchNormQuant(Cell):
|
|||
def extend_repr(self):
|
||||
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
|
||||
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
||||
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
|
||||
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||
self.pad_mode, self.padding, self.dilation, self.group,
|
||||
self.fake, self.freeze_bn, self.momentum, self.quant_delay)
|
||||
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels,
|
||||
self.kernel_size, self.stride,
|
||||
self.pad_mode, self.padding, self.dilation,
|
||||
self.group,
|
||||
self.fake, self.freeze_bn, self.momentum,
|
||||
self.quant_delay)
|
||||
return s
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -685,10 +688,9 @@ class Conv2dQuant(Cell):
|
|||
def extend_repr(self):
|
||||
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
|
||||
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
||||
'has_bias={}, quant_delay={}'.format(
|
||||
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||
self.pad_mode, self.padding, self.dilation, self.group,
|
||||
self.has_bias, self.quant_delay)
|
||||
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||
self.pad_mode, self.padding, self.dilation, self.group,
|
||||
self.has_bias, self.quant_delay)
|
||||
return s
|
||||
|
||||
|
||||
|
@ -799,76 +801,23 @@ class DenseQuant(Cell):
|
|||
|
||||
class _QuantActivation(Cell):
|
||||
r"""
|
||||
Base class for Quant activation function. Add Fake Quant OP after activation OP.
|
||||
Base class for quantization aware training activation function. Add Fake Quant OP after activation OP.
|
||||
"""
|
||||
|
||||
def get_origin(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReLUQuant(_QuantActivation):
|
||||
class ActQuant(_QuantActivation):
|
||||
r"""
|
||||
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
|
||||
Quantization aware training activation function.
|
||||
|
||||
For a more Detailed overview of ReLU op.
|
||||
|
||||
Args:
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of ReLUQuant.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Examples:
|
||||
>>> relu_quant = nn.ReLUQuant()
|
||||
>>> input_x = Tensor(np.array([[1, 2, 0], [-1, -2, 1]]), mindspore.float32)
|
||||
>>> result = relu_quant(input_x)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(ReLUQuant, self).__init__()
|
||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.fake_quant_act(x)
|
||||
return x
|
||||
|
||||
def get_origin(self):
|
||||
return self.relu
|
||||
|
||||
|
||||
class ReLU6Quant(_QuantActivation):
|
||||
r"""
|
||||
ReLU6Quant activation function.
|
||||
|
||||
Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op
|
||||
Add Fake Quant OP after activation. Not Recommand to used these cell for Fake Quant Op
|
||||
Will climp the max range of the activation and the relu6 do the same operation.
|
||||
For a more Detailed overview of ReLU6 op.
|
||||
|
||||
Args:
|
||||
activation (Cell): Activation cell class.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
|
@ -883,19 +832,20 @@ class ReLU6Quant(_QuantActivation):
|
|||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Examples:
|
||||
>>> relu6_quant = nn.ReLU6Quant(4, 1)
|
||||
>>> act_quant = nn.ActQuant(4, 1)
|
||||
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
|
||||
>>> result = relu6_quant(input_x)
|
||||
>>> result = act_quant(input_x)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
activation,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(ReLU6Quant, self).__init__()
|
||||
super(ActQuant, self).__init__()
|
||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
|
@ -905,15 +855,15 @@ class ReLU6Quant(_QuantActivation):
|
|||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.relu6 = P.ReLU6()
|
||||
self.act = activation
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu6(x)
|
||||
x = self.act(x)
|
||||
x = self.fake_quant_act(x)
|
||||
return x
|
||||
|
||||
def get_origin(self):
|
||||
return self.relu6
|
||||
return self.act
|
||||
|
||||
|
||||
class HSwishQuant(_QuantActivation):
|
||||
|
@ -923,6 +873,7 @@ class HSwishQuant(_QuantActivation):
|
|||
For a more Detailed overview of HSwish op.
|
||||
|
||||
Args:
|
||||
activation (Cell): Activation cell class.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
|
@ -943,6 +894,7 @@ class HSwishQuant(_QuantActivation):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
activation,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
|
@ -968,7 +920,10 @@ class HSwishQuant(_QuantActivation):
|
|||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.act = P.HSwish()
|
||||
if isinstance(activation, nn.HSwish):
|
||||
self.act = activation
|
||||
else:
|
||||
raise ValueError("Activation should be `nn.HSwish`")
|
||||
|
||||
def construct(self, x):
|
||||
x = self.fake_quant_act_before(x)
|
||||
|
@ -987,6 +942,7 @@ class HSigmoidQuant(_QuantActivation):
|
|||
For a more Detailed overview of HSigmoid op.
|
||||
|
||||
Args:
|
||||
activation (Cell): Activation cell class.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
|
@ -1007,6 +963,7 @@ class HSigmoidQuant(_QuantActivation):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
activation,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
|
@ -1032,7 +989,10 @@ class HSigmoidQuant(_QuantActivation):
|
|||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.act = P.HSigmoid()
|
||||
if isinstance(activation, nn.HSwish):
|
||||
self.act = activation
|
||||
else:
|
||||
raise ValueError("Activation should be `nn.HSigmoid`")
|
||||
|
||||
def construct(self, x):
|
||||
x = self.fake_quant_act_before(x)
|
||||
|
|
|
@ -1004,6 +1004,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
|
|||
def infer_dtype(self, x_dtype, w_dtype):
|
||||
args = {'x': x_dtype, 'w': w_dtype}
|
||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||
if x_dtype.element_type() == mstype.int8:
|
||||
return mstype.tensor_type(mstype.int32)
|
||||
return x_dtype
|
||||
|
||||
|
||||
|
|
|
@ -33,8 +33,10 @@ from ...ops.operations import _inner_ops as inner
|
|||
from ...train import serialization
|
||||
from . import quant_utils
|
||||
|
||||
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
||||
nn.ReLU6: quant.ReLU6Quant,
|
||||
_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
|
||||
nn.ReLU6: quant.ActQuant,
|
||||
nn.LeakyReLU: quant.ActQuant,
|
||||
nn.Sigmoid: quant.ActQuant,
|
||||
nn.HSigmoid: quant.HSigmoidQuant,
|
||||
nn.HSwish: quant.HSwishQuant}
|
||||
|
||||
|
@ -257,9 +259,9 @@ class ConvertToQuantNetwork:
|
|||
def _convert_activation(self, activation):
|
||||
act_class = activation.__class__
|
||||
if act_class not in _ACTIVATION_MAP:
|
||||
raise ValueError(
|
||||
"Unsupported activation in auto quant: ", act_class)
|
||||
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
|
||||
raise ValueError("Unsupported activation in auto quant: ", act_class)
|
||||
return _ACTIVATION_MAP[act_class](activation=act_class,
|
||||
num_bits=self.act_bits,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.act_symmetric,
|
||||
|
@ -317,7 +319,7 @@ class ExportToQuantInferNetwork:
|
|||
minq = self.all_parameters[minq_name]
|
||||
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
|
||||
else:
|
||||
logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}")
|
||||
logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
|
||||
return None
|
||||
|
||||
# Build the `Quant` `Dequant` op.
|
||||
|
@ -325,7 +327,7 @@ class ExportToQuantInferNetwork:
|
|||
quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in))
|
||||
sqrt_mode = False
|
||||
scale_deq = scale_a_out * scale_w
|
||||
if scale_deq < 2 ** -14:
|
||||
if (scale_deq < 2 ** -14).all():
|
||||
scale_deq = np.sqrt(scale_deq)
|
||||
sqrt_mode = True
|
||||
dequant_op = inner.AscendDequant(sqrt_mode)
|
||||
|
@ -404,11 +406,17 @@ def export(network, *inputs, file_name, file_format='GEIR'):
|
|||
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
|
||||
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
|
||||
"""
|
||||
supported_device = ["Ascend"]
|
||||
supported_formats = ['GEIR']
|
||||
|
||||
if context.get_context('device_target') not in supported_device:
|
||||
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
||||
|
||||
if file_format not in supported_formats:
|
||||
raise ValueError('Illegal file format {}.'.format(file_format))
|
||||
|
||||
network.set_train(False)
|
||||
|
||||
if file_format == 'GEIR':
|
||||
exporter = ExportToQuantInferNetwork(network, *inputs)
|
||||
deploy_net = exporter.run()
|
||||
|
|
|
@ -45,7 +45,7 @@ def cal_quantization_params(input_min,
|
|||
raise ValueError("input min shape should equal to input max.")
|
||||
if len(input_min.shape) > 1:
|
||||
raise ValueError("input min and max shape should be one dim.")
|
||||
if input_min > input_max:
|
||||
if (input_min > input_max).all():
|
||||
raise ValueError("input_min min should less than input max.")
|
||||
if (input_max == input_min).all():
|
||||
# scale = 1.0, zp = 0.0
|
||||
|
@ -85,9 +85,7 @@ def cal_quantization_params(input_min,
|
|||
return scale, zp
|
||||
|
||||
|
||||
def weight2int(data,
|
||||
scale,
|
||||
zero_point):
|
||||
def weight2int(data, scale, zero_point):
|
||||
r"""
|
||||
Calculate int8/uint8 weight from fp32. the formula is defined as:
|
||||
|
||||
|
@ -103,12 +101,24 @@ def weight2int(data,
|
|||
weight (numpy.ndarray): The dimension of channel or 1.
|
||||
"""
|
||||
if scale.shape != zero_point.shape:
|
||||
raise ValueError("scale and zero_point should have the same shape.")
|
||||
if scale.shape[0] > 0:
|
||||
scale = scale.reshape(1, -1)
|
||||
zero_point = zero_point.reshape(1, -1)
|
||||
raise ValueError("`scale` and `zero_point` should have the same shape.")
|
||||
if scale.shape[0] < 0:
|
||||
raise ValueError("`scale` and `zero_point` shape should greater than zero.")
|
||||
|
||||
return np.round((data/scale) + zero_point)
|
||||
if scale.shape[0] == data.shape[0]:
|
||||
# `Conv2d` or `Dense` op weight
|
||||
shape_list = [-1] + [1] * len(data.shape[1:])
|
||||
scale = scale.reshape(shape_list)
|
||||
zero_point = zero_point.reshape(shape_list)
|
||||
elif scale.shape[0] == data.shape[1]:
|
||||
# `DepthwiseConv2d` op weight
|
||||
shape_list = [1, -1] + [1] * len(data.shape[2:])
|
||||
scale = scale.reshape(shape_list)
|
||||
zero_point = zero_point.reshape(shape_list)
|
||||
else:
|
||||
raise ValueError("Unsupported weight shape({})".format(data.shape))
|
||||
|
||||
return np.round((data / scale) + zero_point)
|
||||
|
||||
|
||||
def scale_zp_from_fack_quant_cell(cell, data_type):
|
||||
|
@ -183,9 +193,20 @@ def fold_batchnorm(weight, cell_quant):
|
|||
beta = cell_quant.beta.data.asnumpy()
|
||||
epsilon = cell_quant.eps
|
||||
sigma = np.sqrt(variance + epsilon)
|
||||
gamma = gamma.reshape(-1, 1, 1, 1)
|
||||
sigma = sigma.reshape(-1, 1, 1, 1)
|
||||
mean = mean.reshape(-1, 1, 1, 1)
|
||||
weight = weight * gamma / sigma
|
||||
|
||||
if gamma.shape[0] == weight.shape[0]:
|
||||
# `Conv2d` or `Dense` op weight
|
||||
shape_list = [-1] + [1] * len(weight.shape[1:])
|
||||
_gamma = gamma.reshape(shape_list)
|
||||
_sigma = sigma.reshape(shape_list)
|
||||
elif gamma.shape[0] == weight.shape[1]:
|
||||
# `DepthwiseConv2d` op weight
|
||||
shape_list = [1, -1] + [1] * len(weight.shape[2:])
|
||||
_gamma = gamma.reshape(shape_list)
|
||||
_sigma = sigma.reshape(shape_list)
|
||||
else:
|
||||
raise ValueError("Unsupported weight shape({})".format(weight.shape))
|
||||
|
||||
weight = weight * _gamma / _sigma
|
||||
bias = beta - gamma * mean / sigma
|
||||
return weight, bias
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Export MobilenetV2 on ImageNet"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.quant import quant
|
||||
|
||||
from src.mobilenetV2 import mobilenetV2
|
||||
from src.config import config_ascend
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--device_target', type=str, default=None, help='Run device target')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
cfg = None
|
||||
if args_opt.device_target == "Ascend":
|
||||
cfg = config_ascend
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
else:
|
||||
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
|
||||
|
||||
# define fusion network
|
||||
network = mobilenetV2(num_classes=cfg.num_classes)
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# export network
|
||||
print("============== Starting export ==============")
|
||||
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||
quant.export(network, inputs, file_name="mobilenet_quant", file_format='GEIR')
|
||||
print("============== End export ==============")
|
Loading…
Reference in New Issue