forked from mindspore-Ecosystem/mindspore
fix quant model export mindir failed bug
This commit is contained in:
parent
dcec68bcc7
commit
275f247946
|
@ -48,7 +48,6 @@ class ExportToQuantInferNetwork:
|
||||||
Returns:
|
Returns:
|
||||||
Cell, Infer network.
|
Cell, Infer network.
|
||||||
"""
|
"""
|
||||||
__quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
|
|
||||||
|
|
||||||
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
|
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
|
||||||
network = Validator.check_isinstance('network', network, (nn.Cell,))
|
network = Validator.check_isinstance('network', network, (nn.Cell,))
|
||||||
|
|
|
@ -336,9 +336,8 @@ class QuantizationAwareTraining(Quantizer):
|
||||||
symmetric=self.act_symmetric,
|
symmetric=self.act_symmetric,
|
||||||
narrow_range=self.act_range,
|
narrow_range=self.act_range,
|
||||||
optimize_option=self.optimize_option)
|
optimize_option=self.optimize_option)
|
||||||
prefix = self._convert_op_name(prim_op.name)
|
|
||||||
if network.param_prefix:
|
if network.param_prefix:
|
||||||
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
|
prefix = '.'.join([network.param_prefix, prefix])
|
||||||
add_quant.update_parameters_name(prefix + '.')
|
add_quant.update_parameters_name(prefix + '.')
|
||||||
del network.__dict__[name]
|
del network.__dict__[name]
|
||||||
network.insert_child_to_cell(name, add_quant)
|
network.insert_child_to_cell(name, add_quant)
|
||||||
|
|
|
@ -30,10 +30,9 @@ parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
choices=['Ascend', 'GPU'],
|
choices=['Ascend', 'GPU'],
|
||||||
help='device where the code will be implemented (default: Ascend)')
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
|
||||||
help='path where the dataset is saved')
|
|
||||||
parser.add_argument('--ckpt_path', type=str, default="",
|
parser.add_argument('--ckpt_path', type=str, default="",
|
||||||
help='if mode is test, must provide path where the trained ckpt file')
|
help='if mode is test, must provide path where the trained ckpt file')
|
||||||
|
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -54,4 +53,4 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# export network
|
# export network
|
||||||
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
|
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||||
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='QUANT')
|
export(network, inputs, file_name="lenet_quant", file_format=args.file_format, quant_mode='QUANT')
|
||||||
|
|
Loading…
Reference in New Issue