forked from mindspore-Ecosystem/mindspore
Integrate two allreduce fusion set interfaces into one
This commit is contained in:
parent
8357383111
commit
6fdcc24585
|
@ -15,9 +15,7 @@
|
||||||
"""
|
"""
|
||||||
This interface is ONLY used in Auto-parallel procedure.
|
This interface is ONLY used in Auto-parallel procedure.
|
||||||
"""
|
"""
|
||||||
from .dp_allreduce_fusion import set_fusion_strategy_by_idx, set_fusion_strategy_by_size
|
|
||||||
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
|
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
|
||||||
set_algo_parameters
|
set_algo_parameters
|
||||||
|
|
||||||
__all__ = ["set_fusion_strategy_by_idx", "set_fusion_strategy_by_size", "get_algo_parameters",
|
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
|
||||||
"reset_algo_parameters", "set_algo_parameters"]
|
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Context of auto parallel"""
|
"""Context of auto parallel"""
|
||||||
import threading
|
import threading
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
|
||||||
from mindspore._c_expression import AutoParallelContext
|
from mindspore._c_expression import AutoParallelContext
|
||||||
from mindspore._extends.pynative_helper import args_type_check
|
from mindspore._extends.pynative_helper import args_type_check
|
||||||
|
|
||||||
|
@ -219,13 +221,15 @@ class _AutoParallelContext:
|
||||||
indices (list): Indices list.
|
indices (list): Indices list.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If type of indices item is not int.
|
TypeError: If type of indices item is not int.
|
||||||
"""
|
"""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
for index in indices:
|
for index in indices:
|
||||||
if not isinstance(index, int):
|
if not isinstance(index, int):
|
||||||
raise TypeError('indices has invalid value')
|
raise TypeError('indices has invalid value')
|
||||||
return self._context_handle.set_all_reduce_fusion_split_indices(indices)
|
self._context_handle.set_all_reduce_fusion_split_indices(indices)
|
||||||
|
if context.get_context("device_target") == "Ascend":
|
||||||
|
_set_fusion_strategy_by_idx(indices)
|
||||||
|
|
||||||
def get_all_reduce_fusion_split_indices(self):
|
def get_all_reduce_fusion_split_indices(self):
|
||||||
"""Get allreduce fusion split indices."""
|
"""Get allreduce fusion split indices."""
|
||||||
|
@ -240,13 +244,15 @@ class _AutoParallelContext:
|
||||||
sizes (list): Sizes list.
|
sizes (list): Sizes list.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If type of sizes item is not int.
|
TypeError: If type of sizes item is not int.
|
||||||
"""
|
"""
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
for size in sizes:
|
for size in sizes:
|
||||||
if not isinstance(size, int):
|
if not isinstance(size, int):
|
||||||
raise TypeError('sizes has invalid value')
|
raise TypeError('sizes has invalid value')
|
||||||
return self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
|
self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
|
||||||
|
if context.get_context("device_target") == "Ascend":
|
||||||
|
_set_fusion_strategy_by_size(sizes)
|
||||||
|
|
||||||
def get_all_reduce_fusion_split_sizes(self):
|
def get_all_reduce_fusion_split_sizes(self):
|
||||||
"""Get allreduce fusion split sizes."""
|
"""Get allreduce fusion split sizes."""
|
||||||
|
|
|
@ -43,7 +43,7 @@ def _c_array(ctype, values):
|
||||||
return (ctype * len(values))(*values)
|
return (ctype * len(values))(*values)
|
||||||
|
|
||||||
|
|
||||||
def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
||||||
"""
|
"""
|
||||||
A function set gradient segment strategy according to the index list.
|
A function set gradient segment strategy according to the index list.
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
|
||||||
raise RuntimeError('Allreduce split error')
|
raise RuntimeError('Allreduce split error')
|
||||||
|
|
||||||
|
|
||||||
def set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
|
def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
|
||||||
"""
|
"""
|
||||||
A function set gradient segment strategy according to the data size percentage list.
|
A function set gradient segment strategy according to the data size percentage list.
|
||||||
|
|
Loading…
Reference in New Issue