From deacc65ffdd043c3500cf6d4694c4e3617136b0c Mon Sep 17 00:00:00 2001 From: changzherui Date: Fri, 29 Apr 2022 00:14:30 +0800 Subject: [PATCH] add specify_prefix for load_checkpoint --- .../mindspore/mindspore.load_checkpoint.rst | 9 +- .../python/mindspore/train/serialization.py | 157 ++++++++++-------- tests/ut/python/utils/test_serialize.py | 106 +++++++++++- 3 files changed, 204 insertions(+), 68 deletions(-) diff --git a/docs/api/api_python/mindspore/mindspore.load_checkpoint.rst b/docs/api/api_python/mindspore/mindspore.load_checkpoint.rst index 47f53574221..80460c73220 100644 --- a/docs/api/api_python/mindspore/mindspore.load_checkpoint.rst +++ b/docs/api/api_python/mindspore/mindspore.load_checkpoint.rst @@ -1,10 +1,14 @@ mindspore.load_checkpoint ========================== -.. py:function:: mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM") +.. py:function:: mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM", specify_prefix=None) 加载checkpoint文件。 + .. note:: + - `specify_prefix` 和 `filter_prefix`的功能相互之间没有影响。 + - 如果发现没有参数被成功加载,将会报ValueError. + **参数:** - **ckpt_file_name** (str) – checkpoint的文件名称。 @@ -13,6 +17,7 @@ mindspore.load_checkpoint - **filter_prefix** (Union[str, list[str], tuple[str]]) – 以 `filter_prefix` 开头的参数将不会被加载。默认值:None。 - **dec_key** (Union[None, bytes]) – 用于解密的字节类型密钥,如果值为None,则不需要解密。默认值:None。 - **dec_mode** (str) – 该参数仅当 `dec_key` 不为None时有效。指定解密模式,目前支持“AES-GCM”和“AES-CBC”。默认值:“AES-GCM”。 + - **specify_prefix** (Union[str, list[str], tuple[str]]) – 以 `specify_prefix` 开头的参数将会被加载。默认值:None。 **返回:** @@ -21,3 +26,5 @@ mindspore.load_checkpoint **异常:** - **ValueError** – checkpoint文件格式不正确。 + - **ValueError** – 没有一个参数被成功加载。 + - **ValueError** `specify_prefix` 或者 `filter_prefix` 的数据类型不正确。 diff --git a/mindspore/python/mindspore/train/serialization.py b/mindspore/python/mindspore/train/serialization.py index 5566ba5581f..3943a5615d6 100644 --- a/mindspore/python/mindspore/train/serialization.py +++ b/mindspore/python/mindspore/train/serialization.py @@ -347,15 +347,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, logger.info("Saving checkpoint process is finished.") -def _check_param_prefix(filter_prefix, param_name): - """Checks whether the prefix of parameter name matches the given filter_prefix.""" - for prefix in filter_prefix: - if param_name.find(prefix) == 0 \ - and (param_name == prefix or param_name[len(prefix)] == "." or (prefix and prefix[-1] == ".")): - return True - return False - - def _check_append_dict(append_dict): """Check the argument append_dict for save_checkpoint.""" if append_dict is None: @@ -437,10 +428,15 @@ def load(file_name, **kwargs): return graph -def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"): +def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, + dec_key=None, dec_mode="AES-GCM", specify_prefix=None): """ Load checkpoint info from a specified file. + Note: + 1. `specify_prefix` and `filter_prefix` do not affect each other. + 2. If none of the parameters are loaded from checkpoint file, it will throw ValueError. + Args: ckpt_file_name (str): Checkpoint file name. net (Cell): The network where the parameters will be loaded. Default: None @@ -454,51 +450,38 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N is not required. Default: None. dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. + specify_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the specify_prefix + will be loaded. Default: None. Returns: Dict, key is parameter name, value is a Parameter. Raises: ValueError: Checkpoint file's format is incorrect. + ValueError: Parameter's dict is None after load checkpoint file. + ValueError: The type of `specify_prefix` or `filter_prefix` is incorrect. Examples: >>> from mindspore import load_checkpoint >>> >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" - >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") + >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1", specify_prefix="conv", ) >>> print(param_dict["conv2.weight"]) Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True) """ - ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix) + ckpt_file_name = _check_ckpt_file_name(ckpt_file_name) + specify_prefix = _check_prefix(specify_prefix) + filter_prefix = _check_prefix(filter_prefix) dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) logger.info("Execute the process of loading checkpoint files.") - checkpoint_list = Checkpoint() - - try: - if dec_key is None: - with open(ckpt_file_name, "rb") as f: - pb_content = f.read() - else: - pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode) - if pb_content is None: - raise ValueError("For 'load_checkpoint', Failed to decrypt the checkpoint file.") - checkpoint_list.ParseFromString(pb_content) - except BaseException as e: - if _is_cipher_file(ckpt_file_name): - logger.critical("Failed to read the checkpoint file '%s'. The file may be encrypted, please pass in the " - "correct 'dec_key'.", ckpt_file_name) - else: - logger.critical("Failed to read the checkpoint file '%s' , may not have permission to read it, please " - "check the correct of the file.", ckpt_file_name) - raise ValueError(e.__str__() + "\nFor 'load_checkpoint', failed to read the checkpoint file {}, may not have " - "permission to read it.".format(ckpt_file_name)) + checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode) parameter_dict = {} try: param_data_list = [] for element_id, element in enumerate(checkpoint_list.value): - if filter_prefix is not None and _check_param_prefix(filter_prefix, element.tag): + if not _whether_load_param(specify_prefix, filter_prefix, element.tag): continue data = element.tensor.tensor_content data_type = element.tensor.tensor_type @@ -516,15 +499,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N param_data = float(param_data[0]) elif 'Int' in data_type: param_data = int(param_data[0]) - parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) - elif dims == [1]: - parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) - else: - param_dim = [] - for dim in dims: - param_dim.append(dim) - param_value = param_data.reshape(param_dim) - parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) + elif dims != [1]: + param_data = param_data.reshape(list(dims)) + + parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) logger.info("Loading checkpoint files process is finished.") @@ -543,8 +521,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N return parameter_dict -def _check_checkpoint_param(ckpt_file_name, filter_prefix=None): - """Check function load_checkpoint's parameter.""" +def _check_ckpt_file_name(ckpt_file_name): + """Check function load_checkpoint's cket_file_name.""" if not isinstance(ckpt_file_name, str): raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, " "but got {}.".format(type(ckpt_file_name))) @@ -558,22 +536,71 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None): raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check " "whether the 'ckpt_file_name' is correct.".format(ckpt_file_name)) - if filter_prefix is not None: - if not isinstance(filter_prefix, (str, list, tuple)): - raise TypeError("For 'load_checkpoint', the type of 'filter_prefix' must be string, " - "list[string] or tuple[string] when 'filter_prefix' is not None, but " - f"got {str(type(filter_prefix))}.") - if isinstance(filter_prefix, str): - filter_prefix = (filter_prefix,) - if not filter_prefix: - raise ValueError("For 'load_checkpoint', the argument 'filter_prefix' can't be empty when " - "'filter_prefix' is list or tuple.") - for index, prefix in enumerate(filter_prefix): - if not isinstance(prefix, str): - raise TypeError("For 'load_checkpoint', when 'filter_prefix' is list or tuple, " - "the element in 'filter_prefix' must be string, but got " - f"{str(type(prefix))} at index {index}.") - return ckpt_file_name, filter_prefix + return ckpt_file_name + + +def _check_prefix(prefix): + """Check the correctness of the parameters.""" + if prefix is None: + return prefix + if not isinstance(prefix, (str, list, tuple)): + raise TypeError("For 'load_checkpoint', the type of 'specify_prefix' or 'filter_prefix' must be string, " + "list[string] or tuple[string], but got {}.".format(str(type(prefix)))) + if isinstance(prefix, str): + prefix = (prefix,) + if not prefix: + raise ValueError("For 'load_checkpoint', the argument 'specify_prefix' or 'filter_prefix' can't be empty when" + " 'filter_prefix' is list or tuple.") + for index, pre in enumerate(prefix): + if not isinstance(pre, str): + raise TypeError("For 'load_checkpoint', when 'specify_prefix' or 'filter_prefix' is list or tuple, " + "the element in it must be string, but got " + f"{str(type(pre))} at index {index}.") + if pre == "": + raise ValueError("For 'load_checkpoint', the value of 'specify_prefix' or 'filter_prefix' " + "can't include ''.") + return prefix + + +def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode): + """Parse checkpoint protobuf.""" + checkpoint_list = Checkpoint() + try: + if dec_key is None: + with open(ckpt_file_name, "rb") as f: + pb_content = f.read() + else: + pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode) + if pb_content is None: + raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.") + checkpoint_list.ParseFromString(pb_content) + except BaseException as e: + if _is_cipher_file(ckpt_file_name): + logger.critical("Failed to read the checkpoint file '%s'. The file may be encrypted, please pass in the " + "correct 'dec_key'.", ckpt_file_name) + else: + logger.critical("Failed to read the checkpoint file '%s' , may not have permission to read it, please " + "check the correct of the file.", ckpt_file_name) + raise ValueError(e.__str__() + "\nFor 'load_checkpoint', failed to read the checkpoint file {}, may not have " + "permission to read it.".format(ckpt_file_name)) + return checkpoint_list + + +def _whether_load_param(specify_prefix, filter_prefix, param_name): + """Checks whether the load the parameter after `specify_prefix` or `filter_prefix`.""" + whether_load = True + if specify_prefix: + whether_load = False + for prefix in specify_prefix: + if param_name.startswith(prefix): + whether_load = True + break + if filter_prefix: + for prefix in filter_prefix: + if param_name.startswith(prefix): + whether_load = False + break + return whether_load def load_param_into_net(net, parameter_dict, strict_load=False): @@ -1607,11 +1634,11 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= param_not_in_ckpt.append(param.name) continue - param_rank = rank_list[param.name][0] - skip_merge_split = rank_list[param.name][1] - shard_stride = train_strategy[param.name][4] - if train_strategy[param.name][5]: - shard_size = ckpt_file_len / shard_stride / train_strategy[param.name][5] + param_rank = rank_list.get(param.name)[0] + skip_merge_split = rank_list.get(param.name)[1] + shard_stride = train_strategy.get(param.name)[4] + if train_strategy.get(param.name)[5]: + shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5] else: shard_size = 0 for rank in param_rank: diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index db11245d044..5a6beef2d2c 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -37,9 +37,19 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load from tests.security_utils import security_off_wrap from ..ut_filter import non_graph_engine -class Net(nn.Cell): - """Net definition.""" +class Net(nn.Cell): + """ + Net definition. + parameter name : + conv1.weight + bn1.moving_mean + bn1.moving_variance + bn1.gamma + bn1.beta + fc.weight + fc.bias + """ def __init__(self, num_classes=10): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros") @@ -342,6 +352,98 @@ def test_load_checkpoint_empty_file(): load_checkpoint("empty.ckpt") +def test_load_checkpoint_error_param(): + """ + Feature: Load checkpoint. + Description: Load checkpoint with error param. + Expectation: Raise value error for error param. + """ + context.set_context(mode=context.GRAPH_MODE) + net = Net(10) + ckpt_file = "check_name.ckpt" + save_checkpoint(net, ckpt_file) + with pytest.raises(TypeError): + load_checkpoint(ckpt_file, specify_prefix=123) + with pytest.raises(ValueError): + load_checkpoint(ckpt_file, filter_prefix="") + if os.path.exists(ckpt_file): + os.remove(ckpt_file) + + +def test_load_checkpoint_error_load(): + """ + Feature: Load checkpoint. + Description: Load checkpoint with empty parameter dict. + Expectation: Raise value error for error load. + """ + context.set_context(mode=context.GRAPH_MODE) + net = Net(10) + ckpt_file = "check_name.ckpt" + save_checkpoint(net, ckpt_file) + with pytest.raises(ValueError): + load_checkpoint(ckpt_file, specify_prefix="123") + if os.path.exists(ckpt_file): + os.remove(ckpt_file) + + +def test_load_checkpoint_specify_prefix(): + """ + Feature: Load checkpoint. + Description: Load checkpoint with param `specify_prefix`. + Expectation: Correct loaded checkpoint file. + """ + context.set_context(mode=context.GRAPH_MODE) + net = Net(10) + ckpt_file = "specify_prefix.ckpt" + save_checkpoint(net, ckpt_file) + param_dict = load_checkpoint(ckpt_file, specify_prefix="bn") + assert len(param_dict) == 4 + param_dict = load_checkpoint(ckpt_file, specify_prefix="fc") + assert len(param_dict) == 2 + param_dict = load_checkpoint(ckpt_file, specify_prefix=["fc", "bn"]) + assert len(param_dict) == 6 + if os.path.exists(ckpt_file): + os.remove(ckpt_file) + + +def test_load_checkpoint_filter_prefix(): + """ + Feature: Load checkpoint. + Description: Load checkpoint with param `filter_prefix`. + Expectation: Correct loaded checkpoint file. + """ + context.set_context(mode=context.GRAPH_MODE) + net = Net(10) + ckpt_file = "filter_prefix.ckpt" + save_checkpoint(net, ckpt_file) + param_dict = load_checkpoint(ckpt_file, filter_prefix="fc") + assert len(param_dict) == 5 + param_dict = load_checkpoint(ckpt_file, filter_prefix="bn") + assert len(param_dict) == 3 + param_dict = load_checkpoint(ckpt_file, filter_prefix=["bn", "fc"]) + assert len(param_dict) == 1 + if os.path.exists(ckpt_file): + os.remove(ckpt_file) + + +def test_load_checkpoint_specify_filter_prefix(): + """ + Feature: Load checkpoint. + Description: Load checkpoint with param `filter_prefix` and `specify_prefix`. + Expectation: Correct loaded checkpoint file. + """ + context.set_context(mode=context.GRAPH_MODE) + net = Net(10) + ckpt_file = "specify_filter_prefix.ckpt" + save_checkpoint(net, ckpt_file) + param_dict = load_checkpoint(ckpt_file, specify_prefix="bn", filter_prefix="bn1.moving") + assert len(param_dict) == 2 + param_dict = load_checkpoint(ckpt_file, specify_prefix=["bn", "fc"], filter_prefix="fc.weight") + assert len(param_dict) == 5 + if os.path.exists(ckpt_file): + os.remove(ckpt_file) + + def test_save_and_load_checkpoint_for_network_with_encryption(): """ test save and checkpoint for network with encryption""" context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")