forked from mindspore-Ecosystem/mindspore
Use FuncGraph generation number replacing set.
This commit is contained in:
parent
f31564ce98
commit
737bfc9595
|
@ -47,6 +47,7 @@ FuncGraph::FuncGraph()
|
|||
: flags_(),
|
||||
transforms_(),
|
||||
parameter_default_value_(),
|
||||
seen_(0),
|
||||
parameters_(),
|
||||
has_vararg_(false),
|
||||
has_kwarg_(false),
|
||||
|
@ -981,6 +982,11 @@ void FuncGraph::SetEffectDepends(const std::vector<AnfNodePtr> &depend_inputs) {
|
|||
}
|
||||
}
|
||||
|
||||
size_t NewFgSeenGeneration() {
|
||||
static size_t fg_seen_generation = 0;
|
||||
return ++fg_seen_generation;
|
||||
}
|
||||
|
||||
const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
|
||||
const char kFuncGraphFlagUndetermined[] = "Undeterminate";
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -289,6 +289,7 @@ class FuncGraph : public FuncGraphBase {
|
|||
// parameter default value
|
||||
std::map<std::string, AnfNodePtr> parameter_default_value_;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> make_ref_params_;
|
||||
size_t seen_;
|
||||
|
||||
std::list<CNodePtr> GetOrderedCnodes();
|
||||
void EraseUnusedNodeInOrder(const AnfNodePtr &n);
|
||||
|
@ -364,6 +365,8 @@ inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphP
|
|||
return fg->NewCNode(inputs);
|
||||
}
|
||||
|
||||
size_t NewFgSeenGeneration();
|
||||
|
||||
// Find the root cnodes of a segment of cnodes.
|
||||
std::shared_ptr<OrderedSet<CNodePtr>> FindRoots(const std::vector<CNodePtr> &segment);
|
||||
// Find the leaf cnodes of a segment of cnodes.
|
||||
|
|
|
@ -755,8 +755,8 @@ void DepComputer::Recompute(const FuncGraphPtr &fg) {
|
|||
}
|
||||
}
|
||||
|
||||
FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
|
||||
if (path == nullptr || path->contains(fg)) {
|
||||
FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, size_t seen_num) {
|
||||
if (fg->seen_ == seen_num) {
|
||||
return std::make_shared<FuncGraphSet>();
|
||||
}
|
||||
FuncGraphSetPtr parents = std::make_shared<FuncGraphSet>();
|
||||
|
@ -770,9 +770,9 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f
|
|||
// Search the fv in fg's child func graph.
|
||||
auto &fg_value_nodes = fg->func_graph_value_nodes();
|
||||
for (auto &fg_value_node : fg_value_nodes) {
|
||||
path->add(fg);
|
||||
fg->seen_ = seen_num;
|
||||
auto gt = GetValueNode<FuncGraphPtr>(fg_value_node.first);
|
||||
parents->update(SeekParents(gt, path));
|
||||
parents->update(SeekParents(gt, seen_num));
|
||||
}
|
||||
(void)parents->erase(fg);
|
||||
return parents;
|
||||
|
@ -780,7 +780,7 @@ FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &f
|
|||
|
||||
void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
func_graph_parents_total_analysis_[fg].update(SeekParents(fg));
|
||||
func_graph_parents_total_analysis_[fg].update(SeekParents(fg, NewFgSeenGeneration()));
|
||||
}
|
||||
|
||||
bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) {
|
||||
|
@ -968,9 +968,8 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list<F
|
|||
}
|
||||
}
|
||||
|
||||
bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) {
|
||||
MS_EXCEPTION_IF_NULL(path);
|
||||
if (path->contains(fg)) {
|
||||
bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, size_t seen_num) {
|
||||
if (fg->seen_ == seen_num) {
|
||||
MS_LOG(DEBUG) << fg->ToString() << " had been checked";
|
||||
return false;
|
||||
}
|
||||
|
@ -978,19 +977,20 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt
|
|||
if (!j_fg_value_nodes.empty()) {
|
||||
// check g1->J(fg)->g2->g cycle;
|
||||
auto contains_j =
|
||||
std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(),
|
||||
[path](const std::pair<AnfNodePtr, int> iter) { return !path->contains(GetValueNode<FuncGraphPtr>(iter.first)); });
|
||||
std::find_if(j_fg_value_nodes.begin(), j_fg_value_nodes.end(), [seen_num](const std::pair<AnfNodePtr, int> iter) {
|
||||
return GetValueNode<FuncGraphPtr>(iter.first)->seen_ != seen_num;
|
||||
});
|
||||
if (contains_j != j_fg_value_nodes.end()) {
|
||||
MS_LOG(DEBUG) << fg->ToString() << " contains J(" << contains_j->first->ToString() << ")";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
path->add(fg);
|
||||
fg->seen_ = seen_num;
|
||||
|
||||
// check if func graphs used contains J(func_graph);
|
||||
for (auto &item : fg->func_graph_value_nodes()) {
|
||||
auto used_g = GetValueNode<FuncGraphPtr>(item.first);
|
||||
if (SeekJ(used_g, path)) {
|
||||
if (SeekJ(used_g, seen_num)) {
|
||||
MS_LOG(DEBUG) << fg->ToString() << " users func graph " << used_g->ToString() << " which contains J(func_graph)";
|
||||
return true;
|
||||
}
|
||||
|
@ -1000,7 +1000,6 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPt
|
|||
}
|
||||
|
||||
void FuncGraphJTotalComputer::RealRecompute(FuncGraphPtr fg) {
|
||||
std::shared_ptr<FuncGraphSet> path = std::make_shared<FuncGraphSet>();
|
||||
this->j_total_analysis_[fg] = SeekJ(fg, path);
|
||||
this->j_total_analysis_[fg] = SeekJ(fg, NewFgSeenGeneration());
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -283,7 +283,7 @@ class FuncGraphParentsTotalComputer final : public DepComputer {
|
|||
void RealRecompute(FuncGraphPtr fg) override;
|
||||
|
||||
private:
|
||||
FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared<FuncGraphSet>());
|
||||
FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, size_t seen_num);
|
||||
};
|
||||
|
||||
using FuncGraphToFuncGraphMap = OrderedMap<FuncGraphPtr, FuncGraphPtr>;
|
||||
|
@ -423,7 +423,7 @@ class FuncGraphJTotalComputer final : public DepComputer {
|
|||
void ExtraReset() override { j_total_analysis_.clear(); }
|
||||
|
||||
void RealRecompute(FuncGraphPtr fg) override;
|
||||
bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path);
|
||||
bool SeekJ(const FuncGraphPtr &fg, size_t seen_num);
|
||||
};
|
||||
|
||||
class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> {
|
||||
|
|
Loading…
Reference in New Issue