!5611 Find the set of points which forming loop

Merge pull request !5611 from Margaret_wangrui/get_loop
This commit is contained in:
mindspore-ci-bot 2020-09-02 11:08:15 +08:00 committed by Gitee
commit bfe3b9d171
2 changed files with 63 additions and 2 deletions

View File

@ -279,6 +279,59 @@ std::vector<CNodePtr> KernelGraph::SortStartLabelAndEndGoto() {
return re_order;
}
void KernelGraph::GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num) {
MS_EXCEPTION_IF_NULL(node);
auto node_input_it = node_input_edges_.find(node);
if (node_input_it == node_input_edges_.end()) {
MS_LOG(DEBUG) << "Node [" << node->DebugString() << "] don't have input edges.";
return;
}
visited_nodes_.insert(node);
for (auto input_edge : node_input_edges_[node]) {
size_t input_num = node_input_num_[input_edge.first];
if (input_num == 0) {
continue;
}
if (find(visited_nodes_.begin(), visited_nodes_.end(), input_edge.first) == visited_nodes_.end()) {
MS_EXCEPTION_IF_NULL(input_edge.first);
edge_to_[input_edge.first] = node;
GetLoopNodesByDFS(input_edge.first, loop_num);
} else {
AnfNodePtr node_iter = node;
MS_EXCEPTION_IF_NULL(node_iter);
MS_LOG(DEBUG) << "Print loop nodes start:";
for (; node_iter != input_edge.first; node_iter = edge_to_[node_iter]) {
MS_EXCEPTION_IF_NULL(node_iter);
loop_nodes_.push(node_iter);
node_input_num_[node_iter]--;
MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString();
}
loop_nodes_.push(node_iter);
loop_nodes_.push(node);
(*loop_num)++;
node_input_num_[node_iter]--;
MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString();
MS_LOG(DEBUG) << "Get loop node:" << node->DebugString();
MS_LOG(DEBUG) << "Print loop nodes end, Loop num:" << *loop_num;
}
}
}
uint32_t KernelGraph::GetLoopNum(std::map<AnfNodePtr, size_t> none_zero_nodes) {
uint32_t loop_num = 0;
for (auto iter = none_zero_nodes.begin(); iter != none_zero_nodes.end(); iter++) {
auto node = iter->first;
MS_EXCEPTION_IF_NULL(node);
if (node_input_num_[node] == 0) {
continue;
}
edge_to_.clear();
visited_nodes_.clear();
GetLoopNodesByDFS(node, &loop_num);
}
return loop_num;
}
void KernelGraph::CheckLoop() {
std::map<AnfNodePtr, size_t> none_zero_nodes;
if (node_input_edges_.size() != node_input_num_.size()) {
@ -303,6 +356,7 @@ void KernelGraph::CheckLoop() {
}
// if don't consider control depend and loop exit,a exception will be throw
if (!none_zero_nodes.empty()) {
MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes);
MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
}
}

View File

@ -24,6 +24,7 @@
#include <queue>
#include <map>
#include <set>
#include <stack>
#include <unordered_set>
#include "ir/func_graph.h"
#include "ir/anf.h"
@ -90,8 +91,6 @@ class KernelGraph : public FuncGraph {
void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair);
// get map
std::map<AnfWithOutIndex, AnfWithOutIndex> GetRefMap() const { return ref_out_in_map_; }
// checkout whether loop exist in graph
void CheckLoop();
// check whether graph is executable
bool executable() const { return executable_; }
// set executable of graph
@ -199,6 +198,10 @@ class KernelGraph : public FuncGraph {
AnfNodePtr TransCNodeTuple(const CNodePtr &node);
AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx);
std::vector<CNodePtr> SortStartLabelAndEndGoto();
// checkout whether loop exist in graph
void CheckLoop();
uint32_t GetLoopNum(std::map<AnfNodePtr, size_t> none_zero_nodes);
void GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num);
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
std::vector<AnfNodePtr> child_graph_result_;
@ -243,6 +246,10 @@ class KernelGraph : public FuncGraph {
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
uint32_t current_epoch_;
std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_;
std::set<AnfNodePtr> visited_nodes_;
std::map<AnfNodePtr, AnfNodePtr> edge_to_;
std::stack<AnfNodePtr> loop_nodes_;
};
} // namespace session
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;