Adjust some routines of FG and FGM, about the nodes info. IF.

This commit is contained in:
Zhang Qinghua 2020-05-24 18:16:37 +08:00
parent 737bfc9595
commit dbb86cb1be
6 changed files with 116 additions and 104 deletions

View File

@ -198,7 +198,7 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
const AnfNodeSet &FuncGraph::nodes() { return nodes_; } const AnfNodeSet &FuncGraph::nodes() { return nodes_; }
void FuncGraph::CopyNodes(const AnfNodeSet &other_nodes) { nodes_ = other_nodes; } void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); }
void FuncGraph::ClearNodes() { nodes_.clear(); } void FuncGraph::ClearNodes() { nodes_.clear(); }
@ -215,7 +215,12 @@ void FuncGraph::DropNode(AnfNodePtr node) {
const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; } const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; }
void FuncGraph::CopyValueNodes(const AnfNodeCounterMap &other_value_nodes) { value_nodes_ = other_value_nodes; } void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) {
auto &others = source->value_nodes();
for (auto it = others.begin(); it != others.end(); it++) {
AddValueNode(it->first, it->second);
}
}
void FuncGraph::ClearValueNodes() { value_nodes_.clear(); } void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }
@ -243,9 +248,9 @@ void FuncGraph::DropValueNode(AnfNodePtr node) {
const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; } const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; }
void FuncGraph::CopyFreeVariables(const AnfNodeCounterMap &others) { void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) {
auto it = others.begin(); auto &others = source->free_variables();
for (; it != others.end(); it++) { for (auto it = others.begin(); it != others.end(); it++) {
if (it->first->func_graph().get() != this) { if (it->first->func_graph().get() != this) {
(void)AddFreeVariable(it->first, it->second); (void)AddFreeVariable(it->first, it->second);
} }
@ -313,31 +318,37 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
return func_graphs; return func_graphs;
} }
const AnfNodeCounterMap &FuncGraph::func_graph_value_nodes() { return func_graph_value_nodes_; } const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; }
void FuncGraph::CopyFuncGraphValueNodes(const AnfNodeCounterMap &others) { func_graph_value_nodes_ = others; } void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) {
auto &others = source->func_graphs_used();
for (auto it = others.begin(); it != others.end(); it++) {
(void)AddFuncGraphUsed(it->first, it->second);
}
func_graphs_used_.erase(source);
}
void FuncGraph::ClearFuncGraphValueNodes() { func_graph_value_nodes_.clear(); } void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); }
bool FuncGraph::AddFuncGraphValueNode(AnfNodePtr node, int count) { bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) {
if (func_graph_value_nodes_.count(node) == 0) { if (func_graphs_used_.count(fg) == 0) {
func_graph_value_nodes_[node] = count; func_graphs_used_[fg] = count;
return true; return true;
} else { } else {
func_graph_value_nodes_[node] += count; func_graphs_used_[fg] += count;
return false; return false;
} }
} }
bool FuncGraph::DropFuncGraphValueNode(AnfNodePtr node) { bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) {
if (func_graph_value_nodes_.count(node) != 0) { if (func_graphs_used_.count(fg) != 0) {
if (func_graph_value_nodes_[node] == 1) { if (func_graphs_used_[fg] == 1) {
(void)func_graph_value_nodes_.erase(node); (void)func_graphs_used_.erase(fg);
return true; return true;
} else { } else {
func_graph_value_nodes_[node]--; func_graphs_used_[fg]--;
if (func_graph_value_nodes_[node] < 0) { if (func_graphs_used_[fg] < 0) {
MS_LOG(EXCEPTION) << "Count of value node(FuncGraph) '" << node MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
} }
} }
@ -354,11 +365,13 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; } const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; }
void FuncGraph::CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &others) { void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) {
auto it = others.begin(); auto &others = source->func_graph_cnodes_index();
for (; it != others.end(); it++) { for (auto it = others.begin(); it != others.end(); it++) {
// Ignore the user graph who may own itself. // Ignore the user graph who may own itself.
if (it->first->first->func_graph().get() != this) { auto fg = it->first->first->func_graph();
MS_EXCEPTION_IF_NULL(fg);
if (fg.get() != this) {
AddFuncGraphCNodeIndex(it->first, it->second); AddFuncGraphCNodeIndex(it->first, it->second);
} }
} }
@ -388,28 +401,33 @@ void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) {
} }
} }
const AnfNodeCounterMap &FuncGraph::j_func_graph_value_nodes() { return j_func_graph_value_nodes_; } const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; }
void FuncGraph::CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others) { j_func_graph_value_nodes_ = others; } void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) {
auto &others = source->j_func_graphs();
void FuncGraph::ClearJFuncGraphValueNodes() { j_func_graph_value_nodes_.clear(); } for (auto it = others.begin(); it != others.end(); it++) {
AddJFuncGraph(it->first, it->second);
void FuncGraph::AddJFuncGraphValueNode(AnfNodePtr node, int count) {
if (j_func_graph_value_nodes_.count(node) == 0) {
j_func_graph_value_nodes_[node] = count;
} else {
j_func_graph_value_nodes_[node] += count;
} }
} }
void FuncGraph::DropJFuncGraphValueNode(AnfNodePtr node) { void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); }
if (j_func_graph_value_nodes_.count(node) != 0) {
if (j_func_graph_value_nodes_[node] == 1) { void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) {
(void)j_func_graph_value_nodes_.erase(node); if (j_func_graphs_.count(fg) == 0) {
j_func_graphs_[fg] = count;
} else { } else {
j_func_graph_value_nodes_[node]--; j_func_graphs_[fg] += count;
if (j_func_graph_value_nodes_[node] < 0) { }
MS_LOG(EXCEPTION) << "Count of value node(J FuncGraph) '" << node }
void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) {
if (j_func_graphs_.count(fg) != 0) {
if (j_func_graphs_[fg] == 1) {
(void)j_func_graphs_.erase(fg);
} else {
j_func_graphs_[fg]--;
if (j_func_graphs_[fg] < 0) {
MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info()); << "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
} }
} }

