diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc index 74ef83dcf4a..017c475eb4a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.cc @@ -27,15 +27,15 @@ namespace mindspore { namespace opt { -void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, - const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion) { +void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(candidate_fusion); auto manager = kernel_graph.manager(); MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(relu_input); - auto getitem = relu_input->cast(); + MS_EXCEPTION_IF_NULL(eltwise_input); + auto getitem = eltwise_input->cast(); MS_EXCEPTION_IF_NULL(getitem); auto bnupdate = getitem->input(1); MS_EXCEPTION_IF_NULL(bnupdate); @@ -68,10 +68,11 @@ void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGr auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && + AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE) { auto eltwise_input = cnode->input(1); if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { - MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); + MatchBnupdateDoubleOutputEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion); } } } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h index 9ca88959de7..b9284f424b2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h @@ -39,8 +39,8 @@ class BnupdateEltwiseFusionPass : public FusionBasePass { void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; private: - void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, - FusedNodeRecord *candidate_fusion); + void MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input, + const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h index 024ce416e30..c78e93cd97c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h @@ -33,6 +33,7 @@ const int8_t MAX_ELTWISE_NUM = 3; const int8_t MIN_ELTWISE_SIZE = 2; const int8_t ELTWISE_INPUT_SIZE = 2; const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3; +const int8_t ELTWISE_DOUBLE_OUTPUT_SIZE = 2; const int8_t CONV_DOUBLE_IN_INPUT_SIZE = 3; const int8_t CONV_QUART_IN_INPUT_SIZE = 5; const int8_t ELTWISE_USE = 1;