!5482 modify save_checkpoint

Merge pull request !5482 from liuyang/md_save_checkpoint
This commit is contained in:
mindspore-ci-bot 2020-09-02 14:20:48 +08:00 committed by Gitee
commit 4ec343961e
5 changed files with 49 additions and 52 deletions

View File

@ -23,7 +23,7 @@ import mindspore.context as context
from mindspore import log as logger from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative from mindspore._checkparam import check_bool, check_int_non_negative
from mindspore.train._utils import _make_directory from mindspore.train._utils import _make_directory
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph from mindspore.train.serialization import save_checkpoint, _save_graph
from ._callback import Callback, set_cur_net from ._callback import Callback, set_cur_net
@ -306,8 +306,8 @@ class ModelCheckpoint(Callback):
set_cur_net(cb_params.train_network) set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph() cb_params.train_network.exec_checkpoint_graph()
_exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save, save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
self._config.async_save) self._config.async_save)
self._latest_ckpt_file_name = cur_file self._latest_ckpt_file_name = cur_file

View File

@ -141,24 +141,52 @@ def _exec_save(ckpt_file_name, data_list):
raise RuntimeError(e.__str__()) raise RuntimeError(e.__str__())
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False):
""" """
Saves checkpoint info to a specified file. Saves checkpoint info to a specified file.
Args: Args:
parameter_list (list): Parameters list, each element is a dictionary save_obj (nn.Cell or list): The train network for training or parameters list(each element is a dictionary,
like {"name":xx, "type":xx, "shape":xx, "data":xx}. like {"name":xx, "type":xx, "shape":xx, "data":xx}.)
ckpt_file_name (str): Checkpoint file name. ckpt_file_name (str): Checkpoint file name.
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
Raises: Raises:
TypeError: If the parameter save_obj is not nn.Cell or list type.
RuntimeError: Failed to save the Checkpoint file. RuntimeError: Failed to save the Checkpoint file.
""" """
if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
logger.info("Execute save checkpoint process.") logger.info("Execute save checkpoint process.")
if isinstance(save_obj, nn.Cell):
save_obj.init_parameters_data()
param_dict = {}
for _, param in save_obj.parameters_and_names():
param_dict[param.name] = param
param_list = []
for (key, value) in param_dict.items():
each_param = {"name": key}
if isinstance(value.data, Tensor):
param_data = value.data
else:
param_data = Tensor(value.data)
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
# which should be combined before saving
if integrated_save and key in save_obj.parameter_layout_dict:
param_data = _get_merged_param_data(save_obj, key, param_data)
each_param["data"] = param_data
param_list.append(each_param)
save_obj = param_list
data_list = {} data_list = {}
with _ckpt_mutex: with _ckpt_mutex:
for param in parameter_list: for param in save_obj:
key = param["name"] key = param["name"]
data_list[key] = [] data_list[key] = []
if isinstance(param["data"], Parameter): if isinstance(param["data"], Parameter):
@ -180,6 +208,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
thr.start() thr.start()
else: else:
_exec_save(ckpt_file_name, data_list) _exec_save(ckpt_file_name, data_list)
logger.info("Save checkpoint process finish.") logger.info("Save checkpoint process finish.")
@ -354,39 +383,6 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IRUSR) os.chmod(file_name, stat.S_IRUSR)
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False):
"""
Saves checkpoint for 'ms' backend.
Args:
train_network (Network): The train network for training.
ckpt_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
"""
train_network.init_parameters_data()
param_dict = {}
for _, param in train_network.parameters_and_names():
param_dict[param.name] = param
param_list = []
for (key, value) in param_dict.items():
each_param = {"name": key}
if isinstance(value.data, Tensor):
param_data = value.data
else:
param_data = Tensor(value.data)
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
# which should be combined before saving
if integrated_save and key in train_network.parameter_layout_dict:
param_data = _get_merged_param_data(train_network, key, param_data)
each_param["data"] = param_data
param_list.append(each_param)
save_checkpoint(param_list, ckpt_file_name, async_save)
def _get_merged_param_data(net, param_name, param_data): def _get_merged_param_data(net, param_name, param_data):
""" """
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map. Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.

