From 99779473b9442ac1c6a875c952bdb1dbbf238fd8 Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Mon, 7 Nov 2022 14:24:09 +0800 Subject: [PATCH] Fix addr check in subgraph zero copy. --- .../hal/device/tasksink/rtmodel_zero_copy.cc | 11 ++++- .../actor/control_flow/exit_actor.cc | 2 +- .../graph_scheduler/graph_scheduler.cc | 41 ++++++++++--------- .../runtime/graph_scheduler/graph_scheduler.h | 13 +++--- 4 files changed, 39 insertions(+), 28 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.cc index 8c3a6f1176a..e13390e3707 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.cc @@ -614,7 +614,16 @@ bool RtModelZeroCopy::UpdateTaskArgs(const session::KernelGraph &graph, void *st return false; } - MS_LOG(INFO) << "Check rtMode valid " << ((rtStreamSynchronize(stream) == RT_ERROR_NONE) && CheckRtModelValid(graph)); + if (rtStreamSynchronize(stream) != RT_ERROR_NONE) { + MS_LOG(WARNING) << "Sync stream for graph:" << graph.ToString() << " failed."; + return true; + } + + // If the zero copy in graph mode is enabled, the input and output addr in task may not be same as addr in graph, + // so skip the addr check. + if (!graph.has_flag(kFlagEnableZeroCopyInGraph)) { + MS_LOG(INFO) << "Check rtMode valid " << (CheckRtModelValid(graph)); + } return true; } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/exit_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/exit_actor.cc index 4f775cc9c65..9be94188242 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/exit_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/control_flow/exit_actor.cc @@ -144,7 +144,7 @@ void ExitActor::IncreaseDynamicRefCounts(OpContext *const context) for (size_t i = 0; i < input_device_tensors_.size(); ++i) { if ((input_device_tensors_[i] != nullptr) && (input_device_tensors_[i]->dynamic_ref_count() == 0) && (device_contexts_[i] != nullptr)) { - MS_LOG(WARNING) << GetAID().Name() << " input index:" << i << " has no user and free the memory."; + MS_LOG(INFO) << GetAID().Name() << " input index:" << i << " has no user and free the memory."; device_contexts_[i]->device_res_manager_->FreeMemory(input_device_tensors_[i]); } } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc index e2f788de6a5..2a908ce4067 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc @@ -1284,13 +1284,13 @@ KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const De namespace { void GetAllUInputByCNode(const CNodePtr &cnode, - mindspore::HashMap> *cnode_to_u_inputs) { + mindspore::HashMap> *cnode_to_monad_inputs) { MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(cnode_to_u_inputs); - if (cnode_to_u_inputs->find(cnode) != cnode_to_u_inputs->end()) { + MS_EXCEPTION_IF_NULL(cnode_to_monad_inputs); + if (cnode_to_monad_inputs->find(cnode) != cnode_to_monad_inputs->end()) { return; } - (*cnode_to_u_inputs)[cnode] = {}; + (*cnode_to_monad_inputs)[cnode] = {}; for (const auto &input : cnode->inputs()) { MS_EXCEPTION_IF_NULL(input); if (!input->isa()) { @@ -1299,27 +1299,28 @@ void GetAllUInputByCNode(const CNodePtr &cnode, const auto &cinput = input->cast(); MS_EXCEPTION_IF_NULL(cinput); if (common::AnfAlgo::GetCNodeName(cinput) == kUpdateStateOpName) { - (*cnode_to_u_inputs)[cnode].emplace(cinput); + (*cnode_to_monad_inputs)[cnode].emplace(cinput); } - GetAllUInputByCNode(cinput, cnode_to_u_inputs); - (*cnode_to_u_inputs)[cnode].insert((*cnode_to_u_inputs)[cinput].begin(), (*cnode_to_u_inputs)[cinput].end()); + GetAllUInputByCNode(cinput, cnode_to_monad_inputs); + (*cnode_to_monad_inputs)[cnode].insert((*cnode_to_monad_inputs)[cinput].begin(), + (*cnode_to_monad_inputs)[cinput].end()); } } void GetAllCNodeUInputByGraph(const KernelGraphPtr &graph, - mindspore::HashMap> *cnode_to_u_inputs) { + mindspore::HashMap> *cnode_to_monad_inputs) { MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(cnode_to_u_inputs); + MS_EXCEPTION_IF_NULL(cnode_to_monad_inputs); for (const auto &kernel : graph->execution_order()) { MS_EXCEPTION_IF_NULL(kernel); - GetAllUInputByCNode(kernel, cnode_to_u_inputs); + GetAllUInputByCNode(kernel, cnode_to_monad_inputs); } } // Check if the first input of update state should be linked, if the other inputs of update state has depend the first // input, it would not be linked. bool IsNeedLinkForFirstInput(const CNodePtr &cnode, - const mindspore::HashMap> &cnode_to_u_inputs) { + const mindspore::HashMap> &cnode_to_monad_inputs) { MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() <= kUpdateStateStateInput) { MS_LOG(EXCEPTION) << "Invalid update state node:" << cnode->DebugString(); @@ -1328,8 +1329,8 @@ bool IsNeedLinkForFirstInput(const CNodePtr &cnode, MS_EXCEPTION_IF_NULL(u_input); for (size_t i = kUpdateStateRealInput; i < cnode->inputs().size(); ++i) { MS_EXCEPTION_IF_NULL(cnode->input(i)); - const auto &iter = cnode_to_u_inputs.find(cnode->input(i)); - if (iter != cnode_to_u_inputs.end() && iter->second.find(u_input) != iter->second.end()) { + const auto &iter = cnode_to_monad_inputs.find(cnode->input(i)); + if (iter != cnode_to_monad_inputs.end() && iter->second.find(u_input) != iter->second.end()) { return false; } } @@ -1406,9 +1407,9 @@ void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph, MS_EXCEPTION_IF_NULL(communication_nodes); // Collect all the depend updatestate nodes of the kernels for linking control arrow. - mindspore::HashMap> cnode_to_u_inputs; + mindspore::HashMap> cnode_to_monad_inputs; MS_LOG(INFO) << "Get all u input of cnode in graph:" << graph->ToString() << " start."; - GetAllCNodeUInputByGraph(graph, &cnode_to_u_inputs); + GetAllCNodeUInputByGraph(graph, &cnode_to_monad_inputs); MS_LOG(INFO) << "Get all u input of cnode in graph:" << graph->ToString() << " end."; auto &execution_order = graph->execution_order(); @@ -1432,7 +1433,7 @@ void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph, // Link the control arrows of kernel actor by the auto monad, the inputs include monad node. if (SchedulerHelper::HasMonadControl(input_node, graph)) { LinkControlArrowByAutoMonad(kernel_actor, input_node, graph, graph_compiler_info.control_node_parser_, - cnode_to_u_inputs); + cnode_to_monad_inputs); } if (HasAbstractMonad(input_node)) { (void)auto_monad_actors->emplace_back(kernel_actor); @@ -1693,7 +1694,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, void GraphScheduler::LinkControlArrowByAutoMonad( AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph, const ControlNodeParserPtr &parser, - const mindspore::HashMap> &cnode_to_u_inputs) { + const mindspore::HashMap> &cnode_to_monad_inputs) { MS_EXCEPTION_IF_NULL(to_actor); MS_EXCEPTION_IF_NULL(from_node); MS_EXCEPTION_IF_NULL(graph); @@ -1712,7 +1713,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad( if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) { MS_EXCEPTION_IF_NULL(input_cnode); for (size_t i = 1; i < input_cnode->inputs().size(); ++i) { - LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph, parser, cnode_to_u_inputs); + LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph, parser, cnode_to_monad_inputs); } return; } @@ -1732,7 +1733,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad( real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend)); real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex)); } else if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) { - if (IsNeedLinkForFirstInput(input_cnode, cnode_to_u_inputs) && + if (IsNeedLinkForFirstInput(input_cnode, cnode_to_monad_inputs) && input_cnode->inputs().size() > kUpdateStateStateInput) { // If all other inputs of the update state do not depend on the first input, we need to link control arrow // for the first input. @@ -1780,7 +1781,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad( // The monad node and make tuple node need recursion. if (IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) { - LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph, parser, cnode_to_u_inputs); + LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph, parser, cnode_to_monad_inputs); continue; } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.h b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.h index 92e2f9381c4..d228492b3e2 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.h @@ -168,12 +168,13 @@ class BACKEND_EXPORT GraphScheduler { const KernelWithIndex &to_kernel_with_input_idx); // 2. The processing of linking control arrows. - // The parameter cnode_to_u_inputs contains all the update states that each cnode in the graph depends on. When - // processing the first input of update state, the map is used to check whether it is necessary to link control - // arrow for the first input of update state. - void LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph, - const ControlNodeParserPtr &parser = nullptr, - const mindspore::HashMap> &cnode_to_u_inputs = {}); + // The parameter cnode_to_monad_inputs contains all the update states that each cnode in the graph depends on. When + // processing the first input of update state, the map is used to check whether it is necessary to link control arrow + // for the first input of update state. + void LinkControlArrowByAutoMonad( + AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph, + const ControlNodeParserPtr &parser = nullptr, + const mindspore::HashMap> &cnode_to_monad_inputs = {}); // The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node. void LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node, const KernelGraphPtr &graph) const;