!33744 add specify_prefix for load_checkpoint

Merge pull request !33744 from changzherui/add_specify_prefix
This commit is contained in:
i-robot 2022-05-09 07:56:52 +00:00 committed by Gitee
commit ac58831c9a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 204 additions and 68 deletions

View File

@ -1,10 +1,14 @@
mindspore.load_checkpoint mindspore.load_checkpoint
========================== ==========================
.. py:function:: mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM") .. py:function:: mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM", specify_prefix=None)
加载checkpoint文件。 加载checkpoint文件。
.. note::
- `specify_prefix``filter_prefix`的功能相互之间没有影响。
- 如果发现没有参数被成功加载将会报ValueError.
**参数:** **参数:**
- **ckpt_file_name** (str) checkpoint的文件名称。 - **ckpt_file_name** (str) checkpoint的文件名称。
@ -13,6 +17,7 @@ mindspore.load_checkpoint
- **filter_prefix** (Union[str, list[str], tuple[str]]) `filter_prefix` 开头的参数将不会被加载。默认值None。 - **filter_prefix** (Union[str, list[str], tuple[str]]) `filter_prefix` 开头的参数将不会被加载。默认值None。
- **dec_key** (Union[None, bytes]) 用于解密的字节类型密钥如果值为None则不需要解密。默认值None。 - **dec_key** (Union[None, bytes]) 用于解密的字节类型密钥如果值为None则不需要解密。默认值None。
- **dec_mode** (str) 该参数仅当 `dec_key` 不为None时有效。指定解密模式目前支持“AES-GCM”和“AES-CBC”。默认值“AES-GCM”。 - **dec_mode** (str) 该参数仅当 `dec_key` 不为None时有效。指定解密模式目前支持“AES-GCM”和“AES-CBC”。默认值“AES-GCM”。
- **specify_prefix** (Union[str, list[str], tuple[str]]) `specify_prefix` 开头的参数将会被加载。默认值None。
**返回:** **返回:**
@ -21,3 +26,5 @@ mindspore.load_checkpoint
**异常:** **异常:**
- **ValueError** checkpoint文件格式不正确。 - **ValueError** checkpoint文件格式不正确。
- **ValueError** 没有一个参数被成功加载。
- **ValueError** `specify_prefix` 或者 `filter_prefix` 的数据类型不正确。

View File

