From 518de99cde1746837a5550bb70e6e161bb309808 Mon Sep 17 00:00:00 2001 From: zhang__sss Date: Mon, 26 Apr 2021 11:02:06 +0800 Subject: [PATCH] quant export --- mindspore/compression/export/quant_export.py | 88 ++++--------------- mindspore/train/serialization.py | 47 +++++----- model_zoo/official/cv/lenet_quant/export.py | 2 +- .../cv/lenet_quant/src/lenet_quant.py | 17 ++-- .../official/cv/mobilenetv2_quant/export.py | 8 +- model_zoo/official/cv/resnet50_quant/eval.py | 9 +- .../official/cv/resnet50_quant/export.py | 11 +-- model_zoo/official/cv/resnet50_quant/train.py | 11 +-- .../cv/yolov3_darknet53_quant/export.py | 5 +- 9 files changed, 64 insertions(+), 134 deletions(-) diff --git a/mindspore/compression/export/quant_export.py b/mindspore/compression/export/quant_export.py index 96a9121bf11..472170fe832 100644 --- a/mindspore/compression/export/quant_export.py +++ b/mindspore/compression/export/quant_export.py @@ -30,18 +30,20 @@ from ..quant import quant_utils from ..quant.qat import QuantizationAwareTraining, _AddFakeQuantInput, _AddFakeQuantAfterSubCell -__all__ = ["ExportToQuantInferNetwork", "ExportManualQuantNetwork"] +__all__ = ["ExportToQuantInferNetwork"] class ExportToQuantInferNetwork: """ Convert quantization aware network to infer network. Args: - network (Cell): MindSpore network API `convert_quant_network`. + network (Cell): MindSpore quantization aware training network. inputs (Tensor): Input tensors of the `quantization aware training network`. - mean (int): Input data mean. Default: 127.5. - std_dev (int, float): Input data variance. Default: 127.5. - is_mindir (bool): Whether is MINDIR format. Default: False. + mean (int, float): The mean of input data after preprocessing, used for quantizing the first layer of network. + Default: 127.5. + std_dev (int, float): The variance of input data after preprocessing, used for quantizing the first layer + of network. Default: 127.5. + is_mindir (bool): Whether export MINDIR format. Default: False. Returns: Cell, Infer network. @@ -59,9 +61,11 @@ class ExportToQuantInferNetwork: self.mean = mean self.std_dev = std_dev self.is_mindir = is_mindir + self.upcell = None + self.upname = None def get_inputs_table(self, inputs): - """Get the support info for quant export.""" + """Get the input quantization parameters of quantization cell for quant export.""" phase_name = 'export_quant' graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False) self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id) @@ -151,7 +155,6 @@ class ExportToQuantInferNetwork: dequant_param = np.zeros(scale_length, dtype=np.uint64) for index in range(scale_length): dequant_param[index] += uint32_deq_scale[index] - scale_deq = Tensor(dequant_param, mstype.uint64) # get op if isinstance(cell_core, quant.DenseQuant): @@ -170,69 +173,8 @@ class ExportToQuantInferNetwork: block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) return block - def _convert_quant2deploy(self, network): - """Convert network's all quant subcell to deploy subcell.""" - cells = network.name_cells() - change = False - for name in cells: - subcell = cells[name] - if subcell == network: - continue - cell_core = None - fake_quant_act = None - activation = None - if isinstance(subcell, nn.Conv2dBnAct): - cell_core = subcell.conv - activation = subcell.activation - fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None - elif isinstance(subcell, nn.DenseBnAct): - cell_core = subcell.dense - activation = subcell.activation - fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None - if cell_core is not None: - new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) - if new_subcell: - prefix = subcell.param_prefix - new_subcell.update_parameters_name(prefix + '.') - network.insert_child_to_cell(name, new_subcell) - change = True - elif isinstance(subcell, _AddFakeQuantAfterSubCell): - op = subcell.subcell - if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive): - if self.is_mindir: - op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy())) - op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy())) - network.__delattr__(name) - network.__setattr__(name, op) - change = True - else: - self._convert_quant2deploy(subcell) - if isinstance(network, nn.SequentialCell) and change: - network.cell_list = list(network.cells()) - return network - -class ExportManualQuantNetwork(ExportToQuantInferNetwork): - """ - Convert manual quantization aware network to infer network. - - Args: - network (Cell): MindSpore network API `convert_quant_network`. - inputs (Tensor): Input tensors of the `quantization aware training network`. - mean (int): Input data mean. Default: 127.5. - std_dev (int, float): Input data variance. Default: 127.5. - is_mindir (bool): Whether is MINDIR format. Default: False. - - Returns: - Cell, Infer network. - """ - __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"] - - def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): - super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir=is_mindir) - self.upcell = None - self.upname = None - def _add_output_min_max_for_op(self, origin_op, fake_quant_cell): + """add output quant info for quant op for export mindir.""" if self.is_mindir: np_type = mstype.dtype_to_nptype(self.data_type) _, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_cell, np_type) @@ -251,8 +193,8 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork): network, change = self._convert_subcell(network, change, name, subcell) elif isinstance(subcell, nn.DenseBnAct): network, change = self._convert_subcell(network, change, name, subcell, conv=False) - elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant, - quant.Conv2dQuant, quant.DenseQuant)): + elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnFoldQuantOneConv, + quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)): network, change = self._convert_subcell(network, change, name, subcell, core=False) elif isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"): if self.upcell: @@ -292,16 +234,16 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork): def _convert_subcell(self, network, change, name, subcell, core=True, conv=True): """Convert subcell to ant subcell.""" new_subcell = None + fake_quant_act = None if core: cell_core = subcell.conv if conv else subcell.dense activation = subcell.activation if hasattr(activation, 'fake_quant_act'): fake_quant_act = activation.fake_quant_act - new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) else: cell_core = subcell activation = None - fake_quant_act = None + if cell_core is not None and hasattr(cell_core, "fake_quant_weight"): new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) if new_subcell: prefix = subcell.param_prefix diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 4b8025d2feb..191686d7950 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -596,9 +596,12 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs): kwargs (dict): Configuration options dictionary. - - quant_mode: The mode of quant. - - mean: Input data mean. Default: 127.5. - - std_dev: Input data variance. Default: 127.5. + - quant_mode: If the network is quantization aware training network, the quant_mode should + be set to "QUANT", else the quant_mode should be set to "NONQUANT". + - mean: The mean of input data after preprocessing, used for quantizing the first layer of network. + Default: 127.5. + - std_dev: The variance of input data after preprocessing, used for quantizing the first layer of network. + Default: 127.5. """ logger.info("exporting model file:%s format:%s.", file_name, file_format) check_input_data(*inputs, data_class=Tensor) @@ -748,28 +751,38 @@ def _mindir_save_together(net_dict, model): return False return True +def quant_mode_manage(func): + """ + Inherit the quant_mode in old version. + """ + def warpper(network, *inputs, file_format, **kwargs): + if not kwargs.get('quant_mode', None): + return network + quant_mode = kwargs['quant_mode'] + if quant_mode in ('AUTO', 'MANUAL'): + kwargs['quant_mode'] = 'QUANT' + return func(network, *inputs, file_format=file_format, **kwargs) + return warpper +@quant_mode_manage def _quant_export(network, *inputs, file_format, **kwargs): """ Exports MindSpore quantization predict model to deploy with AIR and MINDIR. """ - if not kwargs.get('quant_mode', None): - return network - supported_device = ["Ascend", "GPU"] supported_formats = ['AIR', 'MINDIR'] - quant_mode_formats = ['AUTO', 'MANUAL'] + quant_mode_formats = ['QUANT', 'NONQUANT'] + quant_mode = kwargs['quant_mode'] + if quant_mode not in quant_mode_formats: + raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.') + if quant_mode == 'NONQUANT': + return network quant_net = copy.deepcopy(network) quant_net._create_time = int(time.time() * 1e9) mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean'] std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev'] - - quant_mode = kwargs['quant_mode'] - if quant_mode not in quant_mode_formats: - raise KeyError(f'Quant_mode input is wrong, Please choose the right mode of the quant_mode.') - mean = Validator.check_value_type("mean", mean, (int, float)) std_dev = Validator.check_value_type("std_dev", std_dev, (int, float)) @@ -781,15 +794,9 @@ def _quant_export(network, *inputs, file_format, **kwargs): quant_net.set_train(False) if file_format == "MINDIR": - if quant_mode == 'MANUAL': - exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True) - else: - exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True) + exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True) else: - if quant_mode == 'MANUAL': - exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs) - else: - exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs) + exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs) deploy_net = exporter.run() return deploy_net diff --git a/model_zoo/official/cv/lenet_quant/export.py b/model_zoo/official/cv/lenet_quant/export.py index a37c90f6c58..1608cdfc0f5 100644 --- a/model_zoo/official/cv/lenet_quant/export.py +++ b/model_zoo/official/cv/lenet_quant/export.py @@ -54,4 +54,4 @@ if __name__ == "__main__": # export network inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32) - export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='AUTO') + export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='QUANT') diff --git a/model_zoo/official/cv/lenet_quant/src/lenet_quant.py b/model_zoo/official/cv/lenet_quant/src/lenet_quant.py index d4e036f55e0..1fd6e7f3656 100644 --- a/model_zoo/official/cv/lenet_quant/src/lenet_quant.py +++ b/model_zoo/official/cv/lenet_quant/src/lenet_quant.py @@ -15,7 +15,8 @@ """Manual construct network for LeNet""" import mindspore.nn as nn - +from mindspore.compression.quant import create_quant_config +from mindspore.compression.common import QuantDtype class LeNet5(nn.Cell): """ @@ -34,14 +35,16 @@ class LeNet5(nn.Cell): def __init__(self, num_class=10, channel=1): super(LeNet5, self).__init__() self.num_class = num_class + self.qconfig = create_quant_config(per_channel=(True, False), symmetric=(True, False)) - self.conv1 = nn.Conv2dBnFoldQuant(channel, 6, 5, pad_mode='valid', per_channel=True, quant_delay=900) - self.conv2 = nn.Conv2dBnFoldQuant(6, 16, 5, pad_mode='valid', per_channel=True, quant_delay=900) - self.fc1 = nn.DenseQuant(16 * 5 * 5, 120, per_channel=True, quant_delay=900) - self.fc2 = nn.DenseQuant(120, 84, per_channel=True, quant_delay=900) - self.fc3 = nn.DenseQuant(84, self.num_class, per_channel=True, quant_delay=900) + self.conv1 = nn.Conv2dQuant(channel, 6, 5, pad_mode='valid', quant_config=self.qconfig, + quant_dtype=QuantDtype.INT8) + self.conv2 = nn.Conv2dQuant(6, 16, 5, pad_mode='valid', quant_config=self.qconfig, quant_dtype=QuantDtype.INT8) + self.fc1 = nn.DenseQuant(16 * 5 * 5, 120, quant_config=self.qconfig, quant_dtype=QuantDtype.INT8) + self.fc2 = nn.DenseQuant(120, 84, quant_config=self.qconfig, quant_dtype=QuantDtype.INT8) + self.fc3 = nn.DenseQuant(84, self.num_class, quant_config=self.qconfig, quant_dtype=QuantDtype.INT8) - self.relu = nn.ActQuant(nn.ReLU(), per_channel=False, quant_delay=900) + self.relu = nn.ActQuant(nn.ReLU(), quant_config=self.qconfig, quant_dtype=QuantDtype.INT8) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() diff --git a/model_zoo/official/cv/mobilenetv2_quant/export.py b/model_zoo/official/cv/mobilenetv2_quant/export.py index 622952532f7..0f2c28b8afa 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/export.py +++ b/model_zoo/official/cv/mobilenetv2_quant/export.py @@ -47,9 +47,7 @@ if __name__ == '__main__': # export network print("============== Starting export ==============") inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32) - if args_opt.file_format == 'MINDIR': - export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR', quant_mode='AUTO') - else: - export(network, inputs, file_name="mobilenet_quant", file_format='AIR', - quant_mode='AUTO', mean=0., std_dev=48.106) + export(network, inputs, file_name="mobilenetv2_quant", file_format=args_opt.file_format, + quant_mode='QUANT', mean=0., std_dev=48.106) + print("============== End export ==============") diff --git a/model_zoo/official/cv/resnet50_quant/eval.py b/model_zoo/official/cv/resnet50_quant/eval.py index 30a0b46bcab..ea8d3c11d3e 100755 --- a/model_zoo/official/cv/resnet50_quant/eval.py +++ b/model_zoo/official/cv/resnet50_quant/eval.py @@ -20,13 +20,11 @@ import argparse from src.config import config_quant from src.dataset import create_dataset from src.crossentropy import CrossEntropy -#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50 from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50 from mindspore import context from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.compression.quant import QuantizationAwareTraining parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') @@ -42,13 +40,8 @@ if args_opt.device_target == "Ascend": context.set_context(device_id=device_id) if __name__ == '__main__': - # define fusion network + # define manual quantization network network = resnet50_quant(class_num=config.class_num) - # convert fusion network to quantization aware network - quantizer = QuantizationAwareTraining(bn_fold=True, - per_channel=[True, False], - symmetric=[True, False]) - network = quantizer.quantize(network) # define network loss if not config.use_label_smooth: diff --git a/model_zoo/official/cv/resnet50_quant/export.py b/model_zoo/official/cv/resnet50_quant/export.py index 81f2ef3caae..8424a125a6d 100644 --- a/model_zoo/official/cv/resnet50_quant/export.py +++ b/model_zoo/official/cv/resnet50_quant/export.py @@ -19,7 +19,6 @@ import numpy as np import mindspore from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export -from mindspore.compression.quant import QuantizationAwareTraining from models.resnet_quant_manual import resnet50_quant from src.config import config_quant @@ -32,13 +31,9 @@ args_opt = parser.parse_args() if __name__ == '__main__': context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False) - # define fusion network + # define manual quantization network network = resnet50_quant(class_num=config_quant.class_num) - # convert fusion network to quantization aware network - quantizer = QuantizationAwareTraining(bn_fold=True, - per_channel=[True, False], - symmetric=[True, False]) - network = quantizer.quantize(network) + # load checkpoint if args_opt.checkpoint_path: param_dict = load_checkpoint(args_opt.checkpoint_path) @@ -49,5 +44,5 @@ if __name__ == '__main__': print("============== Starting export ==============") inputs = Tensor(np.ones([1, 3, 224, 224]), mindspore.float32) export(network, inputs, file_name="resnet50_quant", file_format=args_opt.file_format, - quant_mode='MANUAL', mean=0., std_dev=48.106) + quant_mode='QUANT', mean=0., std_dev=48.106) print("============== End export ==============") diff --git a/model_zoo/official/cv/resnet50_quant/train.py b/model_zoo/official/cv/resnet50_quant/train.py index a5112066b98..f9349885f2d 100755 --- a/model_zoo/official/cv/resnet50_quant/train.py +++ b/model_zoo/official/cv/resnet50_quant/train.py @@ -25,14 +25,12 @@ from mindspore.context import ParallelMode from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.serialization import load_checkpoint -from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.communication.management import init import mindspore.nn as nn import mindspore.common.initializer as weight_init from mindspore.common import set_seed -#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50 from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50 from src.dataset import create_dataset from src.lr_generator import get_lr @@ -80,7 +78,7 @@ if __name__ == '__main__': parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, all_reduce_fusion_config=[107, 160]) - # define network + # define manual quantization network net = resnet50_quant(class_num=config.class_num) net.set_train(True) @@ -112,13 +110,6 @@ if __name__ == '__main__': target=args_opt.device_target) step_size = dataset.get_dataset_size() - # convert fusion network to quantization aware network - quantizer = QuantizationAwareTraining(bn_fold=True, - per_channel=[True, False], - symmetric=[True, False], - one_conv_fold=False) - net = quantizer.quantize(net) - # get learning rate lr = get_lr(lr_init=config.lr_init, lr_end=0.0, diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/export.py b/model_zoo/official/cv/yolov3_darknet53_quant/export.py index 440833afa43..09856164d2d 100644 --- a/model_zoo/official/cv/yolov3_darknet53_quant/export.py +++ b/model_zoo/official/cv/yolov3_darknet53_quant/export.py @@ -28,7 +28,7 @@ parser.add_argument("--device_id", type=int, default=0, help="Device id") parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") parser.add_argument("--file_name", type=str, default="yolov3_darknet53_quant", help="output file name.") -parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='MINDIR', help='file format') +parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='MINDIR', help='file format') args = parser.parse_args() context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) @@ -50,4 +50,5 @@ if __name__ == "__main__": input_data = Tensor(np.zeros(shape), ms.float32) input_shape = Tensor(tuple(config.test_img_shape), ms.float32) - export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format) + export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format, + quant_mode='QUANT', mean=0., std_dev=48.106)