!122 Support to config whether to save integeated checkpoint, in auto model parallel scene
Merge pull request !122 from WeibiaoYu/master
This commit is contained in:
commit
a24297f547
|
@ -374,9 +374,6 @@ class _Executor:
|
|||
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
||||
obj.load_parameter_slice(params)
|
||||
|
||||
if _get_parallel_mode() in ["hybrid_parallel"]:
|
||||
obj.parameter_layout_dict = self._build_parameter_layout(obj)
|
||||
|
||||
# the following GE init process is not needed when use vm or ms backend
|
||||
if enable_ge:
|
||||
# decide whether to sink based on whether the inputs is virtual or not
|
||||
|
@ -449,38 +446,6 @@ class _Executor:
|
|||
return self._exec_pip(obj, *args, phase=phase_real)
|
||||
raise KeyError('{} graph is not exist.'.format(phase_real))
|
||||
|
||||
def _build_parameter_layout(self, obj):
|
||||
"""
|
||||
Build parameter layout, for layerwise_parallel parameter.
|
||||
|
||||
Args:
|
||||
obj (Function or Cell): The function or cell instance need to be compiled.
|
||||
|
||||
Returns:
|
||||
Dictionary, parameter layout info.
|
||||
"""
|
||||
parameter_layout_dict = {}
|
||||
layerwise_parallel_parameters = []
|
||||
for key in obj.parameters_dict():
|
||||
if obj.parameters_dict()[key].layerwise_parallel is True:
|
||||
layerwise_parallel_parameters.append(key)
|
||||
|
||||
if not layerwise_parallel_parameters:
|
||||
return parameter_layout_dict
|
||||
|
||||
from ..communication.management import get_group_size
|
||||
group_size = [get_group_size()]
|
||||
for key in layerwise_parallel_parameters:
|
||||
tensor_map = [0]
|
||||
shape = obj.parameters_dict()[key].data.shape()
|
||||
for x in range(len(shape)): # dim 0 set 0, others set -1
|
||||
if x:
|
||||
tensor_map.append(-1)
|
||||
layout = [group_size, tensor_map]
|
||||
parameter_layout_dict[key] = layout
|
||||
|
||||
return parameter_layout_dict
|
||||
|
||||
def del_net_res(self, net_id):
|
||||
self._executor.del_net_res(net_id)
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ import mindspore.context as context
|
|||
from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import check_int_non_negative
|
||||
from mindspore._checkparam import check_int_non_negative, check_bool
|
||||
from mindspore.common.tensor import Tensor
|
||||
from .summary.summary_record import _cache_summary_tensor_data
|
||||
|
||||
|
@ -150,6 +150,8 @@ class CheckpointConfig:
|
|||
keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
|
||||
keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
|
||||
Can't be used with keep_checkpoint_max at the same time.
|
||||
integrated_save (bool): Whether to intergrated save in automatic model parall scene. Default: True.
|
||||
Integrated save function is only supported in automatic parall scene, not supported in manual parallel.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input_param is None or 0.
|
||||
|
@ -163,7 +165,8 @@ class CheckpointConfig:
|
|||
save_checkpoint_steps=1,
|
||||
save_checkpoint_seconds=0,
|
||||
keep_checkpoint_max=5,
|
||||
keep_checkpoint_per_n_minutes=0):
|
||||
keep_checkpoint_per_n_minutes=0,
|
||||
integrated_save=True):
|
||||
|
||||
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
||||
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
||||
|
@ -191,6 +194,8 @@ class CheckpointConfig:
|
|||
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
|
||||
self._keep_checkpoint_max = 1
|
||||
|
||||
self._integrated_save = check_bool(integrated_save)
|
||||
|
||||
@property
|
||||
def save_checkpoint_steps(self):
|
||||
"""Get the value of _save_checkpoint_steps."""
|
||||
|
@ -211,6 +216,11 @@ class CheckpointConfig:
|
|||
"""Get the value of _keep_checkpoint_per_n_minutes."""
|
||||
return self._keep_checkpoint_per_n_minutes
|
||||
|
||||
@property
|
||||
def integrated_save(self):
|
||||
"""Get the value of _integrated_save."""
|
||||
return self._integrated_save
|
||||
|
||||
def get_checkpoint_policy(self):
|
||||
"""Get the policy of checkpoint."""
|
||||
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
||||
|
@ -619,7 +629,7 @@ class ModelCheckpoint(Callback):
|
|||
_set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
|
||||
_exec_save_checkpoint(cb_params.train_network, gen_file)
|
||||
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
||||
|
||||
if os.path.exists(gen_file):
|
||||
shutil.move(gen_file, cur_file)
|
||||
|
|
|
@ -279,13 +279,14 @@ def _save_graph(network, file_name):
|
|||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
|
||||
|
||||
def _exec_save_checkpoint(train_network, ckpoint_file_name):
|
||||
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
|
||||
"""
|
||||
Saves checkpoint for 'ms' backend.
|
||||
|
||||
Args:
|
||||
train_network (Network): The train network for training.
|
||||
ckpoint_file_name (str): The name of checkpoint file.
|
||||
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
|
||||
"""
|
||||
|
||||
param_dict = {}
|
||||
|
@ -300,9 +301,9 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name):
|
|||
else:
|
||||
param_data = Tensor(value.data)
|
||||
|
||||
# in model parallel scenario, some parameters were spliteds to all the devices,
|
||||
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
|
||||
# which should be combined before saving
|
||||
if key in train_network.parameter_layout_dict:
|
||||
if integrated_save and key in train_network.parameter_layout_dict:
|
||||
param_data = _get_merged_param_data(train_network, key, param_data)
|
||||
|
||||
each_param["data"] = param_data
|
||||
|
|
|
@ -308,10 +308,10 @@ def test_RunContext():
|
|||
def test_Checkpoint_Config():
|
||||
"""Test CheckpointConfig all None or 0."""
|
||||
with pytest.raises(ValueError):
|
||||
CheckpointConfig(0, 0, 0, 0)
|
||||
CheckpointConfig(0, 0, 0, 0, True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
CheckpointConfig(0, None, 0, 0)
|
||||
CheckpointConfig(0, None, 0, 0, True)
|
||||
|
||||
|
||||
def test_step_end_save_graph():
|
||||
|
|
Loading…
Reference in New Issue