From b4a9dcc59b0245ab142492371394e654c72b43f7 Mon Sep 17 00:00:00 2001 From: zhang__sss Date: Thu, 15 Jul 2021 14:31:23 +0800 Subject: [PATCH] code clean --- mindspore/compression/export/quant_export.py | 25 +++++++++++-------- .../fake_learned_scale_quant_perlayer_grad.py | 1 + 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/mindspore/compression/export/quant_export.py b/mindspore/compression/export/quant_export.py index aca23103941..4c6a5468fc0 100644 --- a/mindspore/compression/export/quant_export.py +++ b/mindspore/compression/export/quant_export.py @@ -203,6 +203,7 @@ class ExportToQuantInferNetwork: network = Validator.check_isinstance('network', network, (nn.Cell,)) self.data_type = mstype.int8 self.network = copy.deepcopy(network) + self.network_bk = copy.deepcopy(network) self.get_inputs_table(inputs) self.mean = mean self.std_dev = std_dev @@ -217,7 +218,6 @@ class ExportToQuantInferNetwork: def run(self): """Start to convert.""" - self.network_bk = copy.deepcopy(self.network) self.network.update_cell_prefix() network = self.network if isinstance(network, _AddFakeQuantInput): @@ -255,6 +255,19 @@ class ExportToQuantInferNetwork: block = QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) return block + def _get_input_quant_param(self, minq_name, np_type, param_dict): + """get input quant parameter for quant block""" + fake_quant_a_in_prefix = minq_name[:-5] + cells = self.network_bk.cells_and_names() + for cell in cells: + if cell[0].endswith(fake_quant_a_in_prefix): + fake_quant_a_in = cell[1] + break + scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ + quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type) + param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range + return scale_a_in, zp_a_in + def __get_quant_param(self, cell_core, fake_quant_a_out): """get parameter for quant block""" w_minq_name = cell_core.fake_quant_weight.minq.name @@ -291,15 +304,7 @@ class ExportToQuantInferNetwork: scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ (1 / self.std_dev), round(self.mean), 'None', 'None' else: - fake_quant_a_in_prefix = minq_name[:-5] - cells = self.network_bk.cells_and_names() - for cell in cells: - if cell[0].endswith(fake_quant_a_in_prefix): - fake_quant_a_in = cell[1] - break - scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ - quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type) - param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range + scale_a_in, zp_a_in = self._get_input_quant_param(minq_name, np_type, param_dict) else: # skip quant layer scale_a_in, zp_a_in = 1.0, 0.0 diff --git a/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py index b42fb25b97f..fcc134d35e5 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py @@ -177,6 +177,7 @@ def fake_learned_scale_quant_perlayer_grad_d_param(input_x, alpha, quant_max, quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype) return dout_data, input_data, alpha_data, quant_max_data + @util.check_input_type(dict, dict, dict, dict, dict, dict, bool, str) def fake_learned_scale_quant_perlayer_grad_d(dout, input_x, alpha, quant_max, dx, dalpha, neg_trunc, kernel_name="fake_learned_scale_quant_perlayer_grad_d"):