forked from mindspore-Ecosystem/mindspore
!35394 add get_mindir_input api
Merge pull request !35394 from changzherui/add_get_mindir_input
This commit is contained in:
commit
58efca88db
|
@ -115,17 +115,18 @@ mindspore
|
|||
.. mscnautosummary::
|
||||
:toctree: mindspore
|
||||
|
||||
mindspore.save_checkpoint
|
||||
mindspore.load_checkpoint
|
||||
mindspore.load_param_into_net
|
||||
mindspore.async_ckpt_thread_status
|
||||
mindspore.build_searched_strategy
|
||||
mindspore.convert_model
|
||||
mindspore.export
|
||||
mindspore.load
|
||||
mindspore.parse_print
|
||||
mindspore.build_searched_strategy
|
||||
mindspore.merge_sliced_parameter
|
||||
mindspore.load_checkpoint
|
||||
mindspore.load_distributed_checkpoint
|
||||
mindspore.async_ckpt_thread_status
|
||||
mindspore.load_param_into_net
|
||||
mindspore.merge_sliced_parameter
|
||||
mindspore.parse_print
|
||||
mindspore.restore_group_info_list
|
||||
mindspore.save_checkpoint
|
||||
|
||||
调试调优
|
||||
----------
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
mindspore.conver_model
|
||||
======================
|
||||
|
||||
.. py:class:: mindspore.convert_model(mindir_file, convert_file, file_format)
|
||||
|
||||
将MindIR模型转化为其他格式的模型文件。当前版本仅支持转化成ONNX模型。
|
||||
|
||||
.. note::
|
||||
这是一个实验特性,未来API可能会发生的变化。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **mindir_file** (str) - MindIR模型文件名称。
|
||||
- **convert_file** (str) - 转化后的模型文件名称。
|
||||
- **file_format** (str) - 需要转化的文件格式,当前版本仅支持"ONNX"。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - `mindir_file` 参数不是str类型。
|
||||
- **TypeError** - `convert_file` 参数不是str类型。
|
||||
- **ValueError** - `file_format` 参数的值不是"ONNX"。
|
|
@ -236,17 +236,18 @@ Serialization
|
|||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
mindspore.save_checkpoint
|
||||
mindspore.load_checkpoint
|
||||
mindspore.load_param_into_net
|
||||
mindspore.async_ckpt_thread_status
|
||||
mindspore.build_searched_strategy
|
||||
mindspore.convert_model
|
||||
mindspore.export
|
||||
mindspore.load
|
||||
mindspore.parse_print
|
||||
mindspore.build_searched_strategy
|
||||
mindspore.merge_sliced_parameter
|
||||
mindspore.load_checkpoint
|
||||
mindspore.load_distributed_checkpoint
|
||||
mindspore.async_ckpt_thread_status
|
||||
mindspore.load_param_into_net
|
||||
mindspore.merge_sliced_parameter
|
||||
mindspore.parse_print
|
||||
mindspore.restore_group_info_list
|
||||
mindspore.save_checkpoint
|
||||
|
||||
JIT
|
||||
---
|
||||
|
|
|
@ -483,7 +483,7 @@ class Validator:
|
|||
"""Check whether file name is legitimate."""
|
||||
if not isinstance(target, str):
|
||||
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
||||
raise ValueError("{} '{}' must be string, but got {}.".format(prim_name, target, type(target)))
|
||||
raise TypeError("{} '{}' must be string, but got {}.".format(prim_name, target, type(target)))
|
||||
if target.endswith("\\") or target.endswith("/"):
|
||||
prim_name = f"For '{prim_name}', the" if prim_name else "The"
|
||||
raise ValueError(f"{prim_name} '{target}' cannot be a directory path.")
|
||||
|
|
|
@ -24,7 +24,7 @@ from .amp import build_train_network
|
|||
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
|
||||
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\
|
||||
build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint, async_ckpt_thread_status,\
|
||||
restore_group_info_list
|
||||
restore_group_info_list, convert_model
|
||||
from .callback import Callback, LossMonitor, TimeMonitor, ModelCheckpoint, SummaryCollector, CheckpointConfig, \
|
||||
RunContext, LearningRateScheduler, SummaryLandscape, FederatedLearningManager, History, LambdaCallback, \
|
||||
ReduceLROnPlateau, EarlyStopping
|
||||
|
@ -35,7 +35,7 @@ from .train_thor import ConvertNetUtils, ConvertModelUtils
|
|||
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
|
||||
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
|
||||
"load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter",
|
||||
"load_distributed_checkpoint", "async_ckpt_thread_status", "restore_group_info_list"]
|
||||
"load_distributed_checkpoint", "async_ckpt_thread_status", "restore_group_info_list", "convert_model"]
|
||||
__all__.extend(callback.__all__)
|
||||
__all__.extend(summary.__all__)
|
||||
__all__.extend(train_thor.__all__)
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
||||
from mindspore.train.checkpoint_pb2 import Checkpoint
|
||||
|
@ -238,7 +239,8 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False):
|
|||
Returns:
|
||||
Object, proto object.
|
||||
"""
|
||||
|
||||
Validator.check_file_name_by_regular(file_name)
|
||||
file_name = os.path.realpath(file_name)
|
||||
if proto_format == "MINDIR":
|
||||
model = mindir_model()
|
||||
elif proto_format == "CKPT":
|
||||
|
@ -253,7 +255,8 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False):
|
|||
pb_content = f.read()
|
||||
model.ParseFromString(pb_content)
|
||||
except BaseException as e:
|
||||
logger.critical("Failed to read the file `%s`, please check the correct of the file.", file_name)
|
||||
logger.critical(f"Failed to phase the file: {file_name} as format: {proto_format},"
|
||||
f" please check the correct file and format.")
|
||||
raise ValueError(e.__str__())
|
||||
finally:
|
||||
pass
|
||||
|
|
|
@ -52,6 +52,7 @@ class History(Callback):
|
|||
def __init__(self):
|
||||
super(History, self).__init__()
|
||||
self.history = {}
|
||||
self.epoch = None
|
||||
|
||||
def begin(self, run_context):
|
||||
"""
|
||||
|
|
|
@ -44,6 +44,7 @@ from mindspore.common.api import _cell_graph_executor as _executor
|
|||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.initializer import One
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
from mindspore.compression.export import quant_export
|
||||
from mindspore.parallel._cell_wrapper import get_allgather_cell
|
||||
|
@ -51,8 +52,10 @@ from mindspore.parallel._tensor import _load_tensor, _get_tensor_strategy, _get_
|
|||
from mindspore.parallel._tensor import _reshape_param_data
|
||||
from mindspore.parallel._tensor import _reshape_param_data_with_weight
|
||||
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices
|
||||
from mindspore.train._utils import read_proto
|
||||
from .._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file
|
||||
|
||||
|
||||
tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype.int16, "UInt16": mstype.uint16,
|
||||
"Int32": mstype.int32, "UInt32": mstype.uint32, "Int64": mstype.int64, "UInt64": mstype.uint64,
|
||||
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
|
||||
|
@ -62,6 +65,10 @@ tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UIn
|
|||
"Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
|
||||
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
|
||||
|
||||
mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4: mstype.uint16,
|
||||
5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
|
||||
11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
|
||||
|
||||
_ckpt_mutex = Lock()
|
||||
|
||||
# unit is KB
|
||||
|
@ -162,7 +169,7 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
|
|||
exist_ckpt_file_list = []
|
||||
if os.path.exists(checkpoint_dir):
|
||||
for exist_ckpt_name in os.listdir(checkpoint_dir):
|
||||
file_prefix = model_name + "_iteration_"
|
||||
file_prefix = os.path.join(model_name, "_iteration_")
|
||||
if exist_ckpt_name.startswith(file_prefix):
|
||||
exist_ckpt_file_list.append(exist_ckpt_name)
|
||||
|
||||
|
@ -1113,7 +1120,7 @@ def _save_mindir(net, file_name, *inputs, **kwargs):
|
|||
model.ParseFromString(mindir_stream)
|
||||
|
||||
if kwargs.get('dataset'):
|
||||
check_input_data(kwargs['dataset'], data_class=mindspore.dataset.Dataset)
|
||||
check_input_data(kwargs.get('dataset'), data_class=mindspore.dataset.Dataset)
|
||||
dataset = kwargs.get('dataset')
|
||||
_save_dataset_to_mindir(model, dataset)
|
||||
|
||||
|
@ -1871,3 +1878,85 @@ def _calculation_net_size(net):
|
|||
data_total += sys.getsizeof(net_dict[name].data.asnumpy().tobytes()) / 1024
|
||||
|
||||
return data_total
|
||||
|
||||
|
||||
def _get_mindir_inputs(file_name):
|
||||
"""
|
||||
Get MindIR file's inputs.
|
||||
|
||||
Note:
|
||||
1. Parsing encrypted MindIR file is not supported.
|
||||
2. Parsing dynamic shape MindIR file is not supported.
|
||||
|
||||
Args:
|
||||
file_name (str): MindIR file name.
|
||||
|
||||
Returns:
|
||||
Tensor, list(Tensor), the input of MindIR file.
|
||||
|
||||
Raises:
|
||||
TypeError: If the parameter file_name is not `str`.
|
||||
RuntimeError: MindIR's input is not tensor type or has no dims.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = get_mindir_inputs("lenet.mindir")
|
||||
"""
|
||||
Validator.check_file_name_by_regular(file_name)
|
||||
file_name = os.path.realpath(file_name)
|
||||
model = read_proto(file_name)
|
||||
input_tensor = []
|
||||
|
||||
for ele_input in model.graph.input:
|
||||
input_shape = []
|
||||
if not hasattr(ele_input, "tensor") or not hasattr(ele_input.tensor[0], "dims"):
|
||||
raise RuntimeError("MindIR's inputs has no tensor or tensor has no dims, please check MindIR file.")
|
||||
|
||||
for ele_shape in ele_input.tensor[0].dims:
|
||||
input_shape.append(ele_shape)
|
||||
if -1 in input_shape:
|
||||
raise RuntimeError(f"MindIR input's shape is: {input_shape}, dynamic shape is not supported.")
|
||||
|
||||
mindir_type = ele_input.tensor[0].data_type
|
||||
if mindir_type not in mindir_to_tensor_type:
|
||||
raise RuntimeError(f"MindIR input's type: {mindir_type} is not supported.")
|
||||
|
||||
input_type = mindir_to_tensor_type.get(mindir_type)
|
||||
input_tensor.append(Tensor(shape=input_shape, dtype=input_type, init=One()))
|
||||
|
||||
if not input_tensor:
|
||||
logger.warning("The MindIR model has no input, return None.")
|
||||
return None
|
||||
return input_tensor[0] if len(input_tensor) == 1 else input_tensor
|
||||
|
||||
|
||||
def convert_model(mindir_file, convert_file, file_format):
|
||||
"""
|
||||
Convert mindir model to other format model. Current version only support convert to "ONNX" format.
|
||||
|
||||
Note:
|
||||
This is an experimental function that is subject to change or deletion.
|
||||
|
||||
Args:
|
||||
mindir_file (str): MindIR file name.
|
||||
convert_file (str): Convert model file name.
|
||||
file_format (str): Convert model's format, current version only supports "ONNX".
|
||||
|
||||
Raises:
|
||||
TypeError: If the parameter `mindir_file` is not `str`.
|
||||
TypeError: If the parameter `convert_file` is not `str`.
|
||||
ValueError: If the parameter `file_format` is not "ONNX".
|
||||
|
||||
Examples:
|
||||
>>> convert_model("lenet.mindir", "lenet.onnx", "ONNX")
|
||||
"""
|
||||
Validator.check_file_name_by_regular(mindir_file)
|
||||
Validator.check_file_name_by_regular(convert_file)
|
||||
if file_format != "ONNX":
|
||||
raise ValueError(f"For 'convert_model', 'file_format' must be 'ONNX', but got {file_format}.")
|
||||
net_input = _get_mindir_inputs(mindir_file)
|
||||
graph = load(mindir_file)
|
||||
net = nn.GraphCell(graph)
|
||||
if isinstance(net_input, Tensor):
|
||||
export(net, net_input, file_name=convert_file, file_format=file_format)
|
||||
else:
|
||||
export(net, *net_input, file_name=convert_file, file_format=file_format)
|
||||
|
|
|
@ -17,6 +17,7 @@ import os
|
|||
from io import BytesIO
|
||||
import numpy as np
|
||||
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision as CV
|
||||
|
@ -29,7 +30,7 @@ from mindspore.common.initializer import TruncatedNormal
|
|||
from mindspore.common.parameter import ParameterTuple
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.train.serialization import export
|
||||
from mindspore.train.serialization import export, _get_mindir_inputs, convert_model
|
||||
|
||||
|
||||
def weight_variable():
|
||||
|
@ -108,6 +109,16 @@ class LeNet5(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
class InputNet1(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class InputNet2(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return x, y
|
||||
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
|
@ -160,6 +171,77 @@ def test_export_lenet_grad_mindir():
|
|||
os.remove(verify_name)
|
||||
|
||||
|
||||
def test_get_mindir_inputs1():
|
||||
"""
|
||||
Feature: Get MindIR input.
|
||||
Description: Test get mindir input.
|
||||
Expectation: Successfully
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net = InputNet1()
|
||||
input1 = Tensor(np.zeros([32, 10]).astype(np.float32))
|
||||
file_name = "input1.mindir"
|
||||
export(net, input1, file_name=file_name, file_format='MINDIR')
|
||||
input_tensor = _get_mindir_inputs(file_name)
|
||||
assert os.path.exists(file_name)
|
||||
assert input_tensor.shape == (32, 10)
|
||||
assert input_tensor.dtype == mindspore.float32
|
||||
os.remove(file_name)
|
||||
|
||||
|
||||
def test_get_mindir_inputs2():
|
||||
"""
|
||||
Feature: Get MindIR input.
|
||||
Description: Test get mindir input.
|
||||
Expectation: Successfully
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net = InputNet2()
|
||||
input1 = Tensor(np.zeros(1).astype(np.float16))
|
||||
input2 = Tensor(np.zeros([10, 20]), dtype=mstype.int32)
|
||||
file_name = "input2.mindir"
|
||||
export(net, input1, input2, file_name=file_name, file_format='MINDIR')
|
||||
input_tensor = _get_mindir_inputs(file_name)
|
||||
assert os.path.exists(file_name)
|
||||
assert len(input_tensor) == 2
|
||||
assert input_tensor[0].shape == (1,)
|
||||
assert input_tensor[0].dtype == mindspore.float16
|
||||
assert input_tensor[1].shape == (10, 20)
|
||||
assert input_tensor[1].dtype == mindspore.int32
|
||||
os.remove(file_name)
|
||||
|
||||
|
||||
def test_convert_model():
|
||||
"""
|
||||
Feature: Convert mindir to onnx.
|
||||
Description: Test convert.
|
||||
Expectation: Successfully
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
net1 = InputNet1()
|
||||
input1 = Tensor(np.ones([1, 32, 32]).astype(np.float32))
|
||||
mindir_name1 = "lenet1.mindir"
|
||||
export(net1, input1, file_name=mindir_name1, file_format='MINDIR')
|
||||
onnx_name1 = "lenet1.onnx"
|
||||
convert_model(mindir_name1, onnx_name1, "ONNX")
|
||||
assert os.path.exists(mindir_name1)
|
||||
assert os.path.exists(onnx_name1)
|
||||
os.remove(mindir_name1)
|
||||
os.remove(onnx_name1)
|
||||
|
||||
net2 = InputNet2()
|
||||
input1 = Tensor(np.ones(32).astype(np.float32))
|
||||
input2 = Tensor(np.ones([32, 32]).astype(np.float32))
|
||||
mindir_name2 = "lenet2.mindir"
|
||||
export(net2, input1, input2, file_name=mindir_name2, file_format='MINDIR')
|
||||
onnx_name2 = "lenet2.onnx"
|
||||
convert_model(mindir_name2, onnx_name2, "ONNX")
|
||||
assert os.path.exists(mindir_name2)
|
||||
assert os.path.exists(onnx_name2)
|
||||
os.remove(mindir_name2)
|
||||
os.remove(onnx_name2)
|
||||
|
||||
|
||||
def test_export_lenet_with_dataset():
|
||||
"""
|
||||
Feature: Export LeNet with data preprocess to MindIR
|
||||
|
|
Loading…
Reference in New Issue