forked from mindspore-Ecosystem/mindspore
Refactoring the fg manager module.
This commit is contained in:
parent
aae2e410ba
commit
923d3fee04
|
@ -38,6 +38,32 @@ namespace mindspore {
|
|||
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
|
||||
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
|
||||
|
||||
struct CNodeIndexHasher {
|
||||
std::size_t operator()(const CNodeIndexPairPtr pair) const {
|
||||
MS_EXCEPTION_IF_NULL(pair);
|
||||
MS_EXCEPTION_IF_NULL(pair->first);
|
||||
return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
|
||||
}
|
||||
};
|
||||
|
||||
struct CNodeIndexEqual {
|
||||
bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
|
||||
if (lhs == nullptr || rhs == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (lhs == rhs) {
|
||||
return true;
|
||||
}
|
||||
if (lhs->first != rhs->first) {
|
||||
return false;
|
||||
}
|
||||
if (lhs->second != rhs->second) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
|
||||
using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
|
||||
using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;
|
||||
|
|
|
@ -633,103 +633,7 @@ void FuncGraphTransaction::Commit() {
|
|||
manager_->CommitChanges(changes);
|
||||
}
|
||||
|
||||
FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager)
|
||||
: manager_(manager), include_func_graph_none_(false) {}
|
||||
|
||||
DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
|
||||
MS_EXCEPTION_IF_NULL(manager_);
|
||||
}
|
||||
|
||||
void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
|
||||
|
||||
void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
|
||||
|
||||
template <typename ValueT, class CollectorHash, class CollectorEqual>
|
||||
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Inc(const FuncGraphPtr &func_graph,
|
||||
const ValueT &key, int count) {
|
||||
auto &d = count_nodes_map_[func_graph];
|
||||
if (d.count(key) == 0) {
|
||||
d[key] = count;
|
||||
return true;
|
||||
} else {
|
||||
d[key] += count;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename ValueT, class CollectorHash, class CollectorEqual>
|
||||
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Dec(const FuncGraphPtr &func_graph,
|
||||
const ValueT &key, int count) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto &d = count_nodes_map_[func_graph];
|
||||
if (d.count(key) != 0) {
|
||||
if (d[key] == count) {
|
||||
(void)d.erase(key);
|
||||
return true;
|
||||
} else {
|
||||
d[key] -= count;
|
||||
if (d[key] < 0) {
|
||||
MS_LOG(EXCEPTION) << "Count of key '" << key
|
||||
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename ValueT, class CollectorHash, class CollectorEqual>
|
||||
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const FuncGraphPtr &func_graph,
|
||||
const ValueT &key, int count) {
|
||||
if (count > 0) {
|
||||
return Inc(func_graph, key, count);
|
||||
} else if (count < 0) {
|
||||
return Dec(func_graph, key, -count);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Count of key '" << key
|
||||
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
||||
}
|
||||
}
|
||||
|
||||
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
|
||||
auto &d = count_func_graphs_map_[func_graph];
|
||||
if (d.count(key) == 0) {
|
||||
d[key] = count;
|
||||
return true;
|
||||
} else {
|
||||
d[key] += count;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
|
||||
auto &d = count_func_graphs_map_[func_graph];
|
||||
if (d.count(key) != 0) {
|
||||
if (d[key] == count) {
|
||||
(void)d.erase(key);
|
||||
return true;
|
||||
} else {
|
||||
d[key] -= count;
|
||||
if (d[key] < 0) {
|
||||
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
|
||||
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) {
|
||||
if (count > 0) {
|
||||
return Inc(func_graph, key, count);
|
||||
} else if (count < 0) {
|
||||
return Dec(func_graph, key, -count);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
|
||||
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
|
||||
}
|
||||
}
|
||||
|
||||
DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
|
||||
DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) {
|
||||
MS_EXCEPTION_IF_NULL(manager_);
|
||||
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
|
||||
validate_ = false;
|
||||
|
@ -839,16 +743,15 @@ void FVTotalComputer::RealRecompute() {
|
|||
|
||||
for (auto &fg : manager->func_graphs()) {
|
||||
fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
|
||||
count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>();
|
||||
count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>();
|
||||
}
|
||||
|
||||
for (auto &fg : manager->func_graphs()) {
|
||||
// add all free variable nodes
|
||||
AnfNodeCounterMap items = fg->free_variables();
|
||||
for (auto &iter : items) {
|
||||
auto curr = fg;
|
||||
while (curr != nullptr) {
|
||||
(void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second);
|
||||
fv_total_analysis_[curr][iter.first] = iter.second;
|
||||
curr = manager->parent(curr);
|
||||
if (curr != nullptr) {
|
||||
const AnfNodeSet &all_nodes = curr->nodes();
|
||||
|
@ -859,6 +762,7 @@ void FVTotalComputer::RealRecompute() {
|
|||
}
|
||||
}
|
||||
|
||||
// add all FGs of free variables
|
||||
auto &used = fg->func_graphs_used();
|
||||
for (auto &iter : used) {
|
||||
auto p = manager->parent(iter.first);
|
||||
|
@ -867,21 +771,11 @@ void FVTotalComputer::RealRecompute() {
|
|||
}
|
||||
auto curr = fg;
|
||||
while (curr != p) {
|
||||
(void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second);
|
||||
fv_total_analysis_[curr][iter.first] = iter.second;
|
||||
curr = manager->parent(curr);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto &fg : manager->func_graphs()) {
|
||||
auto &fvp = count_nodes_map_[fg];
|
||||
auto &fvg = count_func_graphs_map_[fg];
|
||||
for (auto &item : fvp) {
|
||||
fv_total_analysis_[fg][item.first] = item.second;
|
||||
}
|
||||
for (auto &item : fvg) {
|
||||
fv_total_analysis_[fg][item.first] = item.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {
|
||||
|
|
|
@ -88,14 +88,6 @@ FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool ma
|
|||
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true);
|
||||
|
||||
struct Signals {
|
||||
Signal<void(FuncGraphPtr)> AddFuncGraph;
|
||||
Signal<void(FuncGraphPtr)> DropFuncGraph;
|
||||
Signal<void(AnfNodePtr)> AddNode;
|
||||
Signal<void(AnfNodePtr)> DropNode;
|
||||
Signal<void(AnfNodePtr, int, AnfNodePtr)> AddEdge;
|
||||
Signal<void(AnfNodePtr, int, AnfNodePtr)> DropEdge;
|
||||
Signal<void(FuncGraphPtr, FuncGraphPtr)> MoveAllCNode;
|
||||
Signal<void()> InvalidateCollector;
|
||||
Signal<void()> InvalidateComputer;
|
||||
};
|
||||
|
||||
|
@ -103,136 +95,15 @@ enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 };
|
|||
|
||||
using CNodeIndexPair = std::pair<AnfNodePtr, int>;
|
||||
using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
|
||||
|
||||
using FuncGraphToFuncGraphCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<FuncGraphPtr, int>>;
|
||||
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
|
||||
using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<ValueT, int, CollectorHash, CollectorEqual>>;
|
||||
|
||||
// analysis base class
|
||||
class FuncGraphAnalysis {
|
||||
public:
|
||||
explicit FuncGraphAnalysis(const FuncGraphManager *const manager);
|
||||
|
||||
virtual ~FuncGraphAnalysis() { manager_ = nullptr; }
|
||||
|
||||
virtual size_t size() const { return 0; }
|
||||
|
||||
virtual void OnAddFuncGraph(FuncGraphPtr) {}
|
||||
|
||||
virtual void OnDropFuncGraph(FuncGraphPtr) {}
|
||||
|
||||
virtual void OnMoveAllCNode(FuncGraphPtr, FuncGraphPtr) {}
|
||||
|
||||
protected:
|
||||
// subclass can reset their own member;
|
||||
virtual void ExtraReset() {}
|
||||
|
||||
virtual void OnAddNode(AnfNodePtr n) {}
|
||||
|
||||
virtual void OnDropNode(AnfNodePtr n) {}
|
||||
|
||||
virtual void OnAddEdge(AnfNodePtr, int, AnfNodePtr) {}
|
||||
|
||||
virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {}
|
||||
|
||||
const FuncGraphManager *manager_;
|
||||
bool include_func_graph_none_;
|
||||
};
|
||||
|
||||
using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>;
|
||||
|
||||
struct CNodeIndexHasher {
|
||||
std::size_t operator()(const CNodeIndexPairPtr pair) const {
|
||||
MS_EXCEPTION_IF_NULL(pair);
|
||||
MS_EXCEPTION_IF_NULL(pair->first);
|
||||
return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
|
||||
}
|
||||
};
|
||||
|
||||
struct CNodeIndexEqual {
|
||||
bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
|
||||
if (lhs == nullptr || rhs == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (lhs == rhs) {
|
||||
return true;
|
||||
}
|
||||
if (lhs->first != rhs->first) {
|
||||
return false;
|
||||
}
|
||||
if (lhs->second != rhs->second) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// graphs analysis which compute in write, read needn't recompute
|
||||
class DepCollector : public FuncGraphAnalysis {
|
||||
public:
|
||||
explicit DepCollector(const FuncGraphManager *manager);
|
||||
~DepCollector() override = default;
|
||||
|
||||
void Reset() { ExtraReset(); }
|
||||
void OnInvalidateCollector() { Reset(); }
|
||||
|
||||
protected:
|
||||
// inherit from FuncGraphAnalysis
|
||||
void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
|
||||
void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
|
||||
// subclass can override;
|
||||
virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {}
|
||||
};
|
||||
|
||||
class CounterFuncGraphCollector : public DepCollector {
|
||||
public:
|
||||
explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {}
|
||||
~CounterFuncGraphCollector() override = default;
|
||||
FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; }
|
||||
// inherit from FuncGraphAnalysis
|
||||
size_t size() const override { return count_func_graphs_map_.size(); }
|
||||
void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); }
|
||||
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); }
|
||||
bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
|
||||
bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
|
||||
bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
|
||||
|
||||
FuncGraphToFuncGraphCounterMap count_func_graphs_map_;
|
||||
|
||||
protected:
|
||||
void ExtraReset() override { count_func_graphs_map_.clear(); }
|
||||
};
|
||||
|
||||
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
|
||||
class CounterAnfNodeCollector : public DepCollector {
|
||||
public:
|
||||
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
|
||||
~CounterAnfNodeCollector() override = default;
|
||||
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }
|
||||
|
||||
size_t size() const override { return count_nodes_map_.size(); }
|
||||
void OnAddFuncGraph(FuncGraphPtr fg) final {
|
||||
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
|
||||
}
|
||||
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
|
||||
|
||||
bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);
|
||||
|
||||
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;
|
||||
|
||||
protected:
|
||||
void ExtraReset() override { count_nodes_map_.clear(); }
|
||||
};
|
||||
|
||||
using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
|
||||
|
||||
// graphs analysis which need dynamic compute by DepCollector in each read
|
||||
class DepComputer : public FuncGraphAnalysis {
|
||||
// analysis base class, graphs analysis which need dynamic compute by DepCollector in each read
|
||||
class DepComputer {
|
||||
public:
|
||||
explicit DepComputer(const FuncGraphManager *manager);
|
||||
~DepComputer() override = default;
|
||||
virtual ~DepComputer() { manager_ = nullptr; }
|
||||
|
||||
virtual size_t size() const { return 0; }
|
||||
|
||||
void Reset() {
|
||||
ExtraReset();
|
||||
|
@ -250,15 +121,14 @@ class DepComputer : public FuncGraphAnalysis {
|
|||
|
||||
bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; }
|
||||
|
||||
void OnAddFuncGraph(FuncGraphPtr) final { Reset(); }
|
||||
|
||||
void OnDropFuncGraph(FuncGraphPtr) final { Reset(); }
|
||||
|
||||
protected:
|
||||
// subclass can reset their own member;
|
||||
virtual void ExtraReset() {}
|
||||
// subclass do the real compute
|
||||
virtual void RealRecompute() {}
|
||||
virtual void RealRecompute(FuncGraphPtr) {}
|
||||
|
||||
const FuncGraphManager *manager_;
|
||||
bool validate_;
|
||||
OrderedMap<FuncGraphPtr, bool> func_graphs_validate_;
|
||||
|
||||
|
@ -345,12 +215,9 @@ class ScopeComputer final : public DepComputer {
|
|||
|
||||
using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>;
|
||||
|
||||
class FVTotalComputer final : public DepComputer,
|
||||
public CounterAnfNodeCollector<AnfNodePtr>,
|
||||
public CounterFuncGraphCollector {
|
||||
class FVTotalComputer final : public DepComputer {
|
||||
public:
|
||||
explicit FVTotalComputer(const FuncGraphManager *m)
|
||||
: DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {}
|
||||
explicit FVTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
|
||||
~FVTotalComputer() override = default;
|
||||
|
||||
FVTotalMap &fv_total_analysis() { return fv_total_analysis_; }
|
||||
|
|
|
@ -104,7 +104,7 @@ class NestingSpecs {
|
|||
return name;
|
||||
}
|
||||
|
||||
void Check(std::shared_ptr<FuncGraphAnalysis> results) {
|
||||
void Check(std::shared_ptr<DepComputer> results) {
|
||||
if (expected_.empty() && expected_recursive_.empty()) {
|
||||
return;
|
||||
}
|
||||
|
@ -120,18 +120,6 @@ class NestingSpecs {
|
|||
CheckRecursive(recursive);
|
||||
return;
|
||||
}
|
||||
|
||||
auto counter_g = dynamic_pointer_cast<CounterFuncGraphCollector>(results);
|
||||
if (counter_g != nullptr) {
|
||||
CheckGraphCounter(counter_g);
|
||||
return;
|
||||
}
|
||||
|
||||
auto counter_p = dynamic_pointer_cast<CounterAnfNodeCollector<AnfNodePtr>>(results);
|
||||
if (counter_p != nullptr) {
|
||||
CheckAnfNodeCounter(counter_p);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -193,59 +181,6 @@ class NestingSpecs {
|
|||
ASSERT_EQ(clean_results, expected_);
|
||||
}
|
||||
|
||||
// Add CheckNesting function
|
||||
void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
|
||||
std::map<std::string, std::set<std::string>> clean_results;
|
||||
for (auto& iter : results->count_nodes_map()) {
|
||||
auto key = iter.first;
|
||||
auto value = iter.second;
|
||||
if (key == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string k = Name(key);
|
||||
|
||||
std::set<std::string> v;
|
||||
for (auto& node : value) {
|
||||
auto fg = node.first;
|
||||
if (!Name(fg).empty()) {
|
||||
v.insert(Name(fg));
|
||||
}
|
||||
}
|
||||
|
||||
if (!v.empty()) {
|
||||
clean_results[k] = v;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(clean_results, expected_);
|
||||
}
|
||||
|
||||
void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
|
||||
std::map<std::string, std::set<std::string>> clean_results;
|
||||
for (auto& iter : results->count_func_graphs_map()) {
|
||||
auto key = iter.first;
|
||||
auto value = iter.second;
|
||||
if (key == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string k = Name(key);
|
||||
|
||||
std::set<std::string> v;
|
||||
for (auto& node : value) {
|
||||
auto fg = node.first;
|
||||
if (!Name(fg).empty()) {
|
||||
v.insert(Name(fg));
|
||||
}
|
||||
}
|
||||
|
||||
if (!v.empty()) {
|
||||
clean_results[k] = v;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_EQ(clean_results, expected_);
|
||||
}
|
||||
|
||||
void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
|
||||
std::map<std::string, bool> clean_results;
|
||||
for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {
|
||||
|
|
Loading…
Reference in New Issue