View File

@ -18,7 +18,7 @@ import os
import numpy as np import numpy as np
import mindspore.context as context import mindspore.context as context
from mindspore.train.serialization import _exec_save_checkpoint, load_checkpoint from mindspore.train.serialization import save_checkpoint, load_checkpoint
from src.config import GatConfig from src.config import GatConfig
from src.dataset import load_and_process from src.dataset import load_and_process
@ -98,7 +98,7 @@ def train():
val_loss_model = eval_loss val_loss_model = eval_loss
if os.path.exists("ckpts/gat.ckpt"): if os.path.exists("ckpts/gat.ckpt"):
os.remove("ckpts/gat.ckpt") os.remove("ckpts/gat.ckpt")
_exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt") save_checkpoint(train_net.network, "ckpts/gat.ckpt")
val_acc_max = np.max((val_acc_max, eval_acc)) val_acc_max = np.max((val_acc_max, eval_acc))
val_loss_min = np.min((val_loss_min, eval_loss)) val_loss_min = np.min((val_loss_min, eval_loss))
curr_step = 0 curr_step = 0

View File

@ -20,7 +20,7 @@ import numpy as np
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.train.serialization import _exec_save_checkpoint from mindspore.train.serialization import save_checkpoint
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
from .assessment_method import Accuracy from .assessment_method import Accuracy
@ -53,9 +53,9 @@ class ModelSaveCkpt(Callback):
self.save_ckpt_step)) self.save_ckpt_step))
if os.path.exists(path): if os.path.exists(path):
os.remove(path) os.remove(path)
_exec_save_checkpoint(self.network, os.path.join(self.output_dir, save_checkpoint(self.network, os.path.join(self.output_dir,
"tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num), "tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
self.save_ckpt_step))) self.save_ckpt_step)))
class LossCallBack(Callback): class LossCallBack(Callback):
""" """
@ -113,7 +113,7 @@ class EvalCallBack(Callback):
eval_model_ckpt_file = "eval_model.ckpt" eval_model_ckpt_file = "eval_model.ckpt"
if os.path.exists(eval_model_ckpt_file): if os.path.exists(eval_model_ckpt_file):
os.remove(eval_model_ckpt_file) os.remove(eval_model_ckpt_file)
_exec_save_checkpoint(self.network, eval_model_ckpt_file) save_checkpoint(self.network, eval_model_ckpt_file)
class BertLearningRate(LearningRateSchedule): class BertLearningRate(LearningRateSchedule):
""" """

View File

@ -31,7 +31,7 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train.callback import _CheckpointManager from mindspore.train.callback import _CheckpointManager
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \ from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
_exec_save_checkpoint, export, _save_graph export, _save_graph
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb") context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
@ -95,8 +95,8 @@ def test_save_graph():
os.remove(output_file) os.remove(output_file)
def test_save_checkpoint(): def test_save_checkpoint_for_list():
""" test_save_checkpoint """ """ test save_checkpoint for list"""
parameter_list = [] parameter_list = []
one_param = {} one_param = {}
param1 = {} param1 = {}
@ -280,14 +280,15 @@ def test_load_param_into_net():
assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1 assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1
def test_exec_save_checkpoint(): def test_save_checkpoint_for_network():
""" test save_checkpoint for network"""
net = Net() net = Net()
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024) opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
loss_net = WithLossCell(net, loss) loss_net = WithLossCell(net, loss)
train_network = TrainOneStepCell(loss_net, opt) train_network = TrainOneStepCell(loss_net, opt)
_exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt") save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
load_checkpoint("new_ckpt.ckpt") load_checkpoint("new_ckpt.ckpt")