forked from OSSInnovation/mindspore
mindir_mode
This commit is contained in:
parent
b993ea0288
commit
2b07d7ffb3
|
@ -342,6 +342,14 @@ class ExportToQuantInferNetwork:
|
|||
network = self._convert_quant2deploy(network)
|
||||
return network
|
||||
|
||||
def statistic_weight(self, weight):
|
||||
out_nums = np.shape(weight)[0]
|
||||
sta_metric = np.zeros((out_nums, 2), dtype=np.float32)
|
||||
for num in range(out_nums):
|
||||
sta_metric[num, 0] = np.min(weight[num])
|
||||
sta_metric[num, 1] = np.max(weight[num])
|
||||
return np.mean(sta_metric[:, 1]).tolist(), np.mean(sta_metric[:, 0]).tolist()
|
||||
|
||||
def _get_quant_block(self, cell_core, activation, fake_quant_a_out):
|
||||
"""convet network's quant subcell to deploy subcell"""
|
||||
# Calculate the scale and zero point
|
||||
|
@ -357,14 +365,12 @@ class ExportToQuantInferNetwork:
|
|||
param_dict["mean"] = self.mean
|
||||
param_dict["std_dev"] = self.std_dev
|
||||
param_dict["symmetric"] = fake_quant_a_out.symmetric
|
||||
if self.is_mindir:
|
||||
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||
scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
|
||||
else:
|
||||
scale_w, zp_w = quant_utils.scale_zp_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||
scale_a_out, _ = quant_utils.scale_zp_from_fake_quant_cell(fake_quant_a_out, np_type)
|
||||
|
||||
scale_w, zp_w, _, _ = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||
scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
|
||||
|
||||
info = self.quant_info_table.get(w_minq_name, None)
|
||||
if info:
|
||||
fack_quant_a_in_op, minq_name = info
|
||||
|
@ -403,6 +409,8 @@ class ExportToQuantInferNetwork:
|
|||
weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
|
||||
elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant):
|
||||
weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core)
|
||||
if self.is_mindir:
|
||||
param_dict["filter_maxq"], param_dict["filter_minq"] = self.statistic_weight(weight)
|
||||
weight_b = weight
|
||||
bias_b = bias
|
||||
# apply the quant
|
||||
|
@ -467,6 +475,9 @@ class ExportToQuantInferNetwork:
|
|||
elif isinstance(subcell, _AddFakeQuantAfterSubCell):
|
||||
op = subcell.subcell
|
||||
if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive):
|
||||
if self.is_mindir:
|
||||
op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy()))
|
||||
op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy()))
|
||||
network.__delattr__(name)
|
||||
network.__setattr__(name, op)
|
||||
change = True
|
||||
|
|
|
@ -120,31 +120,6 @@ def weight2int(data, scale, zero_point):
|
|||
|
||||
return np.round((data / scale) + zero_point)
|
||||
|
||||
|
||||
def scale_zp_from_fake_quant_cell(cell, data_type):
|
||||
r"""
|
||||
Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
|
||||
|
||||
Args:
|
||||
cell (Cell): `mindspore.nn.layer.FakeQuantWithMinMax`
|
||||
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
|
||||
|
||||
Returns:
|
||||
scale (numpy.ndarray): quantization param.
|
||||
zero point (numpy.ndarray): quantization param.
|
||||
"""
|
||||
minq = cell.minq.data.asnumpy()
|
||||
maxq = cell.maxq.data.asnumpy()
|
||||
op = cell.fake_quant_infer
|
||||
|
||||
scale, zp = cal_quantization_params(
|
||||
minq, maxq, data_type,
|
||||
num_bits=op.num_bits,
|
||||
symmetric=op.symmetric,
|
||||
narrow_range=op.narrow_range)
|
||||
return scale, zp
|
||||
|
||||
|
||||
def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
|
||||
"""Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`."""
|
||||
minq = cell.minq.data.asnumpy()
|
||||
|
|
Loading…
Reference in New Issue