Refactoring the fg manager module.

This commit is contained in:
Zhang Qinghua 2020-06-15 15:47:39 +08:00
parent aae2e410ba
commit 923d3fee04
4 changed files with 42 additions and 320 deletions

View File

@ -38,6 +38,32 @@ namespace mindspore {
using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>; using BaseRefCounterMap = OrderedMap<BaseRef, int, BaseRefHash>;
using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>; using FuncGraphCounterMap = OrderedMap<FuncGraphPtr, int>;
struct CNodeIndexHasher {
std::size_t operator()(const CNodeIndexPairPtr pair) const {
MS_EXCEPTION_IF_NULL(pair);
MS_EXCEPTION_IF_NULL(pair->first);
return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
}
};
struct CNodeIndexEqual {
bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
if (lhs == nullptr || rhs == nullptr) {
return false;
}
if (lhs == rhs) {
return true;
}
if (lhs->first != rhs->first) {
return false;
}
if (lhs->second != rhs->second) {
return false;
}
return true;
}
};
template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>> template <typename ValueT, class CounterHash = std::hash<ValueT>, class CounterEqual = std::equal_to<ValueT>>
using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>; using CounterOrderedMap = OrderedMap<ValueT, int, CounterHash, CounterEqual>;
using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>; using AnfNodeCounterMap = CounterOrderedMap<AnfNodePtr>;

View File

