forked from OSSInnovation/mindspore
!14330 Amend manual quant export and add resnet50_quant export.py
From: @zhang__sss Reviewed-by: @zhoufeng54 Signed-off-by:
This commit is contained in:
commit
f47767b361
|
@ -228,10 +228,17 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
|
|||
__quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
|
||||
super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir)
|
||||
super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir=is_mindir)
|
||||
self.upcell = None
|
||||
self.upname = None
|
||||
|
||||
def _add_output_min_max_for_op(self, origin_op, fake_quant_cell):
|
||||
if self.is_mindir:
|
||||
np_type = mstype.dtype_to_nptype(self.data_type)
|
||||
_, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_cell, np_type)
|
||||
origin_op.add_prim_attr('output_maxq', Tensor(maxq))
|
||||
origin_op.add_prim_attr('output_minq', Tensor(minq))
|
||||
|
||||
def _convert_quant2deploy(self, network):
|
||||
"""Convert network's all quant subcell to deploy subcell."""
|
||||
cells = network.name_cells()
|
||||
|
@ -247,18 +254,31 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
|
|||
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant,
|
||||
quant.Conv2dQuant, quant.DenseQuant)):
|
||||
network, change = self._convert_subcell(network, change, name, subcell, core=False)
|
||||
elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver) and self.upcell:
|
||||
np_type = mstype.dtype_to_nptype(self.data_type)
|
||||
_, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(subcell, np_type)
|
||||
self.upcell.core_op.add_prim_attr('output_maxq', Tensor(maxq))
|
||||
self.upcell.core_op.add_prim_attr('output_minq', Tensor(minq))
|
||||
network.insert_child_to_cell(self.upname, self.upcell)
|
||||
elif isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"):
|
||||
if self.upcell:
|
||||
self._add_output_min_max_for_op(self.upcell.core_op, subcell.fake_quant_act)
|
||||
activation = subcell.get_origin()
|
||||
network.insert_child_to_cell(name, activation)
|
||||
change = True
|
||||
elif isinstance(subcell, nn.TensorAddQuant):
|
||||
if isinstance(subcell.add, _AddFakeQuantAfterSubCell):
|
||||
add_op = subcell.add.subcell
|
||||
subcell.__delattr__("add")
|
||||
subcell.__setattr__("add", add_op)
|
||||
add_op = subcell.add
|
||||
if add_op:
|
||||
self._add_output_min_max_for_op(add_op, subcell.fake_quant_act)
|
||||
subcell.__delattr__("fake_quant_act")
|
||||
subcell.__setattr__("fake_quant_act", P.identity())
|
||||
elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver):
|
||||
if self.upcell:
|
||||
self._add_output_min_max_for_op(self.upcell.core_op, subcell)
|
||||
network.__delattr__(name)
|
||||
network.__setattr__(name, P.identity())
|
||||
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
|
||||
op = subcell.subcell
|
||||
if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive):
|
||||
if self.is_mindir:
|
||||
op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy()))
|
||||
op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy()))
|
||||
self._add_output_min_max_for_op(op, subcell.fake_quant_act)
|
||||
network.__delattr__(name)
|
||||
network.__setattr__(name, op)
|
||||
change = True
|
||||
|
@ -271,15 +291,18 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
|
|||
|
||||
def _convert_subcell(self, network, change, name, subcell, core=True, conv=True):
|
||||
"""Convert subcell to ant subcell."""
|
||||
new_subcell = None
|
||||
if core:
|
||||
cell_core = subcell.conv if conv else subcell.dense
|
||||
activation = subcell.activation
|
||||
fake_quant_act = activation.fake_quant_act
|
||||
if hasattr(activation, 'fake_quant_act'):
|
||||
fake_quant_act = activation.fake_quant_act
|
||||
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
|
||||
else:
|
||||
cell_core = subcell
|
||||
activation = None
|
||||
fake_quant_act = None
|
||||
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
|
||||
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
|
||||
if new_subcell:
|
||||
prefix = subcell.param_prefix
|
||||
new_subcell.update_parameters_name(prefix + '.')
|
||||
|
|
|
@ -87,6 +87,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
│ ├──crossentropy.py # define the crossentropy of resnet50-quant
|
||||
├── train.py # training script
|
||||
├── eval.py # evaluation script
|
||||
├── export.py # export script
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -95,6 +95,7 @@ ResNet-50总体网络架构如下:
|
|||
│ ├──crossentropy.py # 定义ResNet-50-Quant的交叉熵
|
||||
├── train.py # 训练脚本
|
||||
├── eval.py # 评估脚本
|
||||
├── export.py # 导出脚本
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Export Resnet50 on ImageNet"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
|
||||
from models.resnet_quant_manual import resnet50_quant
|
||||
from src.config import config_quant
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
|
||||
parser.add_argument('--device_target', type=str, default=None, help='Run device target')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False)
|
||||
# define fusion network
|
||||
network = resnet50_quant(class_num=config_quant.class_num)
|
||||
# convert fusion network to quantization aware network
|
||||
quantizer = QuantizationAwareTraining(bn_fold=True,
|
||||
per_channel=[True, False],
|
||||
symmetric=[True, False])
|
||||
network = quantizer.quantize(network)
|
||||
# load checkpoint
|
||||
if args_opt.checkpoint_path:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
not_load_param = load_param_into_net(network, param_dict)
|
||||
if not_load_param:
|
||||
raise ValueError("Load param into network fail!")
|
||||
# export network
|
||||
print("============== Starting export ==============")
|
||||
inputs = Tensor(np.ones([1, 3, 224, 224]), mindspore.float32)
|
||||
export(network, inputs, file_name="resnet50_quant", file_format=args_opt.file_format,
|
||||
quant_mode='MANUAL', mean=0., std_dev=48.106)
|
||||
print("============== End export ==============")
|
Loading…
Reference in New Issue