From 9068799ca65599aea8fa5e0257f821569c482b8e Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Thu, 5 May 2022 21:14:42 +0800 Subject: [PATCH] add BatchMatmulEltwiseFusionPass --- mindspore/ccsrc/include/common/utils/utils.h | 4 + .../device/ascend/hal/device/lic_manager.cc | 2 + .../device/ascend/hal/device/lic_manager.h | 2 + .../optimizer/ascend_backend_optimization.cc | 4 +- .../batchmatmul_dropoutdomaskv3_fusion_pass.h | 4 +- .../batchmatmul_eltwise_fusion_pass.cc | 132 ++++++++++++++++++ ...ss.h => batchmatmul_eltwise_fusion_pass.h} | 23 +-- .../batchmatmul_fusedmuladd_fusion_pass.cc | 58 -------- .../matmul_dropoutdomaskv3_add_fusion_pass.h | 4 +- 9 files changed, 160 insertions(+), 73 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.cc rename mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/{batchmatmul_fusedmuladd_fusion_pass.h => batchmatmul_eltwise_fusion_pass.h} (65%) delete mode 100644 mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.cc diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index 5d1272b62ed..3e65069c94c 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -144,6 +144,7 @@ constexpr auto kSqrtOpName = "Sqrt"; constexpr auto kRsqrtOpName = "Rsqrt"; constexpr auto kRsqrtGradOpName = "RsqrtGrad"; constexpr auto kErfOpName = "Erf"; +constexpr auto kDivOpName = "Div"; constexpr auto kRealDivOpName = "RealDiv"; constexpr auto kLambUpdateWithLROpName = "LambUpdateWithLR"; constexpr auto kLambNextMVWithDecayOpName = "LambNextMVWithDecay"; @@ -177,6 +178,8 @@ constexpr auto kSendOpName = "StreamSend"; constexpr auto kRecvOpName = "StreamRecv"; constexpr auto kRpcSendOpName = "RpcSend"; constexpr auto kRpcRecvOpName = "RpcRecv"; +constexpr auto kReluOpName = "ReLU"; +constexpr auto kReluGradOpName = "ReluGrad"; constexpr auto kReluV2OpName = "ReLUV2"; constexpr auto kReluGradV2OpName = "ReluGradV2"; constexpr auto kAddOpName = "Add"; @@ -335,6 +338,7 @@ constexpr auto kSubAndFilterOpName = "SubAndFilter"; constexpr auto kPadAndShiftOpName = "PadAndShift"; constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits"; constexpr auto kOneHotOpName = "OneHot"; +constexpr auto kSigmoidOpName = "Sigmoid"; constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits"; constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler"; constexpr auto kLogSoftmaxGradOpName = "LogSoftmaxGrad"; diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.cc index b71c2be27e9..09c03a3fb11 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.cc @@ -44,7 +44,9 @@ static const std::map kPassCodeMap = { {std::to_string(39), OptPassEnum::ClipByNormNoDivSquareSumFusion}, {std::to_string(42), OptPassEnum::MulAddNPass}, {std::to_string(43), OptPassEnum::Resnet50DbnDwFusionPass}, + {std::to_string(44), OptPassEnum::BatchMatMulDropOutDoMaskV3DFusionPass}, {std::to_string(45), OptPassEnum::MatmulConfusiontransposeUbFusion}, + {std::to_string(46), OptPassEnum::MatMulDropOutDoMaskV3DFusionPass}, {std::to_string(47), OptPassEnum::TbeBatchMatmulElementWiseFusionPass}, }; diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.h b/mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.h index 00a4b375b46..8221d70ef3c 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.h +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/lic_manager.h @@ -43,7 +43,9 @@ enum class OptPassEnum { ConvBnReduceFusionPass, MulAddNPass, Resnet50DbnDwFusionPass, + BatchMatMulDropOutDoMaskV3DFusionPass, MatmulConfusiontransposeUbFusion, + MatMulDropOutDoMaskV3DFusionPass, TbeBatchMatmulElementWiseFusionPass, Invalid, }; diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc index 4becd92e6f0..1f100f9a10b 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ascend_backend_optimization.cc @@ -110,9 +110,9 @@ #include "plugin/device/ascend/optimizer/buffer_fusion/matmul_eltwise_fusion_pass.h" #include "plugin/device/ascend/optimizer/buffer_fusion/matmul_confusiontranspose_fusion_pass.h" #include "plugin/device/ascend/optimizer/buffer_fusion/matmul_dropoutdomaskv3_add_fusion_pass.h" -#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h" #include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_dropoutdomaskv3_fusion_pass.h" #include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_reducesum_fusion_pass.h" +#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.h" #include "plugin/device/ascend/optimizer/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h" #include "plugin/device/ascend/optimizer/buffer_fusion/bnupdate_eltwise_fusion_pass.h" #include "plugin/device/ascend/optimizer/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h" @@ -561,7 +561,7 @@ void AscendBackendUBFusionOptimization(const std::shared_ptrAddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); if (!graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); } ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_dropoutdomaskv3_fusion_pass.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_dropoutdomaskv3_fusion_pass.h index 80a0774e6ce..f876111acfe 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_dropoutdomaskv3_fusion_pass.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_dropoutdomaskv3_fusion_pass.h @@ -32,7 +32,9 @@ namespace opt { class BatchMatmulDropoutDoMaskV3FusionPass : public FusionBasePass { public: explicit BatchMatmulDropoutDoMaskV3FusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("BatchMatmulDropoutDoMaskV3FusionPass", std::move(idAllocator)) {} + : FusionBasePass("BatchMatmulDropoutDoMaskV3FusionPass", std::move(idAllocator)) { + PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatMulDropOutDoMaskV3DFusionPass); + } ~BatchMatmulDropoutDoMaskV3FusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.cc new file mode 100644 index 00000000000..2e01ec331cc --- /dev/null +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.h" +#include +#include +#include "kernel/kernel_fusion.h" +#include "backend/common/session/anf_runtime_algorithm.h" +#include "include/common/utils/anfalgo.h" +#include "mindspore/core/ops/core_ops.h" +#include "common/graph_kernel/graph_kernel_flags.h" +#include "backend/common/optimizer/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +namespace { +constexpr auto kAttrNoFusion = "no_fusion"; + +CNodePtr FindInputNode(const CNodePtr &cnode, const string &node_type, const kernel::FusionType &fusion_type) { + auto input_num = common::AnfAlgo::GetInputTensorNum(cnode); + for (size_t i = 1; i <= input_num; ++i) { + auto input = cnode->input(i); + if (input->isa() && common::AnfAlgo::GetCNodeName(input) == node_type && + AnfAlgo::GetFusionType(input) == fusion_type) { + return input->cast(); + } + } + return nullptr; +} +} // namespace + +bool BatchMatmulEltwiseFusionPass::MatchPattern1(const CNodePtr &eltwise1, mindspore::HashSet *record) { + // bmm - eltwise - eltwise1 + const std::set kElem1TypeList = {kAddOpName, kReluOpName, kFusedMulAddOpName}; + if (kElem1TypeList.find(common::AnfAlgo::GetCNodeName(eltwise1)) == kElem1TypeList.end()) { + return false; + } + + auto input_num = common::AnfAlgo::GetInputTensorNum(eltwise1); + for (size_t i = 1; i <= input_num; ++i) { + auto eltwise1_input = eltwise1->input(i); + if (eltwise1_input->isa() && MatchPattern2(eltwise1_input->cast(), record)) { + record->insert(eltwise1); + return true; + } + } + return false; +} + +bool BatchMatmulEltwiseFusionPass::MatchPattern2(const CNodePtr &eltwise, mindspore::HashSet *record) { + // bmm - eltwise + const std::set kElemTypeList = {kFusedMulAddOpName, kAddOpName, kDivOpName, + kRealDivOpName, kReluOpName, kReluGradOpName}; + if (kElemTypeList.find(common::AnfAlgo::GetCNodeName(eltwise)) == kElemTypeList.end()) { + return false; + } + + CNodePtr bmm = FindInputNode(eltwise, kBatchMatMulOpName, kernel::FusionType::BATCH_MATMUL); + if (bmm == nullptr || common::AnfAlgo::IsDynamicShape(bmm) || common::AnfAlgo::GetBooleanAttr(bmm, kAttrNoFusion)) { + return false; + } + + record->insert(eltwise); + record->insert(bmm); + return true; +} + +bool BatchMatmulEltwiseFusionPass::MatchPattern3(const CNodePtr &eltwise, mindspore::HashSet *record) { + // bmm - eltwise1(mul) - eltwise2(sigmoid) - eltwise(mul) + if (common::AnfAlgo::GetCNodeName(eltwise) != kMulOpName) { + return false; + } + + CNodePtr eltwise2 = FindInputNode(eltwise, kSigmoidOpName, kernel::FusionType::ELEMWISE); + if (eltwise2 == nullptr) { + return false; + } + + CNodePtr eltwise1 = FindInputNode(eltwise2, kMulOpName, kernel::FusionType::ELEMWISE); + if (eltwise1 == nullptr) { + return false; + } + + CNodePtr bmm = FindInputNode(eltwise1, kBatchMatMulOpName, kernel::FusionType::BATCH_MATMUL); + if (bmm == nullptr || common::AnfAlgo::IsDynamicShape(bmm)) { + return false; + } + + record->insert(eltwise); + record->insert(eltwise2); + record->insert(eltwise1); + record->insert(bmm); + return true; +} + +void BatchMatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + + const auto &node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + mindspore::HashSet record; + if (MatchPattern1(cnode, &record) || MatchPattern2(cnode, &record) || MatchPattern3(cnode, &record)) { + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.h similarity index 65% rename from mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h rename to mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.h index 9614b24cf5f..465cd02b0bb 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.h @@ -1,5 +1,5 @@ /** - * Copyright 2021 Huawei Technologies Co., Ltd + * Copyright 2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_FUSION_PASS_H_ -#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_FUSION_PASS_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_ELTWISE_FUSION_PASS_H_ #include "utils/hash_set.h" #include "plugin/device/ascend/optimizer/buffer_fusion/fusion_base_pass.h" @@ -27,20 +27,21 @@ namespace mindspore { namespace opt { -class BatchMatmulFusedMulAddFusionPass : public FusionBasePass { +class BatchMatmulEltwiseFusionPass : public FusionBasePass { public: - explicit BatchMatmulFusedMulAddFusionPass(const FusionIdAllocatorPtr &idAllocator) - : FusionBasePass("BatchMatmulFusedMulAddFusionPass", idAllocator) { - PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatmulFusedMulAddFusionPass); + explicit BatchMatmulEltwiseFusionPass(const FusionIdAllocatorPtr &idAllocator) + : FusionBasePass("BatchMatmulEltwiseFusionPass", idAllocator) { + PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::TbeBatchMatmulElementWiseFusionPass); } - ~BatchMatmulFusedMulAddFusionPass() override = default; + ~BatchMatmulEltwiseFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; private: - void MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); + bool MatchPattern1(const CNodePtr &eltwise1, mindspore::HashSet *record); + bool MatchPattern2(const CNodePtr &eltwise, mindspore::HashSet *record); + bool MatchPattern3(const CNodePtr &eltwise, mindspore::HashSet *record); }; } // namespace opt } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_FUSION_PASS_H_ +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.cc deleted file mode 100644 index b7c3bdf51ec..00000000000 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.cc +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h" -#include "kernel/kernel_fusion.h" -#include "backend/common/session/anf_runtime_algorithm.h" -#include "include/common/utils/anfalgo.h" -#include "mindspore/core/ops/core_ops.h" -#include "utils/ms_context.h" -#include "backend/common/optimizer/fusion_id_allocator.h" - -namespace mindspore { -namespace opt { -void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(candidate_fusion); - auto batch_matmul = cnode->input(kIndex2); - MS_EXCEPTION_IF_NULL(batch_matmul); - if (batch_matmul->isa() && common::AnfAlgo::CheckPrimitiveType(batch_matmul, prim::kPrimBatchMatMul)) { - mindspore::HashSet record{cnode, batch_matmul}; - candidate_fusion->push_back(record); - SetRecordFusionId(record); - } -} - -void BatchMatmulFusedMulAddFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(candidate_fusion); - const auto &node_list = TopoSort(kernel_graph.get_return()); - for (auto &node : node_list) { - MS_EXCEPTION_IF_NULL(node); - if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || - common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - - if (common::AnfAlgo::GetCNodeName(cnode) == kFusedMulAddOpName) { - MatchBatchMatmulFusedMulAdd(cnode, kernel_graph, candidate_fusion); - } - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_dropoutdomaskv3_add_fusion_pass.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_dropoutdomaskv3_add_fusion_pass.h index 3cd40a2a0ea..071a2eaefb2 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_dropoutdomaskv3_add_fusion_pass.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/buffer_fusion/matmul_dropoutdomaskv3_add_fusion_pass.h @@ -30,7 +30,9 @@ namespace opt { class MatmulDropoutDoMaskV3AddFusionPass : public FusionBasePass { public: explicit MatmulDropoutDoMaskV3AddFusionPass(FusionIdAllocatorPtr idAllocator) - : FusionBasePass("MatmulDropoutDoMaskV3AddFusionPass", idAllocator) {} + : FusionBasePass("MatmulDropoutDoMaskV3AddFusionPass", idAllocator) { + PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MatMulDropOutDoMaskV3DFusionPass); + } ~MatmulDropoutDoMaskV3AddFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;