!6246 modify export quant network bug

Merge pull request !6246 from changzherui/mod_quant_export
This commit is contained in:
mindspore-ci-bot 2020-09-18 10:52:31 +08:00 committed by Gitee
commit e04ada1148
5 changed files with 12 additions and 15 deletions

View File

@ -1446,7 +1446,6 @@ class QuantMindirBlock(Cell):
if isinstance(activation, ReLU):
self.activation = None
self.has_act = False
self.bias_add = P.BiasAdd()
def construct(self, x):
if self.has_bias:

View File

@ -361,12 +361,12 @@ class ExportToQuantInferNetwork:
param_dict["symmetric"] = fake_quant_a_out.symmetric
if self.is_mindir:
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
quant_utils.scale_zp_max_min_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \
quant_utils.scale_zp_max_min_from_fack_quant_cell(fake_quant_a_out, np_type)
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
else:
scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type)
scale_w, zp_w = quant_utils.scale_zp_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
scale_a_out, _ = quant_utils.scale_zp_from_fake_quant_cell(fake_quant_a_out, np_type)
info = self.quant_info_table.get(w_minq_name, None)
if info:
fack_quant_a_in_op, minq_name = info
@ -432,7 +432,8 @@ class ExportToQuantInferNetwork:
op_core = cell_core.conv
weight = Tensor(weight, self.data_type)
weight_b = Tensor(weight_b)
bias_b = Tensor(bias_b, mstype.float32)
if bias_b is not None:
bias_b = Tensor(bias_b, mstype.float32)
if self.is_mindir:
block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
else:

View File

@ -122,7 +122,7 @@ def weight2int(data, scale, zero_point):
return np.round((data / scale) + zero_point)
def scale_zp_from_fack_quant_cell(cell, data_type):
def scale_zp_from_fake_quant_cell(cell, data_type):
r"""
Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
@ -146,7 +146,7 @@ def scale_zp_from_fack_quant_cell(cell, data_type):
return scale, zp
def scale_zp_max_min_from_fack_quant_cell(cell, data_type):
def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
"""Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`."""
minq = cell.minq.data.asnumpy()
maxq = cell.maxq.data.asnumpy()

View File

@ -175,10 +175,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
param_list = []
for (key, value) in param_dict.items():
each_param = {"name": key}
if isinstance(value.data, Tensor):
param_data = value.data
else:
param_data = Tensor(value.data)
param_data = Tensor(value.data)
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
# which should be combined before saving

View File

@ -24,7 +24,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.quant import quant
from src.mobilenetV2 import mobilenetV2
from src.config import config_ascend
from src.config import config_ascend_quant
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
@ -34,7 +34,7 @@ args_opt = parser.parse_args()
if __name__ == '__main__':
cfg = None
if args_opt.device_target == "Ascend":
cfg = config_ascend
cfg = config_ascend_quant
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
else:
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
@ -50,5 +50,5 @@ if __name__ == '__main__':
# export network
print("============== Starting export ==============")
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
quant.export(network, inputs, file_name="mobilenet_quant", file_format='AIR')
quant.export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR')
print("============== End export ==============")