forked from mindspore-Ecosystem/mindspore
!6246 modify export quant network bug
Merge pull request !6246 from changzherui/mod_quant_export
This commit is contained in:
commit
e04ada1148
|
@ -1446,7 +1446,6 @@ class QuantMindirBlock(Cell):
|
|||
if isinstance(activation, ReLU):
|
||||
self.activation = None
|
||||
self.has_act = False
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
def construct(self, x):
|
||||
if self.has_bias:
|
||||
|
|
|
@ -361,12 +361,12 @@ class ExportToQuantInferNetwork:
|
|||
param_dict["symmetric"] = fake_quant_a_out.symmetric
|
||||
if self.is_mindir:
|
||||
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||
scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fack_quant_cell(fake_quant_a_out, np_type)
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
|
||||
else:
|
||||
scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||
scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type)
|
||||
scale_w, zp_w = quant_utils.scale_zp_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||
scale_a_out, _ = quant_utils.scale_zp_from_fake_quant_cell(fake_quant_a_out, np_type)
|
||||
info = self.quant_info_table.get(w_minq_name, None)
|
||||
if info:
|
||||
fack_quant_a_in_op, minq_name = info
|
||||
|
@ -432,7 +432,8 @@ class ExportToQuantInferNetwork:
|
|||
op_core = cell_core.conv
|
||||
weight = Tensor(weight, self.data_type)
|
||||
weight_b = Tensor(weight_b)
|
||||
bias_b = Tensor(bias_b, mstype.float32)
|
||||
if bias_b is not None:
|
||||
bias_b = Tensor(bias_b, mstype.float32)
|
||||
if self.is_mindir:
|
||||
block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
|
||||
else:
|
||||
|
|
|
@ -122,7 +122,7 @@ def weight2int(data, scale, zero_point):
|
|||
return np.round((data / scale) + zero_point)
|
||||
|
||||
|
||||
def scale_zp_from_fack_quant_cell(cell, data_type):
|
||||
def scale_zp_from_fake_quant_cell(cell, data_type):
|
||||
r"""
|
||||
Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
|
||||
|
||||
|
@ -146,7 +146,7 @@ def scale_zp_from_fack_quant_cell(cell, data_type):
|
|||
return scale, zp
|
||||
|
||||
|
||||
def scale_zp_max_min_from_fack_quant_cell(cell, data_type):
|
||||
def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
|
||||
"""Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`."""
|
||||
minq = cell.minq.data.asnumpy()
|
||||
maxq = cell.maxq.data.asnumpy()
|
||||
|
|
|
@ -175,10 +175,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
|
|||
param_list = []
|
||||
for (key, value) in param_dict.items():
|
||||
each_param = {"name": key}
|
||||
if isinstance(value.data, Tensor):
|
||||
param_data = value.data
|
||||
else:
|
||||
param_data = Tensor(value.data)
|
||||
param_data = Tensor(value.data)
|
||||
|
||||
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
|
||||
# which should be combined before saving
|
||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
from mindspore.train.quant import quant
|
||||
|
||||
from src.mobilenetV2 import mobilenetV2
|
||||
from src.config import config_ascend
|
||||
from src.config import config_ascend_quant
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
|
@ -34,7 +34,7 @@ args_opt = parser.parse_args()
|
|||
if __name__ == '__main__':
|
||||
cfg = None
|
||||
if args_opt.device_target == "Ascend":
|
||||
cfg = config_ascend
|
||||
cfg = config_ascend_quant
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
else:
|
||||
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
|
||||
|
@ -50,5 +50,5 @@ if __name__ == '__main__':
|
|||
# export network
|
||||
print("============== Starting export ==============")
|
||||
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||
quant.export(network, inputs, file_name="mobilenet_quant", file_format='AIR')
|
||||
quant.export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR')
|
||||
print("============== End export ==============")
|
||||
|
|
Loading…
Reference in New Issue