!45210 Fix addr check in subgraph zero copy.

Merge pull request !45210 from gaoyong10/dynamic_shape_05
This commit is contained in:
i-robot 2022-11-07 15:43:55 +00:00 committed by Gitee
commit f4199c045e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 39 additions and 28 deletions

View File

@ -614,7 +614,16 @@ bool RtModelZeroCopy::UpdateTaskArgs(const session::KernelGraph &graph, void *st
return false;
}
MS_LOG(INFO) << "Check rtMode valid " << ((rtStreamSynchronize(stream) == RT_ERROR_NONE) && CheckRtModelValid(graph));
if (rtStreamSynchronize(stream) != RT_ERROR_NONE) {
MS_LOG(WARNING) << "Sync stream for graph:" << graph.ToString() << " failed.";
return true;
}
// If the zero copy in graph mode is enabled, the input and output addr in task may not be same as addr in graph,
// so skip the addr check.
if (!graph.has_flag(kFlagEnableZeroCopyInGraph)) {
MS_LOG(INFO) << "Check rtMode valid " << (CheckRtModelValid(graph));
}
return true;
}

View File

@ -144,7 +144,7 @@ void ExitActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context)
for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
if ((input_device_tensors_[i] != nullptr) && (input_device_tensors_[i]->dynamic_ref_count() == 0) &&
(device_contexts_[i] != nullptr)) {
MS_LOG(WARNING) << GetAID().Name() << " input index:" << i << " has no user and free the memory.";
MS_LOG(INFO) << GetAID().Name() << " input index:" << i << " has no user and free the memory.";
device_contexts_[i]->device_res_manager_->FreeMemory(input_device_tensors_[i]);
}
}

View File

