forked from mindspore-Ecosystem/mindspore
!122 Support to config whether to save integeated checkpoint, in auto model parallel scene
Merge pull request !122 from WeibiaoYu/master
This commit is contained in:
commit
a24297f547
|
@ -374,9 +374,6 @@ class _Executor:
|
||||||
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
|
||||||
obj.load_parameter_slice(params)
|
obj.load_parameter_slice(params)
|
||||||
|
|
||||||
if _get_parallel_mode() in ["hybrid_parallel"]:
|
|
||||||
obj.parameter_layout_dict = self._build_parameter_layout(obj)
|
|
||||||
|
|
||||||
# the following GE init process is not needed when use vm or ms backend
|
# the following GE init process is not needed when use vm or ms backend
|
||||||
if enable_ge:
|
if enable_ge:
|
||||||
# decide whether to sink based on whether the inputs is virtual or not
|
# decide whether to sink based on whether the inputs is virtual or not
|
||||||
|
@ -449,38 +446,6 @@ class _Executor:
|
||||||
return self._exec_pip(obj, *args, phase=phase_real)
|
return self._exec_pip(obj, *args, phase=phase_real)
|
||||||
raise KeyError('{} graph is not exist.'.format(phase_real))
|
raise KeyError('{} graph is not exist.'.format(phase_real))
|
||||||
|
|
||||||
def _build_parameter_layout(self, obj):
|
|
||||||
"""
|
|
||||||
Build parameter layout, for layerwise_parallel parameter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
obj (Function or Cell): The function or cell instance need to be compiled.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary, parameter layout info.
|
|
||||||
"""
|
|
||||||
parameter_layout_dict = {}
|
|
||||||
layerwise_parallel_parameters = []
|
|
||||||
for key in obj.parameters_dict():
|
|
||||||
if obj.parameters_dict()[key].layerwise_parallel is True:
|
|
||||||
layerwise_parallel_parameters.append(key)
|
|
||||||
|
|
||||||
if not layerwise_parallel_parameters:
|
|
||||||
return parameter_layout_dict
|
|
||||||
|
|
||||||
from ..communication.management import get_group_size
|
|
||||||
group_size = [get_group_size()]
|
|
||||||
for key in layerwise_parallel_parameters:
|
|
||||||
tensor_map = [0]
|
|
||||||
shape = obj.parameters_dict()[key].data.shape()
|
|
||||||
for x in range(len(shape)): # dim 0 set 0, others set -1
|
|
||||||
if x:
|
|
||||||
tensor_map.append(-1)
|
|
||||||
layout = [group_size, tensor_map]
|
|
||||||
parameter_layout_dict[key] = layout
|
|
||||||
|
|
||||||
return parameter_layout_dict
|
|
||||||
|
|
||||||
def del_net_res(self, net_id):
|
def del_net_res(self, net_id):
|
||||||
self._executor.del_net_res(net_id)
|
self._executor.del_net_res(net_id)
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ import mindspore.context as context
|
||||||
from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
|
from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
|
||||||
from mindspore.train._utils import _make_directory
|
from mindspore.train._utils import _make_directory
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore._checkparam import check_int_non_negative
|
from mindspore._checkparam import check_int_non_negative, check_bool
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from .summary.summary_record import _cache_summary_tensor_data
|
from .summary.summary_record import _cache_summary_tensor_data
|
||||||
|
|
||||||
|
@ -150,6 +150,8 @@ class CheckpointConfig:
|
||||||
keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
|
keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
|
||||||
keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
|
keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
|
||||||
Can't be used with keep_checkpoint_max at the same time.
|
Can't be used with keep_checkpoint_max at the same time.
|
||||||
|
integrated_save (bool): Whether to intergrated save in automatic model parall scene. Default: True.
|
||||||
|
Integrated save function is only supported in automatic parall scene, not supported in manual parallel.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the input_param is None or 0.
|
ValueError: If the input_param is None or 0.
|
||||||
|
@ -163,7 +165,8 @@ class CheckpointConfig:
|
||||||
save_checkpoint_steps=1,
|
save_checkpoint_steps=1,
|
||||||
save_checkpoint_seconds=0,
|
save_checkpoint_seconds=0,
|
||||||
keep_checkpoint_max=5,
|
keep_checkpoint_max=5,
|
||||||
keep_checkpoint_per_n_minutes=0):
|
keep_checkpoint_per_n_minutes=0,
|
||||||
|
integrated_save=True):
|
||||||
|
|
||||||
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
||||||
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
||||||
|
@ -191,6 +194,8 @@ class CheckpointConfig:
|
||||||
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
|
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
|
||||||
self._keep_checkpoint_max = 1
|
self._keep_checkpoint_max = 1
|
||||||
|
|
||||||
|
self._integrated_save = check_bool(integrated_save)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def save_checkpoint_steps(self):
|
def save_checkpoint_steps(self):
|
||||||
"""Get the value of _save_checkpoint_steps."""
|
"""Get the value of _save_checkpoint_steps."""
|
||||||
|
@ -211,6 +216,11 @@ class CheckpointConfig:
|
||||||
"""Get the value of _keep_checkpoint_per_n_minutes."""
|
"""Get the value of _keep_checkpoint_per_n_minutes."""
|
||||||
return self._keep_checkpoint_per_n_minutes
|
return self._keep_checkpoint_per_n_minutes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def integrated_save(self):
|
||||||
|
"""Get the value of _integrated_save."""
|
||||||
|
return self._integrated_save
|
||||||
|
|
||||||
def get_checkpoint_policy(self):
|
def get_checkpoint_policy(self):
|
||||||
"""Get the policy of checkpoint."""
|
"""Get the policy of checkpoint."""
|
||||||
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
||||||
|
@ -619,7 +629,7 @@ class ModelCheckpoint(Callback):
|
||||||
_set_cur_net(cb_params.train_network)
|
_set_cur_net(cb_params.train_network)
|
||||||
cb_params.train_network.exec_checkpoint_graph()
|
cb_params.train_network.exec_checkpoint_graph()
|
||||||
|
|
||||||
_exec_save_checkpoint(cb_params.train_network, gen_file)
|
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
||||||
|
|
||||||
if os.path.exists(gen_file):
|
if os.path.exists(gen_file):
|
||||||
shutil.move(gen_file, cur_file)
|
shutil.move(gen_file, cur_file)
|
||||||
|
|
|
@ -279,13 +279,14 @@ def _save_graph(network, file_name):
|
||||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||||
|
|
||||||
|
|
||||||
def _exec_save_checkpoint(train_network, ckpoint_file_name):
|
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
|
||||||
"""
|
"""
|
||||||
Saves checkpoint for 'ms' backend.
|
Saves checkpoint for 'ms' backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_network (Network): The train network for training.
|
train_network (Network): The train network for training.
|
||||||
ckpoint_file_name (str): The name of checkpoint file.
|
ckpoint_file_name (str): The name of checkpoint file.
|
||||||
|
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
param_dict = {}
|
param_dict = {}
|
||||||
|
@ -300,9 +301,9 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name):
|
||||||
else:
|
else:
|
||||||
param_data = Tensor(value.data)
|
param_data = Tensor(value.data)
|
||||||
|
|
||||||
# in model parallel scenario, some parameters were spliteds to all the devices,
|
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
|
||||||
# which should be combined before saving
|
# which should be combined before saving
|
||||||
if key in train_network.parameter_layout_dict:
|
if integrated_save and key in train_network.parameter_layout_dict:
|
||||||
param_data = _get_merged_param_data(train_network, key, param_data)
|
param_data = _get_merged_param_data(train_network, key, param_data)
|
||||||
|
|
||||||
each_param["data"] = param_data
|
each_param["data"] = param_data
|
||||||
|
|
|
@ -308,10 +308,10 @@ def test_RunContext():
|
||||||
def test_Checkpoint_Config():
|
def test_Checkpoint_Config():
|
||||||
"""Test CheckpointConfig all None or 0."""
|
"""Test CheckpointConfig all None or 0."""
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
CheckpointConfig(0, 0, 0, 0)
|
CheckpointConfig(0, 0, 0, 0, True)
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
CheckpointConfig(0, None, 0, 0)
|
CheckpointConfig(0, None, 0, 0, True)
|
||||||
|
|
||||||
|
|
||||||
def test_step_end_save_graph():
|
def test_step_end_save_graph():
|
||||||
|
|
Loading…
Reference in New Issue