diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index 70bb83c6461..4fef64d85e6 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -32,6 +32,7 @@ from ...common.tensor import Tensor _cur_dir = os.getcwd() _save_dir = _cur_dir +_info_list = ["epoch_num", "step_num"] def _chg_ckpt_file_name_if_same_exist(directory, prefix): @@ -82,6 +83,8 @@ class CheckpointConfig: async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False. saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation with the network in training, the initial value of saved_network will be saved. Default: None. + append_info (List): The information save to checkpoint file. Support "epoch_num"、"step_num"、and dict. + The key of dict must be str, the value of dict must be one of int float and bool. Default: None. enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption is not required. Default: None. enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption @@ -131,6 +134,7 @@ class CheckpointConfig: integrated_save=True, async_save=False, saved_network=None, + append_info=None, enc_key=None, enc_mode='AES-GCM'): @@ -166,6 +170,7 @@ class CheckpointConfig: self._integrated_save = Validator.check_bool(integrated_save) self._async_save = Validator.check_bool(async_save) self._saved_network = saved_network + self._append_dict = self._handle_append_info(append_info) self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) @@ -214,6 +219,11 @@ class CheckpointConfig: """Get the value of _enc_mode""" return self._enc_mode + @property + def append_dict(self): + """Get the value of append_dict.""" + return self._append_dict + def get_checkpoint_policy(self): """Get the policy of checkpoint.""" checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, @@ -224,6 +234,36 @@ class CheckpointConfig: return checkpoint_policy + @staticmethod + def _handle_append_info(append_info): + """Handle ckpt append info.""" + if append_info is None or append_info == []: + return None + if not isinstance(append_info, list): + raise TypeError(f"The type of append_info must list, but got {str(type(append_info))}.") + handle_append_info = {} + if "epoch_num" in append_info: + handle_append_info["epoch_num"] = 0 + if "step_num" in append_info: + handle_append_info["step_num"] = 0 + dict_num = 0 + for element in append_info: + if not isinstance(element, str) and not isinstance(element, dict): + raise TypeError(f"The type of append_info element must be str or dict, but got {str(type(element))}.") + if isinstance(element, str) and element not in _info_list: + raise TypeError(f"The type of append_info element must be in {_info_list}, but got {element}.") + if isinstance(element, dict): + dict_num += 1 + if dict_num > 1: + raise TypeError(f"The element of append_info must has only one dict.") + for key, value in element.items(): + if isinstance(key, str) and isinstance(value, (int, float, bool)): + handle_append_info[key] = value + else: + raise TypeError(f"The type of dict in append_info must be key: str, value: int or float.") + + return handle_append_info + class ModelCheckpoint(Callback): """ @@ -273,6 +313,9 @@ class ModelCheckpoint(Callback): # get existing checkpoint files self._manager = CheckpointManager() self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) + self._append_dict = self._config.append_dict or {} + self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0 + self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0 self._graph_saved = False self._need_flush_from_cache = True @@ -370,10 +413,13 @@ class ModelCheckpoint(Callback): if context.get_context("enable_ge"): set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() - + if "epoch_num" in self._append_dict: + self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num + if "step_num" in self._append_dict: + self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network - save_checkpoint(network, cur_file, self._config.integrated_save, - self._config.async_save, self._config.enc_key, self._config.enc_mode) + save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save, + self._append_dict, self._config.enc_key, self._config.enc_mode) self._latest_ckpt_file_name = cur_file @@ -442,19 +488,19 @@ class CheckpointManager: def keep_one_ckpoint_per_minutes(self, minutes, cur_time): """Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time].""" - movs = [] + del_list = [] oldest_file = '' oldest_time = cur_time for ck_file in self._ckpoint_filelist: modify_time = os.path.getmtime(ck_file) if cur_time - modify_time < 60 * minutes: - movs.append(ck_file) + del_list.append(ck_file) if modify_time < oldest_time: oldest_time = modify_time oldest_file = ck_file - for mv_file in movs: + for mv_file in del_list: if mv_file == oldest_file: continue self.remove_ckpoint_file(mv_file) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 3a6fb27fe67..ae3c92e6296 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -14,15 +14,16 @@ # ============================================================================ """Model and parameters serialization.""" import os - import sys import stat import math import shutil import time import copy +import threading from threading import Thread, Lock from collections import defaultdict + import numpy as np import mindspore.nn as nn @@ -189,7 +190,8 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"): raise e -def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, enc_key=None, enc_mode="AES-GCM"): +def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, + async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM"): """ Saves checkpoint info to a specified file. @@ -201,6 +203,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False + append_dict (dict): Additional information that needs to be saved. The key of dict must be str, + the value of dict must be one of int float and bool. Default: None enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption is not required. Default: None. enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption @@ -221,6 +225,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) integrated_save = Validator.check_bool(integrated_save) async_save = Validator.check_bool(async_save) + append_dict = _check_append_dict(append_dict) enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes)) enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str) @@ -245,6 +250,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F param_list.append(each_param) save_obj = param_list + if append_dict: + append_info_list = [] + for k_name, value in append_dict.items(): + append_info_list.append({"name": k_name, "data": Tensor(value)}) + save_obj.extend(append_info_list) + data_list = {} with _ckpt_mutex: for param in save_obj: @@ -282,6 +293,17 @@ def _check_param_prefix(filter_prefix, param_name): return False +def _check_append_dict(append_dict): + if append_dict is None: + return append_dict + if not isinstance(append_dict, dict): + raise TypeError(f"The type of append_dict must dict, but got {str(type(append_dict))}.") + if not all(isinstance(ele, str) for ele in append_dict.keys()) or \ + not all(isinstance(ele, (int, float, bool)) for ele in append_dict.values()): + raise TypeError(f"The type of element in append_dict must be key: str, value: int or float.") + return append_dict + + def load(file_name): """ Load MindIR. @@ -456,8 +478,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False): Args: net (Cell): Cell network. parameter_dict (dict): Parameter dictionary. - strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter - in the param_dict into net with the same suffix. Default: False + strict_load (bool): Whether to strict load the parameter into net. False: if some parameters in the net + not loaded, it will remove some parameter's prefix name continue to load. Default: False Raises: TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. @@ -1270,6 +1292,18 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= load_param_into_net(network, param_dict) +def async_ckpt_thread_status(): + """ + Get async save checkpoint thread status. + + Returns: + True, Asynchronous save checkpoint thread is running. + False, Asynchronous save checkpoint thread is not executing. + """ + thr_list = threading.enumerate() + return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list] + + def _check_predict_strategy(predict_strategy): """Check predict strategy.""" def _check_int_list(arg): diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index 6bfc7aeae65..639f959853f 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -120,6 +120,30 @@ def test_save_checkpoint_for_list(): save_checkpoint(parameter_list, ckpt_file_name) +def test_save_checkpoint_for_list_append_info(): + """ test save_checkpoint for list append info""" + parameter_list = [] + one_param = {} + param1 = {} + param2 = {} + one_param['name'] = "param_test" + one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32) + param1['name'] = "param" + param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32) + param2['name'] = "new_param" + param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32) + parameter_list.append(one_param) + parameter_list.append(param1) + parameter_list.append(param2) + append_dict = {"lr": 0.01, "epoch": 20, "train": True} + if os.path.exists('./parameters.ckpt'): + os.chmod('./parameters.ckpt', stat.S_IWRITE) + os.remove('./parameters.ckpt') + + ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') + save_checkpoint(parameter_list, ckpt_file_name, append_dict=append_dict) + + def test_load_checkpoint_error_filename(): ckpt_file_name = 1 with pytest.raises(ValueError): @@ -130,7 +154,7 @@ def test_load_checkpoint(): ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') par_dict = load_checkpoint(ckpt_file_name) - assert len(par_dict) == 3 + assert len(par_dict) == 6 assert par_dict['param_test'].name == 'param_test' assert par_dict['param_test'].data.dtype == mstype.float32 assert par_dict['param_test'].data.shape == (1, 3, 224, 224)