Remove the useless collectors in manager.

This commit is contained in:
Zhang Qinghua 2020-05-22 11:36:26 +08:00
parent 3ae925115f
commit f31564ce98
4 changed files with 46 additions and 402 deletions

View File

@ -640,53 +640,14 @@ void FuncGraphTransaction::Commit() {
} }
FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager)
: manager_(manager), include_func_graph_none_(false) { : manager_(manager), include_func_graph_none_(false) {}
manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph);
manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph);
manager_->signals()->AddEdge.connect(this, &FuncGraphAnalysis::OnAddEdge);
manager_->signals()->DropEdge.connect(this, &FuncGraphAnalysis::OnDropEdge);
manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode);
}
NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() {
include_func_graph_none_ = true;
nodes_analysis_[nullptr] = AnfNodeSet();
manager_->signals()->AddNode.connect(this, &NodesCollector::OnAddNode);
manager_->signals()->DropNode.connect(this, &NodesCollector::OnDropNode);
}
void NodesCollector::OnAddNode(AnfNodePtr n) {
if (nodes_analysis_.find(n->func_graph()) == nodes_analysis_.end()) {
nodes_analysis_[n->func_graph()] = AnfNodeSet();
}
nodes_analysis_[n->func_graph()].add(n);
}
void NodesCollector::OnDropNode(AnfNodePtr n) {
(void)nodes_analysis_[n->func_graph()].erase(n);
auto graph = n->func_graph();
// Remove the node from order list.
if (graph) {
graph->EraseUnusedNodeInOrder(n);
}
}
void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// change the owner of node except for the src's return node
for (auto &it : nodes_analysis_[src]) {
nodes_analysis_[dst].add(it);
}
(void)nodes_analysis_.erase(src);
}
void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector);
} }
void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
template <typename ValueT, class CollectorHash, class CollectorEqual> template <typename ValueT, class CollectorHash, class CollectorEqual>
@ -735,65 +696,6 @@ bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const F
} }
} }
void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (inp->isa<ValueNode>()) {
(void)Mod(node->func_graph(), inp, direction);
}
}
void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_nodes_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_nodes_map_.erase(src);
}
void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp,
EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<FuncGraph>(inp)) {
(void)Mod(GetValueNode<FuncGraphPtr>(inp), std::make_shared<CNodeIndexPair>(std::make_pair(node, index)),
direction);
}
}
void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_nodes_map_[src]) {
// Ignore the user graph who may own itself.
if (dst != it.first->first->func_graph()) {
(void)Inc(dst, it.first, it.second);
}
}
(void)count_nodes_map_.erase(src);
}
void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(inp);
FuncGraphPtr fg1 = node->func_graph();
FuncGraphPtr fg2 = inp->func_graph();
if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
(void)Mod(fg1, inp, direction);
}
}
void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_nodes_map_[src]) {
FuncGraphPtr fg2 = it.first->func_graph();
if (fg2 != dst) {
(void)Inc(dst, it.first, it.second);
}
}
(void)count_nodes_map_.erase(src);
}
static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) {
FuncGraphPtr gn = std::make_shared<FuncGraph>();
(void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg)));
return gn;
}
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph]; auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) == 0) { if (d.count(key) == 0) {
@ -833,87 +735,6 @@ bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGr
} }
} }
void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(inp);
FuncGraphPtr fg1 = node->func_graph();
FuncGraphPtr fg2 = inp->func_graph();
if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
(void)Mod(fg2, fg1, direction);
}
}
void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_func_graphs_map_[src]) {
FuncGraphPtr fg = it.first;
if (fg != dst) {
(void)Inc(dst, fg, it.second);
}
}
(void)count_func_graphs_map_.erase(src);
}
void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
FuncGraphPtr fg1 = node->func_graph();
// possible child parent
if (IsValueNode<FuncGraph>(inp)) {
FuncGraphPtr fg2 = GetValueNode<FuncGraphPtr>(inp);
if (Mod(fg1, ParentProxy(fg2), direction)) {
manager_->signals()->InvalidateComputer();
}
}
// from fv
FuncGraphPtr fg2 = inp->func_graph();
if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
// node use fv will in here, fg1's node use fg2's node, so fg1 is child and fg2 is parent
if (Mod(fg1, fg2, direction)) {
manager_->signals()->InvalidateComputer();
}
}
}
void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
for (auto &it : count_func_graphs_map_[src]) {
if (it.first != dst) {
(void)Inc(dst, it.first, it.second);
}
}
(void)count_func_graphs_map_.erase(src);
}
void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
MS_EXCEPTION_IF_NULL(node);
if (IsValueNode<FuncGraph>(inp)) {
(void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction);
}
}
void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// all graph use in src need to change to dst, so meger the to dst use
for (auto &it : count_func_graphs_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_func_graphs_map_[dst].erase(src);
(void)count_func_graphs_map_.erase(src);
}
void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) {
(void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction);
MS_LOG(DEBUG) << node->func_graph()->ToString() << " users func graph "
<< GetValueNode<FuncGraphPtr>(inp)->ToString() << " which contains J(func_graph), dir: " << direction;
}
}
void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
// all graph use in src need to change to dst, so meger the to dst use
for (auto &it : count_func_graphs_map_[src]) {
(void)Inc(dst, it.first, it.second);
}
(void)count_func_graphs_map_.erase(src);
}
DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);

