!14913 add relu fusion check

From: @wilfchen
Reviewed-by: @cristoval,@limingqi107
Signed-off-by: @limingqi107
This commit is contained in:
mindspore-ci-bot 2021-04-12 09:27:02 +08:00 committed by Gitee
commit 888a2565bc
2 changed files with 12 additions and 0 deletions

View File

@ -73,6 +73,12 @@ const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const A
return nullptr;
}
std::vector<size_t> shape1 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 0);
std::vector<size_t> shape2 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 1);
if (shape1 != shape2) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kFusedAddReluGradV2Name);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2, mask};

View File

@ -72,6 +72,12 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo
return nullptr;
}
std::vector<size_t> shape1 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 0);
std::vector<size_t> shape2 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 1);
if (shape1 != shape2) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kFusedAddReluV2Name);
MS_EXCEPTION_IF_NULL(prim);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2};