diff --git a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc index 2308ae6dabe..147505b0e68 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc @@ -73,6 +73,12 @@ const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const A return nullptr; } + std::vector shape1 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 0); + std::vector shape2 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 1); + if (shape1 != shape2) { + return nullptr; + } + auto prim = std::make_shared(kFusedAddReluGradV2Name); MS_EXCEPTION_IF_NULL(prim); std::vector inputs = {NewValueNode(prim), x1, x2, mask}; diff --git a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc index 1cea7134c48..d9a7d4f5957 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc @@ -72,6 +72,12 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo return nullptr; } + std::vector shape1 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 0); + std::vector shape2 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 1); + if (shape1 != shape2) { + return nullptr; + } + auto prim = std::make_shared(kFusedAddReluV2Name); MS_EXCEPTION_IF_NULL(prim); std::vector inputs = {NewValueNode(prim), x1, x2};