!17761 add save ckpt info

Merge pull request !17761 from changzherui/add_ckpt_info
This commit is contained in:
i-robot 2021-06-10 11:01:18 +08:00 committed by Gitee
commit 3a13dde14d
3 changed files with 115 additions and 11 deletions

View File

@ -32,6 +32,7 @@ from ...common.tensor import Tensor
_cur_dir = os.getcwd()
_save_dir = _cur_dir
_info_list = ["epoch_num", "step_num"]
def _chg_ckpt_file_name_if_same_exist(directory, prefix):
@ -82,6 +83,8 @@ class CheckpointConfig:
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False.
saved_network (Cell): Network to be saved in checkpoint file. If the saved_network has no relation
with the network in training, the initial value of saved_network will be saved. Default: None.
append_info (List): The information save to checkpoint file. Support "epoch_num""step_num"and dict.
The key of dict must be str, the value of dict must be one of int float and bool. Default: None.
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
is not required. Default: None.
enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
@ -131,6 +134,7 @@ class CheckpointConfig:
integrated_save=True,
async_save=False,
saved_network=None,
append_info=None,
enc_key=None,
enc_mode='AES-GCM'):
@ -166,6 +170,7 @@ class CheckpointConfig:
self._integrated_save = Validator.check_bool(integrated_save)
self._async_save = Validator.check_bool(async_save)
self._saved_network = saved_network
self._append_dict = self._handle_append_info(append_info)
self._enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
self._enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
@ -214,6 +219,11 @@ class CheckpointConfig:
"""Get the value of _enc_mode"""
return self._enc_mode
@property
def append_dict(self):
"""Get the value of append_dict."""
return self._append_dict
def get_checkpoint_policy(self):
"""Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps,
@ -224,6 +234,36 @@ class CheckpointConfig:
return checkpoint_policy
@staticmethod
def _handle_append_info(append_info):
"""Handle ckpt append info."""
if append_info is None or append_info == []:
return None
if not isinstance(append_info, list):
raise TypeError(f"The type of append_info must list, but got {str(type(append_info))}.")
handle_append_info = {}
if "epoch_num" in append_info:
handle_append_info["epoch_num"] = 0
if "step_num" in append_info:
handle_append_info["step_num"] = 0
dict_num = 0
for element in append_info:
if not isinstance(element, str) and not isinstance(element, dict):
raise TypeError(f"The type of append_info element must be str or dict, but got {str(type(element))}.")
if isinstance(element, str) and element not in _info_list:
raise TypeError(f"The type of append_info element must be in {_info_list}, but got {element}.")
if isinstance(element, dict):
dict_num += 1
if dict_num > 1:
raise TypeError(f"The element of append_info must has only one dict.")
for key, value in element.items():
if isinstance(key, str) and isinstance(value, (int, float, bool)):
handle_append_info[key] = value
else:
raise TypeError(f"The type of dict in append_info must be key: str, value: int or float.")
return handle_append_info
class ModelCheckpoint(Callback):
"""
@ -273,6 +313,9 @@ class ModelCheckpoint(Callback):
# get existing checkpoint files
self._manager = CheckpointManager()
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
self._append_dict = self._config.append_dict or {}
self._append_epoch_num = self._append_dict["epoch_num"] if "epoch_num" in self._append_dict else 0
self._append_step_num = self._append_dict["step_num"] if "step_num" in self._append_dict else 0
self._graph_saved = False
self._need_flush_from_cache = True
@ -370,10 +413,13 @@ class ModelCheckpoint(Callback):
if context.get_context("enable_ge"):
set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()
if "epoch_num" in self._append_dict:
self._append_dict["epoch_num"] = self._append_epoch_num + cb_params.cur_epoch_num
if "step_num" in self._append_dict:
self._append_dict["step_num"] = self._append_step_num + cb_params.cur_step_num
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network
save_checkpoint(network, cur_file, self._config.integrated_save,
self._config.async_save, self._config.enc_key, self._config.enc_mode)
save_checkpoint(network, cur_file, self._config.integrated_save, self._config.async_save,
self._append_dict, self._config.enc_key, self._config.enc_mode)
self._latest_ckpt_file_name = cur_file
@ -442,19 +488,19 @@ class CheckpointManager:
def keep_one_ckpoint_per_minutes(self, minutes, cur_time):
"""Only keep the latest one ckpt file per minutes, remove other files generated in [last_time, cur_time]."""
movs = []
del_list = []
oldest_file = ''
oldest_time = cur_time
for ck_file in self._ckpoint_filelist:
modify_time = os.path.getmtime(ck_file)
if cur_time - modify_time < 60 * minutes:
movs.append(ck_file)
del_list.append(ck_file)
if modify_time < oldest_time:
oldest_time = modify_time
oldest_file = ck_file
for mv_file in movs:
for mv_file in del_list:
if mv_file == oldest_file:
continue
self.remove_ckpoint_file(mv_file)

