forked from mindspore-Ecosystem/mindspore
Modified the ST of lenet_quant, and fix the neg_trunc bug for manual quantization
This commit is contained in:
parent
22e55a5193
commit
02228d342c
|
@ -518,15 +518,8 @@ class QuantizationAwareTraining(Quantizer):
|
|||
"""
|
||||
act_class = activation.__class__
|
||||
act_list = [nn.ReLU, nn.ReLU6, nn.Sigmoid]
|
||||
neg_trunc_act_list = [nn.ReLU, nn.ReLU6]
|
||||
act_list_with_fake_before = [nn.LeakyReLU, nn.HSigmoid, nn.HSwish]
|
||||
|
||||
if act_class in neg_trunc_act_list and OptimizeOption.LEARNED_SCALE in self.optimize_option:
|
||||
self.quant_config = self.quant_config._replace(
|
||||
activation=self.quant_config.activation.partial_init(neg_trunc=True, narrow_range=False))
|
||||
return quant.ActQuant(activation=activation,
|
||||
quant_config=self.quant_config,
|
||||
quant_dtype=self.act_dtype)
|
||||
if act_class in act_list:
|
||||
return quant.ActQuant(activation=activation,
|
||||
quant_config=self.quant_config,
|
||||
|
|
|
@ -30,6 +30,7 @@ import mindspore.context as context
|
|||
from .normalization import BatchNorm2d
|
||||
from .activation import get_activation
|
||||
from ..cell import Cell
|
||||
from ... import nn
|
||||
from ...ops.operations import _quant_ops as Q
|
||||
|
||||
__all__ = [
|
||||
|
@ -1495,6 +1496,8 @@ class ActQuant(_QuantActivation):
|
|||
quant_config=quant_config_default,
|
||||
quant_dtype=QuantDtype.INT8):
|
||||
super(ActQuant, self).__init__()
|
||||
act_class = activation.__class__
|
||||
act_list = [nn.ReLU, nn.ReLU6]
|
||||
self.act = Validator.check_isinstance("activation", activation, Cell)
|
||||
self.fake_before = Validator.check_bool(fake_before, "fake_before")
|
||||
if self.fake_before:
|
||||
|
@ -1503,12 +1506,21 @@ class ActQuant(_QuantActivation):
|
|||
ema=ema,
|
||||
ema_decay=ema_decay,
|
||||
quant_dtype=quant_dtype)
|
||||
self.neg_trunc = False
|
||||
self.narrow_range = False
|
||||
preset_dict = quant_config.activation.p.keywords
|
||||
if 'mode' in preset_dict and preset_dict['mode'] == "LEARNED_SCALE" and act_class in act_list:
|
||||
self.neg_trunc = True
|
||||
elif 'narrow_range' in preset_dict:
|
||||
self.narrow_range = preset_dict['narrow_range']
|
||||
|
||||
self.fake_quant_act = quant_config.activation(min_init=-6,
|
||||
max_init=6,
|
||||
ema=ema,
|
||||
ema_decay=ema_decay,
|
||||
quant_dtype=quant_dtype)
|
||||
|
||||
quant_dtype=quant_dtype,
|
||||
neg_trunc=self.neg_trunc,
|
||||
narrow_range=self.narrow_range)
|
||||
def construct(self, x):
|
||||
if self.fake_before:
|
||||
x = self.fake_quant_act_before(x)
|
||||
|
|
|
@ -18,19 +18,6 @@ network config setting, will be used in test_lenet_quant.py
|
|||
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
nonquant_cfg = edict({
|
||||
'num_classes': 10,
|
||||
'lr': 0.01,
|
||||
'momentum': 0.9,
|
||||
'epoch_size': 10,
|
||||
'batch_size': 32,
|
||||
'buffer_size': 1000,
|
||||
'image_height': 32,
|
||||
'image_width': 32,
|
||||
'save_checkpoint_steps': 1875,
|
||||
'keep_checkpoint_max': 10,
|
||||
})
|
||||
|
||||
quant_cfg = edict({
|
||||
'num_classes': 10,
|
||||
'lr': 0.01,
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""LeNet."""
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
"""weight initial for conv layer"""
|
||||
weight = weight_variable()
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
weight_init=weight, has_bias=False, pad_mode="valid")
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
"""weight initial for fc layer"""
|
||||
weight = weight_variable()
|
||||
bias = weight_variable()
|
||||
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
def weight_variable():
|
||||
"""weight initial"""
|
||||
return TruncatedNormal(0.02)
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
"""
|
||||
Lenet network
|
||||
|
||||
Args:
|
||||
num_class (int): Num classes. Default: 10.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor
|
||||
Examples:
|
||||
>>> LeNet(num_class=10)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_class=10, channel=1):
|
||||
super(LeNet5, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.conv1 = conv(channel, 6, 5)
|
||||
self.conv2 = conv(6, 16, 5)
|
||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||
self.fc2 = fc_with_initialize(120, 84)
|
||||
self.fc3 = fc_with_initialize(84, self.num_class)
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
|
@ -23,68 +23,25 @@ from mindspore import Tensor
|
|||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore import load_checkpoint, load_param_into_net, export
|
||||
from mindspore.train import Model
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.compression.quant.quantizer import OptimizeOption
|
||||
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
|
||||
from dataset import create_dataset
|
||||
from config import nonquant_cfg, quant_cfg
|
||||
from lenet import LeNet5
|
||||
from config import quant_cfg
|
||||
from lenet_fusion import LeNet5 as LeNet5Fusion
|
||||
import numpy as np
|
||||
|
||||
device_target = 'GPU'
|
||||
data_path = "/home/workspace/mindspore_dataset/mnist"
|
||||
|
||||
|
||||
def train_lenet():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
cfg = nonquant_cfg
|
||||
ds_train = create_dataset(os.path.join(data_path, "train"),
|
||||
cfg.batch_size)
|
||||
|
||||
network = LeNet5(cfg.num_classes)
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ckpt_lenet_noquant", config=config_ck)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
print("============== Starting Training Lenet==============")
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
|
||||
dataset_sink_mode=True)
|
||||
|
||||
|
||||
def eval_lenet():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
cfg = nonquant_cfg
|
||||
ds_eval = create_dataset(os.path.join(data_path, "test"), cfg.batch_size, 1)
|
||||
ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt'
|
||||
# define fusion network
|
||||
network = LeNet5(cfg.num_classes)
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
# call back and monitor
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
not_load_param = load_param_into_net(network, param_dict)
|
||||
if not_load_param:
|
||||
raise ValueError("Load param into net fail!")
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=True)
|
||||
print("============== {} ==============".format(acc))
|
||||
|
||||
lenet_ckpt_path = "/home/workspace/mindspore_dataset/checkpoint/lenet/ckpt_lenet_noquant-10_1875.ckpt"
|
||||
|
||||
def train_lenet_quant(optim_option="QAT"):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=device_target)
|
||||
cfg = quant_cfg
|
||||
ckpt_path = './ckpt_lenet_noquant-10_1875.ckpt'
|
||||
ckpt_path = lenet_ckpt_path
|
||||
ds_train = create_dataset(os.path.join(data_path, "train"), cfg.batch_size, 1)
|
||||
step_size = ds_train.get_dataset_size()
|
||||
|
||||
|
@ -211,8 +168,6 @@ def export_lenet(optim_option="QAT"):
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lenet_quant():
|
||||
train_lenet()
|
||||
eval_lenet()
|
||||
train_lenet_quant()
|
||||
eval_quant()
|
||||
export_lenet()
|
||||
|
|
Loading…
Reference in New Issue