@ -1284,13 +1284,13 @@ KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const De
namespace {
void GetAllUInputByCNode(const CNodePtr &cnode,
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *cnode_to_u_inputs) {
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *cnode_to_monad_inputs) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(cnode_to_u_inputs);
if (cnode_to_u_inputs->find(cnode) != cnode_to_u_inputs->end()) {
MS_EXCEPTION_IF_NULL(cnode_to_monad_inputs);
if (cnode_to_monad_inputs->find(cnode) != cnode_to_monad_inputs->end()) {
return;
}
(*cnode_to_u_inputs)[cnode] = {};
(*cnode_to_monad_inputs)[cnode] = {};
for (const auto &input : cnode->inputs()) {
MS_EXCEPTION_IF_NULL(input);
if (!input->isa<CNode>()) {
@ -1299,27 +1299,28 @@ void GetAllUInputByCNode(const CNodePtr &cnode,
const auto &cinput = input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cinput);
if (common::AnfAlgo::GetCNodeName(cinput) == kUpdateStateOpName) {
(*cnode_to_u_inputs)[cnode].emplace(cinput);
(*cnode_to_monad_inputs)[cnode].emplace(cinput);
}
GetAllUInputByCNode(cinput, cnode_to_u_inputs);
(*cnode_to_u_inputs)[cnode].insert((*cnode_to_u_inputs)[cinput].begin(), (*cnode_to_u_inputs)[cinput].end());
GetAllUInputByCNode(cinput, cnode_to_monad_inputs);
(*cnode_to_monad_inputs)[cnode].insert((*cnode_to_monad_inputs)[cinput].begin(),
(*cnode_to_monad_inputs)[cinput].end());
}
}
void GetAllCNodeUInputByGraph(const KernelGraphPtr &graph,
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *cnode_to_u_inputs) {
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *cnode_to_monad_inputs) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(cnode_to_u_inputs);
MS_EXCEPTION_IF_NULL(cnode_to_monad_inputs);
for (const auto &kernel : graph->execution_order()) {
MS_EXCEPTION_IF_NULL(kernel);
GetAllUInputByCNode(kernel, cnode_to_u_inputs);
GetAllUInputByCNode(kernel, cnode_to_monad_inputs);
}
}
// Check if the first input of update state should be linked, if the other inputs of update state has depend the first
// input, it would not be linked.
bool IsNeedLinkForFirstInput(const CNodePtr &cnode,
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_u_inputs) {
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_monad_inputs) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() <= kUpdateStateStateInput) {
MS_LOG(EXCEPTION) << "Invalid update state node:" << cnode->DebugString();
@ -1328,8 +1329,8 @@ bool IsNeedLinkForFirstInput(const CNodePtr &cnode,
MS_EXCEPTION_IF_NULL(u_input);
for (size_t i = kUpdateStateRealInput; i < cnode->inputs().size(); ++i) {
MS_EXCEPTION_IF_NULL(cnode->input(i));
const auto &iter = cnode_to_u_inputs.find(cnode->input(i));
if (iter != cnode_to_u_inputs.end() && iter->second.find(u_input) != iter->second.end()) {
const auto &iter = cnode_to_monad_inputs.find(cnode->input(i));
if (iter != cnode_to_monad_inputs.end() && iter->second.find(u_input) != iter->second.end()) {
return false;
}
}
@ -1406,9 +1407,9 @@ void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
MS_EXCEPTION_IF_NULL(communication_nodes);
// Collect all the depend updatestate nodes of the kernels for linking control arrow.
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> cnode_to_u_inputs;
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> cnode_to_monad_inputs;
MS_LOG(INFO) << "Get all u input of cnode in graph:" << graph->ToString() << " start.";
GetAllCNodeUInputByGraph(graph, &cnode_to_u_inputs);
GetAllCNodeUInputByGraph(graph, &cnode_to_monad_inputs);
MS_LOG(INFO) << "Get all u input of cnode in graph:" << graph->ToString() << " end.";
auto &execution_order = graph->execution_order();
@ -1432,7 +1433,7 @@ void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
if (SchedulerHelper::HasMonadControl(input_node, graph)) {
LinkControlArrowByAutoMonad(kernel_actor, input_node, graph, graph_compiler_info.control_node_parser_,
cnode_to_u_inputs);
cnode_to_monad_inputs);
}
if (HasAbstractMonad(input_node)) {
(void)auto_monad_actors->emplace_back(kernel_actor);
@ -1693,7 +1694,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor,
void GraphScheduler::LinkControlArrowByAutoMonad(
AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph, const ControlNodeParserPtr &parser,
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_u_inputs) {
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_monad_inputs) {
MS_EXCEPTION_IF_NULL(to_actor);
MS_EXCEPTION_IF_NULL(from_node);
MS_EXCEPTION_IF_NULL(graph);
@ -1712,7 +1713,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(
if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) {
MS_EXCEPTION_IF_NULL(input_cnode);
for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph, parser, cnode_to_u_inputs);
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph, parser, cnode_to_monad_inputs);
}
return;
}
@ -1732,7 +1733,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(
real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend));
real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
} else if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) {
if (IsNeedLinkForFirstInput(input_cnode, cnode_to_u_inputs) &&
if (IsNeedLinkForFirstInput(input_cnode, cnode_to_monad_inputs) &&
input_cnode->inputs().size() > kUpdateStateStateInput) {
// If all other inputs of the update state do not depend on the first input, we need to link control arrow
// for the first input.
@ -1780,7 +1781,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(
// The monad node and make tuple node need recursion.
if (IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph, parser, cnode_to_u_inputs);
LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph, parser, cnode_to_monad_inputs);
continue;
}

View File

@ -168,12 +168,13 @@ class BACKEND_EXPORT GraphScheduler {
const KernelWithIndex &to_kernel_with_input_idx);
// 2. The processing of linking control arrows.
// The parameter cnode_to_u_inputs contains all the update states that each cnode in the graph depends on. When
// processing the first input of update state, the map is used to check whether it is necessary to link control
// arrow for the first input of update state.
void LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph,
const ControlNodeParserPtr &parser = nullptr,
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_u_inputs = {});
// The parameter cnode_to_monad_inputs contains all the update states that each cnode in the graph depends on. When
// processing the first input of update state, the map is used to check whether it is necessary to link control arrow
// for the first input of update state.
void LinkControlArrowByAutoMonad(
AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph,
const ControlNodeParserPtr &parser = nullptr,
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_monad_inputs = {});
// The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
void LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node,
const KernelGraphPtr &graph) const;