forked from mindspore-Ecosystem/mindspore
!189 Integrate two allreduce fusion set interfaces into one
Merge pull request !189 from yao_yf/parallel_interface_organize
This commit is contained in:
commit
ecc168c72f
|
@ -15,9 +15,7 @@
|
|||
"""
|
||||
This interface is ONLY used in Auto-parallel procedure.
|
||||
"""
|
||||
from .dp_allreduce_fusion import set_fusion_strategy_by_idx, set_fusion_strategy_by_size
|
||||
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
|
||||
set_algo_parameters
|
||||
|
||||
__all__ = ["set_fusion_strategy_by_idx", "set_fusion_strategy_by_size", "get_algo_parameters",
|
||||
"reset_algo_parameters", "set_algo_parameters"]
|
||||
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
# ============================================================================
|
||||
"""Context of auto parallel"""
|
||||
import threading
|
||||
import mindspore.context as context
|
||||
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
|
||||
from mindspore._c_expression import AutoParallelContext
|
||||
from mindspore._extends.pynative_helper import args_type_check
|
||||
|
||||
|
@ -219,13 +221,15 @@ class _AutoParallelContext:
|
|||
indices (list): Indices list.
|
||||
|
||||
Raises:
|
||||
ValueError: If type of indices item is not int.
|
||||
TypeError: If type of indices item is not int.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
for index in indices:
|
||||
if not isinstance(index, int):
|
||||
raise TypeError('indices has invalid value')
|
||||
return self._context_handle.set_all_reduce_fusion_split_indices(indices)
|
||||
self._context_handle.set_all_reduce_fusion_split_indices(indices)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
_set_fusion_strategy_by_idx(indices)
|
||||
|
||||
def get_all_reduce_fusion_split_indices(self):
|
||||
"""Get allreduce fusion split indices."""
|
||||
|
@ -240,13 +244,15 @@ class _AutoParallelContext:
|
|||
sizes (list): Sizes list.
|
||||
|
||||
Raises:
|
||||
ValueError: If type of sizes item is not int.
|
||||
TypeError: If type of sizes item is not int.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
for size in sizes:
|
||||
if not isinstance(size, int):
|
||||
raise TypeError('sizes has invalid value')
|
||||
return self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
|
||||
self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
_set_fusion_strategy_by_size(sizes)
|
||||
|
||||
def get_all_reduce_fusion_split_sizes(self):
|
||||
"""Get allreduce fusion split sizes."""
|
||||
|
|
|
@ -43,7 +43,7 @@ def _c_array(ctype, values):
|
|||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
||||
def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
||||
"""
|
||||
A function set gradient segment strategy according to the index list.
|
||||
|
||||
|
@ -100,7 +100,7 @@ def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
|||
raise RuntimeError('Allreduce split error')
|
||||
|
||||
|
||||
def set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
|
||||
def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
|
||||
"""
|
||||
A function set gradient segment strategy according to the data size percentage list.
|
||||
|
Loading…
Reference in New Issue