!7568 Resolve nodes swell issue with control flow after grad operation.

Merge pull request !7568 from 张清华/grad_opt
This commit is contained in:
mindspore-ci-bot 2020-10-24 19:25:12 +08:00 committed by Gitee
commit 6cb1a3f53a
3 changed files with 33 additions and 4 deletions

View File

@ -71,6 +71,11 @@ void DFunctor::Init(bool is_top) {
}
}
void DFunctor::Finish() {
CallDoutHoleOnTape();
EliminatePrimalGraph();
}
void DFunctor::Clear() {
func_graph_to_functor_.clear();
anfnode_to_adjoin_definition_.clear();
@ -728,10 +733,7 @@ void DFunctor::CallDoutHoleOnTape() {
}
}
}
FuncGraphPtr DFunctor::k_graph() {
CallDoutHoleOnTape();
return k_graph_;
}
FuncGraphPtr DFunctor::k_graph() { return k_graph_; }
void DFunctor::BroadCastStopFlag() {
// As stop set expanding, all directly or indirectly stopped CNode will be cut off
@ -768,5 +770,28 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
}
return true;
}
// To replace the primal graph with k graph
void DFunctor::EliminatePrimalGraph() {
auto k_vnode = NewValueNode(k_graph_);
auto idx0 = NewValueNode(SizeToInt(0));
auto imm0 = std::make_shared<Int32Imm>(0);
idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
auto manager = primal_graph_->manager();
auto users = primal_graph_->func_graph_cnodes_index();
for (auto &it : users) {
auto cnode = it.first->first->cast<CNodePtr>();
auto index = it.first->second;
auto vnode = cnode->inputs()[index];
if (index != 0) {
MS_LOG(INFO) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}";
continue;
}
cnode->set_input(0, k_vnode); // Replace primal graph with k graph
auto construct_wrapper = cnode->func_graph();
auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0});
manager->Replace(cnode, getitem0);
}
}
} // namespace ad
} // namespace mindspore

View File

@ -64,6 +64,7 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
// Register functor objects to form a global view.
void Init(bool is_top = false);
void Finish();
bool IsInScope(const AnfNodePtr &node);
// Clear resources.
@ -97,6 +98,8 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
void UpdateAdjoint(const AdjointPtr &adjoint_definition);
void CallDoutHoleOnTape();
void ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph);
// Replace the primal graph with k graph
void EliminatePrimalGraph();
std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.

View File

@ -53,6 +53,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
f->Init(is_top);
f->MapObject();
f->MapMorphism();
f->Finish();
auto ret = f->k_graph();
if (is_top) {
DFunctor::Clear();