forked from mindspore-Ecosystem/mindspore
!20337 warnning clean for quant
Merge pull request !20337 from zhang_sss/warnning_clean_2
This commit is contained in:
commit
f4957f8eb1
|
@ -203,6 +203,7 @@ class ExportToQuantInferNetwork:
|
|||
network = Validator.check_isinstance('network', network, (nn.Cell,))
|
||||
self.data_type = mstype.int8
|
||||
self.network = copy.deepcopy(network)
|
||||
self.network_bk = copy.deepcopy(network)
|
||||
self.get_inputs_table(inputs)
|
||||
self.mean = mean
|
||||
self.std_dev = std_dev
|
||||
|
@ -217,7 +218,6 @@ class ExportToQuantInferNetwork:
|
|||
|
||||
def run(self):
|
||||
"""Start to convert."""
|
||||
self.network_bk = copy.deepcopy(self.network)
|
||||
self.network.update_cell_prefix()
|
||||
network = self.network
|
||||
if isinstance(network, _AddFakeQuantInput):
|
||||
|
@ -255,6 +255,19 @@ class ExportToQuantInferNetwork:
|
|||
block = QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
|
||||
return block
|
||||
|
||||
def _get_input_quant_param(self, minq_name, np_type, param_dict):
|
||||
"""get input quant parameter for quant block"""
|
||||
fake_quant_a_in_prefix = minq_name[:-5]
|
||||
cells = self.network_bk.cells_and_names()
|
||||
for cell in cells:
|
||||
if cell[0].endswith(fake_quant_a_in_prefix):
|
||||
fake_quant_a_in = cell[1]
|
||||
break
|
||||
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type)
|
||||
param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range
|
||||
return scale_a_in, zp_a_in
|
||||
|
||||
def __get_quant_param(self, cell_core, fake_quant_a_out):
|
||||
"""get parameter for quant block"""
|
||||
w_minq_name = cell_core.fake_quant_weight.minq.name
|
||||
|
@ -291,15 +304,7 @@ class ExportToQuantInferNetwork:
|
|||
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
|
||||
(1 / self.std_dev), round(self.mean), 'None', 'None'
|
||||
else:
|
||||
fake_quant_a_in_prefix = minq_name[:-5]
|
||||
cells = self.network_bk.cells_and_names()
|
||||
for cell in cells:
|
||||
if cell[0].endswith(fake_quant_a_in_prefix):
|
||||
fake_quant_a_in = cell[1]
|
||||
break
|
||||
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type)
|
||||
param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range
|
||||
scale_a_in, zp_a_in = self._get_input_quant_param(minq_name, np_type, param_dict)
|
||||
else:
|
||||
# skip quant layer
|
||||
scale_a_in, zp_a_in = 1.0, 0.0
|
||||
|
|
|
@ -177,6 +177,7 @@ def fake_learned_scale_quant_perlayer_grad_d_param(input_x, alpha, quant_max,
|
|||
quant_max_data = tvm.placeholder(quant_max_shape, name="quant_max_data", dtype=quant_max_dtype)
|
||||
return dout_data, input_data, alpha_data, quant_max_data
|
||||
|
||||
|
||||
@util.check_input_type(dict, dict, dict, dict, dict, dict, bool, str)
|
||||
def fake_learned_scale_quant_perlayer_grad_d(dout, input_x, alpha, quant_max, dx, dalpha, neg_trunc,
|
||||
kernel_name="fake_learned_scale_quant_perlayer_grad_d"):
|
||||
|
|
Loading…
Reference in New Issue