View File

@ -140,44 +140,6 @@ class FuncGraphAnalysis {
using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>; using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>;
// graphs analysis which compute in write, read needn't recompute
class DepCollector : public FuncGraphAnalysis {
public:
explicit DepCollector(const FuncGraphManager *manager);
~DepCollector() override = default;
void Reset() { ExtraReset(); }
void OnInvalidateCollector() { Reset(); }
protected:
// inherit from FuncGraphAnalysis
void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
// subclass can override;
virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {}
};
class NodesCollector final : public DepCollector {
public:
explicit NodesCollector(const FuncGraphManager *m);
~NodesCollector() override = default;
const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; }
size_t size() const override { return nodes_analysis_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); }
void OnDropFuncGraph(FuncGraphPtr fg) override { (void)nodes_analysis_.erase(fg); }
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
FuncGraphToAnfNodeMap nodes_analysis_;
protected:
void ExtraReset() override { nodes_analysis_.clear(); }
void OnAddNode(AnfNodePtr n) override;
void OnDropNode(AnfNodePtr n) override;
};
struct CNodeIndexHasher { struct CNodeIndexHasher {
std::size_t operator()(const CNodeIndexPairPtr pair) const { std::size_t operator()(const CNodeIndexPairPtr pair) const {
MS_EXCEPTION_IF_NULL(pair); MS_EXCEPTION_IF_NULL(pair);
@ -204,59 +166,21 @@ struct CNodeIndexEqual {
} }
}; };
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>> // graphs analysis which compute in write, read needn't recompute
class CounterAnfNodeCollector : public DepCollector { class DepCollector : public FuncGraphAnalysis {
public: public:
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} explicit DepCollector(const FuncGraphManager *manager);
~CounterAnfNodeCollector() override = default; ~DepCollector() override = default;
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }
size_t size() const override { return count_nodes_map_.size(); } void Reset() { ExtraReset(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { void OnInvalidateCollector() { Reset(); }
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
}
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;
protected: protected:
void ExtraReset() override { count_nodes_map_.clear(); } // inherit from FuncGraphAnalysis
}; void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
class ValueNodesCollector final : public CounterAnfNodeCollector<AnfNodePtr> { // subclass can override;
public: virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {}
explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~ValueNodesCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
// Record the CNode and its input index, who points to the function graph.
class FuncGraphUsersCNodeIndexCollector final
: public CounterAnfNodeCollector<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> {
public:
explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FuncGraphUsersCNodeIndexCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
class FVDirectCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
public:
explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
~FVDirectCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
}; };
class CounterFuncGraphCollector : public DepCollector { class CounterFuncGraphCollector : public DepCollector {
@ -278,50 +202,27 @@ class CounterFuncGraphCollector : public DepCollector {
void ExtraReset() override { count_func_graphs_map_.clear(); } void ExtraReset() override { count_func_graphs_map_.clear(); }
}; };
class FuncGraphChildDirect final : public CounterFuncGraphCollector { template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
class CounterAnfNodeCollector : public DepCollector {
public: public:
explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~CounterAnfNodeCollector() override = default;
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }
~FuncGraphChildDirect() override = default; size_t size() const override { return count_nodes_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final {
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
}
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;
protected: protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override; void ExtraReset() override { count_nodes_map_.clear(); }
};
// graph's all parents, parentsdirect have a map, which key is graph, value is this graph's all direct and proxy
// parents:
// 1.proxy parent: graph g use graph f, key is g, value is ParentProxy(f) because f's parent will be g's parent
// 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f
class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
~FuncGraphParentsDirectCollector() override = default;
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
// graph's all used graphs: key is g, value is g used graph
class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
~FuncGraphsUsedCollector() override = default;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
};
class FuncGraphJDirectCollector final : public CounterFuncGraphCollector {
public:
explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override;
~FuncGraphJDirectCollector() override = default;
protected:
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
}; };
using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>; using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;

View File

@ -132,18 +132,6 @@ class NestingSpecs {
CheckAnfNodeCounter(counter_p); CheckAnfNodeCounter(counter_p);
return; return;
} }
auto counter_pair = dynamic_pointer_cast<CounterAnfNodeCollector<CNodeIndexPairPtr>>(results);
if (counter_pair != nullptr) {
CheckCNodeIndexPairCounter(counter_pair);
return;
}
auto nodes = dynamic_pointer_cast<NodesCollector>(results);
if (nodes != nullptr) {
CheckNodes(nodes);
return;
}
} }
private: private:
@ -205,33 +193,7 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_); ASSERT_EQ(clean_results, expected_);
} }
void CheckNodes(std::shared_ptr<NodesCollector> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->nodes_analysis()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);
std::set<std::string> v;
for (auto& node : value) {
if (!node->isa<CNode>() && !Name(node).empty()) {
v.insert(Name(node));
}
}
if (!v.empty()) {
clean_results[k] = v;
}
}
ASSERT_EQ(clean_results, expected_);
}
// Add CheckNesting function // Add CheckNesting function
void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) { void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
std::map<std::string, std::set<std::string>> clean_results; std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) { for (auto& iter : results->count_nodes_map()) {
@ -258,32 +220,6 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_); ASSERT_EQ(clean_results, expected_);
} }
void CheckCNodeIndexPairCounter(std::shared_ptr<CounterAnfNodeCollector<CNodeIndexPairPtr>> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);
std::set<std::string> v;
for (auto& node : value) {
auto fg = node.first->first;
if (!Name(fg).empty()) {
v.insert(Name(fg));
}
}
if (!v.empty()) {
clean_results[k] = v;
}
}
ASSERT_EQ(clean_results, expected_);
}
void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) { void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
std::map<std::string, std::set<std::string>> clean_results; std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_func_graphs_map()) { for (auto& iter : results->count_func_graphs_map()) {
@ -471,17 +407,10 @@ std::vector<FuncGraphPtr> MakeNestedGraph2() {
} }
// Add TestManager::CheckManager function to checkout the result // Add TestManager::CheckManager function to checkout the result
void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) { void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
auto size = mng->func_graphs().size(); auto size = mng->func_graphs().size();
ASSERT_EQ(size + 1, mng->nodes().size());
ASSERT_EQ(size, mng->free_variables_total().size()); ASSERT_EQ(size, mng->free_variables_total().size());
ASSERT_EQ(size, mng->valuenodes().size());
ASSERT_EQ(size, mng->free_variables_direct().size());
ASSERT_EQ(size, mng->func_graph_cnodes_index().size());
ASSERT_EQ(size, mng->func_graph_parents_direct().size());
ASSERT_EQ(size, mng->func_graphs_used().size());
} }
TEST_F(TestManager, test_scalar_add_manual) { TEST_F(TestManager, test_scalar_add_manual) {
@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) {
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);
auto nodes = mng->nodes(); ASSERT_EQ(2, f->nodes().size());
ASSERT_EQ(3, nodes[nullptr].size()); ASSERT_EQ(1, g->nodes().size());
ASSERT_EQ(2, nodes[f].size());
ASSERT_EQ(1, nodes[g].size());
auto users = mng->node_users(); auto users = mng->node_users();
for (auto& iter : users) { for (auto& iter : users) {
ASSERT_EQ(1, iter.second.size()); ASSERT_EQ(1, iter.second.size());
} }
auto graphs_used = mng->func_graphs_used(); ASSERT_EQ(1, f->func_graph_value_nodes().size());
ASSERT_EQ(1, graphs_used[f].size()); ASSERT_EQ(0, g->func_graph_value_nodes().size());
ASSERT_EQ(0, graphs_used[g].size());
auto fv_direct = mng->free_variables_direct(); ASSERT_EQ(0, f->free_variables().size());
ASSERT_EQ(0, fv_direct[f].size()); ASSERT_EQ(1, g->free_variables().size());
ASSERT_EQ(1, fv_direct[g].size());
auto fv_total = mng->free_variables_total(); auto fv_total = mng->free_variables_total();
ASSERT_EQ(0, fv_total[f].size()); ASSERT_EQ(0, fv_total[f].size());
ASSERT_EQ(1, fv_total[g].size()); ASSERT_EQ(1, fv_total[g].size());
auto cnodes = mng->func_graph_cnodes_index(); ASSERT_EQ(0, f->func_graph_cnodes_index().size());
ASSERT_EQ(0, cnodes[f].size()); ASSERT_EQ(1, g->func_graph_cnodes_index().size());
ASSERT_EQ(1, cnodes[g].size());
} }
TEST_F(TestManager, test_deep_nested2_manual) { TEST_F(TestManager, test_deep_nested2_manual) {
@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) {
ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(3, mng->func_graphs().size());
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
ASSERT_EQ(4, mng->nodes().size()); ASSERT_EQ(4, gfn->nodes().size());
ASSERT_EQ(20, mng->all_nodes().size()); ASSERT_EQ(20, mng->all_nodes().size());
ASSERT_EQ(25, mng->node_users().size()); ASSERT_EQ(25, mng->node_users().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);
@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) {
ASSERT_EQ(3, mng->func_graphs().size()); ASSERT_EQ(3, mng->func_graphs().size());
ASSERT_EQ(1, mng->roots().size()); ASSERT_EQ(1, mng->roots().size());
ASSERT_EQ(4, mng->nodes().size());
ASSERT_EQ(20, mng->all_nodes().size()); ASSERT_EQ(20, mng->all_nodes().size());
CheckAnalysisSize(mng); CheckAnalysisSize(mng);
} }
@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) {
FuncGraphPtr fg = getPyFun("ir_get_fn"); FuncGraphPtr fg = getPyFun("ir_get_fn");
auto mng = Manage(fg); auto mng = Manage(fg);
const FuncGraphToAnfNodeMap& nodes = mng->nodes(); const auto &fgs = mng->func_graphs();
ASSERT_TRUE(nodes.find(fg) != nodes.end()); ASSERT_TRUE(fgs.contains(fg));
FuncGraphSet s; FuncGraphSet s;
s.add(fg); s.add(fg);
mng->MaybeDropFuncGraphs(s); mng->MaybeDropFuncGraphs(s);
ASSERT_TRUE(nodes.find(fg) != nodes.end()); ASSERT_TRUE(fgs.contains(fg));
} }
TEST_F(TestManager, test_keep_roots) { TEST_F(TestManager, test_keep_roots) {

View File

@ -26,15 +26,14 @@
namespace mindspore { namespace mindspore {
void CheckNoFreeVariables(FuncGraphPtr root) { void CheckNoFreeVariables(FuncGraphPtr root) {
auto mng = Manage(root); auto mng = Manage(root);
for (auto &iter : mng->nodes()) { for (auto &iter : mng->func_graphs()) {
auto g = iter.first; auto g = iter;
auto nodes = iter.second;
if (g == nullptr) { if (g == nullptr) {
continue; continue;
} }
ASSERT_TRUE(g->parent() == nullptr); ASSERT_TRUE(g->parent() == nullptr);
auto nodes = g->nodes();
for (auto &node : nodes) { for (auto &node : nodes) {
ASSERT_EQ(node->func_graph(), g); ASSERT_EQ(node->func_graph(), g);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();