forked from mindspore-Ecosystem/mindspore
!14913 add relu fusion check
From: @wilfchen Reviewed-by: @cristoval,@limingqi107 Signed-off-by: @limingqi107
This commit is contained in:
commit
888a2565bc
|
@ -73,6 +73,12 @@ const AnfNodePtr AddReluGradV2Fusion::Process(const FuncGraphPtr &graph, const A
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<size_t> shape1 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 0);
|
||||
std::vector<size_t> shape2 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 1);
|
||||
if (shape1 != shape2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedAddReluGradV2Name);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2, mask};
|
||||
|
|
|
@ -72,6 +72,12 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<size_t> shape1 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 0);
|
||||
std::vector<size_t> shape2 = AnfAlgo::GetPrevNodeOutputInferShape(tensor_add, 1);
|
||||
if (shape1 != shape2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kFusedAddReluV2Name);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), x1, x2};
|
||||
|
|
Loading…
Reference in New Issue