View File

@ -189,21 +189,21 @@ class FuncGraph : public FuncGraphBase {
// get all nodes belonging to this func graph // get all nodes belonging to this func graph
const AnfNodeSet &nodes(); const AnfNodeSet &nodes();
void CopyNodes(const AnfNodeSet &other_nodes); void CopyNodes(const FuncGraphPtr &source);
void ClearNodes(); void ClearNodes();
void AddNode(AnfNodePtr node); void AddNode(AnfNodePtr node);
void DropNode(AnfNodePtr node); void DropNode(AnfNodePtr node);
// get all value_nodes belonging to this func graph // get all value_nodes belonging to this func graph
const AnfNodeCounterMap &value_nodes(); const AnfNodeCounterMap &value_nodes();
void CopyValueNodes(const AnfNodeCounterMap &other_value_nodes); void CopyValueNodes(const FuncGraphPtr &source);
void ClearValueNodes(); void ClearValueNodes();
void AddValueNode(AnfNodePtr node, int count = 1); void AddValueNode(AnfNodePtr node, int count = 1);
void DropValueNode(AnfNodePtr node); void DropValueNode(AnfNodePtr node);
// get all free vars directly used in this func graph // get all free vars directly used in this func graph
const AnfNodeCounterMap &free_variables(); const AnfNodeCounterMap &free_variables();
void CopyFreeVariables(const AnfNodeCounterMap &others); void CopyFreeVariables(const FuncGraphPtr &source);
void ClearFreeVariables(); void ClearFreeVariables();
bool AddFreeVariable(AnfNodePtr node, int count = 1); bool AddFreeVariable(AnfNodePtr node, int count = 1);
bool DropFreeVariable(AnfNodePtr node); bool DropFreeVariable(AnfNodePtr node);
@ -218,25 +218,25 @@ class FuncGraph : public FuncGraphBase {
std::vector<FuncGraphPtr> free_variables_func_graphs(); std::vector<FuncGraphPtr> free_variables_func_graphs();
// get all value nodes of func graph directly used by this func graph // get all value nodes of func graph directly used by this func graph
const AnfNodeCounterMap &func_graph_value_nodes(); const FuncGraphCounterMap &func_graphs_used();
void CopyFuncGraphValueNodes(const AnfNodeCounterMap &others); void CopyFuncGraphsUsed(const FuncGraphPtr &source);
void ClearFuncGraphValueNodes(); void ClearFuncGraphsUsed();
bool AddFuncGraphValueNode(AnfNodePtr node, int count = 1); bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1);
bool DropFuncGraphValueNode(AnfNodePtr node); bool DropFuncGraphUsed(FuncGraphPtr fg);
// get all value nodes of J func graph directly used by this func graph // get all value nodes of J func graph directly used by this func graph
const AnfNodeCounterMap &j_func_graph_value_nodes(); const FuncGraphCounterMap &j_func_graphs();
void CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others); void CopyJFuncGraphs(const FuncGraphPtr &source);
void ClearJFuncGraphValueNodes(); void ClearJFuncGraphs();
void AddJFuncGraphValueNode(AnfNodePtr node, int count = 1); void AddJFuncGraph(FuncGraphPtr fg, int count = 1);
void DropJFuncGraphValueNode(AnfNodePtr node); void DropJFuncGraph(FuncGraphPtr fg);
// get all func graphs nested used by this func graph // get all func graphs nested used by this func graph
const FuncGraphSet &func_graphs_used_total(); const FuncGraphSet &func_graphs_used_total();
// get all user value nodes of this func graph, by CNode and its input's index // get all user value nodes of this func graph, by CNode and its input's index
const CNodeIndexCounterMap &func_graph_cnodes_index(); const CNodeIndexCounterMap &func_graph_cnodes_index();
void CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &other_value_nodes); void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
void ClearFuncGraphCNodesIndex(); void ClearFuncGraphCNodesIndex();
void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1); void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node); void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);
@ -311,13 +311,13 @@ class FuncGraph : public FuncGraphBase {
AnfNodeCounterMap value_nodes_; AnfNodeCounterMap value_nodes_;
// all func graph value nodes of the function // all func graph value nodes of the function
AnfNodeCounterMap func_graph_value_nodes_; FuncGraphCounterMap func_graphs_used_;
// all free variables of the function // all free variables of the function
AnfNodeCounterMap free_variables_; AnfNodeCounterMap free_variables_;
// all value nodes calling J in the function // all value nodes calling J in the function
AnfNodeCounterMap j_func_graph_value_nodes_; FuncGraphCounterMap j_func_graphs_;
// all user value nodes of this func graph, recording by CNode and its input's index // all user value nodes of this func graph, recording by CNode and its input's index
CNodeIndexCounterMap func_graph_cnodes_index_; CNodeIndexCounterMap func_graph_cnodes_index_;

View File

@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
if (!clone_all_used_graphs_) { if (!clone_all_used_graphs_) {
return; return;
} }
auto &used = func_graph->func_graph_value_nodes(); auto &used = func_graph->func_graphs_used();
for (auto &fg_value_node : used) { for (auto &fg : used) {
todo_.push_back({GetValueNode<FuncGraphPtr>(fg_value_node.first), nullptr, {}}); todo_.push_back({fg.first, nullptr, {}});
} }
} }

