forked from mindspore-Ecosystem/mindspore
!5611 Find the set of points which forming loop
Merge pull request !5611 from Margaret_wangrui/get_loop
This commit is contained in:
commit
bfe3b9d171
|
@ -279,6 +279,59 @@ std::vector<CNodePtr> KernelGraph::SortStartLabelAndEndGoto() {
|
|||
return re_order;
|
||||
}
|
||||
|
||||
void KernelGraph::GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_input_it = node_input_edges_.find(node);
|
||||
if (node_input_it == node_input_edges_.end()) {
|
||||
MS_LOG(DEBUG) << "Node [" << node->DebugString() << "] don't have input edges.";
|
||||
return;
|
||||
}
|
||||
visited_nodes_.insert(node);
|
||||
for (auto input_edge : node_input_edges_[node]) {
|
||||
size_t input_num = node_input_num_[input_edge.first];
|
||||
if (input_num == 0) {
|
||||
continue;
|
||||
}
|
||||
if (find(visited_nodes_.begin(), visited_nodes_.end(), input_edge.first) == visited_nodes_.end()) {
|
||||
MS_EXCEPTION_IF_NULL(input_edge.first);
|
||||
edge_to_[input_edge.first] = node;
|
||||
GetLoopNodesByDFS(input_edge.first, loop_num);
|
||||
} else {
|
||||
AnfNodePtr node_iter = node;
|
||||
MS_EXCEPTION_IF_NULL(node_iter);
|
||||
MS_LOG(DEBUG) << "Print loop nodes start:";
|
||||
for (; node_iter != input_edge.first; node_iter = edge_to_[node_iter]) {
|
||||
MS_EXCEPTION_IF_NULL(node_iter);
|
||||
loop_nodes_.push(node_iter);
|
||||
node_input_num_[node_iter]--;
|
||||
MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString();
|
||||
}
|
||||
loop_nodes_.push(node_iter);
|
||||
loop_nodes_.push(node);
|
||||
(*loop_num)++;
|
||||
node_input_num_[node_iter]--;
|
||||
MS_LOG(DEBUG) << "Get loop node:" << node_iter->DebugString();
|
||||
MS_LOG(DEBUG) << "Get loop node:" << node->DebugString();
|
||||
MS_LOG(DEBUG) << "Print loop nodes end, Loop num:" << *loop_num;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t KernelGraph::GetLoopNum(std::map<AnfNodePtr, size_t> none_zero_nodes) {
|
||||
uint32_t loop_num = 0;
|
||||
for (auto iter = none_zero_nodes.begin(); iter != none_zero_nodes.end(); iter++) {
|
||||
auto node = iter->first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node_input_num_[node] == 0) {
|
||||
continue;
|
||||
}
|
||||
edge_to_.clear();
|
||||
visited_nodes_.clear();
|
||||
GetLoopNodesByDFS(node, &loop_num);
|
||||
}
|
||||
return loop_num;
|
||||
}
|
||||
|
||||
void KernelGraph::CheckLoop() {
|
||||
std::map<AnfNodePtr, size_t> none_zero_nodes;
|
||||
if (node_input_edges_.size() != node_input_num_.size()) {
|
||||
|
@ -303,6 +356,7 @@ void KernelGraph::CheckLoop() {
|
|||
}
|
||||
// if don't consider control depend and loop exit,a exception will be throw
|
||||
if (!none_zero_nodes.empty()) {
|
||||
MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes);
|
||||
MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include <queue>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <stack>
|
||||
#include <unordered_set>
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -90,8 +91,6 @@ class KernelGraph : public FuncGraph {
|
|||
void AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair);
|
||||
// get map
|
||||
std::map<AnfWithOutIndex, AnfWithOutIndex> GetRefMap() const { return ref_out_in_map_; }
|
||||
// checkout whether loop exist in graph
|
||||
void CheckLoop();
|
||||
// check whether graph is executable
|
||||
bool executable() const { return executable_; }
|
||||
// set executable of graph
|
||||
|
@ -199,6 +198,10 @@ class KernelGraph : public FuncGraph {
|
|||
AnfNodePtr TransCNodeTuple(const CNodePtr &node);
|
||||
AnfNodePtr CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx);
|
||||
std::vector<CNodePtr> SortStartLabelAndEndGoto();
|
||||
// checkout whether loop exist in graph
|
||||
void CheckLoop();
|
||||
uint32_t GetLoopNum(std::map<AnfNodePtr, size_t> none_zero_nodes);
|
||||
void GetLoopNodesByDFS(AnfNodePtr node, uint32_t *loop_num);
|
||||
|
||||
std::shared_ptr<std::vector<AnfNodePtr>> inputs_;
|
||||
std::vector<AnfNodePtr> child_graph_result_;
|
||||
|
@ -243,6 +246,10 @@ class KernelGraph : public FuncGraph {
|
|||
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
|
||||
uint32_t current_epoch_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_;
|
||||
|
||||
std::set<AnfNodePtr> visited_nodes_;
|
||||
std::map<AnfNodePtr, AnfNodePtr> edge_to_;
|
||||
std::stack<AnfNodePtr> loop_nodes_;
|
||||
};
|
||||
} // namespace session
|
||||
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
|
||||
|
|
Loading…
Reference in New Issue