diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 0cb70c42c2b..b273244bc8f 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -102,6 +102,7 @@ #include "backend/optimizer/ascend/format_type/remove_internal_output.h" #include "backend/optimizer/ascend/ir_fission/concat_fission.h" #include "backend/optimizer/ascend/ir_fission/pack_fission.h" +#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h" #include "utils/context/ms_context.h" #include "utils/config_manager.h" #include "debug/anf_ir_dump.h" @@ -341,6 +342,7 @@ void AscendBackendOptimization(const std::shared_ptr &kern auto other_pm = std::make_shared("other_pm"); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc new file mode 100644 index 00000000000..2ca18ec7e9a --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h" +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +namespace { +void AddOutputs(const AnfNodePtr &node, int rank_size) { + MS_EXCEPTION_IF_NULL(node); + auto origin_abstract = node->abstract(); + MS_EXCEPTION_IF_NULL(origin_abstract); + auto tuple_abstract = origin_abstract->cast(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + auto &origin_abstracts = tuple_abstract->elements(); + AbstractBasePtrList abstract_list; + std::vector outputs_device_type; + std::vector outputs_device_format; + for (int i = 0; i < rank_size; ++i) { + for (size_t j = 0; j < origin_abstracts.size(); ++j) { + abstract_list.push_back(origin_abstracts[j]); + outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(node, j)); + outputs_device_format.push_back(AnfAlgo::GetOutputFormat(node, j)); + } + } + // Update abstract + auto new_abstracts = std::make_shared(abstract_list); + node->set_abstract(new_abstracts); + // Update kernel build info + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); + builder->SetOutputsDeviceType(outputs_device_type); + builder->SetOutputsFormat(outputs_device_format); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); +} +} // namespace + +AnfNodePtr ConcatOutputsForAllGather::InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::vector &new_tuple_getitems, + int rank_size) const { + MS_EXCEPTION_IF_NULL(func_graph); + std::vector make_tuple_inputs; + size_t inputs_size = AnfAlgo::GetInputTensorNum(node); + for (size_t i = 0; i < inputs_size; ++i) { + for (size_t j = 0, idx = i; j < IntToSize(rank_size); ++j, idx += inputs_size) { + std::vector concat_inputs{NewValueNode(std::make_shared(prim::kPrimConcat->name()))}; + concat_inputs.push_back(new_tuple_getitems[idx]); + auto concat = func_graph->NewCNode(concat_inputs); + MS_EXCEPTION_IF_NULL(concat); + MS_EXCEPTION_IF_NULL(new_tuple_getitems[idx]); + concat->set_abstract(new_tuple_getitems[idx]->abstract()); + AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(0), concat); + AnfAlgo::SetNodeAttr(kAttrInputNums, MakeValue(rank_size), concat); + std::vector dyn_input_size{rank_size}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_size), concat); + kernel_select_->SelectKernel(concat); + make_tuple_inputs.push_back(concat); + } + } + auto make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} + +const BaseRef ConcatOutputsForAllGather::DefinePattern() const { + VarPtr Xs = std::make_shared(); + auto prim = std::make_shared(kAllGatherOpName); + return VectorRef({prim, Xs}); +} + +const AnfNodePtr ConcatOutputsForAllGather::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!AnfAlgo::HasNodeAttr(kAttrFusion, cnode) || !AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) { + return nullptr; + } + auto fusion = AnfAlgo::GetNodeAttr(cnode, kAttrFusion); + if (fusion <= 0) { + return nullptr; + } + auto rank_size = AnfAlgo::GetNodeAttr(node, kAttrRankSize); + AddOutputs(node, rank_size); + std::vector new_outputs; + CreateMultipleOutputsOfAnfNode(func_graph, node, AnfAlgo::GetOutputTensorNum(node), &new_outputs); + return InsertConcatForOutput(func_graph, node, new_outputs, rank_size); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h new file mode 100644 index 00000000000..7b4d0d5427d --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class ConcatOutputsForAllGather : public PatternProcessPass { + public: + explicit ConcatOutputsForAllGather(bool multigraph = true) + : PatternProcessPass("concat_outputs_for_all_gather", multigraph), + kernel_select_(std::make_shared()) {} + ~ConcatOutputsForAllGather() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + AnfNodePtr InsertConcatForOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::vector &new_tuple_getitems, int rank_size) const; + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_CONCAT_OUTPUTS_FOR_ALLGATHER_H_ diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc index d8f94c4a2cb..024f69e8437 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc @@ -188,9 +188,13 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr auto abstract_tuple = std::make_shared(abstract_list); MS_EXCEPTION_IF_NULL(abstract_tuple); fused_node->set_abstract(abstract_tuple); - AnfAlgo::CopyNodeAttr("fusion", communication_op_info.communication_op_nodes[end_index], fused_node); - AnfAlgo::CopyNodeAttr("op", communication_op_info.communication_op_nodes[end_index], fused_node); - AnfAlgo::CopyNodeAttr("group", communication_op_info.communication_op_nodes[end_index], fused_node); + auto final_node = communication_op_info.communication_op_nodes[end_index]; + AnfAlgo::CopyNodeAttr(kAttrFusion, final_node, fused_node); + AnfAlgo::CopyNodeAttr(kAttrOp, final_node, fused_node); + AnfAlgo::CopyNodeAttr(kAttrGroup, final_node, fused_node); + if (AnfAlgo::HasNodeAttr(kAttrRankSize, final_node)) { + AnfAlgo::CopyNodeAttr(kAttrRankSize, final_node, fused_node); + } return fused_node; } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 1ce981e2e3a..240392dcfba 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -250,6 +250,7 @@ constexpr auto kAttrChildGraph = "child_graph"; constexpr auto kAttrInputNums = "inputNums"; constexpr auto kAttrT = "T"; constexpr auto kAttrNum = "num"; +constexpr auto kAttrRankSize = "rank_size"; // attr value constexpr auto kValueTargetSwitch = "target_switch";