From 044feecbd750e1ce6cc3e4030d55cf4998d1aee4 Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Mon, 31 Oct 2022 14:52:22 +0800 Subject: [PATCH] Zero copy for atomic addr clean op. Free the relationship between formal and real parameter. not skip the nopnode which is an ouput. --- .../common/session/kernel_graph_mgr.cc | 28 ++++++- .../backend/common/session/kernel_graph_mgr.h | 5 +- .../hal/device/tasksink/rtmodel_zero_copy.cc | 77 +++++++++++++++---- .../runtime/graph_scheduler/graph_compiler.cc | 12 +-- 4 files changed, 95 insertions(+), 27 deletions(-) diff --git a/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.cc b/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.cc index 704493be588..6b203b61986 100644 --- a/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.cc +++ b/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.cc @@ -252,7 +252,11 @@ void KernelGraphMgr::InitInternalOutputParameter(const AnfNodePtr &out_node, con builder.SetOutputsDeviceType({type}); builder.SetOutputsFormat({format}); d_kernel_info->set_select_kernel_build_info(builder.Build()); - AnfAlgo::SetOutputAddr(address, 0, parameter.get()); + // If the flag is enable, it means the graph would run in subgraph sink mode, the internal parameter cannot share + // the same device address. + if (!node_graph->has_flag(kFlagEnableZeroCopyInGraph)) { + AnfAlgo::SetOutputAddr(address, 0, parameter.get()); + } auto abstract = std::make_shared(TypeIdToType(type), parameter->Shape()->cast()); parameter->set_abstract(abstract); @@ -873,10 +877,16 @@ void KernelGraphMgr::AddParameterToGraphInputs(const std::vector &pa } KernelGraphPtr KernelGraphMgr::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs, - DeviceType device_target, bool common_opt) { + DeviceType device_target, bool common_opt, + bool is_enable_zero_copy) { mindspore::HashMap other_graph_cnode; auto graph = NewKernelGraph(); MS_EXCEPTION_IF_NULL(graph); + // Set the zero copy flag in subgraph sink mode. + if (is_enable_zero_copy) { + MS_LOG(INFO) << "Set zero copy flag for graph:" << graph->ToString(); + graph->set_flag(kFlagEnableZeroCopyInGraph, true); + } MS_LOG(INFO) << "Create graph: " << graph->graph_id(); for (const auto &node : lst) { MS_EXCEPTION_IF_NULL(node); @@ -1146,6 +1156,17 @@ std::string KernelGraphMgr::AddPartialParametersMap(const AnfNodePtr &partial_no return graph_target; } +namespace { +bool IsNeedAddPartialParameter(const AnfNodePtr &user, const std::string &kernel_target, + const std::shared_ptr &graph) { + // If the flag is enable, it means the graph would run in subgraph sink mode, the real parameter on partial + // cannot share the same device address with the formal parameter. + MS_EXCEPTION_IF_NULL(graph); + return common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice && + !ExistGraphCaller(user) && (!graph->has_flag(kFlagEnableZeroCopyInGraph)); +} +} // namespace + void KernelGraphMgr::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node, const FuncGraphManagerPtr &front_func_graph_manager, const std::shared_ptr &backend_graph) { @@ -1176,8 +1197,7 @@ void KernelGraphMgr::HandleInternalOutput(const AnfNodePtr &input_front_node, co if (internal_output) { auto users = ExtendNodeUsers(front_func_graph_manager, front_node); for (auto &user : users) { - if (common::AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice && - !ExistGraphCaller(user)) { + if (IsNeedAddPartialParameter(user, kernel_target, backend_graph)) { auto partial_target = AddPartialParametersMap(user); if (partial_target != kNoTarget && partial_target != kernel_target) { unique_target = false; diff --git a/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.h b/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.h index d68a5a1bb97..f660b7fd321 100644 --- a/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.h +++ b/mindspore/ccsrc/backend/common/session/kernel_graph_mgr.h @@ -46,9 +46,12 @@ class BACKEND_EXPORT KernelGraphMgr { KernelGraphMgr() {} virtual ~KernelGraphMgr() {} + // The parameter is_enable_zero_copy means if the parameter in graph can avoid copy when it is executed, and it is + // true in subgraph sink mode, and the device address shared for partial parameters and internal parameters in graph + // would be disabled. std::shared_ptr ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs, DeviceType device_target = DeviceType::kUnknown, - bool common_opt = true); + bool common_opt = true, bool is_enable_zero_copy = false); std::shared_ptr ConstructKernelGraph(const FuncGraphPtr &func_graph, std::vector *all_out_graph, diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.cc b/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.cc index e5336673e52..8c3a6f1176a 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/device/tasksink/rtmodel_zero_copy.cc @@ -335,33 +335,83 @@ bool ZeroCopyTask::UpdateArgs(void *stream) { } namespace { -void GenerateZeroCopyTaskForInput(const AnfNodePtr &node, const TaskPtr &task, const session::KernelGraph &graph, +std::vector GetInputNodeWithIndex(const CNodePtr &node, const TaskPtr &task, + const std::vector &output_with_indexs, + std::set> *node_to_offset) { + std::vector input_node_with_indexs; + auto input_num = common::AnfAlgo::GetInputTensorNum(node); + if (common::AnfAlgo::GetCNodeName(node) == kAtomicAddrCleanOpName) { + // For atomic addr clean op, the args in task is not the input node of kernel, we should get the real input index + // from the input node. + for (size_t i = 0; i < input_num; ++i) { + const auto &input = node->input(i + 1); + MS_EXCEPTION_IF_NULL(input); + if (input->isa() && common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, input->cast())) { + auto clean_output_indexs = common::AnfAlgo::GetNodeAttr>(input, kAttrAtomicOutputIndexs); + for (auto index : clean_output_indexs) { + MS_LOG(DEBUG) << "atomic addr clean index:" << index << " for node:" << input->fullname_with_scope(); + input_node_with_indexs.emplace_back(input, index); + } + } + } + if (input_node_with_indexs.size() != (task->ArgsSize() / sizeof(void *))) { + MS_LOG(ERROR) << "Invalid input size:" << input_node_with_indexs.size() + << " task size:" << (task->ArgsSize() / sizeof(void *)) << " for node:" << node->DebugString(); + } + } else { + for (size_t i = 0; i < input_num; ++i) { + if (node_to_offset->find(std::make_pair(node, i)) != node_to_offset->end()) { + input_node_with_indexs.emplace_back(nullptr, i); + continue; + } + + size_t input_index_in_graph = AnfAlgo::GetInputGraphIdxByKernelIdx(node, i); + KernelWithIndex input_with_index{node, input_index_in_graph}; + do { + input_with_index = common::AnfAlgo::GetPrevNodeOutput(input_with_index.first, input_with_index.second, false); + if (std::find_if(output_with_indexs.begin(), output_with_indexs.end(), + [input_with_index](const KernelWithIndex &output) { + const auto &real_output = common::AnfAlgo::FetchRealNodeSkipMonadControl(output); + return real_output == input_with_index; + }) != output_with_indexs.end()) { + break; + } + } while (input_with_index.first != nullptr && common::AnfAlgo::IsNopNode(input_with_index.first)); + MS_LOG(DEBUG) << "Add input node:" << input_with_index.first->fullname_with_scope() + << " index:" << input_with_index.second << " for node:" << node->fullname_with_scope(); + input_node_with_indexs.emplace_back(input_with_index); + } + } + return input_node_with_indexs; +} + +void GenerateZeroCopyTaskForInput(const CNodePtr &node, const TaskPtr &task, const session::KernelGraph &graph, std::vector *zero_copy_tasks, std::set> *node_to_offset) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(zero_copy_tasks); MS_EXCEPTION_IF_NULL(node_to_offset); - auto input_num = common::AnfAlgo::GetInputTensorNum(node); const auto &output_with_indexs = common::AnfAlgo::GetAllOutputWithIndex(graph.output()); const auto &ref_node_map = graph.GetRefMap(); - for (size_t i = 0; i < input_num; ++i) { - if (node_to_offset->find(std::make_pair(node, i)) != node_to_offset->end()) { + std::vector input_node_with_indexs = + GetInputNodeWithIndex(node, task, output_with_indexs, node_to_offset); + + for (size_t i = 0; i < input_node_with_indexs.size(); ++i) { + KernelWithIndex input_with_index = input_node_with_indexs[i]; + const auto input = input_with_index.first; + if (input == nullptr || node_to_offset->find(std::make_pair(node, i)) != node_to_offset->end()) { continue; } - size_t input_index_in_graph = AnfAlgo::GetInputGraphIdxByKernelIdx(node, i); - const auto &input_with_index = common::AnfAlgo::GetPrevNodeOutput(node, input_index_in_graph, true); - const auto input = input_with_index.first; - MS_EXCEPTION_IF_NULL(input); if (input->isa()) { // 1. Input parameter. zero_copy_tasks->emplace_back( std::make_shared(input, task->Args(), i * sizeof(void *), task->task_name())); node_to_offset->emplace(node, i); MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i - << " ptr from parameter input:" << input->DebugString(); + << " ptr from parameter input:" << input->fullname_with_scope(); } else if (input->isa()) { // 2. Input which is graph output. if (std::find_if(output_with_indexs.begin(), output_with_indexs.end(), @@ -373,7 +423,8 @@ void GenerateZeroCopyTaskForInput(const AnfNodePtr &node, const TaskPtr &task, c input, input_with_index.second, task->Args(), i * sizeof(void *), task->task_name())); node_to_offset->emplace(node, i); MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i - << " ptr from cnode input:" << input->DebugString() << " cnode index:" << input_with_index.second; + << " ptr from cnode input:" << input->fullname_with_scope() + << " cnode index:" << input_with_index.second; } else { // 3. Input which is a ref node whose input is a parameter, like: // refnode(parameter, node1) @@ -385,7 +436,7 @@ void GenerateZeroCopyTaskForInput(const AnfNodePtr &node, const TaskPtr &task, c zero_copy_tasks->emplace_back(std::make_shared( parameter, task->Args(), i * sizeof(void *), task->task_name())); MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i - << " ptr from parameter input:" << parameter->DebugString(); + << " ptr from parameter input:" << parameter->fullname_with_scope(); node_to_offset->emplace(node, i); } } @@ -427,7 +478,7 @@ void GenerateZeroCopyTaskForOutput(const AnfNodePtr &node, const TaskPtr &task, std::make_shared(ref_iter->second.first, ref_iter->second.second, task->Args(), input_index * sizeof(void *), task->task_name())); MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " input index:" << i - << " ptr from cnode input:" << ref_iter->second.first->DebugString() + << " ptr from cnode input:" << ref_iter->second.first->fullname_with_scope() << " cnode index:" << ref_iter->second.second; node_to_offset->emplace(node, input_index); zero_copy_ref_nodes->emplace(ref_iter->second); @@ -436,7 +487,7 @@ void GenerateZeroCopyTaskForOutput(const AnfNodePtr &node, const TaskPtr &task, zero_copy_tasks->emplace_back(std::make_shared( ref_iter->second.first, task->Args(), (input_num + i) * sizeof(void *), task->task_name())); MS_LOG(DEBUG) << "Add zero copy task for node:" << node->fullname_with_scope() << " output index:" << i - << " ptr from parameter input:" << ref_iter->second.first->DebugString(); + << " ptr from parameter input:" << ref_iter->second.first->fullname_with_scope(); node_to_offset->emplace(node, input_num + i); } } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc index 1139853d591..ae55067af2e 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc @@ -255,12 +255,11 @@ void OptimizeNopNode(KernelGraph *graph) { graph->set_execution_order(new_execution_order); } -bool SetZeroCopyFlag(const KernelGraphPtr &graph, bool run_in_pynative) { +bool IsEnableZeroCopy(bool run_in_pynative) { if (run_in_pynative) { return false; } - MS_EXCEPTION_IF_NULL(graph); auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); bool task_sink = ms_context->get_param(MS_CTX_ENABLE_TASK_SINK); @@ -283,8 +282,6 @@ bool SetZeroCopyFlag(const KernelGraphPtr &graph, bool run_in_pynative) { if (common::GetEnv("DISABLE_ZERO_COPY") == "1") { return false; } - MS_LOG(INFO) << "Set zero copy flag for graph:" << graph->ToString(); - graph->set_flag(kFlagEnableZeroCopyInGraph, true); return true; } } // namespace @@ -299,7 +296,8 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod auto nodes = segment->nodes_; auto device_terget = device_context->GetDeviceType(); // Generate kernel graph. - KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs, device_terget); + KernelGraphPtr graph = + session_->ConstructKernelGraph(nodes, outputs, device_terget, true, IsEnableZeroCopy(run_in_pynative)); MS_EXCEPTION_IF_NULL(graph); opt::EliminateIllegalDataTypePass(graph); SetGraphDependency(graph, segment); @@ -477,10 +475,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id())); } #endif - // If the zero copy flag has been set in graph, the relationship between partial and parameter should be disabled. - if (SetZeroCopyFlag(graph, run_in_pynative)) { - session_->ClearPartialParameterMap(); - } MS_EXCEPTION_IF_NULL(device_context->kernel_executor_); // Execute optimization pass. device_context->kernel_executor_->OptimizeGraph(graph);