fix quant model export mindir failed bug

This commit is contained in:
Erpim 2021-05-18 19:41:54 +08:00
parent dcec68bcc7
commit 275f247946
3 changed files with 3 additions and 6 deletions

View File

@ -48,7 +48,6 @@ class ExportToQuantInferNetwork:
Returns: Returns:
Cell, Infer network. Cell, Infer network.
""" """
__quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
network = Validator.check_isinstance('network', network, (nn.Cell,)) network = Validator.check_isinstance('network', network, (nn.Cell,))

View File

@ -336,9 +336,8 @@ class QuantizationAwareTraining(Quantizer):
symmetric=self.act_symmetric, symmetric=self.act_symmetric,
narrow_range=self.act_range, narrow_range=self.act_range,
optimize_option=self.optimize_option) optimize_option=self.optimize_option)
prefix = self._convert_op_name(prim_op.name)
if network.param_prefix: if network.param_prefix:
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)]) prefix = '.'.join([network.param_prefix, prefix])
add_quant.update_parameters_name(prefix + '.') add_quant.update_parameters_name(prefix + '.')
del network.__dict__[name] del network.__dict__[name]
network.insert_child_to_cell(name, add_quant) network.insert_child_to_cell(name, add_quant)

View File

@ -30,10 +30,9 @@ parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU'], choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="", parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file') help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
args = parser.parse_args() args = parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
@ -54,4 +53,4 @@ if __name__ == "__main__":
# export network # export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32) inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='QUANT') export(network, inputs, file_name="lenet_quant", file_format=args.file_format, quant_mode='QUANT')