forked from mindspore-Ecosystem/mindspore
mode_export_v3
This commit is contained in:
parent
9b2b062642
commit
9607778f01
|
@ -391,14 +391,16 @@ class ExportToQuantInferNetwork:
|
||||||
|
|
||||||
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
|
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
|
||||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||||
_, _, param_dict["output_maxq"], param_dict["output_minq"] = \
|
if fake_quant_a_out is not None:
|
||||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
|
_, _, param_dict["output_maxq"], param_dict["output_minq"] = \
|
||||||
|
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
|
||||||
|
|
||||||
info = self.quant_info_table.get(w_minq_name, None)
|
info = self.quant_info_table.get(w_minq_name, None)
|
||||||
if info:
|
if info:
|
||||||
fake_quant_a_in_op, minq_name = info
|
fake_quant_a_in_op, minq_name = info
|
||||||
if minq_name == 'input':
|
if minq_name == 'input':
|
||||||
scale_a_in, zp_a_in = self.input_scale, self.input_zero_point
|
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
|
||||||
|
self.input_scale, self.input_zero_point, 'None', 'None'
|
||||||
else:
|
else:
|
||||||
maxq = self.all_parameters[minq_name[:-4] + "maxq"]
|
maxq = self.all_parameters[minq_name[:-4] + "maxq"]
|
||||||
minq = self.all_parameters[minq_name]
|
minq = self.all_parameters[minq_name]
|
||||||
|
@ -483,11 +485,11 @@ class ExportToQuantInferNetwork:
|
||||||
if isinstance(subcell, quant.Conv2dBnAct):
|
if isinstance(subcell, quant.Conv2dBnAct):
|
||||||
cell_core = subcell.conv
|
cell_core = subcell.conv
|
||||||
activation = subcell.activation
|
activation = subcell.activation
|
||||||
fake_quant_act = activation.fake_quant_act
|
fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
|
||||||
elif isinstance(subcell, quant.DenseBnAct):
|
elif isinstance(subcell, quant.DenseBnAct):
|
||||||
cell_core = subcell.dense
|
cell_core = subcell.dense
|
||||||
activation = subcell.activation
|
activation = subcell.activation
|
||||||
fake_quant_act = activation.fake_quant_act
|
fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
|
||||||
if cell_core is not None:
|
if cell_core is not None:
|
||||||
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
|
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
|
||||||
if new_subcell:
|
if new_subcell:
|
||||||
|
|
|
@ -519,7 +519,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
|
||||||
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
logger.info("exporting model file:%s format:%s.", file_name, file_format)
|
||||||
check_input_data(*inputs, data_class=Tensor)
|
check_input_data(*inputs, data_class=Tensor)
|
||||||
|
|
||||||
net = _quant_export(net, *inputs, file_format='AIR', **kwargs)
|
net = _quant_export(net, *inputs, file_format=file_format, **kwargs)
|
||||||
_export(net, file_name, file_format, *inputs)
|
_export(net, file_name, file_format, *inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -566,7 +566,7 @@ def _export(net, file_name, file_format, *inputs):
|
||||||
net.set_train(mode=True)
|
net.set_train(mode=True)
|
||||||
|
|
||||||
|
|
||||||
def _quant_export(network, *inputs, file_format='AIR', **kwargs):
|
def _quant_export(network, *inputs, file_format, **kwargs):
|
||||||
"""
|
"""
|
||||||
Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
|
Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue