!45210 Fix addr check in subgraph zero copy.
Merge pull request !45210 from gaoyong10/dynamic_shape_05
This commit is contained in:
commit
f4199c045e
|
@ -614,7 +614,16 @@ bool RtModelZeroCopy::UpdateTaskArgs(const session::KernelGraph &graph, void *st
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "Check rtMode valid " << ((rtStreamSynchronize(stream) == RT_ERROR_NONE) && CheckRtModelValid(graph));
|
if (rtStreamSynchronize(stream) != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(WARNING) << "Sync stream for graph:" << graph.ToString() << " failed.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the zero copy in graph mode is enabled, the input and output addr in task may not be same as addr in graph,
|
||||||
|
// so skip the addr check.
|
||||||
|
if (!graph.has_flag(kFlagEnableZeroCopyInGraph)) {
|
||||||
|
MS_LOG(INFO) << "Check rtMode valid " << (CheckRtModelValid(graph));
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -144,7 +144,7 @@ void ExitActor::IncreaseDynamicRefCounts(OpContext<DeviceTensor> *const context)
|
||||||
for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
|
for (size_t i = 0; i < input_device_tensors_.size(); ++i) {
|
||||||
if ((input_device_tensors_[i] != nullptr) && (input_device_tensors_[i]->dynamic_ref_count() == 0) &&
|
if ((input_device_tensors_[i] != nullptr) && (input_device_tensors_[i]->dynamic_ref_count() == 0) &&
|
||||||
(device_contexts_[i] != nullptr)) {
|
(device_contexts_[i] != nullptr)) {
|
||||||
MS_LOG(WARNING) << GetAID().Name() << " input index:" << i << " has no user and free the memory.";
|
MS_LOG(INFO) << GetAID().Name() << " input index:" << i << " has no user and free the memory.";
|
||||||
device_contexts_[i]->device_res_manager_->FreeMemory(input_device_tensors_[i]);
|
device_contexts_[i]->device_res_manager_->FreeMemory(input_device_tensors_[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1284,13 +1284,13 @@ KernelActorPtr GraphScheduler::GenerateRpcActor(const CNodePtr &kernel, const De
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void GetAllUInputByCNode(const CNodePtr &cnode,
|
void GetAllUInputByCNode(const CNodePtr &cnode,
|
||||||
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *cnode_to_u_inputs) {
|
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *cnode_to_monad_inputs) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
MS_EXCEPTION_IF_NULL(cnode_to_u_inputs);
|
MS_EXCEPTION_IF_NULL(cnode_to_monad_inputs);
|
||||||
if (cnode_to_u_inputs->find(cnode) != cnode_to_u_inputs->end()) {
|
if (cnode_to_monad_inputs->find(cnode) != cnode_to_monad_inputs->end()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
(*cnode_to_u_inputs)[cnode] = {};
|
(*cnode_to_monad_inputs)[cnode] = {};
|
||||||
for (const auto &input : cnode->inputs()) {
|
for (const auto &input : cnode->inputs()) {
|
||||||
MS_EXCEPTION_IF_NULL(input);
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
if (!input->isa<CNode>()) {
|
if (!input->isa<CNode>()) {
|
||||||
|
@ -1299,27 +1299,28 @@ void GetAllUInputByCNode(const CNodePtr &cnode,
|
||||||
const auto &cinput = input->cast<CNodePtr>();
|
const auto &cinput = input->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cinput);
|
MS_EXCEPTION_IF_NULL(cinput);
|
||||||
if (common::AnfAlgo::GetCNodeName(cinput) == kUpdateStateOpName) {
|
if (common::AnfAlgo::GetCNodeName(cinput) == kUpdateStateOpName) {
|
||||||
(*cnode_to_u_inputs)[cnode].emplace(cinput);
|
(*cnode_to_monad_inputs)[cnode].emplace(cinput);
|
||||||
}
|
}
|
||||||
GetAllUInputByCNode(cinput, cnode_to_u_inputs);
|
GetAllUInputByCNode(cinput, cnode_to_monad_inputs);
|
||||||
(*cnode_to_u_inputs)[cnode].insert((*cnode_to_u_inputs)[cinput].begin(), (*cnode_to_u_inputs)[cinput].end());
|
(*cnode_to_monad_inputs)[cnode].insert((*cnode_to_monad_inputs)[cinput].begin(),
|
||||||
|
(*cnode_to_monad_inputs)[cinput].end());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GetAllCNodeUInputByGraph(const KernelGraphPtr &graph,
|
void GetAllCNodeUInputByGraph(const KernelGraphPtr &graph,
|
||||||
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *cnode_to_u_inputs) {
|
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *cnode_to_monad_inputs) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(cnode_to_u_inputs);
|
MS_EXCEPTION_IF_NULL(cnode_to_monad_inputs);
|
||||||
for (const auto &kernel : graph->execution_order()) {
|
for (const auto &kernel : graph->execution_order()) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel);
|
MS_EXCEPTION_IF_NULL(kernel);
|
||||||
GetAllUInputByCNode(kernel, cnode_to_u_inputs);
|
GetAllUInputByCNode(kernel, cnode_to_monad_inputs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the first input of update state should be linked, if the other inputs of update state has depend the first
|
// Check if the first input of update state should be linked, if the other inputs of update state has depend the first
|
||||||
// input, it would not be linked.
|
// input, it would not be linked.
|
||||||
bool IsNeedLinkForFirstInput(const CNodePtr &cnode,
|
bool IsNeedLinkForFirstInput(const CNodePtr &cnode,
|
||||||
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_u_inputs) {
|
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_monad_inputs) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (cnode->inputs().size() <= kUpdateStateStateInput) {
|
if (cnode->inputs().size() <= kUpdateStateStateInput) {
|
||||||
MS_LOG(EXCEPTION) << "Invalid update state node:" << cnode->DebugString();
|
MS_LOG(EXCEPTION) << "Invalid update state node:" << cnode->DebugString();
|
||||||
|
@ -1328,8 +1329,8 @@ bool IsNeedLinkForFirstInput(const CNodePtr &cnode,
|
||||||
MS_EXCEPTION_IF_NULL(u_input);
|
MS_EXCEPTION_IF_NULL(u_input);
|
||||||
for (size_t i = kUpdateStateRealInput; i < cnode->inputs().size(); ++i) {
|
for (size_t i = kUpdateStateRealInput; i < cnode->inputs().size(); ++i) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode->input(i));
|
MS_EXCEPTION_IF_NULL(cnode->input(i));
|
||||||
const auto &iter = cnode_to_u_inputs.find(cnode->input(i));
|
const auto &iter = cnode_to_monad_inputs.find(cnode->input(i));
|
||||||
if (iter != cnode_to_u_inputs.end() && iter->second.find(u_input) != iter->second.end()) {
|
if (iter != cnode_to_monad_inputs.end() && iter->second.find(u_input) != iter->second.end()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1406,9 +1407,9 @@ void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
|
||||||
MS_EXCEPTION_IF_NULL(communication_nodes);
|
MS_EXCEPTION_IF_NULL(communication_nodes);
|
||||||
|
|
||||||
// Collect all the depend updatestate nodes of the kernels for linking control arrow.
|
// Collect all the depend updatestate nodes of the kernels for linking control arrow.
|
||||||
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> cnode_to_u_inputs;
|
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> cnode_to_monad_inputs;
|
||||||
MS_LOG(INFO) << "Get all u input of cnode in graph:" << graph->ToString() << " start.";
|
MS_LOG(INFO) << "Get all u input of cnode in graph:" << graph->ToString() << " start.";
|
||||||
GetAllCNodeUInputByGraph(graph, &cnode_to_u_inputs);
|
GetAllCNodeUInputByGraph(graph, &cnode_to_monad_inputs);
|
||||||
MS_LOG(INFO) << "Get all u input of cnode in graph:" << graph->ToString() << " end.";
|
MS_LOG(INFO) << "Get all u input of cnode in graph:" << graph->ToString() << " end.";
|
||||||
|
|
||||||
auto &execution_order = graph->execution_order();
|
auto &execution_order = graph->execution_order();
|
||||||
|
@ -1432,7 +1433,7 @@ void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
|
||||||
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
|
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
|
||||||
if (SchedulerHelper::HasMonadControl(input_node, graph)) {
|
if (SchedulerHelper::HasMonadControl(input_node, graph)) {
|
||||||
LinkControlArrowByAutoMonad(kernel_actor, input_node, graph, graph_compiler_info.control_node_parser_,
|
LinkControlArrowByAutoMonad(kernel_actor, input_node, graph, graph_compiler_info.control_node_parser_,
|
||||||
cnode_to_u_inputs);
|
cnode_to_monad_inputs);
|
||||||
}
|
}
|
||||||
if (HasAbstractMonad(input_node)) {
|
if (HasAbstractMonad(input_node)) {
|
||||||
(void)auto_monad_actors->emplace_back(kernel_actor);
|
(void)auto_monad_actors->emplace_back(kernel_actor);
|
||||||
|
@ -1693,7 +1694,7 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor,
|
||||||
|
|
||||||
void GraphScheduler::LinkControlArrowByAutoMonad(
|
void GraphScheduler::LinkControlArrowByAutoMonad(
|
||||||
AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph, const ControlNodeParserPtr &parser,
|
AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph, const ControlNodeParserPtr &parser,
|
||||||
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_u_inputs) {
|
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_monad_inputs) {
|
||||||
MS_EXCEPTION_IF_NULL(to_actor);
|
MS_EXCEPTION_IF_NULL(to_actor);
|
||||||
MS_EXCEPTION_IF_NULL(from_node);
|
MS_EXCEPTION_IF_NULL(from_node);
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
@ -1712,7 +1713,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(
|
||||||
if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) {
|
if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) {
|
||||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||||
for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
|
for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
|
||||||
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph, parser, cnode_to_u_inputs);
|
LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph, parser, cnode_to_monad_inputs);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -1732,7 +1733,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(
|
||||||
real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend));
|
real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend));
|
||||||
real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
|
real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
|
||||||
} else if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) {
|
} else if (common::AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) {
|
||||||
if (IsNeedLinkForFirstInput(input_cnode, cnode_to_u_inputs) &&
|
if (IsNeedLinkForFirstInput(input_cnode, cnode_to_monad_inputs) &&
|
||||||
input_cnode->inputs().size() > kUpdateStateStateInput) {
|
input_cnode->inputs().size() > kUpdateStateStateInput) {
|
||||||
// If all other inputs of the update state do not depend on the first input, we need to link control arrow
|
// If all other inputs of the update state do not depend on the first input, we need to link control arrow
|
||||||
// for the first input.
|
// for the first input.
|
||||||
|
@ -1780,7 +1781,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(
|
||||||
|
|
||||||
// The monad node and make tuple node need recursion.
|
// The monad node and make tuple node need recursion.
|
||||||
if (IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
|
if (IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
|
||||||
LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph, parser, cnode_to_u_inputs);
|
LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph, parser, cnode_to_monad_inputs);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -168,12 +168,13 @@ class BACKEND_EXPORT GraphScheduler {
|
||||||
const KernelWithIndex &to_kernel_with_input_idx);
|
const KernelWithIndex &to_kernel_with_input_idx);
|
||||||
|
|
||||||
// 2. The processing of linking control arrows.
|
// 2. The processing of linking control arrows.
|
||||||
// The parameter cnode_to_u_inputs contains all the update states that each cnode in the graph depends on. When
|
// The parameter cnode_to_monad_inputs contains all the update states that each cnode in the graph depends on. When
|
||||||
// processing the first input of update state, the map is used to check whether it is necessary to link control
|
// processing the first input of update state, the map is used to check whether it is necessary to link control arrow
|
||||||
// arrow for the first input of update state.
|
// for the first input of update state.
|
||||||
void LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph,
|
void LinkControlArrowByAutoMonad(
|
||||||
const ControlNodeParserPtr &parser = nullptr,
|
AbstractActor *to_actor, const AnfNodePtr &from_node, const KernelGraphPtr &graph,
|
||||||
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_u_inputs = {});
|
const ControlNodeParserPtr &parser = nullptr,
|
||||||
|
const mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> &cnode_to_monad_inputs = {});
|
||||||
// The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
|
// The skipped node doesn't run, so need link the control arrow between the inputs and user of skipped node.
|
||||||
void LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node,
|
void LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node,
|
||||||
const KernelGraphPtr &graph) const;
|
const KernelGraphPtr &graph) const;
|
||||||
|
|
Loading…
Reference in New Issue