diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 0483029f061..57c62b8116a 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -234,8 +234,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { AdjointPtr node_adjoint = nullptr; AnfNodePtr k = nullptr; if (IsValueNode(node)) { - TraceGuard trace_guard(std::make_shared(cnode_morph->debug_info())); - k = MapToK(node); + k = MapToK(cnode_morph, i); node_adjoint = std::make_shared(node, k, tape_); anfnode_to_adjoin_[node] = node_adjoint; } else { @@ -597,6 +596,31 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { return NewValueNode(functor->k_graph_); } +// Construct representation graph for primitive CNode. +AnfNodePtr DFunctor::MapToK(const CNodePtr &primal_user, size_t index) { + auto primal = primal_user->input(index); + ScopeGuard scope_guard(primal->scope()); + // Map primitive to K + if (IsValueNode(primal)) { + auto value_node = primal->cast(); + auto prim = GetValueNode(value_node); + if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { + MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; + need_cut_ = true; + } + auto k_prim = g_k_prims.KPrimitive(primal_user, value_node, resources_); + if (k_prim != nullptr) { + return NewValueNode(k_prim); + } + // When failed to find k_prim, try k_meta. + auto k_meta = g_k_prims.KMetaFuncGraph(prim); + if (k_meta != nullptr) { + return NewValueNode(k_meta); + } + } + return MapToK(primal); +} + // Construct representation graph for given node. AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { ScopeGuard scope_guard(primal->scope()); @@ -608,7 +632,7 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) { MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; need_cut_ = true; } - auto k_prim = g_k_prims.KPrimitive(value_node, resources_); + auto k_prim = g_k_prims.KPrimitive(nullptr, value_node, resources_); if (k_prim != nullptr) { return NewValueNode(k_prim); } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index e9f29d0f2a6..83ca4fa41d5 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -81,8 +81,10 @@ class DFunctor : public std::enable_shared_from_this { void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint); AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv); AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv); - // Map Anfnode object from D category to K category. + // Map AnfNode object from D category to K category. AnfNodePtr MapToK(const AnfNodePtr &primal); + // Map CNode object from D category to K category. + AnfNodePtr MapToK(const CNodePtr &primal_user, size_t index); // Map FuncGraph object from D category to K category. AnfNodePtr MapToK(const FuncGraphPtr &primal); // MapObject impls. @@ -129,7 +131,8 @@ class KPrim { KPrim() = default; ~KPrim() = default; - FuncGraphPtr KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); + FuncGraphPtr KPrimitive(const CNodePtr &primal_user, const ValueNodePtr &value_node, + const pipeline::ResourceBasePtr &resources); MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop); @@ -145,7 +148,7 @@ class KPrim { FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); // Given a bprop rule, do the K mapping. template - FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g); + FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const CNodePtr &cnode); AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, std::vector *const transf_args); @@ -156,7 +159,7 @@ class KPrim { }; template -FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { +FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(primal); MS_EXCEPTION_IF_NULL(bprop_fg); CheckBprop(bprop_fg, primal->ToString()); @@ -197,8 +200,13 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); - auto out_value = outer->NewCNode(transf_args); - + CNodePtr out_value = nullptr; + if (cnode != nullptr) { // Set equiv debug info. for Primitive CNode out. + TraceGuard trace_guard(std::make_shared(cnode->debug_info())); + out_value = outer->NewCNode(transf_args); + } else { + out_value = outer->NewCNode(transf_args); + } (void)mng->Replace(out_param, out_value); TraceGuard guard(std::make_shared(out_param->debug_info())); @@ -207,7 +215,6 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) { // We remove all parameters except new_dout. std::vector newBpropParams = {new_dout}; cloned_bprop_fg->set_parameters(newBpropParams); - outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)})); return BasicClone(outer); } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc index 12654232ef6..165a911d4d7 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc @@ -64,7 +64,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt } FuncGraphPtr Kprim(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { - auto fg = g_k_prims.KPrimitive(value_node, resources); + auto fg = g_k_prims.KPrimitive(nullptr, value_node, resources); if (fg == nullptr) { return nullptr; } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index c55c539a38e..6cc6468b12c 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -102,7 +102,8 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; } -FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { +FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node, + const pipeline::ResourceBasePtr &resources) { if (!IsValueNode(value_node)) { MS_LOG(EXCEPTION) << "Primitive node is not valid."; } @@ -141,7 +142,7 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R } } - auto expanded_fg = BpropToK(prim, bprop_fg); + auto expanded_fg = BpropToK(prim, bprop_fg, cnode); if (expanded_fg == nullptr) { MS_LOG(EXCEPTION) << "Failed convert " << prim->name() << " prim bprop function to J expanded func graph. NodeInfo: " @@ -220,7 +221,7 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { MS_EXCEPTION_IF_NULL(bprop_fg); auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph(); - auto expanded_fg = BpropToK(fprop_fg, bprop_fg); + auto expanded_fg = BpropToK(fprop_fg, bprop_fg, nullptr); if (expanded_fg == nullptr) { MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString() << " Cell bprop function to K expanded func graph. NodeInfo: " diff --git a/tests/ut/cpp/operator/grad_implementations_test.cc b/tests/ut/cpp/operator/grad_implementations_test.cc index f55553ab721..9d37d6c4744 100644 --- a/tests/ut/cpp/operator/grad_implementations_test.cc +++ b/tests/ut/cpp/operator/grad_implementations_test.cc @@ -33,11 +33,11 @@ class TestGradImplementations : public UT::Common { }; TEST_F(TestGradImplementations, TestGetAugmentedGraph) { - FuncGraphPtr fg = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr); + FuncGraphPtr fg = ad::g_k_prims.KPrimitive(nullptr, NewValueNode(kPrimScalarMul), nullptr); ASSERT_TRUE(fg != nullptr); draw::Draw("gradImpl_TestGetAugmentedFuncGraph.dot", fg); - auto fg1 = ad::g_k_prims.KPrimitive(NewValueNode(kPrimScalarMul), nullptr); + auto fg1 = ad::g_k_prims.KPrimitive(nullptr, NewValueNode(kPrimScalarMul), nullptr); FuncGraphPairMapEquiv equiv_graph; NodeMapEquiv equiv_node;