diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 028084ccca0..8d0f6b72bf3 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -431,7 +431,7 @@ class Validator: def check_file_name_by_regular(target, reg=None, flag=re.ASCII, prim_name=None): """Check whether file name is legitimate.""" if reg is None: - reg = r"^[0-9a-zA-Z\_\-\.\/\\]+$" + reg = r"^[0-9a-zA-Z\_\-\.\:\/\\]+$" if re.match(reg, target, flag) is None: prim_name = f'in `{prim_name}`' if prim_name else "" raise ValueError("'{}' {} is illegal, it should be match regular'{}' by flags'{}'".format( diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 3a3f878f267..cb2fccf265e 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -93,7 +93,7 @@ def _update_param(param, new_param): if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor): if param.data.shape != (1,) and param.data.shape != (): logger.error("Failed to combine the net and the parameters for param %s.", param.name) - msg = ("Net parameters {} shape({}) is not (1,), inconsitent with parameter_dict's(scalar)." + msg = ("Net parameters {} shape({}) is not (1,), inconsistent with parameter_dict's(scalar)." .format(param.name, param.data.shape)) raise RuntimeError(msg) param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype)) @@ -244,31 +244,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") """ - if not isinstance(ckpt_file_name, str): - raise ValueError("The ckpt_file_name must be string.") - - if not os.path.exists(ckpt_file_name): - raise ValueError("The checkpoint file is not exist.") - - if ckpt_file_name[-5:] != ".ckpt": - raise ValueError("Please input the correct checkpoint file name.") - - if os.path.getsize(ckpt_file_name) == 0: - raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.") - - if filter_prefix is not None: - if not isinstance(filter_prefix, (str, list, tuple)): - raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] " - f"when filter_prefix is not None, but got {str(type(filter_prefix))}.") - if isinstance(filter_prefix, str): - filter_prefix = (filter_prefix,) - if not filter_prefix: - raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.") - for index, prefix in enumerate(filter_prefix): - if not isinstance(prefix, str): - raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], " - f"but got {str(type(prefix))} at index {index}.") - + ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix) logger.info("Execute the process of loading checkpoint files.") checkpoint_list = Checkpoint() @@ -297,7 +273,6 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N param_data = np.concatenate((param_data_list), axis=0) param_data_list.clear() dims = element.tensor.dims - if dims == [0]: if 'Float' in data_type: param_data = float(param_data[0]) @@ -328,6 +303,32 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N return parameter_dict +def _check_checkpoint_param(ckpt_file_name, filter_prefix=None): + """Check function load_checkpoint's parameter.""" + if not isinstance(ckpt_file_name, str): + raise ValueError("The ckpt_file_name must be string.") + + if not os.path.exists(ckpt_file_name): + raise ValueError("The checkpoint file is not exist.") + + if ckpt_file_name[-5:] != ".ckpt": + raise ValueError("Please input the correct checkpoint file name.") + + if filter_prefix is not None: + if not isinstance(filter_prefix, (str, list, tuple)): + raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] " + f"when filter_prefix is not None, but got {str(type(filter_prefix))}.") + if isinstance(filter_prefix, str): + filter_prefix = (filter_prefix,) + if not filter_prefix: + raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.") + for index, prefix in enumerate(filter_prefix): + if not isinstance(prefix, str): + raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], " + f"but got {str(type(prefix))} at index {index}.") + return ckpt_file_name, filter_prefix + + def load_param_into_net(net, parameter_dict, strict_load=False): """ Loads parameters into network. @@ -560,13 +561,15 @@ def _export(net, file_name, file_format, *inputs): if file_format == 'AIR': phase_name = 'export.air' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name) - file_name += ".air" + if not file_name.endswith('.air'): + file_name += ".air" _executor.export(file_name, graph_id) elif file_format == 'ONNX': phase_name = 'export.onnx' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) onnx_stream = _executor._get_func_graph_proto(net, graph_id) - file_name += ".onnx" + if not file_name.endswith('.onnx'): + file_name += ".onnx" with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream) @@ -574,7 +577,8 @@ def _export(net, file_name, file_format, *inputs): phase_name = 'export.mindir' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') - file_name += ".mindir" + if not file_name.endswith('.mindir'): + file_name += ".mindir" with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) f.write(onnx_stream)