forked from mindspore-Ecosystem/mindspore
!5482 modify save_checkpoint
Merge pull request !5482 from liuyang/md_save_checkpoint
This commit is contained in:
commit
4ec343961e
|
@ -23,7 +23,7 @@ import mindspore.context as context
|
|||
from mindspore import log as logger
|
||||
from mindspore._checkparam import check_bool, check_int_non_negative
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
|
||||
from mindspore.train.serialization import save_checkpoint, _save_graph
|
||||
from ._callback import Callback, set_cur_net
|
||||
|
||||
|
||||
|
@ -306,8 +306,8 @@ class ModelCheckpoint(Callback):
|
|||
set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
|
||||
_exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
|
||||
self._config.async_save)
|
||||
save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
|
||||
self._config.async_save)
|
||||
|
||||
self._latest_ckpt_file_name = cur_file
|
||||
|
||||
|
|
|
@ -141,24 +141,52 @@ def _exec_save(ckpt_file_name, data_list):
|
|||
raise RuntimeError(e.__str__())
|
||||
|
||||
|
||||
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
|
||||
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False):
|
||||
"""
|
||||
Saves checkpoint info to a specified file.
|
||||
|
||||
Args:
|
||||
parameter_list (list): Parameters list, each element is a dictionary
|
||||
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
||||
save_obj (nn.Cell or list): The train network for training or parameters list(each element is a dictionary,
|
||||
like {"name":xx, "type":xx, "shape":xx, "data":xx}.)
|
||||
ckpt_file_name (str): Checkpoint file name.
|
||||
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
||||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
|
||||
|
||||
Raises:
|
||||
TypeError: If the parameter save_obj is not nn.Cell or list type.
|
||||
RuntimeError: Failed to save the Checkpoint file.
|
||||
"""
|
||||
|
||||
if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
|
||||
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
|
||||
|
||||
logger.info("Execute save checkpoint process.")
|
||||
|
||||
if isinstance(save_obj, nn.Cell):
|
||||
save_obj.init_parameters_data()
|
||||
param_dict = {}
|
||||
for _, param in save_obj.parameters_and_names():
|
||||
param_dict[param.name] = param
|
||||
param_list = []
|
||||
for (key, value) in param_dict.items():
|
||||
each_param = {"name": key}
|
||||
if isinstance(value.data, Tensor):
|
||||
param_data = value.data
|
||||
else:
|
||||
param_data = Tensor(value.data)
|
||||
|
||||
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
|
||||
# which should be combined before saving
|
||||
if integrated_save and key in save_obj.parameter_layout_dict:
|
||||
param_data = _get_merged_param_data(save_obj, key, param_data)
|
||||
|
||||
each_param["data"] = param_data
|
||||
param_list.append(each_param)
|
||||
save_obj = param_list
|
||||
|
||||
data_list = {}
|
||||
with _ckpt_mutex:
|
||||
for param in parameter_list:
|
||||
for param in save_obj:
|
||||
key = param["name"]
|
||||
data_list[key] = []
|
||||
if isinstance(param["data"], Parameter):
|
||||
|
@ -180,6 +208,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
|
|||
thr.start()
|
||||
else:
|
||||
_exec_save(ckpt_file_name, data_list)
|
||||
|
||||
logger.info("Save checkpoint process finish.")
|
||||
|
||||
|
||||
|
@ -354,39 +383,6 @@ def _save_graph(network, file_name):
|
|||
os.chmod(file_name, stat.S_IRUSR)
|
||||
|
||||
|
||||
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False):
|
||||
"""
|
||||
Saves checkpoint for 'ms' backend.
|
||||
|
||||
Args:
|
||||
train_network (Network): The train network for training.
|
||||
ckpt_file_name (str): The name of checkpoint file.
|
||||
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
||||
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
|
||||
"""
|
||||
train_network.init_parameters_data()
|
||||
param_dict = {}
|
||||
for _, param in train_network.parameters_and_names():
|
||||
param_dict[param.name] = param
|
||||
param_list = []
|
||||
for (key, value) in param_dict.items():
|
||||
each_param = {"name": key}
|
||||
if isinstance(value.data, Tensor):
|
||||
param_data = value.data
|
||||
else:
|
||||
param_data = Tensor(value.data)
|
||||
|
||||
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
|
||||
# which should be combined before saving
|
||||
if integrated_save and key in train_network.parameter_layout_dict:
|
||||
param_data = _get_merged_param_data(train_network, key, param_data)
|
||||
|
||||
each_param["data"] = param_data
|
||||
param_list.append(each_param)
|
||||
|
||||
save_checkpoint(param_list, ckpt_file_name, async_save)
|
||||
|
||||
|
||||
def _get_merged_param_data(net, param_name, param_data):
|
||||
"""
|
||||
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
|
||||
|
|
|
@ -18,7 +18,7 @@ import os
|
|||
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from mindspore.train.serialization import _exec_save_checkpoint, load_checkpoint
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||
|
||||
from src.config import GatConfig
|
||||
from src.dataset import load_and_process
|
||||
|
@ -98,7 +98,7 @@ def train():
|
|||
val_loss_model = eval_loss
|
||||
if os.path.exists("ckpts/gat.ckpt"):
|
||||
os.remove("ckpts/gat.ckpt")
|
||||
_exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt")
|
||||
save_checkpoint(train_net.network, "ckpts/gat.ckpt")
|
||||
val_acc_max = np.max((val_acc_max, eval_acc))
|
||||
val_loss_min = np.min((val_loss_min, eval_loss))
|
||||
curr_step = 0
|
||||
|
|
|
@ -20,7 +20,7 @@ import numpy as np
|
|||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.train.callback import Callback
|
||||
from mindspore.train.serialization import _exec_save_checkpoint
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||
from .assessment_method import Accuracy
|
||||
|
@ -53,9 +53,9 @@ class ModelSaveCkpt(Callback):
|
|||
self.save_ckpt_step))
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
_exec_save_checkpoint(self.network, os.path.join(self.output_dir,
|
||||
"tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
|
||||
self.save_ckpt_step)))
|
||||
save_checkpoint(self.network, os.path.join(self.output_dir,
|
||||
"tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
|
||||
self.save_ckpt_step)))
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
|
@ -113,7 +113,7 @@ class EvalCallBack(Callback):
|
|||
eval_model_ckpt_file = "eval_model.ckpt"
|
||||
if os.path.exists(eval_model_ckpt_file):
|
||||
os.remove(eval_model_ckpt_file)
|
||||
_exec_save_checkpoint(self.network, eval_model_ckpt_file)
|
||||
save_checkpoint(self.network, eval_model_ckpt_file)
|
||||
|
||||
class BertLearningRate(LearningRateSchedule):
|
||||
"""
|
||||
|
|
|
@ -31,7 +31,7 @@ from mindspore.nn.optim.momentum import Momentum
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.train.callback import _CheckpointManager
|
||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
|
||||
_exec_save_checkpoint, export, _save_graph
|
||||
export, _save_graph
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||
|
@ -95,8 +95,8 @@ def test_save_graph():
|
|||
os.remove(output_file)
|
||||
|
||||
|
||||
def test_save_checkpoint():
|
||||
""" test_save_checkpoint """
|
||||
def test_save_checkpoint_for_list():
|
||||
""" test save_checkpoint for list"""
|
||||
parameter_list = []
|
||||
one_param = {}
|
||||
param1 = {}
|
||||
|
@ -280,14 +280,15 @@ def test_load_param_into_net():
|
|||
assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1
|
||||
|
||||
|
||||
def test_exec_save_checkpoint():
|
||||
def test_save_checkpoint_for_network():
|
||||
""" test save_checkpoint for network"""
|
||||
net = Net()
|
||||
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||
opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
|
||||
|
||||
loss_net = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(loss_net, opt)
|
||||
_exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
|
||||
save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
|
||||
|
||||
load_checkpoint("new_ckpt.ckpt")
|
||||
|
||||
|
|
Loading…
Reference in New Issue