!35394 add get_mindir_input api

Merge pull request !35394 from changzherui/add_get_mindir_input
This commit is contained in:
i-robot 2022-06-10 08:28:35 +00:00 committed by Gitee
commit 58efca88db
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 220 additions and 22 deletions

View File

@ -115,17 +115,18 @@ mindspore
.. mscnautosummary::
:toctree: mindspore
mindspore.save_checkpoint
mindspore.load_checkpoint
mindspore.load_param_into_net
mindspore.async_ckpt_thread_status
mindspore.build_searched_strategy
mindspore.convert_model
mindspore.export
mindspore.load
mindspore.parse_print
mindspore.build_searched_strategy
mindspore.merge_sliced_parameter
mindspore.load_checkpoint
mindspore.load_distributed_checkpoint
mindspore.async_ckpt_thread_status
mindspore.load_param_into_net
mindspore.merge_sliced_parameter
mindspore.parse_print
mindspore.restore_group_info_list
mindspore.save_checkpoint
调试调优
----------

View File

@ -0,0 +1,21 @@
mindspore.conver_model
======================
.. py:class:: mindspore.convert_model(mindir_file, convert_file, file_format)
将MindIR模型转化为其他格式的模型文件。当前版本仅支持转化成ONNX模型。
.. note::
这是一个实验特性未来API可能会发生的变化。
**参数:**
- **mindir_file** (str) - MindIR模型文件名称。
- **convert_file** (str) - 转化后的模型文件名称。
- **file_format** (str) - 需要转化的文件格式,当前版本仅支持"ONNX"。
**异常:**
- **TypeError** - `mindir_file` 参数不是str类型。
- **TypeError** - `convert_file` 参数不是str类型。
- **ValueError** - `file_format` 参数的值不是"ONNX"。

View File

@ -236,17 +236,18 @@ Serialization
:nosignatures:
:template: classtemplate.rst
mindspore.save_checkpoint
mindspore.load_checkpoint
mindspore.load_param_into_net
mindspore.async_ckpt_thread_status
mindspore.build_searched_strategy
mindspore.convert_model
mindspore.export
mindspore.load
mindspore.parse_print
mindspore.build_searched_strategy
mindspore.merge_sliced_parameter
mindspore.load_checkpoint
mindspore.load_distributed_checkpoint
mindspore.async_ckpt_thread_status
mindspore.load_param_into_net
mindspore.merge_sliced_parameter
mindspore.parse_print
mindspore.restore_group_info_list
mindspore.save_checkpoint
JIT
---

View File

@ -483,7 +483,7 @@ class Validator:
"""Check whether file name is legitimate."""
if not isinstance(target, str):
prim_name = f"For '{prim_name}', the" if prim_name else "The"
raise ValueError("{} '{}' must be string, but got {}.".format(prim_name, target, type(target)))
raise TypeError("{} '{}' must be string, but got {}.".format(prim_name, target, type(target)))
if target.endswith("\\") or target.endswith("/"):
prim_name = f"For '{prim_name}', the" if prim_name else "The"
raise ValueError(f"{prim_name} '{target}' cannot be a directory path.")

View File

@ -24,7 +24,7 @@ from .amp import build_train_network
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\
build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, async_ckpt_thread_status,\
restore_group_info_list
restore_group_info_list, convert_model
from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, CheckpointConfig, \
RunContext, LearningRateScheduler, SummaryLandscape, FederatedLearningManager, History, LambdaCallback, \
ReduceLROnPlateau, EarlyStopping
@ -35,7 +35,7 @@ from .train_thor import ConvertNetUtils, ConvertModelUtils
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
"load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
"load_distributed_checkpoint", "async_ckpt_thread_status", "restore_group_info_list"]
"load_distributed_checkpoint", "async_ckpt_thread_status", "restore_group_info_list", "convert_model"]
__all__.extend(callback.__all__)
__all__.extend(summary.__all__)
__all__.extend(train_thor.__all__)

View File

