modify export file name check

This commit is contained in:
changzherui 2021-01-25 19:40:02 +08:00
parent c5b9971427
commit f552d84a94
2 changed files with 35 additions and 31 deletions

View File

@ -431,7 +431,7 @@ class Validator:
def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
"""Check whether file name is legitimate."""
if reg is None:
reg = r"^[0-9a-zA-Z\_\-\.\/\\]+$"
reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$"
if re.match(reg, target, flag) is None:
prim_name = f'in `{prim_name}`' if prim_name else ""
raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format(

View File

@ -93,7 +93,7 @@ def _update_param(param, new_param):
if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
if param.data.shape != (1,) and param.data.shape != ():
logger.error("Failed to combine the net and the parameters for param %s.", param.name)
msg = ("Net parameters {} shape({}) is not (1,), inconsitent with parameter_dict's(scalar)."
msg = ("Net parameters {} shape({}) is not (1,), inconsistent with parameter_dict's(scalar)."
.format(param.name, param.data.shape))
raise RuntimeError(msg)
param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype))
@ -244,31 +244,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
"""
if not isinstance(ckpt_file_name, str):
raise ValueError("The ckpt_file_name must be string.")
if not os.path.exists(ckpt_file_name):
raise ValueError("The checkpoint file is not exist.")
if ckpt_file_name[-5:] != ".ckpt":
raise ValueError("Please input the correct checkpoint file name.")
if os.path.getsize(ckpt_file_name) == 0:
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
if filter_prefix is not None:
if not isinstance(filter_prefix, (str, list, tuple)):
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] "
f"when filter_prefix is not None, but got {str(type(filter_prefix))}.")
if isinstance(filter_prefix, str):
filter_prefix = (filter_prefix,)
if not filter_prefix:
raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.")
for index, prefix in enumerate(filter_prefix):
if not isinstance(prefix, str):
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
f"but got {str(type(prefix))} at index {index}.")
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
logger.info("Execute the process of loading checkpoint files.")
checkpoint_list = Checkpoint()
@ -297,7 +273,6 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
param_data = np.concatenate((param_data_list), axis=0)
param_data_list.clear()
dims = element.tensor.dims
if dims == [0]:
if 'Float' in data_type:
param_data = float(param_data[0])
@ -328,6 +303,32 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
return parameter_dict
def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
"""Check function load_checkpoint's parameter."""
if not isinstance(ckpt_file_name, str):
raise ValueError("The ckpt_file_name must be string.")
if not os.path.exists(ckpt_file_name):
raise ValueError("The checkpoint file is not exist.")
if ckpt_file_name[-5:] != ".ckpt":
raise ValueError("Please input the correct checkpoint file name.")
if filter_prefix is not None:
if not isinstance(filter_prefix, (str, list, tuple)):
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] "
f"when filter_prefix is not None, but got {str(type(filter_prefix))}.")
if isinstance(filter_prefix, str):
filter_prefix = (filter_prefix,)
if not filter_prefix:
raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.")
for index, prefix in enumerate(filter_prefix):
if not isinstance(prefix, str):
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
f"but got {str(type(prefix))} at index {index}.")
return ckpt_file_name, filter_prefix
def load_param_into_net(net, parameter_dict, strict_load=False):
"""
Loads parameters into network.
@ -560,13 +561,15 @@ def _export(net, file_name, file_format, *inputs):
if file_format == 'AIR':
phase_name = 'export.air'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
file_name += ".air"
if not file_name.endswith('.air'):
file_name += ".air"
_executor.export(file_name, graph_id)
elif file_format == 'ONNX':
phase_name = 'export.onnx'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(net, graph_id)
file_name += ".onnx"
if not file_name.endswith('.onnx'):
file_name += ".onnx"
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)
@ -574,7 +577,8 @@ def _export(net, file_name, file_format, *inputs):
phase_name = 'export.mindir'
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
file_name += ".mindir"
if not file_name.endswith('.mindir'):
file_name += ".mindir"
with open(file_name, 'wb') as f:
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(onnx_stream)