forked from mindspore-Ecosystem/mindspore
set all reduce fusion default group
This commit is contained in:
parent
2e9206e8bb
commit
503dd297c5
|
@ -245,7 +245,7 @@ class _AutoParallelContext:
|
||||||
self.check_context_handle()
|
self.check_context_handle()
|
||||||
return self._context_handle.get_parameter_broadcast_is_set()
|
return self._context_handle.get_parameter_broadcast_is_set()
|
||||||
|
|
||||||
def set_all_reduce_fusion_split_indices(self, indices, group=""):
|
def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"):
|
||||||
"""
|
"""
|
||||||
Set allreduce fusion strategy by parameters indices.
|
Set allreduce fusion strategy by parameters indices.
|
||||||
|
|
||||||
|
@ -279,7 +279,7 @@ class _AutoParallelContext:
|
||||||
else:
|
else:
|
||||||
_set_fusion_strategy_by_idx(indices, group)
|
_set_fusion_strategy_by_idx(indices, group)
|
||||||
|
|
||||||
def get_all_reduce_fusion_split_indices(self, group=""):
|
def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"):
|
||||||
"""
|
"""
|
||||||
Get allreduce fusion split indices.
|
Get allreduce fusion split indices.
|
||||||
|
|
||||||
|
@ -301,7 +301,7 @@ class _AutoParallelContext:
|
||||||
raise TypeError('Group must be a python str')
|
raise TypeError('Group must be a python str')
|
||||||
return self._context_handle.get_all_reduce_fusion_split_indices(group)
|
return self._context_handle.get_all_reduce_fusion_split_indices(group)
|
||||||
|
|
||||||
def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
|
def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"):
|
||||||
"""
|
"""
|
||||||
Set allreduce fusion strategy by parameters data sizes.
|
Set allreduce fusion strategy by parameters data sizes.
|
||||||
|
|
||||||
|
@ -335,7 +335,7 @@ class _AutoParallelContext:
|
||||||
else:
|
else:
|
||||||
_set_fusion_strategy_by_size(sizes, group)
|
_set_fusion_strategy_by_size(sizes, group)
|
||||||
|
|
||||||
def get_all_reduce_fusion_split_sizes(self, group=""):
|
def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"):
|
||||||
"""
|
"""
|
||||||
Get allreduce fusion split sizes.
|
Get allreduce fusion split sizes.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue