!20337 warnning clean for quant

Merge pull request !20337 from zhang_sss/warnning_clean_2
This commit is contained in:
i-robot 2021-07-16 08:58:24 +00:00 committed by Gitee
commit f4957f8eb1
2 changed files with 16 additions and 10 deletions

View File

@ -203,6 +203,7 @@ class ExportToQuantInferNetwork:
network = Validator.check_isinstance('network', network, (nn.Cell,))
self.data_type = mstype.int8
self.network = copy.deepcopy(network)
self.network_bk = copy.deepcopy(network)
self.get_inputs_table(inputs)
self.mean = mean
self.std_dev = std_dev
@ -217,7 +218,6 @@ class ExportToQuantInferNetwork:
def run(self):
"""Start to convert."""
self.network_bk = copy.deepcopy(self.network)
self.network.update_cell_prefix()
network = self.network
if isinstance(network, _AddFakeQuantInput):
@ -255,6 +255,19 @@ class ExportToQuantInferNetwork:
block = QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
return block
def _get_input_quant_param(self, minq_name, np_type, param_dict):
"""get input quant parameter for quant block"""
fake_quant_a_in_prefix = minq_name[:-5]
cells = self.network_bk.cells_and_names()
for cell in cells:
if cell[0].endswith(fake_quant_a_in_prefix):
fake_quant_a_in = cell[1]
break
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type)
param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range
return scale_a_in, zp_a_in
def __get_quant_param(self, cell_core, fake_quant_a_out):
"""get parameter for quant block"""
w_minq_name = cell_core.fake_quant_weight.minq.name
@ -291,15 +304,7 @@ class ExportToQuantInferNetwork:
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
(1 / self.std_dev), round(self.mean), 'None', 'None'
else:
fake_quant_a_in_prefix = minq_name[:-5]
cells = self.network_bk.cells_and_names()
for cell in cells:
if cell[0].endswith(fake_quant_a_in_prefix):
fake_quant_a_in = cell[1]
break
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type)
param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range
scale_a_in, zp_a_in = self._get_input_quant_param(minq_name, np_type, param_dict)
else:
# skip quant layer
scale_a_in, zp_a_in = 1.0, 0.0

View File

@ -177,6 +177,7 @@ def fake_learned_scale_quant_perlayer_grad_d_param(input_x, alpha, quant_max,
quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype)
return dout_data, input_data, alpha_data, quant_max_data
@util.check_input_type(dict, dict, dict, dict, dict, dict, bool, str)
def fake_learned_scale_quant_perlayer_grad_d(dout, input_x, alpha, quant_max, dx, dalpha, neg_trunc,
kernel_name="fake_learned_scale_quant_perlayer_grad_d"):