forked from mindspore-Ecosystem/mindspore
!2447 asyn save checkpoint to file
Merge pull request !2447 from changzherui/asyn_ckpt_r0.3
This commit is contained in:
commit
3e3cbbba0f
|
@ -16,7 +16,6 @@
|
|||
|
||||
import os
|
||||
import stat
|
||||
import shutil
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
|
@ -625,8 +624,6 @@ class ModelCheckpoint(Callback):
|
|||
global _save_dir
|
||||
_save_dir = self._directory
|
||||
cur_file = os.path.join(self._directory, cur_ckpoint_file)
|
||||
tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt'
|
||||
gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process)
|
||||
self._last_time_for_keep = time.time()
|
||||
self._last_triggered_step = cb_params.cur_step_num
|
||||
|
||||
|
@ -634,10 +631,8 @@ class ModelCheckpoint(Callback):
|
|||
_set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
|
||||
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
||||
_exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save)
|
||||
|
||||
if os.path.exists(gen_file):
|
||||
shutil.move(gen_file, cur_file)
|
||||
self._latest_ckpt_file_name = cur_file
|
||||
|
||||
@property
|
||||
|
|
|
@ -84,12 +84,12 @@ class DatasetHelper:
|
|||
class _DatasetIter:
|
||||
"""Base iter for dataset help"""
|
||||
def __init__(self, dataset):
|
||||
self.loop_size = 1
|
||||
if not hasattr(dataset, '__loop_size__'):
|
||||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
|
||||
if not hasattr(dataset, '__ME_INITED__'):
|
||||
if not hasattr(dataset, '__loop_size__'):
|
||||
self.loop_size = dataset.get_dataset_size()
|
||||
else:
|
||||
self.loop_size = dataset.__loop_size__
|
||||
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
|
||||
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
"""Model and parameters serialization."""
|
||||
import os
|
||||
import stat
|
||||
from threading import Thread
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
|
@ -96,7 +97,23 @@ def _update_param(param, new_param):
|
|||
param.set_parameter_data(type(param.data)(new_param.data))
|
||||
|
||||
|
||||
def save_checkpoint(parameter_list, ckpoint_file_name):
|
||||
def asyn_thread(fun):
|
||||
def wrapper(*args, **kwargs):
|
||||
thr = Thread(target=fun, args=args, kwargs=kwargs)
|
||||
thr.start()
|
||||
return wrapper
|
||||
|
||||
|
||||
@asyn_thread
|
||||
def asyn_save_fun(ckpoint_file_name, checkpoint_list):
|
||||
logger.info("Asynchronous execute save checkpoint into file.")
|
||||
with open(ckpoint_file_name, "wb") as f:
|
||||
f.write(checkpoint_list.SerializeToString())
|
||||
os.chmod(ckpoint_file_name, stat.S_IRUSR)
|
||||
logger.info("Asynchronous save checkpoint into file process finish.")
|
||||
|
||||
|
||||
def save_checkpoint(parameter_list, ckpoint_file_name, asyn_exec=False):
|
||||
"""
|
||||
Saves checkpoint info to a specified file.
|
||||
|
||||
|
@ -104,6 +121,7 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
|
|||
parameter_list (list): Parameters list, each element is a dict
|
||||
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
||||
ckpoint_file_name (str): Checkpoint file name.
|
||||
asyn_exec (bool): Whether asynchronous execute save checkpoint into file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Failed to save the Checkpoint file.
|
||||
|
@ -127,10 +145,12 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
|
|||
else:
|
||||
for dim in param['data'].shape():
|
||||
param_tensor.dims.append(dim)
|
||||
|
||||
with open(ckpoint_file_name, "wb") as f:
|
||||
f.write(checkpoint_list.SerializeToString())
|
||||
os.chmod(ckpoint_file_name, stat.S_IRUSR)
|
||||
if asyn_exec:
|
||||
asyn_save_fun(ckpoint_file_name, checkpoint_list)
|
||||
else:
|
||||
with open(ckpoint_file_name, "wb") as f:
|
||||
f.write(checkpoint_list.SerializeToString())
|
||||
os.chmod(ckpoint_file_name, stat.S_IRUSR)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name)
|
||||
|
@ -298,7 +318,7 @@ def _save_graph(network, file_name):
|
|||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
|
||||
|
||||
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
|
||||
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True, asyn_save=False):
|
||||
"""
|
||||
Saves checkpoint for 'ms' backend.
|
||||
|
||||
|
@ -329,7 +349,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
|
|||
each_param["data"] = param_data
|
||||
param_list.append(each_param)
|
||||
|
||||
save_checkpoint(param_list, ckpoint_file_name)
|
||||
save_checkpoint(param_list, ckpoint_file_name, asyn_save)
|
||||
|
||||
|
||||
def _get_merged_param_data(net, param_name, param_data):
|
||||
|
|
Loading…
Reference in New Issue