@ -347,15 +347,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
logger.info("Saving checkpoint process is finished.") logger.info("Saving checkpoint process is finished.")
def _check_param_prefix(filter_prefix, param_name):
"""Checks whether the prefix of parameter name matches the given filter_prefix."""
for prefix in filter_prefix:
if param_name.find(prefix) == 0 \
and (param_name == prefix or param_name[len(prefix)] == "." or (prefix and prefix[-1] == ".")):
return True
return False
def _check_append_dict(append_dict): def _check_append_dict(append_dict):
"""Check the argument append_dict for save_checkpoint.""" """Check the argument append_dict for save_checkpoint."""
if append_dict is None: if append_dict is None:
@ -437,10 +428,15 @@ def load(file_name, **kwargs):
return graph return graph
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"): def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
dec_key=None, dec_mode="AES-GCM", specify_prefix=None):
""" """
Load checkpoint info from a specified file. Load checkpoint info from a specified file.
Note:
1. `specify_prefix` and `filter_prefix` do not affect each other.
2. If none of the parameters are loaded from checkpoint file, it will throw ValueError.
Args: Args:
ckpt_file_name (str): Checkpoint file name. ckpt_file_name (str): Checkpoint file name.
net (Cell): The network where the parameters will be loaded. Default: None net (Cell): The network where the parameters will be loaded. Default: None
@ -454,51 +450,38 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
is not required. Default: None. is not required. Default: None.
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'. mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
specify_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the specify_prefix
will be loaded. Default: None.
Returns: Returns:
Dict, key is parameter name, value is a Parameter. Dict, key is parameter name, value is a Parameter.
Raises: Raises:
ValueError: Checkpoint file's format is incorrect. ValueError: Checkpoint file's format is incorrect.
ValueError: Parameter's dict is None after load checkpoint file.
ValueError: The type of `specify_prefix` or `filter_prefix` is incorrect.
Examples: Examples:
>>> from mindspore import load_checkpoint >>> from mindspore import load_checkpoint
>>> >>>
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1", specify_prefix="conv", )
>>> print(param_dict["conv2.weight"]) >>> print(param_dict["conv2.weight"])
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True) Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
""" """
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix) ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
specify_prefix = _check_prefix(specify_prefix)
filter_prefix = _check_prefix(filter_prefix)
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes)) dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str) dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
logger.info("Execute the process of loading checkpoint files.") logger.info("Execute the process of loading checkpoint files.")
checkpoint_list = Checkpoint() checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode)
try:
if dec_key is None:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
else:
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
if pb_content is None:
raise ValueError("For 'load_checkpoint', Failed to decrypt the checkpoint file.")
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
if _is_cipher_file(ckpt_file_name):
logger.critical("Failed to read the checkpoint file '%s'. The file may be encrypted, please pass in the "
"correct 'dec_key'.", ckpt_file_name)
else:
logger.critical("Failed to read the checkpoint file '%s' , may not have permission to read it, please "
"check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', failed to read the checkpoint file {}, may not have "
"permission to read it.".format(ckpt_file_name))
parameter_dict = {} parameter_dict = {}
try: try:
param_data_list = [] param_data_list = []
for element_id, element in enumerate(checkpoint_list.value): for element_id, element in enumerate(checkpoint_list.value):
if filter_prefix is not None and _check_param_prefix(filter_prefix, element.tag): if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
continue continue
data = element.tensor.tensor_content data = element.tensor.tensor_content
data_type = element.tensor.tensor_type data_type = element.tensor.tensor_type
@ -516,15 +499,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
param_data = float(param_data[0]) param_data = float(param_data[0])
elif 'Int' in data_type: elif 'Int' in data_type:
param_data = int(param_data[0]) param_data = int(param_data[0])
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) elif dims != [1]:
elif dims == [1]: param_data = param_data.reshape(list(dims))
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
else: parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
param_dim = []
for dim in dims:
param_dim.append(dim)
param_value = param_data.reshape(param_dim)
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
logger.info("Loading checkpoint files process is finished.") logger.info("Loading checkpoint files process is finished.")
@ -543,8 +521,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
return parameter_dict return parameter_dict
def _check_checkpoint_param(ckpt_file_name, filter_prefix=None): def _check_ckpt_file_name(ckpt_file_name):
"""Check function load_checkpoint's parameter.""" """Check function load_checkpoint's cket_file_name."""
if not isinstance(ckpt_file_name, str): if not isinstance(ckpt_file_name, str):
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, " raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
"but got {}.".format(type(ckpt_file_name))) "but got {}.".format(type(ckpt_file_name)))
@ -558,22 +536,71 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check " raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
"whether the 'ckpt_file_name' is correct.".format(ckpt_file_name)) "whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
if filter_prefix is not None: return ckpt_file_name
if not isinstance(filter_prefix, (str, list, tuple)):
raise TypeError("For 'load_checkpoint', the type of 'filter_prefix' must be string, "
"list[string] or tuple[string] when 'filter_prefix' is not None, but " def _check_prefix(prefix):
f"got {str(type(filter_prefix))}.") """Check the correctness of the parameters."""
if isinstance(filter_prefix, str): if prefix is None:
filter_prefix = (filter_prefix,) return prefix
if not filter_prefix: if not isinstance(prefix, (str, list, tuple)):
raise ValueError("For 'load_checkpoint', the argument 'filter_prefix' can't be empty when " raise TypeError("For 'load_checkpoint', the type of 'specify_prefix' or 'filter_prefix' must be string, "
"'filter_prefix' is list or tuple.") "list[string] or tuple[string], but got {}.".format(str(type(prefix))))
for index, prefix in enumerate(filter_prefix): if isinstance(prefix, str):
if not isinstance(prefix, str): prefix = (prefix,)
raise TypeError("For 'load_checkpoint', when 'filter_prefix' is list or tuple, " if not prefix:
"the element in 'filter_prefix' must be string, but got " raise ValueError("For 'load_checkpoint', the argument 'specify_prefix' or 'filter_prefix' can't be empty when"
f"{str(type(prefix))} at index {index}.") " 'filter_prefix' is list or tuple.")
return ckpt_file_name, filter_prefix for index, pre in enumerate(prefix):
if not isinstance(pre, str):
raise TypeError("For 'load_checkpoint', when 'specify_prefix' or 'filter_prefix' is list or tuple, "
"the element in it must be string, but got "
f"{str(type(pre))} at index {index}.")
if pre == "":
raise ValueError("For 'load_checkpoint', the value of 'specify_prefix' or 'filter_prefix' "
"can't include ''.")
return prefix
def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
"""Parse checkpoint protobuf."""
checkpoint_list = Checkpoint()
try:
if dec_key is None:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
else:
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
if pb_content is None:
raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
if _is_cipher_file(ckpt_file_name):
logger.critical("Failed to read the checkpoint file '%s'. The file may be encrypted, please pass in the "
"correct 'dec_key'.", ckpt_file_name)
else:
logger.critical("Failed to read the checkpoint file '%s' , may not have permission to read it, please "
"check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', failed to read the checkpoint file {}, may not have "
"permission to read it.".format(ckpt_file_name))
return checkpoint_list
def _whether_load_param(specify_prefix, filter_prefix, param_name):
"""Checks whether the load the parameter after `specify_prefix` or `filter_prefix`."""
whether_load = True
if specify_prefix:
whether_load = False
for prefix in specify_prefix:
if param_name.startswith(prefix):
whether_load = True
break
if filter_prefix:
for prefix in filter_prefix:
if param_name.startswith(prefix):
whether_load = False
break
return whether_load
def load_param_into_net(net, parameter_dict, strict_load=False): def load_param_into_net(net, parameter_dict, strict_load=False):
@ -1607,11 +1634,11 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
param_not_in_ckpt.append(param.name) param_not_in_ckpt.append(param.name)
continue continue
param_rank = rank_list[param.name][0] param_rank = rank_list.get(param.name)[0]
skip_merge_split = rank_list[param.name][1] skip_merge_split = rank_list.get(param.name)[1]
shard_stride = train_strategy[param.name][4] shard_stride = train_strategy.get(param.name)[4]
if train_strategy[param.name][5]: if train_strategy.get(param.name)[5]:
shard_size = ckpt_file_len / shard_stride / train_strategy[param.name][5] shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5]
else: else:
shard_size = 0 shard_size = 0
for rank in param_rank: for rank in param_rank:

