!33744 add specify_prefix for load_checkpoint
Merge pull request !33744 from changzherui/add_specify_prefix
This commit is contained in:
commit
ac58831c9a
|
@ -1,10 +1,14 @@
|
||||||
mindspore.load_checkpoint
|
mindspore.load_checkpoint
|
||||||
==========================
|
==========================
|
||||||
|
|
||||||
.. py:function:: mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM")
|
.. py:function:: mindspore.load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM", specify_prefix=None)
|
||||||
|
|
||||||
加载checkpoint文件。
|
加载checkpoint文件。
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
- `specify_prefix` 和 `filter_prefix`的功能相互之间没有影响。
|
||||||
|
- 如果发现没有参数被成功加载,将会报ValueError.
|
||||||
|
|
||||||
**参数:**
|
**参数:**
|
||||||
|
|
||||||
- **ckpt_file_name** (str) – checkpoint的文件名称。
|
- **ckpt_file_name** (str) – checkpoint的文件名称。
|
||||||
|
@ -13,6 +17,7 @@ mindspore.load_checkpoint
|
||||||
- **filter_prefix** (Union[str, list[str], tuple[str]]) – 以 `filter_prefix` 开头的参数将不会被加载。默认值:None。
|
- **filter_prefix** (Union[str, list[str], tuple[str]]) – 以 `filter_prefix` 开头的参数将不会被加载。默认值:None。
|
||||||
- **dec_key** (Union[None, bytes]) – 用于解密的字节类型密钥,如果值为None,则不需要解密。默认值:None。
|
- **dec_key** (Union[None, bytes]) – 用于解密的字节类型密钥,如果值为None,则不需要解密。默认值:None。
|
||||||
- **dec_mode** (str) – 该参数仅当 `dec_key` 不为None时有效。指定解密模式,目前支持“AES-GCM”和“AES-CBC”。默认值:“AES-GCM”。
|
- **dec_mode** (str) – 该参数仅当 `dec_key` 不为None时有效。指定解密模式,目前支持“AES-GCM”和“AES-CBC”。默认值:“AES-GCM”。
|
||||||
|
- **specify_prefix** (Union[str, list[str], tuple[str]]) – 以 `specify_prefix` 开头的参数将会被加载。默认值:None。
|
||||||
|
|
||||||
**返回:**
|
**返回:**
|
||||||
|
|
||||||
|
@ -21,3 +26,5 @@ mindspore.load_checkpoint
|
||||||
**异常:**
|
**异常:**
|
||||||
|
|
||||||
- **ValueError** – checkpoint文件格式不正确。
|
- **ValueError** – checkpoint文件格式不正确。
|
||||||
|
- **ValueError** – 没有一个参数被成功加载。
|
||||||
|
- **ValueError** `specify_prefix` 或者 `filter_prefix` 的数据类型不正确。
|
||||||
|
|
|
@ -347,15 +347,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
||||||
logger.info("Saving checkpoint process is finished.")
|
logger.info("Saving checkpoint process is finished.")
|
||||||
|
|
||||||
|
|
||||||
def _check_param_prefix(filter_prefix, param_name):
|
|
||||||
"""Checks whether the prefix of parameter name matches the given filter_prefix."""
|
|
||||||
for prefix in filter_prefix:
|
|
||||||
if param_name.find(prefix) == 0 \
|
|
||||||
and (param_name == prefix or param_name[len(prefix)] == "." or (prefix and prefix[-1] == ".")):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _check_append_dict(append_dict):
|
def _check_append_dict(append_dict):
|
||||||
"""Check the argument append_dict for save_checkpoint."""
|
"""Check the argument append_dict for save_checkpoint."""
|
||||||
if append_dict is None:
|
if append_dict is None:
|
||||||
|
@ -437,10 +428,15 @@ def load(file_name, **kwargs):
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None, dec_key=None, dec_mode="AES-GCM"):
|
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None,
|
||||||
|
dec_key=None, dec_mode="AES-GCM", specify_prefix=None):
|
||||||
"""
|
"""
|
||||||
Load checkpoint info from a specified file.
|
Load checkpoint info from a specified file.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
1. `specify_prefix` and `filter_prefix` do not affect each other.
|
||||||
|
2. If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ckpt_file_name (str): Checkpoint file name.
|
ckpt_file_name (str): Checkpoint file name.
|
||||||
net (Cell): The network where the parameters will be loaded. Default: None
|
net (Cell): The network where the parameters will be loaded. Default: None
|
||||||
|
@ -454,51 +450,38 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
||||||
is not required. Default: None.
|
is not required. Default: None.
|
||||||
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
|
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption
|
||||||
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
|
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.
|
||||||
|
specify_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the specify_prefix
|
||||||
|
will be loaded. Default: None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict, key is parameter name, value is a Parameter.
|
Dict, key is parameter name, value is a Parameter.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: Checkpoint file's format is incorrect.
|
ValueError: Checkpoint file's format is incorrect.
|
||||||
|
ValueError: Parameter's dict is None after load checkpoint file.
|
||||||
|
ValueError: The type of `specify_prefix` or `filter_prefix` is incorrect.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> from mindspore import load_checkpoint
|
>>> from mindspore import load_checkpoint
|
||||||
>>>
|
>>>
|
||||||
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
||||||
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1", specify_prefix="conv", )
|
||||||
>>> print(param_dict["conv2.weight"])
|
>>> print(param_dict["conv2.weight"])
|
||||||
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
Parameter (name=conv2.weight, shape=(16, 6, 5, 5), dtype=Float32, requires_grad=True)
|
||||||
"""
|
"""
|
||||||
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
|
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name)
|
||||||
|
specify_prefix = _check_prefix(specify_prefix)
|
||||||
|
filter_prefix = _check_prefix(filter_prefix)
|
||||||
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
||||||
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
||||||
logger.info("Execute the process of loading checkpoint files.")
|
logger.info("Execute the process of loading checkpoint files.")
|
||||||
checkpoint_list = Checkpoint()
|
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode)
|
||||||
|
|
||||||
try:
|
|
||||||
if dec_key is None:
|
|
||||||
with open(ckpt_file_name, "rb") as f:
|
|
||||||
pb_content = f.read()
|
|
||||||
else:
|
|
||||||
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
|
||||||
if pb_content is None:
|
|
||||||
raise ValueError("For 'load_checkpoint', Failed to decrypt the checkpoint file.")
|
|
||||||
checkpoint_list.ParseFromString(pb_content)
|
|
||||||
except BaseException as e:
|
|
||||||
if _is_cipher_file(ckpt_file_name):
|
|
||||||
logger.critical("Failed to read the checkpoint file '%s'. The file may be encrypted, please pass in the "
|
|
||||||
"correct 'dec_key'.", ckpt_file_name)
|
|
||||||
else:
|
|
||||||
logger.critical("Failed to read the checkpoint file '%s' , may not have permission to read it, please "
|
|
||||||
"check the correct of the file.", ckpt_file_name)
|
|
||||||
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', failed to read the checkpoint file {}, may not have "
|
|
||||||
"permission to read it.".format(ckpt_file_name))
|
|
||||||
|
|
||||||
parameter_dict = {}
|
parameter_dict = {}
|
||||||
try:
|
try:
|
||||||
param_data_list = []
|
param_data_list = []
|
||||||
for element_id, element in enumerate(checkpoint_list.value):
|
for element_id, element in enumerate(checkpoint_list.value):
|
||||||
if filter_prefix is not None and _check_param_prefix(filter_prefix, element.tag):
|
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
||||||
continue
|
continue
|
||||||
data = element.tensor.tensor_content
|
data = element.tensor.tensor_content
|
||||||
data_type = element.tensor.tensor_type
|
data_type = element.tensor.tensor_type
|
||||||
|
@ -516,15 +499,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
||||||
param_data = float(param_data[0])
|
param_data = float(param_data[0])
|
||||||
elif 'Int' in data_type:
|
elif 'Int' in data_type:
|
||||||
param_data = int(param_data[0])
|
param_data = int(param_data[0])
|
||||||
|
elif dims != [1]:
|
||||||
|
param_data = param_data.reshape(list(dims))
|
||||||
|
|
||||||
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
||||||
elif dims == [1]:
|
|
||||||
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
|
|
||||||
else:
|
|
||||||
param_dim = []
|
|
||||||
for dim in dims:
|
|
||||||
param_dim.append(dim)
|
|
||||||
param_value = param_data.reshape(param_dim)
|
|
||||||
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
|
|
||||||
|
|
||||||
logger.info("Loading checkpoint files process is finished.")
|
logger.info("Loading checkpoint files process is finished.")
|
||||||
|
|
||||||
|
@ -543,8 +521,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
||||||
return parameter_dict
|
return parameter_dict
|
||||||
|
|
||||||
|
|
||||||
def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
|
def _check_ckpt_file_name(ckpt_file_name):
|
||||||
"""Check function load_checkpoint's parameter."""
|
"""Check function load_checkpoint's cket_file_name."""
|
||||||
if not isinstance(ckpt_file_name, str):
|
if not isinstance(ckpt_file_name, str):
|
||||||
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
|
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
|
||||||
"but got {}.".format(type(ckpt_file_name)))
|
"but got {}.".format(type(ckpt_file_name)))
|
||||||
|
@ -558,22 +536,71 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
|
||||||
raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
|
raise ValueError("For 'load_checkpoint', the checkpoint file: {} does not exist, please check "
|
||||||
"whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
|
"whether the 'ckpt_file_name' is correct.".format(ckpt_file_name))
|
||||||
|
|
||||||
if filter_prefix is not None:
|
return ckpt_file_name
|
||||||
if not isinstance(filter_prefix, (str, list, tuple)):
|
|
||||||
raise TypeError("For 'load_checkpoint', the type of 'filter_prefix' must be string, "
|
|
||||||
"list[string] or tuple[string] when 'filter_prefix' is not None, but "
|
def _check_prefix(prefix):
|
||||||
f"got {str(type(filter_prefix))}.")
|
"""Check the correctness of the parameters."""
|
||||||
if isinstance(filter_prefix, str):
|
if prefix is None:
|
||||||
filter_prefix = (filter_prefix,)
|
return prefix
|
||||||
if not filter_prefix:
|
if not isinstance(prefix, (str, list, tuple)):
|
||||||
raise ValueError("For 'load_checkpoint', the argument 'filter_prefix' can't be empty when "
|
raise TypeError("For 'load_checkpoint', the type of 'specify_prefix' or 'filter_prefix' must be string, "
|
||||||
"'filter_prefix' is list or tuple.")
|
"list[string] or tuple[string], but got {}.".format(str(type(prefix))))
|
||||||
for index, prefix in enumerate(filter_prefix):
|
if isinstance(prefix, str):
|
||||||
if not isinstance(prefix, str):
|
prefix = (prefix,)
|
||||||
raise TypeError("For 'load_checkpoint', when 'filter_prefix' is list or tuple, "
|
if not prefix:
|
||||||
"the element in 'filter_prefix' must be string, but got "
|
raise ValueError("For 'load_checkpoint', the argument 'specify_prefix' or 'filter_prefix' can't be empty when"
|
||||||
f"{str(type(prefix))} at index {index}.")
|
" 'filter_prefix' is list or tuple.")
|
||||||
return ckpt_file_name, filter_prefix
|
for index, pre in enumerate(prefix):
|
||||||
|
if not isinstance(pre, str):
|
||||||
|
raise TypeError("For 'load_checkpoint', when 'specify_prefix' or 'filter_prefix' is list or tuple, "
|
||||||
|
"the element in it must be string, but got "
|
||||||
|
f"{str(type(pre))} at index {index}.")
|
||||||
|
if pre == "":
|
||||||
|
raise ValueError("For 'load_checkpoint', the value of 'specify_prefix' or 'filter_prefix' "
|
||||||
|
"can't include ''.")
|
||||||
|
return prefix
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode):
|
||||||
|
"""Parse checkpoint protobuf."""
|
||||||
|
checkpoint_list = Checkpoint()
|
||||||
|
try:
|
||||||
|
if dec_key is None:
|
||||||
|
with open(ckpt_file_name, "rb") as f:
|
||||||
|
pb_content = f.read()
|
||||||
|
else:
|
||||||
|
pb_content = _decrypt(ckpt_file_name, dec_key, len(dec_key), dec_mode)
|
||||||
|
if pb_content is None:
|
||||||
|
raise ValueError("For 'load_checkpoint', failed to decrypt the checkpoint file.")
|
||||||
|
checkpoint_list.ParseFromString(pb_content)
|
||||||
|
except BaseException as e:
|
||||||
|
if _is_cipher_file(ckpt_file_name):
|
||||||
|
logger.critical("Failed to read the checkpoint file '%s'. The file may be encrypted, please pass in the "
|
||||||
|
"correct 'dec_key'.", ckpt_file_name)
|
||||||
|
else:
|
||||||
|
logger.critical("Failed to read the checkpoint file '%s' , may not have permission to read it, please "
|
||||||
|
"check the correct of the file.", ckpt_file_name)
|
||||||
|
raise ValueError(e.__str__() + "\nFor 'load_checkpoint', failed to read the checkpoint file {}, may not have "
|
||||||
|
"permission to read it.".format(ckpt_file_name))
|
||||||
|
return checkpoint_list
|
||||||
|
|
||||||
|
|
||||||
|
def _whether_load_param(specify_prefix, filter_prefix, param_name):
|
||||||
|
"""Checks whether the load the parameter after `specify_prefix` or `filter_prefix`."""
|
||||||
|
whether_load = True
|
||||||
|
if specify_prefix:
|
||||||
|
whether_load = False
|
||||||
|
for prefix in specify_prefix:
|
||||||
|
if param_name.startswith(prefix):
|
||||||
|
whether_load = True
|
||||||
|
break
|
||||||
|
if filter_prefix:
|
||||||
|
for prefix in filter_prefix:
|
||||||
|
if param_name.startswith(prefix):
|
||||||
|
whether_load = False
|
||||||
|
break
|
||||||
|
return whether_load
|
||||||
|
|
||||||
|
|
||||||
def load_param_into_net(net, parameter_dict, strict_load=False):
|
def load_param_into_net(net, parameter_dict, strict_load=False):
|
||||||
|
@ -1607,11 +1634,11 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
||||||
param_not_in_ckpt.append(param.name)
|
param_not_in_ckpt.append(param.name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param_rank = rank_list[param.name][0]
|
param_rank = rank_list.get(param.name)[0]
|
||||||
skip_merge_split = rank_list[param.name][1]
|
skip_merge_split = rank_list.get(param.name)[1]
|
||||||
shard_stride = train_strategy[param.name][4]
|
shard_stride = train_strategy.get(param.name)[4]
|
||||||
if train_strategy[param.name][5]:
|
if train_strategy.get(param.name)[5]:
|
||||||
shard_size = ckpt_file_len / shard_stride / train_strategy[param.name][5]
|
shard_size = ckpt_file_len / shard_stride / train_strategy.get(param.name)[5]
|
||||||
else:
|
else:
|
||||||
shard_size = 0
|
shard_size = 0
|
||||||
for rank in param_rank:
|
for rank in param_rank:
|
||||||
|
|
|
@ -37,9 +37,19 @@ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load
|
||||||
from tests.security_utils import security_off_wrap
|
from tests.security_utils import security_off_wrap
|
||||||
from ..ut_filter import non_graph_engine
|
from ..ut_filter import non_graph_engine
|
||||||
|
|
||||||
class Net(nn.Cell):
|
|
||||||
"""Net definition."""
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
"""
|
||||||
|
Net definition.
|
||||||
|
parameter name :
|
||||||
|
conv1.weight
|
||||||
|
bn1.moving_mean
|
||||||
|
bn1.moving_variance
|
||||||
|
bn1.gamma
|
||||||
|
bn1.beta
|
||||||
|
fc.weight
|
||||||
|
fc.bias
|
||||||
|
"""
|
||||||
def __init__(self, num_classes=10):
|
def __init__(self, num_classes=10):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
|
||||||
|
@ -342,6 +352,98 @@ def test_load_checkpoint_empty_file():
|
||||||
load_checkpoint("empty.ckpt")
|
load_checkpoint("empty.ckpt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_checkpoint_error_param():
|
||||||
|
"""
|
||||||
|
Feature: Load checkpoint.
|
||||||
|
Description: Load checkpoint with error param.
|
||||||
|
Expectation: Raise value error for error param.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = Net(10)
|
||||||
|
ckpt_file = "check_name.ckpt"
|
||||||
|
save_checkpoint(net, ckpt_file)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
load_checkpoint(ckpt_file, specify_prefix=123)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
load_checkpoint(ckpt_file, filter_prefix="")
|
||||||
|
if os.path.exists(ckpt_file):
|
||||||
|
os.remove(ckpt_file)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_checkpoint_error_load():
|
||||||
|
"""
|
||||||
|
Feature: Load checkpoint.
|
||||||
|
Description: Load checkpoint with empty parameter dict.
|
||||||
|
Expectation: Raise value error for error load.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = Net(10)
|
||||||
|
ckpt_file = "check_name.ckpt"
|
||||||
|
save_checkpoint(net, ckpt_file)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
load_checkpoint(ckpt_file, specify_prefix="123")
|
||||||
|
if os.path.exists(ckpt_file):
|
||||||
|
os.remove(ckpt_file)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_checkpoint_specify_prefix():
|
||||||
|
"""
|
||||||
|
Feature: Load checkpoint.
|
||||||
|
Description: Load checkpoint with param `specify_prefix`.
|
||||||
|
Expectation: Correct loaded checkpoint file.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = Net(10)
|
||||||
|
ckpt_file = "specify_prefix.ckpt"
|
||||||
|
save_checkpoint(net, ckpt_file)
|
||||||
|
param_dict = load_checkpoint(ckpt_file, specify_prefix="bn")
|
||||||
|
assert len(param_dict) == 4
|
||||||
|
param_dict = load_checkpoint(ckpt_file, specify_prefix="fc")
|
||||||
|
assert len(param_dict) == 2
|
||||||
|
param_dict = load_checkpoint(ckpt_file, specify_prefix=["fc", "bn"])
|
||||||
|
assert len(param_dict) == 6
|
||||||
|
if os.path.exists(ckpt_file):
|
||||||
|
os.remove(ckpt_file)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_checkpoint_filter_prefix():
|
||||||
|
"""
|
||||||
|
Feature: Load checkpoint.
|
||||||
|
Description: Load checkpoint with param `filter_prefix`.
|
||||||
|
Expectation: Correct loaded checkpoint file.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = Net(10)
|
||||||
|
ckpt_file = "filter_prefix.ckpt"
|
||||||
|
save_checkpoint(net, ckpt_file)
|
||||||
|
param_dict = load_checkpoint(ckpt_file, filter_prefix="fc")
|
||||||
|
assert len(param_dict) == 5
|
||||||
|
param_dict = load_checkpoint(ckpt_file, filter_prefix="bn")
|
||||||
|
assert len(param_dict) == 3
|
||||||
|
param_dict = load_checkpoint(ckpt_file, filter_prefix=["bn", "fc"])
|
||||||
|
assert len(param_dict) == 1
|
||||||
|
if os.path.exists(ckpt_file):
|
||||||
|
os.remove(ckpt_file)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_checkpoint_specify_filter_prefix():
|
||||||
|
"""
|
||||||
|
Feature: Load checkpoint.
|
||||||
|
Description: Load checkpoint with param `filter_prefix` and `specify_prefix`.
|
||||||
|
Expectation: Correct loaded checkpoint file.
|
||||||
|
"""
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
net = Net(10)
|
||||||
|
ckpt_file = "specify_filter_prefix.ckpt"
|
||||||
|
save_checkpoint(net, ckpt_file)
|
||||||
|
param_dict = load_checkpoint(ckpt_file, specify_prefix="bn", filter_prefix="bn1.moving")
|
||||||
|
assert len(param_dict) == 2
|
||||||
|
param_dict = load_checkpoint(ckpt_file, specify_prefix=["bn", "fc"], filter_prefix="fc.weight")
|
||||||
|
assert len(param_dict) == 5
|
||||||
|
if os.path.exists(ckpt_file):
|
||||||
|
os.remove(ckpt_file)
|
||||||
|
|
||||||
|
|
||||||
def test_save_and_load_checkpoint_for_network_with_encryption():
|
def test_save_and_load_checkpoint_for_network_with_encryption():
|
||||||
""" test save and checkpoint for network with encryption"""
|
""" test save and checkpoint for network with encryption"""
|
||||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||||
|
|
Loading…
Reference in New Issue