forked from mindspore-Ecosystem/mindspore
!7189 mode_export_v2
Merge pull request !7189 from baiyangfan/mode_export_v2
This commit is contained in:
commit
86729985df
|
@ -1369,7 +1369,8 @@ class QuantBlock(Cell):
|
|||
def construct(self, x):
|
||||
x = self.quant(x)
|
||||
if self.has_bias:
|
||||
x = self.core_op(x, self.weight, self.bias)
|
||||
x = self.core_op(x, self.weight)
|
||||
x = self.bias_add(x, self.bias)
|
||||
else:
|
||||
x = self.core_op(x, self.weight)
|
||||
x = self.dequant(x, self.dequant_scale)
|
||||
|
@ -1412,6 +1413,7 @@ class QuantMindirBlock(Cell):
|
|||
self.core_op.add_prim_attr("activation_name", activation.__class__.__name__)
|
||||
self.core_op.add_prim_attr("filter_maxq", Tensor(param_dict["filter_maxq"]))
|
||||
self.core_op.add_prim_attr("filter_minq", Tensor(param_dict["filter_minq"]))
|
||||
if param_dict["output_maxq"] is not None:
|
||||
self.core_op.add_prim_attr("output_maxq", Tensor(param_dict["output_maxq"]))
|
||||
self.core_op.add_prim_attr("output_minq", Tensor(param_dict["output_minq"]))
|
||||
self.core_op.add_prim_attr("symmetric", Tensor(param_dict["symmetric"]))
|
||||
|
@ -1419,24 +1421,27 @@ class QuantMindirBlock(Cell):
|
|||
self.core_op.add_prim_attr("pad_mode", core_op.pad_mode)
|
||||
self.core_op.add_prim_attr("num_bits", Tensor(8))
|
||||
self.core_op.add_prim_attr("narrow_range", Tensor(False))
|
||||
if param_dict["input_maxq"] is not None:
|
||||
self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"]))
|
||||
self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"]))
|
||||
else:
|
||||
if param_dict["input_maxq"] == 'None':
|
||||
self.core_op.add_prim_attr("mean", Tensor(param_dict["mean"]))
|
||||
self.core_op.add_prim_attr("std_dev", Tensor(param_dict["std_dev"]))
|
||||
elif param_dict["input_maxq"] is not None:
|
||||
self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"]))
|
||||
self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"]))
|
||||
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
self.has_bias = bias is not None
|
||||
self.activation = activation
|
||||
self.has_act = activation is not None
|
||||
self.bias_add = P.BiasAdd()
|
||||
if isinstance(activation, ReLU):
|
||||
self.activation = None
|
||||
self.has_act = False
|
||||
|
||||
def construct(self, x):
|
||||
if self.has_bias:
|
||||
x = self.core_op(x, self.weight, self.bias)
|
||||
x = self.core_op(x, self.weight)
|
||||
x = self.bias_add(x, self.bias)
|
||||
else:
|
||||
x = self.core_op(x, self.weight)
|
||||
return x
|
||||
|
|
|
@ -355,7 +355,7 @@ class ExportToQuantInferNetwork:
|
|||
param_dict["input_minq"] = None
|
||||
param_dict["mean"] = self.mean
|
||||
param_dict["std_dev"] = self.std_dev
|
||||
param_dict["symmetric"] = fake_quant_a_out.symmetric
|
||||
param_dict["symmetric"] = cell_core.fake_quant_weight.symmetric
|
||||
|
||||
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||
|
@ -578,3 +578,235 @@ def convert_quant_network(network,
|
|||
symmetric=symmetric,
|
||||
narrow_range=narrow_range)
|
||||
return net.run()
|
||||
|
||||
def manual_export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='MINDIR'):
|
||||
"""
|
||||
Manual exports MindSpore quantization predict model to deploy wiAIR and MINDIR.
|
||||
|
||||
Args:
|
||||
network (Cell): MindSpore network produced by `convert_quant_network`.
|
||||
inputs (Tensor): Inputs of the `quantization aware training network`.
|
||||
file_name (str): File name of model to export.
|
||||
mean (int, float): Input data mean. Default: 127.5.
|
||||
std_dev (int, float): Input data variance. Default: 127.5.
|
||||
file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported
|
||||
quantization aware model. Default: 'AIR'.
|
||||
|
||||
- AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
|
||||
Ascend model.
|
||||
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
|
||||
for MindSpore models.
|
||||
Recommended suffix for output file is '.mindir'.
|
||||
"""
|
||||
supported_device = ["Ascend", "GPU"]
|
||||
supported_formats = ['AIR', 'MINDIR']
|
||||
|
||||
mean = Validator.check_type("mean", mean, (int, float))
|
||||
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
|
||||
|
||||
if context.get_context('device_target') not in supported_device:
|
||||
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
|
||||
|
||||
if file_format not in supported_formats:
|
||||
raise ValueError('Illegal file format {}.'.format(file_format))
|
||||
|
||||
network.set_train(False)
|
||||
if file_format == "MINDIR":
|
||||
exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
|
||||
else:
|
||||
exporter = ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=False)
|
||||
deploy_net = exporter.run()
|
||||
serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
|
||||
|
||||
class ExportManualQuantNetwork:
|
||||
"""
|
||||
Convert anual quantization aware network to infer network.
|
||||
|
||||
Args:
|
||||
network (Cell): MindSpore network API `convert_quant_network`.
|
||||
inputs (Tensor): Input tensors of the `quantization aware training network`.
|
||||
mean (int): Input data mean. Default: 127.5.
|
||||
std_dev (int, float): Input data variance. Default: 127.5.
|
||||
is_mindir (bool): Whether is MINDIR format. Default: False.
|
||||
|
||||
Returns:
|
||||
Cell, Infer network.
|
||||
"""
|
||||
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
def __init__(self, network, mean, std_dev, *inputs, is_mindir):
|
||||
network = Validator.check_isinstance('network', network, (nn.Cell,))
|
||||
self.input_scale = 1 / std_dev
|
||||
self.input_zero_point = round(mean)
|
||||
self.data_type = mstype.int8
|
||||
self.network = copy.deepcopy(network)
|
||||
self.all_parameters = {p.name: p for p in self.network.get_parameters()}
|
||||
self.get_inputs_table(inputs)
|
||||
self.mean = mean
|
||||
self.std_dev = std_dev
|
||||
self.is_mindir = is_mindir
|
||||
self.upcell = None
|
||||
self.upname = None
|
||||
|
||||
def get_inputs_table(self, inputs):
|
||||
"""Get the support info for quant export."""
|
||||
phase_name = 'export_quant'
|
||||
graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False)
|
||||
self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id)
|
||||
|
||||
def run(self):
|
||||
"""Start to convert."""
|
||||
self.network.update_cell_prefix()
|
||||
network = self.network
|
||||
if isinstance(network, _AddFakeQuantInput):
|
||||
network = network.network
|
||||
network = self._convert_manual_network(network)
|
||||
return network
|
||||
|
||||
def _convert_manual_network(self, network):
|
||||
"""Convert network's all quant subcell to deploy subcell."""
|
||||
cells = network.name_cells()
|
||||
change = False
|
||||
for name in cells:
|
||||
subcell = cells[name]
|
||||
if subcell == network:
|
||||
continue
|
||||
if isinstance(subcell, quant.Conv2dBnAct):
|
||||
network, change = self._convert_subcell(network, change, name, subcell)
|
||||
elif isinstance(subcell, quant.DenseBnAct):
|
||||
network, change = self._convert_subcell(network, change, name, subcell, conv=False)
|
||||
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant,
|
||||
quant.Conv2dQuant, quant.DenseQuant)):
|
||||
network, change = self._convert_subcell(network, change, name, subcell, core=False)
|
||||
elif isinstance(subcell, quant.FakeQuantWithMinMax) and self.upcell:
|
||||
np_type = mstype.dtype_to_nptype(self.data_type)
|
||||
_, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(subcell, np_type)
|
||||
self.upcell.core_op.add_prim_attr('output_maxq', Tensor(maxq))
|
||||
self.upcell.core_op.add_prim_attr('output_minq', Tensor(minq))
|
||||
network.insert_child_to_cell(self.upname, self.upcell)
|
||||
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
|
||||
op = subcell.subcell
|
||||
if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive):
|
||||
if self.is_mindir:
|
||||
op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy()))
|
||||
op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy()))
|
||||
network.__delattr__(name)
|
||||
network.__setattr__(name, op)
|
||||
change = True
|
||||
else:
|
||||
self.upcell, self.upname = None, None
|
||||
self._convert_manual_network(subcell)
|
||||
if isinstance(network, nn.SequentialCell) and change:
|
||||
network.cell_list = list(network.cells())
|
||||
return network
|
||||
|
||||
def _convert_subcell(self, network, change, name, subcell, core=True, conv=True):
|
||||
"""Convert subcell to ant subcell."""
|
||||
if core:
|
||||
cell_core = subcell.conv if conv else subcell.dense
|
||||
activation = subcell.activation
|
||||
fake_quant_act = activation.fake_quant_act
|
||||
else:
|
||||
cell_core = subcell
|
||||
activation = None
|
||||
fake_quant_act = None
|
||||
new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act)
|
||||
if new_subcell:
|
||||
prefix = subcell.param_prefix
|
||||
new_subcell.update_parameters_name(prefix + '.')
|
||||
self.upcell = None if core else new_subcell
|
||||
self.upname = None if core else name
|
||||
network.insert_child_to_cell(name, new_subcell)
|
||||
change = True
|
||||
return network, change
|
||||
|
||||
def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
|
||||
"""convet network's quant subcell to deploy subcell"""
|
||||
w_minq_name = cell_core.fake_quant_weight.minq.name
|
||||
np_type = mstype.dtype_to_nptype(self.data_type)
|
||||
param_dict = dict()
|
||||
param_dict["filter_maxq"] = None
|
||||
param_dict["filter_minq"] = None
|
||||
param_dict["output_maxq"] = None
|
||||
param_dict["output_minq"] = None
|
||||
param_dict["input_maxq"] = None
|
||||
param_dict["input_minq"] = None
|
||||
param_dict["mean"] = self.mean
|
||||
param_dict["std_dev"] = self.std_dev
|
||||
param_dict["symmetric"] = cell_core.fake_quant_weight.symmetric
|
||||
|
||||
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||
if fake_quant_a_out is not None:
|
||||
_, _, param_dict["output_maxq"], param_dict["output_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
|
||||
|
||||
info = self.quant_info_table.get(w_minq_name, None)
|
||||
if info:
|
||||
fack_quant_a_in_op, minq_name = info
|
||||
if minq_name == 'input':
|
||||
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
|
||||
self.input_scale, self.input_zero_point, 'None', 'None'
|
||||
else:
|
||||
maxq = self.all_parameters[minq_name[:-4] + "maxq"]
|
||||
minq = self.all_parameters[minq_name]
|
||||
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_data(fack_quant_a_in_op, minq, maxq, np_type)
|
||||
else:
|
||||
# skip quant layer
|
||||
scale_a_in, zp_a_in = 1, 0
|
||||
|
||||
# Build the `Quant` `Dequant` op.
|
||||
# Quant only support perlayer version. Need check here.
|
||||
quant_op = inner.Quant(1 / float(scale_a_in), float(zp_a_in))
|
||||
scale_deq = scale_a_in * scale_w
|
||||
dequant_op = inner.Dequant()
|
||||
|
||||
if isinstance(activation, _AddFakeQuantAfterSubCell):
|
||||
activation = activation.subcell
|
||||
elif hasattr(activation, "get_origin"):
|
||||
activation = activation.get_origin()
|
||||
|
||||
# get the `weight` and `bias`
|
||||
weight = cell_core.weight.data.asnumpy()
|
||||
bias = None
|
||||
if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
|
||||
if cell_core.has_bias:
|
||||
bias = cell_core.bias.data.asnumpy()
|
||||
elif isinstance(cell_core, quant.Conv2dBnFoldQuant):
|
||||
weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
|
||||
elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant):
|
||||
weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core)
|
||||
weight_b = weight
|
||||
bias_b = bias
|
||||
# apply the quant
|
||||
fake_quant_weight_op = cell_core.fake_quant_weight.fake_quant_infer
|
||||
weight = quant_utils.weight2int(weight, scale_w, zp_w, np_type, fake_quant_weight_op.num_bits,
|
||||
fake_quant_weight_op.narrow_range)
|
||||
if bias is not None:
|
||||
bias = Tensor(bias / scale_a_in / scale_w, mstype.int32)
|
||||
|
||||
float32_deq_scale = scale_deq.astype(np.float32)
|
||||
uint32_deq_scale = np.frombuffer(float32_deq_scale, np.uint32)
|
||||
scale_length = scale_deq.size # channel
|
||||
dequant_param = np.zeros(scale_length, dtype=np.uint64)
|
||||
for index in range(scale_length):
|
||||
dequant_param[index] += uint32_deq_scale[index]
|
||||
|
||||
scale_deq = Tensor(dequant_param, mstype.uint64)
|
||||
# get op
|
||||
if isinstance(cell_core, quant.DenseQuant):
|
||||
op_core = P.MatMul()
|
||||
weight = np.transpose(weight)
|
||||
weight_b = np.transpose(weight_b)
|
||||
else:
|
||||
op_core = cell_core.conv
|
||||
weight = Tensor(weight, self.data_type)
|
||||
weight_b = Tensor(weight_b)
|
||||
if bias_b is not None:
|
||||
bias_b = Tensor(bias_b, mstype.float32)
|
||||
if self.is_mindir:
|
||||
block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
|
||||
else:
|
||||
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
|
||||
return block
|
||||
|
|
Loading…
Reference in New Issue