add specify_prefix for load_checkpoint

This commit is contained in:
changzherui 2022-04-29 00:14:30 +08:00
parent f8cc3c01fb
commit deacc65ffd
3 changed files with 204 additions and 68 deletions

View File

@ -1,10 +1,14 @@
mindspore.load_checkpoint
==========================
.. py:function:: mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM")
.. py:function:: mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM", specify_prefix=None)
加载checkpoint文件。
.. note::
- `specify_prefix``filter_prefix`的功能相互之间没有影响。
- 如果发现没有参数被成功加载将会报ValueError.
**参数:**
- **ckpt_file_name** (str) checkpoint的文件名称。
@ -13,6 +17,7 @@ mindspore.load_checkpoint
- **filter_prefix** (Union[str, list[str], tuple[str]]) `filter_prefix` 开头的参数将不会被加载。默认值None。
- **dec_key** (Union[None, bytes]) 用于解密的字节类型密钥如果值为None则不需要解密。默认值None。
- **dec_mode** (str) 该参数仅当 `dec_key` 不为None时有效。指定解密模式目前支持“AES-GCM”和“AES-CBC”。默认值“AES-GCM”。
- **specify_prefix** (Union[str, list[str], tuple[str]]) `specify_prefix` 开头的参数将会被加载。默认值None。
**返回:**
@ -21,3 +26,5 @@ mindspore.load_checkpoint
**异常:**
- **ValueError** checkpoint文件格式不正确。
- **ValueError** 没有一个参数被成功加载。
- **ValueError** `specify_prefix` 或者 `filter_prefix` 的数据类型不正确。

View File

@ -347,15 +347,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
logger.info("Saving checkpoint process is finished.")
def _check_param_prefix(filter_prefix, param_name):
"""Checks whether the prefix of parameter name matches the given filter_prefix."""
for prefix in filter_prefix:
if param_name.find(prefix) == 0 \
and (param_name == prefix or param_name[len(prefix)] == "." or (prefix and prefix[-1] == ".")):
return True
return False
def _check_append_dict(append_dict):
"""Check the argument append_dict for save_checkpoint."""
if append_dict is None:
@ -437,10 +428,15 @@ def load(file_name, **kwargs):
return graph
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"):
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
dec_key=None, dec_mode="AES-GCM", specify_prefix=None):
"""
Load checkpoint info from a specified file.
Note:
1. `specify_prefix` and `filter_prefix` do not affect each other.
2. If none of the parameters are loaded from checkpoint file, it will throw ValueError.
Args:
ckpt_file_name (str): Checkpoint file name.
net (Cell): The network where the parameters will be loaded. Default: None
@ -454,51 +450,38 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
is not required. Default: None.
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
specify_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the specify_prefix
will be loaded. Default: None.
Returns:
Dict, key is parameter name, value is a Parameter.
Raises:
ValueError: Checkpoint file's format is incorrect.
ValueError: Parameter's dict is None after load checkpoint file.
ValueError: The type of `specify_prefix` or `filter_prefix` is incorrect.
Examples:
>>> from mindspore import load_checkpoint
>>>
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1", specify_prefix="conv", )
>>> print(param_dict["conv2.weight"])
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
"""
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
specify_prefix = _check_prefix(specify_prefix)
filter_prefix = _check_prefix(filter_prefix)
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
logger.info("Execute the process of loading checkpoint files.")
checkpoint_list = Checkpoint()
try:
if dec_key is None:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
else:
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
if pb_content is None:
raise ValueError("For 'load_checkpoint', Failed to decrypt the checkpoint file.")
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
if _is_cipher_file(ckpt_file_name):
logger.critical("Failed to read the checkpoint file '%s'. The file may be encrypted, please pass in the "
"correct 'dec_key'.", ckpt_file_name)
else:
logger.critical("Failed to read the checkpoint file '%s' , may not have permission to read it, please "
"check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', failed to read the checkpoint file {}, may not have "
"permission to read it.".format(ckpt_file_name))
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode)
parameter_dict = {}
try:
param_data_list = []
for element_id, element in enumerate(checkpoint_list.value):
if filter_prefix is not None and _check_param_prefix(filter_prefix, element.tag):
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
continue
data = element.tensor.tensor_content
data_type = element.tensor.tensor_type
@ -516,15 +499,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
param_data = float(param_data[0])
elif 'Int' in data_type:
param_data = int(param_data[0])
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
elif dims == [1]:
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
else:
param_dim = []
for dim in dims:
param_dim.append(dim)
param_value = param_data.reshape(param_dim)
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
elif dims != [1]:
param_data = param_data.reshape(list(dims))
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
logger.info("Loading checkpoint files process is finished.")
@ -543,8 +521,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
return parameter_dict
def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
"""Check function load_checkpoint's parameter."""
def _check_ckpt_file_name(ckpt_file_name):
"""Check function load_checkpoint's cket_file_name."""
if not isinstance(ckpt_file_name, str):
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
"but got {}.".format(type(ckpt_file_name)))
@ -558,22 +536,71 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
"whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
if filter_prefix is not None:
if not isinstance(filter_prefix, (str, list, tuple)):
raise TypeError("For 'load_checkpoint', the type of 'filter_prefix' must be string, "
"list[string] or tuple[string] when 'filter_prefix' is not None, but "
f"got {str(type(filter_prefix))}.")
if isinstance(filter_prefix, str):
filter_prefix = (filter_prefix,)
if not filter_prefix:
raise ValueError("For 'load_checkpoint', the argument 'filter_prefix' can't be empty when "
"'filter_prefix' is list or tuple.")
for index, prefix in enumerate(filter_prefix):
if not isinstance(prefix, str):
raise TypeError("For 'load_checkpoint', when 'filter_prefix' is list or tuple, "
"the element in 'filter_prefix' must be string, but got "
f"{str(type(prefix))} at index {index}.")
return ckpt_file_name, filter_prefix
return ckpt_file_name
def _check_prefix(prefix):
"""Check the correctness of the parameters."""
if prefix is None:
return prefix
if not isinstance(prefix, (str, list, tuple)):
raise TypeError("For 'load_checkpoint', the type of 'specify_prefix' or 'filter_prefix' must be string, "
"list[string] or tuple[string], but got {}.".format(str(type(prefix))))
if isinstance(prefix, str):
prefix = (prefix,)
if not prefix:
raise ValueError("For 'load_checkpoint', the argument 'specify_prefix' or 'filter_prefix' can't be empty when"
" 'filter_prefix' is list or tuple.")
for index, pre in enumerate(prefix):
if not isinstance(pre, str):
raise TypeError("For 'load_checkpoint', when 'specify_prefix' or 'filter_prefix' is list or tuple, "
"the element in it must be string, but got "
f"{str(type(pre))} at index {index}.")
if pre == "":
raise ValueError("For 'load_checkpoint', the value of 'specify_prefix' or 'filter_prefix' "
"can't include ''.")
return prefix
def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
"""Parse checkpoint protobuf."""
checkpoint_list = Checkpoint()
try:
if dec_key is None:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
else:
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
if pb_content is None:
raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
if _is_cipher_file(ckpt_file_name):
logger.critical("Failed to read the checkpoint file '%s'. The file may be encrypted, please pass in the "
"correct 'dec_key'.", ckpt_file_name)
else:
logger.critical("Failed to read the checkpoint file '%s' , may not have permission to read it, please "
"check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', failed to read the checkpoint file {}, may not have "
"permission to read it.".format(ckpt_file_name))
return checkpoint_list
def _whether_load_param(specify_prefix, filter_prefix, param_name):
"""Checks whether the load the parameter after `specify_prefix` or `filter_prefix`."""
whether_load = True
if specify_prefix:
whether_load = False
for prefix in specify_prefix:
if param_name.startswith(prefix):
whether_load = True
break
if filter_prefix:
for prefix in filter_prefix:
if param_name.startswith(prefix):
whether_load = False
break
return whether_load
def load_param_into_net(net, parameter_dict, strict_load=False):
@ -1607,11 +1634,11 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
param_not_in_ckpt.append(param.name)
continue
param_rank = rank_list[param.name][0]
skip_merge_split = rank_list[param.name][1]
shard_stride = train_strategy[param.name][4]
if train_strategy[param.name][5]:
shard_size = ckpt_file_len / shard_stride / train_strategy[param.name][5]
param_rank = rank_list.get(param.name)[0]
skip_merge_split = rank_list.get(param.name)[1]
shard_stride = train_strategy.get(param.name)[4]
if train_strategy.get(param.name)[5]:
shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5]
else:
shard_size = 0
for rank in param_rank:

View File

@ -37,9 +37,19 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load
from tests.security_utils import security_off_wrap
from ..ut_filter import non_graph_engine
class Net(nn.Cell):
"""Net definition."""
class Net(nn.Cell):
"""
Net definition.
parameter name :
conv1.weight
bn1.moving_mean
bn1.moving_variance
bn1.gamma
bn1.beta
fc.weight
fc.bias
"""
def __init__(self, num_classes=10):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
@ -342,6 +352,98 @@ def test_load_checkpoint_empty_file():
load_checkpoint("empty.ckpt")
def test_load_checkpoint_error_param():
"""
Feature: Load checkpoint.
Description: Load checkpoint with error param.
Expectation: Raise value error for error param.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net(10)
ckpt_file = "check_name.ckpt"
save_checkpoint(net, ckpt_file)
with pytest.raises(TypeError):
load_checkpoint(ckpt_file, specify_prefix=123)
with pytest.raises(ValueError):
load_checkpoint(ckpt_file, filter_prefix="")
if os.path.exists(ckpt_file):
os.remove(ckpt_file)
def test_load_checkpoint_error_load():
"""
Feature: Load checkpoint.
Description: Load checkpoint with empty parameter dict.
Expectation: Raise value error for error load.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net(10)
ckpt_file = "check_name.ckpt"
save_checkpoint(net, ckpt_file)
with pytest.raises(ValueError):
load_checkpoint(ckpt_file, specify_prefix="123")
if os.path.exists(ckpt_file):
os.remove(ckpt_file)
def test_load_checkpoint_specify_prefix():
"""
Feature: Load checkpoint.
Description: Load checkpoint with param `specify_prefix`.
Expectation: Correct loaded checkpoint file.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net(10)
ckpt_file = "specify_prefix.ckpt"
save_checkpoint(net, ckpt_file)
param_dict = load_checkpoint(ckpt_file, specify_prefix="bn")
assert len(param_dict) == 4
param_dict = load_checkpoint(ckpt_file, specify_prefix="fc")
assert len(param_dict) == 2
param_dict = load_checkpoint(ckpt_file, specify_prefix=["fc", "bn"])
assert len(param_dict) == 6
if os.path.exists(ckpt_file):
os.remove(ckpt_file)
def test_load_checkpoint_filter_prefix():
"""
Feature: Load checkpoint.
Description: Load checkpoint with param `filter_prefix`.
Expectation: Correct loaded checkpoint file.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net(10)
ckpt_file = "filter_prefix.ckpt"
save_checkpoint(net, ckpt_file)
param_dict = load_checkpoint(ckpt_file, filter_prefix="fc")
assert len(param_dict) == 5
param_dict = load_checkpoint(ckpt_file, filter_prefix="bn")
assert len(param_dict) == 3
param_dict = load_checkpoint(ckpt_file, filter_prefix=["bn", "fc"])
assert len(param_dict) == 1
if os.path.exists(ckpt_file):
os.remove(ckpt_file)
def test_load_checkpoint_specify_filter_prefix():
"""
Feature: Load checkpoint.
Description: Load checkpoint with param `filter_prefix` and `specify_prefix`.
Expectation: Correct loaded checkpoint file.
"""
context.set_context(mode=context.GRAPH_MODE)
net = Net(10)
ckpt_file = "specify_filter_prefix.ckpt"
save_checkpoint(net, ckpt_file)
param_dict = load_checkpoint(ckpt_file, specify_prefix="bn", filter_prefix="bn1.moving")
assert len(param_dict) == 2
param_dict = load_checkpoint(ckpt_file, specify_prefix=["bn", "fc"], filter_prefix="fc.weight")
assert len(param_dict) == 5
if os.path.exists(ckpt_file):
os.remove(ckpt_file)
def test_save_and_load_checkpoint_for_network_with_encryption():
""" test save and checkpoint for network with encryption"""
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")