forked from mindspore-Ecosystem/mindspore
Adjust some routines of FG and FGM, about the nodes info. IF.
This commit is contained in:
parent
737bfc9595
commit
dbb86cb1be
|
@ -198,7 +198,7 @@ GraphDebugInfoPtr FuncGraph::debug_info() {
|
||||||
|
|
||||||
const AnfNodeSet &FuncGraph::nodes() { return nodes_; }
|
const AnfNodeSet &FuncGraph::nodes() { return nodes_; }
|
||||||
|
|
||||||
void FuncGraph::CopyNodes(const AnfNodeSet &other_nodes) { nodes_ = other_nodes; }
|
void FuncGraph::CopyNodes(const FuncGraphPtr &source) { nodes_ = source->nodes(); }
|
||||||
|
|
||||||
void FuncGraph::ClearNodes() { nodes_.clear(); }
|
void FuncGraph::ClearNodes() { nodes_.clear(); }
|
||||||
|
|
||||||
|
@ -215,7 +215,12 @@ void FuncGraph::DropNode(AnfNodePtr node) {
|
||||||
|
|
||||||
const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; }
|
const AnfNodeCounterMap &FuncGraph::value_nodes() { return value_nodes_; }
|
||||||
|
|
||||||
void FuncGraph::CopyValueNodes(const AnfNodeCounterMap &other_value_nodes) { value_nodes_ = other_value_nodes; }
|
void FuncGraph::CopyValueNodes(const FuncGraphPtr &source) {
|
||||||
|
auto &others = source->value_nodes();
|
||||||
|
for (auto it = others.begin(); it != others.end(); it++) {
|
||||||
|
AddValueNode(it->first, it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }
|
void FuncGraph::ClearValueNodes() { value_nodes_.clear(); }
|
||||||
|
|
||||||
|
@ -243,9 +248,9 @@ void FuncGraph::DropValueNode(AnfNodePtr node) {
|
||||||
|
|
||||||
const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; }
|
const AnfNodeCounterMap &FuncGraph::free_variables() { return free_variables_; }
|
||||||
|
|
||||||
void FuncGraph::CopyFreeVariables(const AnfNodeCounterMap &others) {
|
void FuncGraph::CopyFreeVariables(const FuncGraphPtr &source) {
|
||||||
auto it = others.begin();
|
auto &others = source->free_variables();
|
||||||
for (; it != others.end(); it++) {
|
for (auto it = others.begin(); it != others.end(); it++) {
|
||||||
if (it->first->func_graph().get() != this) {
|
if (it->first->func_graph().get() != this) {
|
||||||
(void)AddFreeVariable(it->first, it->second);
|
(void)AddFreeVariable(it->first, it->second);
|
||||||
}
|
}
|
||||||
|
@ -313,31 +318,37 @@ std::vector<FuncGraphPtr> FuncGraph::free_variables_func_graphs() {
|
||||||
return func_graphs;
|
return func_graphs;
|
||||||
}
|
}
|
||||||
|
|
||||||
const AnfNodeCounterMap &FuncGraph::func_graph_value_nodes() { return func_graph_value_nodes_; }
|
const FuncGraphCounterMap &FuncGraph::func_graphs_used() { return func_graphs_used_; }
|
||||||
|
|
||||||
void FuncGraph::CopyFuncGraphValueNodes(const AnfNodeCounterMap &others) { func_graph_value_nodes_ = others; }
|
void FuncGraph::CopyFuncGraphsUsed(const FuncGraphPtr &source) {
|
||||||
|
auto &others = source->func_graphs_used();
|
||||||
|
for (auto it = others.begin(); it != others.end(); it++) {
|
||||||
|
(void)AddFuncGraphUsed(it->first, it->second);
|
||||||
|
}
|
||||||
|
func_graphs_used_.erase(source);
|
||||||
|
}
|
||||||
|
|
||||||
void FuncGraph::ClearFuncGraphValueNodes() { func_graph_value_nodes_.clear(); }
|
void FuncGraph::ClearFuncGraphsUsed() { func_graphs_used_.clear(); }
|
||||||
|
|
||||||
bool FuncGraph::AddFuncGraphValueNode(AnfNodePtr node, int count) {
|
bool FuncGraph::AddFuncGraphUsed(FuncGraphPtr fg, int count) {
|
||||||
if (func_graph_value_nodes_.count(node) == 0) {
|
if (func_graphs_used_.count(fg) == 0) {
|
||||||
func_graph_value_nodes_[node] = count;
|
func_graphs_used_[fg] = count;
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
func_graph_value_nodes_[node] += count;
|
func_graphs_used_[fg] += count;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FuncGraph::DropFuncGraphValueNode(AnfNodePtr node) {
|
bool FuncGraph::DropFuncGraphUsed(FuncGraphPtr fg) {
|
||||||
if (func_graph_value_nodes_.count(node) != 0) {
|
if (func_graphs_used_.count(fg) != 0) {
|
||||||
if (func_graph_value_nodes_[node] == 1) {
|
if (func_graphs_used_[fg] == 1) {
|
||||||
(void)func_graph_value_nodes_.erase(node);
|
(void)func_graphs_used_.erase(fg);
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
func_graph_value_nodes_[node]--;
|
func_graphs_used_[fg]--;
|
||||||
if (func_graph_value_nodes_[node] < 0) {
|
if (func_graphs_used_[fg] < 0) {
|
||||||
MS_LOG(EXCEPTION) << "Count of value node(FuncGraph) '" << node
|
MS_LOG(EXCEPTION) << "Count of FuncGraph '" << fg
|
||||||
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
|
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -354,11 +365,13 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
|
||||||
|
|
||||||
const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; }
|
const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() { return func_graph_cnodes_index_; }
|
||||||
|
|
||||||
void FuncGraph::CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &others) {
|
void FuncGraph::CopyFuncGraphCNodesIndex(const FuncGraphPtr &source) {
|
||||||
auto it = others.begin();
|
auto &others = source->func_graph_cnodes_index();
|
||||||
for (; it != others.end(); it++) {
|
for (auto it = others.begin(); it != others.end(); it++) {
|
||||||
// Ignore the user graph who may own itself.
|
// Ignore the user graph who may own itself.
|
||||||
if (it->first->first->func_graph().get() != this) {
|
auto fg = it->first->first->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
|
if (fg.get() != this) {
|
||||||
AddFuncGraphCNodeIndex(it->first, it->second);
|
AddFuncGraphCNodeIndex(it->first, it->second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -388,28 +401,33 @@ void FuncGraph::DropFuncGraphCNodeIndex(CNodeIndexPairPtr pair) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const AnfNodeCounterMap &FuncGraph::j_func_graph_value_nodes() { return j_func_graph_value_nodes_; }
|
const FuncGraphCounterMap &FuncGraph::j_func_graphs() { return j_func_graphs_; }
|
||||||
|
|
||||||
void FuncGraph::CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others) { j_func_graph_value_nodes_ = others; }
|
void FuncGraph::CopyJFuncGraphs(const FuncGraphPtr &source) {
|
||||||
|
auto &others = source->j_func_graphs();
|
||||||
void FuncGraph::ClearJFuncGraphValueNodes() { j_func_graph_value_nodes_.clear(); }
|
for (auto it = others.begin(); it != others.end(); it++) {
|
||||||
|
AddJFuncGraph(it->first, it->second);
|
||||||
void FuncGraph::AddJFuncGraphValueNode(AnfNodePtr node, int count) {
|
|
||||||
if (j_func_graph_value_nodes_.count(node) == 0) {
|
|
||||||
j_func_graph_value_nodes_[node] = count;
|
|
||||||
} else {
|
|
||||||
j_func_graph_value_nodes_[node] += count;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FuncGraph::DropJFuncGraphValueNode(AnfNodePtr node) {
|
void FuncGraph::ClearJFuncGraphs() { j_func_graphs_.clear(); }
|
||||||
if (j_func_graph_value_nodes_.count(node) != 0) {
|
|
||||||
if (j_func_graph_value_nodes_[node] == 1) {
|
void FuncGraph::AddJFuncGraph(FuncGraphPtr fg, int count) {
|
||||||
(void)j_func_graph_value_nodes_.erase(node);
|
if (j_func_graphs_.count(fg) == 0) {
|
||||||
|
j_func_graphs_[fg] = count;
|
||||||
|
} else {
|
||||||
|
j_func_graphs_[fg] += count;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FuncGraph::DropJFuncGraph(FuncGraphPtr fg) {
|
||||||
|
if (j_func_graphs_.count(fg) != 0) {
|
||||||
|
if (j_func_graphs_[fg] == 1) {
|
||||||
|
(void)j_func_graphs_.erase(fg);
|
||||||
} else {
|
} else {
|
||||||
j_func_graph_value_nodes_[node]--;
|
j_func_graphs_[fg]--;
|
||||||
if (j_func_graph_value_nodes_[node] < 0) {
|
if (j_func_graphs_[fg] < 0) {
|
||||||
MS_LOG(EXCEPTION) << "Count of value node(J FuncGraph) '" << node
|
MS_LOG(EXCEPTION) << "Count of J FuncGraph '" << fg
|
||||||
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
|
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(debug_info());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -189,21 +189,21 @@ class FuncGraph : public FuncGraphBase {
|
||||||
|
|
||||||
// get all nodes belonging to this func graph
|
// get all nodes belonging to this func graph
|
||||||
const AnfNodeSet &nodes();
|
const AnfNodeSet &nodes();
|
||||||
void CopyNodes(const AnfNodeSet &other_nodes);
|
void CopyNodes(const FuncGraphPtr &source);
|
||||||
void ClearNodes();
|
void ClearNodes();
|
||||||
void AddNode(AnfNodePtr node);
|
void AddNode(AnfNodePtr node);
|
||||||
void DropNode(AnfNodePtr node);
|
void DropNode(AnfNodePtr node);
|
||||||
|
|
||||||
// get all value_nodes belonging to this func graph
|
// get all value_nodes belonging to this func graph
|
||||||
const AnfNodeCounterMap &value_nodes();
|
const AnfNodeCounterMap &value_nodes();
|
||||||
void CopyValueNodes(const AnfNodeCounterMap &other_value_nodes);
|
void CopyValueNodes(const FuncGraphPtr &source);
|
||||||
void ClearValueNodes();
|
void ClearValueNodes();
|
||||||
void AddValueNode(AnfNodePtr node, int count = 1);
|
void AddValueNode(AnfNodePtr node, int count = 1);
|
||||||
void DropValueNode(AnfNodePtr node);
|
void DropValueNode(AnfNodePtr node);
|
||||||
|
|
||||||
// get all free vars directly used in this func graph
|
// get all free vars directly used in this func graph
|
||||||
const AnfNodeCounterMap &free_variables();
|
const AnfNodeCounterMap &free_variables();
|
||||||
void CopyFreeVariables(const AnfNodeCounterMap &others);
|
void CopyFreeVariables(const FuncGraphPtr &source);
|
||||||
void ClearFreeVariables();
|
void ClearFreeVariables();
|
||||||
bool AddFreeVariable(AnfNodePtr node, int count = 1);
|
bool AddFreeVariable(AnfNodePtr node, int count = 1);
|
||||||
bool DropFreeVariable(AnfNodePtr node);
|
bool DropFreeVariable(AnfNodePtr node);
|
||||||
|
@ -218,25 +218,25 @@ class FuncGraph : public FuncGraphBase {
|
||||||
std::vector<FuncGraphPtr> free_variables_func_graphs();
|
std::vector<FuncGraphPtr> free_variables_func_graphs();
|
||||||
|
|
||||||
// get all value nodes of func graph directly used by this func graph
|
// get all value nodes of func graph directly used by this func graph
|
||||||
const AnfNodeCounterMap &func_graph_value_nodes();
|
const FuncGraphCounterMap &func_graphs_used();
|
||||||
void CopyFuncGraphValueNodes(const AnfNodeCounterMap &others);
|
void CopyFuncGraphsUsed(const FuncGraphPtr &source);
|
||||||
void ClearFuncGraphValueNodes();
|
void ClearFuncGraphsUsed();
|
||||||
bool AddFuncGraphValueNode(AnfNodePtr node, int count = 1);
|
bool AddFuncGraphUsed(FuncGraphPtr fg, int count = 1);
|
||||||
bool DropFuncGraphValueNode(AnfNodePtr node);
|
bool DropFuncGraphUsed(FuncGraphPtr fg);
|
||||||
|
|
||||||
// get all value nodes of J func graph directly used by this func graph
|
// get all value nodes of J func graph directly used by this func graph
|
||||||
const AnfNodeCounterMap &j_func_graph_value_nodes();
|
const FuncGraphCounterMap &j_func_graphs();
|
||||||
void CopyJFuncGraphValueNodes(const AnfNodeCounterMap &others);
|
void CopyJFuncGraphs(const FuncGraphPtr &source);
|
||||||
void ClearJFuncGraphValueNodes();
|
void ClearJFuncGraphs();
|
||||||
void AddJFuncGraphValueNode(AnfNodePtr node, int count = 1);
|
void AddJFuncGraph(FuncGraphPtr fg, int count = 1);
|
||||||
void DropJFuncGraphValueNode(AnfNodePtr node);
|
void DropJFuncGraph(FuncGraphPtr fg);
|
||||||
|
|
||||||
// get all func graphs nested used by this func graph
|
// get all func graphs nested used by this func graph
|
||||||
const FuncGraphSet &func_graphs_used_total();
|
const FuncGraphSet &func_graphs_used_total();
|
||||||
|
|
||||||
// get all user value nodes of this func graph, by CNode and its input's index
|
// get all user value nodes of this func graph, by CNode and its input's index
|
||||||
const CNodeIndexCounterMap &func_graph_cnodes_index();
|
const CNodeIndexCounterMap &func_graph_cnodes_index();
|
||||||
void CopyFuncGraphCNodesIndex(const CNodeIndexCounterMap &other_value_nodes);
|
void CopyFuncGraphCNodesIndex(const FuncGraphPtr &source);
|
||||||
void ClearFuncGraphCNodesIndex();
|
void ClearFuncGraphCNodesIndex();
|
||||||
void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
|
void AddFuncGraphCNodeIndex(CNodeIndexPairPtr node, int count = 1);
|
||||||
void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);
|
void DropFuncGraphCNodeIndex(CNodeIndexPairPtr node);
|
||||||
|
@ -311,13 +311,13 @@ class FuncGraph : public FuncGraphBase {
|
||||||
AnfNodeCounterMap value_nodes_;
|
AnfNodeCounterMap value_nodes_;
|
||||||
|
|
||||||
// all func graph value nodes of the function
|
// all func graph value nodes of the function
|
||||||
AnfNodeCounterMap func_graph_value_nodes_;
|
FuncGraphCounterMap func_graphs_used_;
|
||||||
|
|
||||||
// all free variables of the function
|
// all free variables of the function
|
||||||
AnfNodeCounterMap free_variables_;
|
AnfNodeCounterMap free_variables_;
|
||||||
|
|
||||||
// all value nodes calling J in the function
|
// all value nodes calling J in the function
|
||||||
AnfNodeCounterMap j_func_graph_value_nodes_;
|
FuncGraphCounterMap j_func_graphs_;
|
||||||
|
|
||||||
// all user value nodes of this func graph, recording by CNode and its input's index
|
// all user value nodes of this func graph, recording by CNode and its input's index
|
||||||
CNodeIndexCounterMap func_graph_cnodes_index_;
|
CNodeIndexCounterMap func_graph_cnodes_index_;
|
||||||
|
|
|
@ -153,9 +153,9 @@ void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) {
|
||||||
if (!clone_all_used_graphs_) {
|
if (!clone_all_used_graphs_) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto &used = func_graph->func_graph_value_nodes();
|
auto &used = func_graph->func_graphs_used();
|
||||||
for (auto &fg_value_node : used) {
|
for (auto &fg : used) {
|
||||||
todo_.push_back({GetValueNode<FuncGraphPtr>(fg_value_node.first), nullptr, {}});
|
todo_.push_back({fg.first, nullptr, {}});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -196,7 +196,6 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
AddIntoManaged(func_graph);
|
AddIntoManaged(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(signals_);
|
|
||||||
std::vector<AnfNodePtr> para = func_graph->parameters();
|
std::vector<AnfNodePtr> para = func_graph->parameters();
|
||||||
AcquireNodes(para);
|
AcquireNodes(para);
|
||||||
std::vector<AnfNodePtr> return_vec({func_graph->get_return()});
|
std::vector<AnfNodePtr> return_vec({func_graph->get_return()});
|
||||||
|
@ -301,7 +300,6 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool
|
||||||
std::vector<AnfNodePtr> return_vec = {func_graph->get_return()};
|
std::vector<AnfNodePtr> return_vec = {func_graph->get_return()};
|
||||||
todo.update(MaybeDropNodes(return_vec));
|
todo.update(MaybeDropNodes(return_vec));
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(signals_);
|
|
||||||
for (auto &fg : dropped) {
|
for (auto &fg : dropped) {
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
all_nodes_.difference_update(fg->parameters());
|
all_nodes_.difference_update(fg->parameters());
|
||||||
|
@ -334,7 +332,6 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E
|
||||||
}
|
}
|
||||||
auto &users_node = node_users_[inp];
|
auto &users_node = node_users_[inp];
|
||||||
users_node.add(make_pair(node, index));
|
users_node.add(make_pair(node, index));
|
||||||
MS_EXCEPTION_IF_NULL(signals_);
|
|
||||||
AddEdge(node, index, inp);
|
AddEdge(node, index, inp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -384,8 +381,6 @@ void FuncGraphManager::AcquireNodes(const std::vector<AnfNodePtr> &nodes) {
|
||||||
FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) {
|
FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> &nodes) {
|
||||||
AnfNodeSet nodes_ordered(nodes);
|
AnfNodeSet nodes_ordered(nodes);
|
||||||
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
|
FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>();
|
||||||
MS_EXCEPTION_IF_NULL(signals_);
|
|
||||||
|
|
||||||
while (!nodes_ordered.empty()) {
|
while (!nodes_ordered.empty()) {
|
||||||
AnfNodePtr node = nodes_ordered.pop();
|
AnfNodePtr node = nodes_ordered.pop();
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
@ -475,13 +470,13 @@ inline void FuncGraphManager::AddEdge(AnfNodePtr node, int index, AnfNodePtr inp
|
||||||
if (input->isa<ValueNode>()) {
|
if (input->isa<ValueNode>()) {
|
||||||
fg->AddValueNode(input);
|
fg->AddValueNode(input);
|
||||||
if (IsValueNode<FuncGraph>(input)) {
|
if (IsValueNode<FuncGraph>(input)) {
|
||||||
if (fg->AddFuncGraphValueNode(input)) {
|
|
||||||
signals_->InvalidateComputer();
|
|
||||||
}
|
|
||||||
auto used = GetValueNode<FuncGraphPtr>(input);
|
auto used = GetValueNode<FuncGraphPtr>(input);
|
||||||
used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
|
used->AddFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
|
||||||
|
if (fg->AddFuncGraphUsed(used)) {
|
||||||
|
signals_->InvalidateComputer();
|
||||||
|
}
|
||||||
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
|
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
|
||||||
fg->AddJFuncGraphValueNode(input);
|
fg->AddJFuncGraph(used);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (fg != nullptr && fg != input->func_graph()) {
|
} else if (fg != nullptr && fg != input->func_graph()) {
|
||||||
|
@ -496,13 +491,13 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in
|
||||||
if (input->isa<ValueNode>()) {
|
if (input->isa<ValueNode>()) {
|
||||||
fg->DropValueNode(input);
|
fg->DropValueNode(input);
|
||||||
if (IsValueNode<FuncGraph>(input)) {
|
if (IsValueNode<FuncGraph>(input)) {
|
||||||
if (fg->DropFuncGraphValueNode(input)) {
|
|
||||||
signals_->InvalidateComputer();
|
|
||||||
}
|
|
||||||
auto used = GetValueNode<FuncGraphPtr>(input);
|
auto used = GetValueNode<FuncGraphPtr>(input);
|
||||||
used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
|
used->DropFuncGraphCNodeIndex(std::make_shared<CNodeIndexPair>(std::make_pair(node, index)));
|
||||||
|
if (fg->DropFuncGraphUsed(used)) {
|
||||||
|
signals_->InvalidateComputer();
|
||||||
|
}
|
||||||
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
|
if (IsPrimitiveCNode(node, prim::kPrimJ)) {
|
||||||
fg->DropJFuncGraphValueNode(input);
|
fg->DropJFuncGraph(used);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (fg != nullptr && fg != input->func_graph()) {
|
} else if (fg != nullptr && fg != input->func_graph()) {
|
||||||
|
@ -513,19 +508,19 @@ inline void FuncGraphManager::DropEdge(AnfNodePtr node, int index, AnfNodePtr in
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) {
|
inline void FuncGraphManager::MoveAllNodes(FuncGraphPtr source, FuncGraphPtr target) {
|
||||||
target->CopyNodes(source->nodes());
|
target->CopyNodes(source);
|
||||||
target->CopyValueNodes(source->value_nodes());
|
target->CopyValueNodes(source);
|
||||||
target->CopyFuncGraphCNodesIndex(source->func_graph_cnodes_index());
|
target->CopyFuncGraphCNodesIndex(source);
|
||||||
target->CopyFreeVariables(source->free_variables());
|
target->CopyFreeVariables(source);
|
||||||
target->CopyFuncGraphValueNodes(source->func_graph_value_nodes());
|
target->CopyFuncGraphsUsed(source);
|
||||||
target->CopyJFuncGraphValueNodes(source->j_func_graph_value_nodes());
|
target->CopyJFuncGraphs(source);
|
||||||
signals_->InvalidateComputer();
|
signals_->InvalidateComputer();
|
||||||
source->ClearNodes();
|
source->ClearNodes();
|
||||||
source->ClearValueNodes();
|
source->ClearValueNodes();
|
||||||
source->ClearFuncGraphCNodesIndex();
|
source->ClearFuncGraphCNodesIndex();
|
||||||
source->ClearFreeVariables();
|
source->ClearFreeVariables();
|
||||||
source->ClearFuncGraphValueNodes();
|
source->ClearFuncGraphsUsed();
|
||||||
source->ClearJFuncGraphValueNodes();
|
source->ClearJFuncGraphs();
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphTransaction FuncGraphManager::Transact() {
|
FuncGraphTransaction FuncGraphManager::Transact() {
|
||||||
|
@ -768,10 +763,10 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search the fv in fg's child func graph.
|
// Search the fv in fg's child func graph.
|
||||||
auto &fg_value_nodes = fg->func_graph_value_nodes();
|
auto &fgs = fg->func_graphs_used();
|
||||||
for (auto &fg_value_node : fg_value_nodes) {
|
for (auto &item : fgs) {
|
||||||
fg->seen_ = seen_num;
|
fg->seen_ = seen_num;
|
||||||
auto gt = GetValueNode<FuncGraphPtr>(fg_value_node.first);
|
auto gt = item.first;
|
||||||
parents->update(SeekParents(gt, seen_num));
|
parents->update(SeekParents(gt, seen_num));
|
||||||
}
|
}
|
||||||
(void)parents->erase(fg);
|
(void)parents->erase(fg);
|
||||||
|
@ -865,15 +860,15 @@ void FVTotalComputer::RealRecompute() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto &used = fg->func_graph_value_nodes();
|
auto &used = fg->func_graphs_used();
|
||||||
for (auto &iter : used) {
|
for (auto &iter : used) {
|
||||||
auto p = manager->parent(GetValueNode<FuncGraphPtr>(iter.first));
|
auto p = manager->parent(iter.first);
|
||||||
if (p == nullptr) {
|
if (p == nullptr) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto curr = fg;
|
auto curr = fg;
|
||||||
while (curr != p) {
|
while (curr != p) {
|
||||||
(void)CounterFuncGraphCollector::Mod(curr, GetValueNode<FuncGraphPtr>(iter.first), iter.second);
|
(void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second);
|
||||||
curr = manager->parent(curr);
|
curr = manager->parent(curr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -899,8 +894,8 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
|
||||||
while (!todo.empty()) {
|
while (!todo.empty()) {
|
||||||
todo_new.clear();
|
todo_new.clear();
|
||||||
for (auto > : todo) {
|
for (auto > : todo) {
|
||||||
for (auto &item : gt->func_graph_value_nodes()) {
|
for (auto &item : gt->func_graphs_used()) {
|
||||||
auto used_fg = GetValueNode<FuncGraphPtr>(item.first);
|
auto used_fg = item.first;
|
||||||
if (used_fg == fg) {
|
if (used_fg == fg) {
|
||||||
func_graph_used_total_analysis_[fg].add(used_fg);
|
func_graph_used_total_analysis_[fg].add(used_fg);
|
||||||
continue;
|
continue;
|
||||||
|
@ -925,8 +920,8 @@ bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &f
|
||||||
while (!todo.empty()) {
|
while (!todo.empty()) {
|
||||||
todo_new.clear();
|
todo_new.clear();
|
||||||
for (auto > : todo) {
|
for (auto > : todo) {
|
||||||
for (auto &item : gt->func_graph_value_nodes()) {
|
for (auto &item : gt->func_graphs_used()) {
|
||||||
auto used_g = GetValueNode<FuncGraphPtr>(item.first);
|
auto used_g = item.first;
|
||||||
if (used_g == fg) {
|
if (used_g == fg) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -957,9 +952,9 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
trace->push_back(fg);
|
trace->push_back(fg);
|
||||||
auto &items = fg->func_graph_value_nodes();
|
auto &items = fg->func_graphs_used();
|
||||||
for (auto iter = items.begin(); iter != items.end(); (void)iter++) {
|
for (auto iter = items.begin(); iter != items.end(); (void)iter++) {
|
||||||
CheckRecursiveGraphs(GetValueNode<FuncGraphPtr>(iter->first), trace);
|
CheckRecursiveGraphs(iter->first, trace);
|
||||||
}
|
}
|
||||||
trace->pop_back();
|
trace->pop_back();
|
||||||
if (!recursive_map_.count(fg)) {
|
if (!recursive_map_.count(fg)) {
|
||||||
|
@ -973,14 +968,13 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
|
||||||
MS_LOG(DEBUG) << fg->ToString() << " had been checked";
|
MS_LOG(DEBUG) << fg->ToString() << " had been checked";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto &j_fg_value_nodes = fg->j_func_graph_value_nodes();
|
auto &j_fgs = fg->j_func_graphs();
|
||||||
if (!j_fg_value_nodes.empty()) {
|
if (!j_fgs.empty()) {
|
||||||
// check g1->J(fg)->g2->g cycle;
|
// check g1->J(fg)->g2->g cycle;
|
||||||
auto contains_j =
|
auto contains_j = std::find_if(j_fgs.begin(), j_fgs.end(), [seen_num](const std::pair<FuncGraphPtr, int> iter) {
|
||||||
std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), [seen_num](const std::pair<AnfNodePtr, int> iter) {
|
return iter.first->seen_ != seen_num;
|
||||||
return GetValueNode<FuncGraphPtr>(iter.first)->seen_ != seen_num;
|
});
|
||||||
});
|
if (contains_j != j_fgs.end()) {
|
||||||
if (contains_j != j_fg_value_nodes.end()) {
|
|
||||||
MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
|
MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -988,8 +982,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
|
||||||
fg->seen_ = seen_num;
|
fg->seen_ = seen_num;
|
||||||
|
|
||||||
// check if func graphs used contains J(func_graph);
|
// check if func graphs used contains J(func_graph);
|
||||||
for (auto &item : fg->func_graph_value_nodes()) {
|
for (auto &item : fg->func_graphs_used()) {
|
||||||
auto used_g = GetValueNode<FuncGraphPtr>(item.first);
|
auto used_g = item.first;
|
||||||
if (SeekJ(used_g, seen_num)) {
|
if (SeekJ(used_g, seen_num)) {
|
||||||
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
|
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -2187,7 +2187,7 @@ void MarkForwardCNode(const FuncGraphPtr &root) {
|
||||||
SetForwardFlag(all_nodes);
|
SetForwardFlag(all_nodes);
|
||||||
} else {
|
} else {
|
||||||
for (auto &func_graph : graph_set) {
|
for (auto &func_graph : graph_set) {
|
||||||
MS_LOG(INFO) << "The sub graph size of root is " << root->func_graph_value_nodes().size();
|
MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size();
|
||||||
auto return_node = func_graph->get_return();
|
auto return_node = func_graph->get_return();
|
||||||
MS_EXCEPTION_IF_NULL(return_node);
|
MS_EXCEPTION_IF_NULL(return_node);
|
||||||
auto all_dfs_nodes = DeepLinkedGraphSearch(return_node);
|
auto all_dfs_nodes = DeepLinkedGraphSearch(return_node);
|
||||||
|
|
|
@ -462,8 +462,8 @@ TEST_F(TestManager, test_nested_manual) {
|
||||||
ASSERT_EQ(1, iter.second.size());
|
ASSERT_EQ(1, iter.second.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
ASSERT_EQ(1, f->func_graph_value_nodes().size());
|
ASSERT_EQ(1, f->func_graphs_used().size());
|
||||||
ASSERT_EQ(0, g->func_graph_value_nodes().size());
|
ASSERT_EQ(0, g->func_graphs_used().size());
|
||||||
|
|
||||||
ASSERT_EQ(0, f->free_variables().size());
|
ASSERT_EQ(0, f->free_variables().size());
|
||||||
ASSERT_EQ(1, g->free_variables().size());
|
ASSERT_EQ(1, g->free_variables().size());
|
||||||
|
|
Loading…
Reference in New Issue