diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 46b3cd1934f..cb4cb39e664 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -154,7 +154,9 @@ class ConvertToQuantNetwork: per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) - prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) + prefix = self._convert_op_name(prim_op.name) + if network.param_prefix: + prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) add_quant.update_parameters_name(prefix + '.') del network.__dict__[name] network.insert_child_to_cell(name, add_quant) diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index c9e6ac92e1f..c4a8004012a 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -125,7 +125,7 @@ def scale_zp_from_fack_quant_cell(cell, data_type): """ minq = cell.minq.data.asnumpy() maxq = cell.maxq.data.asnumpy() - op = cell.fake_quant + op = cell.fake_quant_infer scale, zp = cal_quantization_params( minq, maxq, data_type, diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index c9398be4560..54563d86eb7 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -67,7 +67,7 @@ def test_qat_lenet(): img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) net = LeNet5() net = qat.convert_quant_network( - net, quant_delay=0, bn_fold=False, freeze_bn=10000, num_bits=8) + net, freeze_bn=10000, num_bits=8) # should load the checkpoint. mock here for param in net.get_parameters(): param.init_data()