forked from mindspore-Ecosystem/mindspore
!17761 add save ckpt info
Merge pull request !17761 from changzherui/add_ckpt_info
This commit is contained in:
commit
3a13dde14d
|
@ -32,6 +32,7 @@ from ...common.tensor import Tensor
|
|||
|
||||
_cur_dir = os.getcwd()
|
||||
_save_dir = _cur_dir
|
||||
_info_list = ["epoch_num", "step_num"]
|
||||
|
||||
|
||||
def _chg_ckpt_file_name_if_same_exist(directory, prefix):
|
||||
|
@ -82,6 +83,8 @@ class CheckpointConfig:
|
|||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False.
|
||||
saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
|
||||
with the network in training, the initial value of saved_network will be saved. Default: None.
|
||||
append_info (List): The information save to checkpoint file. Support "epoch_num"、"step_num"、and dict.
|
||||
The key of dict must be str, the value of dict must be one of int float and bool. Default: None.
|
||||
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
|
||||
is not required. Default: None.
|
||||
enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
|
||||
|
@ -131,6 +134,7 @@ class CheckpointConfig:
|
|||
integrated_save=True,
|
||||
async_save=False,
|
||||
saved_network=None,
|
||||
append_info=None,
|
||||
enc_key=None,
|
||||
enc_mode='AES-GCM'):
|
||||
|
||||
|
@ -166,6 +170,7 @@ class CheckpointConfig:
|
|||
self._integrated_save = Validator.check_bool(integrated_save)
|
||||
self._async_save = Validator.check_bool(async_save)
|
||||
self._saved_network = saved_network
|
||||
self._append_dict = self._handle_append_info(append_info)
|
||||
self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
||||
self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
|
||||
|
||||
|
@ -214,6 +219,11 @@ class CheckpointConfig:
|
|||
"""Get the value of _enc_mode"""
|
||||
return self._enc_mode
|
||||
|
||||
@property
|
||||
def append_dict(self):
|
||||
"""Get the value of append_dict."""
|
||||
return self._append_dict
|
||||
|
||||
def get_checkpoint_policy(self):
|
||||
"""Get the policy of checkpoint."""
|
||||
checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
|
||||
|
@ -224,6 +234,36 @@ class CheckpointConfig:
|
|||
|
||||
return checkpoint_policy
|
||||
|
||||
@staticmethod
|
||||
def _handle_append_info(append_info):
|
||||
"""Handle ckpt append info."""
|
||||
if append_info is None or append_info == []:
|
||||
return None
|
||||
if not isinstance(append_info, list):
|
||||
raise TypeError(f"The type of append_info must list, but got {str(type(append_info))}.")
|
||||
handle_append_info = {}
|
||||
if "epoch_num" in append_info:
|
||||
handle_append_info["epoch_num"] = 0
|
||||
if "step_num" in append_info:
|
||||
handle_append_info["step_num"] = 0
|
||||
dict_num = 0
|
||||
for element in append_info:
|
||||
if not isinstance(element, str) and not isinstance(element, dict):
|
||||
raise TypeError(f"The type of append_info element must be str or dict, but got {str(type(element))}.")
|
||||
if isinstance(element, str) and element not in _info_list:
|
||||
raise TypeError(f"The type of append_info element must be in {_info_list}, but got {element}.")
|
||||
if isinstance(element, dict):
|
||||
dict_num += 1
|
||||
if dict_num > 1:
|
||||
raise TypeError(f"The element of append_info must has only one dict.")
|
||||
for key, value in element.items():
|
||||
if isinstance(key, str) and isinstance(value, (int, float, bool)):
|
||||
handle_append_info[key] = value
|
||||
else:
|
||||
raise TypeError(f"The type of dict in append_info must be key: str, value: int or float.")
|
||||
|
||||
return handle_append_info
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
"""
|
||||
|
@ -273,6 +313,9 @@ class ModelCheckpoint(Callback):
|
|||
# get existing checkpoint files
|
||||
self._manager = CheckpointManager()
|
||||
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
||||
self._append_dict = self._config.append_dict or {}
|
||||
self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0
|
||||
self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0
|
||||
self._graph_saved = False
|
||||
self._need_flush_from_cache = True
|
||||
|
||||
|
@ -370,10 +413,13 @@ class ModelCheckpoint(Callback):
|
|||
if context.get_context("enable_ge"):
|
||||
set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
|
||||
if "epoch_num" in self._append_dict:
|
||||
self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num
|
||||
if "step_num" in self._append_dict:
|
||||
self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num
|
||||
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
|
||||
save_checkpoint(network, cur_file, self._config.integrated_save,
|
||||
self._config.async_save, self._config.enc_key, self._config.enc_mode)
|
||||
save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
|
||||
self._append_dict, self._config.enc_key, self._config.enc_mode)
|
||||
|
||||
self._latest_ckpt_file_name = cur_file
|
||||
|
||||
|
@ -442,19 +488,19 @@ class CheckpointManager:
|
|||
|
||||
def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
|
||||
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
|
||||
movs = []
|
||||
del_list = []
|
||||
oldest_file = ''
|
||||
oldest_time = cur_time
|
||||
for ck_file in self._ckpoint_filelist:
|
||||
modify_time = os.path.getmtime(ck_file)
|
||||
if cur_time - modify_time < 60 * minutes:
|
||||
movs.append(ck_file)
|
||||
del_list.append(ck_file)
|
||||
|
||||
if modify_time < oldest_time:
|
||||
oldest_time = modify_time
|
||||
oldest_file = ck_file
|
||||
|
||||
for mv_file in movs:
|
||||
for mv_file in del_list:
|
||||
if mv_file == oldest_file:
|
||||
continue
|
||||
self.remove_ckpoint_file(mv_file)
|
||||
|
|
|
@ -14,15 +14,16 @@
|
|||
# ============================================================================
|
||||
"""Model and parameters serialization."""
|
||||
import os
|
||||
|
||||
import sys
|
||||
import stat
|
||||
import math
|
||||
import shutil
|
||||
import time
|
||||
import copy
|
||||
import threading
|
||||
from threading import Thread, Lock
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -189,7 +190,8 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
|
|||
raise e
|
||||
|
||||
|
||||
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, enc_key=None, enc_mode="AES-GCM"):
|
||||
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
||||
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM"):
|
||||
"""
|
||||
Saves checkpoint info to a specified file.
|
||||
|
||||
|
@ -201,6 +203,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
|
|||
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
|
||||
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
|
||||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
|
||||
append_dict (dict): Additional information that needs to be saved. The key of dict must be str,
|
||||
the value of dict must be one of int float and bool. Default: None
|
||||
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
|
||||
is not required. Default: None.
|
||||
enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
|
||||
|
@ -221,6 +225,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
|
|||
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
|
||||
integrated_save = Validator.check_bool(integrated_save)
|
||||
async_save = Validator.check_bool(async_save)
|
||||
append_dict = _check_append_dict(append_dict)
|
||||
enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
|
||||
enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
|
||||
|
||||
|
@ -245,6 +250,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
|
|||
param_list.append(each_param)
|
||||
save_obj = param_list
|
||||
|
||||
if append_dict:
|
||||
append_info_list = []
|
||||
for k_name, value in append_dict.items():
|
||||
append_info_list.append({"name": k_name, "data": Tensor(value)})
|
||||
save_obj.extend(append_info_list)
|
||||
|
||||
data_list = {}
|
||||
with _ckpt_mutex:
|
||||
for param in save_obj:
|
||||
|
@ -282,6 +293,17 @@ def _check_param_prefix(filter_prefix, param_name):
|
|||
return False
|
||||
|
||||
|
||||
def _check_append_dict(append_dict):
|
||||
if append_dict is None:
|
||||
return append_dict
|
||||
if not isinstance(append_dict, dict):
|
||||
raise TypeError(f"The type of append_dict must dict, but got {str(type(append_dict))}.")
|
||||
if not all(isinstance(ele, str) for ele in append_dict.keys()) or \
|
||||
not all(isinstance(ele, (int, float, bool)) for ele in append_dict.values()):
|
||||
raise TypeError(f"The type of element in append_dict must be key: str, value: int or float.")
|
||||
return append_dict
|
||||
|
||||
|
||||
def load(file_name):
|
||||
"""
|
||||
Load MindIR.
|
||||
|
@ -456,8 +478,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
|
|||
Args:
|
||||
net (Cell): Cell network.
|
||||
parameter_dict (dict): Parameter dictionary.
|
||||
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
|
||||
in the param_dict into net with the same suffix. Default: False
|
||||
strict_load (bool): Whether to strict load the parameter into net. False: if some parameters in the net
|
||||
not loaded, it will remove some parameter's prefix name continue to load. Default: False
|
||||
|
||||
Raises:
|
||||
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
|
||||
|
@ -1270,6 +1292,18 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
|
|||
load_param_into_net(network, param_dict)
|
||||
|
||||
|
||||
def async_ckpt_thread_status():
|
||||
"""
|
||||
Get async save checkpoint thread status.
|
||||
|
||||
Returns:
|
||||
True, Asynchronous save checkpoint thread is running.
|
||||
False, Asynchronous save checkpoint thread is not executing.
|
||||
"""
|
||||
thr_list = threading.enumerate()
|
||||
return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
|
||||
|
||||
|
||||
def _check_predict_strategy(predict_strategy):
|
||||
"""Check predict strategy."""
|
||||
def _check_int_list(arg):
|
||||
|
|
|
@ -120,6 +120,30 @@ def test_save_checkpoint_for_list():
|
|||
save_checkpoint(parameter_list, ckpt_file_name)
|
||||
|
||||
|
||||
def test_save_checkpoint_for_list_append_info():
|
||||
""" test save_checkpoint for list append info"""
|
||||
parameter_list = []
|
||||
one_param = {}
|
||||
param1 = {}
|
||||
param2 = {}
|
||||
one_param['name'] = "param_test"
|
||||
one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
|
||||
param1['name'] = "param"
|
||||
param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32)
|
||||
param2['name'] = "new_param"
|
||||
param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32)
|
||||
parameter_list.append(one_param)
|
||||
parameter_list.append(param1)
|
||||
parameter_list.append(param2)
|
||||
append_dict = {"lr": 0.01, "epoch": 20, "train": True}
|
||||
if os.path.exists('./parameters.ckpt'):
|
||||
os.chmod('./parameters.ckpt', stat.S_IWRITE)
|
||||
os.remove('./parameters.ckpt')
|
||||
|
||||
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
|
||||
save_checkpoint(parameter_list, ckpt_file_name, append_dict=append_dict)
|
||||
|
||||
|
||||
def test_load_checkpoint_error_filename():
|
||||
ckpt_file_name = 1
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -130,7 +154,7 @@ def test_load_checkpoint():
|
|||
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
|
||||
par_dict = load_checkpoint(ckpt_file_name)
|
||||
|
||||
assert len(par_dict) == 3
|
||||
assert len(par_dict) == 6
|
||||
assert par_dict['param_test'].name == 'param_test'
|
||||
assert par_dict['param_test'].data.dtype == mstype.float32
|
||||
assert par_dict['param_test'].data.shape == (1, 3, 224, 224)
|
||||
|
|
Loading…
Reference in New Issue