diff --git a/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc b/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc index 012135f3444..6206c77f31a 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/control_node_scheduler.cc @@ -69,9 +69,16 @@ void FetchRealDependNodeByAutoMonad(const AnfNodePtr &node, std::set return; } + const mindspore::HashSet recursion_prims = { + prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple}; if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimLoad)) { FetchRealDependNodeByAutoMonad(real_inputs[kDependAttachNodeIndex], depend_nodes); + // The real input may be this scene: depend/load --> load/depend, so need add the control arrow for real input + // node in this scene. + if (IsOneOfPrimitiveCNode(real_inputs[kRealInputIndexInDepend], recursion_prims)) { + FetchRealDependNodeByAutoMonad(real_inputs[kRealInputIndexInDepend], depend_nodes); + } } else if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimUpdateState)) { for (size_t i = kUpdateStateRealInput; i < real_inputs.size(); ++i) { FetchRealDependNodeByAutoMonad(real_inputs[i], depend_nodes);