diff --git a/mindspore/ccsrc/ir/primitive_base.h b/mindspore/ccsrc/ir/primitive_base.h index ffaa12c4d5b..b34c43d00e6 100644 --- a/mindspore/ccsrc/ir/primitive_base.h +++ b/mindspore/ccsrc/ir/primitive_base.h @@ -141,7 +141,10 @@ struct PrimitiveEqual { }; struct PrimitiveHasher { - std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); } + std::size_t operator()(PrimitivePtr const &prim) const { + MS_EXCEPTION_IF_NULL(prim); + return prim->Hash(); + } }; } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/optimizer/ad/dfunctor.cc index c189c337d89..cde90db3467 100644 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/optimizer/ad/dfunctor.cc @@ -54,8 +54,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas dout_ = tape_->add_parameter(); } -void DFunctor::Init(const DFunctorPtr &functor, bool is_top) { - func_graph_to_functor_[primal_graph_] = functor; +void DFunctor::Init(bool is_top) { + func_graph_to_functor_[primal_graph_] = shared_from_this(); is_top_ = is_top; if (is_top) { scope_ = primal_graph_->scope(); @@ -371,7 +371,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false); auto functor = std::make_shared(primal, resources_); - functor->Init(functor); + functor->Init(); functor->k_graph_ = fg; return fg; @@ -394,7 +394,7 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) { } auto functor = std::make_shared(primal, resources_); - functor->Init(functor); + functor->Init(); functor->MapObject(); functor->MapMorphism(); diff --git a/mindspore/ccsrc/optimizer/ad/dfunctor.h b/mindspore/ccsrc/optimizer/ad/dfunctor.h index e35d0569086..598dd958694 100644 --- a/mindspore/ccsrc/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/optimizer/ad/dfunctor.h @@ -35,14 +35,40 @@ namespace mindspore { namespace ad { -using Registry = std::unordered_map; +struct PrimitiveTotalEqual { + bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { + if (t1->name() != t2->name()) { + return false; + } + + auto const &attrs1 = t1->attrs(); + auto const &attrs2 = t2->attrs(); + if (attrs1.size() != attrs2.size()) { + return false; + } + + for (auto &attr : attrs1) { + if (!t2->HasAttr(attr.first)) { + return false; + } + + if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) { + return false; + } + } + + return true; + } +}; + +using Registry = std::unordered_map; class KPrim; extern KPrim g_k_prims; class DFunctor; using DFunctorPtr = std::shared_ptr; // D Functor's rules to map closure object and morphisms. -class DFunctor { +class DFunctor : public std::enable_shared_from_this { public: DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources); ~DFunctor() = default; @@ -54,7 +80,7 @@ class DFunctor { // Construct user defined k object. FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); // Register functor objects to form a global view. - void Init(const DFunctorPtr &functor, bool is_top = false); + void Init(bool is_top = false); bool IsInScope(const AnfNodePtr &node); // Clear resources. diff --git a/mindspore/ccsrc/optimizer/ad/grad.cc b/mindspore/ccsrc/optimizer/ad/grad.cc index 4b5efeeefd9..43d2a66ad2d 100644 --- a/mindspore/ccsrc/optimizer/ad/grad.cc +++ b/mindspore/ccsrc/optimizer/ad/grad.cc @@ -51,7 +51,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt } return user_defined; } - f->Init(f, is_top); + f->Init(is_top); f->MapObject(); f->MapMorphism(); auto ret = f->k_graph(); diff --git a/mindspore/ccsrc/optimizer/ad/kprim.cc b/mindspore/ccsrc/optimizer/ad/kprim.cc index 711c4b7a714..a9883cbf63e 100644 --- a/mindspore/ccsrc/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/optimizer/ad/kprim.cc @@ -82,7 +82,7 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { return iter->second; } - if (prim->name() == "make_tuple") { + if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { MetaFuncGraphPtr meta = std::make_shared("make_tuple_gradient"); bprop_registry_meta_[prim::kPrimMakeTuple] = meta; return meta; @@ -111,7 +111,7 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R return fprop; } - if (prim->name() == "make_tuple") { + if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { return nullptr; } diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h index c5694ab8fa4..a388ccb7c85 100644 --- a/mindspore/ccsrc/optimizer/irpass.h +++ b/mindspore/ccsrc/optimizer/irpass.h @@ -113,15 +113,30 @@ class InferenceOptPrepareLib { // predicate functions inline bool IsNode(const AnfNodePtr &) { return true; } -inline bool IsCNode(const AnfNodePtr &node) { return node->isa(); } +inline bool IsCNode(const AnfNodePtr &node) { + if (node != nullptr) { + return node->isa(); + } + return false; +} -inline bool IsVNode(const AnfNodePtr &node) { return node->isa(); } +inline bool IsVNode(const AnfNodePtr &node) { + if (node != nullptr) { + return node->isa(); + } + return false; +} -inline bool IsParam(const AnfNodePtr &node) { return node->isa(); } +inline bool IsParam(const AnfNodePtr &node) { + if (node != nullptr) { + return node->isa(); + } + return false; +} // Check if CNode Input 0 is Func Graph inline bool IsCNodeGraph(const AnfNodePtr &node) { - if (!node->isa()) { + if (node == nullptr || !node->isa()) { return false; } @@ -131,7 +146,7 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) { // Check if CNode Input 0 is CNode inline bool IsCNodeDup(const AnfNodePtr &node) { - if (!node->isa()) { + if (node == nullptr || !node->isa()) { return false; }