forked from mindspore-Ecosystem/mindspore
bug fix in auto create quant graph in master
This commit is contained in:
parent
cf6dd99ed7
commit
9b7a426c6b
|
@ -1193,9 +1193,9 @@ class QuantBlock(Cell):
|
|||
self.dequant = dequant_op
|
||||
self.dequant_scale = dequant_scale
|
||||
self.bias = bias
|
||||
self.has_bias = bias is None
|
||||
self.has_bias = bias is not None
|
||||
self.activation = activation
|
||||
self.has_act = activation is None
|
||||
self.has_act = activation is not None
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
def construct(self, x):
|
||||
|
|
|
@ -86,7 +86,7 @@ class LossMonitor(Callback):
|
|||
|
||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
|
||||
"loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}]".format(
|
||||
"loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}ms]".format(
|
||||
cb_params.cur_epoch_num, cb_params.epoch_num,
|
||||
cur_step_in_epoch, int(cb_params.batch_num),
|
||||
step_loss, np.mean(self.losses),
|
||||
|
|
|
@ -33,7 +33,6 @@ from ...ops.operations import _inner_ops as inner
|
|||
from ...train import serialization
|
||||
from . import quant_utils
|
||||
|
||||
|
||||
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
||||
nn.ReLU6: quant.ReLU6Quant,
|
||||
nn.HSigmoid: quant.HSigmoidQuant,
|
||||
|
@ -178,7 +177,6 @@ class ConvertToQuantNetwork:
|
|||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
momentum=1 - bn_inner.momentum,
|
||||
quant_delay=self.weight_qdelay,
|
||||
freeze_bn=self.freeze_bn,
|
||||
per_channel=self.weight_channel,
|
||||
|
@ -268,16 +266,16 @@ class ConvertToQuantNetwork:
|
|||
narrow_range=self.act_range)
|
||||
|
||||
|
||||
class ExportQuantNetworkDeploy:
|
||||
class ExportToQuantInferNetwork:
|
||||
"""
|
||||
Convert quantization aware network to deploy network.
|
||||
Convert quantization aware network to infer network.
|
||||
|
||||
Args:
|
||||
network (Cell): MindSpore network produced by `convert_quant_network`.
|
||||
inputs (Tensor): Inputs of the `network`.
|
||||
network (Cell): MindSpore network API `convert_quant_network`.
|
||||
inputs (Tensor): Input tensors of the `quantization aware training network`.
|
||||
|
||||
Returns:
|
||||
Cell, converted network.
|
||||
Cell, GEIR backend Infer network.
|
||||
"""
|
||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
|
@ -287,7 +285,7 @@ class ExportQuantNetworkDeploy:
|
|||
network = validator.check_isinstance('network', network, (nn.Cell,))
|
||||
self.data_type = mstype.int8
|
||||
self.network = copy.deepcopy(network)
|
||||
self.all_paramters = {p.name: p for p in self.network.get_parameters()}
|
||||
self.all_parameters = {p.name: p for p in self.network.get_parameters()}
|
||||
self.get_inputs_table(inputs)
|
||||
|
||||
def get_inputs_table(self, inputs):
|
||||
|
@ -315,8 +313,8 @@ class ExportQuantNetworkDeploy:
|
|||
info = self.quant_info_table.get(w_minq_name, None)
|
||||
if info:
|
||||
fack_quant_a_in_op, minq_name = info
|
||||
maxq = self.all_paramters[minq_name[:-4] + "maxq"]
|
||||
minq = self.all_paramters[minq_name]
|
||||
maxq = self.all_parameters[minq_name[:-4] + "maxq"]
|
||||
minq = self.all_parameters[minq_name]
|
||||
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
|
||||
else:
|
||||
logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}")
|
||||
|
@ -357,7 +355,7 @@ class ExportQuantNetworkDeploy:
|
|||
return block
|
||||
|
||||
def _convert_quant2deploy(self, network):
|
||||
"""Convet network's all quant subcell to deploy subcell."""
|
||||
"""Convert network's all quant subcell to deploy subcell."""
|
||||
cells = network.name_cells()
|
||||
change = False
|
||||
for name in cells:
|
||||
|
@ -395,18 +393,26 @@ class ExportQuantNetworkDeploy:
|
|||
return network
|
||||
|
||||
|
||||
def export_geir(network, *inputs, file_name):
|
||||
def export(network, *inputs, file_name, file_format='GEIR'):
|
||||
"""
|
||||
Exports MindSpore quant predict model to deploy with GEIR.
|
||||
Exports MindSpore quantization predict model to deploy with GEIR.
|
||||
|
||||
Args:
|
||||
network (Cell): MindSpore network produced by `convert_quant_network`.
|
||||
inputs (Tensor): Inputs of the `network`.
|
||||
inputs (Tensor): Inputs of the `quantization aware training network`.
|
||||
file_name (str): File name of model to export.
|
||||
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
|
||||
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
|
||||
"""
|
||||
exporter = ExportQuantNetworkDeploy(network, *inputs)
|
||||
deploy_net = exporter.run()
|
||||
serialization.export(deploy_net, *inputs, file_name=file_name, file_format="GEIR")
|
||||
supported_formats = ['GEIR']
|
||||
|
||||
if file_format not in supported_formats:
|
||||
raise ValueError('Illegal file format {}.'.format(file_format))
|
||||
|
||||
if file_format == 'GEIR':
|
||||
exporter = ExportToQuantInferNetwork(network, *inputs)
|
||||
deploy_net = exporter.run()
|
||||
serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
|
||||
|
||||
|
||||
def convert_quant_network(network,
|
||||
|
@ -443,6 +449,7 @@ def convert_quant_network(network,
|
|||
Cell, Network which has change to quantization aware training network cell.
|
||||
"""
|
||||
support_device = ["Ascend", "GPU"]
|
||||
|
||||
def convert2list(name, value):
|
||||
if not isinstance(value, list) and not isinstance(value, tuple):
|
||||
value = [value]
|
||||
|
@ -457,7 +464,7 @@ def convert_quant_network(network,
|
|||
narrow_range = convert2list("narrow range", narrow_range)
|
||||
|
||||
if context.get_context('device_target') not in support_device:
|
||||
raise KeyError("Not support {} backend.".format(context.get_context('device_target')))
|
||||
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
||||
|
||||
net = ConvertToQuantNetwork(network=network,
|
||||
quant_delay=quant_delay,
|
||||
|
|
|
@ -160,7 +160,10 @@ def load_checkpoint(ckpt_file_name, net=None):
|
|||
if not isinstance(ckpt_file_name, str):
|
||||
raise ValueError("The ckpt_file_name must be string.")
|
||||
|
||||
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
|
||||
if not os.path.exists(ckpt_file_name):
|
||||
raise ValueError("The checkpoint file is not exist.")
|
||||
|
||||
if ckpt_file_name[-5:] != ".ckpt":
|
||||
raise ValueError("Please input the correct checkpoint file name.")
|
||||
|
||||
if os.path.getsize(ckpt_file_name) == 0:
|
||||
|
|
|
@ -57,7 +57,7 @@ if __name__ == "__main__":
|
|||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
# load check point into network
|
||||
param_dict = load_checkpoint(args.ckpt_path, network.type)
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
|
|
|
@ -49,7 +49,7 @@ if __name__ == "__main__":
|
|||
|
||||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# convert fusion netwrok to quantization aware network
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||
|
||||
# define loss
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
export quantization aware training network to infer `GEIR` backend.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train.quant import quant
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import mnist_cfg as cfg
|
||||
from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||
help='path where the dataset is saved')
|
||||
parser.add_argument('--ckpt_path', type=str, default="",
|
||||
help='if mode is test, must provide path where the trained ckpt file')
|
||||
parser.add_argument('--dataset_sink_mode', type=bool, default=True,
|
||||
help='dataset_sink_mode is False or True')
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
|
||||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||
quant.export(network, inputs, file_name="lenet_quant", file_format='GEIR')
|
|
@ -22,7 +22,7 @@ import os
|
|||
import argparse
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from src.dataset import create_dataset
|
||||
|
@ -54,7 +54,6 @@ if __name__ == "__main__":
|
|||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
|
||||
# call back and monitor
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||
|
@ -63,6 +62,6 @@ if __name__ == "__main__":
|
|||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()],
|
||||
dataset_sink_mode=args.dataset_sink_mode)
|
||||
print("============== End Training ==============")
|
||||
|
|
|
@ -23,7 +23,7 @@ import argparse
|
|||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
|
||||
from mindspore.train import Model
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train.quant import quant
|
||||
|
@ -51,20 +51,19 @@ if __name__ == "__main__":
|
|||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path, network.type)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# define network loss
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
# define network optimization
|
||||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||||
|
||||
# call back and monitor
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||
|
@ -73,6 +72,6 @@ if __name__ == "__main__":
|
|||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
|
||||
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()],
|
||||
dataset_sink_mode=args.dataset_sink_mode)
|
||||
print("============== End Training ==============")
|
||||
|
|
Loading…
Reference in New Issue