From 02228d342c190456dfa039d41000ea52e6dd9f27 Mon Sep 17 00:00:00 2001 From: Erpim Date: Mon, 31 May 2021 15:22:22 +0800 Subject: [PATCH] Modified the ST of lenet_quant, and fix the neg_trunc bug for manual quantization --- mindspore/compression/quant/qat.py | 7 -- mindspore/nn/layer/quant.py | 16 +++- tests/st/quantization/lenet_quant/config.py | 13 --- tests/st/quantization/lenet_quant/lenet.py | 79 ------------------- .../lenet_quant/test_lenet_quant.py | 53 +------------ 5 files changed, 18 insertions(+), 150 deletions(-) delete mode 100644 tests/st/quantization/lenet_quant/lenet.py diff --git a/mindspore/compression/quant/qat.py b/mindspore/compression/quant/qat.py index efcee7cf2ba..e48c4d48d82 100644 --- a/mindspore/compression/quant/qat.py +++ b/mindspore/compression/quant/qat.py @@ -518,15 +518,8 @@ class QuantizationAwareTraining(Quantizer): """ act_class = activation.__class__ act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid] - neg_trunc_act_list = [nn.ReLU, nn.ReLU6] act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish] - if act_class in neg_trunc_act_list and OptimizeOption.LEARNED_SCALE in self.optimize_option: - self.quant_config = self.quant_config._replace( - activation=self.quant_config.activation.partial_init(neg_trunc=True, narrow_range=False)) - return quant.ActQuant(activation=activation, - quant_config=self.quant_config, - quant_dtype=self.act_dtype) if act_class in act_list: return quant.ActQuant(activation=activation, quant_config=self.quant_config, diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index c1406829b77..5ce3963a195 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -30,6 +30,7 @@ import mindspore.context as context from .normalization import BatchNorm2d from .activation import get_activation from ..cell import Cell +from ... import nn from ...ops.operations import _quant_ops as Q __all__ = [ @@ -1495,6 +1496,8 @@ class ActQuant(_QuantActivation): quant_config=quant_config_default, quant_dtype=QuantDtype.INT8): super(ActQuant, self).__init__() + act_class = activation.__class__ + act_list = [nn.ReLU, nn.ReLU6] self.act = Validator.check_isinstance("activation", activation, Cell) self.fake_before = Validator.check_bool(fake_before, "fake_before") if self.fake_before: @@ -1503,12 +1506,21 @@ class ActQuant(_QuantActivation): ema=ema, ema_decay=ema_decay, quant_dtype=quant_dtype) + self.neg_trunc = False + self.narrow_range = False + preset_dict = quant_config.activation.p.keywords + if 'mode' in preset_dict and preset_dict['mode'] == "LEARNED_SCALE" and act_class in act_list: + self.neg_trunc = True + elif 'narrow_range' in preset_dict: + self.narrow_range = preset_dict['narrow_range'] + self.fake_quant_act = quant_config.activation(min_init=-6, max_init=6, ema=ema, ema_decay=ema_decay, - quant_dtype=quant_dtype) - + quant_dtype=quant_dtype, + neg_trunc=self.neg_trunc, + narrow_range=self.narrow_range) def construct(self, x): if self.fake_before: x = self.fake_quant_act_before(x) diff --git a/tests/st/quantization/lenet_quant/config.py b/tests/st/quantization/lenet_quant/config.py index 7c4f5a54b7e..1106edfa6da 100644 --- a/tests/st/quantization/lenet_quant/config.py +++ b/tests/st/quantization/lenet_quant/config.py @@ -18,19 +18,6 @@ network config setting, will be used in test_lenet_quant.py from easydict import EasyDict as edict -nonquant_cfg = edict({ - 'num_classes': 10, - 'lr': 0.01, - 'momentum': 0.9, - 'epoch_size': 10, - 'batch_size': 32, - 'buffer_size': 1000, - 'image_height': 32, - 'image_width': 32, - 'save_checkpoint_steps': 1875, - 'keep_checkpoint_max': 10, -}) - quant_cfg = edict({ 'num_classes': 10, 'lr': 0.01, diff --git a/tests/st/quantization/lenet_quant/lenet.py b/tests/st/quantization/lenet_quant/lenet.py deleted file mode 100644 index 42444100073..00000000000 --- a/tests/st/quantization/lenet_quant/lenet.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""LeNet.""" -import mindspore.nn as nn -from mindspore.common.initializer import TruncatedNormal - - -def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): - """weight initial for conv layer""" - weight = weight_variable() - return nn.Conv2d(in_channels, out_channels, - kernel_size=kernel_size, stride=stride, padding=padding, - weight_init=weight, has_bias=False, pad_mode="valid") - - -def fc_with_initialize(input_channels, out_channels): - """weight initial for fc layer""" - weight = weight_variable() - bias = weight_variable() - return nn.Dense(input_channels, out_channels, weight, bias) - - -def weight_variable(): - """weight initial""" - return TruncatedNormal(0.02) - - -class LeNet5(nn.Cell): - """ - Lenet network - - Args: - num_class (int): Num classes. Default: 10. - - Returns: - Tensor, output tensor - Examples: - >>> LeNet(num_class=10) - - """ - - def __init__(self, num_class=10, channel=1): - super(LeNet5, self).__init__() - self.num_class = num_class - self.conv1 = conv(channel, 6, 5) - self.conv2 = conv(6, 16, 5) - self.fc1 = fc_with_initialize(16 * 5 * 5, 120) - self.fc2 = fc_with_initialize(120, 84) - self.fc3 = fc_with_initialize(84, self.num_class) - self.relu = nn.ReLU() - self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.flatten = nn.Flatten() - - def construct(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.flatten(x) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.relu(x) - x = self.fc3(x) - return x diff --git a/tests/st/quantization/lenet_quant/test_lenet_quant.py b/tests/st/quantization/lenet_quant/test_lenet_quant.py index f0d593a2230..cc2b81f788f 100644 --- a/tests/st/quantization/lenet_quant/test_lenet_quant.py +++ b/tests/st/quantization/lenet_quant/test_lenet_quant.py @@ -23,68 +23,25 @@ from mindspore import Tensor from mindspore.common import dtype as mstype import mindspore.nn as nn from mindspore.nn.metrics import Accuracy -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore import load_checkpoint, load_param_into_net, export from mindspore.train import Model from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.quant.quantizer import OptimizeOption from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from dataset import create_dataset -from config import nonquant_cfg, quant_cfg -from lenet import LeNet5 +from config import quant_cfg from lenet_fusion import LeNet5 as LeNet5Fusion import numpy as np device_target = 'GPU' data_path = "/home/workspace/mindspore_dataset/mnist" - - -def train_lenet(): - context.set_context(mode=context.GRAPH_MODE, device_target=device_target) - cfg = nonquant_cfg - ds_train = create_dataset(os.path.join(data_path, "train"), - cfg.batch_size) - - network = LeNet5(cfg.num_classes) - net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="ckpt_lenet_noquant", config=config_ck) - model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - - print("============== Starting Training Lenet==============") - model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], - dataset_sink_mode=True) - - -def eval_lenet(): - context.set_context(mode=context.GRAPH_MODE, device_target=device_target) - cfg = nonquant_cfg - ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1) - ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt' - # define fusion network - network = LeNet5(cfg.num_classes) - net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") - net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - # call back and monitor - model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - # load quantization aware network checkpoint - param_dict = load_checkpoint(ckpt_path) - not_load_param = load_param_into_net(network, param_dict) - if not_load_param: - raise ValueError("Load param into net fail!") - - print("============== Starting Testing ==============") - acc = model.eval(ds_eval, dataset_sink_mode=True) - print("============== {} ==============".format(acc)) - +lenet_ckpt_path = "/home/workspace/mindspore_dataset/checkpoint/lenet/ckpt_lenet_noquant-10_1875.ckpt" def train_lenet_quant(optim_option="QAT"): context.set_context(mode=context.GRAPH_MODE, device_target=device_target) cfg = quant_cfg - ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt' + ckpt_path = lenet_ckpt_path ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1) step_size = ds_train.get_dataset_size() @@ -211,8 +168,6 @@ def export_lenet(optim_option="QAT"): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_lenet_quant(): - train_lenet() - eval_lenet() train_lenet_quant() eval_quant() export_lenet()