diff --git a/docs/api/api_python/mindspore.rst b/docs/api/api_python/mindspore.rst index f652fd29de8..8bf082bd6f0 100644 --- a/docs/api/api_python/mindspore.rst +++ b/docs/api/api_python/mindspore.rst @@ -115,17 +115,18 @@ mindspore .. mscnautosummary:: :toctree: mindspore - mindspore.save_checkpoint - mindspore.load_checkpoint - mindspore.load_param_into_net + mindspore.async_ckpt_thread_status + mindspore.build_searched_strategy + mindspore.convert_model mindspore.export mindspore.load - mindspore.parse_print - mindspore.build_searched_strategy - mindspore.merge_sliced_parameter + mindspore.load_checkpoint mindspore.load_distributed_checkpoint - mindspore.async_ckpt_thread_status + mindspore.load_param_into_net + mindspore.merge_sliced_parameter + mindspore.parse_print mindspore.restore_group_info_list + mindspore.save_checkpoint 调试调优 ---------- diff --git a/docs/api/api_python/mindspore/mindspore.convert_model.rst b/docs/api/api_python/mindspore/mindspore.convert_model.rst new file mode 100644 index 00000000000..27d8702d9a9 --- /dev/null +++ b/docs/api/api_python/mindspore/mindspore.convert_model.rst @@ -0,0 +1,21 @@ +mindspore.conver_model +====================== + +.. py:class:: mindspore.convert_model(mindir_file, convert_file, file_format) + + 将MindIR模型转化为其他格式的模型文件。当前版本仅支持转化成ONNX模型。 + + .. note:: + 这是一个实验特性,未来API可能会发生的变化。 + + **参数:** + + - **mindir_file** (str) - MindIR模型文件名称。 + - **convert_file** (str) - 转化后的模型文件名称。 + - **file_format** (str) - 需要转化的文件格式,当前版本仅支持"ONNX"。 + + **异常:** + + - **TypeError** - `mindir_file` 参数不是str类型。 + - **TypeError** - `convert_file` 参数不是str类型。 + - **ValueError** - `file_format` 参数的值不是"ONNX"。 diff --git a/docs/api/api_python_en/mindspore.rst b/docs/api/api_python_en/mindspore.rst index 69dbb5b8325..1694a9084f4 100644 --- a/docs/api/api_python_en/mindspore.rst +++ b/docs/api/api_python_en/mindspore.rst @@ -236,17 +236,18 @@ Serialization :nosignatures: :template: classtemplate.rst - mindspore.save_checkpoint - mindspore.load_checkpoint - mindspore.load_param_into_net + mindspore.async_ckpt_thread_status + mindspore.build_searched_strategy + mindspore.convert_model mindspore.export mindspore.load - mindspore.parse_print - mindspore.build_searched_strategy - mindspore.merge_sliced_parameter + mindspore.load_checkpoint mindspore.load_distributed_checkpoint - mindspore.async_ckpt_thread_status + mindspore.load_param_into_net + mindspore.merge_sliced_parameter + mindspore.parse_print mindspore.restore_group_info_list + mindspore.save_checkpoint JIT --- diff --git a/mindspore/python/mindspore/_checkparam.py b/mindspore/python/mindspore/_checkparam.py index 781356f73a7..010aa005e9f 100644 --- a/mindspore/python/mindspore/_checkparam.py +++ b/mindspore/python/mindspore/_checkparam.py @@ -483,7 +483,7 @@ class Validator: """Check whether file name is legitimate.""" if not isinstance(target, str): prim_name = f"For '{prim_name}', the" if prim_name else "The" - raise ValueError("{} '{}' must be string, but got {}.".format(prim_name, target, type(target))) + raise TypeError("{} '{}' must be string, but got {}.".format(prim_name, target, type(target))) if target.endswith("\\") or target.endswith("/"): prim_name = f"For '{prim_name}', the" if prim_name else "The" raise ValueError(f"{prim_name} '{target}' cannot be a directory path.") diff --git a/mindspore/python/mindspore/train/__init__.py b/mindspore/python/mindspore/train/__init__.py index f557f0eb40f..4432d14d68c 100644 --- a/mindspore/python/mindspore/train/__init__.py +++ b/mindspore/python/mindspore/train/__init__.py @@ -24,7 +24,7 @@ from .amp import build_train_network from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\ build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, async_ckpt_thread_status,\ - restore_group_info_list + restore_group_info_list, convert_model from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, CheckpointConfig, \ RunContext, LearningRateScheduler, SummaryLandscape, FederatedLearningManager, History, LambdaCallback, \ ReduceLROnPlateau, EarlyStopping @@ -35,7 +35,7 @@ from .train_thor import ConvertNetUtils, ConvertModelUtils __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter", - "load_distributed_checkpoint", "async_ckpt_thread_status", "restore_group_info_list"] + "load_distributed_checkpoint", "async_ckpt_thread_status", "restore_group_info_list", "convert_model"] __all__.extend(callback.__all__) __all__.extend(summary.__all__) __all__.extend(train_thor.__all__) diff --git a/mindspore/python/mindspore/train/_utils.py b/mindspore/python/mindspore/train/_utils.py index 951ba053fea..ca749ee44a2 100644 --- a/mindspore/python/mindspore/train/_utils.py +++ b/mindspore/python/mindspore/train/_utils.py @@ -22,6 +22,7 @@ from mindspore.common.tensor import Tensor from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype from mindspore.common import dtype as mstype from mindspore import log as logger +from mindspore._checkparam import Validator from mindspore.common.api import _cell_graph_executor from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model from mindspore.train.checkpoint_pb2 import Checkpoint @@ -238,7 +239,8 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False): Returns: Object, proto object. """ - + Validator.check_file_name_by_regular(file_name) + file_name = os.path.realpath(file_name) if proto_format == "MINDIR": model = mindir_model() elif proto_format == "CKPT": @@ -253,7 +255,8 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False): pb_content = f.read() model.ParseFromString(pb_content) except BaseException as e: - logger.critical("Failed to read the file `%s`, please check the correct of the file.", file_name) + logger.critical(f"Failed to phase the file: {file_name} as format: {proto_format}," + f" please check the correct file and format.") raise ValueError(e.__str__()) finally: pass diff --git a/mindspore/python/mindspore/train/callback/_history.py b/mindspore/python/mindspore/train/callback/_history.py index 186c8e7be02..ddbd09b3c52 100644 --- a/mindspore/python/mindspore/train/callback/_history.py +++ b/mindspore/python/mindspore/train/callback/_history.py @@ -52,6 +52,7 @@ class History(Callback): def __init__(self): super(History, self).__init__() self.history = {} + self.epoch = None def begin(self, run_context): """ diff --git a/mindspore/python/mindspore/train/serialization.py b/mindspore/python/mindspore/train/serialization.py index 5aaaa9a9c59..549934c23ce 100644 --- a/mindspore/python/mindspore/train/serialization.py +++ b/mindspore/python/mindspore/train/serialization.py @@ -44,6 +44,7 @@ from mindspore.common.api import _cell_graph_executor as _executor from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor +from mindspore.common.initializer import One from mindspore.communication.management import get_rank, get_group_size from mindspore.compression.export import quant_export from mindspore.parallel._cell_wrapper import get_allgather_cell @@ -51,8 +52,10 @@ from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_ from mindspore.parallel._tensor import _reshape_param_data from mindspore.parallel._tensor import _reshape_param_data_with_weight from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices +from mindspore.train._utils import read_proto from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file + tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16, "Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64, "Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64, @@ -62,6 +65,10 @@ tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UIn "Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64, "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"} +mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16, + 5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16, + 11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64} + _ckpt_mutex = Lock() # unit is KB @@ -162,7 +169,7 @@ def _save_weight(checkpoint_dir, model_name, iteration, params): exist_ckpt_file_list = [] if os.path.exists(checkpoint_dir): for exist_ckpt_name in os.listdir(checkpoint_dir): - file_prefix = model_name + "_iteration_" + file_prefix = os.path.join(model_name, "_iteration_") if exist_ckpt_name.startswith(file_prefix): exist_ckpt_file_list.append(exist_ckpt_name) @@ -1113,7 +1120,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs): model.ParseFromString(mindir_stream) if kwargs.get('dataset'): - check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset) + check_input_data(kwargs.get('dataset'), data_class=mindspore.dataset.Dataset) dataset = kwargs.get('dataset') _save_dataset_to_mindir(model, dataset) @@ -1871,3 +1878,85 @@ def _calculation_net_size(net): data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024 return data_total + + +def _get_mindir_inputs(file_name): + """ + Get MindIR file's inputs. + + Note: + 1. Parsing encrypted MindIR file is not supported. + 2. Parsing dynamic shape MindIR file is not supported. + + Args: + file_name (str): MindIR file name. + + Returns: + Tensor, list(Tensor), the input of MindIR file. + + Raises: + TypeError: If the parameter file_name is not `str`. + RuntimeError: MindIR's input is not tensor type or has no dims. + + Examples: + >>> input_tensor = get_mindir_inputs("lenet.mindir") + """ + Validator.check_file_name_by_regular(file_name) + file_name = os.path.realpath(file_name) + model = read_proto(file_name) + input_tensor = [] + + for ele_input in model.graph.input: + input_shape = [] + if not hasattr(ele_input, "tensor") or not hasattr(ele_input.tensor[0], "dims"): + raise RuntimeError("MindIR's inputs has no tensor or tensor has no dims, please check MindIR file.") + + for ele_shape in ele_input.tensor[0].dims: + input_shape.append(ele_shape) + if -1 in input_shape: + raise RuntimeError(f"MindIR input's shape is: {input_shape}, dynamic shape is not supported.") + + mindir_type = ele_input.tensor[0].data_type + if mindir_type not in mindir_to_tensor_type: + raise RuntimeError(f"MindIR input's type: {mindir_type} is not supported.") + + input_type = mindir_to_tensor_type.get(mindir_type) + input_tensor.append(Tensor(shape=input_shape, dtype=input_type, init=One())) + + if not input_tensor: + logger.warning("The MindIR model has no input, return None.") + return None + return input_tensor[0] if len(input_tensor) == 1 else input_tensor + + +def convert_model(mindir_file, convert_file, file_format): + """ + Convert mindir model to other format model. Current version only support convert to "ONNX" format. + + Note: + This is an experimental function that is subject to change or deletion. + + Args: + mindir_file (str): MindIR file name. + convert_file (str): Convert model file name. + file_format (str): Convert model's format, current version only supports "ONNX". + + Raises: + TypeError: If the parameter `mindir_file` is not `str`. + TypeError: If the parameter `convert_file` is not `str`. + ValueError: If the parameter `file_format` is not "ONNX". + + Examples: + >>> convert_model("lenet.mindir", "lenet.onnx", "ONNX") + """ + Validator.check_file_name_by_regular(mindir_file) + Validator.check_file_name_by_regular(convert_file) + if file_format != "ONNX": + raise ValueError(f"For 'convert_model', 'file_format' must be 'ONNX', but got {file_format}.") + net_input = _get_mindir_inputs(mindir_file) + graph = load(mindir_file) + net = nn.GraphCell(graph) + if isinstance(net_input, Tensor): + export(net, net_input, file_name=convert_file, file_format=file_format) + else: + export(net, *net_input, file_name=convert_file, file_format=file_format) diff --git a/tests/ut/python/utils/test_export.py b/tests/ut/python/utils/test_export.py index 60b02433bff..5d0b67d5a47 100644 --- a/tests/ut/python/utils/test_export.py +++ b/tests/ut/python/utils/test_export.py @@ -17,6 +17,7 @@ import os from io import BytesIO import numpy as np +import mindspore import mindspore.nn as nn import mindspore.dataset as ds import mindspore.dataset.vision as CV @@ -29,7 +30,7 @@ from mindspore.common.initializer import TruncatedNormal from mindspore.common.parameter import ParameterTuple from mindspore.ops import operations as P from mindspore.ops import composite as C -from mindspore.train.serialization import export +from mindspore.train.serialization import export, _get_mindir_inputs, convert_model def weight_variable(): @@ -108,6 +109,16 @@ class LeNet5(nn.Cell): return x +class InputNet1(nn.Cell): + def construct(self, x): + return x + + +class InputNet2(nn.Cell): + def construct(self, x, y): + return x, y + + class WithLossCell(nn.Cell): def __init__(self, network): super(WithLossCell, self).__init__(auto_prefix=False) @@ -160,6 +171,77 @@ def test_export_lenet_grad_mindir(): os.remove(verify_name) +def test_get_mindir_inputs1(): + """ + Feature: Get MindIR input. + Description: Test get mindir input. + Expectation: Successfully + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + net = InputNet1() + input1 = Tensor(np.zeros([32, 10]).astype(np.float32)) + file_name = "input1.mindir" + export(net, input1, file_name=file_name, file_format='MINDIR') + input_tensor = _get_mindir_inputs(file_name) + assert os.path.exists(file_name) + assert input_tensor.shape == (32, 10) + assert input_tensor.dtype == mindspore.float32 + os.remove(file_name) + + +def test_get_mindir_inputs2(): + """ + Feature: Get MindIR input. + Description: Test get mindir input. + Expectation: Successfully + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + net = InputNet2() + input1 = Tensor(np.zeros(1).astype(np.float16)) + input2 = Tensor(np.zeros([10, 20]), dtype=mstype.int32) + file_name = "input2.mindir" + export(net, input1, input2, file_name=file_name, file_format='MINDIR') + input_tensor = _get_mindir_inputs(file_name) + assert os.path.exists(file_name) + assert len(input_tensor) == 2 + assert input_tensor[0].shape == (1,) + assert input_tensor[0].dtype == mindspore.float16 + assert input_tensor[1].shape == (10, 20) + assert input_tensor[1].dtype == mindspore.int32 + os.remove(file_name) + + +def test_convert_model(): + """ + Feature: Convert mindir to onnx. + Description: Test convert. + Expectation: Successfully + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + net1 = InputNet1() + input1 = Tensor(np.ones([1, 32, 32]).astype(np.float32)) + mindir_name1 = "lenet1.mindir" + export(net1, input1, file_name=mindir_name1, file_format='MINDIR') + onnx_name1 = "lenet1.onnx" + convert_model(mindir_name1, onnx_name1, "ONNX") + assert os.path.exists(mindir_name1) + assert os.path.exists(onnx_name1) + os.remove(mindir_name1) + os.remove(onnx_name1) + + net2 = InputNet2() + input1 = Tensor(np.ones(32).astype(np.float32)) + input2 = Tensor(np.ones([32, 32]).astype(np.float32)) + mindir_name2 = "lenet2.mindir" + export(net2, input1, input2, file_name=mindir_name2, file_format='MINDIR') + onnx_name2 = "lenet2.onnx" + convert_model(mindir_name2, onnx_name2, "ONNX") + assert os.path.exists(mindir_name2) + assert os.path.exists(onnx_name2) + os.remove(mindir_name2) + os.remove(onnx_name2) + + def test_export_lenet_with_dataset(): """ Feature: Export LeNet with data preprocess to MindIR