Modified the ST of lenet_quant, and fix the neg_trunc bug for manual quantization

This commit is contained in:
Erpim 2021-05-31 15:22:22 +08:00
parent 22e55a5193
commit 02228d342c
5 changed files with 18 additions and 150 deletions

View File

@ -518,15 +518,8 @@ class QuantizationAwareTraining(Quantizer):
"""
act_class = activation.__class__
act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid]
neg_trunc_act_list = [nn.ReLU, nn.ReLU6]
act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish]
if act_class in neg_trunc_act_list and OptimizeOption.LEARNED_SCALE in self.optimize_option:
self.quant_config = self.quant_config._replace(
activation=self.quant_config.activation.partial_init(neg_trunc=True, narrow_range=False))
return quant.ActQuant(activation=activation,
quant_config=self.quant_config,
quant_dtype=self.act_dtype)
if act_class in act_list:
return quant.ActQuant(activation=activation,
quant_config=self.quant_config,

View File

@ -30,6 +30,7 @@ import mindspore.context as context
from .normalization import BatchNorm2d
from .activation import get_activation
from ..cell import Cell
from ... import nn
from ...ops.operations import _quant_ops as Q
__all__ = [
@ -1495,6 +1496,8 @@ class ActQuant(_QuantActivation):
quant_config=quant_config_default,
quant_dtype=QuantDtype.INT8):
super(ActQuant, self).__init__()
act_class = activation.__class__
act_list = [nn.ReLU, nn.ReLU6]
self.act = Validator.check_isinstance("activation", activation, Cell)
self.fake_before = Validator.check_bool(fake_before, "fake_before")
if self.fake_before:
@ -1503,12 +1506,21 @@ class ActQuant(_QuantActivation):
ema=ema,
ema_decay=ema_decay,
quant_dtype=quant_dtype)
self.neg_trunc = False
self.narrow_range = False
preset_dict = quant_config.activation.p.keywords
if 'mode' in preset_dict and preset_dict['mode'] == "LEARNED_SCALE" and act_class in act_list:
self.neg_trunc = True
elif 'narrow_range' in preset_dict:
self.narrow_range = preset_dict['narrow_range']
self.fake_quant_act = quant_config.activation(min_init=-6,
max_init=6,
ema=ema,
ema_decay=ema_decay,
quant_dtype=quant_dtype)
quant_dtype=quant_dtype,
neg_trunc=self.neg_trunc,
narrow_range=self.narrow_range)
def construct(self, x):
if self.fake_before:
x = self.fake_quant_act_before(x)

View File

@ -18,19 +18,6 @@ network config setting, will be used in test_lenet_quant.py
from easydict import EasyDict as edict
nonquant_cfg = edict({
'num_classes': 10,
'lr': 0.01,
'momentum': 0.9,
'epoch_size': 10,
'batch_size': 32,
'buffer_size': 1000,
'image_height': 32,
'image_width': 32,
'save_checkpoint_steps': 1875,
'keep_checkpoint_max': 10,
})
quant_cfg = edict({
'num_classes': 10,
'lr': 0.01,

View File

@ -1,79 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LeNet."""
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(channel, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x

View File

@ -23,68 +23,25 @@ from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore import load_checkpoint, load_param_into_net, export
from mindspore.train import Model
from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.compression.quant.quantizer import OptimizeOption
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
from dataset import create_dataset
from config import nonquant_cfg, quant_cfg
from lenet import LeNet5
from config import quant_cfg
from lenet_fusion import LeNet5 as LeNet5Fusion
import numpy as np
device_target = 'GPU'
data_path = "/home/workspace/mindspore_dataset/mnist"
def train_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = nonquant_cfg
ds_train = create_dataset(os.path.join(data_path, "train"),
cfg.batch_size)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="ckpt_lenet_noquant", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training Lenet==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
dataset_sink_mode=True)
def eval_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = nonquant_cfg
ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1)
ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt'
# define fusion network
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# call back and monitor
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
# load quantization aware network checkpoint
param_dict = load_checkpoint(ckpt_path)
not_load_param = load_param_into_net(network, param_dict)
if not_load_param:
raise ValueError("Load param into net fail!")
print("============== Starting Testing ==============")
acc = model.eval(ds_eval, dataset_sink_mode=True)
print("============== {} ==============".format(acc))
lenet_ckpt_path = "/home/workspace/mindspore_dataset/checkpoint/lenet/ckpt_lenet_noquant-10_1875.ckpt"
def train_lenet_quant(optim_option="QAT"):
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
cfg = quant_cfg
ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt'
ckpt_path = lenet_ckpt_path
ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1)
step_size = ds_train.get_dataset_size()
@ -211,8 +168,6 @@ def export_lenet(optim_option="QAT"):
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_lenet_quant():
train_lenet()
eval_lenet()
train_lenet_quant()
eval_quant()
export_lenet()