@ -633,103 +633,7 @@ void FuncGraphTransaction::Commit() {
manager_->CommitChanges(changes); manager_->CommitChanges(changes);
} }
FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) DepComputer::DepComputer(const FuncGraphManager *const manager) : manager_(manager) {
: manager_(manager), include_func_graph_none_(false) {}
DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
MS_EXCEPTION_IF_NULL(manager_);
}
void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); }
void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); }
template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Inc(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
auto &d = count_nodes_map_[func_graph];
if (d.count(key) == 0) {
d[key] = count;
return true;
} else {
d[key] += count;
}
return false;
}
template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Dec(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
MS_EXCEPTION_IF_NULL(func_graph);
auto &d = count_nodes_map_[func_graph];
if (d.count(key) != 0) {
if (d[key] == count) {
(void)d.erase(key);
return true;
} else {
d[key] -= count;
if (d[key] < 0) {
MS_LOG(EXCEPTION) << "Count of key '" << key
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
}
return false;
}
template <typename ValueT, class CollectorHash, class CollectorEqual>
bool CounterAnfNodeCollector<ValueT, CollectorHash, CollectorEqual>::Mod(const FuncGraphPtr &func_graph,
const ValueT &key, int count) {
if (count > 0) {
return Inc(func_graph, key, count);
} else if (count < 0) {
return Dec(func_graph, key, -count);
} else {
MS_LOG(EXCEPTION) << "Count of key '" << key
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) == 0) {
d[key] = count;
return true;
} else {
d[key] += count;
}
return false;
}
bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) {
auto &d = count_func_graphs_map_[func_graph];
if (d.count(key) != 0) {
if (d[key] == count) {
(void)d.erase(key);
return true;
} else {
d[key] -= count;
if (d[key] < 0) {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' dec from 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
}
return false;
}
bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) {
if (count > 0) {
return Inc(func_graph, key, count);
} else if (count < 0) {
return Dec(func_graph, key, -count);
} else {
MS_LOG(EXCEPTION) << "Count of key '" << key->ToString()
<< "' cannot be 0. NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info());
}
}
DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) {
MS_EXCEPTION_IF_NULL(manager_); MS_EXCEPTION_IF_NULL(manager_);
manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer);
validate_ = false; validate_ = false;
@ -839,16 +743,15 @@ void FVTotalComputer::RealRecompute() {
for (auto &fg : manager->func_graphs()) { for (auto &fg : manager->func_graphs()) {
fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>(); fv_total_analysis_[fg] = OrderedMap<BaseRef, int, BaseRefHash>();
count_nodes_map_[fg] = OrderedMap<AnfNodePtr, int>();
count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>();
} }
for (auto &fg : manager->func_graphs()) { for (auto &fg : manager->func_graphs()) {
// add all free variable nodes
AnfNodeCounterMap items = fg->free_variables(); AnfNodeCounterMap items = fg->free_variables();
for (auto &iter : items) { for (auto &iter : items) {
auto curr = fg; auto curr = fg;
while (curr != nullptr) { while (curr != nullptr) {
(void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); fv_total_analysis_[curr][iter.first] = iter.second;
curr = manager->parent(curr); curr = manager->parent(curr);
if (curr != nullptr) { if (curr != nullptr) {
const AnfNodeSet &all_nodes = curr->nodes(); const AnfNodeSet &all_nodes = curr->nodes();
@ -859,6 +762,7 @@ void FVTotalComputer::RealRecompute() {
} }
} }
// add all FGs of free variables
auto &used = fg->func_graphs_used(); auto &used = fg->func_graphs_used();
for (auto &iter : used) { for (auto &iter : used) {
auto p = manager->parent(iter.first); auto p = manager->parent(iter.first);
@ -867,21 +771,11 @@ void FVTotalComputer::RealRecompute() {
} }
auto curr = fg; auto curr = fg;
while (curr != p) { while (curr != p) {
(void)CounterFuncGraphCollector::Mod(curr, iter.first, iter.second); fv_total_analysis_[curr][iter.first] = iter.second;
curr = manager->parent(curr); curr = manager->parent(curr);
} }
} }
} }
for (auto &fg : manager->func_graphs()) {
auto &fvp = count_nodes_map_[fg];
auto &fvg = count_func_graphs_map_[fg];
for (auto &item : fvp) {
fv_total_analysis_[fg][item.first] = item.second;
}
for (auto &item : fvg) {
fv_total_analysis_[fg][item.first] = item.second;
}
}
} }
void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) {

View File

@ -88,14 +88,6 @@ FuncGraphManagerPtr Manage(const std::vector<FuncGraphPtr> &func_graphs, bool ma
FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true); FuncGraphManagerPtr MakeManager(const std::vector<FuncGraphPtr> &func_graphs = {}, bool manage = true);
struct Signals { struct Signals {
Signal<void(FuncGraphPtr)> AddFuncGraph;
Signal<void(FuncGraphPtr)> DropFuncGraph;
Signal<void(AnfNodePtr)> AddNode;
Signal<void(AnfNodePtr)> DropNode;
Signal<void(AnfNodePtr, int, AnfNodePtr)> AddEdge;
Signal<void(AnfNodePtr, int, AnfNodePtr)> DropEdge;
Signal<void(FuncGraphPtr, FuncGraphPtr)> MoveAllCNode;
Signal<void()> InvalidateCollector;
Signal<void()> InvalidateComputer; Signal<void()> InvalidateComputer;
}; };
@ -103,136 +95,15 @@ enum EdgeProcessDirection { kDecEdge = -1, kIncEdge = 1 };
using CNodeIndexPair = std::pair<AnfNodePtr, int>; using CNodeIndexPair = std::pair<AnfNodePtr, int>;
using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>; using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
using FuncGraphToFuncGraphCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<FuncGraphPtr, int>>;
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
using FuncGraphToAnfNodeCounterMap = OrderedMap<FuncGraphPtr, OrderedMap<ValueT, int, CollectorHash, CollectorEqual>>;
// analysis base class
class FuncGraphAnalysis {
public:
explicit FuncGraphAnalysis(const FuncGraphManager *const manager);
virtual ~FuncGraphAnalysis() { manager_ = nullptr; }
virtual size_t size() const { return 0; }
virtual void OnAddFuncGraph(FuncGraphPtr) {}
virtual void OnDropFuncGraph(FuncGraphPtr) {}
virtual void OnMoveAllCNode(FuncGraphPtr, FuncGraphPtr) {}
protected:
// subclass can reset their own member;
virtual void ExtraReset() {}
virtual void OnAddNode(AnfNodePtr n) {}
virtual void OnDropNode(AnfNodePtr n) {}
virtual void OnAddEdge(AnfNodePtr, int, AnfNodePtr) {}
virtual void OnDropEdge(AnfNodePtr, int, AnfNodePtr) {}
const FuncGraphManager *manager_;
bool include_func_graph_none_;
};
using FuncGraphToAnfNodeMap = OrderedMap<FuncGraphPtr, AnfNodeSet>;
struct CNodeIndexHasher {
std::size_t operator()(const CNodeIndexPairPtr pair) const {
MS_EXCEPTION_IF_NULL(pair);
MS_EXCEPTION_IF_NULL(pair->first);
return hash_combine(pair->first->hash(), std::hash<int>()(pair->second));
}
};
struct CNodeIndexEqual {
bool operator()(const CNodeIndexPairPtr lhs, const CNodeIndexPairPtr rhs) const {
if (lhs == nullptr || rhs == nullptr) {
return false;
}
if (lhs == rhs) {
return true;
}
if (lhs->first != rhs->first) {
return false;
}
if (lhs->second != rhs->second) {
return false;
}
return true;
}
};
// graphs analysis which compute in write, read needn't recompute
class DepCollector : public FuncGraphAnalysis {
public:
explicit DepCollector(const FuncGraphManager *manager);
~DepCollector() override = default;
void Reset() { ExtraReset(); }
void OnInvalidateCollector() { Reset(); }
protected:
// inherit from FuncGraphAnalysis
void OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
void OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) override;
// subclass can override;
virtual void OnModEdge(AnfNodePtr, int, AnfNodePtr, EdgeProcessDirection) {}
};
class CounterFuncGraphCollector : public DepCollector {
public:
explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterFuncGraphCollector() override = default;
FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; }
// inherit from FuncGraphAnalysis
size_t size() const override { return count_func_graphs_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap<FuncGraphPtr, int>(); }
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count);
FuncGraphToFuncGraphCounterMap count_func_graphs_map_;
protected:
void ExtraReset() override { count_func_graphs_map_.clear(); }
};
template <typename ValueT, class CollectorHash = std::hash<ValueT>, class CollectorEqual = std::equal_to<ValueT>>
class CounterAnfNodeCollector : public DepCollector {
public:
explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {}
~CounterAnfNodeCollector() override = default;
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> &count_nodes_map() { return count_nodes_map_; }
size_t size() const override { return count_nodes_map_.size(); }
void OnAddFuncGraph(FuncGraphPtr fg) final {
count_nodes_map_[fg] = OrderedMap<ValueT, int, CollectorHash, CollectorEqual>();
}
void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); }
bool Inc(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Dec(const FuncGraphPtr &func_graph, const ValueT &key, int count);
bool Mod(const FuncGraphPtr &func_graph, const ValueT &key, int count);
FuncGraphToAnfNodeCounterMap<ValueT, CollectorHash, CollectorEqual> count_nodes_map_;
protected:
void ExtraReset() override { count_nodes_map_.clear(); }
};
using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>; using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
// graphs analysis which need dynamic compute by DepCollector in each read // analysis base class, graphs analysis which need dynamic compute by DepCollector in each read
class DepComputer : public FuncGraphAnalysis { class DepComputer {
public: public:
explicit DepComputer(const FuncGraphManager *manager); explicit DepComputer(const FuncGraphManager *manager);
~DepComputer() override = default; virtual ~DepComputer() { manager_ = nullptr; }
virtual size_t size() const { return 0; }
void Reset() { void Reset() {
ExtraReset(); ExtraReset();
@ -250,15 +121,14 @@ class DepComputer : public FuncGraphAnalysis {
bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; } bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; }
void OnAddFuncGraph(FuncGraphPtr) final { Reset(); }
void OnDropFuncGraph(FuncGraphPtr) final { Reset(); }
protected: protected:
// subclass can reset their own member;
virtual void ExtraReset() {}
// subclass do the real compute // subclass do the real compute
virtual void RealRecompute() {} virtual void RealRecompute() {}
virtual void RealRecompute(FuncGraphPtr) {} virtual void RealRecompute(FuncGraphPtr) {}
const FuncGraphManager *manager_;
bool validate_; bool validate_;
OrderedMap<FuncGraphPtr, bool> func_graphs_validate_; OrderedMap<FuncGraphPtr, bool> func_graphs_validate_;
@ -345,12 +215,9 @@ class ScopeComputer final : public DepComputer {
using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>; using FVTotalMap = OrderedMap<FuncGraphPtr, OrderedMap<BaseRef, int, BaseRefHash>>;
class FVTotalComputer final : public DepComputer, class FVTotalComputer final : public DepComputer {
public CounterAnfNodeCollector<AnfNodePtr>,
public CounterFuncGraphCollector {
public: public:
explicit FVTotalComputer(const FuncGraphManager *m) explicit FVTotalComputer(const FuncGraphManager *m) : DepComputer(m) {}
: DepComputer(m), CounterAnfNodeCollector(m), CounterFuncGraphCollector(m) {}
~FVTotalComputer() override = default; ~FVTotalComputer() override = default;
FVTotalMap &fv_total_analysis() { return fv_total_analysis_; } FVTotalMap &fv_total_analysis() { return fv_total_analysis_; }

View File

@ -104,7 +104,7 @@ class NestingSpecs {
return name; return name;
} }
void Check(std::shared_ptr<FuncGraphAnalysis> results) { void Check(std::shared_ptr<DepComputer> results) {
if (expected_.empty() && expected_recursive_.empty()) { if (expected_.empty() && expected_recursive_.empty()) {
return; return;
} }
@ -120,18 +120,6 @@ class NestingSpecs {
CheckRecursive(recursive); CheckRecursive(recursive);
return; return;
} }
auto counter_g = dynamic_pointer_cast<CounterFuncGraphCollector>(results);
if (counter_g != nullptr) {
CheckGraphCounter(counter_g);
return;
}
auto counter_p = dynamic_pointer_cast<CounterAnfNodeCollector<AnfNodePtr>>(results);
if (counter_p != nullptr) {
CheckAnfNodeCounter(counter_p);
return;
}
} }
private: private:
@ -193,59 +181,6 @@ class NestingSpecs {
ASSERT_EQ(clean_results, expected_); ASSERT_EQ(clean_results, expected_);
} }
// Add CheckNesting function
void CheckAnfNodeCounter(std::shared_ptr<CounterAnfNodeCollector<AnfNodePtr>> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_nodes_map()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);
std::set<std::string> v;
for (auto& node : value) {
auto fg = node.first;
if (!Name(fg).empty()) {
v.insert(Name(fg));
}
}
if (!v.empty()) {
clean_results[k] = v;
}
}
ASSERT_EQ(clean_results, expected_);
}
void CheckGraphCounter(std::shared_ptr<CounterFuncGraphCollector> results) {
std::map<std::string, std::set<std::string>> clean_results;
for (auto& iter : results->count_func_graphs_map()) {
auto key = iter.first;
auto value = iter.second;
if (key == nullptr) {
continue;
}
std::string k = Name(key);
std::set<std::string> v;
for (auto& node : value) {
auto fg = node.first;
if (!Name(fg).empty()) {
v.insert(Name(fg));
}
}
if (!v.empty()) {
clean_results[k] = v;
}
}
ASSERT_EQ(clean_results, expected_);
}
void CheckRecursive(std::shared_ptr<RecursiveComputer> results) { void CheckRecursive(std::shared_ptr<RecursiveComputer> results) {
std::map<std::string, bool> clean_results; std::map<std::string, bool> clean_results;
for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) { for (auto iter = results->recursive_analysis().begin(); iter != results->recursive_analysis().end(); ++iter) {