!169 fix the method to calculate the children graph

Merge pull request !169 from xychow/fix-manager-children-issue
This commit is contained in:
mindspore-ci-bot 2020-04-11 14:07:45 +08:00 committed by Gitee
commit 066f20e791
3 changed files with 18 additions and 61 deletions

View File

@ -985,40 +985,14 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) {
}
}
// children include:
// A. func graphs which use variables in fg as free variables; (child_direct_)
// B. func graphs which call func func graph in A. (all_users_)
FuncGraphSetPtr ChildrenComputer::SeekChildren(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) {
if (path == nullptr || path->contains(fg)) {
return std::make_shared<FuncGraphSet>();
}
std::shared_ptr<FuncGraphSet> children = std::make_shared<FuncGraphSet>();
auto& deps = *child_direct_;
auto& users = *all_users_;
MS_LOG(DEBUG) << "" << fg->ToString() << " start func graph dep size:" << deps[fg].size();
for (auto& dep : deps[fg]) {
FuncGraphPtr child = dep.first;
children->add(child);
path->add(child);
MS_LOG(DEBUG) << "Child func graph:" << fg->ToString() << " child " << child->ToString();
for (auto& user : users[child]) {
auto user_func_graph = user.first;
MS_LOG(DEBUG) << "Func graph:" << fg->ToString() << " user " << user_func_graph->ToString();
children->add(user_func_graph);
path->add(user_func_graph);
}
children->update(SeekChildren(child, path));
}
(void)children->erase(fg);
MS_LOG(DEBUG) << "End in children: " << children->size();
return children;
}
void ChildrenComputer::RealRecompute(FuncGraphPtr fg) {
MS_EXCEPTION_IF_NULL(manager_);
child_direct_ = &manager_->func_graph_child_direct();
all_users_ = &manager_->func_graph_users();
children_analysis_[fg].update(SeekChildren(fg));
auto used_fg_total = manager_->func_graphs_used_total(fg);
for (auto& used_fg : used_fg_total) {
if (manager_->parent(used_fg) == fg) {
children_analysis_[fg].add(used_fg);
}
}
}
void ScopeComputer::RealRecompute(FuncGraphPtr fg) {

View File

@ -398,11 +398,8 @@ class ParentComputer final : public DepComputer {
// graph's children graph except self
class ChildrenComputer final : public DepComputer {
public:
explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m), all_users_(nullptr), child_direct_(nullptr) {}
~ChildrenComputer() override {
all_users_ = nullptr;
child_direct_ = nullptr;
}
explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {}
~ChildrenComputer() override = default;
FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; }
@ -414,13 +411,6 @@ class ChildrenComputer final : public DepComputer {
void ExtraReset() override { children_analysis_.clear(); }
void RealRecompute(FuncGraphPtr fg) override;
private:
FuncGraphSetPtr SeekChildren(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared<FuncGraphSet>());
// when SeekChildren calls itself recursively, it can access these variables by class member
// other than pass by formal parameters, it can save 2 parameters for SeekChildren().
FuncGraphToFuncGraphCounterMap* all_users_;
FuncGraphToFuncGraphCounterMap* child_direct_;
};
// graph's children graph include self

View File

@ -38,16 +38,6 @@ def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE)
@ms_function
def refactor_fac(n):
""" grad_refactor_fac """
if n == 0:
return 1
return n * refactor_fac(n-1)
def test_refactor():
res = refactor_fac(3)
assert res == 6
@ms_function
def while_upper_bound(upper):
rval = 2
@ -386,16 +376,19 @@ def test_grad_while():
assert grad_while(5) == (60,)
@ms_function
def fac(n):
""" fac """
def factorial(n):
""" factorial """
if n == 0:
return 1
return n * fac(n-1)
return n * factorial(n-1)
def test_fac():
""" test_fac """
res = fac(4)
assert res == 24
def test_factorial():
res = factorial(3)
assert res == 6
def test_grad_factorial():
res = C.grad(factorial)(3)
assert res == 11
def _for(x):
""" _for """