!14330 Amend manual quant export and add resnet50_quant export.py

From: @zhang__sss
Reviewed-by: @zhoufeng54
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-30 19:14:12 +08:00 committed by Gitee
commit f47767b361
4 changed files with 90 additions and 12 deletions

View File

@ -228,10 +228,17 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
__quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir)
super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir=is_mindir)
self.upcell = None
self.upname = None
def _add_output_min_max_for_op(self, origin_op, fake_quant_cell):
if self.is_mindir:
np_type = mstype.dtype_to_nptype(self.data_type)
_, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_cell, np_type)
origin_op.add_prim_attr('output_maxq', Tensor(maxq))
origin_op.add_prim_attr('output_minq', Tensor(minq))
def _convert_quant2deploy(self, network):
"""Convert network's all quant subcell to deploy subcell."""
cells = network.name_cells()
@ -247,18 +254,31 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant,
quant.Conv2dQuant, quant.DenseQuant)):
network, change = self._convert_subcell(network, change, name, subcell, core=False)
elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver) and self.upcell:
np_type = mstype.dtype_to_nptype(self.data_type)
_, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(subcell, np_type)
self.upcell.core_op.add_prim_attr('output_maxq', Tensor(maxq))
self.upcell.core_op.add_prim_attr('output_minq', Tensor(minq))
network.insert_child_to_cell(self.upname, self.upcell)
elif isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"):
if self.upcell:
self._add_output_min_max_for_op(self.upcell.core_op, subcell.fake_quant_act)
activation = subcell.get_origin()
network.insert_child_to_cell(name, activation)
change = True
elif isinstance(subcell, nn.TensorAddQuant):
if isinstance(subcell.add, _AddFakeQuantAfterSubCell):
add_op = subcell.add.subcell
subcell.__delattr__("add")
subcell.__setattr__("add", add_op)
add_op = subcell.add
if add_op:
self._add_output_min_max_for_op(add_op, subcell.fake_quant_act)
subcell.__delattr__("fake_quant_act")
subcell.__setattr__("fake_quant_act", P.identity())
elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver):
if self.upcell:
self._add_output_min_max_for_op(self.upcell.core_op, subcell)
network.__delattr__(name)
network.__setattr__(name, P.identity())
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
op = subcell.subcell
if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive):
if self.is_mindir:
op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy()))
op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy()))
self._add_output_min_max_for_op(op, subcell.fake_quant_act)
network.__delattr__(name)
network.__setattr__(name, op)
change = True
@ -271,15 +291,18 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
def _convert_subcell(self, network, change, name, subcell, core=True, conv=True):
"""Convert subcell to ant subcell."""
new_subcell = None
if core:
cell_core = subcell.conv if conv else subcell.dense
activation = subcell.activation
fake_quant_act = activation.fake_quant_act
if hasattr(activation, 'fake_quant_act'):
fake_quant_act = activation.fake_quant_act
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
else:
cell_core = subcell
activation = None
fake_quant_act = None
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
if new_subcell:
prefix = subcell.param_prefix
new_subcell.update_parameters_name(prefix + '.')

View File

@ -87,6 +87,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
│ ├──crossentropy.py # define the crossentropy of resnet50-quant
├── train.py # training script
├── eval.py # evaluation script
├── export.py # export script
```

View File

@ -95,6 +95,7 @@ ResNet-50总体网络架构如下
│ ├──crossentropy.py # 定义ResNet-50-Quant的交叉熵
├── train.py # 训练脚本
├── eval.py # 评估脚本
├── export.py # 导出脚本
```

View File

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export Resnet50 on ImageNet"""
import argparse
import numpy as np
import mindspore
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore.compression.quant import QuantizationAwareTraining
from models.resnet_quant_manual import resnet50_quant
from src.config import config_quant
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument('--device_target', type=str, default=None, help='Run device target')
args_opt = parser.parse_args()
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
# define fusion network
network = resnet50_quant(class_num=config_quant.class_num)
# convert fusion network to quantization aware network
quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)
# load checkpoint
if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path)
not_load_param = load_param_into_net(network, param_dict)
if not_load_param:
raise ValueError("Load param into network fail!")
# export network
print("============== Starting export ==============")
inputs = Tensor(np.ones([1, 3, 224, 224]), mindspore.float32)
export(network, inputs, file_name="resnet50_quant", file_format=args_opt.file_format,
quant_mode='MANUAL', mean=0., std_dev=48.106)
print("============== End export ==============")