forked from mindspore-Ecosystem/mindspore
Remove the useless collectors in manager.
This commit is contained in:
parent
3ae925115f
commit
f31564ce98
|
@ -640,53 +640,14 @@ void FuncGraphTransaction::Commit() {
|
|||
}
|
||||
|
||||
FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager)
|
||||
: manager_(manager), include_func_graph_none_(false) {
|
||||
manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph);
|
||||
manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph);
|
||||
manager_->signals()->AddEdge.connect(this, &FuncGraphAnalysis::OnAddEdge);
|
||||
manager_->signals()->DropEdge.connect(this, &FuncGraphAnalysis::OnDropEdge);
|
||||
manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode);
|
||||
}
|
||||
|
||||
NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() {
|
||||
include_func_graph_none_ = true;
|
||||
nodes_analysis_[nullptr] = AnfNodeSet();
|
||||
|
||||
manager_->signals()->AddNode.connect(this, &NodesCollector::OnAddNode);
|
||||
manager_->signals()->DropNode.connect(this, &NodesCollector::OnDropNode);
|
||||
}
|
||||
|
||||
void NodesCollector::OnAddNode(AnfNodePtr n) {
|
||||
if (nodes_analysis_.find(n->func_graph()) == nodes_analysis_.end()) {
|
||||
nodes_analysis_[n->func_graph()] = AnfNodeSet();
|
||||
}
|
||||
nodes_analysis_[n->func_graph()].add(n);
|
||||
}
|
||||
|
||||
void NodesCollector::OnDropNode(AnfNodePtr n) {
|
||||
(void)nodes_analysis_[n->func_graph()].erase(n);
|
||||
auto graph = n->func_graph();
|
||||
// Remove the node from order list.
|
||||
if (graph) {
|
||||
graph->EraseUnusedNodeInOrder(n);
|
||||
}
|
||||
}
|
||||
|
||||
void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
// change the owner of node except for the src's return node
|
||||
for (auto &it : nodes_analysis_[src]) {
|
||||
nodes_analysis_[dst].add(it);
|
||||
}
|
||||
(void)nodes_analysis_.erase(src);
|
||||
}
|
||||
|
||||
void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
|
||||
: manager_(manager), include_func_graph_none_(false) {}
|
||||
|
||||
DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
|
||||
MS_EXCEPTION_IF_NULL(manager_);
|
||||
manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector);
|
||||
}
|
||||
|
||||
void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
|
||||
|
||||
void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
|
||||
|
||||
template <typename ValueT, class CollectorHash, class CollectorEqual>
|
||||
|
@ -735,65 +696,6 @@ bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const F
|
|||
}
|
||||
}
|
||||
|
||||
void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (inp->isa<ValueNode>()) {
|
||||
(void)Mod(node->func_graph(), inp, direction);
|
||||
}
|
||||
}
|
||||
|
||||
void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_nodes_map_[src]) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
(void)count_nodes_map_.erase(src);
|
||||
}
|
||||
|
||||
void FuncGraphUsersCNodeIndexCollector::OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp,
|
||||
EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsValueNode<FuncGraph>(inp)) {
|
||||
(void)Mod(GetValueNode<FuncGraphPtr>(inp), std::make_shared<CNodeIndexPair>(std::make_pair(node, index)),
|
||||
direction);
|
||||
}
|
||||
}
|
||||
|
||||
void FuncGraphUsersCNodeIndexCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_nodes_map_[src]) {
|
||||
// Ignore the user graph who may own itself.
|
||||
if (dst != it.first->first->func_graph()) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
}
|
||||
(void)count_nodes_map_.erase(src);
|
||||
}
|
||||
|
||||
void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(inp);
|
||||
FuncGraphPtr fg1 = node->func_graph();
|
||||
FuncGraphPtr fg2 = inp->func_graph();
|
||||
if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
|
||||
(void)Mod(fg1, inp, direction);
|
||||
}
|
||||
}
|
||||
|
||||
void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_nodes_map_[src]) {
|
||||
FuncGraphPtr fg2 = it.first->func_graph();
|
||||
if (fg2 != dst) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
}
|
||||
(void)count_nodes_map_.erase(src);
|
||||
}
|
||||
|
||||
static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) {
|
||||
FuncGraphPtr gn = std::make_shared<FuncGraph>();
|
||||
(void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg)));
|
||||
return gn;
|
||||
}
|
||||
|
||||
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
|
||||
auto &d = count_func_graphs_map_[func_graph];
|
||||
if (d.count(key) == 0) {
|
||||
|
@ -833,87 +735,6 @@ bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGr
|
|||
}
|
||||
}
|
||||
|
||||
void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(inp);
|
||||
FuncGraphPtr fg1 = node->func_graph();
|
||||
FuncGraphPtr fg2 = inp->func_graph();
|
||||
if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
|
||||
(void)Mod(fg2, fg1, direction);
|
||||
}
|
||||
}
|
||||
|
||||
void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_func_graphs_map_[src]) {
|
||||
FuncGraphPtr fg = it.first;
|
||||
if (fg != dst) {
|
||||
(void)Inc(dst, fg, it.second);
|
||||
}
|
||||
}
|
||||
(void)count_func_graphs_map_.erase(src);
|
||||
}
|
||||
|
||||
void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
FuncGraphPtr fg1 = node->func_graph();
|
||||
// possible child parent
|
||||
if (IsValueNode<FuncGraph>(inp)) {
|
||||
FuncGraphPtr fg2 = GetValueNode<FuncGraphPtr>(inp);
|
||||
if (Mod(fg1, ParentProxy(fg2), direction)) {
|
||||
manager_->signals()->InvalidateComputer();
|
||||
}
|
||||
}
|
||||
// from fv
|
||||
FuncGraphPtr fg2 = inp->func_graph();
|
||||
if (nullptr != fg1 && nullptr != fg2 && fg1 != fg2) {
|
||||
// node use fv will in here, fg1's node use fg2's node, so fg1 is child and fg2 is parent
|
||||
if (Mod(fg1, fg2, direction)) {
|
||||
manager_->signals()->InvalidateComputer();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
for (auto &it : count_func_graphs_map_[src]) {
|
||||
if (it.first != dst) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
}
|
||||
(void)count_func_graphs_map_.erase(src);
|
||||
}
|
||||
|
||||
void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (IsValueNode<FuncGraph>(inp)) {
|
||||
(void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction);
|
||||
}
|
||||
}
|
||||
|
||||
void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
// all graph use in src need to change to dst, so meger the to dst use
|
||||
for (auto &it : count_func_graphs_map_[src]) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
(void)count_func_graphs_map_[dst].erase(src);
|
||||
(void)count_func_graphs_map_.erase(src);
|
||||
}
|
||||
|
||||
void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProcessDirection direction) {
|
||||
if (IsValueNode<FuncGraph>(inp) && IsPrimitiveCNode(node, prim::kPrimJ)) {
|
||||
(void)Mod(node->func_graph(), GetValueNode<FuncGraphPtr>(inp), direction);
|
||||
MS_LOG(DEBUG) << node->func_graph()->ToString() << " users func graph "
|
||||
<< GetValueNode<FuncGraphPtr>(inp)->ToString() << " which contains J(func_graph), dir: " << direction;
|
||||
}
|
||||
}
|
||||
|
||||
void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) {
|
||||
// all graph use in src need to change to dst, so meger the to dst use
|
||||
for (auto &it : count_func_graphs_map_[src]) {
|
||||
(void)Inc(dst, it.first, it.second);
|
||||
}
|
||||
(void)count_func_graphs_map_.erase(src);
|
||||
}
|
||||
|
||||
DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
|
||||
MS_EXCEPTION_IF_NULL(manager_);
|
||||
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
|
||||
|
|
|
@ -140,44 +140,6 @@ class FuncGraphAnalysis {
|
|||
|
||||
using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>;
|
||||
|
||||
// graphs analysis which compute in write, read needn't recompute
|
||||
class DepCollector : public FuncGraphAnalysis {
|
||||
public:
|
||||
explicit DepCollector(const FuncGraphManager *manager);
|
||||
~DepCollector() override = default;
|
||||
|
||||
void Reset() { ExtraReset(); }
|
||||
void OnInvalidateCollector() { Reset(); }
|
||||
|
||||
protected:
|
||||
// inherit from FuncGraphAnalysis
|
||||
void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
|
||||
void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
|
||||
// subclass can override;
|
||||
virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {}
|
||||
};
|
||||
|
||||
class NodesCollector final : public DepCollector {
|
||||
public:
|
||||
explicit NodesCollector(const FuncGraphManager *m);
|
||||
~NodesCollector() override = default;
|
||||
|
||||
const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; }
|
||||
size_t size() const override { return nodes_analysis_.size(); }
|
||||
void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); }
|
||||
|
||||
void OnDropFuncGraph(FuncGraphPtr fg) override { (void)nodes_analysis_.erase(fg); }
|
||||
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
|
||||
FuncGraphToAnfNodeMap nodes_analysis_;
|
||||
|
||||
protected:
|
||||
void ExtraReset() override { nodes_analysis_.clear(); }
|
||||
void OnAddNode(AnfNodePtr n) override;
|
||||
void OnDropNode(AnfNodePtr n) override;
|
||||
};
|
||||
|
||||
struct CNodeIndexHasher {
|
||||
std::size_t operator()(const CNodeIndexPairPtr pair) const {
|
||||
MS_EXCEPTION_IF_NULL(pair);
|
||||
|
@ -204,59 +166,21 @@ struct CNodeIndexEqual {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
|
||||
class CounterAnfNodeCollector : public DepCollector {
|
||||
// graphs analysis which compute in write, read needn't recompute
|
||||
class DepCollector : public FuncGraphAnalysis {
|
||||
public:
|
||||
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
|
||||
~CounterAnfNodeCollector() override = default;
|
||||
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }
|
||||
explicit DepCollector(const FuncGraphManager *manager);
|
||||
~DepCollector() override = default;
|
||||
|
||||
size_t size() const override { return count_nodes_map_.size(); }
|
||||
void OnAddFuncGraph(FuncGraphPtr fg) final {
|
||||
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
|
||||
}
|
||||
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
|
||||
|
||||
bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
|
||||
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;
|
||||
void Reset() { ExtraReset(); }
|
||||
void OnInvalidateCollector() { Reset(); }
|
||||
|
||||
protected:
|
||||
void ExtraReset() override { count_nodes_map_.clear(); }
|
||||
};
|
||||
|
||||
class ValueNodesCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
|
||||
public:
|
||||
explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||
~ValueNodesCollector() override = default;
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
// Record the CNode and its input index, who points to the function graph.
|
||||
class FuncGraphUsersCNodeIndexCollector final
|
||||
: public CounterAnfNodeCollector<CNodeIndexPairPtr, CNodeIndexHasher, CNodeIndexEqual> {
|
||||
public:
|
||||
explicit FuncGraphUsersCNodeIndexCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||
~FuncGraphUsersCNodeIndexCollector() override = default;
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
class FVDirectCollector final : public CounterAnfNodeCollector<AnfNodePtr> {
|
||||
public:
|
||||
explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {}
|
||||
~FVDirectCollector() override = default;
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
// inherit from FuncGraphAnalysis
|
||||
void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
|
||||
void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
|
||||
// subclass can override;
|
||||
virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {}
|
||||
};
|
||||
|
||||
class CounterFuncGraphCollector : public DepCollector {
|
||||
|
@ -278,50 +202,27 @@ class CounterFuncGraphCollector : public DepCollector {
|
|||
void ExtraReset() override { count_func_graphs_map_.clear(); }
|
||||
};
|
||||
|
||||
class FuncGraphChildDirect final : public CounterFuncGraphCollector {
|
||||
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
|
||||
class CounterAnfNodeCollector : public DepCollector {
|
||||
public:
|
||||
explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
|
||||
~CounterAnfNodeCollector() override = default;
|
||||
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }
|
||||
|
||||
~FuncGraphChildDirect() override = default;
|
||||
size_t size() const override { return count_nodes_map_.size(); }
|
||||
void OnAddFuncGraph(FuncGraphPtr fg) final {
|
||||
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
|
||||
}
|
||||
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
|
||||
|
||||
bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
|
||||
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
// graph's all parents, parentsdirect have a map, which key is graph, value is this graph's all direct and proxy
|
||||
// parents:
|
||||
// 1.proxy parent: graph g use graph f, key is g, value is ParentProxy(f) because f's parent will be g's parent
|
||||
// 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f
|
||||
class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector {
|
||||
public:
|
||||
explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||
~FuncGraphParentsDirectCollector() override = default;
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
// graph's all used graphs: key is g, value is g used graph
|
||||
class FuncGraphsUsedCollector final : public CounterFuncGraphCollector {
|
||||
public:
|
||||
explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||
void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override;
|
||||
~FuncGraphsUsedCollector() override = default;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
};
|
||||
|
||||
class FuncGraphJDirectCollector final : public CounterFuncGraphCollector {
|
||||
public:
|
||||
explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {}
|
||||
void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override;
|
||||
~FuncGraphJDirectCollector() override = default;
|
||||
|
||||
protected:
|
||||
void OnModEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction) override;
|
||||
void ExtraReset() override { count_nodes_map_.clear(); }
|
||||
};
|
||||
|
||||
using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
|
||||
|
|
|
@ -132,18 +132,6 @@ class NestingSpecs {
|
|||
CheckAnfNodeCounter(counter_p);
|
||||
return;
|
||||
}
|
||||
|
||||
auto counter_pair = dynamic_pointer_cast<CounterAnfNodeCollector<CNodeIndexPairPtr>>(results);
|
||||
if (counter_pair != nullptr) {
|
||||
CheckCNodeIndexPairCounter(counter_pair);
|
||||
return;
|
||||
}
|
||||
|
||||
auto nodes = dynamic_pointer_cast<NodesCollector>(results);
|
||||
if (nodes != nullptr) {
|
||||
CheckNodes(nodes);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -205,33 +193,7 @@ class NestingSpecs {
|
|||
ASSERT_EQ(clean_results, expected_);
|
||||
}
|
||||
|
||||
void CheckNodes(std::shared_ptr<NodesCollector> results) {
|
||||
std::map<std::string, std::set<std::string>> clean_results;
|
||||
for (auto& iter : results->nodes_analysis()) {
|
||||
auto key = iter.first;
|
||||
auto value = iter.second;
|
||||
if (key == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string k = Name(key);
|
||||
|
||||
std::set<std::string> v;
|
||||
for (auto& node : value) {
|
||||
if (!node->isa<CNode>() && !Name(node).empty()) {
|
||||
v.insert(Name(node));
|
||||
}
|
||||
}
|
||||
|
||||
if (!v.empty()) {
|
||||
clean_results[k] = v;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(clean_results, expected_);
|
||||
}
|
||||
|
||||
// Add CheckNesting function
|
||||
|
||||
void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
|
||||
std::map<std::string, std::set<std::string>> clean_results;
|
||||
for (auto& iter : results->count_nodes_map()) {
|
||||
|
@ -258,32 +220,6 @@ class NestingSpecs {
|
|||
ASSERT_EQ(clean_results, expected_);
|
||||
}
|
||||
|
||||
void CheckCNodeIndexPairCounter(std::shared_ptr<CounterAnfNodeCollector<CNodeIndexPairPtr>> results) {
|
||||
std::map<std::string, std::set<std::string>> clean_results;
|
||||
for (auto& iter : results->count_nodes_map()) {
|
||||
auto key = iter.first;
|
||||
auto value = iter.second;
|
||||
if (key == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string k = Name(key);
|
||||
|
||||
std::set<std::string> v;
|
||||
for (auto& node : value) {
|
||||
auto fg = node.first->first;
|
||||
if (!Name(fg).empty()) {
|
||||
v.insert(Name(fg));
|
||||
}
|
||||
}
|
||||
|
||||
if (!v.empty()) {
|
||||
clean_results[k] = v;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(clean_results, expected_);
|
||||
}
|
||||
|
||||
void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
|
||||
std::map<std::string, std::set<std::string>> clean_results;
|
||||
for (auto& iter : results->count_func_graphs_map()) {
|
||||
|
@ -471,17 +407,10 @@ std::vector<FuncGraphPtr> MakeNestedGraph2() {
|
|||
}
|
||||
|
||||
// Add TestManager::CheckManager function to checkout the result
|
||||
|
||||
void TestManager::CheckAnalysisSize(std::shared_ptr<FuncGraphManager> mng) {
|
||||
auto size = mng->func_graphs().size();
|
||||
|
||||
ASSERT_EQ(size + 1, mng->nodes().size());
|
||||
ASSERT_EQ(size, mng->free_variables_total().size());
|
||||
ASSERT_EQ(size, mng->valuenodes().size());
|
||||
ASSERT_EQ(size, mng->free_variables_direct().size());
|
||||
ASSERT_EQ(size, mng->func_graph_cnodes_index().size());
|
||||
ASSERT_EQ(size, mng->func_graph_parents_direct().size());
|
||||
ASSERT_EQ(size, mng->func_graphs_used().size());
|
||||
}
|
||||
|
||||
TEST_F(TestManager, test_scalar_add_manual) {
|
||||
|
@ -525,31 +454,26 @@ TEST_F(TestManager, test_nested_manual) {
|
|||
ASSERT_EQ(1, mng->roots().size());
|
||||
CheckAnalysisSize(mng);
|
||||
|
||||
auto nodes = mng->nodes();
|
||||
ASSERT_EQ(3, nodes[nullptr].size());
|
||||
ASSERT_EQ(2, nodes[f].size());
|
||||
ASSERT_EQ(1, nodes[g].size());
|
||||
ASSERT_EQ(2, f->nodes().size());
|
||||
ASSERT_EQ(1, g->nodes().size());
|
||||
|
||||
auto users = mng->node_users();
|
||||
for (auto& iter : users) {
|
||||
ASSERT_EQ(1, iter.second.size());
|
||||
}
|
||||
|
||||
auto graphs_used = mng->func_graphs_used();
|
||||
ASSERT_EQ(1, graphs_used[f].size());
|
||||
ASSERT_EQ(0, graphs_used[g].size());
|
||||
ASSERT_EQ(1, f->func_graph_value_nodes().size());
|
||||
ASSERT_EQ(0, g->func_graph_value_nodes().size());
|
||||
|
||||
auto fv_direct = mng->free_variables_direct();
|
||||
ASSERT_EQ(0, fv_direct[f].size());
|
||||
ASSERT_EQ(1, fv_direct[g].size());
|
||||
ASSERT_EQ(0, f->free_variables().size());
|
||||
ASSERT_EQ(1, g->free_variables().size());
|
||||
|
||||
auto fv_total = mng->free_variables_total();
|
||||
ASSERT_EQ(0, fv_total[f].size());
|
||||
ASSERT_EQ(1, fv_total[g].size());
|
||||
|
||||
auto cnodes = mng->func_graph_cnodes_index();
|
||||
ASSERT_EQ(0, cnodes[f].size());
|
||||
ASSERT_EQ(1, cnodes[g].size());
|
||||
ASSERT_EQ(0, f->func_graph_cnodes_index().size());
|
||||
ASSERT_EQ(1, g->func_graph_cnodes_index().size());
|
||||
}
|
||||
|
||||
TEST_F(TestManager, test_deep_nested2_manual) {
|
||||
|
@ -567,7 +491,7 @@ TEST_F(TestManager, test_deep_nested2_manual) {
|
|||
|
||||
ASSERT_EQ(3, mng->func_graphs().size());
|
||||
ASSERT_EQ(1, mng->roots().size());
|
||||
ASSERT_EQ(4, mng->nodes().size());
|
||||
ASSERT_EQ(4, gfn->nodes().size());
|
||||
ASSERT_EQ(20, mng->all_nodes().size());
|
||||
ASSERT_EQ(25, mng->node_users().size());
|
||||
CheckAnalysisSize(mng);
|
||||
|
@ -631,7 +555,6 @@ TEST_F(TestManager, test_deep_nested_manual) {
|
|||
|
||||
ASSERT_EQ(3, mng->func_graphs().size());
|
||||
ASSERT_EQ(1, mng->roots().size());
|
||||
ASSERT_EQ(4, mng->nodes().size());
|
||||
ASSERT_EQ(20, mng->all_nodes().size());
|
||||
CheckAnalysisSize(mng);
|
||||
}
|
||||
|
@ -716,12 +639,12 @@ TEST_F(TestManager, test_drop_root) {
|
|||
FuncGraphPtr fg = getPyFun("ir_get_fn");
|
||||
|
||||
auto mng = Manage(fg);
|
||||
const FuncGraphToAnfNodeMap& nodes = mng->nodes();
|
||||
ASSERT_TRUE(nodes.find(fg) != nodes.end());
|
||||
const auto &fgs = mng->func_graphs();
|
||||
ASSERT_TRUE(fgs.contains(fg));
|
||||
FuncGraphSet s;
|
||||
s.add(fg);
|
||||
mng->MaybeDropFuncGraphs(s);
|
||||
ASSERT_TRUE(nodes.find(fg) != nodes.end());
|
||||
ASSERT_TRUE(fgs.contains(fg));
|
||||
}
|
||||
|
||||
TEST_F(TestManager, test_keep_roots) {
|
||||
|
|
|
@ -26,15 +26,14 @@
|
|||
namespace mindspore {
|
||||
void CheckNoFreeVariables(FuncGraphPtr root) {
|
||||
auto mng = Manage(root);
|
||||
for (auto &iter : mng->nodes()) {
|
||||
auto g = iter.first;
|
||||
auto nodes = iter.second;
|
||||
for (auto &iter : mng->func_graphs()) {
|
||||
auto g = iter;
|
||||
if (g == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ASSERT_TRUE(g->parent() == nullptr);
|
||||
|
||||
auto nodes = g->nodes();
|
||||
for (auto &node : nodes) {
|
||||
ASSERT_EQ(node->func_graph(), g);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
|
Loading…
Reference in New Issue