From ed26666210aeaf4d7599f36298fbd513fa2c6781 Mon Sep 17 00:00:00 2001 From: l00591931 Date: Mon, 31 May 2021 11:38:36 +0800 Subject: [PATCH] Fix bug for primal_attr --- .../ccsrc/frontend/optimizer/ad/dfunctor.h | 18 ++++++--- .../ccsrc/frontend/optimizer/ad/kprim.cc | 40 +++++++++---------- mindspore/core/ir/anf.h | 4 +- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index a51543d307d..4879c58fde8 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -33,6 +33,7 @@ #include "frontend/optimizer/ad/adjoint.h" #include "frontend/operator/ops.h" #include "debug/trace.h" +#include "utils/utils.h" namespace mindspore { namespace ad { @@ -142,8 +143,7 @@ class KPrim { FuncGraphPtr GetPossibleBprop(const PrimitivePtr &prim); private: - FuncGraphPtr GetBprop(const PrimitivePtr &prim, const std::unordered_map &primal_attrs, - const std::vector &primal_debug_infos); + FuncGraphPtr GetBprop(const PrimitivePtr &prim); FuncGraphPtr GetFprop(const PrimitivePtr &prim); FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); @@ -152,7 +152,8 @@ class KPrim { // Refer the comment in KUserDefinedCellBprop. template FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const FuncGraphPtr ¤t_primal_fg, - const CNodePtr &cnode); + const CNodePtr &cnode, const std::unordered_map &primal_attrs, + const std::vector &primal_debug_infos); AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg); void TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const PrimitivePtr &primitive, const FuncGraphPtr &outer, @@ -169,15 +170,20 @@ class KPrim { template FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg, - const CNodePtr &cnode) { + const CNodePtr &cnode, const std::unordered_map &primal_attrs, + const std::vector &primal_debug_infos) { MS_EXCEPTION_IF_NULL(primal); MS_EXCEPTION_IF_NULL(bprop_fg); CheckBprop(bprop_fg, primal->ToString()); auto debug_info = std::make_shared(); debug_info->set_name(primal->ToString()); - - auto cloned_bprop_fg = BasicClone(bprop_fg); + FuncGraphPtr cloned_bprop_fg; + { + PrimalAttrGuard primal_attr_guard(primal_attrs); + PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos); + cloned_bprop_fg = BasicClone(bprop_fg); + } MS_EXCEPTION_IF_NULL(cloned_bprop_fg); cloned_bprop_fg->debug_info()->set_name(""); diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index 439d0bb2b2e..cd79a608334 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -40,8 +40,7 @@ namespace mindspore { namespace ad { KPrim g_k_prims; -FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const std::unordered_map &primal_attrs, - const std::vector &primal_debug_infos) { +FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { // Set a child scope named "grad'PrimitiveName'" for the bprop function, // and add "Gradients" to the front. static const std::string gradients_scope = "Gradients/"; @@ -50,8 +49,6 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const std::unordered_map< auto scope = std::make_shared(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() + grad_op_child_scope_prefix + prim->name()); ScopeGuard scope_guard(scope); - PrimalAttrGuard primal_attr_guard(primal_attrs); - PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos); py::function fn; if (prim->is_base()) { @@ -87,7 +84,7 @@ FuncGraphPtr KPrim::GetPossibleBprop(const PrimitivePtr &prim) { } if (bprop_fg == nullptr) { - bprop_fg = GetBprop(prim, {}, {}); + bprop_fg = GetBprop(prim); if (bprop_fg != nullptr) { // Set bprop_g graph cache bprop_registry_[prim] = bprop_fg; @@ -190,20 +187,7 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_ } if (bprop_fg == nullptr) { - std::unordered_map primal_attrs; - std::vector primal_debug_infos; - if (resources != nullptr) { - auto manager = resources->manager(); - auto &users = manager->node_users()[value_node]; - for (auto user_iter = users.begin(); user_iter != users.end(); user_iter++) { - primal_debug_infos.push_back(user_iter->first->debug_info()); - } - } - if (cnode != nullptr) { - const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId(); - primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr); - } - bprop_fg = GetBprop(prim, primal_attrs, primal_debug_infos); + bprop_fg = GetBprop(prim); if (bprop_fg != nullptr) { // Set bprop_g graph cache bprop_registry_[prim] = bprop_fg; @@ -214,7 +198,21 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_ } AdjustForAutoMonad(prim, bprop_fg); - auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode); + std::unordered_map primal_attrs; + std::vector primal_debug_infos; + if (resources != nullptr) { + auto manager = resources->manager(); + auto &users = manager->node_users()[value_node]; + for (auto user_iter = users.begin(); user_iter != users.end(); user_iter++) { + primal_debug_infos.push_back(user_iter->first->debug_info()); + } + } + if (cnode != nullptr) { + primal_attrs = cnode->primal_attrs(); + const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId(); + primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr); + } + auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode, primal_attrs, primal_debug_infos); if (expanded_fg == nullptr) { MS_LOG(EXCEPTION) << "Failed convert " << prim->name() << " prim bprop function to J expanded func graph. NodeInfo: " @@ -376,7 +374,7 @@ FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const Fu // primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph. // current_primal_fg is specalized and AutoMoaded primal_fg; auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph(); - auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr); + auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr, {}, {}); if (expanded_fg == nullptr) { MS_LOG(EXCEPTION) << "Failed convert " << primal_fg->ToString() << " Cell bprop function to K expanded func graph. NodeInfo: " diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 358bae907f8..43dde4eb3b3 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -312,7 +312,9 @@ class CNode : public AnfNode, public EffectInfoHolder { std::vector primal_debug_infos() { return primal_debug_infos_; } - void set_primal_debug_infos(const std::vector &debug_infos) { primal_debug_infos_ = debug_infos; } + void set_primal_debug_infos(const std::vector &debug_infos) { + primal_debug_infos_.insert(primal_debug_infos_.end(), debug_infos.begin(), debug_infos.end()); + } void AddPrimalDebugInfo(const NodeDebugInfoPtr debug_info) { if (std::find(primal_debug_infos_.begin(), primal_debug_infos_.end(), debug_info) != primal_debug_infos_.end()) {