forked from mindspore-Ecosystem/mindspore
commit
b3f91a4f22
|
@ -141,7 +141,10 @@ struct PrimitiveEqual {
|
|||
};
|
||||
|
||||
struct PrimitiveHasher {
|
||||
std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); }
|
||||
std::size_t operator()(PrimitivePtr const &prim) const {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return prim->Hash();
|
||||
}
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
|
||||
|
|
|
@ -54,8 +54,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
|||
dout_ = tape_->add_parameter();
|
||||
}
|
||||
|
||||
void DFunctor::Init(const DFunctorPtr &functor, bool is_top) {
|
||||
func_graph_to_functor_[primal_graph_] = functor;
|
||||
void DFunctor::Init(bool is_top) {
|
||||
func_graph_to_functor_[primal_graph_] = shared_from_this();
|
||||
is_top_ = is_top;
|
||||
if (is_top) {
|
||||
scope_ = primal_graph_->scope();
|
||||
|
@ -371,7 +371,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
|
|||
primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
|
||||
|
||||
auto functor = std::make_shared<DFunctor>(primal, resources_);
|
||||
functor->Init(functor);
|
||||
functor->Init();
|
||||
functor->k_graph_ = fg;
|
||||
|
||||
return fg;
|
||||
|
@ -394,7 +394,7 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
|
|||
}
|
||||
|
||||
auto functor = std::make_shared<DFunctor>(primal, resources_);
|
||||
functor->Init(functor);
|
||||
functor->Init();
|
||||
functor->MapObject();
|
||||
functor->MapMorphism();
|
||||
|
||||
|
|
|
@ -35,14 +35,40 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr>;
|
||||
struct PrimitiveTotalEqual {
|
||||
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
||||
if (t1->name() != t2->name()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto const &attrs1 = t1->attrs();
|
||||
auto const &attrs2 = t2->attrs();
|
||||
if (attrs1.size() != attrs2.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto &attr : attrs1) {
|
||||
if (!t2->HasAttr(attr.first)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
|
||||
class KPrim;
|
||||
extern KPrim g_k_prims;
|
||||
class DFunctor;
|
||||
using DFunctorPtr = std::shared_ptr<DFunctor>;
|
||||
|
||||
// D Functor's rules to map closure object and morphisms.
|
||||
class DFunctor {
|
||||
class DFunctor : public std::enable_shared_from_this<DFunctor> {
|
||||
public:
|
||||
DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources);
|
||||
~DFunctor() = default;
|
||||
|
@ -54,7 +80,7 @@ class DFunctor {
|
|||
// Construct user defined k object.
|
||||
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
|
||||
// Register functor objects to form a global view.
|
||||
void Init(const DFunctorPtr &functor, bool is_top = false);
|
||||
void Init(bool is_top = false);
|
||||
bool IsInScope(const AnfNodePtr &node);
|
||||
|
||||
// Clear resources.
|
||||
|
|
|
@ -51,7 +51,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
|||
}
|
||||
return user_defined;
|
||||
}
|
||||
f->Init(f, is_top);
|
||||
f->Init(is_top);
|
||||
f->MapObject();
|
||||
f->MapMorphism();
|
||||
auto ret = f->k_graph();
|
||||
|
|
|
@ -82,7 +82,7 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
if (prim->name() == "make_tuple") {
|
||||
if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
||||
MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
|
||||
bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
|
||||
return meta;
|
||||
|
@ -111,7 +111,7 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
|||
return fprop;
|
||||
}
|
||||
|
||||
if (prim->name() == "make_tuple") {
|
||||
if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -113,15 +113,30 @@ class InferenceOptPrepareLib {
|
|||
// predicate functions
|
||||
inline bool IsNode(const AnfNodePtr &) { return true; }
|
||||
|
||||
inline bool IsCNode(const AnfNodePtr &node) { return node->isa<CNode>(); }
|
||||
inline bool IsCNode(const AnfNodePtr &node) {
|
||||
if (node != nullptr) {
|
||||
return node->isa<CNode>();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool IsVNode(const AnfNodePtr &node) { return node->isa<ValueNode>(); }
|
||||
inline bool IsVNode(const AnfNodePtr &node) {
|
||||
if (node != nullptr) {
|
||||
return node->isa<ValueNode>();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool IsParam(const AnfNodePtr &node) { return node->isa<Parameter>(); }
|
||||
inline bool IsParam(const AnfNodePtr &node) {
|
||||
if (node != nullptr) {
|
||||
return node->isa<Parameter>();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if CNode Input 0 is Func Graph
|
||||
inline bool IsCNodeGraph(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -131,7 +146,7 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) {
|
|||
|
||||
// Check if CNode Input 0 is CNode
|
||||
inline bool IsCNodeDup(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
if (node == nullptr || !node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue