forked from mindspore-Ecosystem/mindspore
add specify_prefix for load_checkpoint
This commit is contained in:
parent
f8cc3c01fb
commit
deacc65ffd
|
@ -1,10 +1,14 @@
|
|||
mindspore.load_checkpoint
|
||||
==========================
|
||||
|
||||
.. py:function:: mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM")
|
||||
.. py:function:: mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM", specify_prefix=None)
|
||||
|
||||
加载checkpoint文件。
|
||||
|
||||
.. note::
|
||||
- `specify_prefix` 和 `filter_prefix`的功能相互之间没有影响。
|
||||
- 如果发现没有参数被成功加载,将会报ValueError.
|
||||
|
||||
**参数:**
|
||||
|
||||
- **ckpt_file_name** (str) – checkpoint的文件名称。
|
||||
|
@ -13,6 +17,7 @@ mindspore.load_checkpoint
|
|||
- **filter_prefix** (Union[str, list[str], tuple[str]]) – 以 `filter_prefix` 开头的参数将不会被加载。默认值:None。
|
||||
- **dec_key** (Union[None, bytes]) – 用于解密的字节类型密钥,如果值为None,则不需要解密。默认值:None。
|
||||
- **dec_mode** (str) – 该参数仅当 `dec_key` 不为None时有效。指定解密模式,目前支持“AES-GCM”和“AES-CBC”。默认值:“AES-GCM”。
|
||||
- **specify_prefix** (Union[str, list[str], tuple[str]]) – 以 `specify_prefix` 开头的参数将会被加载。默认值:None。
|
||||
|
||||
**返回:**
|
||||
|
||||
|
@ -21,3 +26,5 @@ mindspore.load_checkpoint
|
|||
**异常:**
|
||||
|
||||
- **ValueError** – checkpoint文件格式不正确。
|
||||
- **ValueError** – 没有一个参数被成功加载。
|
||||
- **ValueError** `specify_prefix` 或者 `filter_prefix` 的数据类型不正确。
|
||||
|
|
|
@ -347,15 +347,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|||
logger.info("Saving checkpoint process is finished.")
|
||||
|
||||
|
||||
def _check_param_prefix(filter_prefix, param_name):
|
||||
"""Checks whether the prefix of parameter name matches the given filter_prefix."""
|
||||
for prefix in filter_prefix:
|
||||
if param_name.find(prefix) == 0 \
|
||||
and (param_name == prefix or param_name[len(prefix)] == "." or (prefix and prefix[-1] == ".")):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _check_append_dict(append_dict):
|
||||
"""Check the argument append_dict for save_checkpoint."""
|
||||
if append_dict is None:
|
||||
|
@ -437,10 +428,15 @@ def load(file_name, **kwargs):
|
|||
return graph
|
||||
|
||||
|
||||
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"):
|
||||
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
|
||||
dec_key=None, dec_mode="AES-GCM", specify_prefix=None):
|
||||
"""
|
||||
Load checkpoint info from a specified file.
|
||||
|
||||
Note:
|
||||
1. `specify_prefix` and `filter_prefix` do not affect each other.
|
||||
2. If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
||||
|
||||
Args:
|
||||
ckpt_file_name (str): Checkpoint file name.
|
||||
net (Cell): The network where the parameters will be loaded. Default: None
|
||||
|
@ -454,51 +450,38 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
is not required. Default: None.
|
||||
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
|
||||
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
|
||||
specify_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the specify_prefix
|
||||
will be loaded. Default: None.
|
||||
|
||||
Returns:
|
||||
Dict, key is parameter name, value is a Parameter.
|
||||
|
||||
Raises:
|
||||
ValueError: Checkpoint file's format is incorrect.
|
||||
ValueError: Parameter's dict is None after load checkpoint file.
|
||||
ValueError: The type of `specify_prefix` or `filter_prefix` is incorrect.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore import load_checkpoint
|
||||
>>>
|
||||
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
||||
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
||||
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1", specify_prefix="conv", )
|
||||
>>> print(param_dict["conv2.weight"])
|
||||
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
||||
"""
|
||||
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
|
||||
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
||||
specify_prefix = _check_prefix(specify_prefix)
|
||||
filter_prefix = _check_prefix(filter_prefix)
|
||||
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
||||
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
||||
logger.info("Execute the process of loading checkpoint files.")
|
||||
checkpoint_list = Checkpoint()
|
||||
|
||||
try:
|
||||
if dec_key is None:
|
||||
with open(ckpt_file_name, "rb") as f:
|
||||
pb_content = f.read()
|
||||
else:
|
||||
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
||||
if pb_content is None:
|
||||
raise ValueError("For 'load_checkpoint', Failed to decrypt the checkpoint file.")
|
||||
checkpoint_list.ParseFromString(pb_content)
|
||||
except BaseException as e:
|
||||
if _is_cipher_file(ckpt_file_name):
|
||||
logger.critical("Failed to read the checkpoint file '%s'. The file may be encrypted, please pass in the "
|
||||
"correct 'dec_key'.", ckpt_file_name)
|
||||
else:
|
||||
logger.critical("Failed to read the checkpoint file '%s' , may not have permission to read it, please "
|
||||
"check the correct of the file.", ckpt_file_name)
|
||||
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', failed to read the checkpoint file {}, may not have "
|
||||
"permission to read it.".format(ckpt_file_name))
|
||||
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode)
|
||||
|
||||
parameter_dict = {}
|
||||
try:
|
||||
param_data_list = []
|
||||
for element_id, element in enumerate(checkpoint_list.value):
|
||||
if filter_prefix is not None and _check_param_prefix(filter_prefix, element.tag):
|
||||
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
||||
continue
|
||||
data = element.tensor.tensor_content
|
||||
data_type = element.tensor.tensor_type
|
||||
|
@ -516,15 +499,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
param_data = float(param_data[0])
|
||||
elif 'Int' in data_type:
|
||||
param_data = int(param_data[0])
|
||||
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
||||
elif dims == [1]:
|
||||
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
||||
else:
|
||||
param_dim = []
|
||||
for dim in dims:
|
||||
param_dim.append(dim)
|
||||
param_value = param_data.reshape(param_dim)
|
||||
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
|
||||
elif dims != [1]:
|
||||
param_data = param_data.reshape(list(dims))
|
||||
|
||||
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
||||
|
||||
logger.info("Loading checkpoint files process is finished.")
|
||||
|
||||
|
@ -543,8 +521,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|||
return parameter_dict
|
||||
|
||||
|
||||
def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
|
||||
"""Check function load_checkpoint's parameter."""
|
||||
def _check_ckpt_file_name(ckpt_file_name):
|
||||
"""Check function load_checkpoint's cket_file_name."""
|
||||
if not isinstance(ckpt_file_name, str):
|
||||
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
|
||||
"but got {}.".format(type(ckpt_file_name)))
|
||||
|
@ -558,22 +536,71 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
|
|||
raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
|
||||
"whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
|
||||
|
||||
if filter_prefix is not None:
|
||||
if not isinstance(filter_prefix, (str, list, tuple)):
|
||||
raise TypeError("For 'load_checkpoint', the type of 'filter_prefix' must be string, "
|
||||
"list[string] or tuple[string] when 'filter_prefix' is not None, but "
|
||||
f"got {str(type(filter_prefix))}.")
|
||||
if isinstance(filter_prefix, str):
|
||||
filter_prefix = (filter_prefix,)
|
||||
if not filter_prefix:
|
||||
raise ValueError("For 'load_checkpoint', the argument 'filter_prefix' can't be empty when "
|
||||
"'filter_prefix' is list or tuple.")
|
||||
for index, prefix in enumerate(filter_prefix):
|
||||
if not isinstance(prefix, str):
|
||||
raise TypeError("For 'load_checkpoint', when 'filter_prefix' is list or tuple, "
|
||||
"the element in 'filter_prefix' must be string, but got "
|
||||
f"{str(type(prefix))} at index {index}.")
|
||||
return ckpt_file_name, filter_prefix
|
||||
return ckpt_file_name
|
||||
|
||||
|
||||
def _check_prefix(prefix):
|
||||
"""Check the correctness of the parameters."""
|
||||
if prefix is None:
|
||||
return prefix
|
||||
if not isinstance(prefix, (str, list, tuple)):
|
||||
raise TypeError("For 'load_checkpoint', the type of 'specify_prefix' or 'filter_prefix' must be string, "
|
||||
"list[string] or tuple[string], but got {}.".format(str(type(prefix))))
|
||||
if isinstance(prefix, str):
|
||||
prefix = (prefix,)
|
||||
if not prefix:
|
||||
raise ValueError("For 'load_checkpoint', the argument 'specify_prefix' or 'filter_prefix' can't be empty when"
|
||||
" 'filter_prefix' is list or tuple.")
|
||||
for index, pre in enumerate(prefix):
|
||||
if not isinstance(pre, str):
|
||||
raise TypeError("For 'load_checkpoint', when 'specify_prefix' or 'filter_prefix' is list or tuple, "
|
||||
"the element in it must be string, but got "
|
||||
f"{str(type(pre))} at index {index}.")
|
||||
if pre == "":
|
||||
raise ValueError("For 'load_checkpoint', the value of 'specify_prefix' or 'filter_prefix' "
|
||||
"can't include ''.")
|
||||
return prefix
|
||||
|
||||
|
||||
def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
|
||||
"""Parse checkpoint protobuf."""
|
||||
checkpoint_list = Checkpoint()
|
||||
try:
|
||||
if dec_key is None:
|
||||
with open(ckpt_file_name, "rb") as f:
|
||||
pb_content = f.read()
|
||||
else:
|
||||
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
||||
if pb_content is None:
|
||||
raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
|
||||
checkpoint_list.ParseFromString(pb_content)
|
||||
except BaseException as e:
|
||||
if _is_cipher_file(ckpt_file_name):
|
||||
logger.critical("Failed to read the checkpoint file '%s'. The file may be encrypted, please pass in the "
|
||||
"correct 'dec_key'.", ckpt_file_name)
|
||||
else:
|
||||
logger.critical("Failed to read the checkpoint file '%s' , may not have permission to read it, please "
|
||||
"check the correct of the file.", ckpt_file_name)
|
||||
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', failed to read the checkpoint file {}, may not have "
|
||||
"permission to read it.".format(ckpt_file_name))
|
||||
return checkpoint_list
|
||||
|
||||
|
||||
def _whether_load_param(specify_prefix, filter_prefix, param_name):
|
||||
"""Checks whether the load the parameter after `specify_prefix` or `filter_prefix`."""
|
||||
whether_load = True
|
||||
if specify_prefix:
|
||||
whether_load = False
|
||||
for prefix in specify_prefix:
|
||||
if param_name.startswith(prefix):
|
||||
whether_load = True
|
||||
break
|
||||
if filter_prefix:
|
||||
for prefix in filter_prefix:
|
||||
if param_name.startswith(prefix):
|
||||
whether_load = False
|
||||
break
|
||||
return whether_load
|
||||
|
||||
|
||||
def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||
|
@ -1607,11 +1634,11 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|||
param_not_in_ckpt.append(param.name)
|
||||
continue
|
||||
|
||||
param_rank = rank_list[param.name][0]
|
||||
skip_merge_split = rank_list[param.name][1]
|
||||
shard_stride = train_strategy[param.name][4]
|
||||
if train_strategy[param.name][5]:
|
||||
shard_size = ckpt_file_len / shard_stride / train_strategy[param.name][5]
|
||||
param_rank = rank_list.get(param.name)[0]
|
||||
skip_merge_split = rank_list.get(param.name)[1]
|
||||
shard_stride = train_strategy.get(param.name)[4]
|
||||
if train_strategy.get(param.name)[5]:
|
||||
shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5]
|
||||
else:
|
||||
shard_size = 0
|
||||
for rank in param_rank:
|
||||
|
|
|
@ -37,9 +37,19 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load
|
|||
from tests.security_utils import security_off_wrap
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""Net definition."""
|
||||
|
||||
class Net(nn.Cell):
|
||||
"""
|
||||
Net definition.
|
||||
parameter name :
|
||||
conv1.weight
|
||||
bn1.moving_mean
|
||||
bn1.moving_variance
|
||||
bn1.gamma
|
||||
bn1.beta
|
||||
fc.weight
|
||||
fc.bias
|
||||
"""
|
||||
def __init__(self, num_classes=10):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
|
||||
|
@ -342,6 +352,98 @@ def test_load_checkpoint_empty_file():
|
|||
load_checkpoint("empty.ckpt")
|
||||
|
||||
|
||||
def test_load_checkpoint_error_param():
|
||||
"""
|
||||
Feature: Load checkpoint.
|
||||
Description: Load checkpoint with error param.
|
||||
Expectation: Raise value error for error param.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
ckpt_file = "check_name.ckpt"
|
||||
save_checkpoint(net, ckpt_file)
|
||||
with pytest.raises(TypeError):
|
||||
load_checkpoint(ckpt_file, specify_prefix=123)
|
||||
with pytest.raises(ValueError):
|
||||
load_checkpoint(ckpt_file, filter_prefix="")
|
||||
if os.path.exists(ckpt_file):
|
||||
os.remove(ckpt_file)
|
||||
|
||||
|
||||
def test_load_checkpoint_error_load():
|
||||
"""
|
||||
Feature: Load checkpoint.
|
||||
Description: Load checkpoint with empty parameter dict.
|
||||
Expectation: Raise value error for error load.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
ckpt_file = "check_name.ckpt"
|
||||
save_checkpoint(net, ckpt_file)
|
||||
with pytest.raises(ValueError):
|
||||
load_checkpoint(ckpt_file, specify_prefix="123")
|
||||
if os.path.exists(ckpt_file):
|
||||
os.remove(ckpt_file)
|
||||
|
||||
|
||||
def test_load_checkpoint_specify_prefix():
|
||||
"""
|
||||
Feature: Load checkpoint.
|
||||
Description: Load checkpoint with param `specify_prefix`.
|
||||
Expectation: Correct loaded checkpoint file.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
ckpt_file = "specify_prefix.ckpt"
|
||||
save_checkpoint(net, ckpt_file)
|
||||
param_dict = load_checkpoint(ckpt_file, specify_prefix="bn")
|
||||
assert len(param_dict) == 4
|
||||
param_dict = load_checkpoint(ckpt_file, specify_prefix="fc")
|
||||
assert len(param_dict) == 2
|
||||
param_dict = load_checkpoint(ckpt_file, specify_prefix=["fc", "bn"])
|
||||
assert len(param_dict) == 6
|
||||
if os.path.exists(ckpt_file):
|
||||
os.remove(ckpt_file)
|
||||
|
||||
|
||||
def test_load_checkpoint_filter_prefix():
|
||||
"""
|
||||
Feature: Load checkpoint.
|
||||
Description: Load checkpoint with param `filter_prefix`.
|
||||
Expectation: Correct loaded checkpoint file.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
ckpt_file = "filter_prefix.ckpt"
|
||||
save_checkpoint(net, ckpt_file)
|
||||
param_dict = load_checkpoint(ckpt_file, filter_prefix="fc")
|
||||
assert len(param_dict) == 5
|
||||
param_dict = load_checkpoint(ckpt_file, filter_prefix="bn")
|
||||
assert len(param_dict) == 3
|
||||
param_dict = load_checkpoint(ckpt_file, filter_prefix=["bn", "fc"])
|
||||
assert len(param_dict) == 1
|
||||
if os.path.exists(ckpt_file):
|
||||
os.remove(ckpt_file)
|
||||
|
||||
|
||||
def test_load_checkpoint_specify_filter_prefix():
|
||||
"""
|
||||
Feature: Load checkpoint.
|
||||
Description: Load checkpoint with param `filter_prefix` and `specify_prefix`.
|
||||
Expectation: Correct loaded checkpoint file.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net(10)
|
||||
ckpt_file = "specify_filter_prefix.ckpt"
|
||||
save_checkpoint(net, ckpt_file)
|
||||
param_dict = load_checkpoint(ckpt_file, specify_prefix="bn", filter_prefix="bn1.moving")
|
||||
assert len(param_dict) == 2
|
||||
param_dict = load_checkpoint(ckpt_file, specify_prefix=["bn", "fc"], filter_prefix="fc.weight")
|
||||
assert len(param_dict) == 5
|
||||
if os.path.exists(ckpt_file):
|
||||
os.remove(ckpt_file)
|
||||
|
||||
|
||||
def test_save_and_load_checkpoint_for_network_with_encryption():
|
||||
""" test save and checkpoint for network with encryption"""
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
|
|
Loading…
Reference in New Issue