forked from mindspore-Ecosystem/mindspore
!7401 custom ckpt save and load
Merge pull request !7401 from caozhou/custom_ckpt_save_and_load
This commit is contained in:
commit
ba6023b87d
|
@ -21,6 +21,7 @@ import time
|
|||
import threading
|
||||
import mindspore.context as context
|
||||
from mindspore import log as logger
|
||||
from mindspore import nn
|
||||
from mindspore._checkparam import Validator
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import save_checkpoint, _save_graph
|
||||
|
@ -88,13 +89,36 @@ class CheckpointConfig:
|
|||
integrated_save (bool): Whether to perform integrated save function in automatic model parallel scene.
|
||||
Default: True. Integrated save function is only supported in automatic parallel scene, not supported
|
||||
in manual parallel.
|
||||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
|
||||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False.
|
||||
saved_network (Cell): Network to be saved in checkpoint file. Default: None.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input_param is None or 0.
|
||||
|
||||
Examples:
|
||||
>>> config = CheckpointConfig()
|
||||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal')
|
||||
>>> self.bn = nn.BatchNorm2d(64)
|
||||
>>> self.relu = nn.ReLU()
|
||||
>>> self.flatten = nn.Flatten()
|
||||
>>> self.fc = nn.Dense(64*224*224, 12)
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> x = self.conv(x)
|
||||
>>> x = self.bn(x)
|
||||
>>> x = self.relu(x)
|
||||
>>> x = self.flatten(x)
|
||||
>>> out = self.fc(x)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
||||
>>> dataset = get_dataset()
|
||||
>>> config = CheckpointConfig(saved_network=net)
|
||||
>>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config)
|
||||
>>> model.train(10, dataset, callbacks=ckpoint_cb)
|
||||
"""
|
||||
|
@ -104,7 +128,8 @@ class CheckpointConfig:
|
|||
keep_checkpoint_max=5,
|
||||
keep_checkpoint_per_n_minutes=0,
|
||||
integrated_save=True,
|
||||
async_save=False):
|
||||
async_save=False,
|
||||
saved_network=None):
|
||||
|
||||
if save_checkpoint_steps is not None:
|
||||
save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps)
|
||||
|
@ -115,6 +140,9 @@ class CheckpointConfig:
|
|||
if keep_checkpoint_per_n_minutes is not None:
|
||||
keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes)
|
||||
|
||||
if saved_network is not None and not isinstance(saved_network, nn.Cell):
|
||||
raise TypeError(f"The type of saved_network must be None or Cell, but got {str(type(saved_network))}.")
|
||||
|
||||
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
||||
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
||||
raise ValueError("The input_param can't be all None or 0")
|
||||
|
@ -134,6 +162,7 @@ class CheckpointConfig:
|
|||
|
||||
self._integrated_save = Validator.check_bool(integrated_save)
|
||||
self._async_save = Validator.check_bool(async_save)
|
||||
self._saved_network = saved_network
|
||||
|
||||
@property
|
||||
def save_checkpoint_steps(self):
|
||||
|
@ -165,12 +194,18 @@ class CheckpointConfig:
|
|||
"""Get the value of _async_save."""
|
||||
return self._async_save
|
||||
|
||||
@property
|
||||
def saved_network(self):
|
||||
"""Get the value of _saved_network"""
|
||||
return self._saved_network
|
||||
|
||||
def get_checkpoint_policy(self):
|
||||
"""Get the policy of checkpoint."""
|
||||
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
||||
'save_checkpoint_seconds': self._save_checkpoint_seconds,
|
||||
'keep_checkpoint_max': self._keep_checkpoint_max,
|
||||
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
|
||||
checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
|
||||
'save_checkpoint_seconds': self.save_checkpoint_seconds,
|
||||
'keep_checkpoint_max': self.keep_checkpoint_max,
|
||||
'keep_checkpoint_per_n_minutes': self.keep_checkpoint_per_n_minutes,
|
||||
'saved_network': self.saved_network}
|
||||
|
||||
return checkpoint_policy
|
||||
|
||||
|
@ -306,7 +341,8 @@ class ModelCheckpoint(Callback):
|
|||
set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
|
||||
save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
|
||||
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
|
||||
save_checkpoint(network, cur_file, self._config.integrated_save,
|
||||
self._config.async_save)
|
||||
|
||||
self._latest_ckpt_file_name = cur_file
|
||||
|
|
|
@ -225,7 +225,16 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
|
|||
logger.info("Save checkpoint process finish.")
|
||||
|
||||
|
||||
def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
|
||||
def _check_param_prefix(filter_prefix, param_name):
|
||||
"""Checks whether the prefix of parameter name matches the given filter_prefix."""
|
||||
for prefix in filter_prefix:
|
||||
if param_name.find(prefix) == 0 \
|
||||
and (param_name == prefix or param_name[len(prefix)] == "." or (prefix and prefix[-1] == ".")):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None):
|
||||
"""
|
||||
Loads checkpoint info from a specified file.
|
||||
|
||||
|
@ -234,6 +243,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
|
|||
net (Cell): Cell network. Default: None
|
||||
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
|
||||
in the param_dict into net with the same suffix. Default: False
|
||||
filter_prefix (Union[str, list[str], tuple[str]]): Parameter with the filter prefix will not be loaded.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
Dict, key is parameter name, value is a Parameter.
|
||||
|
@ -253,6 +264,19 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
|
|||
if os.path.getsize(ckpt_file_name) == 0:
|
||||
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
|
||||
|
||||
if filter_prefix is not None:
|
||||
if not isinstance(filter_prefix, (str, list, tuple)):
|
||||
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] "
|
||||
f"when filter_prefix is not None, but got {str(type(filter_prefix))}.")
|
||||
if isinstance(filter_prefix, str):
|
||||
filter_prefix = (filter_prefix,)
|
||||
if not filter_prefix:
|
||||
raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.")
|
||||
for index, prefix in enumerate(filter_prefix):
|
||||
if not isinstance(prefix, str):
|
||||
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
|
||||
f"but got {str(type(prefix))} at index {index}.")
|
||||
|
||||
logger.info("Execute load checkpoint process.")
|
||||
checkpoint_list = Checkpoint()
|
||||
|
||||
|
@ -266,9 +290,10 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
|
|||
|
||||
parameter_dict = {}
|
||||
try:
|
||||
element_id = 0
|
||||
param_data_list = []
|
||||
for element in checkpoint_list.value:
|
||||
for element_id, element in enumerate(checkpoint_list.value):
|
||||
if filter_prefix is not None and _check_param_prefix(filter_prefix, element.tag):
|
||||
continue
|
||||
data = element.tensor.tensor_content
|
||||
data_type = element.tensor.tensor_type
|
||||
np_type = tensor_to_np_type[data_type]
|
||||
|
@ -296,14 +321,15 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False):
|
|||
param_value = param_data.reshape(param_dim)
|
||||
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
|
||||
|
||||
element_id += 1
|
||||
|
||||
logger.info("Load checkpoint process finish.")
|
||||
|
||||
except BaseException as e:
|
||||
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
|
||||
raise RuntimeError(e.__str__())
|
||||
|
||||
if not parameter_dict:
|
||||
raise ValueError(f"The loaded parameter dict is empty after filtering, please check filter_prefix.")
|
||||
|
||||
if net is not None:
|
||||
load_param_into_net(net, parameter_dict, strict_load)
|
||||
|
||||
|
|
Loading…
Reference in New Issue