diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc index 9ba7efd60f8..69383eaa86d 100644 --- a/mindspore/ccsrc/parallel/context.cc +++ b/mindspore/ccsrc/parallel/context.cc @@ -113,20 +113,28 @@ void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ck strategy_ckpt_save_file_ = strategy_ckpt_save_file; } -void ParallelContext::set_all_reduce_fusion_split_indices(const std::vector indices) { - all_reduce_fusion_split_indices_ = indices; +void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group) { + all_reduce_fusion_split_indices_[group] = indices; } -const std::vector ParallelContext::all_reduce_fusion_split_indices() const { - return all_reduce_fusion_split_indices_; +const std::vector ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const { + auto iter = all_reduce_fusion_split_indices_.find(group); + if (iter != all_reduce_fusion_split_indices_.end()) { + return iter->second; + } + return {}; } -void ParallelContext::set_all_reduce_fusion_split_sizes(const std::vector sizes) { - all_reduce_fusion_split_sizes_ = sizes; +void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group) { + all_reduce_fusion_split_sizes_[group] = sizes; } -const std::vector ParallelContext::all_reduce_fusion_split_sizes() const { - return all_reduce_fusion_split_sizes_; +const std::vector ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const { + auto iter = all_reduce_fusion_split_sizes_.find(group); + if (iter != all_reduce_fusion_split_sizes_.end()) { + return iter->second; + } + return {}; } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h index 0e007c92c64..3e750c17606 100644 --- a/mindspore/ccsrc/parallel/context.h +++ b/mindspore/ccsrc/parallel/context.h @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -76,10 +77,10 @@ class ParallelContext { bool global_rank_is_set() const { return global_rank_is_set_; } bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; } - void set_all_reduce_fusion_split_indices(const std::vector indices); - const std::vector all_reduce_fusion_split_indices() const; - void set_all_reduce_fusion_split_sizes(const std::vector sizes); - const std::vector all_reduce_fusion_split_sizes() const; + void SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group); + const std::vector GetAllReduceFusionSplitIndices(const std::string &group) const; + void SetAllReduceFusionSplitSizes(const std::vector sizes, const std::string &group); + const std::vector GetAllReduceFusionSplitSizes(const std::string &group) const; void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) { enable_all_reduce_fusion_ = enable_all_reduce_fusion; } @@ -108,8 +109,8 @@ class ParallelContext { bool global_rank_is_set_; bool parameter_broadcast_is_set_; bool enable_all_reduce_fusion_; - std::vector all_reduce_fusion_split_indices_; - std::vector all_reduce_fusion_split_sizes_; + std::map> all_reduce_fusion_split_indices_; + std::map> all_reduce_fusion_split_sizes_; std::string strategy_ckpt_load_file_; std::string strategy_ckpt_save_file_; }; diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 7c663291c04..31ebadf29e0 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -159,13 +159,13 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.") - .def("set_all_reduce_fusion_split_indices", &ParallelContext::set_all_reduce_fusion_split_indices, + .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices, "Set all reduce fusion split indices.") - .def("get_all_reduce_fusion_split_indices", &ParallelContext::all_reduce_fusion_split_indices, + .def("get_all_reduce_fusion_split_indices", &ParallelContext::GetAllReduceFusionSplitIndices, "Get all reduce fusion split indices.") - .def("set_all_reduce_fusion_split_sizes", &ParallelContext::set_all_reduce_fusion_split_sizes, + .def("set_all_reduce_fusion_split_sizes", &ParallelContext::SetAllReduceFusionSplitSizes, "Set all reduce fusion split sizes.") - .def("get_all_reduce_fusion_split_sizes", &ParallelContext::all_reduce_fusion_split_sizes, + .def("get_all_reduce_fusion_split_sizes", &ParallelContext::GetAllReduceFusionSplitSizes, "Get all reduce fusion split sizes.") .def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion, "Set enable/disable all reduce fusion.") diff --git a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc index 4bcd488f691..fc878dd8811 100644 --- a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc @@ -92,7 +92,7 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) { } // namespace bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, - std::vector *segment_index) const { + std::vector *segment_index, const std::string &group) const { MS_EXCEPTION_IF_NULL(segment_num); MS_EXCEPTION_IF_NULL(segment_index); size_t communication_op_node_size = communication_op_info.communication_op_nodes.size(); @@ -100,7 +100,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic auto parallel_context = parallel::ParallelContext::GetInstance(); MS_EXCEPTION_IF_NULL(parallel_context); - const std::vector split_indices = parallel_context->all_reduce_fusion_split_indices(); + const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group); size_t segments = 0; if (split_indices.size() != 0) { @@ -255,7 +255,7 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { } size_t segment_num = 0; std::vector segment_index; - if (GetSplitSegments(it.second, &segment_num, &segment_index)) { + if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) { if (DoFusion(func_graph, it.second, segment_num, segment_index)) { changed = true; } diff --git a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h index e98da1f0ccd..e01d1816164 100644 --- a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h +++ b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h @@ -46,7 +46,7 @@ class CommunicationOpFusion : public Pass { const CommunicationOpInfo &communication_op_info, size_t start_index, size_t end_index) const; bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, - std::vector *segment_index) const; + std::vector *segment_index, const std::string &group) const; std::string op_name_; size_t groups_ = 1; }; diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index f3f8d443e9e..aee47858cd0 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -19,6 +19,8 @@ from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, from mindspore._c_expression import AutoParallelContext from mindspore._checkparam import args_type_check +_MAX_GROUP_NAME_LEN = 127 + class _AutoParallelContext: """ @@ -243,51 +245,117 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_parameter_broadcast_is_set() - def set_all_reduce_fusion_split_indices(self, indices): + def set_all_reduce_fusion_split_indices(self, indices, group=""): """ Set allreduce fusion strategy by parameters indices. Args: indices (list): Indices list. + group (str): The hccl communication group. Raises: TypeError: If type of indices item is not int. + TypeError: If group is not a python str. """ self.check_context_handle() - for index in indices: - if not isinstance(index, int): - raise TypeError('indices has invalid value') - self._context_handle.set_all_reduce_fusion_split_indices(indices) + if isinstance(indices, (list)): + for index in indices: + if not isinstance(index, int): + raise TypeError('indices has invalid value') + else: + raise TypeError('indices must be a python list') + + if isinstance(group, (str)): + group_len = len(group) + if group_len > _MAX_GROUP_NAME_LEN: + raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') + else: + raise TypeError('Group must be a python str') + + self._context_handle.set_all_reduce_fusion_split_indices(indices, group) if context.get_context("device_target") == "Ascend": - _set_fusion_strategy_by_idx(indices) + if group == "": + _set_fusion_strategy_by_idx(indices) + else: + _set_fusion_strategy_by_idx(indices, group) - def get_all_reduce_fusion_split_indices(self): - """Get allreduce fusion split indices.""" + def get_all_reduce_fusion_split_indices(self, group=""): + """ + Get allreduce fusion split indices. + + Args: + group (str): The hccl communication group. + + Returns: + Return split sizes list according to the group. + + Raises: + TypeError: If group is not a python str. + """ self.check_context_handle() - return self._context_handle.get_all_reduce_fusion_split_indices() + if isinstance(group, (str)): + group_len = len(group) + if group_len > _MAX_GROUP_NAME_LEN: + raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') + else: + raise TypeError('Group must be a python str') + return self._context_handle.get_all_reduce_fusion_split_indices(group) - def set_all_reduce_fusion_split_sizes(self, sizes): + def set_all_reduce_fusion_split_sizes(self, sizes, group=""): """ Set allreduce fusion strategy by parameters data sizes. Args: sizes (list): Sizes list. + group (str): The hccl communication group. Raises: TypeError: If type of sizes item is not int. + TypeError: If group is not a python str. """ self.check_context_handle() - for size in sizes: - if not isinstance(size, int): - raise TypeError('sizes has invalid value') - self._context_handle.set_all_reduce_fusion_split_sizes(sizes) - if context.get_context("device_target") == "Ascend": - _set_fusion_strategy_by_size(sizes) + if isinstance(sizes, (list)): + for size in sizes: + if not isinstance(size, int): + raise TypeError('sizes has invalid value') + else: + raise TypeError('sizes must be a python list') - def get_all_reduce_fusion_split_sizes(self): - """Get allreduce fusion split sizes.""" + if isinstance(group, (str)): + group_len = len(group) + if group_len > _MAX_GROUP_NAME_LEN: + raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') + else: + raise TypeError('Group must be a python str') + + self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group) + if context.get_context("device_target") == "Ascend": + if group == "": + _set_fusion_strategy_by_size(sizes) + else: + _set_fusion_strategy_by_size(sizes, group) + + def get_all_reduce_fusion_split_sizes(self, group=""): + """ + Get allreduce fusion split sizes. + + Args: + group (str): The hccl communication group. + + Returns: + Return split sizes list according to the group. + + Raises: + TypeError: If group is not a python str. + """ self.check_context_handle() - return self._context_handle.get_all_reduce_fusion_split_sizes() + if isinstance(group, (str)): + group_len = len(group) + if group_len > _MAX_GROUP_NAME_LEN: + raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') + else: + raise TypeError('Group must be a python str') + return self._context_handle.get_all_reduce_fusion_split_sizes(group) def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): """