From f23b92af4c14c0eef63c5eec34a2ec81df1ab054 Mon Sep 17 00:00:00 2001 From: He Wei Date: Thu, 18 Feb 2021 16:52:48 +0800 Subject: [PATCH] [auto-monad] Optimize order list Use OrderedSet for order list to optimize performance. --- .gitignore | 3 + .../jit/static_analysis/auto_monad.cc | 7 +-- mindspore/core/ir/func_graph.cc | 18 +++--- mindspore/core/ir/func_graph.h | 8 +-- mindspore/core/utils/ordered_set.h | 56 +++++++++++++++---- 5 files changed, 62 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index 2258c328f2d..9a0fc864723 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ mindspore/lib output *.ir st_tests +kernel_meta/ +somas_meta/ +trace_code_graph_* # mindspore lite java mindspore/lite/java/java/.gradle diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc index 4c3eae03629..53b34c38625 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc @@ -323,7 +323,7 @@ class SideEffectFinder { } static void UpdateOrderList(const FuncGraphPtr &func_graph) { - std::list new_order_list; + OrderedSet new_order_list; const auto &order_list = func_graph->order_list(); for (auto &cnode : order_list) { PushToOrderList(func_graph, cnode, &new_order_list); @@ -331,10 +331,9 @@ class SideEffectFinder { func_graph->set_order_list(std::move(new_order_list)); } - static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, std::list *new_order_list) { + static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, OrderedSet *new_order_list) { MS_EXCEPTION_IF_NULL(cnode); - auto iter = std::find(new_order_list->begin(), new_order_list->end(), cnode); - if (iter != new_order_list->end()) { + if (new_order_list->contains(cnode)) { return; } for (auto &input : cnode->inputs()) { diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index d77553aef89..fe3634ef4b6 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -136,23 +136,21 @@ CNodePtr FuncGraph::NewCNodeInFront(const std::vector &inputs) { CNodePtr FuncGraph::NewCNodeBefore(const AnfNodePtr &position, const std::vector &inputs) { CNodePtr cnode = NewCNode(inputs); - auto iter = std::find(order_.begin(), order_.end(), position); + CNodePtr pos_cnode = dyn_cast(position); + auto iter = order_.find(pos_cnode); order_.insert(iter, cnode); return cnode; } CNodePtr FuncGraph::NewCNodeAfter(const AnfNodePtr &position, const std::vector &inputs) { CNodePtr cnode = NewCNode(inputs); - if (!position->isa()) { - order_.push_front(cnode); - return cnode; - } - auto iter = std::find(order_.begin(), order_.end(), position); + CNodePtr pos_cnode = dyn_cast(position); + auto iter = order_.find(pos_cnode); if (iter == order_.end()) { order_.push_front(cnode); - return cnode; + } else { + order_.insert(std::next(iter), cnode); } - order_.insert(std::next(iter), cnode); return cnode; } @@ -616,7 +614,7 @@ void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &node) { if (node) { auto cnode = node->cast(); if (cnode) { - order_.remove(cnode); + order_.erase(cnode); MS_LOG(DEBUG) << "Remove the node" << node->DebugString() << " from order list."; } } @@ -636,7 +634,7 @@ void FuncGraph::ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new return; } // Search old node in order list. - auto iter = std::find(order_.begin(), order_.end(), old_cnode); + auto iter = order_.find(old_cnode); if (iter == order_.end()) { // Skip if old node not found in order list. return; diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 1ce01e2cc1f..9d70155c157 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -359,15 +359,15 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { void EraseUnusedNodeInOrder(const AnfNodePtr &n); void EraseUnusedNodeInOrder(); void DumpCNodeList(); - const std::list &order_list() const { return order_; } + const OrderedSet &order_list() const { return order_; } - void set_order_list(std::list &&order_list) { order_ = std::move(order_list); } + void set_order_list(OrderedSet &&order_list) { order_ = std::move(order_list); } // Add a cnode at the end of order list. void AppendOrderList(const CNodePtr &cnode) { order_.push_back(cnode); } // Prepend cnode at the front of order list. - void PrependOrderList(const CNodePtr &cnode) { order_.insert(order_.begin(), cnode); } + void PrependOrderList(const CNodePtr &cnode) { order_.push_front(cnode); } // Maintain cnode order list when a cnode is replaced by a new one. void ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node); @@ -461,7 +461,7 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { const std::vector &kwarg_values_tuple_nodes); // CNode order which relates to origin code order - std::list order_; + OrderedSet order_; bool stub_; inline static Drawer drawer_ = nullptr; // Design switch_layer_input as a ptr to diff --git a/mindspore/core/utils/ordered_set.h b/mindspore/core/utils/ordered_set.h index 3f10acd2acc..4ca82d8e917 100644 --- a/mindspore/core/utils/ordered_set.h +++ b/mindspore/core/utils/ordered_set.h @@ -58,6 +58,8 @@ class OrderedSet { } } + OrderedSet(OrderedSet &&os) = default; + explicit OrderedSet(const sequential_type &other) { for (auto &item : other) { add(item); @@ -80,23 +82,27 @@ class OrderedSet { return *this; } - // Add an element to the OrderedSet, without judging return value - void add(const element_type &e) { (void)insert(e); } + OrderedSet &operator=(OrderedSet &&os) = default; - // insert an element to the OrderedSet - std::pair insert(const element_type &e) { - iterator empty_itr; - std::pair map_pair = std::make_pair(e, empty_itr); - auto result = mapped_data_.insert(map_pair); - auto &seq_idx = result.first->second; - // if insert success; + // insert an element to the OrderedSet after the given position. + std::pair insert(iterator pos, const element_type &e) { + auto result = mapped_data_.emplace(e, ordered_data_.end()); if (result.second) { - auto it = ordered_data_.insert(ordered_data_.end(), e); - seq_idx = it; + result.first->second = ordered_data_.emplace(pos, e); } - return std::pair(seq_idx, result.second); + return {result.first->second, result.second}; } + // Add an element to the OrderedSet, without judging return value + void add(const element_type &e) { (void)insert(ordered_data_.end(), e); } + + // insert an element to the end of OrderedSet. + std::pair insert(const element_type &e) { return insert(ordered_data_.end(), e); } + + void push_back(const element_type &e) { (void)insert(ordered_data_.end(), e); } + + void push_front(const element_type &e) { (void)insert(ordered_data_.begin(), e); } + // Remove an element, if removed return true, otherwise return false bool erase(const element_type &e) { auto pos = mapped_data_.find(e); @@ -109,6 +115,16 @@ class OrderedSet { return true; } + iterator erase(iterator pos) { + (void)mapped_data_.erase(*pos); + return ordered_data_.erase(pos); + } + + iterator erase(const_iterator pos) { + (void)mapped_data_.erase(*pos); + return ordered_data_.erase(pos); + } + // Return the container size std::size_t size() const { return mapped_data_.size(); } @@ -267,6 +283,22 @@ class OrderedSet { bool contains(const element_type &e) const { return (mapped_data_.find(e) != mapped_data_.end()); } + const_iterator find(const element_type &e) const { + auto iter = mapped_data_.find(e); + if (iter == mapped_data_.end()) { + return ordered_data_.end(); + } + return iter->second; + } + + iterator find(const element_type &e) { + auto iter = mapped_data_.find(e); + if (iter == mapped_data_.end()) { + return ordered_data_.end(); + } + return iter->second; + } + // Return the count of an element in set std::size_t count(const element_type &e) const { return mapped_data_.count(e); }