forked from mindspore-Ecosystem/mindspore
!5482 modify save_checkpoint
Merge pull request !5482 from liuyang/md_save_checkpoint
This commit is contained in:
commit
4ec343961e
|
@ -23,7 +23,7 @@ import mindspore.context as context
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore._checkparam import check_bool, check_int_non_negative
|
from mindspore._checkparam import check_bool, check_int_non_negative
|
||||||
from mindspore.train._utils import _make_directory
|
from mindspore.train._utils import _make_directory
|
||||||
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
|
from mindspore.train.serialization import save_checkpoint, _save_graph
|
||||||
from ._callback import Callback, set_cur_net
|
from ._callback import Callback, set_cur_net
|
||||||
|
|
||||||
|
|
||||||
|
@ -306,8 +306,8 @@ class ModelCheckpoint(Callback):
|
||||||
set_cur_net(cb_params.train_network)
|
set_cur_net(cb_params.train_network)
|
||||||
cb_params.train_network.exec_checkpoint_graph()
|
cb_params.train_network.exec_checkpoint_graph()
|
||||||
|
|
||||||
_exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
|
save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
|
||||||
self._config.async_save)
|
self._config.async_save)
|
||||||
|
|
||||||
self._latest_ckpt_file_name = cur_file
|
self._latest_ckpt_file_name = cur_file
|
||||||
|
|
||||||
|
|
|
@ -141,24 +141,52 @@ def _exec_save(ckpt_file_name, data_list):
|
||||||
raise RuntimeError(e.__str__())
|
raise RuntimeError(e.__str__())
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False):
|
||||||
"""
|
"""
|
||||||
Saves checkpoint info to a specified file.
|
Saves checkpoint info to a specified file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
parameter_list (list): Parameters list, each element is a dictionary
|
save_obj (nn.Cell or list): The train network for training or parameters list(each element is a dictionary,
|
||||||
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
like {"name":xx, "type":xx, "shape":xx, "data":xx}.)
|
||||||
ckpt_file_name (str): Checkpoint file name.
|
ckpt_file_name (str): Checkpoint file name.
|
||||||
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
||||||
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
|
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
TypeError: If the parameter save_obj is not nn.Cell or list type.
|
||||||
RuntimeError: Failed to save the Checkpoint file.
|
RuntimeError: Failed to save the Checkpoint file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
|
||||||
|
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
|
||||||
|
|
||||||
logger.info("Execute save checkpoint process.")
|
logger.info("Execute save checkpoint process.")
|
||||||
|
|
||||||
|
if isinstance(save_obj, nn.Cell):
|
||||||
|
save_obj.init_parameters_data()
|
||||||
|
param_dict = {}
|
||||||
|
for _, param in save_obj.parameters_and_names():
|
||||||
|
param_dict[param.name] = param
|
||||||
|
param_list = []
|
||||||
|
for (key, value) in param_dict.items():
|
||||||
|
each_param = {"name": key}
|
||||||
|
if isinstance(value.data, Tensor):
|
||||||
|
param_data = value.data
|
||||||
|
else:
|
||||||
|
param_data = Tensor(value.data)
|
||||||
|
|
||||||
|
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
|
||||||
|
# which should be combined before saving
|
||||||
|
if integrated_save and key in save_obj.parameter_layout_dict:
|
||||||
|
param_data = _get_merged_param_data(save_obj, key, param_data)
|
||||||
|
|
||||||
|
each_param["data"] = param_data
|
||||||
|
param_list.append(each_param)
|
||||||
|
save_obj = param_list
|
||||||
|
|
||||||
data_list = {}
|
data_list = {}
|
||||||
with _ckpt_mutex:
|
with _ckpt_mutex:
|
||||||
for param in parameter_list:
|
for param in save_obj:
|
||||||
key = param["name"]
|
key = param["name"]
|
||||||
data_list[key] = []
|
data_list[key] = []
|
||||||
if isinstance(param["data"], Parameter):
|
if isinstance(param["data"], Parameter):
|
||||||
|
@ -180,6 +208,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
|
||||||
thr.start()
|
thr.start()
|
||||||
else:
|
else:
|
||||||
_exec_save(ckpt_file_name, data_list)
|
_exec_save(ckpt_file_name, data_list)
|
||||||
|
|
||||||
logger.info("Save checkpoint process finish.")
|
logger.info("Save checkpoint process finish.")
|
||||||
|
|
||||||
|
|
||||||
|
@ -354,39 +383,6 @@ def _save_graph(network, file_name):
|
||||||
os.chmod(file_name, stat.S_IRUSR)
|
os.chmod(file_name, stat.S_IRUSR)
|
||||||
|
|
||||||
|
|
||||||
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False):
|
|
||||||
"""
|
|
||||||
Saves checkpoint for 'ms' backend.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_network (Network): The train network for training.
|
|
||||||
ckpt_file_name (str): The name of checkpoint file.
|
|
||||||
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
|
||||||
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
|
|
||||||
"""
|
|
||||||
train_network.init_parameters_data()
|
|
||||||
param_dict = {}
|
|
||||||
for _, param in train_network.parameters_and_names():
|
|
||||||
param_dict[param.name] = param
|
|
||||||
param_list = []
|
|
||||||
for (key, value) in param_dict.items():
|
|
||||||
each_param = {"name": key}
|
|
||||||
if isinstance(value.data, Tensor):
|
|
||||||
param_data = value.data
|
|
||||||
else:
|
|
||||||
param_data = Tensor(value.data)
|
|
||||||
|
|
||||||
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
|
|
||||||
# which should be combined before saving
|
|
||||||
if integrated_save and key in train_network.parameter_layout_dict:
|
|
||||||
param_data = _get_merged_param_data(train_network, key, param_data)
|
|
||||||
|
|
||||||
each_param["data"] = param_data
|
|
||||||
param_list.append(each_param)
|
|
||||||
|
|
||||||
save_checkpoint(param_list, ckpt_file_name, async_save)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_merged_param_data(net, param_name, param_data):
|
def _get_merged_param_data(net, param_name, param_data):
|
||||||
"""
|
"""
|
||||||
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
|
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
|
||||||
|
|
|
@ -18,7 +18,7 @@ import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore.train.serialization import _exec_save_checkpoint, load_checkpoint
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
||||||
|
|
||||||
from src.config import GatConfig
|
from src.config import GatConfig
|
||||||
from src.dataset import load_and_process
|
from src.dataset import load_and_process
|
||||||
|
@ -98,7 +98,7 @@ def train():
|
||||||
val_loss_model = eval_loss
|
val_loss_model = eval_loss
|
||||||
if os.path.exists("ckpts/gat.ckpt"):
|
if os.path.exists("ckpts/gat.ckpt"):
|
||||||
os.remove("ckpts/gat.ckpt")
|
os.remove("ckpts/gat.ckpt")
|
||||||
_exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt")
|
save_checkpoint(train_net.network, "ckpts/gat.ckpt")
|
||||||
val_acc_max = np.max((val_acc_max, eval_acc))
|
val_acc_max = np.max((val_acc_max, eval_acc))
|
||||||
val_loss_min = np.min((val_loss_min, eval_loss))
|
val_loss_min = np.min((val_loss_min, eval_loss))
|
||||||
curr_step = 0
|
curr_step = 0
|
||||||
|
|
|
@ -20,7 +20,7 @@ import numpy as np
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
from mindspore.train.callback import Callback
|
from mindspore.train.callback import Callback
|
||||||
from mindspore.train.serialization import _exec_save_checkpoint
|
from mindspore.train.serialization import save_checkpoint
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
|
||||||
from .assessment_method import Accuracy
|
from .assessment_method import Accuracy
|
||||||
|
@ -53,9 +53,9 @@ class ModelSaveCkpt(Callback):
|
||||||
self.save_ckpt_step))
|
self.save_ckpt_step))
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
_exec_save_checkpoint(self.network, os.path.join(self.output_dir,
|
save_checkpoint(self.network, os.path.join(self.output_dir,
|
||||||
"tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
|
"tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
|
||||||
self.save_ckpt_step)))
|
self.save_ckpt_step)))
|
||||||
|
|
||||||
class LossCallBack(Callback):
|
class LossCallBack(Callback):
|
||||||
"""
|
"""
|
||||||
|
@ -113,7 +113,7 @@ class EvalCallBack(Callback):
|
||||||
eval_model_ckpt_file = "eval_model.ckpt"
|
eval_model_ckpt_file = "eval_model.ckpt"
|
||||||
if os.path.exists(eval_model_ckpt_file):
|
if os.path.exists(eval_model_ckpt_file):
|
||||||
os.remove(eval_model_ckpt_file)
|
os.remove(eval_model_ckpt_file)
|
||||||
_exec_save_checkpoint(self.network, eval_model_ckpt_file)
|
save_checkpoint(self.network, eval_model_ckpt_file)
|
||||||
|
|
||||||
class BertLearningRate(LearningRateSchedule):
|
class BertLearningRate(LearningRateSchedule):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -31,7 +31,7 @@ from mindspore.nn.optim.momentum import Momentum
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.train.callback import _CheckpointManager
|
from mindspore.train.callback import _CheckpointManager
|
||||||
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
|
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
|
||||||
_exec_save_checkpoint, export, _save_graph
|
export, _save_graph
|
||||||
from ..ut_filter import non_graph_engine
|
from ..ut_filter import non_graph_engine
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
|
||||||
|
@ -95,8 +95,8 @@ def test_save_graph():
|
||||||
os.remove(output_file)
|
os.remove(output_file)
|
||||||
|
|
||||||
|
|
||||||
def test_save_checkpoint():
|
def test_save_checkpoint_for_list():
|
||||||
""" test_save_checkpoint """
|
""" test save_checkpoint for list"""
|
||||||
parameter_list = []
|
parameter_list = []
|
||||||
one_param = {}
|
one_param = {}
|
||||||
param1 = {}
|
param1 = {}
|
||||||
|
@ -280,14 +280,15 @@ def test_load_param_into_net():
|
||||||
assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1
|
assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1
|
||||||
|
|
||||||
|
|
||||||
def test_exec_save_checkpoint():
|
def test_save_checkpoint_for_network():
|
||||||
|
""" test save_checkpoint for network"""
|
||||||
net = Net()
|
net = Net()
|
||||||
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
|
||||||
opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
|
opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
|
||||||
|
|
||||||
loss_net = WithLossCell(net, loss)
|
loss_net = WithLossCell(net, loss)
|
||||||
train_network = TrainOneStepCell(loss_net, opt)
|
train_network = TrainOneStepCell(loss_net, opt)
|
||||||
_exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
|
save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
|
||||||
|
|
||||||
load_checkpoint("new_ckpt.ckpt")
|
load_checkpoint("new_ckpt.ckpt")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue