forked from mindspore-Ecosystem/mindspore
!169 fix the method to calculate the children graph
Merge pull request !169 from xychow/fix-manager-children-issue
This commit is contained in:
commit
066f20e791
|
@ -985,40 +985,14 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
|
|||
}
|
||||
}
|
||||
|
||||
// children include:
|
||||
// A. func graphs which use variables in fg as free variables; (child_direct_)
|
||||
// B. func graphs which call func func graph in A. (all_users_)
|
||||
FuncGraphSetPtr ChildrenComputer::SeekChildren(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) {
|
||||
if (path == nullptr || path->contains(fg)) {
|
||||
return std::make_shared<FuncGraphSet>();
|
||||
}
|
||||
std::shared_ptr<FuncGraphSet> children = std::make_shared<FuncGraphSet>();
|
||||
auto& deps = *child_direct_;
|
||||
auto& users = *all_users_;
|
||||
MS_LOG(DEBUG) << "" << fg->ToString() << " start func graph dep size:" << deps[fg].size();
|
||||
for (auto& dep : deps[fg]) {
|
||||
FuncGraphPtr child = dep.first;
|
||||
children->add(child);
|
||||
path->add(child);
|
||||
MS_LOG(DEBUG) << "Child func graph:" << fg->ToString() << " child " << child->ToString();
|
||||
for (auto& user : users[child]) {
|
||||
auto user_func_graph = user.first;
|
||||
MS_LOG(DEBUG) << "Func graph:" << fg->ToString() << " user " << user_func_graph->ToString();
|
||||
children->add(user_func_graph);
|
||||
path->add(user_func_graph);
|
||||
}
|
||||
children->update(SeekChildren(child, path));
|
||||
}
|
||||
(void)children->erase(fg);
|
||||
MS_LOG(DEBUG) << "End in children: " << children->size();
|
||||
return children;
|
||||
}
|
||||
|
||||
void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
|
||||
MS_EXCEPTION_IF_NULL(manager_);
|
||||
child_direct_ = &manager_->func_graph_child_direct();
|
||||
all_users_ = &manager_->func_graph_users();
|
||||
children_analysis_[fg].update(SeekChildren(fg));
|
||||
auto used_fg_total = manager_->func_graphs_used_total(fg);
|
||||
for (auto& used_fg : used_fg_total) {
|
||||
if (manager_->parent(used_fg) == fg) {
|
||||
children_analysis_[fg].add(used_fg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ScopeComputer::RealRecompute(FuncGraphPtr fg) {
|
||||
|
|
|
@ -398,11 +398,8 @@ class ParentComputer final : public DepComputer {
|
|||
// graph's children graph except self
|
||||
class ChildrenComputer final : public DepComputer {
|
||||
public:
|
||||
explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m), all_users_(nullptr), child_direct_(nullptr) {}
|
||||
~ChildrenComputer() override {
|
||||
all_users_ = nullptr;
|
||||
child_direct_ = nullptr;
|
||||
}
|
||||
explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {}
|
||||
~ChildrenComputer() override = default;
|
||||
|
||||
FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; }
|
||||
|
||||
|
@ -414,13 +411,6 @@ class ChildrenComputer final : public DepComputer {
|
|||
void ExtraReset() override { children_analysis_.clear(); }
|
||||
|
||||
void RealRecompute(FuncGraphPtr fg) override;
|
||||
|
||||
private:
|
||||
FuncGraphSetPtr SeekChildren(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared<FuncGraphSet>());
|
||||
// when SeekChildren calls itself recursively, it can access these variables by class member
|
||||
// other than pass by formal parameters, it can save 2 parameters for SeekChildren().
|
||||
FuncGraphToFuncGraphCounterMap* all_users_;
|
||||
FuncGraphToFuncGraphCounterMap* child_direct_;
|
||||
};
|
||||
|
||||
// graph's children graph include self
|
||||
|
|
|
@ -38,16 +38,6 @@ def setup_module(module):
|
|||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
|
||||
@ms_function
|
||||
def refactor_fac(n):
|
||||
""" grad_refactor_fac """
|
||||
if n == 0:
|
||||
return 1
|
||||
return n * refactor_fac(n-1)
|
||||
def test_refactor():
|
||||
res = refactor_fac(3)
|
||||
assert res == 6
|
||||
|
||||
@ms_function
|
||||
def while_upper_bound(upper):
|
||||
rval = 2
|
||||
|
@ -386,16 +376,19 @@ def test_grad_while():
|
|||
assert grad_while(5) == (60,)
|
||||
|
||||
@ms_function
|
||||
def fac(n):
|
||||
""" fac """
|
||||
def factorial(n):
|
||||
""" factorial """
|
||||
if n == 0:
|
||||
return 1
|
||||
return n * fac(n-1)
|
||||
return n * factorial(n-1)
|
||||
|
||||
def test_fac():
|
||||
""" test_fac """
|
||||
res = fac(4)
|
||||
assert res == 24
|
||||
def test_factorial():
|
||||
res = factorial(3)
|
||||
assert res == 6
|
||||
|
||||
def test_grad_factorial():
|
||||
res = C.grad(factorial)(3)
|
||||
assert res == 11
|
||||
|
||||
def _for(x):
|
||||
""" _for """
|
||||
|
|
Loading…
Reference in New Issue