modify convert model
This commit is contained in:
parent
4f679c0cb7
commit
35f72a0cd6
|
@ -22,6 +22,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
||||
from mindspore.train.checkpoint_pb2 import Checkpoint
|
||||
|
@ -236,7 +237,8 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False):
|
|||
Returns:
|
||||
Object, proto object.
|
||||
"""
|
||||
|
||||
Validator.check_file_name_by_regular(file_name)
|
||||
file_name = os.path.realpath(file_name)
|
||||
if proto_format == "MINDIR":
|
||||
model = mindir_model()
|
||||
elif proto_format == "CKPT":
|
||||
|
@ -251,7 +253,9 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False):
|
|||
pb_content = f.read()
|
||||
model.ParseFromString(pb_content)
|
||||
except BaseException as e:
|
||||
logger.critical("Failed to read the file `%s`, please check the correct of the file.", file_name)
|
||||
|
||||
logger.critical(f"Failed to phase the file: {file_name} as format: {proto_format},"
|
||||
f" please check the correct file and format.")
|
||||
raise ValueError(e.__str__())
|
||||
finally:
|
||||
pass
|
||||
|
|
|
@ -1856,7 +1856,7 @@ def convert_model(mindir_file, convert_file, file_format):
|
|||
Args:
|
||||
mindir_file (str): MindIR file name.
|
||||
convert_file (str): Convert model file name.
|
||||
file_format (str): Convert model's format, current version only supports 'ONNX'.
|
||||
file_format (str): Convert model's format, current version only supports "ONNX".
|
||||
|
||||
Raises:
|
||||
ValueError: If the parameter `mindir_file` is not `str`.
|
||||
|
@ -1868,10 +1868,8 @@ def convert_model(mindir_file, convert_file, file_format):
|
|||
"""
|
||||
Validator.check_file_name_by_regular(mindir_file)
|
||||
Validator.check_file_name_by_regular(convert_file)
|
||||
supported_formats = ["ONNX"]
|
||||
if file_format not in supported_formats:
|
||||
raise ValueError(f"For 'convert_model', 'file_format' must be one of {supported_formats},"
|
||||
f"but got {file_format}.")
|
||||
if file_format != "ONNX":
|
||||
raise ValueError(f"For 'convert_model', 'file_format' must be 'ONNX', but got {file_format}.")
|
||||
net_input = _get_mindir_inputs(mindir_file)
|
||||
graph = load(mindir_file)
|
||||
net = nn.GraphCell(graph)
|
||||
|
|
Loading…
Reference in New Issue