forked from mindspore-Ecosystem/mindspore
mod quant export mindir bug
This commit is contained in:
parent
ce7a7b0fc7
commit
4275ecb11b
|
@ -1440,7 +1440,6 @@ class QuantMindirBlock(Cell):
|
||||||
if isinstance(activation, ReLU):
|
if isinstance(activation, ReLU):
|
||||||
self.activation = None
|
self.activation = None
|
||||||
self.has_act = False
|
self.has_act = False
|
||||||
self.bias_add = P.BiasAdd()
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
|
|
|
@ -361,12 +361,12 @@ class ExportToQuantInferNetwork:
|
||||||
param_dict["symmetric"] = fake_quant_a_out.symmetric
|
param_dict["symmetric"] = fake_quant_a_out.symmetric
|
||||||
if self.is_mindir:
|
if self.is_mindir:
|
||||||
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
|
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
|
||||||
quant_utils.scale_zp_max_min_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
|
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||||
scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \
|
scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \
|
||||||
quant_utils.scale_zp_max_min_from_fack_quant_cell(fake_quant_a_out, np_type)
|
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
|
||||||
else:
|
else:
|
||||||
scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
|
scale_w, zp_w = quant_utils.scale_zp_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||||
scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type)
|
scale_a_out, _ = quant_utils.scale_zp_from_fake_quant_cell(fake_quant_a_out, np_type)
|
||||||
info = self.quant_info_table.get(w_minq_name, None)
|
info = self.quant_info_table.get(w_minq_name, None)
|
||||||
if info:
|
if info:
|
||||||
fack_quant_a_in_op, minq_name = info
|
fack_quant_a_in_op, minq_name = info
|
||||||
|
@ -432,7 +432,8 @@ class ExportToQuantInferNetwork:
|
||||||
op_core = cell_core.conv
|
op_core = cell_core.conv
|
||||||
weight = Tensor(weight, self.data_type)
|
weight = Tensor(weight, self.data_type)
|
||||||
weight_b = Tensor(weight_b)
|
weight_b = Tensor(weight_b)
|
||||||
bias_b = Tensor(bias_b, mstype.float32)
|
if bias_b is not None:
|
||||||
|
bias_b = Tensor(bias_b, mstype.float32)
|
||||||
if self.is_mindir:
|
if self.is_mindir:
|
||||||
block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
|
block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -122,7 +122,7 @@ def weight2int(data, scale, zero_point):
|
||||||
return np.round((data / scale) + zero_point)
|
return np.round((data / scale) + zero_point)
|
||||||
|
|
||||||
|
|
||||||
def scale_zp_from_fack_quant_cell(cell, data_type):
|
def scale_zp_from_fake_quant_cell(cell, data_type):
|
||||||
r"""
|
r"""
|
||||||
Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
|
Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ def scale_zp_from_fack_quant_cell(cell, data_type):
|
||||||
return scale, zp
|
return scale, zp
|
||||||
|
|
||||||
|
|
||||||
def scale_zp_max_min_from_fack_quant_cell(cell, data_type):
|
def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
|
||||||
"""Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`."""
|
"""Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`."""
|
||||||
minq = cell.minq.data.asnumpy()
|
minq = cell.minq.data.asnumpy()
|
||||||
maxq = cell.maxq.data.asnumpy()
|
maxq = cell.maxq.data.asnumpy()
|
||||||
|
|
|
@ -170,10 +170,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
|
||||||
param_list = []
|
param_list = []
|
||||||
for (key, value) in param_dict.items():
|
for (key, value) in param_dict.items():
|
||||||
each_param = {"name": key}
|
each_param = {"name": key}
|
||||||
if isinstance(value.data, Tensor):
|
param_data = Tensor(value.data)
|
||||||
param_data = value.data
|
|
||||||
else:
|
|
||||||
param_data = Tensor(value.data)
|
|
||||||
|
|
||||||
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
|
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
|
||||||
# which should be combined before saving
|
# which should be combined before saving
|
||||||
|
|
|
@ -24,7 +24,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.train.quant import quant
|
from mindspore.train.quant import quant
|
||||||
|
|
||||||
from src.mobilenetV2 import mobilenetV2
|
from src.mobilenetV2 import mobilenetV2
|
||||||
from src.config import config_ascend
|
from src.config import config_ascend_quant
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Image classification')
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||||
|
@ -34,7 +34,7 @@ args_opt = parser.parse_args()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
cfg = None
|
cfg = None
|
||||||
if args_opt.device_target == "Ascend":
|
if args_opt.device_target == "Ascend":
|
||||||
cfg = config_ascend
|
cfg = config_ascend_quant
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
|
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
|
||||||
|
@ -50,5 +50,5 @@ if __name__ == '__main__':
|
||||||
# export network
|
# export network
|
||||||
print("============== Starting export ==============")
|
print("============== Starting export ==============")
|
||||||
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
|
inputs = Tensor(np.ones([1, 3, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||||
quant.export(network, inputs, file_name="mobilenet_quant", file_format='AIR')
|
quant.export(network, inputs, file_name="mobilenet_quant", file_format='MINDIR')
|
||||||
print("============== End export ==============")
|
print("============== End export ==============")
|
||||||
|
|
Loading…
Reference in New Issue