modify convert model

This commit is contained in:
changzherui 2022-06-08 21:44:05 +08:00
parent 4f679c0cb7
commit 35f72a0cd6
2 changed files with 9 additions and 7 deletions

View File

@ -22,6 +22,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype, pytype_to_dtype
from mindspore.common import dtype as mstype
from mindspore import log as logger
from mindspore._checkparam import Validator
from mindspore.common.api import _cell_graph_executor
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
from mindspore.train.checkpoint_pb2 import Checkpoint
@ -236,7 +237,8 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False):
Returns:
Object, proto object.
"""
Validator.check_file_name_by_regular(file_name)
file_name = os.path.realpath(file_name)
if proto_format == "MINDIR":
model = mindir_model()
elif proto_format == "CKPT":
@ -251,7 +253,9 @@ def read_proto(file_name, proto_format="MINDIR", display_data=False):
pb_content = f.read()
model.ParseFromString(pb_content)
except BaseException as e:
logger.critical("Failed to read the file `%s`, please check the correct of the file.", file_name)
logger.critical(f"Failed to phase the file: {file_name} as format: {proto_format},"
f" please check the correct file and format.")
raise ValueError(e.__str__())
finally:
pass

View File

@ -1856,7 +1856,7 @@ def convert_model(mindir_file, convert_file, file_format):
Args:
mindir_file (str): MindIR file name.
convert_file (str): Convert model file name.
file_format (str): Convert model's format, current version only supports 'ONNX'.
file_format (str): Convert model's format, current version only supports "ONNX".
Raises:
ValueError: If the parameter `mindir_file` is not `str`.
@ -1868,10 +1868,8 @@ def convert_model(mindir_file, convert_file, file_format):
"""
Validator.check_file_name_by_regular(mindir_file)
Validator.check_file_name_by_regular(convert_file)
supported_formats = ["ONNX"]
if file_format not in supported_formats:
raise ValueError(f"For 'convert_model', 'file_format' must be one of {supported_formats},"
f"but got {file_format}.")
if file_format != "ONNX":
raise ValueError(f"For 'convert_model', 'file_format' must be 'ONNX', but got {file_format}.")
net_input = _get_mindir_inputs(mindir_file)
graph = load(mindir_file)
net = nn.GraphCell(graph)