From 1e8997f4cb55b0808f76f5c8bdbd1c763b2b16f7 Mon Sep 17 00:00:00 2001 From: kswang Date: Thu, 23 Apr 2020 20:56:42 +0800 Subject: [PATCH] optimize sort for mem reuse and fix memreuse bug --- .../mem_reuse/mem_reuse_allocator.cc | 3 +- .../ccsrc/session/anf_runtime_algorithm.cc | 15 +- .../ccsrc/session/anf_runtime_algorithm.h | 1 - mindspore/ccsrc/session/kernel_graph.cc | 137 +++++++++--------- mindspore/ccsrc/session/kernel_graph.h | 5 +- 5 files changed, 78 insertions(+), 83 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc index 1cecd170d30..8a3647d9801 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse_allocator.cc @@ -251,9 +251,10 @@ void BestFitMemReuse::ReleaseNodeUnusedOutput(const KernelDef *kernel_def_ptr) { } size_t BestFitMemReuse::FindIndx(const std::vector &membuf_ptr_list, int fac_idx) const { - size_t membuf_index = 0; + size_t membuf_index = membuf_ptr_list.size(); for (size_t n = 0; n < membuf_ptr_list.size(); ++n) { auto membuf = membuf_ptr_list[n]; + MS_EXCEPTION_IF_NULL(membuf); if (membuf->index_ == fac_idx) { membuf_index = n; break; diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 525ff44dd8b..3d5be5298ae 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -851,17 +851,12 @@ void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &i bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_name = AnfAlgo::GetCNodeName(node); - auto kernel_type = AnfAlgo::GetKernelType(node); - if (kernel_name == kAllReduceOpName || kernel_type == HCCL_KERNEL) { - return true; + if (!node->isa()) { + return false; } - return false; -} - -bool AnfRuntimeAlgorithm::IsAllReduceOp(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa() && AnfAlgo::GetCNodeName(node) == kAllReduceOpName) { + auto kernel_name = AnfAlgo::GetCNodeName(node); + if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName || + kernel_name == kReduceScatterOpName) { return true; } return false; diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 78359cdd5a4..a70a63b6786 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -176,7 +176,6 @@ class AnfRuntimeAlgorithm { // get real input index for some tbe ops which input order is different between me and tbe impl static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); static bool IsCommunicationOp(const AnfNodePtr &node); - static bool IsAllReduceOp(const AnfNodePtr &node); static bool IsGetNext(const NotNull &node); }; } // namespace session diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index 139539ccb23..cdadf389a6f 100755 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -49,80 +49,81 @@ std::vector KernelGraph::outputs() const { return std::vector(); } +void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, + std::unordered_set *visited_nodes) { + MS_EXCEPTION_IF_NULL(visit_queue); + MS_EXCEPTION_IF_NULL(visited_nodes); + auto it = node_output_edges_.find(node); + if (it == node_output_edges_.end()) { + // value node and parameter has no input,no need to print log + if (node->isa()) { + MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; + } + return; + } + + // visit all reduce node first, then other nodes + std::vector active_nodes; + for (const auto &output_edge : it->second) { + auto next_node = output_edge.first; + if (node_input_num_.find(next_node) == node_input_num_.end()) { + MS_EXCEPTION_IF_NULL(next_node); + MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]"; + } + MS_EXCEPTION_IF_NULL(next_node); + MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() + << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; + if (node_input_num_[next_node] < output_edge.second) { + MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node] + << ",depend edge:" << output_edge.second; + } + node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; + // allreduce first + if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) { + (void)visited_nodes->insert(next_node); + if (AnfAlgo::IsCommunicationOp(next_node)) { + MS_LOG(DEBUG) << "visit node:" << next_node->DebugString(); + visit_queue->push(next_node); + } else { + active_nodes.emplace_back(next_node); + } + } + } + + for (auto &node : active_nodes) { + MS_LOG(DEBUG) << "visit node:" << node->DebugString(); + visit_queue->push(node); + } +} + void KernelGraph::SetExecOrderByDefault() { - std::stack seed_nodes; + std::queue seed_nodes; UpdateNodeEdgeList(&seed_nodes); execution_order_.clear(); std::unordered_set visited_nodes; std::queue zero_input_nodes; - - auto visit_node_descendant = [&visited_nodes, this](const AnfNodePtr &node, std::queue *visit_queue) { - auto it = node_output_edges_.find(node); - if (it == node_output_edges_.end()) { - // value node and parameter has no input,no need to print log - if (node->isa()) { - MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]"; - } - return; - } - - // visit all reduce node first, then other nodes - std::vector active_nodes; - for (const auto &output_edge : it->second) { - auto next_node = output_edge.first; - if (node_input_num_.find(next_node) == node_input_num_.end()) { - MS_EXCEPTION_IF_NULL(next_node); - MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]"; - } - MS_EXCEPTION_IF_NULL(next_node); - MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() - << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; - if (node_input_num_[next_node] < output_edge.second) { - MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" - << node_input_num_[next_node] << ",depend edge:" << output_edge.second; - } - node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; - // allreduce first - if (node_input_num_[next_node] == 0 && visited_nodes.find(next_node) == visited_nodes.end()) { - (void)visited_nodes.insert(next_node); - if (AnfAlgo::IsAllReduceOp(next_node)) { - MS_LOG(DEBUG) << "visit node:" << next_node->DebugString(); - visit_queue->push(next_node); - } else { - active_nodes.emplace_back(next_node); - } - } - } - - for (auto &node : active_nodes) { - MS_LOG(DEBUG) << "visit node:" << node->DebugString(); - visit_queue->push(node); - } - }; - - AnfNodePtr last_allreduce_node = nullptr; - std::queue allreduce_descendants; - while (!seed_nodes.empty() || last_allreduce_node != nullptr) { + AnfNodePtr last_communication_node = nullptr; + std::queue communication_descendants; + while (!seed_nodes.empty() || last_communication_node != nullptr) { // seed nodes first, then visit last all reduce node descendant if (seed_nodes.empty()) { - visit_node_descendant(last_allreduce_node, &allreduce_descendants); - last_allreduce_node = nullptr; + VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); + last_communication_node = nullptr; } else { - zero_input_nodes.push(seed_nodes.top()); + zero_input_nodes.push(seed_nodes.front()); seed_nodes.pop(); } - // all reduce node descendant first, then common queue - while (!zero_input_nodes.empty() || !allreduce_descendants.empty()) { + while (!zero_input_nodes.empty() || !communication_descendants.empty()) { AnfNodePtr node = nullptr; - bool is_allreduce_descendant = false; - if (allreduce_descendants.empty()) { + bool is_communication_descendant = false; + if (communication_descendants.empty()) { node = zero_input_nodes.front(); zero_input_nodes.pop(); } else { - node = allreduce_descendants.front(); - allreduce_descendants.pop(); - is_allreduce_descendant = true; + node = communication_descendants.front(); + communication_descendants.pop(); + is_communication_descendant = true; } // add execute node MS_EXCEPTION_IF_NULL(node); @@ -130,19 +131,18 @@ void KernelGraph::SetExecOrderByDefault() { execution_order_.push_back(node->cast()); } // for all reduce node, visit last all reduce node descendant - if (AnfAlgo::IsAllReduceOp(node)) { - if (last_allreduce_node != nullptr) { - visit_node_descendant(last_allreduce_node, &allreduce_descendants); + if (AnfAlgo::IsCommunicationOp(node)) { + if (last_communication_node != nullptr) { + VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); } - last_allreduce_node = node; - } else if (is_allreduce_descendant) { - visit_node_descendant(node, &allreduce_descendants); + last_communication_node = node; + } else if (is_communication_descendant) { + VisitNodeDescendants(node, &communication_descendants, &visited_nodes); } else { - visit_node_descendant(node, &zero_input_nodes); + VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes); } } } - CheckLoop(); } @@ -467,7 +467,7 @@ bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue *seed_nodes) { +void KernelGraph::UpdateNodeEdgeList(std::queue *seed_nodes) { node_output_edges_.clear(); node_input_num_.clear(); node_input_edges_.clear(); @@ -483,7 +483,6 @@ void KernelGraph::UpdateNodeEdgeList(std::stack *seed_nodes) { seed_nodes->push(node); continue; } - if (!node->isa()) { continue; } diff --git a/mindspore/ccsrc/session/kernel_graph.h b/mindspore/ccsrc/session/kernel_graph.h index 54b16014a3d..a33e8f7bd6d 100755 --- a/mindspore/ccsrc/session/kernel_graph.h +++ b/mindspore/ccsrc/session/kernel_graph.h @@ -22,7 +22,6 @@ #include #include #include -#include #include #include #include "ir/func_graph.h" @@ -94,8 +93,10 @@ class KernelGraph : public FuncGraph { private: // remove value node form graph bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); + void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, + std::unordered_set *visited_nodes); // update node edge list - void UpdateNodeEdgeList(std::stack *seed_nodes); + void UpdateNodeEdgeList(std::queue *seed_nodes); // add node depend edge by data edge or control depend void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); // handle control depend