forked from mindspore-Ecosystem/mindspore
!847 Optimize the collectors in manager which listen to the graphs and nodes changes.
Merge pull request !847 from ZhangQinghua/opt
This commit is contained in:
commit
db3dd02838
|
@ -263,18 +263,15 @@ const FuncGraphSet &FuncGraph::func_graphs_used_total() {
|
|||
return used;
|
||||
}
|
||||
|
||||
const FuncGraphCounterMap &FuncGraph::func_graph_users() {
|
||||
const CNodeIndexCounterMap &FuncGraph::func_graph_cnodes_index() {
|
||||
auto mng = manager_.lock();
|
||||
if (mng == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "BUG: no manager for this func graph: " << ToString()
|
||||
<< " NodeInfo: " << trace::GetDebugInfo(debug_info());
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto &users = mng->func_graph_users();
|
||||
return users[shared_from_base<FuncGraph>()];
|
||||
}
|
||||
|
||||
const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() {
|
||||
auto mng = manager_.lock();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
auto &users = mng->func_graph_user_cnodes();
|
||||
return users[shared_from_base<FuncGraph>()];
|
||||
auto &cnode = mng->func_graph_cnodes_index();
|
||||
return cnode[shared_from_base<FuncGraph>()];
|
||||
}
|
||||
|
||||
FuncGraphPtr FuncGraph::parent() {
|
||||
|
|
|
@ -37,6 +37,7 @@ namespace mindspore {
|
|||
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
|
||||
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
|
||||
using AnfNodeCounterMap = OrderedMap<AnfNodePtr, int>;
|
||||
using CNodeIndexCounterMap = OrderedMap<CNodeIndexPairPtr, int, CNodeIndexHasher, CNodeIndexEqual>;
|
||||
|
||||
const char FUNC_GRAPH_FLAG_IGNORE_VALUES[] = "ignore_values";
|
||||
const char FUNC_GRAPH_FLAG_DEFER_INLINE[] = "defer_inline";
|
||||
|
@ -203,11 +204,8 @@ class FuncGraph : public FuncGraphBase {
|
|||
// get all func graphs nested used by this func graph
|
||||
const FuncGraphSet &func_graphs_used_total();
|
||||
|
||||
// get all users of this func graph
|
||||
const FuncGraphCounterMap &func_graph_users();
|
||||
|
||||
// get all user cnodes of this func graph
|
||||
const AnfNodeCounterMap &func_graph_user_cnodes();
|
||||
// get all user value nodes of this func graph
|
||||
const CNodeIndexCounterMap &func_graph_cnodes_index();
|
||||
|
||||
// Return the parent of this graph.
|
||||
FuncGraphPtr parent();
|
||||
|
|
|
@ -182,9 +182,11 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
|
|||
}
|
||||
target_func_graph->set_return(return_node);
|
||||
|
||||
auto &value_nodes = manager_->func_graph_valuenodes()[func_graph];
|
||||
for (auto &value_node : value_nodes) {
|
||||
CloneValueNode(value_node.first, target_func_graph);
|
||||
auto &cnodes = manager_->func_graph_cnodes_index()[func_graph];
|
||||
for (auto &cnode : cnodes) {
|
||||
auto parent = cnode.first->first->cast<CNodePtr>();
|
||||
auto valuenode = parent->input(cnode.first->second);
|
||||
CloneValueNode(valuenode, target_func_graph);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -386,8 +388,8 @@ void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraph
|
|||
if (lift_params.empty()) {
|
||||
return;
|
||||
}
|
||||
for (auto &user : func_graph_user->func_graph_users()) {
|
||||
LiftParameters(user.first, func_graph_user, lift_params);
|
||||
for (auto &cnode : func_graph_user->func_graph_cnodes_index()) {
|
||||
LiftParameters(cnode.first->first->func_graph(), func_graph_user, lift_params);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -395,8 +397,8 @@ void Cloner::Lift() {
|
|||
for (auto &func_graph_params : repl_func_graph_params_) {
|
||||
auto &func_graph = func_graph_params.first;
|
||||
auto ¶ms = func_graph_params.second;
|
||||
for (auto &user : func_graph->func_graph_users()) {
|
||||
LiftParameters(user.first, func_graph, params);
|
||||
for (auto &cnode : func_graph->func_graph_cnodes_index()) {
|
||||
LiftParameters(cnode.first->first->func_graph(), func_graph, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,13 +78,16 @@ void FuncGraphManager::Reset() {
|
|||
node_users_ = NodeUsersMap();
|
||||
|
||||
signals_ = std::make_shared<Signals>();
|
||||
// FuncGraph --> AnfNode
|
||||
nodes_ = std::make_shared<NodesCollector>(this);
|
||||
|
||||
// FuncGraph --> {AnfNode, Count}
|
||||
valuenodes_ = std::make_shared<ValueNodesCollector>(this);
|
||||
free_variables_direct_ = std::make_shared<FVDirectCollector>(this);
|
||||
func_graph_valuenodes_ = std::make_shared<FuncGraphValueNodesCollector>(this);
|
||||
func_graph_cnodes_index_ = std::make_shared<FuncGraphUsersCNodeIndexCollector>(this);
|
||||
|
||||
// FuncGraph --> {FuncGraph, Count}
|
||||
func_graphs_used_ = std::make_shared<FuncGraphsUsedCollector>(this);
|
||||
func_graph_users_ = std::make_shared<FuncGraphUsersCollector>(this);
|
||||
func_graph_user_cnodes_ = std::make_shared<FuncGraphUserNodesCollector>(this);
|
||||
func_graph_child_direct_ = std::make_shared<FuncGraphChildDirect>(this);
|
||||
func_graph_parents_direct_ = std::make_shared<FuncGraphParentsDirectCollector>(this);
|
||||
func_graph_j_direct_ = std::make_shared<FuncGraphJDirectCollector>(this);
|
||||
|
@ -300,9 +303,9 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool
|
|||
MS_LOG(DEBUG) << "Cannot drop as roots contains func graph: " << func_graph->ToString();
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(func_graph_users_);
|
||||
auto &users = func_graph_users_->count_func_graphs_map()[func_graph];
|
||||
if (!users.empty() && !ignore_users) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph_cnodes_index_);
|
||||
auto &users_cnode_index = func_graph_cnodes_index_->count_nodes_map()[func_graph];
|
||||
if (!users_cnode_index.empty() && !ignore_users) {
|
||||
MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString();
|
||||
continue;
|
||||
}
|
||||
|
@ -472,10 +475,6 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t
|
|||
node->set_scope(scope);
|
||||
}
|
||||
}
|
||||
for (auto &used : source->func_graphs_used()) {
|
||||
(void)func_graph_users_->Inc(used.first, target, used.second);
|
||||
(void)this->func_graph_users()[used.first].erase(source);
|
||||
}
|
||||
for (auto &child : this->func_graph_child_direct()[source]) {
|
||||
(void)func_graph_parents_direct_->Inc(child.first, target, child.second);
|
||||
(void)this->func_graph_parents_direct()[child.first].erase(source);
|
||||
|
@ -661,7 +660,9 @@ DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAna
|
|||
|
||||
void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
|
||||
|
||||
bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
|
||||
template <typename ValueT, class CollectorHash, class CollectorEqual>
|
||||
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Inc(const FuncGraphPtr &func_graph,
|
||||
const ValueT &key, int count) {
|
||||
auto &d = count_nodes_map_[func_graph];
|
||||
if (d.count(key) == 0) {
|
||||
d[key] = count;
|
||||
|
@ -672,7 +673,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
return false;
|
||||
}
|
||||
|
||||
bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) {
|
||||
template <typename ValueT, class CollectorHash, class CollectorEqual>
|
||||
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Dec(const FuncGraphPtr &func_graph,
|
||||
const ValueT &key, int count) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto &d = count_nodes_map_[func_graph];
|
||||
if (d.count(key) != 0) {
|
||||
|
@ -682,7 +685,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
} else {
|
||||
d[key] -= count;
|
||||
if (d[key] < 0) {
|
||||
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
|
||||
MS_LOG(EXCEPTION) << "Count of key '" << key
|
||||
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
||||
}
|
||||
}
|
||||
|
@ -690,17 +693,78 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodeP
|
|||
return false;
|
||||
}
|
||||
|
||||
bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) {
|
||||
template <typename ValueT, class CollectorHash, class CollectorEqual>
|
||||
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const FuncGraphPtr &func_graph,
|
||||
const ValueT &key, int count) {
|
||||
if (count > 0) {
|
||||
return Inc(func_graph, key, count);
|
||||
} else if (count < 0) {
|
||||
return Dec(func_graph, key, -count);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
|
||||
MS_LOG(EXCEPTION) << "Count of key '" << key
|
||||
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
||||
}
|
||||
}
|
||||
|
||||
void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (inp->isa<ValueNode>()) {
|
||||
(void)Mod(node->func_graph(), inp, direction);
|
||||
}
|
||||
}
|
||||
|
||||
void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_nodes_map_[src]) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
(void)count_nodes_map_.erase(src);
|
||||
}
|
||||
|
||||
void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp,
|
||||
EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsValueNode<FuncGraph>(inp)) {
|
||||
(void)Mod(GetValueNode<FuncGraphPtr>(inp), std::make_shared<CNodeIndexPair>(std::make_pair(node, index)),
|
||||
direction);
|
||||
}
|
||||
}
|
||||
|
||||
void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_nodes_map_[src]) {
|
||||
// Ignore the user graph who may own itself.
|
||||
if (dst != it.first->first->func_graph()) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
}
|
||||
(void)count_nodes_map_.erase(src);
|
||||
}
|
||||
|
||||
void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(inp);
|
||||
FuncGraphPtr fg1 = node->func_graph();
|
||||
FuncGraphPtr fg2 = inp->func_graph();
|
||||
if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
|
||||
(void)Mod(fg1, inp, direction);
|
||||
}
|
||||
}
|
||||
|
||||
void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_nodes_map_[src]) {
|
||||
FuncGraphPtr fg2 = it.first->func_graph();
|
||||
if (fg2 != dst) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
}
|
||||
(void)count_nodes_map_.erase(src);
|
||||
}
|
||||
|
||||
static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) {
|
||||
FuncGraphPtr gn = std::make_shared<FuncGraph>();
|
||||
(void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg)));
|
||||
return gn;
|
||||
}
|
||||
|
||||
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
|
||||
auto &d = count_func_graphs_map_[func_graph];
|
||||
if (d.count(key) == 0) {
|
||||
|
@ -740,60 +804,6 @@ bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGr
|
|||
}
|
||||
}
|
||||
|
||||
void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (inp->isa<ValueNode>()) {
|
||||
(void)Mod(node->func_graph(), inp, direction);
|
||||
}
|
||||
}
|
||||
|
||||
void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_nodes_map_[src]) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
(void)count_nodes_map_.erase(src);
|
||||
}
|
||||
|
||||
// if inp is a graph ValueNode, this graph's FuncGraphValueNodesCollector's value is inp self
|
||||
void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
if (IsValueNode<FuncGraph>(inp)) {
|
||||
(void)Mod(GetValueNode<FuncGraphPtr>(inp), inp, direction);
|
||||
}
|
||||
}
|
||||
|
||||
void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_nodes_map_[src]) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
(void)count_nodes_map_.erase(src);
|
||||
}
|
||||
|
||||
void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(inp);
|
||||
FuncGraphPtr fg1 = node->func_graph();
|
||||
FuncGraphPtr fg2 = inp->func_graph();
|
||||
if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
|
||||
(void)Mod(fg1, inp, direction);
|
||||
}
|
||||
}
|
||||
|
||||
void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_nodes_map_[src]) {
|
||||
FuncGraphPtr fg2 = it.first->func_graph();
|
||||
if (fg2 != dst) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
}
|
||||
(void)count_nodes_map_.erase(src);
|
||||
}
|
||||
|
||||
static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) {
|
||||
FuncGraphPtr gn = std::make_shared<FuncGraph>();
|
||||
(void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg)));
|
||||
return gn;
|
||||
}
|
||||
|
||||
void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(inp);
|
||||
|
@ -859,32 +869,6 @@ void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst)
|
|||
(void)count_func_graphs_map_.erase(src);
|
||||
}
|
||||
|
||||
void FuncGraphUsersCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsValueNode<FuncGraph>(inp)) {
|
||||
(void)Mod(GetValueNode<FuncGraphPtr>(inp), node->func_graph(), direction);
|
||||
}
|
||||
}
|
||||
|
||||
void FuncGraphUsersCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr) {
|
||||
// all graph use in src need to change to dst, so add dst user
|
||||
(void)count_func_graphs_map_.erase(src);
|
||||
}
|
||||
|
||||
void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsValueNode<FuncGraph>(inp)) {
|
||||
(void)Mod(GetValueNode<FuncGraphPtr>(inp), node, direction);
|
||||
}
|
||||
}
|
||||
|
||||
void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_nodes_map_[src]) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
(void)count_nodes_map_.erase(src);
|
||||
}
|
||||
|
||||
void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) {
|
||||
(void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction);
|
||||
|
|
|
@ -100,8 +100,12 @@ struct Signals {
|
|||
|
||||
enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 };
|
||||
|
||||
using CNodeIndexPair = std::pair<AnfNodePtr, int>;
|
||||
using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
|
||||
|
||||
using FuncGraphToFuncGraphCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<FuncGraphPtr, int>>;
|
||||
using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<AnfNodePtr, int>>;
|
||||
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
|
||||
using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<ValueT, int, CollectorHash, CollectorEqual>>;
|
||||
|
||||
// analysis base class
|
||||
class FuncGraphAnalysis {
|
||||
|
@ -174,6 +178,87 @@ class NodesCollector final : public DepCollector {
|
|||
void OnDropNode(AnfNodePtr n) override;
|
||||
};
|
||||
|
||||
struct CNodeIndexHasher {
|
||||
std::size_t operator()(const CNodeIndexPairPtr pair) const {
|
||||
MS_EXCEPTION_IF_NULL(pair);
|
||||
MS_EXCEPTION_IF_NULL(pair->first);
|
||||
return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
|
||||
}
|
||||
};
|
||||
|
||||
struct CNodeIndexEqual {
|
||||
bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
|
||||
if (lhs == nullptr || rhs == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (lhs == rhs) {
|
||||
return true;
|
||||
}
|
||||
if (lhs->first != rhs->first) {
|
||||
return false;
|
||||
}
|
||||
if (lhs->second != rhs->second) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
|
||||
class CounterAnfNodeCollector : public DepCollector {
|
||||
public:
|
||||
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
|
||||
~CounterAnfNodeCollector() override = default;
|
||||
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }
|
||||
|
||||
size_t size() const override { return count_nodes_map_.size(); }
|
||||
void OnAddFuncGraph(FuncGraphPtr fg) final {
|
||||
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
|
||||
}
|
||||
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
|
||||
|
||||
bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
|
||||
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;
|
||||
|
||||
protected:
|
||||
void ExtraReset() override { count_nodes_map_.clear(); }
|
||||
};
|
||||
|
||||
class ValueNodesCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
|
||||
public:
|
||||
explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||
~ValueNodesCollector() override = default;
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
// Record the CNode and its input index, who points to the function graph.
|
||||
class FuncGraphUsersCNodeIndexCollector final
|
||||
: public CounterAnfNodeCollector<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> {
|
||||
public:
|
||||
explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||
~FuncGraphUsersCNodeIndexCollector() override = default;
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
class FVDirectCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
|
||||
public:
|
||||
explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||
~FVDirectCollector() override = default;
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
class CounterFuncGraphCollector : public DepCollector {
|
||||
public:
|
||||
explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {}
|
||||
|
@ -193,56 +278,6 @@ class CounterFuncGraphCollector : public DepCollector {
|
|||
void ExtraReset() override { count_func_graphs_map_.clear(); }
|
||||
};
|
||||
|
||||
class CounterAnfNodeCollector : public DepCollector {
|
||||
public:
|
||||
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
|
||||
~CounterAnfNodeCollector() override = default;
|
||||
FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; }
|
||||
|
||||
size_t size() const override { return count_nodes_map_.size(); }
|
||||
void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>(); }
|
||||
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
|
||||
|
||||
bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
|
||||
bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
|
||||
bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count);
|
||||
|
||||
FuncGraphToAnfNodeCounterMap count_nodes_map_;
|
||||
|
||||
protected:
|
||||
void ExtraReset() override { count_nodes_map_.clear(); }
|
||||
};
|
||||
|
||||
class ValueNodesCollector final : public CounterAnfNodeCollector {
|
||||
public:
|
||||
explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||
~ValueNodesCollector() override = default;
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector {
|
||||
public:
|
||||
explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||
~FuncGraphValueNodesCollector() override = default;
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
class FVDirectCollector final : public CounterAnfNodeCollector {
|
||||
public:
|
||||
explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||
~FVDirectCollector() override = default;
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
class FuncGraphChildDirect final : public CounterFuncGraphCollector {
|
||||
public:
|
||||
explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||
|
@ -279,28 +314,6 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
|
|||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
// graph's all user graphs: key is g, value is graphs who used g
|
||||
class FuncGraphUsersCollector final : public CounterFuncGraphCollector {
|
||||
public:
|
||||
explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
~FuncGraphUsersCollector() override = default;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
// graph's all user cnodes: key is g, value is cnodes who used g
|
||||
class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector {
|
||||
public:
|
||||
explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
~FuncGraphUserNodesCollector() override = default;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
class FuncGraphJDirectCollector final : public CounterFuncGraphCollector {
|
||||
public:
|
||||
explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||
|
@ -433,7 +446,9 @@ class ScopeComputer final : public DepComputer {
|
|||
|
||||
using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>;
|
||||
|
||||
class FVTotalComputer final : public DepComputer, public CounterAnfNodeCollector, public CounterFuncGraphCollector {
|
||||
class FVTotalComputer final : public DepComputer,
|
||||
public CounterAnfNodeCollector<AnfNodePtr>,
|
||||
public CounterFuncGraphCollector {
|
||||
public:
|
||||
explicit FVTotalComputer(const FuncGraphManager *m)
|
||||
: DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {}
|
||||
|
@ -549,18 +564,18 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
|||
|
||||
FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; }
|
||||
|
||||
FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; }
|
||||
FuncGraphToAnfNodeCounterMap<AnfNodePtr> &valuenodes() const { return valuenodes_->count_nodes_map_; }
|
||||
|
||||
FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; }
|
||||
FuncGraphToAnfNodeCounterMap<AnfNodePtr> &free_variables_direct() const {
|
||||
return free_variables_direct_->count_nodes_map_;
|
||||
}
|
||||
|
||||
FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; }
|
||||
FuncGraphToAnfNodeCounterMap<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> &func_graph_cnodes_index() const {
|
||||
return func_graph_cnodes_index_->count_nodes_map_;
|
||||
}
|
||||
|
||||
FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; }
|
||||
|
||||
FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; }
|
||||
|
||||
FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; }
|
||||
|
||||
FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const {
|
||||
return func_graph_child_direct_->count_func_graphs_map_;
|
||||
}
|
||||
|
@ -598,10 +613,8 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
|||
std::shared_ptr<NodesCollector> nodes_;
|
||||
std::shared_ptr<ValueNodesCollector> valuenodes_;
|
||||
std::shared_ptr<FVDirectCollector> free_variables_direct_;
|
||||
std::shared_ptr<FuncGraphValueNodesCollector> func_graph_valuenodes_;
|
||||
std::shared_ptr<FuncGraphUsersCNodeIndexCollector> func_graph_cnodes_index_;
|
||||
std::shared_ptr<FuncGraphsUsedCollector> func_graphs_used_;
|
||||
std::shared_ptr<FuncGraphUsersCollector> func_graph_users_;
|
||||
std::shared_ptr<FuncGraphUserNodesCollector> func_graph_user_cnodes_;
|
||||
std::shared_ptr<FuncGraphChildDirect> func_graph_child_direct_;
|
||||
std::shared_ptr<FuncGraphParentsDirectCollector> func_graph_parents_direct_;
|
||||
std::shared_ptr<FuncGraphJDirectCollector> func_graph_j_direct_;
|
||||
|
|
|
@ -81,10 +81,10 @@ bool IsTrivial(const FuncGraphPtr &fg, AnfNodePtr) {
|
|||
}
|
||||
|
||||
bool IsUniqueUse(const FuncGraphPtr &fg, AnfNodePtr) {
|
||||
auto &users = fg->func_graph_users();
|
||||
auto &cnodes = fg->func_graph_cnodes_index();
|
||||
int n_use =
|
||||
std::accumulate(users.begin(), users.end(), 0,
|
||||
[](int sum, const std::pair<const FuncGraphPtr, int> &item) { return sum + item.second; });
|
||||
std::accumulate(cnodes.begin(), cnodes.end(), 0,
|
||||
[](int sum, const std::pair<const CNodeIndexPairPtr, int> &item) { return sum + item.second; });
|
||||
return n_use == 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -486,7 +486,8 @@ void CompileGraph::AddExternal(const LinConvertResult &result) {
|
|||
}
|
||||
|
||||
void TraverseGraphMap(
|
||||
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphToAnfNodeCounterMap &cts,
|
||||
const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr,
|
||||
const FuncGraphToAnfNodeCounterMap<AnfNodePtr> &cts,
|
||||
const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
|
||||
MS_EXCEPTION_IF_NULL(manager_ptr);
|
||||
MS_EXCEPTION_IF_NULL(tr);
|
||||
|
|
|
@ -127,12 +127,18 @@ class NestingSpecs {
|
|||
return;
|
||||
}
|
||||
|
||||
auto counter_p = dynamic_pointer_cast<CounterAnfNodeCollector>(results);
|
||||
auto counter_p = dynamic_pointer_cast<CounterAnfNodeCollector<AnfNodePtr>>(results);
|
||||
if (counter_p != nullptr) {
|
||||
CheckAnfNodeCounter(counter_p);
|
||||
return;
|
||||
}
|
||||
|
||||
auto counter_pair = dynamic_pointer_cast<CounterAnfNodeCollector<CNodeIndexPairPtr>>(results);
|
||||
if (counter_pair != nullptr) {
|
||||
CheckCNodeIndexPairCounter(counter_pair);
|
||||
return;
|
||||
}
|
||||
|
||||
auto nodes = dynamic_pointer_cast<NodesCollector>(results);
|
||||
if (nodes != nullptr) {
|
||||
CheckNodes(nodes);
|
||||
|
@ -226,7 +232,7 @@ class NestingSpecs {
|
|||
|
||||
// Add CheckNesting function
|
||||
|
||||
void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector> results) {
|
||||
void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
|
||||
std::map<std::string, std::set<std::string>> clean_results;
|
||||
for (auto& iter : results->count_nodes_map()) {
|
||||
auto key = iter.first;
|
||||
|
@ -252,6 +258,32 @@ class NestingSpecs {
|
|||
ASSERT_EQ(clean_results, expected_);
|
||||
}
|
||||
|
||||
void CheckCNodeIndexPairCounter(std::shared_ptr<CounterAnfNodeCollector<CNodeIndexPairPtr>> results) {
|
||||
std::map<std::string, std::set<std::string>> clean_results;
|
||||
for (auto& iter : results->count_nodes_map()) {
|
||||
auto key = iter.first;
|
||||
auto value = iter.second;
|
||||
if (key == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string k = Name(key);
|
||||
|
||||
std::set<std::string> v;
|
||||
for (auto& node : value) {
|
||||
auto fg = node.first->first;
|
||||
if (!Name(fg).empty()) {
|
||||
v.insert(Name(fg));
|
||||
}
|
||||
}
|
||||
|
||||
if (!v.empty()) {
|
||||
clean_results[k] = v;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(clean_results, expected_);
|
||||
}
|
||||
|
||||
void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
|
||||
std::map<std::string, std::set<std::string>> clean_results;
|
||||
for (auto& iter : results->count_func_graphs_map()) {
|
||||
|
@ -447,9 +479,8 @@ void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
|
|||
ASSERT_EQ(size, mng->free_variables_total().size());
|
||||
ASSERT_EQ(size, mng->valuenodes().size());
|
||||
ASSERT_EQ(size, mng->free_variables_direct().size());
|
||||
ASSERT_EQ(size, mng->func_graph_valuenodes().size());
|
||||
ASSERT_EQ(size, mng->func_graph_cnodes_index().size());
|
||||
ASSERT_EQ(size, mng->func_graph_parents_direct().size());
|
||||
ASSERT_EQ(size, mng->func_graph_users().size());
|
||||
ASSERT_EQ(size, mng->func_graphs_used().size());
|
||||
}
|
||||
|
||||
|
@ -508,10 +539,6 @@ TEST_F(TestManager, test_nested_manual) {
|
|||
ASSERT_EQ(1, graphs_used[f].size());
|
||||
ASSERT_EQ(0, graphs_used[g].size());
|
||||
|
||||
auto graph_users = mng->func_graph_users();
|
||||
ASSERT_EQ(0, graph_users[f].size());
|
||||
ASSERT_EQ(1, graph_users[g].size());
|
||||
|
||||
auto fv_direct = mng->free_variables_direct();
|
||||
ASSERT_EQ(0, fv_direct[f].size());
|
||||
ASSERT_EQ(1, fv_direct[g].size());
|
||||
|
@ -520,9 +547,9 @@ TEST_F(TestManager, test_nested_manual) {
|
|||
ASSERT_EQ(0, fv_total[f].size());
|
||||
ASSERT_EQ(1, fv_total[g].size());
|
||||
|
||||
auto graph_valuenodes = mng->func_graph_valuenodes();
|
||||
ASSERT_EQ(0, graph_valuenodes[f].size());
|
||||
ASSERT_EQ(1, graph_valuenodes[g].size());
|
||||
auto cnodes = mng->func_graph_cnodes_index();
|
||||
ASSERT_EQ(0, cnodes[f].size());
|
||||
ASSERT_EQ(1, cnodes[g].size());
|
||||
}
|
||||
|
||||
TEST_F(TestManager, test_deep_nested2_manual) {
|
||||
|
|
Loading…
Reference in New Issue