diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index aee47858cd0..02190290379 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -245,7 +245,7 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_parameter_broadcast_is_set() - def set_all_reduce_fusion_split_indices(self, indices, group=""): + def set_all_reduce_fusion_split_indices(self, indices, group="hccl_world_groupsum1"): """ Set allreduce fusion strategy by parameters indices. @@ -279,7 +279,7 @@ class _AutoParallelContext: else: _set_fusion_strategy_by_idx(indices, group) - def get_all_reduce_fusion_split_indices(self, group=""): + def get_all_reduce_fusion_split_indices(self, group="hccl_world_groupsum1"): """ Get allreduce fusion split indices. @@ -301,7 +301,7 @@ class _AutoParallelContext: raise TypeError('Group must be a python str') return self._context_handle.get_all_reduce_fusion_split_indices(group) - def set_all_reduce_fusion_split_sizes(self, sizes, group=""): + def set_all_reduce_fusion_split_sizes(self, sizes, group="hccl_world_groupsum1"): """ Set allreduce fusion strategy by parameters data sizes. @@ -335,7 +335,7 @@ class _AutoParallelContext: else: _set_fusion_strategy_by_size(sizes, group) - def get_all_reduce_fusion_split_sizes(self, group=""): + def get_all_reduce_fusion_split_sizes(self, group="hccl_world_groupsum1"): """ Get allreduce fusion split sizes.