forked from mindspore-Ecosystem/mindspore
!2512 [quant]The top level add op prefix_name check error
Merge pull request !2512 from vlne-v1/I1LJMR-quant-the-top-level-add-op-prefix_name-check-error
This commit is contained in:
commit
363a232cbc
|
@ -154,7 +154,9 @@ class ConvertToQuantNetwork:
|
||||||
per_channel=self.act_channel,
|
per_channel=self.act_channel,
|
||||||
symmetric=self.act_symmetric,
|
symmetric=self.act_symmetric,
|
||||||
narrow_range=self.act_range)
|
narrow_range=self.act_range)
|
||||||
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
|
prefix = self._convert_op_name(prim_op.name)
|
||||||
|
if network.param_prefix:
|
||||||
|
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
|
||||||
add_quant.update_parameters_name(prefix + '.')
|
add_quant.update_parameters_name(prefix + '.')
|
||||||
del network.__dict__[name]
|
del network.__dict__[name]
|
||||||
network.insert_child_to_cell(name, add_quant)
|
network.insert_child_to_cell(name, add_quant)
|
||||||
|
|
|
@ -125,7 +125,7 @@ def scale_zp_from_fack_quant_cell(cell, data_type):
|
||||||
"""
|
"""
|
||||||
minq = cell.minq.data.asnumpy()
|
minq = cell.minq.data.asnumpy()
|
||||||
maxq = cell.maxq.data.asnumpy()
|
maxq = cell.maxq.data.asnumpy()
|
||||||
op = cell.fake_quant
|
op = cell.fake_quant_infer
|
||||||
|
|
||||||
scale, zp = cal_quantization_params(
|
scale, zp = cal_quantization_params(
|
||||||
minq, maxq, data_type,
|
minq, maxq, data_type,
|
||||||
|
|
|
@ -67,7 +67,7 @@ def test_qat_lenet():
|
||||||
img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32))
|
img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32))
|
||||||
net = LeNet5()
|
net = LeNet5()
|
||||||
net = qat.convert_quant_network(
|
net = qat.convert_quant_network(
|
||||||
net, quant_delay=0, bn_fold=False, freeze_bn=10000, num_bits=8)
|
net, freeze_bn=10000, num_bits=8)
|
||||||
# should load the checkpoint. mock here
|
# should load the checkpoint. mock here
|
||||||
for param in net.get_parameters():
|
for param in net.get_parameters():
|
||||||
param.init_data()
|
param.init_data()
|
||||||
|
|
Loading…
Reference in New Issue