forked from mindspore-Ecosystem/mindspore
modify grad accu and comm fusion api
This commit is contained in:
parent
1d505ebad3
commit
d19d42ee44
|
@ -288,20 +288,20 @@ class Parameter(Tensor_):
|
|||
|
||||
@property
|
||||
def comm_fusion(self):
|
||||
"""Get the fusion type for communication operators corresponding to this parameter."""
|
||||
"""
|
||||
Get and Set the fusion type (int) for communication operators corresponding to this parameter.
|
||||
|
||||
In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or
|
||||
gradients aggregation are inserted automatically. Set the fusion type for communication operators generated
|
||||
for this parameter. The value of fusion must be greater than or equal to 0. When the value of fusion is 0,
|
||||
operators will not be fused together.
|
||||
|
||||
Only `Ascend` and `Graph` mode is supported.
|
||||
"""
|
||||
return self.param_info.comm_fusion
|
||||
|
||||
@comm_fusion.setter
|
||||
def comm_fusion(self, comm_fusion_):
|
||||
"""
|
||||
In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or
|
||||
gradients aggregation are inserted automatically.Set the fusion type for communication operators generated
|
||||
for this parameter. Only `Ascend` and `Graph` mode is supported.
|
||||
|
||||
Args:
|
||||
comm_fusion_ (int): The value of fusion must be greater than or equal to 0.
|
||||
When the value of fusion is 0, operators will not be fused together.
|
||||
"""
|
||||
if context.get_context("mode") == context.PYNATIVE_MODE and "auto_parallel" in _get_parallel_mode():
|
||||
raise RuntimeError("`comm_fusion` does not support PYNATIVE_MODE")
|
||||
Validator.check_non_negative_int(comm_fusion_)
|
||||
|
|
|
@ -344,7 +344,7 @@ def _context():
|
|||
@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
|
||||
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
all_reduce_fusion_config=list, pipeline_stages=int)
|
||||
all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int)
|
||||
def set_auto_parallel_context(**kwargs):
|
||||
r"""
|
||||
Set auto parallel context, which is valid only for Ascend and GPU target.
|
||||
|
@ -371,6 +371,7 @@ def set_auto_parallel_context(**kwargs):
|
|||
all_reduce_fusion_config strategy_ckpt_save_file
|
||||
enable_parallel_optimizer full_batch
|
||||
\ pipeline_stages
|
||||
\ grad_accumulation_step
|
||||
=========================== ===========================
|
||||
|
||||
Args:
|
||||
|
@ -420,6 +421,8 @@ def set_auto_parallel_context(**kwargs):
|
|||
the devices are distributed alone the pipeline. The total devices will be divided into
|
||||
'pipeline_stags' stages. This currently could only be used when
|
||||
parallel mode semi_auto_parallel is enabled. Default: 1.
|
||||
grad_accumulation_step (int): Set the accumulation steps of gradients in auto and semi auto parallel mode.
|
||||
This should be a positive int. Default: 1.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not attribute in auto parallel context.
|
||||
|
|
|
@ -18,7 +18,7 @@ import mindspore.context as context
|
|||
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
|
||||
from mindspore.parallel._ps_context import _is_role_pserver
|
||||
from mindspore._c_expression import AutoParallelContext
|
||||
from mindspore._checkparam import args_type_check
|
||||
from mindspore._checkparam import args_type_check, Validator
|
||||
|
||||
_MAX_GROUP_NAME_LEN = 127
|
||||
_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
|
||||
|
@ -257,6 +257,7 @@ class _AutoParallelContext:
|
|||
grad_accumulation_step (int): The grad accumulation step.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
Validator.check_positive_int(grad_accumulation_step)
|
||||
self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
|
||||
|
||||
def get_grad_accumulation_step(self):
|
||||
|
|
Loading…
Reference in New Issue