!14643 [Grad Accumulation]Fix transdata-depend-load can't eliminate.
From: @linqingke Reviewed-by: @xu-yfei,@guoqi1024 Signed-off-by: @xu-yfei
This commit is contained in:
commit
da9b8c0f46
|
@ -472,7 +472,7 @@ bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
|
|||
auto out_node = output.first;
|
||||
auto name = AnfAlgo::GetCNodeName(out_node);
|
||||
if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
|
||||
name == prim::kPrimTupleGetItem->name()) {
|
||||
name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name()) {
|
||||
auto result = IsNotRealUsedByOthers(graph, out_node);
|
||||
if (!result) {
|
||||
return result;
|
||||
|
|
|
@ -52,11 +52,12 @@ CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cno
|
|||
|
||||
CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name()) {
|
||||
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name() &&
|
||||
AnfAlgo::GetCNodeName(cnode) != prim::kPrimLoad->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto virtual_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependVirtualInputIndex);
|
||||
if (!AnfAlgo::CheckPrimitiveType(virtual_input_op, prim::kPrimUpdateState)) {
|
||||
if (!HasAbstractMonad(virtual_input_op)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex);
|
||||
|
|
Loading…
Reference in New Issue