From 20ea704dc5419b46b336599e2486b3be824a629b Mon Sep 17 00:00:00 2001 From: linqingke Date: Tue, 6 Apr 2021 09:30:00 +0800 Subject: [PATCH] fix transdata-depend-load can't eliminate. --- mindspore/ccsrc/backend/optimizer/common/helper.cc | 2 +- .../ccsrc/backend/optimizer/pass/optimize_dependence.cc | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 36e5ae14a91..387899599b3 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -472,7 +472,7 @@ bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { auto out_node = output.first; auto name = AnfAlgo::GetCNodeName(out_node); if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() || - name == prim::kPrimTupleGetItem->name()) { + name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name()) { auto result = IsNotRealUsedByOthers(graph, out_node); if (!result) { return result; diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc index 58af92d8c12..d032e77a396 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc @@ -52,11 +52,12 @@ CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cno CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name()) { + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name() && + AnfAlgo::GetCNodeName(cnode) != prim::kPrimLoad->name()) { return nullptr; } auto virtual_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependVirtualInputIndex); - if (!AnfAlgo::CheckPrimitiveType(virtual_input_op, prim::kPrimUpdateState)) { + if (!HasAbstractMonad(virtual_input_op)) { return nullptr; } auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex);