!14643 [Grad Accumulation]Fix transdata-depend-load can't eliminate.

From: @linqingke
Reviewed-by: @xu-yfei,@guoqi1024
Signed-off-by: @xu-yfei
This commit is contained in:
mindspore-ci-bot 2021-04-06 20:57:11 +08:00 committed by Gitee
commit da9b8c0f46
2 changed files with 4 additions and 3 deletions

View File

@ -472,7 +472,7 @@ bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) {
auto out_node = output.first;
auto name = AnfAlgo::GetCNodeName(out_node);
if (name == prim::kPrimDepend->name() || name == prim::kPrimMakeTuple->name() ||
name == prim::kPrimTupleGetItem->name()) {
name == prim::kPrimTupleGetItem->name() || name == prim::kPrimLoad->name()) {
auto result = IsNotRealUsedByOthers(graph, out_node);
if (!result) {
return result;

View File

@ -52,11 +52,12 @@ CNodePtr CreateNewDependNode(const FuncGraphPtr &func_graph, const CNodePtr &cno
CNodePtr CheckIsolatedVirtualNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name()) {
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDepend->name() &&
AnfAlgo::GetCNodeName(cnode) != prim::kPrimLoad->name()) {
return nullptr;
}
auto virtual_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependVirtualInputIndex);
if (!AnfAlgo::CheckPrimitiveType(virtual_input_op, prim::kPrimUpdateState)) {
if (!HasAbstractMonad(virtual_input_op)) {
return nullptr;
}
auto real_input_op = AnfAlgo::GetInputNode(cnode, kIsolatedDependRealInputIndex);