diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index a2d82525e9a..0416c308305 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -21,7 +21,7 @@ #include "pre_activate/ascend/ir_fission/bn_grad_split.h" #include "pre_activate/ascend/ir_fusion/fused_batch_norm_fusion.h" #include "pre_activate/ascend/ir_fission/layer_norm_grad_split.h" -#include "pre_activate/pass/allreduce_fusion.h" +#include "pre_activate/pass/communication_op_fusion.h" #include "pre_activate/ascend/ir_fusion/square_sum_fusion.h" #include "pre_activate/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" #include "pre_activate/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" @@ -261,6 +261,7 @@ void AscendBackendOptimization(const std::shared_ptr &kern auto optimizer = std::make_shared(); auto other_pm = std::make_shared("other_pm"); other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/pass/allreduce_fusion.h b/mindspore/ccsrc/pre_activate/pass/allreduce_fusion.h deleted file mode 100644 index e443767e43d..00000000000 --- a/mindspore/ccsrc/pre_activate/pass/allreduce_fusion.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_ -#include - -#include "pre_activate/common/pass.h" -#include "ir/func_graph.h" -#include "ir/anf.h" - -namespace mindspore { -namespace opt { -struct AllReduceInfo_t { - std::vector allreduce_node; - std::vector input_grad_size; - std::vector input_grad_time; -}; - -class AllReduceFusion : public Pass { - public: - explicit AllReduceFusion(size_t groups = 1) : Pass("all_reduce_fusion"), groups_(groups) {} - ~AllReduceFusion() override = default; - bool Run(const FuncGraphPtr &graph) override; - - private: - bool DoFusion(const FuncGraphPtr &func_graph, const AllReduceInfo_t &allreduce_node_info, size_t segment_num, - const std::vector &segment_index) const; - AnfNodePtr CreateFusedAllReduce(const FuncGraphPtr &func_graph, const AllReduceInfo_t &allreduce_node_info, - size_t start_index, size_t end_index) const; - bool GetSplitSegments(const AllReduceInfo_t &allreduce_node_info, size_t *segment_num, - std::vector *segment_index) const; - size_t groups_ = 1; -}; -} // namespace opt -} // namespace mindspore -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_ALLREDUCE_FUSION_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/allreduce_fusion.cc b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc similarity index 62% rename from mindspore/ccsrc/pre_activate/pass/allreduce_fusion.cc rename to mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc index 70a8974ecaf..4bcd488f691 100644 --- a/mindspore/ccsrc/pre_activate/pass/allreduce_fusion.cc +++ b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.cc @@ -13,14 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "pre_activate/pass/allreduce_fusion.h" +#include "pre_activate/pass/communication_op_fusion.h" #include -#include #include #include -#include "utils/utils.h" #include "utils/graph_utils.h" #include "operator/ops.h" #include "device/kernel_info.h" @@ -31,9 +29,12 @@ namespace mindspore { namespace opt { namespace { -kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allreduce_node_info, size_t start_index, +constexpr auto kAttrDefaultGroup = "default_group"; +constexpr auto kAttrDefaultOp = "default_op"; + +kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &communication_op_info, size_t start_index, size_t end_index) { - if (end_index >= allreduce_node_info.allreduce_node.size()) { + if (end_index >= communication_op_info.communication_op_nodes.size()) { MS_LOG(EXCEPTION) << "end index out of vector size"; } std::vector inputs_device_format; @@ -43,7 +44,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allred std::vector> outputs_shape; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; for (size_t idx = start_index; idx <= end_index; ++idx) { - auto cnode = allreduce_node_info.allreduce_node[idx]; + auto cnode = communication_op_info.communication_op_nodes[idx]; MS_EXCEPTION_IF_NULL(cnode); for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); @@ -64,14 +65,38 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AllReduceInfo_t &allred builder.SetOutputsDeviceType(outputs_device_type); return builder.Build(); } + +std::string GetFusionGroupKey(const AnfNodePtr &node) { + auto primitive = AnfAlgo::GetCNodePrimitive(node); + MS_EXCEPTION_IF_NULL(primitive); + ValuePtr attr_fusion = primitive->GetAttr(kAttrFusion); + if (attr_fusion == nullptr) { + return ""; + } + int fusion = GetValue(attr_fusion); + if (fusion == 0) { + return ""; + } + std::string group = kAttrDefaultGroup; + ValuePtr attr_group = primitive->GetAttr(kAttrGroup); + if (attr_group != nullptr) { + group = GetValue(attr_group); + } + std::string op = kAttrDefaultOp; + ValuePtr attr_op = primitive->GetAttr(kAttrOp); + if (attr_op != nullptr) { + op = GetValue(attr_op); + } + return group + op + std::to_string(fusion); +} } // namespace -bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_info, size_t *segment_num, - std::vector *segment_index) const { +bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, + std::vector *segment_index) const { MS_EXCEPTION_IF_NULL(segment_num); MS_EXCEPTION_IF_NULL(segment_index); - size_t allreduce_node_size = allreduce_node_info.allreduce_node.size(); - MS_LOG(INFO) << "graph all reduce node size " << allreduce_node_size; + size_t communication_op_node_size = communication_op_info.communication_op_nodes.size(); + MS_LOG(INFO) << "graph " << op_name_ << " node size " << communication_op_node_size; auto parallel_context = parallel::ParallelContext::GetInstance(); MS_EXCEPTION_IF_NULL(parallel_context); @@ -82,30 +107,31 @@ bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_inf uint32_t last_index = 0; for (size_t i = 0; i < split_indices.size(); ++i) { uint32_t index = split_indices[i]; - if (index <= last_index || index >= allreduce_node_size) { - MS_LOG(EXCEPTION) << "invalid allreduce split index " << i << " " << index; + if (index <= last_index || index >= communication_op_node_size) { + MS_LOG(EXCEPTION) << "invalid " << op_name_ << " split index " << i << " " << index; } segment_index->push_back(index); last_index = index; segments++; } - if (last_index != allreduce_node_size - 1) { - segment_index->push_back(allreduce_node_size - 1); + if (last_index != communication_op_node_size - 1) { + segment_index->push_back(communication_op_node_size - 1); segments++; } } else { segments = groups_; for (size_t i = 0; i < segments - 1; ++i) { - segment_index->push_back((i + 1) * (allreduce_node_size / segments) - 1); + segment_index->push_back((i + 1) * (communication_op_node_size / segments) - 1); } - segment_index->push_back(allreduce_node_size - 1); + segment_index->push_back(communication_op_node_size - 1); } - if (segments >= allreduce_node_size) { - MS_LOG(INFO) << "fusion not changed: segment_num=" << segments << ", allreduce_node_size=" << allreduce_node_size; + if (segments >= communication_op_node_size) { + MS_LOG(INFO) << "fusion not changed: segment_num=" << segments + << ", communication_op_node_size=" << communication_op_node_size; return false; } - if (segment_index->at(segments - 1) != allreduce_node_size - 1) { + if (segment_index->at(segments - 1) != communication_op_node_size - 1) { MS_LOG(EXCEPTION) << "the last segment index is invalid."; } for (size_t i = 0; i < segments - 1; ++i) { @@ -118,19 +144,19 @@ bool AllReduceFusion::GetSplitSegments(const AllReduceInfo_t &allreduce_node_inf return true; } -AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph, - const AllReduceInfo_t &allreduce_node_info, size_t start_index, - size_t end_index) const { +AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, + const CommunicationOpInfo &communication_op_info, + size_t start_index, size_t end_index) const { MS_EXCEPTION_IF_NULL(func_graph); - auto prim = std::make_shared(kAllReduceOpName); + auto prim = std::make_shared(op_name_); MS_EXCEPTION_IF_NULL(prim); std::vector fusion_inputs = {NewValueNode(prim)}; // get all inputs of current segment - if (end_index >= allreduce_node_info.allreduce_node.size()) { + if (end_index >= communication_op_info.communication_op_nodes.size()) { MS_LOG(EXCEPTION) << "end index out of vector size"; } for (size_t idx = start_index; idx <= end_index; ++idx) { - auto cnode = allreduce_node_info.allreduce_node[idx]; + auto cnode = communication_op_info.communication_op_nodes[idx]; MS_EXCEPTION_IF_NULL(cnode); fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); } @@ -141,14 +167,14 @@ AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph, fused_node->set_kernel_info(kernel_info); AbstractBasePtrList abstract_list; for (size_t idx = start_index; idx <= end_index; ++idx) { - auto cnode = allreduce_node_info.allreduce_node[idx]; + auto cnode = communication_op_info.communication_op_nodes[idx]; MS_EXCEPTION_IF_NULL(cnode); AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node); AnfAlgo::CopyNodeAttr("op", cnode, fused_node); AnfAlgo::CopyNodeAttr("group", cnode, fused_node); abstract_list.push_back(cnode->abstract()); } - auto kernel_build_info = GenerateKernelBuildInfo(allreduce_node_info, start_index, end_index); + auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, fused_node.get()); auto abstract_tuple = std::make_shared(abstract_list); MS_EXCEPTION_IF_NULL(abstract_tuple); @@ -156,8 +182,8 @@ AnfNodePtr AllReduceFusion::CreateFusedAllReduce(const FuncGraphPtr &func_graph, return fused_node; } -bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceInfo_t &allreduce_node_info, - size_t segment_num, const std::vector &segment_index) const { +bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, + size_t segment_num, const std::vector &segment_index) const { MS_EXCEPTION_IF_NULL(func_graph); auto manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -169,12 +195,13 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn start_index = end_index + 1; continue; } - AnfNodePtr new_allreduce = CreateFusedAllReduce(func_graph, allreduce_node_info, start_index, end_index); - // replace old allreduce with new allreduce + AnfNodePtr new_communication_op = + CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index); + // replace old communication op with new communication op for (auto idx = start_index; idx <= end_index; ++idx) { std::vector tuple_getitem_input; tuple_getitem_input.push_back(NewValueNode(prim::kPrimTupleGetItem)); - tuple_getitem_input.push_back(new_allreduce); + tuple_getitem_input.push_back(new_communication_op); auto index = NewValueNode(SizeToInt(idx - start_index)); MS_EXCEPTION_IF_NULL(index); auto imm = std::make_shared(idx - start_index); @@ -185,10 +212,10 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn tuple_getitem_input.push_back(index); AnfNodePtr tuple_getitem = func_graph->NewCNode(tuple_getitem_input); MS_EXCEPTION_IF_NULL(tuple_getitem); - auto allreduce_node_item = allreduce_node_info.allreduce_node.at(idx); - MS_EXCEPTION_IF_NULL(allreduce_node_item); - tuple_getitem->set_abstract(allreduce_node_item->abstract()); - if (!manager->Replace(allreduce_node_item, tuple_getitem)) { + auto communication_op_node_item = communication_op_info.communication_op_nodes.at(idx); + MS_EXCEPTION_IF_NULL(communication_op_node_item); + tuple_getitem->set_abstract(communication_op_node_item->abstract()); + if (!manager->Replace(communication_op_node_item, tuple_getitem)) { MS_LOG(EXCEPTION) << "manager replace node failed"; } } @@ -198,29 +225,24 @@ bool AllReduceFusion::DoFusion(const FuncGraphPtr &func_graph, const AllReduceIn return changed; } -bool AllReduceFusion::Run(const FuncGraphPtr &func_graph) { +bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); const float input_grad_size_num = 0.0; const float input_grad_time_num = 0.0; // divide candidate fusion groups with same (group,op,fusion) attrs, fusion==0 means not fusion - std::unordered_map candidate_groups; + std::unordered_map candidate_groups; std::vector node_list = TopoSort(func_graph->get_return()); for (auto &node : node_list) { - if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) { - auto primitive = AnfAlgo::GetCNodePrimitive(node); - MS_EXCEPTION_IF_NULL(primitive); - int fusion = GetValue(primitive->GetAttr("fusion")); - if (fusion == 0) { + if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == op_name_) { + std::string key = GetFusionGroupKey(node); + if (key.empty()) { continue; } - std::string group = GetValue(primitive->GetAttr("group")); - std::string op = GetValue(primitive->GetAttr("op")); - std::string key = group + op + std::to_string(fusion); if (candidate_groups.find(key) == candidate_groups.end()) { - AllReduceInfo_t allreduce_node_info; - candidate_groups[key] = allreduce_node_info; + CommunicationOpInfo communication_op_info; + candidate_groups[key] = communication_op_info; } - candidate_groups[key].allreduce_node.push_back(node->cast()); + candidate_groups[key].communication_op_nodes.push_back(node->cast()); candidate_groups[key].input_grad_size.push_back(input_grad_size_num); candidate_groups[key].input_grad_time.push_back(input_grad_time_num); } @@ -228,7 +250,7 @@ bool AllReduceFusion::Run(const FuncGraphPtr &func_graph) { // split candidate group to segments according to _group class member bool changed = false; for (auto &it : candidate_groups) { - if (it.second.allreduce_node.size() <= 1) { + if (it.second.communication_op_nodes.size() <= 1) { continue; } size_t segment_num = 0; diff --git a/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h new file mode 100644 index 00000000000..af8b557d5fc --- /dev/null +++ b/mindspore/ccsrc/pre_activate/pass/communication_op_fusion.h @@ -0,0 +1,67 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ +#include +#include +#include + +#include "pre_activate/common/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "utils/utils.h" + +namespace mindspore { +namespace opt { +struct CommunicationOpInfo { + std::vector communication_op_nodes; + std::vector input_grad_size; + std::vector input_grad_time; +}; + +class CommunicationOpFusion : public Pass { + public: + explicit CommunicationOpFusion(const std::string &name, std::string op_name, size_t groups = 1) + : Pass(name), op_name_(std::move(op_name)), groups_(groups) {} + ~CommunicationOpFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + bool DoFusion(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, size_t segment_num, + const std::vector &segment_index) const; + AnfNodePtr CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, + const CommunicationOpInfo &communication_op_info, size_t start_index, + size_t end_index) const; + bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, + std::vector *segment_index) const; + std::string op_name_; + size_t groups_ = 1; +}; + +class AllReduceFusion : public CommunicationOpFusion { + public: + explicit AllReduceFusion(size_t groups = 1) : CommunicationOpFusion("all_reduce_fusion", kAllReduceOpName, groups) {} + ~AllReduceFusion() override = default; +}; + +class AllGatherFusion : public CommunicationOpFusion { + public: + explicit AllGatherFusion(size_t groups = 1) : CommunicationOpFusion("all_gather_fusion", kAllGatherOpName, groups) {} + ~AllGatherFusion() override = default; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_PASS_COMMUNICATION_OP_FUSION_H_ diff --git a/mindspore/ccsrc/session/gpu_session.cc b/mindspore/ccsrc/session/gpu_session.cc index f5e8c44231d..4a9506913ca 100644 --- a/mindspore/ccsrc/session/gpu_session.cc +++ b/mindspore/ccsrc/session/gpu_session.cc @@ -20,7 +20,7 @@ #include "device/gpu/gpu_stream_assign.h" #include "pre_activate/common/optimizer.h" #include "pre_activate/common/pass_manager.h" -#include "pre_activate/pass/allreduce_fusion.h" +#include "pre_activate/pass/communication_op_fusion.h" #include "device/kernel_runtime_manager.h" #include "predict/predict.h" #include "common/utils.h" diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 6829a7e888e..40c1b05377b 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -155,6 +155,9 @@ constexpr auto kAttrOutputUsedNum = "output_used_num"; constexpr auto kAttrHasBias = "has_bias"; constexpr auto kAttrN = "n"; constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active"; +constexpr auto kAttrFusion = "fusion"; +constexpr auto kAttrGroup = "group"; +constexpr auto kAttrOp = "op"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/tests/ut/cpp/pre_activate/common/ir_fusion/allreduce_fusion_test.cc b/tests/ut/cpp/pre_activate/common/ir_fusion/allreduce_fusion_test.cc index d5f2fa636dd..7f3b9d4c9de 100644 --- a/tests/ut/cpp/pre_activate/common/ir_fusion/allreduce_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/common/ir_fusion/allreduce_fusion_test.cc @@ -20,7 +20,7 @@ #include "ir/manager.h" #include "debug/anf_ir_dump.h" #include "session/anf_runtime_algorithm.h" -#include "pre_activate/pass/allreduce_fusion.h" +#include "pre_activate/pass/communication_op_fusion.h" #include "pre_activate/common/optimizer.h" #include "device/kernel_info.h" #include "pre_activate/common/pass_manager.h"