@ -22,6 +22,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
from mindspore.common import dtype as mstype
from mindspore import log as logger
from mindspore._checkparam import Validator
from mindspore.common.api import _cell_graph_executor
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
from mindspore.train.checkpoint_pb2 import Checkpoint
@ -238,7 +239,8 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False):
Returns:
Object, proto object.
"""
Validator.check_file_name_by_regular(file_name)
file_name = os.path.realpath(file_name)
if proto_format == "MINDIR":
model = mindir_model()
elif proto_format == "CKPT":
@ -253,7 +255,8 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False):
pb_content = f.read()
model.ParseFromString(pb_content)
except BaseException as e:
logger.critical("Failed to read the file `%s`, please check the correct of the file.", file_name)
logger.critical(f"Failed to phase the file: {file_name} as format: {proto_format},"
f" please check the correct file and format.")
raise ValueError(e.__str__())
finally:
pass

View File

@ -52,6 +52,7 @@ class History(Callback):
def __init__(self):
super(History, self).__init__()
self.history = {}
self.epoch = None
def begin(self, run_context):
"""

View File

@ -44,6 +44,7 @@ from mindspore.common.api import _cell_graph_executor as _executor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import One
from mindspore.communication.management import get_rank, get_group_size
from mindspore.compression.export import quant_export
from mindspore.parallel._cell_wrapper import get_allgather_cell
@ -51,8 +52,10 @@ from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_
from mindspore.parallel._tensor import _reshape_param_data
from mindspore.parallel._tensor import _reshape_param_data_with_weight
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
from mindspore.train._utils import read_proto
from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
@ -62,6 +65,10 @@ tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UIn
"Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16,
5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
_ckpt_mutex = Lock()
# unit is KB
@ -162,7 +169,7 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
exist_ckpt_file_list = []
if os.path.exists(checkpoint_dir):
for exist_ckpt_name in os.listdir(checkpoint_dir):
file_prefix = model_name + "_iteration_"
file_prefix = os.path.join(model_name, "_iteration_")
if exist_ckpt_name.startswith(file_prefix):
exist_ckpt_file_list.append(exist_ckpt_name)
@ -1113,7 +1120,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
model.ParseFromString(mindir_stream)
if kwargs.get('dataset'):
check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset)
check_input_data(kwargs.get('dataset'), data_class=mindspore.dataset.Dataset)
dataset = kwargs.get('dataset')
_save_dataset_to_mindir(model, dataset)
@ -1871,3 +1878,85 @@ def _calculation_net_size(net):
data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
return data_total
def _get_mindir_inputs(file_name):
"""
Get MindIR file's inputs.
Note:
1. Parsing encrypted MindIR file is not supported.
2. Parsing dynamic shape MindIR file is not supported.
Args:
file_name (str): MindIR file name.
Returns:
Tensor, list(Tensor), the input of MindIR file.
Raises:
TypeError: If the parameter file_name is not `str`.
RuntimeError: MindIR's input is not tensor type or has no dims.
Examples:
>>> input_tensor = get_mindir_inputs("lenet.mindir")
"""
Validator.check_file_name_by_regular(file_name)
file_name = os.path.realpath(file_name)
model = read_proto(file_name)
input_tensor = []
for ele_input in model.graph.input:
input_shape = []
if not hasattr(ele_input, "tensor") or not hasattr(ele_input.tensor[0], "dims"):
raise RuntimeError("MindIR's inputs has no tensor or tensor has no dims, please check MindIR file.")
for ele_shape in ele_input.tensor[0].dims:
input_shape.append(ele_shape)
if -1 in input_shape:
raise RuntimeError(f"MindIR input's shape is: {input_shape}, dynamic shape is not supported.")
mindir_type = ele_input.tensor[0].data_type
if mindir_type not in mindir_to_tensor_type:
raise RuntimeError(f"MindIR input's type: {mindir_type} is not supported.")
input_type = mindir_to_tensor_type.get(mindir_type)
input_tensor.append(Tensor(shape=input_shape, dtype=input_type, init=One()))
if not input_tensor:
logger.warning("The MindIR model has no input, return None.")
return None
return input_tensor[0] if len(input_tensor) == 1 else input_tensor
def convert_model(mindir_file, convert_file, file_format):
"""
Convert mindir model to other format model. Current version only support convert to "ONNX" format.
Note:
This is an experimental function that is subject to change or deletion.
Args:
mindir_file (str): MindIR file name.
convert_file (str): Convert model file name.
file_format (str): Convert model's format, current version only supports "ONNX".
Raises:
TypeError: If the parameter `mindir_file` is not `str`.
TypeError: If the parameter `convert_file` is not `str`.
ValueError: If the parameter `file_format` is not "ONNX".
Examples:
>>> convert_model("lenet.mindir", "lenet.onnx", "ONNX")
"""
Validator.check_file_name_by_regular(mindir_file)
Validator.check_file_name_by_regular(convert_file)
if file_format != "ONNX":
raise ValueError(f"For 'convert_model', 'file_format' must be 'ONNX', but got {file_format}.")
net_input = _get_mindir_inputs(mindir_file)
graph = load(mindir_file)
net = nn.GraphCell(graph)
if isinstance(net_input, Tensor):
export(net, net_input, file_name=convert_file, file_format=file_format)
else:
export(net, *net_input, file_name=convert_file, file_format=file_format)

View File

@ -17,6 +17,7 @@ import os
from io import BytesIO
import numpy as np
import mindspore
import mindspore.nn as nn
import mindspore.dataset as ds
import mindspore.dataset.vision as CV
@ -29,7 +30,7 @@ from mindspore.common.initializer import TruncatedNormal
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.train.serialization import export
from mindspore.train.serialization import export, _get_mindir_inputs, convert_model
def weight_variable():
@ -108,6 +109,16 @@ class LeNet5(nn.Cell):
return x
class InputNet1(nn.Cell):
def construct(self, x):
return x
class InputNet2(nn.Cell):
def construct(self, x, y):
return x, y
class WithLossCell(nn.Cell):
def __init__(self, network):
super(WithLossCell, self).__init__(auto_prefix=False)
@ -160,6 +171,77 @@ def test_export_lenet_grad_mindir():
os.remove(verify_name)
def test_get_mindir_inputs1():
"""
Feature: Get MindIR input.
Description: Test get mindir input.
Expectation: Successfully
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = InputNet1()
input1 = Tensor(np.zeros([32, 10]).astype(np.float32))
file_name = "input1.mindir"
export(net, input1, file_name=file_name, file_format='MINDIR')
input_tensor = _get_mindir_inputs(file_name)
assert os.path.exists(file_name)
assert input_tensor.shape == (32, 10)
assert input_tensor.dtype == mindspore.float32
os.remove(file_name)
def test_get_mindir_inputs2():
"""
Feature: Get MindIR input.
Description: Test get mindir input.
Expectation: Successfully
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = InputNet2()
input1 = Tensor(np.zeros(1).astype(np.float16))
input2 = Tensor(np.zeros([10, 20]), dtype=mstype.int32)
file_name = "input2.mindir"
export(net, input1, input2, file_name=file_name, file_format='MINDIR')
input_tensor = _get_mindir_inputs(file_name)
assert os.path.exists(file_name)
assert len(input_tensor) == 2
assert input_tensor[0].shape == (1,)
assert input_tensor[0].dtype == mindspore.float16
assert input_tensor[1].shape == (10, 20)
assert input_tensor[1].dtype == mindspore.int32
os.remove(file_name)
def test_convert_model():
"""
Feature: Convert mindir to onnx.
Description: Test convert.
Expectation: Successfully
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net1 = InputNet1()
input1 = Tensor(np.ones([1, 32, 32]).astype(np.float32))
mindir_name1 = "lenet1.mindir"
export(net1, input1, file_name=mindir_name1, file_format='MINDIR')
onnx_name1 = "lenet1.onnx"
convert_model(mindir_name1, onnx_name1, "ONNX")
assert os.path.exists(mindir_name1)
assert os.path.exists(onnx_name1)
os.remove(mindir_name1)
os.remove(onnx_name1)
net2 = InputNet2()
input1 = Tensor(np.ones(32).astype(np.float32))
input2 = Tensor(np.ones([32, 32]).astype(np.float32))
mindir_name2 = "lenet2.mindir"
export(net2, input1, input2, file_name=mindir_name2, file_format='MINDIR')
onnx_name2 = "lenet2.onnx"
convert_model(mindir_name2, onnx_name2, "ONNX")
assert os.path.exists(mindir_name2)
assert os.path.exists(onnx_name2)
os.remove(mindir_name2)
os.remove(onnx_name2)
def test_export_lenet_with_dataset():
"""
Feature: Export LeNet with data preprocess to MindIR