View File

@ -37,9 +37,19 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load
from tests.security_utils import security_off_wrap from tests.security_utils import security_off_wrap
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
class Net(nn.Cell):
"""Net definition."""
class Net(nn.Cell):
"""
Net definition.
parameter name :
conv1.weight
bn1.moving_mean
bn1.moving_variance
bn1.gamma
bn1.beta
fc.weight
fc.bias
"""
def __init__(self, num_classes=10): def __init__(self, num_classes=10):
super(Net, self).__init__() super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros") self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
@ -342,6 +352,98 @@ def test_load_checkpoint_empty_file():
load_checkpoint("empty.ckpt") load_checkpoint("empty.ckpt")
def test_load_checkpoint_error_param():
"""
Feature: Load checkpoint.
Description: Load checkpoint with error param.
Expectation: Raise value error for error param.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net(10)
ckpt_file = "check_name.ckpt"
save_checkpoint(net, ckpt_file)
with pytest.raises(TypeError):
load_checkpoint(ckpt_file, specify_prefix=123)
with pytest.raises(ValueError):
load_checkpoint(ckpt_file, filter_prefix="")
if os.path.exists(ckpt_file):
os.remove(ckpt_file)
def test_load_checkpoint_error_load():
"""
Feature: Load checkpoint.
Description: Load checkpoint with empty parameter dict.
Expectation: Raise value error for error load.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net(10)
ckpt_file = "check_name.ckpt"
save_checkpoint(net, ckpt_file)
with pytest.raises(ValueError):
load_checkpoint(ckpt_file, specify_prefix="123")
if os.path.exists(ckpt_file):
os.remove(ckpt_file)
def test_load_checkpoint_specify_prefix():
"""
Feature: Load checkpoint.
Description: Load checkpoint with param `specify_prefix`.
Expectation: Correct loaded checkpoint file.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net(10)
ckpt_file = "specify_prefix.ckpt"
save_checkpoint(net, ckpt_file)
param_dict = load_checkpoint(ckpt_file, specify_prefix="bn")
assert len(param_dict) == 4
param_dict = load_checkpoint(ckpt_file, specify_prefix="fc")
assert len(param_dict) == 2
param_dict = load_checkpoint(ckpt_file, specify_prefix=["fc", "bn"])
assert len(param_dict) == 6
if os.path.exists(ckpt_file):
os.remove(ckpt_file)
def test_load_checkpoint_filter_prefix():
"""
Feature: Load checkpoint.
Description: Load checkpoint with param `filter_prefix`.
Expectation: Correct loaded checkpoint file.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net(10)
ckpt_file = "filter_prefix.ckpt"
save_checkpoint(net, ckpt_file)
param_dict = load_checkpoint(ckpt_file, filter_prefix="fc")
assert len(param_dict) == 5
param_dict = load_checkpoint(ckpt_file, filter_prefix="bn")
assert len(param_dict) == 3
param_dict = load_checkpoint(ckpt_file, filter_prefix=["bn", "fc"])
assert len(param_dict) == 1
if os.path.exists(ckpt_file):
os.remove(ckpt_file)
def test_load_checkpoint_specify_filter_prefix():
"""
Feature: Load checkpoint.
Description: Load checkpoint with param `filter_prefix` and `specify_prefix`.
Expectation: Correct loaded checkpoint file.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net(10)
ckpt_file = "specify_filter_prefix.ckpt"
save_checkpoint(net, ckpt_file)
param_dict = load_checkpoint(ckpt_file, specify_prefix="bn", filter_prefix="bn1.moving")
assert len(param_dict) == 2
param_dict = load_checkpoint(ckpt_file, specify_prefix=["bn", "fc"], filter_prefix="fc.weight")
assert len(param_dict) == 5
if os.path.exists(ckpt_file):
os.remove(ckpt_file)
def test_save_and_load_checkpoint_for_network_with_encryption(): def test_save_and_load_checkpoint_for_network_with_encryption():
""" test save and checkpoint for network with encryption""" """ test save and checkpoint for network with encryption"""
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb") context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")