forked from mindspore-Ecosystem/mindspore
commit
b3f91a4f22
|
@ -141,7 +141,10 @@ struct PrimitiveEqual {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PrimitiveHasher {
|
struct PrimitiveHasher {
|
||||||
std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); }
|
std::size_t operator()(PrimitivePtr const &prim) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
return prim->Hash();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
|
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_
|
||||||
|
|
|
@ -54,8 +54,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
||||||
dout_ = tape_->add_parameter();
|
dout_ = tape_->add_parameter();
|
||||||
}
|
}
|
||||||
|
|
||||||
void DFunctor::Init(const DFunctorPtr &functor, bool is_top) {
|
void DFunctor::Init(bool is_top) {
|
||||||
func_graph_to_functor_[primal_graph_] = functor;
|
func_graph_to_functor_[primal_graph_] = shared_from_this();
|
||||||
is_top_ = is_top;
|
is_top_ = is_top;
|
||||||
if (is_top) {
|
if (is_top) {
|
||||||
scope_ = primal_graph_->scope();
|
scope_ = primal_graph_->scope();
|
||||||
|
@ -371,7 +371,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
|
||||||
primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
|
primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
|
||||||
|
|
||||||
auto functor = std::make_shared<DFunctor>(primal, resources_);
|
auto functor = std::make_shared<DFunctor>(primal, resources_);
|
||||||
functor->Init(functor);
|
functor->Init();
|
||||||
functor->k_graph_ = fg;
|
functor->k_graph_ = fg;
|
||||||
|
|
||||||
return fg;
|
return fg;
|
||||||
|
@ -394,7 +394,7 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto functor = std::make_shared<DFunctor>(primal, resources_);
|
auto functor = std::make_shared<DFunctor>(primal, resources_);
|
||||||
functor->Init(functor);
|
functor->Init();
|
||||||
functor->MapObject();
|
functor->MapObject();
|
||||||
functor->MapMorphism();
|
functor->MapMorphism();
|
||||||
|
|
||||||
|
|
|
@ -35,14 +35,40 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ad {
|
namespace ad {
|
||||||
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr>;
|
struct PrimitiveTotalEqual {
|
||||||
|
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
|
||||||
|
if (t1->name() != t2->name()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const &attrs1 = t1->attrs();
|
||||||
|
auto const &attrs2 = t2->attrs();
|
||||||
|
if (attrs1.size() != attrs2.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &attr : attrs1) {
|
||||||
|
if (!t2->HasAttr(attr.first)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
|
||||||
class KPrim;
|
class KPrim;
|
||||||
extern KPrim g_k_prims;
|
extern KPrim g_k_prims;
|
||||||
class DFunctor;
|
class DFunctor;
|
||||||
using DFunctorPtr = std::shared_ptr<DFunctor>;
|
using DFunctorPtr = std::shared_ptr<DFunctor>;
|
||||||
|
|
||||||
// D Functor's rules to map closure object and morphisms.
|
// D Functor's rules to map closure object and morphisms.
|
||||||
class DFunctor {
|
class DFunctor : public std::enable_shared_from_this<DFunctor> {
|
||||||
public:
|
public:
|
||||||
DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources);
|
DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources);
|
||||||
~DFunctor() = default;
|
~DFunctor() = default;
|
||||||
|
@ -54,7 +80,7 @@ class DFunctor {
|
||||||
// Construct user defined k object.
|
// Construct user defined k object.
|
||||||
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
|
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
|
||||||
// Register functor objects to form a global view.
|
// Register functor objects to form a global view.
|
||||||
void Init(const DFunctorPtr &functor, bool is_top = false);
|
void Init(bool is_top = false);
|
||||||
bool IsInScope(const AnfNodePtr &node);
|
bool IsInScope(const AnfNodePtr &node);
|
||||||
|
|
||||||
// Clear resources.
|
// Clear resources.
|
||||||
|
|
|
@ -51,7 +51,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
|
||||||
}
|
}
|
||||||
return user_defined;
|
return user_defined;
|
||||||
}
|
}
|
||||||
f->Init(f, is_top);
|
f->Init(is_top);
|
||||||
f->MapObject();
|
f->MapObject();
|
||||||
f->MapMorphism();
|
f->MapMorphism();
|
||||||
auto ret = f->k_graph();
|
auto ret = f->k_graph();
|
||||||
|
|
|
@ -82,7 +82,7 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (prim->name() == "make_tuple") {
|
if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
||||||
MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
|
MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
|
||||||
bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
|
bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
|
||||||
return meta;
|
return meta;
|
||||||
|
@ -111,7 +111,7 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
||||||
return fprop;
|
return fprop;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (prim->name() == "make_tuple") {
|
if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -113,15 +113,30 @@ class InferenceOptPrepareLib {
|
||||||
// predicate functions
|
// predicate functions
|
||||||
inline bool IsNode(const AnfNodePtr &) { return true; }
|
inline bool IsNode(const AnfNodePtr &) { return true; }
|
||||||
|
|
||||||
inline bool IsCNode(const AnfNodePtr &node) { return node->isa<CNode>(); }
|
inline bool IsCNode(const AnfNodePtr &node) {
|
||||||
|
if (node != nullptr) {
|
||||||
|
return node->isa<CNode>();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool IsVNode(const AnfNodePtr &node) { return node->isa<ValueNode>(); }
|
inline bool IsVNode(const AnfNodePtr &node) {
|
||||||
|
if (node != nullptr) {
|
||||||
|
return node->isa<ValueNode>();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
inline bool IsParam(const AnfNodePtr &node) { return node->isa<Parameter>(); }
|
inline bool IsParam(const AnfNodePtr &node) {
|
||||||
|
if (node != nullptr) {
|
||||||
|
return node->isa<Parameter>();
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if CNode Input 0 is Func Graph
|
// Check if CNode Input 0 is Func Graph
|
||||||
inline bool IsCNodeGraph(const AnfNodePtr &node) {
|
inline bool IsCNodeGraph(const AnfNodePtr &node) {
|
||||||
if (!node->isa<CNode>()) {
|
if (node == nullptr || !node->isa<CNode>()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,7 +146,7 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) {
|
||||||
|
|
||||||
// Check if CNode Input 0 is CNode
|
// Check if CNode Input 0 is CNode
|
||||||
inline bool IsCNodeDup(const AnfNodePtr &node) {
|
inline bool IsCNodeDup(const AnfNodePtr &node) {
|
||||||
if (!node->isa<CNode>()) {
|
if (node == nullptr || !node->isa<CNode>()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue