!7401 custom ckpt save and load

Merge pull request !7401 from caozhou/custom_ckpt_save_and_load
This commit is contained in:
mindspore-ci-bot 2020-10-20 20:31:46 +08:00 committed by Gitee
commit ba6023b87d
2 changed files with 75 additions and 13 deletions

View File

@ -21,6 +21,7 @@ import time
import threading
import mindspore.context as context
from mindspore import log as logger
from mindspore import nn
from mindspore._checkparam import Validator
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import save_checkpoint, _save_graph
@ -88,13 +89,36 @@ class CheckpointConfig:
integrated_save (bool): Whether to perform integrated save function in automatic model parallel scene.
Default: True. Integrated save function is only supported in automatic parallel scene, not supported
in manual parallel.
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False.
saved_network (Cell): Network to be saved in checkpoint file. Default: None.
Raises:
ValueError: If the input_param is None or 0.
Examples:
>>> config = CheckpointConfig()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
>>> self.bn = nn.BatchNorm2d(64)
>>> self.relu = nn.ReLU()
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Dense(64*224*224, 12)
>>>
>>> def construct(self, x):
>>> x = self.conv(x)
>>> x = self.bn(x)
>>> x = self.relu(x)
>>> x = self.flatten(x)
>>> out = self.fc(x)
>>> return out
>>>
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim)
>>> dataset = get_dataset()
>>> config = CheckpointConfig(saved_network=net)
>>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config)
>>> model.train(10, dataset, callbacks=ckpoint_cb)
"""
@ -104,7 +128,8 @@ class CheckpointConfig:
keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0,
integrated_save=True,
async_save=False):
async_save=False,
saved_network=None):
if save_checkpoint_steps is not None:
save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps)
@ -115,6 +140,9 @@ class CheckpointConfig:
if keep_checkpoint_per_n_minutes is not None:
keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes)
if saved_network is not None and not isinstance(saved_network, nn.Cell):
raise TypeError(f"The type of saved_network must be None or Cell, but got {str(type(saved_network))}.")
if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
raise ValueError("The input_param can't be all None or 0")
@ -134,6 +162,7 @@ class CheckpointConfig:
self._integrated_save = Validator.check_bool(integrated_save)
self._async_save = Validator.check_bool(async_save)
self._saved_network = saved_network
@property
def save_checkpoint_steps(self):
@ -165,12 +194,18 @@ class CheckpointConfig:
"""Get the value of _async_save."""
return self._async_save
@property
def saved_network(self):
"""Get the value of _saved_network"""
return self._saved_network
def get_checkpoint_policy(self):
"""Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
'save_checkpoint_seconds': self._save_checkpoint_seconds,
'keep_checkpoint_max': self._keep_checkpoint_max,
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
'save_checkpoint_seconds': self.save_checkpoint_seconds,
'keep_checkpoint_max': self.keep_checkpoint_max,
'keep_checkpoint_per_n_minutes': self.keep_checkpoint_per_n_minutes,
'saved_network': self.saved_network}
return checkpoint_policy
@ -306,7 +341,8 @@ class ModelCheckpoint(Callback):
set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()
save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
save_checkpoint(network, cur_file, self._config.integrated_save,
self._config.async_save)
self._latest_ckpt_file_name = cur_file

View File

@ -225,7 +225,16 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
logger.info("Save checkpoint process finish.")
def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
def _check_param_prefix(filter_prefix, param_name):
"""Checks whether the prefix of parameter name matches the given filter_prefix."""
for prefix in filter_prefix:
if param_name.find(prefix) == 0 \
and (param_name == prefix or param_name[len(prefix)] == "." or (prefix and prefix[-1] == ".")):
return True
return False
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None):
"""
Loads checkpoint info from a specified file.
@ -234,6 +243,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
net (Cell): Cell network. Default: None
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
in the param_dict into net with the same suffix. Default: False
filter_prefix (Union[str, list[str], tuple[str]]): Parameter with the filter prefix will not be loaded.
Default: None.
Returns:
Dict, key is parameter name, value is a Parameter.
@ -253,6 +264,19 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
if os.path.getsize(ckpt_file_name) == 0:
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
if filter_prefix is not None:
if not isinstance(filter_prefix, (str, list, tuple)):
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] "
f"when filter_prefix is not None, but got {str(type(filter_prefix))}.")
if isinstance(filter_prefix, str):
filter_prefix = (filter_prefix,)
if not filter_prefix:
raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.")
for index, prefix in enumerate(filter_prefix):
if not isinstance(prefix, str):
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
f"but got {str(type(prefix))} at index {index}.")
logger.info("Execute load checkpoint process.")
checkpoint_list = Checkpoint()
@ -266,9 +290,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
parameter_dict = {}
try:
element_id = 0
param_data_list = []
for element in checkpoint_list.value:
for element_id, element in enumerate(checkpoint_list.value):
if filter_prefix is not None and _check_param_prefix(filter_prefix, element.tag):
continue
data = element.tensor.tensor_content
data_type = element.tensor.tensor_type
np_type = tensor_to_np_type[data_type]
@ -296,14 +321,15 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
param_value = param_data.reshape(param_dim)
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
element_id += 1
logger.info("Load checkpoint process finish.")
except BaseException as e:
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__())
if not parameter_dict:
raise ValueError(f"The loaded parameter dict is empty after filtering, please check filter_prefix.")
if net is not None:
load_param_into_net(net, parameter_dict, strict_load)