View File

@ -196,7 +196,6 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) {
return; return;
} }
AddIntoManaged(func_graph); AddIntoManaged(func_graph);
MS_EXCEPTION_IF_NULL(signals_);
std::vector<AnfNodePtr> para = func_graph->parameters(); std::vector<AnfNodePtr> para = func_graph->parameters();
AcquireNodes(para); AcquireNodes(para);
std::vector<AnfNodePtr> return_vec({func_graph->get_return()}); std::vector<AnfNodePtr> return_vec({func_graph->get_return()});
@ -301,7 +300,6 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool
std::vector<AnfNodePtr> return_vec = {func_graph->get_return()}; std::vector<AnfNodePtr> return_vec = {func_graph->get_return()};
todo.update(MaybeDropNodes(return_vec)); todo.update(MaybeDropNodes(return_vec));
} }
MS_EXCEPTION_IF_NULL(signals_);
for (auto &fg : dropped) { for (auto &fg : dropped) {
MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg);
all_nodes_.difference_update(fg->parameters()); all_nodes_.difference_update(fg->parameters());
@ -334,7 +332,6 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
} }
auto &users_node = node_users_[inp]; auto &users_node = node_users_[inp];
users_node.add(make_pair(node, index)); users_node.add(make_pair(node, index));
MS_EXCEPTION_IF_NULL(signals_);
AddEdge(node, index, inp); AddEdge(node, index, inp);
} }
} }
@ -384,8 +381,6 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) { FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) {
AnfNodeSet nodes_ordered(nodes); AnfNodeSet nodes_ordered(nodes);
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>(); FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
MS_EXCEPTION_IF_NULL(signals_);
while (!nodes_ordered.empty()) { while (!nodes_ordered.empty()) {
AnfNodePtr node = nodes_ordered.pop(); AnfNodePtr node = nodes_ordered.pop();
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -475,13 +470,13 @@ inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr inp
if (input->isa<ValueNode>()) { if (input->isa<ValueNode>()) {
fg->AddValueNode(input); fg->AddValueNode(input);
if (IsValueNode<FuncGraph>(input)) { if (IsValueNode<FuncGraph>(input)) {
if (fg->AddFuncGraphValueNode(input)) {
signals_->InvalidateComputer();
}
auto used = GetValueNode<FuncGraphPtr>(input); auto used = GetValueNode<FuncGraphPtr>(input);
used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index))); used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
if (fg->AddFuncGraphUsed(used)) {
signals_->InvalidateComputer();
}
if (IsPrimitiveCNode(node, prim::kPrimJ)) { if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->AddJFuncGraphValueNode(input); fg->AddJFuncGraph(used);
} }
} }
} else if (fg != nullptr && fg != input->func_graph()) { } else if (fg != nullptr && fg != input->func_graph()) {
@ -496,13 +491,13 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in
if (input->isa<ValueNode>()) { if (input->isa<ValueNode>()) {
fg->DropValueNode(input); fg->DropValueNode(input);
if (IsValueNode<FuncGraph>(input)) { if (IsValueNode<FuncGraph>(input)) {
if (fg->DropFuncGraphValueNode(input)) {
signals_->InvalidateComputer();
}
auto used = GetValueNode<FuncGraphPtr>(input); auto used = GetValueNode<FuncGraphPtr>(input);
used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index))); used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
if (fg->DropFuncGraphUsed(used)) {
signals_->InvalidateComputer();
}
if (IsPrimitiveCNode(node, prim::kPrimJ)) { if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->DropJFuncGraphValueNode(input); fg->DropJFuncGraph(used);
} }
} }
} else if (fg != nullptr && fg != input->func_graph()) { } else if (fg != nullptr && fg != input->func_graph()) {
@ -513,19 +508,19 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in
} }
inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) { inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) {
target->CopyNodes(source->nodes()); target->CopyNodes(source);
target->CopyValueNodes(source->value_nodes()); target->CopyValueNodes(source);
target->CopyFuncGraphCNodesIndex(source->func_graph_cnodes_index()); target->CopyFuncGraphCNodesIndex(source);
target->CopyFreeVariables(source->free_variables()); target->CopyFreeVariables(source);
target->CopyFuncGraphValueNodes(source->func_graph_value_nodes()); target->CopyFuncGraphsUsed(source);
target->CopyJFuncGraphValueNodes(source->j_func_graph_value_nodes()); target->CopyJFuncGraphs(source);
signals_->InvalidateComputer(); signals_->InvalidateComputer();
source->ClearNodes(); source->ClearNodes();
source->ClearValueNodes(); source->ClearValueNodes();
source->ClearFuncGraphCNodesIndex(); source->ClearFuncGraphCNodesIndex();
source->ClearFreeVariables(); source->ClearFreeVariables();
source->ClearFuncGraphValueNodes(); source->ClearFuncGraphsUsed();
source->ClearJFuncGraphValueNodes(); source->ClearJFuncGraphs();
} }
FuncGraphTransaction FuncGraphManager::Transact() { FuncGraphTransaction FuncGraphManager::Transact() {
@ -768,10 +763,10 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f
} }
// Search the fv in fg's child func graph. // Search the fv in fg's child func graph.
auto &fg_value_nodes = fg->func_graph_value_nodes(); auto &fgs = fg->func_graphs_used();
for (auto &fg_value_node : fg_value_nodes) { for (auto &item : fgs) {
fg->seen_ = seen_num; fg->seen_ = seen_num;
auto gt = GetValueNode<FuncGraphPtr>(fg_value_node.first); auto gt = item.first;
parents->update(SeekParents(gt, seen_num)); parents->update(SeekParents(gt, seen_num));
} }
(void)parents->erase(fg); (void)parents->erase(fg);
@ -865,15 +860,15 @@ void FVTotalComputer::RealRecompute() {
} }
} }
auto &used = fg->func_graph_value_nodes(); auto &used = fg->func_graphs_used();
for (auto &iter : used) { for (auto &iter : used) {
auto p = manager->parent(GetValueNode<FuncGraphPtr>(iter.first)); auto p = manager->parent(iter.first);
if (p == nullptr) { if (p == nullptr) {
continue; continue;
} }
auto curr = fg; auto curr = fg;
while (curr != p) { while (curr != p) {
(void)CounterFuncGraphCollector::Mod(curr, GetValueNode<FuncGraphPtr>(iter.first), iter.second); (void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second);
curr = manager->parent(curr); curr = manager->parent(curr);
} }
} }
@ -899,8 +894,8 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
while (!todo.empty()) { while (!todo.empty()) {
todo_new.clear(); todo_new.clear();
for (auto &gt : todo) { for (auto &gt : todo) {
for (auto &item : gt->func_graph_value_nodes()) { for (auto &item : gt->func_graphs_used()) {
auto used_fg = GetValueNode<FuncGraphPtr>(item.first); auto used_fg = item.first;
if (used_fg == fg) { if (used_fg == fg) {
func_graph_used_total_analysis_[fg].add(used_fg); func_graph_used_total_analysis_[fg].add(used_fg);
continue; continue;
@ -925,8 +920,8 @@ bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &f
while (!todo.empty()) { while (!todo.empty()) {
todo_new.clear(); todo_new.clear();
for (auto &gt : todo) { for (auto &gt : todo) {
for (auto &item : gt->func_graph_value_nodes()) { for (auto &item : gt->func_graphs_used()) {
auto used_g = GetValueNode<FuncGraphPtr>(item.first); auto used_g = item.first;
if (used_g == fg) { if (used_g == fg) {
return true; return true;
} }
@ -957,9 +952,9 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
} }
} else { } else {
trace->push_back(fg); trace->push_back(fg);
auto &items = fg->func_graph_value_nodes(); auto &items = fg->func_graphs_used();
for (auto iter = items.begin(); iter != items.end(); (void)iter++) { for (auto iter = items.begin(); iter != items.end(); (void)iter++) {
CheckRecursiveGraphs(GetValueNode<FuncGraphPtr>(iter->first), trace); CheckRecursiveGraphs(iter->first, trace);
} }
trace->pop_back(); trace->pop_back();
if (!recursive_map_.count(fg)) { if (!recursive_map_.count(fg)) {
@ -973,14 +968,13 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
MS_LOG(DEBUG) << fg->ToString() << " had been checked"; MS_LOG(DEBUG) << fg->ToString() << " had been checked";
return false; return false;
} }
auto &j_fg_value_nodes = fg->j_func_graph_value_nodes(); auto &j_fgs = fg->j_func_graphs();
if (!j_fg_value_nodes.empty()) { if (!j_fgs.empty()) {
// check g1->J(fg)->g2->g cycle; // check g1->J(fg)->g2->g cycle;
auto contains_j = auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair<FuncGraphPtr, int> iter) {
std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), [seen_num](const std::pair<AnfNodePtr, int> iter) { return iter.first->seen_ != seen_num;
return GetValueNode<FuncGraphPtr>(iter.first)->seen_ != seen_num;
}); });
if (contains_j != j_fg_value_nodes.end()) { if (contains_j != j_fgs.end()) {
MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")"; MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
return true; return true;
} }
@ -988,8 +982,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
fg->seen_ = seen_num; fg->seen_ = seen_num;
// check if func graphs used contains J(func_graph); // check if func graphs used contains J(func_graph);
for (auto &item : fg->func_graph_value_nodes()) { for (auto &item : fg->func_graphs_used()) {
auto used_g = GetValueNode<FuncGraphPtr>(item.first); auto used_g = item.first;
if (SeekJ(used_g, seen_num)) { if (SeekJ(used_g, seen_num)) {
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)"; MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
return true; return true;

View File

@ -2187,7 +2187,7 @@ void MarkForwardCNode(const FuncGraphPtr &root) {
SetForwardFlag(all_nodes); SetForwardFlag(all_nodes);
} else { } else {
for (auto &func_graph : graph_set) { for (auto &func_graph : graph_set) {
MS_LOG(INFO) << "The sub graph size of root is " << root->func_graph_value_nodes().size(); MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
auto return_node = func_graph->get_return(); auto return_node = func_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node); MS_EXCEPTION_IF_NULL(return_node);
auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); auto all_dfs_nodes = DeepLinkedGraphSearch(return_node);

View File

@ -462,8 +462,8 @@ TEST_F(TestManager, test_nested_manual) {
ASSERT_EQ(1, iter.second.size()); ASSERT_EQ(1, iter.second.size());
} }
ASSERT_EQ(1, f->func_graph_value_nodes().size()); ASSERT_EQ(1, f->func_graphs_used().size());
ASSERT_EQ(0, g->func_graph_value_nodes().size()); ASSERT_EQ(0, g->func_graphs_used().size());
ASSERT_EQ(0, f->free_variables().size()); ASSERT_EQ(0, f->free_variables().size());
ASSERT_EQ(1, g->free_variables().size()); ASSERT_EQ(1, g->free_variables().size());