!17569 Modified the ST of lenet_quant, and fix the neg_trunc bug for manual quantization
From: @erpim Reviewed-by: @zhang__sss,@zlq2020,@zh_qh Signed-off-by: @zlq2020
This commit is contained in:
commit
475386e338
|
@ -518,15 +518,8 @@ class QuantizationAwareTraining(Quantizer):
|
||||||
"""
|
"""
|
||||||
act_class = activation.__class__
|
act_class = activation.__class__
|
||||||
act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid]
|
act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid]
|
||||||
neg_trunc_act_list = [nn.ReLU, nn.ReLU6]
|
|
||||||
act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish]
|
act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish]
|
||||||
|
|
||||||
if act_class in neg_trunc_act_list and OptimizeOption.LEARNED_SCALE in self.optimize_option:
|
|
||||||
self.quant_config = self.quant_config._replace(
|
|
||||||
activation=self.quant_config.activation.partial_init(neg_trunc=True, narrow_range=False))
|
|
||||||
return quant.ActQuant(activation=activation,
|
|
||||||
quant_config=self.quant_config,
|
|
||||||
quant_dtype=self.act_dtype)
|
|
||||||
if act_class in act_list:
|
if act_class in act_list:
|
||||||
return quant.ActQuant(activation=activation,
|
return quant.ActQuant(activation=activation,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
|
|
|
@ -29,6 +29,7 @@ import mindspore.context as context
|
||||||
from .normalization import BatchNorm2d
|
from .normalization import BatchNorm2d
|
||||||
from .activation import get_activation
|
from .activation import get_activation
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
|
from ... import nn
|
||||||
from ...ops.operations import _quant_ops as Q
|
from ...ops.operations import _quant_ops as Q
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -1494,6 +1495,8 @@ class ActQuant(_QuantActivation):
|
||||||
quant_config=quant_config_default,
|
quant_config=quant_config_default,
|
||||||
quant_dtype=QuantDtype.INT8):
|
quant_dtype=QuantDtype.INT8):
|
||||||
super(ActQuant, self).__init__()
|
super(ActQuant, self).__init__()
|
||||||
|
act_class = activation.__class__
|
||||||
|
act_list = [nn.ReLU, nn.ReLU6]
|
||||||
self.act = Validator.check_isinstance("activation", activation, Cell)
|
self.act = Validator.check_isinstance("activation", activation, Cell)
|
||||||
self.fake_before = Validator.check_bool(fake_before, "fake_before")
|
self.fake_before = Validator.check_bool(fake_before, "fake_before")
|
||||||
if self.fake_before:
|
if self.fake_before:
|
||||||
|
@ -1502,12 +1505,21 @@ class ActQuant(_QuantActivation):
|
||||||
ema=ema,
|
ema=ema,
|
||||||
ema_decay=ema_decay,
|
ema_decay=ema_decay,
|
||||||
quant_dtype=quant_dtype)
|
quant_dtype=quant_dtype)
|
||||||
|
self.neg_trunc = False
|
||||||
|
self.narrow_range = False
|
||||||
|
preset_dict = quant_config.activation.p.keywords
|
||||||
|
if 'mode' in preset_dict and preset_dict['mode'] == "LEARNED_SCALE" and act_class in act_list:
|
||||||
|
self.neg_trunc = True
|
||||||
|
elif 'narrow_range' in preset_dict:
|
||||||
|
self.narrow_range = preset_dict['narrow_range']
|
||||||
|
|
||||||
self.fake_quant_act = quant_config.activation(min_init=-6,
|
self.fake_quant_act = quant_config.activation(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
ema=ema,
|
ema=ema,
|
||||||
ema_decay=ema_decay,
|
ema_decay=ema_decay,
|
||||||
quant_dtype=quant_dtype)
|
quant_dtype=quant_dtype,
|
||||||
|
neg_trunc=self.neg_trunc,
|
||||||
|
narrow_range=self.narrow_range)
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
if self.fake_before:
|
if self.fake_before:
|
||||||
x = self.fake_quant_act_before(x)
|
x = self.fake_quant_act_before(x)
|
||||||
|
|
|
@ -18,19 +18,6 @@ network config setting, will be used in test_lenet_quant.py
|
||||||
|
|
||||||
from easydict import EasyDict as edict
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
nonquant_cfg = edict({
|
|
||||||
'num_classes': 10,
|
|
||||||
'lr': 0.01,
|
|
||||||
'momentum': 0.9,
|
|
||||||
'epoch_size': 10,
|
|
||||||
'batch_size': 32,
|
|
||||||
'buffer_size': 1000,
|
|
||||||
'image_height': 32,
|
|
||||||
'image_width': 32,
|
|
||||||
'save_checkpoint_steps': 1875,
|
|
||||||
'keep_checkpoint_max': 10,
|
|
||||||
})
|
|
||||||
|
|
||||||
quant_cfg = edict({
|
quant_cfg = edict({
|
||||||
'num_classes': 10,
|
'num_classes': 10,
|
||||||
'lr': 0.01,
|
'lr': 0.01,
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""LeNet."""
|
|
||||||
import mindspore.nn as nn
|
|
||||||
from mindspore.common.initializer import TruncatedNormal
|
|
||||||
|
|
||||||
|
|
||||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
|
||||||
"""weight initial for conv layer"""
|
|
||||||
weight = weight_variable()
|
|
||||||
return nn.Conv2d(in_channels, out_channels,
|
|
||||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
|
||||||
weight_init=weight, has_bias=False, pad_mode="valid")
|
|
||||||
|
|
||||||
|
|
||||||
def fc_with_initialize(input_channels, out_channels):
|
|
||||||
"""weight initial for fc layer"""
|
|
||||||
weight = weight_variable()
|
|
||||||
bias = weight_variable()
|
|
||||||
return nn.Dense(input_channels, out_channels, weight, bias)
|
|
||||||
|
|
||||||
|
|
||||||
def weight_variable():
|
|
||||||
"""weight initial"""
|
|
||||||
return TruncatedNormal(0.02)
|
|
||||||
|
|
||||||
|
|
||||||
class LeNet5(nn.Cell):
|
|
||||||
"""
|
|
||||||
Lenet network
|
|
||||||
|
|
||||||
Args:
|
|
||||||
num_class (int): Num classes. Default: 10.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor, output tensor
|
|
||||||
Examples:
|
|
||||||
>>> LeNet(num_class=10)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_class=10, channel=1):
|
|
||||||
super(LeNet5, self).__init__()
|
|
||||||
self.num_class = num_class
|
|
||||||
self.conv1 = conv(channel, 6, 5)
|
|
||||||
self.conv2 = conv(6, 16, 5)
|
|
||||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
|
||||||
self.fc2 = fc_with_initialize(120, 84)
|
|
||||||
self.fc3 = fc_with_initialize(84, self.num_class)
|
|
||||||
self.relu = nn.ReLU()
|
|
||||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
||||||
self.flatten = nn.Flatten()
|
|
||||||
|
|
||||||
def construct(self, x):
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.max_pool2d(x)
|
|
||||||
x = self.conv2(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.max_pool2d(x)
|
|
||||||
x = self.flatten(x)
|
|
||||||
x = self.fc1(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.fc2(x)
|
|
||||||
x = self.relu(x)
|
|
||||||
x = self.fc3(x)
|
|
||||||
return x
|
|
|
@ -23,68 +23,25 @@ from mindspore import Tensor
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.nn.metrics import Accuracy
|
from mindspore.nn.metrics import Accuracy
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||||
from mindspore import load_checkpoint, load_param_into_net, export
|
from mindspore import load_checkpoint, load_param_into_net, export
|
||||||
from mindspore.train import Model
|
from mindspore.train import Model
|
||||||
from mindspore.compression.quant import QuantizationAwareTraining
|
from mindspore.compression.quant import QuantizationAwareTraining
|
||||||
from mindspore.compression.quant.quantizer import OptimizeOption
|
from mindspore.compression.quant.quantizer import OptimizeOption
|
||||||
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
|
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||||
from dataset import create_dataset
|
from dataset import create_dataset
|
||||||
from config import nonquant_cfg, quant_cfg
|
from config import quant_cfg
|
||||||
from lenet import LeNet5
|
|
||||||
from lenet_fusion import LeNet5 as LeNet5Fusion
|
from lenet_fusion import LeNet5 as LeNet5Fusion
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
device_target = 'GPU'
|
device_target = 'GPU'
|
||||||
data_path = "/home/workspace/mindspore_dataset/mnist"
|
data_path = "/home/workspace/mindspore_dataset/mnist"
|
||||||
|
lenet_ckpt_path = "/home/workspace/mindspore_dataset/checkpoint/lenet/ckpt_lenet_noquant-10_1875.ckpt"
|
||||||
|
|
||||||
def train_lenet():
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
|
||||||
cfg = nonquant_cfg
|
|
||||||
ds_train = create_dataset(os.path.join(data_path, "train"),
|
|
||||||
cfg.batch_size)
|
|
||||||
|
|
||||||
network = LeNet5(cfg.num_classes)
|
|
||||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
|
||||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
|
||||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
|
||||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
|
||||||
ckpoint_cb = ModelCheckpoint(prefix="ckpt_lenet_noquant", config=config_ck)
|
|
||||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|
||||||
|
|
||||||
print("============== Starting Training Lenet==============")
|
|
||||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
|
|
||||||
dataset_sink_mode=True)
|
|
||||||
|
|
||||||
|
|
||||||
def eval_lenet():
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
|
||||||
cfg = nonquant_cfg
|
|
||||||
ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1)
|
|
||||||
ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt'
|
|
||||||
# define fusion network
|
|
||||||
network = LeNet5(cfg.num_classes)
|
|
||||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
|
||||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
|
||||||
# call back and monitor
|
|
||||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|
||||||
# load quantization aware network checkpoint
|
|
||||||
param_dict = load_checkpoint(ckpt_path)
|
|
||||||
not_load_param = load_param_into_net(network, param_dict)
|
|
||||||
if not_load_param:
|
|
||||||
raise ValueError("Load param into net fail!")
|
|
||||||
|
|
||||||
print("============== Starting Testing ==============")
|
|
||||||
acc = model.eval(ds_eval, dataset_sink_mode=True)
|
|
||||||
print("============== {} ==============".format(acc))
|
|
||||||
|
|
||||||
|
|
||||||
def train_lenet_quant(optim_option="QAT"):
|
def train_lenet_quant(optim_option="QAT"):
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||||
cfg = quant_cfg
|
cfg = quant_cfg
|
||||||
ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt'
|
ckpt_path = lenet_ckpt_path
|
||||||
ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1)
|
ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1)
|
||||||
step_size = ds_train.get_dataset_size()
|
step_size = ds_train.get_dataset_size()
|
||||||
|
|
||||||
|
@ -211,8 +168,6 @@ def export_lenet(optim_option="QAT"):
|
||||||
@pytest.mark.platform_x86_gpu_training
|
@pytest.mark.platform_x86_gpu_training
|
||||||
@pytest.mark.env_onecard
|
@pytest.mark.env_onecard
|
||||||
def test_lenet_quant():
|
def test_lenet_quant():
|
||||||
train_lenet()
|
|
||||||
eval_lenet()
|
|
||||||
train_lenet_quant()
|
train_lenet_quant()
|
||||||
eval_quant()
|
eval_quant()
|
||||||
export_lenet()
|
export_lenet()
|
||||||
|
|
Loading…
Reference in New Issue