From c4bbf5a282103d412f966734a7306cf86de477a6 Mon Sep 17 00:00:00 2001 From: yujianfeng Date: Thu, 3 Sep 2020 14:27:33 +0800 Subject: [PATCH] Fix bnupdate_eltwise_eltwise ub fusion pass --- .../bnupdate_eltwise_eltwise_fusion_pass.cc | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc index 129c6e1f599..715fed7d793 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.cc @@ -27,6 +27,20 @@ namespace mindspore { namespace opt { +namespace { +constexpr size_t kEltwiseInputSize = 2; +constexpr size_t kEltwiseOutputSize = 2; +bool CheckEltwiseInputAndOutputSize(const AnfNodePtr &node) { + if (AnfAlgo::GetInputTensorNum(node) == kEltwiseInputSize) { + return true; + } + if (AnfAlgo::GetOutputTensorNum(node) == kEltwiseOutputSize) { + return true; + } + return false; +} +} // namespace + void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { @@ -74,8 +88,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) { auto eltwise_input = cnode->input(1); + MS_EXCEPTION_IF_NULL(eltwise_input); if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) { MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); }