forked from mindspore-Ecosystem/mindspore
[auto-monad] Optimize order list
Use OrderedSet for order list to optimize performance.
This commit is contained in:
parent
adfe6e1bc2
commit
f23b92af4c
|
@ -4,6 +4,9 @@ mindspore/lib
|
||||||
output
|
output
|
||||||
*.ir
|
*.ir
|
||||||
st_tests
|
st_tests
|
||||||
|
kernel_meta/
|
||||||
|
somas_meta/
|
||||||
|
trace_code_graph_*
|
||||||
|
|
||||||
# mindspore lite java
|
# mindspore lite java
|
||||||
mindspore/lite/java/java/.gradle
|
mindspore/lite/java/java/.gradle
|
||||||
|
|
|
@ -323,7 +323,7 @@ class SideEffectFinder {
|
||||||
}
|
}
|
||||||
|
|
||||||
static void UpdateOrderList(const FuncGraphPtr &func_graph) {
|
static void UpdateOrderList(const FuncGraphPtr &func_graph) {
|
||||||
std::list<CNodePtr> new_order_list;
|
OrderedSet<CNodePtr> new_order_list;
|
||||||
const auto &order_list = func_graph->order_list();
|
const auto &order_list = func_graph->order_list();
|
||||||
for (auto &cnode : order_list) {
|
for (auto &cnode : order_list) {
|
||||||
PushToOrderList(func_graph, cnode, &new_order_list);
|
PushToOrderList(func_graph, cnode, &new_order_list);
|
||||||
|
@ -331,10 +331,9 @@ class SideEffectFinder {
|
||||||
func_graph->set_order_list(std::move(new_order_list));
|
func_graph->set_order_list(std::move(new_order_list));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, std::list<CNodePtr> *new_order_list) {
|
static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, OrderedSet<CNodePtr> *new_order_list) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
auto iter = std::find(new_order_list->begin(), new_order_list->end(), cnode);
|
if (new_order_list->contains(cnode)) {
|
||||||
if (iter != new_order_list->end()) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (auto &input : cnode->inputs()) {
|
for (auto &input : cnode->inputs()) {
|
||||||
|
|
|
@ -136,23 +136,21 @@ CNodePtr FuncGraph::NewCNodeInFront(const std::vector<AnfNodePtr> &inputs) {
|
||||||
|
|
||||||
CNodePtr FuncGraph::NewCNodeBefore(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) {
|
CNodePtr FuncGraph::NewCNodeBefore(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) {
|
||||||
CNodePtr cnode = NewCNode(inputs);
|
CNodePtr cnode = NewCNode(inputs);
|
||||||
auto iter = std::find(order_.begin(), order_.end(), position);
|
CNodePtr pos_cnode = dyn_cast<CNode>(position);
|
||||||
|
auto iter = order_.find(pos_cnode);
|
||||||
order_.insert(iter, cnode);
|
order_.insert(iter, cnode);
|
||||||
return cnode;
|
return cnode;
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr FuncGraph::NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) {
|
CNodePtr FuncGraph::NewCNodeAfter(const AnfNodePtr &position, const std::vector<AnfNodePtr> &inputs) {
|
||||||
CNodePtr cnode = NewCNode(inputs);
|
CNodePtr cnode = NewCNode(inputs);
|
||||||
if (!position->isa<CNode>()) {
|
CNodePtr pos_cnode = dyn_cast<CNode>(position);
|
||||||
order_.push_front(cnode);
|
auto iter = order_.find(pos_cnode);
|
||||||
return cnode;
|
|
||||||
}
|
|
||||||
auto iter = std::find(order_.begin(), order_.end(), position);
|
|
||||||
if (iter == order_.end()) {
|
if (iter == order_.end()) {
|
||||||
order_.push_front(cnode);
|
order_.push_front(cnode);
|
||||||
return cnode;
|
} else {
|
||||||
|
order_.insert(std::next(iter), cnode);
|
||||||
}
|
}
|
||||||
order_.insert(std::next(iter), cnode);
|
|
||||||
return cnode;
|
return cnode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -616,7 +614,7 @@ void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &node) {
|
||||||
if (node) {
|
if (node) {
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
if (cnode) {
|
if (cnode) {
|
||||||
order_.remove(cnode);
|
order_.erase(cnode);
|
||||||
MS_LOG(DEBUG) << "Remove the node" << node->DebugString() << " from order list.";
|
MS_LOG(DEBUG) << "Remove the node" << node->DebugString() << " from order list.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -636,7 +634,7 @@ void FuncGraph::ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// Search old node in order list.
|
// Search old node in order list.
|
||||||
auto iter = std::find(order_.begin(), order_.end(), old_cnode);
|
auto iter = order_.find(old_cnode);
|
||||||
if (iter == order_.end()) {
|
if (iter == order_.end()) {
|
||||||
// Skip if old node not found in order list.
|
// Skip if old node not found in order list.
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -359,15 +359,15 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
||||||
void EraseUnusedNodeInOrder(const AnfNodePtr &n);
|
void EraseUnusedNodeInOrder(const AnfNodePtr &n);
|
||||||
void EraseUnusedNodeInOrder();
|
void EraseUnusedNodeInOrder();
|
||||||
void DumpCNodeList();
|
void DumpCNodeList();
|
||||||
const std::list<CNodePtr> &order_list() const { return order_; }
|
const OrderedSet<CNodePtr> &order_list() const { return order_; }
|
||||||
|
|
||||||
void set_order_list(std::list<CNodePtr> &&order_list) { order_ = std::move(order_list); }
|
void set_order_list(OrderedSet<CNodePtr> &&order_list) { order_ = std::move(order_list); }
|
||||||
|
|
||||||
// Add a cnode at the end of order list.
|
// Add a cnode at the end of order list.
|
||||||
void AppendOrderList(const CNodePtr &cnode) { order_.push_back(cnode); }
|
void AppendOrderList(const CNodePtr &cnode) { order_.push_back(cnode); }
|
||||||
|
|
||||||
// Prepend cnode at the front of order list.
|
// Prepend cnode at the front of order list.
|
||||||
void PrependOrderList(const CNodePtr &cnode) { order_.insert(order_.begin(), cnode); }
|
void PrependOrderList(const CNodePtr &cnode) { order_.push_front(cnode); }
|
||||||
|
|
||||||
// Maintain cnode order list when a cnode is replaced by a new one.
|
// Maintain cnode order list when a cnode is replaced by a new one.
|
||||||
void ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
void ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node);
|
||||||
|
@ -461,7 +461,7 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
||||||
const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes);
|
const std::vector<AnfNodePtr> &kwarg_values_tuple_nodes);
|
||||||
|
|
||||||
// CNode order which relates to origin code order
|
// CNode order which relates to origin code order
|
||||||
std::list<CNodePtr> order_;
|
OrderedSet<CNodePtr> order_;
|
||||||
bool stub_;
|
bool stub_;
|
||||||
inline static Drawer drawer_ = nullptr;
|
inline static Drawer drawer_ = nullptr;
|
||||||
// Design switch_layer_input as a ptr to
|
// Design switch_layer_input as a ptr to
|
||||||
|
|
|
@ -58,6 +58,8 @@ class OrderedSet {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OrderedSet(OrderedSet &&os) = default;
|
||||||
|
|
||||||
explicit OrderedSet(const sequential_type &other) {
|
explicit OrderedSet(const sequential_type &other) {
|
||||||
for (auto &item : other) {
|
for (auto &item : other) {
|
||||||
add(item);
|
add(item);
|
||||||
|
@ -80,23 +82,27 @@ class OrderedSet {
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add an element to the OrderedSet, without judging return value
|
OrderedSet &operator=(OrderedSet &&os) = default;
|
||||||
void add(const element_type &e) { (void)insert(e); }
|
|
||||||
|
|
||||||
// insert an element to the OrderedSet
|
// insert an element to the OrderedSet after the given position.
|
||||||
std::pair<iterator, bool> insert(const element_type &e) {
|
std::pair<iterator, bool> insert(iterator pos, const element_type &e) {
|
||||||
iterator empty_itr;
|
auto result = mapped_data_.emplace(e, ordered_data_.end());
|
||||||
std::pair<element_type, typename map_type::mapped_type> map_pair = std::make_pair(e, empty_itr);
|
|
||||||
auto result = mapped_data_.insert(map_pair);
|
|
||||||
auto &seq_idx = result.first->second;
|
|
||||||
// if insert success;
|
|
||||||
if (result.second) {
|
if (result.second) {
|
||||||
auto it = ordered_data_.insert(ordered_data_.end(), e);
|
result.first->second = ordered_data_.emplace(pos, e);
|
||||||
seq_idx = it;
|
|
||||||
}
|
}
|
||||||
return std::pair<iterator, bool>(seq_idx, result.second);
|
return {result.first->second, result.second};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add an element to the OrderedSet, without judging return value
|
||||||
|
void add(const element_type &e) { (void)insert(ordered_data_.end(), e); }
|
||||||
|
|
||||||
|
// insert an element to the end of OrderedSet.
|
||||||
|
std::pair<iterator, bool> insert(const element_type &e) { return insert(ordered_data_.end(), e); }
|
||||||
|
|
||||||
|
void push_back(const element_type &e) { (void)insert(ordered_data_.end(), e); }
|
||||||
|
|
||||||
|
void push_front(const element_type &e) { (void)insert(ordered_data_.begin(), e); }
|
||||||
|
|
||||||
// Remove an element, if removed return true, otherwise return false
|
// Remove an element, if removed return true, otherwise return false
|
||||||
bool erase(const element_type &e) {
|
bool erase(const element_type &e) {
|
||||||
auto pos = mapped_data_.find(e);
|
auto pos = mapped_data_.find(e);
|
||||||
|
@ -109,6 +115,16 @@ class OrderedSet {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
iterator erase(iterator pos) {
|
||||||
|
(void)mapped_data_.erase(*pos);
|
||||||
|
return ordered_data_.erase(pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator erase(const_iterator pos) {
|
||||||
|
(void)mapped_data_.erase(*pos);
|
||||||
|
return ordered_data_.erase(pos);
|
||||||
|
}
|
||||||
|
|
||||||
// Return the container size
|
// Return the container size
|
||||||
std::size_t size() const { return mapped_data_.size(); }
|
std::size_t size() const { return mapped_data_.size(); }
|
||||||
|
|
||||||
|
@ -267,6 +283,22 @@ class OrderedSet {
|
||||||
|
|
||||||
bool contains(const element_type &e) const { return (mapped_data_.find(e) != mapped_data_.end()); }
|
bool contains(const element_type &e) const { return (mapped_data_.find(e) != mapped_data_.end()); }
|
||||||
|
|
||||||
|
const_iterator find(const element_type &e) const {
|
||||||
|
auto iter = mapped_data_.find(e);
|
||||||
|
if (iter == mapped_data_.end()) {
|
||||||
|
return ordered_data_.end();
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
iterator find(const element_type &e) {
|
||||||
|
auto iter = mapped_data_.find(e);
|
||||||
|
if (iter == mapped_data_.end()) {
|
||||||
|
return ordered_data_.end();
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
// Return the count of an element in set
|
// Return the count of an element in set
|
||||||
std::size_t count(const element_type &e) const { return mapped_data_.count(e); }
|
std::size_t count(const element_type &e) const { return mapped_data_.count(e); }
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue