forked from mindspore-Ecosystem/mindspore
!7568 Resolve nodes swell issue with control flow after grad operation.
Merge pull request !7568 from 张清华/grad_opt
This commit is contained in:
commit
6cb1a3f53a
|
@ -71,6 +71,11 @@ void DFunctor::Init(bool is_top) {
|
|||
}
|
||||
}
|
||||
|
||||
void DFunctor::Finish() {
|
||||
CallDoutHoleOnTape();
|
||||
EliminatePrimalGraph();
|
||||
}
|
||||
|
||||
void DFunctor::Clear() {
|
||||
func_graph_to_functor_.clear();
|
||||
anfnode_to_adjoin_definition_.clear();
|
||||
|
@ -728,10 +733,7 @@ void DFunctor::CallDoutHoleOnTape() {
|
|||
}
|
||||
}
|
||||
}
|
||||
FuncGraphPtr DFunctor::k_graph() {
|
||||
CallDoutHoleOnTape();
|
||||
return k_graph_;
|
||||
}
|
||||
FuncGraphPtr DFunctor::k_graph() { return k_graph_; }
|
||||
|
||||
void DFunctor::BroadCastStopFlag() {
|
||||
// As stop set expanding, all directly or indirectly stopped CNode will be cut off
|
||||
|
@ -768,5 +770,28 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// To replace the primal graph with k graph
|
||||
void DFunctor::EliminatePrimalGraph() {
|
||||
auto k_vnode = NewValueNode(k_graph_);
|
||||
auto idx0 = NewValueNode(SizeToInt(0));
|
||||
auto imm0 = std::make_shared<Int32Imm>(0);
|
||||
idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
|
||||
auto manager = primal_graph_->manager();
|
||||
auto users = primal_graph_->func_graph_cnodes_index();
|
||||
for (auto &it : users) {
|
||||
auto cnode = it.first->first->cast<CNodePtr>();
|
||||
auto index = it.first->second;
|
||||
auto vnode = cnode->inputs()[index];
|
||||
if (index != 0) {
|
||||
MS_LOG(INFO) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}";
|
||||
continue;
|
||||
}
|
||||
cnode->set_input(0, k_vnode); // Replace primal graph with k graph
|
||||
auto construct_wrapper = cnode->func_graph();
|
||||
auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0});
|
||||
manager->Replace(cnode, getitem0);
|
||||
}
|
||||
}
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -64,6 +64,7 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
|
|||
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
|
||||
// Register functor objects to form a global view.
|
||||
void Init(bool is_top = false);
|
||||
void Finish();
|
||||
bool IsInScope(const AnfNodePtr &node);
|
||||
|
||||
// Clear resources.
|
||||
|
@ -97,6 +98,8 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
|
|||
void UpdateAdjoint(const AdjointPtr &adjoint_definition);
|
||||
void CallDoutHoleOnTape();
|
||||
void ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph);
|
||||
// Replace the primal graph with k graph
|
||||
void EliminatePrimalGraph();
|
||||
|
||||
std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
|
||||
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
|
||||
|
|
|
@ -53,6 +53,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
|||
f->Init(is_top);
|
||||
f->MapObject();
|
||||
f->MapMorphism();
|
||||
f->Finish();
|
||||
auto ret = f->k_graph();
|
||||
if (is_top) {
|
||||
DFunctor::Clear();
|
||||
|
|
Loading…
Reference in New Issue