Adjust some routines of FG and FGM, about the nodes info. IF.

This commit is contained in:
Zhang Qinghua 2020-05-24 18:16:37 +08:00
parent 737bfc9595
commit dbb86cb1be
6 changed files with 116 additions and 104 deletions

View File

@ -198,7 +198,7 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
const AnfNodeSet &FuncGraph::nodes() { return nodes_; }
void FuncGraph::CopyNodes(const AnfNodeSet &other_nodes) { nodes_ = other_nodes; }
void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); }
void FuncGraph::ClearNodes() { nodes_.clear(); }
@ -215,7 +215,12 @@ void FuncGraph::DropNode(AnfNodePtr node) {
const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; }
void FuncGraph::CopyValueNodes(const AnfNodeCounterMap &other_value_nodes) { value_nodes_ = other_value_nodes; }
void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) {
auto &others = source->value_nodes();
for (auto it = others.begin(); it != others.end(); it++) {
AddValueNode(it->first, it->second);
}
}
void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }
@ -243,9 +248,9 @@ void FuncGraph::DropValueNode(AnfNodePtr node) {
const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; }
void FuncGraph::CopyFreeVariables(const AnfNodeCounterMap &others) {
auto it = others.begin();
for (; it != others.end(); it++) {
void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) {
auto &others = source->free_variables();
for (auto it = others.begin(); it != others.end(); it++) {
if (it->first->func_graph().get() != this) {
(void)AddFreeVariable(it->first, it->second);
}
@ -313,31 +318,37 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
return func_graphs;
}
const AnfNodeCounterMap &FuncGraph::func_graph_value_nodes() { return func_graph_value_nodes_; }
const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; }
void FuncGraph::CopyFuncGraphValueNodes(const AnfNodeCounterMap &others) { func_graph_value_nodes_ = others; }
void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) {
auto &others = source->func_graphs_used();
for (auto it = others.begin(); it != others.end(); it++) {
(void)AddFuncGraphUsed(it->first, it->second);
}
func_graphs_used_.erase(source);
}
void FuncGraph::ClearFuncGraphValueNodes() { func_graph_value_nodes_.clear(); }
void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); }
bool FuncGraph::AddFuncGraphValueNode(AnfNodePtr node, int count) {
if (func_graph_value_nodes_.count(node) == 0) {
func_graph_value_nodes_[node] = count;
bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) {
if (func_graphs_used_.count(fg) == 0) {
func_graphs_used_[fg] = count;
return true;
} else {
func_graph_value_nodes_[node] += count;
func_graphs_used_[fg] += count;
return false;
}
}
bool FuncGraph::DropFuncGraphValueNode(AnfNodePtr node) {
if (func_graph_value_nodes_.count(node) != 0) {
if (func_graph_value_nodes_[node] == 1) {
(void)func_graph_value_nodes_.erase(node);
bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) {
if (func_graphs_used_.count(fg) != 0) {
if (func_graphs_used_[fg] == 1) {
(void)func_graphs_used_.erase(fg);
return true;
} else {
func_graph_value_nodes_[node]--;
if (func_graph_value_nodes_[node] < 0) {
MS_LOG(EXCEPTION) << "Count of value node(FuncGraph) '" << node
func_graphs_used_[fg]--;
if (func_graphs_used_[fg] < 0) {
MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}
@ -354,11 +365,13 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; }
void FuncGraph::CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &others) {
auto it = others.begin();
for (; it != others.end(); it++) {
void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) {
auto &others = source->func_graph_cnodes_index();
for (auto it = others.begin(); it != others.end(); it++) {
// Ignore the user graph who may own itself.
if (it->first->first->func_graph().get() != this) {
auto fg = it->first->first->func_graph();
MS_EXCEPTION_IF_NULL(fg);
if (fg.get() != this) {
AddFuncGraphCNodeIndex(it->first, it->second);
}
}
@ -388,28 +401,33 @@ void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) {
}
}
const AnfNodeCounterMap &FuncGraph::j_func_graph_value_nodes() { return j_func_graph_value_nodes_; }
const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; }
void FuncGraph::CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others) { j_func_graph_value_nodes_ = others; }
void FuncGraph::ClearJFuncGraphValueNodes() { j_func_graph_value_nodes_.clear(); }
void FuncGraph::AddJFuncGraphValueNode(AnfNodePtr node, int count) {
if (j_func_graph_value_nodes_.count(node) == 0) {
j_func_graph_value_nodes_[node] = count;
} else {
j_func_graph_value_nodes_[node] += count;
void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) {
auto &others = source->j_func_graphs();
for (auto it = others.begin(); it != others.end(); it++) {
AddJFuncGraph(it->first, it->second);
}
}
void FuncGraph::DropJFuncGraphValueNode(AnfNodePtr node) {
if (j_func_graph_value_nodes_.count(node) != 0) {
if (j_func_graph_value_nodes_[node] == 1) {
(void)j_func_graph_value_nodes_.erase(node);
void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); }
void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) {
if (j_func_graphs_.count(fg) == 0) {
j_func_graphs_[fg] = count;
} else {
j_func_graphs_[fg] += count;
}
}
void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) {
if (j_func_graphs_.count(fg) != 0) {
if (j_func_graphs_[fg] == 1) {
(void)j_func_graphs_.erase(fg);
} else {
j_func_graph_value_nodes_[node]--;
if (j_func_graph_value_nodes_[node] < 0) {
MS_LOG(EXCEPTION) << "Count of value node(J FuncGraph) '" << node
j_func_graphs_[fg]--;
if (j_func_graphs_[fg] < 0) {
MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
}
}

View File

@ -189,21 +189,21 @@ class FuncGraph : public FuncGraphBase {
// get all nodes belonging to this func graph
const AnfNodeSet &nodes();
void CopyNodes(const AnfNodeSet &other_nodes);
void CopyNodes(const FuncGraphPtr &source);
void ClearNodes();
void AddNode(AnfNodePtr node);
void DropNode(AnfNodePtr node);
// get all value_nodes belonging to this func graph
const AnfNodeCounterMap &value_nodes();
void CopyValueNodes(const AnfNodeCounterMap &other_value_nodes);
void CopyValueNodes(const FuncGraphPtr &source);
void ClearValueNodes();
void AddValueNode(AnfNodePtr node, int count = 1);
void DropValueNode(AnfNodePtr node);
// get all free vars directly used in this func graph
const AnfNodeCounterMap &free_variables();
void CopyFreeVariables(const AnfNodeCounterMap &others);
void CopyFreeVariables(const FuncGraphPtr &source);
void ClearFreeVariables();
bool AddFreeVariable(AnfNodePtr node, int count = 1);
bool DropFreeVariable(AnfNodePtr node);
@ -218,25 +218,25 @@ class FuncGraph : public FuncGraphBase {
std::vector<FuncGraphPtr> free_variables_func_graphs();
// get all value nodes of func graph directly used by this func graph
const AnfNodeCounterMap &func_graph_value_nodes();
void CopyFuncGraphValueNodes(const AnfNodeCounterMap &others);
void ClearFuncGraphValueNodes();
bool AddFuncGraphValueNode(AnfNodePtr node, int count = 1);
bool DropFuncGraphValueNode(AnfNodePtr node);
const FuncGraphCounterMap &func_graphs_used();
void CopyFuncGraphsUsed(const FuncGraphPtr &source);
void ClearFuncGraphsUsed();
bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1);
bool DropFuncGraphUsed(FuncGraphPtr fg);
// get all value nodes of J func graph directly used by this func graph
const AnfNodeCounterMap &j_func_graph_value_nodes();
void CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others);
void ClearJFuncGraphValueNodes();
void AddJFuncGraphValueNode(AnfNodePtr node, int count = 1);
void DropJFuncGraphValueNode(AnfNodePtr node);
const FuncGraphCounterMap &j_func_graphs();
void CopyJFuncGraphs(const FuncGraphPtr &source);
void ClearJFuncGraphs();
void AddJFuncGraph(FuncGraphPtr fg, int count = 1);
void DropJFuncGraph(FuncGraphPtr fg);
// get all func graphs nested used by this func graph
const FuncGraphSet &func_graphs_used_total();
// get all user value nodes of this func graph, by CNode and its input's index
const CNodeIndexCounterMap &func_graph_cnodes_index();
void CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &other_value_nodes);
void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
void ClearFuncGraphCNodesIndex();
void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);
@ -311,13 +311,13 @@ class FuncGraph : public FuncGraphBase {
AnfNodeCounterMap value_nodes_;
// all func graph value nodes of the function
AnfNodeCounterMap func_graph_value_nodes_;
FuncGraphCounterMap func_graphs_used_;
// all free variables of the function
AnfNodeCounterMap free_variables_;
// all value nodes calling J in the function
AnfNodeCounterMap j_func_graph_value_nodes_;
FuncGraphCounterMap j_func_graphs_;
// all user value nodes of this func graph, recording by CNode and its input's index
CNodeIndexCounterMap func_graph_cnodes_index_;

View File

@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
if (!clone_all_used_graphs_) {
return;
}
auto &used = func_graph->func_graph_value_nodes();
for (auto &fg_value_node : used) {
todo_.push_back({GetValueNode<FuncGraphPtr>(fg_value_node.first), nullptr, {}});
auto &used = func_graph->func_graphs_used();
for (auto &fg : used) {
todo_.push_back({fg.first, nullptr, {}});
}
}

View File

@ -196,7 +196,6 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) {
return;
}
AddIntoManaged(func_graph);
MS_EXCEPTION_IF_NULL(signals_);
std::vector<AnfNodePtr> para = func_graph->parameters();
AcquireNodes(para);
std::vector<AnfNodePtr> return_vec({func_graph->get_return()});
@ -301,7 +300,6 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool
std::vector<AnfNodePtr> return_vec = {func_graph->get_return()};
todo.update(MaybeDropNodes(return_vec));
}
MS_EXCEPTION_IF_NULL(signals_);
for (auto &fg : dropped) {
MS_EXCEPTION_IF_NULL(fg);
all_nodes_.difference_update(fg->parameters());
@ -334,7 +332,6 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
}
auto &users_node = node_users_[inp];
users_node.add(make_pair(node, index));
MS_EXCEPTION_IF_NULL(signals_);
AddEdge(node, index, inp);
}
}
@ -384,8 +381,6 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) {
AnfNodeSet nodes_ordered(nodes);
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
MS_EXCEPTION_IF_NULL(signals_);
while (!nodes_ordered.empty()) {
AnfNodePtr node = nodes_ordered.pop();
MS_EXCEPTION_IF_NULL(node);
@ -475,13 +470,13 @@ inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr inp
if (input->isa<ValueNode>()) {
fg->AddValueNode(input);
if (IsValueNode<FuncGraph>(input)) {
if (fg->AddFuncGraphValueNode(input)) {
signals_->InvalidateComputer();
}
auto used = GetValueNode<FuncGraphPtr>(input);
used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
if (fg->AddFuncGraphUsed(used)) {
signals_->InvalidateComputer();
}
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->AddJFuncGraphValueNode(input);
fg->AddJFuncGraph(used);
}
}
} else if (fg != nullptr && fg != input->func_graph()) {
@ -496,13 +491,13 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in
if (input->isa<ValueNode>()) {
fg->DropValueNode(input);
if (IsValueNode<FuncGraph>(input)) {
if (fg->DropFuncGraphValueNode(input)) {
signals_->InvalidateComputer();
}
auto used = GetValueNode<FuncGraphPtr>(input);
used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
if (fg->DropFuncGraphUsed(used)) {
signals_->InvalidateComputer();
}
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
fg->DropJFuncGraphValueNode(input);
fg->DropJFuncGraph(used);
}
}
} else if (fg != nullptr && fg != input->func_graph()) {
@ -513,19 +508,19 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in
}
inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) {
target->CopyNodes(source->nodes());
target->CopyValueNodes(source->value_nodes());
target->CopyFuncGraphCNodesIndex(source->func_graph_cnodes_index());
target->CopyFreeVariables(source->free_variables());
target->CopyFuncGraphValueNodes(source->func_graph_value_nodes());
target->CopyJFuncGraphValueNodes(source->j_func_graph_value_nodes());
target->CopyNodes(source);
target->CopyValueNodes(source);
target->CopyFuncGraphCNodesIndex(source);
target->CopyFreeVariables(source);
target->CopyFuncGraphsUsed(source);
target->CopyJFuncGraphs(source);
signals_->InvalidateComputer();
source->ClearNodes();
source->ClearValueNodes();
source->ClearFuncGraphCNodesIndex();
source->ClearFreeVariables();
source->ClearFuncGraphValueNodes();
source->ClearJFuncGraphValueNodes();
source->ClearFuncGraphsUsed();
source->ClearJFuncGraphs();
}
FuncGraphTransaction FuncGraphManager::Transact() {
@ -768,10 +763,10 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f
}
// Search the fv in fg's child func graph.
auto &fg_value_nodes = fg->func_graph_value_nodes();
for (auto &fg_value_node : fg_value_nodes) {
auto &fgs = fg->func_graphs_used();
for (auto &item : fgs) {
fg->seen_ = seen_num;
auto gt = GetValueNode<FuncGraphPtr>(fg_value_node.first);
auto gt = item.first;
parents->update(SeekParents(gt, seen_num));
}
(void)parents->erase(fg);
@ -865,15 +860,15 @@ void FVTotalComputer::RealRecompute() {
}
}
auto &used = fg->func_graph_value_nodes();
auto &used = fg->func_graphs_used();
for (auto &iter : used) {
auto p = manager->parent(GetValueNode<FuncGraphPtr>(iter.first));
auto p = manager->parent(iter.first);
if (p == nullptr) {
continue;
}
auto curr = fg;
while (curr != p) {
(void)CounterFuncGraphCollector::Mod(curr, GetValueNode<FuncGraphPtr>(iter.first), iter.second);
(void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second);
curr = manager->parent(curr);
}
}
@ -899,8 +894,8 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
while (!todo.empty()) {
todo_new.clear();
for (auto &gt : todo) {
for (auto &item : gt->func_graph_value_nodes()) {
auto used_fg = GetValueNode<FuncGraphPtr>(item.first);
for (auto &item : gt->func_graphs_used()) {
auto used_fg = item.first;
if (used_fg == fg) {
func_graph_used_total_analysis_[fg].add(used_fg);
continue;
@ -925,8 +920,8 @@ bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &f
while (!todo.empty()) {
todo_new.clear();
for (auto &gt : todo) {
for (auto &item : gt->func_graph_value_nodes()) {
auto used_g = GetValueNode<FuncGraphPtr>(item.first);
for (auto &item : gt->func_graphs_used()) {
auto used_g = item.first;
if (used_g == fg) {
return true;
}
@ -957,9 +952,9 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
}
} else {
trace->push_back(fg);
auto &items = fg->func_graph_value_nodes();
auto &items = fg->func_graphs_used();
for (auto iter = items.begin(); iter != items.end(); (void)iter++) {
CheckRecursiveGraphs(GetValueNode<FuncGraphPtr>(iter->first), trace);
CheckRecursiveGraphs(iter->first, trace);
}
trace->pop_back();
if (!recursive_map_.count(fg)) {
@ -973,14 +968,13 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
MS_LOG(DEBUG) << fg->ToString() << " had been checked";
return false;
}
auto &j_fg_value_nodes = fg->j_func_graph_value_nodes();
if (!j_fg_value_nodes.empty()) {
auto &j_fgs = fg->j_func_graphs();
if (!j_fgs.empty()) {
// check g1->J(fg)->g2->g cycle;
auto contains_j =
std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), [seen_num](const std::pair<AnfNodePtr, int> iter) {
return GetValueNode<FuncGraphPtr>(iter.first)->seen_ != seen_num;
});
if (contains_j != j_fg_value_nodes.end()) {
auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair<FuncGraphPtr, int> iter) {
return iter.first->seen_ != seen_num;
});
if (contains_j != j_fgs.end()) {
MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
return true;
}
@ -988,8 +982,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
fg->seen_ = seen_num;
// check if func graphs used contains J(func_graph);
for (auto &item : fg->func_graph_value_nodes()) {
auto used_g = GetValueNode<FuncGraphPtr>(item.first);
for (auto &item : fg->func_graphs_used()) {
auto used_g = item.first;
if (SeekJ(used_g, seen_num)) {
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
return true;

View File

@ -2187,7 +2187,7 @@ void MarkForwardCNode(const FuncGraphPtr &root) {
SetForwardFlag(all_nodes);
} else {
for (auto &func_graph : graph_set) {
MS_LOG(INFO) << "The sub graph size of root is " << root->func_graph_value_nodes().size();
MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
auto return_node = func_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
auto all_dfs_nodes = DeepLinkedGraphSearch(return_node);

View File

@ -462,8 +462,8 @@ TEST_F(TestManager, test_nested_manual) {
ASSERT_EQ(1, iter.second.size());
}
ASSERT_EQ(1, f->func_graph_value_nodes().size());
ASSERT_EQ(0, g->func_graph_value_nodes().size());
ASSERT_EQ(1, f->func_graphs_used().size());
ASSERT_EQ(0, g->func_graphs_used().size());
ASSERT_EQ(0, f->free_variables().size());
ASSERT_EQ(1, g->free_variables().size());