forked from mindspore-Ecosystem/mindspore
fix quant model export mindir failed bug
This commit is contained in:
parent
dcec68bcc7
commit
275f247946
|
@ -48,7 +48,6 @@ class ExportToQuantInferNetwork:
|
|||
Returns:
|
||||
Cell, Infer network.
|
||||
"""
|
||||
__quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"]
|
||||
|
||||
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
|
||||
network = Validator.check_isinstance('network', network, (nn.Cell,))
|
||||
|
|
|
@ -336,9 +336,8 @@ class QuantizationAwareTraining(Quantizer):
|
|||
symmetric=self.act_symmetric,
|
||||
narrow_range=self.act_range,
|
||||
optimize_option=self.optimize_option)
|
||||
prefix = self._convert_op_name(prim_op.name)
|
||||
if network.param_prefix:
|
||||
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
|
||||
prefix = '.'.join([network.param_prefix, prefix])
|
||||
add_quant.update_parameters_name(prefix + '.')
|
||||
del network.__dict__[name]
|
||||
network.insert_child_to_cell(name, add_quant)
|
||||
|
|
|
@ -30,10 +30,9 @@ parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
|||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||
help='path where the dataset is saved')
|
||||
parser.add_argument('--ckpt_path', type=str, default="",
|
||||
help='if mode is test, must provide path where the trained ckpt file')
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
|
||||
args = parser.parse_args()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -54,4 +53,4 @@ if __name__ == "__main__":
|
|||
|
||||
# export network
|
||||
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
|
||||
export(network, inputs, file_name="lenet_quant", file_format='MINDIR', quant_mode='QUANT')
|
||||
export(network, inputs, file_name="lenet_quant", file_format=args.file_format, quant_mode='QUANT')
|
||||
|
|
Loading…
Reference in New Issue