View File

@ -14,15 +14,16 @@
# ============================================================================
"""Model and parameters serialization."""
import os
import sys
import stat
import math
import shutil
import time
import copy
import threading
from threading import Thread, Lock
from collections import defaultdict
import numpy as np
import mindspore.nn as nn
@ -189,7 +190,8 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"):
raise e
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False, enc_key=None, enc_mode="AES-GCM"):
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM"):
"""
Saves checkpoint info to a specified file.
@ -201,6 +203,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
append_dict (dict): Additional information that needs to be saved. The key of dict must be str,
the value of dict must be one of int float and bool. Default: None
enc_key (Union[None, bytes]): Byte type key used for encryption. If the value is None, the encryption
is not required. Default: None.
enc_mode (str): This parameter is valid only when enc_key is not set to None. Specifies the encryption
@ -221,6 +225,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
integrated_save = Validator.check_bool(integrated_save)
async_save = Validator.check_bool(async_save)
append_dict = _check_append_dict(append_dict)
enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
@ -245,6 +250,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
param_list.append(each_param)
save_obj = param_list
if append_dict:
append_info_list = []
for k_name, value in append_dict.items():
append_info_list.append({"name": k_name, "data": Tensor(value)})
save_obj.extend(append_info_list)
data_list = {}
with _ckpt_mutex:
for param in save_obj:
@ -282,6 +293,17 @@ def _check_param_prefix(filter_prefix, param_name):
return False
def _check_append_dict(append_dict):
if append_dict is None:
return append_dict
if not isinstance(append_dict, dict):
raise TypeError(f"The type of append_dict must dict, but got {str(type(append_dict))}.")
if not all(isinstance(ele, str) for ele in append_dict.keys()) or \
not all(isinstance(ele, (int, float, bool)) for ele in append_dict.values()):
raise TypeError(f"The type of element in append_dict must be key: str, value: int or float.")
return append_dict
def load(file_name):
"""
Load MindIR.
@ -456,8 +478,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
Args:
net (Cell): Cell network.
parameter_dict (dict): Parameter dictionary.
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
in the param_dict into net with the same suffix. Default: False
strict_load (bool): Whether to strict load the parameter into net. False: if some parameters in the net
not loaded, it will remove some parameter's prefix name continue to load. Default: False
Raises:
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.
@ -1270,6 +1292,18 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
load_param_into_net(network, param_dict)
def async_ckpt_thread_status():
"""
Get async save checkpoint thread status.
Returns:
True, Asynchronous save checkpoint thread is running.
False, Asynchronous save checkpoint thread is not executing.
"""
thr_list = threading.enumerate()
return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list]
def _check_predict_strategy(predict_strategy):
"""Check predict strategy."""
def _check_int_list(arg):

View File

@ -120,6 +120,30 @@ def test_save_checkpoint_for_list():
save_checkpoint(parameter_list, ckpt_file_name)
def test_save_checkpoint_for_list_append_info():
""" test save_checkpoint for list append info"""
parameter_list = []
one_param = {}
param1 = {}
param2 = {}
one_param['name'] = "param_test"
one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
param1['name'] = "param"
param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32)
param2['name'] = "new_param"
param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32)
parameter_list.append(one_param)
parameter_list.append(param1)
parameter_list.append(param2)
append_dict = {"lr": 0.01, "epoch": 20, "train": True}
if os.path.exists('./parameters.ckpt'):
os.chmod('./parameters.ckpt', stat.S_IWRITE)
os.remove('./parameters.ckpt')
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
save_checkpoint(parameter_list, ckpt_file_name, append_dict=append_dict)
def test_load_checkpoint_error_filename():
ckpt_file_name = 1
with pytest.raises(ValueError):
@ -130,7 +154,7 @@ def test_load_checkpoint():
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
par_dict = load_checkpoint(ckpt_file_name)
assert len(par_dict) == 3
assert len(par_dict) == 6
assert par_dict['param_test'].name == 'param_test'
assert par_dict['param_test'].data.dtype == mstype.float32
assert par_dict['param_test'].data.shape == (1, 3, 224, 224)