forked from mindspore-Ecosystem/mindspore
modify export file name check
This commit is contained in:
parent
c5b9971427
commit
f552d84a94
|
@ -431,7 +431,7 @@ class Validator:
|
|||
def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
|
||||
"""Check whether file name is legitimate."""
|
||||
if reg is None:
|
||||
reg = r"^[0-9a-zA-Z\_\-\.\/\\]+$"
|
||||
reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$"
|
||||
if re.match(reg, target, flag) is None:
|
||||
prim_name = f'in `{prim_name}`' if prim_name else ""
|
||||
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format(
|
||||
|
|
|
@ -93,7 +93,7 @@ def _update_param(param, new_param):
|
|||
if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
|
||||
if param.data.shape != (1,) and param.data.shape != ():
|
||||
logger.error("Failed to combine the net and the parameters for param %s.", param.name)
|
||||
msg = ("Net parameters {} shape({}) is not (1,), inconsitent with parameter_dict's(scalar)."
|
||||
msg = ("Net parameters {} shape({}) is not (1,), inconsistent with parameter_dict's(scalar)."
|
||||
.format(param.name, param.data.shape))
|
||||
raise RuntimeError(msg)
|
||||
param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype))
|
||||
|
@ -244,31 +244,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
||||
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
||||
"""
|
||||
if not isinstance(ckpt_file_name, str):
|
||||
raise ValueError("The ckpt_file_name must be string.")
|
||||
|
||||
if not os.path.exists(ckpt_file_name):
|
||||
raise ValueError("The checkpoint file is not exist.")
|
||||
|
||||
if ckpt_file_name[-5:] != ".ckpt":
|
||||
raise ValueError("Please input the correct checkpoint file name.")
|
||||
|
||||
if os.path.getsize(ckpt_file_name) == 0:
|
||||
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
|
||||
|
||||
if filter_prefix is not None:
|
||||
if not isinstance(filter_prefix, (str, list, tuple)):
|
||||
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] "
|
||||
f"when filter_prefix is not None, but got {str(type(filter_prefix))}.")
|
||||
if isinstance(filter_prefix, str):
|
||||
filter_prefix = (filter_prefix,)
|
||||
if not filter_prefix:
|
||||
raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.")
|
||||
for index, prefix in enumerate(filter_prefix):
|
||||
if not isinstance(prefix, str):
|
||||
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
|
||||
f"but got {str(type(prefix))} at index {index}.")
|
||||
|
||||
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
|
||||
logger.info("Execute the process of loading checkpoint files.")
|
||||
checkpoint_list = Checkpoint()
|
||||
|
||||
|
@ -297,7 +273,6 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
param_data = np.concatenate((param_data_list), axis=0)
|
||||
param_data_list.clear()
|
||||
dims = element.tensor.dims
|
||||
|
||||
if dims == [0]:
|
||||
if 'Float' in data_type:
|
||||
param_data = float(param_data[0])
|
||||
|
@ -328,6 +303,32 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
return parameter_dict
|
||||
|
||||
|
||||
def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
|
||||
"""Check function load_checkpoint's parameter."""
|
||||
if not isinstance(ckpt_file_name, str):
|
||||
raise ValueError("The ckpt_file_name must be string.")
|
||||
|
||||
if not os.path.exists(ckpt_file_name):
|
||||
raise ValueError("The checkpoint file is not exist.")
|
||||
|
||||
if ckpt_file_name[-5:] != ".ckpt":
|
||||
raise ValueError("Please input the correct checkpoint file name.")
|
||||
|
||||
if filter_prefix is not None:
|
||||
if not isinstance(filter_prefix, (str, list, tuple)):
|
||||
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] "
|
||||
f"when filter_prefix is not None, but got {str(type(filter_prefix))}.")
|
||||
if isinstance(filter_prefix, str):
|
||||
filter_prefix = (filter_prefix,)
|
||||
if not filter_prefix:
|
||||
raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.")
|
||||
for index, prefix in enumerate(filter_prefix):
|
||||
if not isinstance(prefix, str):
|
||||
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
|
||||
f"but got {str(type(prefix))} at index {index}.")
|
||||
return ckpt_file_name, filter_prefix
|
||||
|
||||
|
||||
def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||
"""
|
||||
Loads parameters into network.
|
||||
|
@ -560,13 +561,15 @@ def _export(net, file_name, file_format, *inputs):
|
|||
if file_format == 'AIR':
|
||||
phase_name = 'export.air'
|
||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
|
||||
file_name += ".air"
|
||||
if not file_name.endswith('.air'):
|
||||
file_name += ".air"
|
||||
_executor.export(file_name, graph_id)
|
||||
elif file_format == 'ONNX':
|
||||
phase_name = 'export.onnx'
|
||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
||||
onnx_stream = _executor._get_func_graph_proto(net, graph_id)
|
||||
file_name += ".onnx"
|
||||
if not file_name.endswith('.onnx'):
|
||||
file_name += ".onnx"
|
||||
with open(file_name, 'wb') as f:
|
||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
f.write(onnx_stream)
|
||||
|
@ -574,7 +577,8 @@ def _export(net, file_name, file_format, *inputs):
|
|||
phase_name = 'export.mindir'
|
||||
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
||||
onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
|
||||
file_name += ".mindir"
|
||||
if not file_name.endswith('.mindir'):
|
||||
file_name += ".mindir"
|
||||
with open(file_name, 'wb') as f:
|
||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
f.write(onnx_stream)
|
||||
|
|
Loading…
Reference in New Issue