forked from mindspore-Ecosystem/mindspore
!3562 add nccl default allreduce fusion group
Merge pull request !3562 from yuchaojie/add_nccl_default_allreduce_fusion_group
This commit is contained in:
commit
51891b751d
|
@ -20,6 +20,8 @@ from mindspore._c_expression import AutoParallelContext
|
|||
from mindspore._checkparam import args_type_check
|
||||
|
||||
_MAX_GROUP_NAME_LEN = 127
|
||||
_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
|
||||
_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
|
||||
|
||||
|
||||
class _AutoParallelContext:
|
||||
|
@ -267,7 +269,7 @@ class _AutoParallelContext:
|
|||
self.check_context_handle()
|
||||
return self._context_handle.get_parameter_broadcast_is_set()
|
||||
|
||||
def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"):
|
||||
def set_all_reduce_fusion_split_indices(self, indices, group=""):
|
||||
"""
|
||||
Set allreduce fusion strategy by parameters indices.
|
||||
|
||||
|
@ -294,11 +296,17 @@ class _AutoParallelContext:
|
|||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
|
||||
if group == "":
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
||||
else:
|
||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
||||
|
||||
self._context_handle.set_all_reduce_fusion_split_indices(indices, group)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
_set_fusion_strategy_by_idx(indices)
|
||||
|
||||
def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"):
|
||||
def get_all_reduce_fusion_split_indices(self, group=""):
|
||||
"""
|
||||
Get allreduce fusion split indices.
|
||||
|
||||
|
@ -318,9 +326,15 @@ class _AutoParallelContext:
|
|||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
|
||||
if group == "":
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
||||
else:
|
||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
||||
return self._context_handle.get_all_reduce_fusion_split_indices(group)
|
||||
|
||||
def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"):
|
||||
def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
|
||||
"""
|
||||
Set allreduce fusion strategy by parameters data sizes.
|
||||
|
||||
|
@ -347,11 +361,17 @@ class _AutoParallelContext:
|
|||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
|
||||
if group == "":
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
||||
else:
|
||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
||||
|
||||
self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
_set_fusion_strategy_by_size(sizes)
|
||||
|
||||
def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"):
|
||||
def get_all_reduce_fusion_split_sizes(self, group=""):
|
||||
"""
|
||||
Get allreduce fusion split sizes.
|
||||
|
||||
|
@ -371,6 +391,12 @@ class _AutoParallelContext:
|
|||
raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
|
||||
else:
|
||||
raise TypeError('Group must be a python str')
|
||||
|
||||
if group == "":
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
group = _DEFAULT_HCCL_FUSION_GROUP_NAME
|
||||
else:
|
||||
group = _DEFAULT_NCCL_FUSION_GROUP_NAME
|
||||
return self._context_handle.get_all_reduce_fusion_split_sizes(group)
|
||||
|
||||
def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
|
||||
|
|
Loading…
Reference in New Issue