From 1740aac860beb87e1d3bddcfa48692a3739a8ce5 Mon Sep 17 00:00:00 2001 From: zhaosida Date: Wed, 31 Mar 2021 14:50:50 +0800 Subject: [PATCH] fusion atomic clear node --- .../ccsrc/backend/optimizer/somas/somas.cc | 89 ++++++-------- .../device/ascend/kernel_build_ascend.cc | 116 ++++++++++++++++-- .../device/ascend/tasksink/task_generator.cc | 62 +++++----- 3 files changed, 178 insertions(+), 89 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/somas/somas.cc b/mindspore/ccsrc/backend/optimizer/somas/somas.cc index 8c6928e7a75..64c9c62ed48 100644 --- a/mindspore/ccsrc/backend/optimizer/somas/somas.cc +++ b/mindspore/ccsrc/backend/optimizer/somas/somas.cc @@ -567,59 +567,46 @@ void Somas::InitAtomicCleanInputs(bool is_all_nop_node, const CNodePtr &kernel) auto stream = node->GetStream(); MS_EXCEPTION_IF_NULL(stream); - MS_EXCEPTION_IF_NULL(kernel->inputs()[1]); - auto pre_node = (kernel->inputs()[1])->cast(); - auto iter = nodes_map_.find(pre_node.get()); - if (iter == nodes_map_.end()) { - MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input [" << pre_node->fullname_with_scope() - << "] is not init."; - } - auto pre_somas_node = iter->second; - // set clean output tensors - if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { - auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); - for (auto index : clean_output_indexs) { - if (index > pre_somas_node->output_tensors_.size()) { - MS_LOG(EXCEPTION) << "Output index " << index << " exceed input node [" << pre_node->fullname_with_scope() - << "]'s outputs size " << pre_somas_node->output_tensors_.size(); - } - auto input_somas_tensor = pre_somas_node->output_tensors_[index]; - MS_EXCEPTION_IF_NULL(input_somas_tensor); - node->input_tensors_.push_back(input_somas_tensor); - input_somas_tensor->destinations_.insert(node); - input_somas_tensor->destinationStreams_.insert(stream); - if (input_somas_tensor->lifetime_.start_ > node->GetId()) { - input_somas_tensor->lifetime_.start_ = node->GetId(); - } - node->ancestor_nodes_.insert(pre_somas_node); - auto input_tensor_stream = input_somas_tensor->GetSourceStream(); - if (input_tensor_stream != stream) { - stream->ancestor_streams_.insert(input_tensor_stream); - input_somas_tensor->between_streams_ = true; + auto input_tensor_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_tensor_num; i++) { + MS_EXCEPTION_IF_NULL(kernel->inputs()[i + 1]); + auto pre_node = kernel->input(i + 1)->cast(); + auto iter = nodes_map_.find(pre_node.get()); + if (iter == nodes_map_.end()) { + MS_LOG(EXCEPTION) << "Kernel[" << kernel->fullname_with_scope() << "]'s input [" + << pre_node->fullname_with_scope() << "] is not init."; + } + auto pre_somas_node = iter->second; + // set clean output tensors + if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { + auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); + for (auto index : clean_output_indexs) { + if (index > pre_somas_node->output_tensors_.size()) { + MS_LOG(EXCEPTION) << "Output index " << index << " exceed input node [" << pre_node->fullname_with_scope() + << "]'s outputs size " << pre_somas_node->output_tensors_.size(); + } + auto input_somas_tensor = pre_somas_node->output_tensors_[index]; + MS_EXCEPTION_IF_NULL(input_somas_tensor); + node->input_tensors_.push_back(input_somas_tensor); + input_somas_tensor->lifelong_value_ = kLifeLongGraphAll; + MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_ + << " 's output" << index << " to lifelong"; } } - } - // set clean workspace tensors - if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { - auto clean_workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); - for (const auto &index : clean_workspace_indexs) { - if (index > pre_somas_node->output_tensors_.size()) { - MS_LOG(EXCEPTION) << "Workspace index " << index << " exceed input node [" << pre_node->fullname_with_scope() - << "]'s Workspace size " << pre_somas_node->workspace_tensors_.size(); - } - auto input_somas_tensor = pre_somas_node->workspace_tensors_[index]; - MS_EXCEPTION_IF_NULL(input_somas_tensor); - node->input_tensors_.push_back(input_somas_tensor); - input_somas_tensor->destinations_.insert(node); - input_somas_tensor->destinationStreams_.insert(stream); - if (input_somas_tensor->lifetime_.start_ > node->GetId()) { - input_somas_tensor->lifetime_.start_ = node->GetId(); - } - node->ancestor_nodes_.insert(pre_somas_node); - auto input_tensor_stream = input_somas_tensor->GetSourceStream(); - if (input_tensor_stream != stream) { - stream->ancestor_streams_.insert(input_tensor_stream); - input_somas_tensor->between_streams_ = true; + // set clean workspace tensors + if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { + auto clean_workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); + for (const auto &index : clean_workspace_indexs) { + if (index > pre_somas_node->output_tensors_.size()) { + MS_LOG(EXCEPTION) << "Workspace index " << index << " exceed input node [" << pre_node->fullname_with_scope() + << "]'s Workspace size " << pre_somas_node->workspace_tensors_.size(); + } + auto input_somas_tensor = pre_somas_node->workspace_tensors_[index]; + MS_EXCEPTION_IF_NULL(input_somas_tensor); + node->input_tensors_.push_back(input_somas_tensor); + input_somas_tensor->lifelong_value_ = kLifeLongGraphAll; + MS_LOG(INFO) << "Set " << node->scope_full_name_ << "'s Input node " << pre_somas_node->scope_full_name_ + << " 's workspace" << index << " to lifelong"; } } } diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc index 988cdd07c4d..808b1c4bd9b 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc @@ -40,6 +40,8 @@ namespace device { namespace ascend { using mindspore::kernel::tbe::TbeUtils; using std::make_shared; +constexpr size_t kMaxAttrMemListSize = 192; + static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) { kernel::KernelModPtr kernel_mod_ptr = nullptr; KernelType kernel_type = AnfAlgo::GetKernelType(anf_node); @@ -159,6 +161,30 @@ static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_gr new_nodes->push_back(clear_zero); } +static void AddFusionTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph, + const mindspore::CNodePtr &stream_node, + const std::vector &fusion_clear_inputs, + const std::vector &clean_size_list, + std::vector *new_nodes) { + auto clear_zero_prim = std::make_shared(kAtomicAddrCleanOpName); + MS_EXCEPTION_IF_NULL(clear_zero_prim); + auto new_value_node = NewValueNode(clear_zero_prim); + MS_EXCEPTION_IF_NULL(new_value_node); + std::vector inputs = {new_value_node}; + inputs.insert(inputs.end(), fusion_clear_inputs.begin(), fusion_clear_inputs.end()); + CNodePtr clear_zero = kernel_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(clear_zero); + AbstractBasePtr abstract = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract); + clear_zero->set_abstract(abstract); + auto builder = std::make_shared(); + builder->SetKernelType(KernelType::TBE_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get()); + AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size_list), clear_zero); + AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(stream_node.get()), clear_zero.get()); + new_nodes->insert(new_nodes->begin(), clear_zero); +} + static bool IsAtomicNode(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node); @@ -264,23 +290,23 @@ std::map> GetCommunicationOpInputInfo( return comm_input_info_map; } -void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); +static void TbeClearZeroNodeFusion(mindspore::session::KernelGraph *const kernel_graph) { std::vector new_nodes; + std::vector clean_size_list; + std::vector fusion_clear_inputs; std::map> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph); for (const auto &anf_node : kernel_graph->execution_order()) { std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); bool is_comm_input = false; + // set communication input output index attr if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) { auto indexes = comm_input_info_map[anf_node]; AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node); is_comm_input = true; } - if (is_comm_input) { - AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); - } else if (apply_function_name == prim::kPrimMaxPoolGrad->name() && - AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { + if (apply_function_name == prim::kPrimMaxPoolGrad->name() && + AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { auto clear_zero_prim = std::make_shared(kClearZeroOpName); MS_EXCEPTION_IF_NULL(clear_zero_prim); auto new_value_node = NewValueNode(clear_zero_prim); @@ -299,15 +325,85 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { // set the distinction label of clear same with anf AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get()); new_nodes.push_back(clear_zero); - } else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) { - if (IsAtomicNode(anf_node)) { - AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); + } else if (is_comm_input || + (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL && IsAtomicNode(anf_node))) { + auto clean_sizes = CalCleanZerosSize(anf_node); + if (!clean_sizes.empty()) { + auto clean_total_num = clean_size_list.size() + clean_sizes.size(); + if (clean_total_num >= kMaxAttrMemListSize) { + // create clean node + auto stream_node = new_nodes.empty() ? anf_node : new_nodes.front(); + AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes); + clean_size_list.clear(); + fusion_clear_inputs.clear(); + } + clean_size_list.insert(clean_size_list.end(), clean_sizes.begin(), clean_sizes.end()); + fusion_clear_inputs.emplace_back(anf_node); + MS_LOG(DEBUG) << "fusion_clear_inputs size: " << fusion_clear_inputs.size() + << ", clean_size_list: " << clean_size_list.size(); } } - new_nodes.push_back(anf_node); + new_nodes.emplace_back(anf_node); + } + + if (!fusion_clear_inputs.empty() && !clean_size_list.empty()) { + // create clean node + auto stream_node = new_nodes.front(); + AddFusionTbeClearZeroNode(kernel_graph, stream_node, fusion_clear_inputs, clean_size_list, &new_nodes); } kernel_graph->set_execution_order(new_nodes); } + +void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + static const auto enable_fusion_clear = (common::GetEnv("ENV_FUSION_CLEAR") == "1"); + bool is_dynamic_graph = kernel_graph->is_dynamic_shape(); + if (!is_dynamic_graph && enable_fusion_clear) { + TbeClearZeroNodeFusion(kernel_graph); + } else { + std::vector new_nodes; + std::map> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph); + for (const auto &anf_node : kernel_graph->execution_order()) { + std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); + bool is_comm_input = false; + if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) { + auto indexes = comm_input_info_map[anf_node]; + AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node); + is_comm_input = true; + } + + if (is_comm_input) { + AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); + } else if (apply_function_name == prim::kPrimMaxPoolGrad->name() && + AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { + auto clear_zero_prim = std::make_shared(kClearZeroOpName); + MS_EXCEPTION_IF_NULL(clear_zero_prim); + auto new_value_node = NewValueNode(clear_zero_prim); + MS_EXCEPTION_IF_NULL(new_value_node); + std::vector inputs = {new_value_node}; + inputs.push_back(anf_node); + CNodePtr clear_zero = kernel_graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(clear_zero); + auto kernel_info = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_info); + clear_zero->set_kernel_info(kernel_info); + AbstractBasePtr abstract = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract); + AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector({"x"})), clear_zero); + SelectKernelInfo(clear_zero); + // set the distinction label of clear same with anf + AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get()); + new_nodes.push_back(clear_zero); + } else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) { + if (IsAtomicNode(anf_node)) { + AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); + } + } + new_nodes.push_back(anf_node); + } + kernel_graph->set_execution_order(new_nodes); + } +} } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc index 660787f877a..944c4550cdc 100644 --- a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc @@ -92,42 +92,48 @@ void TaskGenerator::LaunchAddrCleanAkgKernel(const CNodePtr &anf_node_ptr, Addre void TaskGenerator::LaunchAddrCleanKernel(const CNodePtr &anf_node_ptr, AddressPtrList *kernel_inputs) { MS_EXCEPTION_IF_NULL(anf_node_ptr); MS_EXCEPTION_IF_NULL(kernel_inputs); - if (anf_node_ptr->inputs().size() != 2) { + // akg process + if (AnfAlgo::GetKernelType(anf_node_ptr) == KernelType::AKG_KERNEL) { LaunchAddrCleanAkgKernel(anf_node_ptr, kernel_inputs); return; } - MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[1]); - auto pre_node = (anf_node_ptr->inputs()[1])->cast(); - // set clean output addr - if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { - auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); - for (auto index : clean_output_indexs) { - auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); - kernel::AddressPtr input = std::make_shared(); - MS_EXCEPTION_IF_NULL(input); - input->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(input->addr); - input->size = device_address->size_; - kernel_inputs->push_back(input); + // tbe process + auto input_tensor_num = AnfAlgo::GetInputTensorNum(anf_node_ptr); + for (size_t i = 0; i < input_tensor_num; i++) { + // set clean output addr + MS_EXCEPTION_IF_NULL(anf_node_ptr->inputs()[i + 1]); + auto pre_node = anf_node_ptr->input(i + 1)->cast(); + if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) { + auto clean_output_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicOutputIndexs); + for (auto index : clean_output_indexs) { + auto device_address = AnfAlgo::GetOutputAddr(pre_node, index); + kernel::AddressPtr input = std::make_shared(); + MS_EXCEPTION_IF_NULL(input); + input->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(input->addr); + input->size = device_address->size_; + kernel_inputs->push_back(input); + } + MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); } - MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexs.size(); - } - // set clean workspace address - if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { - auto clean_workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); - for (const auto &index : clean_workspace_indexs) { - auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); - kernel::AddressPtr workspace = std::make_shared(); - MS_EXCEPTION_IF_NULL(workspace); - workspace->addr = device_address->ptr_; - MS_EXCEPTION_IF_NULL(workspace->addr); - workspace->size = device_address->size_; - kernel_inputs->push_back(workspace); + // set clean workspace address + if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) { + auto clean_workspace_indexs = AnfAlgo::GetNodeAttr>(pre_node, kAttrAtomicWorkspaceIndexs); + for (const auto &index : clean_workspace_indexs) { + auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index); + kernel::AddressPtr workspace = std::make_shared(); + MS_EXCEPTION_IF_NULL(workspace); + workspace->addr = device_address->ptr_; + MS_EXCEPTION_IF_NULL(workspace->addr); + workspace->size = device_address->size_; + kernel_inputs->push_back(workspace); + } + MS_LOG(DEBUG) << "AtomicAddClean clean workspace size:" << clean_workspace_indexs.size(); } } auto clear_mems = AnfAlgo::GetNodeAttr>(anf_node_ptr, kAttrAtomicAddMemSize); if (kernel_inputs->size() != clear_mems.size()) { - MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size,kerenl_inputs size:" + MS_LOG(EXCEPTION) << "AtomicAddClean kernel inputs size not equal clear memory size, kernel inputs size:" << kernel_inputs->size() << ",clean mem size" << clear_mems.size(); } }