Asynchronous save checkpoint
This commit is contained in:
parent
c99cc0dfa1
commit
d45abc5f54
|
@ -15,7 +15,6 @@
|
|||
"""Checkpoint related classes and functions."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import stat
|
||||
import time
|
||||
|
||||
|
@ -86,6 +85,7 @@ class CheckpointConfig:
|
|||
Can't be used with keep_checkpoint_max at the same time.
|
||||
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
|
||||
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
|
||||
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
|
||||
|
||||
Raises:
|
||||
ValueError: If the input_param is None or 0.
|
||||
|
@ -100,7 +100,8 @@ class CheckpointConfig:
|
|||
save_checkpoint_seconds=0,
|
||||
keep_checkpoint_max=5,
|
||||
keep_checkpoint_per_n_minutes=0,
|
||||
integrated_save=True):
|
||||
integrated_save=True,
|
||||
async_save=False):
|
||||
|
||||
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
||||
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
||||
|
@ -129,6 +130,7 @@ class CheckpointConfig:
|
|||
self._keep_checkpoint_max = 1
|
||||
|
||||
self._integrated_save = check_bool(integrated_save)
|
||||
self._async_save = check_bool(async_save)
|
||||
|
||||
@property
|
||||
def save_checkpoint_steps(self):
|
||||
|
@ -155,6 +157,11 @@ class CheckpointConfig:
|
|||
"""Get the value of _integrated_save."""
|
||||
return self._integrated_save
|
||||
|
||||
@property
|
||||
def async_save(self):
|
||||
"""Get the value of _async_save."""
|
||||
return self._async_save
|
||||
|
||||
def get_checkpoint_policy(self):
|
||||
"""Get the policy of checkpoint."""
|
||||
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
||||
|
@ -282,8 +289,6 @@ class ModelCheckpoint(Callback):
|
|||
global _save_dir
|
||||
_save_dir = self._directory
|
||||
cur_file = os.path.join(self._directory, cur_ckpoint_file)
|
||||
tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt'
|
||||
gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process)
|
||||
self._last_time_for_keep = time.time()
|
||||
self._last_triggered_step = cb_params.cur_step_num
|
||||
|
||||
|
@ -291,10 +296,9 @@ class ModelCheckpoint(Callback):
|
|||
set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
|
||||
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
||||
_exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
|
||||
self._config.async_save)
|
||||
|
||||
if os.path.exists(gen_file):
|
||||
shutil.move(gen_file, cur_file)
|
||||
self._latest_ckpt_file_name = cur_file
|
||||
|
||||
@property
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Model and parameters serialization."""
|
||||
import os
|
||||
import stat
|
||||
from threading import Thread, Lock
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -40,6 +41,7 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
|
|||
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
|
||||
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
|
||||
|
||||
_ckpt_mutex = Lock()
|
||||
|
||||
def _special_process_par(par, new_par):
|
||||
"""
|
||||
|
@ -101,7 +103,29 @@ def _update_param(param, new_param):
|
|||
param.set_parameter_data(type(param.data)(new_param.data))
|
||||
|
||||
|
||||
def save_checkpoint(parameter_list, ckpt_file_name):
|
||||
def _exec_save(ckpt_file_name, data_list):
|
||||
"""Execute save checkpoint into file process."""
|
||||
checkpoint_list = Checkpoint()
|
||||
|
||||
try:
|
||||
with _ckpt_mutex:
|
||||
for name, value in data_list.items():
|
||||
param_value = checkpoint_list.value.add()
|
||||
param_value.tag = name
|
||||
param_tensor = param_value.tensor
|
||||
param_tensor.dims.extend(value[0])
|
||||
param_tensor.tensor_type = value[1]
|
||||
param_tensor.tensor_content = value[2].tostring()
|
||||
|
||||
with open(ckpt_file_name, "wb") as f:
|
||||
f.write(checkpoint_list.SerializeToString())
|
||||
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
|
||||
raise RuntimeError(e.__str__())
|
||||
|
||||
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
|
||||
"""
|
||||
Saves checkpoint info to a specified file.
|
||||
|
||||
|
@ -109,37 +133,37 @@ def save_checkpoint(parameter_list, ckpt_file_name):
|
|||
parameter_list (list): Parameters list, each element is a dict
|
||||
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
||||
ckpt_file_name (str): Checkpoint file name.
|
||||
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
|
||||
|
||||
Raises:
|
||||
RuntimeError: Failed to save the Checkpoint file.
|
||||
"""
|
||||
logger.info("Execute save checkpoint process.")
|
||||
checkpoint_list = Checkpoint()
|
||||
|
||||
try:
|
||||
data_list = {}
|
||||
with _ckpt_mutex:
|
||||
for param in parameter_list:
|
||||
param_value = checkpoint_list.value.add()
|
||||
param_value.tag = param["name"]
|
||||
param_tensor = param_value.tensor
|
||||
key = param["name"]
|
||||
data_list[key] = []
|
||||
if isinstance(param["data"], Parameter):
|
||||
param["data"].init_data()
|
||||
param_data = param["data"].asnumpy().reshape(-1)
|
||||
param_tensor.tensor_content = param_data.tostring()
|
||||
param_tensor.tensor_type = str(param["data"].dtype)
|
||||
|
||||
dims = []
|
||||
if param['data'].shape == ():
|
||||
param_tensor.dims.append(0)
|
||||
dims.append(0)
|
||||
else:
|
||||
for dim in param['data'].shape:
|
||||
param_tensor.dims.append(dim)
|
||||
dims.append(dim)
|
||||
data_list[key].append(dims)
|
||||
tensor_type = str(param["data"].dtype)
|
||||
data_list[key].append(tensor_type)
|
||||
data = param["data"].asnumpy().reshape(-1)
|
||||
data_list[key].append(data)
|
||||
|
||||
with open(ckpt_file_name, "wb") as f:
|
||||
f.write(checkpoint_list.SerializeToString())
|
||||
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
|
||||
raise RuntimeError(e.__str__())
|
||||
if async_save:
|
||||
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list))
|
||||
thr.start()
|
||||
else:
|
||||
_exec_save(ckpt_file_name, data_list)
|
||||
logger.info("Save checkpoint process finish.")
|
||||
|
||||
|
||||
|
@ -305,7 +329,7 @@ def _save_graph(network, file_name):
|
|||
os.chmod(file_name, stat.S_IRUSR)
|
||||
|
||||
|
||||
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
|
||||
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False):
|
||||
"""
|
||||
Saves checkpoint for 'ms' backend.
|
||||
|
||||
|
@ -313,6 +337,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
|
|||
train_network (Network): The train network for training.
|
||||
ckpt_file_name (str): The name of checkpoint file.
|
||||
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
||||
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
|
||||
"""
|
||||
|
||||
param_dict = {}
|
||||
|
@ -336,7 +361,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
|
|||
each_param["data"] = param_data
|
||||
param_list.append(each_param)
|
||||
|
||||
save_checkpoint(param_list, ckpt_file_name)
|
||||
save_checkpoint(param_list, ckpt_file_name, async_save)
|
||||
|
||||
|
||||
def _get_merged_param_data(net, param_name, param_data):
|
||||
|
|
Loading…
Reference in New Issue