forked from mindspore-Ecosystem/mindspore
add mobilenetV2 quant export
This commit is contained in:
parent
ccad0cae3a
commit
d383ade6f9
|
@ -289,7 +289,8 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
|
||||||
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
|
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
|
||||||
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
|
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
|
||||||
auto filter = [](AnfNodePtr node) {
|
auto filter = [](AnfNodePtr node) {
|
||||||
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul));
|
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul) ||
|
||||||
|
IsPrimitiveCNode(node, prim::kPrimDepthwiseConv2dNative));
|
||||||
};
|
};
|
||||||
std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
|
std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
|
||||||
auto is_quant_cnode = [](AnfNodePtr node) {
|
auto is_quant_cnode = [](AnfNodePtr node) {
|
||||||
|
|
|
@ -530,6 +530,7 @@ _activation = {
|
||||||
'relu6': ReLU6,
|
'relu6': ReLU6,
|
||||||
'tanh': Tanh,
|
'tanh': Tanh,
|
||||||
'gelu': GELU,
|
'gelu': GELU,
|
||||||
|
'elu': ELU,
|
||||||
'sigmoid': Sigmoid,
|
'sigmoid': Sigmoid,
|
||||||
'prelu': PReLU,
|
'prelu': PReLU,
|
||||||
'leakyrelu': LeakyReLU,
|
'leakyrelu': LeakyReLU,
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import nn
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
|
@ -41,8 +42,7 @@ __all__ = [
|
||||||
'Conv2dBatchNormQuant',
|
'Conv2dBatchNormQuant',
|
||||||
'Conv2dQuant',
|
'Conv2dQuant',
|
||||||
'DenseQuant',
|
'DenseQuant',
|
||||||
'ReLUQuant',
|
'ActQuant',
|
||||||
'ReLU6Quant',
|
|
||||||
'HSwishQuant',
|
'HSwishQuant',
|
||||||
'HSigmoidQuant',
|
'HSigmoidQuant',
|
||||||
'TensorAddQuant',
|
'TensorAddQuant',
|
||||||
|
@ -375,9 +375,10 @@ class FakeQuantWithMinMax(Cell):
|
||||||
|
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
|
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
|
||||||
'quant_delay={}, min_init={}, max_init={}'.format(
|
'quant_delay={}, min_init={}, max_init={}'.format(self.num_bits, self.symmetric, self.narrow_range,
|
||||||
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel,
|
self.ema, self.ema_decay, self.per_channel,
|
||||||
self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init)
|
self.channel_axis, self.num_channels, self.quant_delay,
|
||||||
|
self.min_init, self.max_init)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -540,10 +541,12 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
|
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
|
||||||
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
||||||
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(
|
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'.format(self.in_channels, self.out_channels,
|
||||||
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
self.kernel_size, self.stride,
|
||||||
self.pad_mode, self.padding, self.dilation, self.group,
|
self.pad_mode, self.padding, self.dilation,
|
||||||
self.fake, self.freeze_bn, self.momentum, self.quant_delay)
|
self.group,
|
||||||
|
self.fake, self.freeze_bn, self.momentum,
|
||||||
|
self.quant_delay)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -685,10 +688,9 @@ class Conv2dQuant(Cell):
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
|
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
|
||||||
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
||||||
'has_bias={}, quant_delay={}'.format(
|
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||||
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
self.pad_mode, self.padding, self.dilation, self.group,
|
||||||
self.pad_mode, self.padding, self.dilation, self.group,
|
self.has_bias, self.quant_delay)
|
||||||
self.has_bias, self.quant_delay)
|
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
@ -799,76 +801,23 @@ class DenseQuant(Cell):
|
||||||
|
|
||||||
class _QuantActivation(Cell):
|
class _QuantActivation(Cell):
|
||||||
r"""
|
r"""
|
||||||
Base class for Quant activation function. Add Fake Quant OP after activation OP.
|
Base class for quantization aware training activation function. Add Fake Quant OP after activation OP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_origin(self):
|
def get_origin(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ReLUQuant(_QuantActivation):
|
class ActQuant(_QuantActivation):
|
||||||
r"""
|
r"""
|
||||||
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
|
Quantization aware training activation function.
|
||||||
|
|
||||||
For a more Detailed overview of ReLU op.
|
Add Fake Quant OP after activation. Not Recommand to used these cell for Fake Quant Op
|
||||||
|
|
||||||
Args:
|
|
||||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
|
||||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
- **x** (Tensor) - The input of ReLUQuant.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, with the same type and shape as the `x`.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> relu_quant = nn.ReLUQuant()
|
|
||||||
>>> input_x = Tensor(np.array([[1, 2, 0], [-1, -2, 1]]), mindspore.float32)
|
|
||||||
>>> result = relu_quant(input_x)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
ema_decay=0.999,
|
|
||||||
per_channel=False,
|
|
||||||
num_bits=8,
|
|
||||||
symmetric=False,
|
|
||||||
narrow_range=False,
|
|
||||||
quant_delay=0):
|
|
||||||
super(ReLUQuant, self).__init__()
|
|
||||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
|
||||||
max_init=6,
|
|
||||||
ema=True,
|
|
||||||
ema_decay=ema_decay,
|
|
||||||
per_channel=per_channel,
|
|
||||||
num_bits=num_bits,
|
|
||||||
symmetric=symmetric,
|
|
||||||
narrow_range=narrow_range,
|
|
||||||
quant_delay=quant_delay)
|
|
||||||
self.relu = P.ReLU()
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.fake_quant_act(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def get_origin(self):
|
|
||||||
return self.relu
|
|
||||||
|
|
||||||
|
|
||||||
class ReLU6Quant(_QuantActivation):
|
|
||||||
r"""
|
|
||||||
ReLU6Quant activation function.
|
|
||||||
|
|
||||||
Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op
|
|
||||||
Will climp the max range of the activation and the relu6 do the same operation.
|
Will climp the max range of the activation and the relu6 do the same operation.
|
||||||
For a more Detailed overview of ReLU6 op.
|
For a more Detailed overview of ReLU6 op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
activation (Cell): Activation cell class.
|
||||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
|
@ -883,19 +832,20 @@ class ReLU6Quant(_QuantActivation):
|
||||||
Tensor, with the same type and shape as the `x`.
|
Tensor, with the same type and shape as the `x`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> relu6_quant = nn.ReLU6Quant(4, 1)
|
>>> act_quant = nn.ActQuant(4, 1)
|
||||||
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
|
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
|
||||||
>>> result = relu6_quant(input_x)
|
>>> result = act_quant(input_x)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
activation,
|
||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
num_bits=8,
|
num_bits=8,
|
||||||
symmetric=False,
|
symmetric=False,
|
||||||
narrow_range=False,
|
narrow_range=False,
|
||||||
quant_delay=0):
|
quant_delay=0):
|
||||||
super(ReLU6Quant, self).__init__()
|
super(ActQuant, self).__init__()
|
||||||
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
ema=True,
|
ema=True,
|
||||||
|
@ -905,15 +855,15 @@ class ReLU6Quant(_QuantActivation):
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range,
|
narrow_range=narrow_range,
|
||||||
quant_delay=quant_delay)
|
quant_delay=quant_delay)
|
||||||
self.relu6 = P.ReLU6()
|
self.act = activation
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.relu6(x)
|
x = self.act(x)
|
||||||
x = self.fake_quant_act(x)
|
x = self.fake_quant_act(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_origin(self):
|
def get_origin(self):
|
||||||
return self.relu6
|
return self.act
|
||||||
|
|
||||||
|
|
||||||
class HSwishQuant(_QuantActivation):
|
class HSwishQuant(_QuantActivation):
|
||||||
|
@ -923,6 +873,7 @@ class HSwishQuant(_QuantActivation):
|
||||||
For a more Detailed overview of HSwish op.
|
For a more Detailed overview of HSwish op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
activation (Cell): Activation cell class.
|
||||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
|
@ -943,6 +894,7 @@ class HSwishQuant(_QuantActivation):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
activation,
|
||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
num_bits=8,
|
num_bits=8,
|
||||||
|
@ -968,7 +920,10 @@ class HSwishQuant(_QuantActivation):
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range,
|
narrow_range=narrow_range,
|
||||||
quant_delay=quant_delay)
|
quant_delay=quant_delay)
|
||||||
self.act = P.HSwish()
|
if isinstance(activation, nn.HSwish):
|
||||||
|
self.act = activation
|
||||||
|
else:
|
||||||
|
raise ValueError("Activation should be `nn.HSwish`")
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.fake_quant_act_before(x)
|
x = self.fake_quant_act_before(x)
|
||||||
|
@ -987,6 +942,7 @@ class HSigmoidQuant(_QuantActivation):
|
||||||
For a more Detailed overview of HSigmoid op.
|
For a more Detailed overview of HSigmoid op.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
activation (Cell): Activation cell class.
|
||||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
|
@ -1007,6 +963,7 @@ class HSigmoidQuant(_QuantActivation):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
activation,
|
||||||
ema_decay=0.999,
|
ema_decay=0.999,
|
||||||
per_channel=False,
|
per_channel=False,
|
||||||
num_bits=8,
|
num_bits=8,
|
||||||
|
@ -1032,7 +989,10 @@ class HSigmoidQuant(_QuantActivation):
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range,
|
narrow_range=narrow_range,
|
||||||
quant_delay=quant_delay)
|
quant_delay=quant_delay)
|
||||||
self.act = P.HSigmoid()
|
if isinstance(activation, nn.HSwish):
|
||||||
|
self.act = activation
|
||||||
|
else:
|
||||||
|
raise ValueError("Activation should be `nn.HSigmoid`")
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.fake_quant_act_before(x)
|
x = self.fake_quant_act_before(x)
|
||||||
|
|
|
@ -1004,6 +1004,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
|
||||||
def infer_dtype(self, x_dtype, w_dtype):
|
def infer_dtype(self, x_dtype, w_dtype):
|
||||||
args = {'x': x_dtype, 'w': w_dtype}
|
args = {'x': x_dtype, 'w': w_dtype}
|
||||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||||
|
if x_dtype.element_type() == mstype.int8:
|
||||||
|
return mstype.tensor_type(mstype.int32)
|
||||||
return x_dtype
|
return x_dtype
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -33,8 +33,10 @@ from ...ops.operations import _inner_ops as inner
|
||||||
from ...train import serialization
|
from ...train import serialization
|
||||||
from . import quant_utils
|
from . import quant_utils
|
||||||
|
|
||||||
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
|
||||||
nn.ReLU6: quant.ReLU6Quant,
|
nn.ReLU6: quant.ActQuant,
|
||||||
|
nn.LeakyReLU: quant.ActQuant,
|
||||||
|
nn.Sigmoid: quant.ActQuant,
|
||||||
nn.HSigmoid: quant.HSigmoidQuant,
|
nn.HSigmoid: quant.HSigmoidQuant,
|
||||||
nn.HSwish: quant.HSwishQuant}
|
nn.HSwish: quant.HSwishQuant}
|
||||||
|
|
||||||
|
@ -257,9 +259,9 @@ class ConvertToQuantNetwork:
|
||||||
def _convert_activation(self, activation):
|
def _convert_activation(self, activation):
|
||||||
act_class = activation.__class__
|
act_class = activation.__class__
|
||||||
if act_class not in _ACTIVATION_MAP:
|
if act_class not in _ACTIVATION_MAP:
|
||||||
raise ValueError(
|
raise ValueError("Unsupported activation in auto quant: ", act_class)
|
||||||
"Unsupported activation in auto quant: ", act_class)
|
return _ACTIVATION_MAP[act_class](activation=act_class,
|
||||||
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
|
num_bits=self.act_bits,
|
||||||
quant_delay=self.act_qdelay,
|
quant_delay=self.act_qdelay,
|
||||||
per_channel=self.act_channel,
|
per_channel=self.act_channel,
|
||||||
symmetric=self.act_symmetric,
|
symmetric=self.act_symmetric,
|
||||||
|
@ -317,7 +319,7 @@ class ExportToQuantInferNetwork:
|
||||||
minq = self.all_parameters[minq_name]
|
minq = self.all_parameters[minq_name]
|
||||||
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
|
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}")
|
logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Build the `Quant` `Dequant` op.
|
# Build the `Quant` `Dequant` op.
|
||||||
|
@ -325,7 +327,7 @@ class ExportToQuantInferNetwork:
|
||||||
quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in))
|
quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in))
|
||||||
sqrt_mode = False
|
sqrt_mode = False
|
||||||
scale_deq = scale_a_out * scale_w
|
scale_deq = scale_a_out * scale_w
|
||||||
if scale_deq < 2 ** -14:
|
if (scale_deq < 2 ** -14).all():
|
||||||
scale_deq = np.sqrt(scale_deq)
|
scale_deq = np.sqrt(scale_deq)
|
||||||
sqrt_mode = True
|
sqrt_mode = True
|
||||||
dequant_op = inner.AscendDequant(sqrt_mode)
|
dequant_op = inner.AscendDequant(sqrt_mode)
|
||||||
|
@ -404,11 +406,17 @@ def export(network, *inputs, file_name, file_format='GEIR'):
|
||||||
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
|
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
|
||||||
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
|
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
|
||||||
"""
|
"""
|
||||||
|
supported_device = ["Ascend"]
|
||||||
supported_formats = ['GEIR']
|
supported_formats = ['GEIR']
|
||||||
|
|
||||||
|
if context.get_context('device_target') not in supported_device:
|
||||||
|
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
||||||
|
|
||||||
if file_format not in supported_formats:
|
if file_format not in supported_formats:
|
||||||
raise ValueError('Illegal file format {}.'.format(file_format))
|
raise ValueError('Illegal file format {}.'.format(file_format))
|
||||||
|
|
||||||
|
network.set_train(False)
|
||||||
|
|
||||||
if file_format == 'GEIR':
|
if file_format == 'GEIR':
|
||||||
exporter = ExportToQuantInferNetwork(network, *inputs)
|
exporter = ExportToQuantInferNetwork(network, *inputs)
|
||||||
deploy_net = exporter.run()
|
deploy_net = exporter.run()
|
||||||
|
|
|
@ -45,7 +45,7 @@ def cal_quantization_params(input_min,
|
||||||
raise ValueError("input min shape should equal to input max.")
|
raise ValueError("input min shape should equal to input max.")
|
||||||
if len(input_min.shape) > 1:
|
if len(input_min.shape) > 1:
|
||||||
raise ValueError("input min and max shape should be one dim.")
|
raise ValueError("input min and max shape should be one dim.")
|
||||||
if input_min > input_max:
|
if (input_min > input_max).all():
|
||||||
raise ValueError("input_min min should less than input max.")
|
raise ValueError("input_min min should less than input max.")
|
||||||
if (input_max == input_min).all():
|
if (input_max == input_min).all():
|
||||||
# scale = 1.0, zp = 0.0
|
# scale = 1.0, zp = 0.0
|
||||||
|
@ -85,9 +85,7 @@ def cal_quantization_params(input_min,
|
||||||
return scale, zp
|
return scale, zp
|
||||||
|
|
||||||
|
|
||||||
def weight2int(data,
|
def weight2int(data, scale, zero_point):
|
||||||
scale,
|
|
||||||
zero_point):
|
|
||||||
r"""
|
r"""
|
||||||
Calculate int8/uint8 weight from fp32. the formula is defined as:
|
Calculate int8/uint8 weight from fp32. the formula is defined as:
|
||||||
|
|
||||||
|
@ -103,12 +101,24 @@ def weight2int(data,
|
||||||
weight (numpy.ndarray): The dimension of channel or 1.
|
weight (numpy.ndarray): The dimension of channel or 1.
|
||||||
"""
|
"""
|
||||||
if scale.shape != zero_point.shape:
|
if scale.shape != zero_point.shape:
|
||||||
raise ValueError("scale and zero_point should have the same shape.")
|
raise ValueError("`scale` and `zero_point` should have the same shape.")
|
||||||
if scale.shape[0] > 0:
|
if scale.shape[0] < 0:
|
||||||
scale = scale.reshape(1, -1)
|
raise ValueError("`scale` and `zero_point` shape should greater than zero.")
|
||||||
zero_point = zero_point.reshape(1, -1)
|
|
||||||
|
|
||||||
return np.round((data/scale) + zero_point)
|
if scale.shape[0] == data.shape[0]:
|
||||||
|
# `Conv2d` or `Dense` op weight
|
||||||
|
shape_list = [-1] + [1] * len(data.shape[1:])
|
||||||
|
scale = scale.reshape(shape_list)
|
||||||
|
zero_point = zero_point.reshape(shape_list)
|
||||||
|
elif scale.shape[0] == data.shape[1]:
|
||||||
|
# `DepthwiseConv2d` op weight
|
||||||
|
shape_list = [1, -1] + [1] * len(data.shape[2:])
|
||||||
|
scale = scale.reshape(shape_list)
|
||||||
|
zero_point = zero_point.reshape(shape_list)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported weight shape({})".format(data.shape))
|
||||||
|
|
||||||
|
return np.round((data / scale) + zero_point)
|
||||||
|
|
||||||
|
|
||||||
def scale_zp_from_fack_quant_cell(cell, data_type):
|
def scale_zp_from_fack_quant_cell(cell, data_type):
|
||||||
|
@ -183,9 +193,20 @@ def fold_batchnorm(weight, cell_quant):
|
||||||
beta = cell_quant.beta.data.asnumpy()
|
beta = cell_quant.beta.data.asnumpy()
|
||||||
epsilon = cell_quant.eps
|
epsilon = cell_quant.eps
|
||||||
sigma = np.sqrt(variance + epsilon)
|
sigma = np.sqrt(variance + epsilon)
|
||||||
gamma = gamma.reshape(-1, 1, 1, 1)
|
|
||||||
sigma = sigma.reshape(-1, 1, 1, 1)
|
if gamma.shape[0] == weight.shape[0]:
|
||||||
mean = mean.reshape(-1, 1, 1, 1)
|
# `Conv2d` or `Dense` op weight
|
||||||
weight = weight * gamma / sigma
|
shape_list = [-1] + [1] * len(weight.shape[1:])
|
||||||
|
_gamma = gamma.reshape(shape_list)
|
||||||
|
_sigma = sigma.reshape(shape_list)
|
||||||
|
elif gamma.shape[0] == weight.shape[1]:
|
||||||
|
# `DepthwiseConv2d` op weight
|
||||||
|
shape_list = [1, -1] + [1] * len(weight.shape[2:])
|
||||||
|
_gamma = gamma.reshape(shape_list)
|
||||||
|
_sigma = sigma.reshape(shape_list)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported weight shape({})".format(weight.shape))
|
||||||
|
|
||||||
|
weight = weight * _gamma / _sigma
|
||||||
bias = beta - gamma * mean / sigma
|
bias = beta - gamma * mean / sigma
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Export MobilenetV2 on ImageNet"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.train.quant import quant
|
||||||
|
|
||||||
|
from src.mobilenetV2 import mobilenetV2
|
||||||
|
from src.config import config_ascend
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||||
|
parser.add_argument('--device_target', type=str, default=None, help='Run device target')
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
cfg = None
|
||||||
|
if args_opt.device_target == "Ascend":
|
||||||
|
cfg = config_ascend
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
|
||||||
|
|
||||||
|
# define fusion network
|
||||||
|
network = mobilenetV2(num_classes=cfg.num_classes)
|
||||||
|
# convert fusion network to quantization aware network
|
||||||
|
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
|
||||||
|
# load checkpoint
|
||||||
|
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||||
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
|
# export network
|
||||||
|
print("============== Starting export ==============")
|
||||||
|
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||||
|
quant.export(network, inputs, file_name="mobilenet_quant", file_format='GEIR')
|
||||||
|
print("============== End export ==============")
|
Loading…
Reference in New Issue