!2295 Optimization for ad

Merge pull request !2295 from Kang/master
This commit is contained in:
mindspore-ci-bot 2020-06-19 11:10:58 +08:00 committed by Gitee
commit b3f91a4f22
6 changed files with 60 additions and 16 deletions

View File

@ -141,7 +141,10 @@ struct PrimitiveEqual {
};
struct PrimitiveHasher {
std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); }
std::size_t operator()(PrimitivePtr const &prim) const {
MS_EXCEPTION_IF_NULL(prim);
return prim->Hash();
}
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_

View File

@ -54,8 +54,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
dout_ = tape_->add_parameter();
}
void DFunctor::Init(const DFunctorPtr &functor, bool is_top) {
func_graph_to_functor_[primal_graph_] = functor;
void DFunctor::Init(bool is_top) {
func_graph_to_functor_[primal_graph_] = shared_from_this();
is_top_ = is_top;
if (is_top) {
scope_ = primal_graph_->scope();
@ -371,7 +371,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
auto functor = std::make_shared<DFunctor>(primal, resources_);
functor->Init(functor);
functor->Init();
functor->k_graph_ = fg;
return fg;
@ -394,7 +394,7 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
}
auto functor = std::make_shared<DFunctor>(primal, resources_);
functor->Init(functor);
functor->Init();
functor->MapObject();
functor->MapMorphism();

View File

@ -35,14 +35,40 @@
namespace mindspore {
namespace ad {
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr>;
struct PrimitiveTotalEqual {
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
if (t1->name() != t2->name()) {
return false;
}
auto const &attrs1 = t1->attrs();
auto const &attrs2 = t2->attrs();
if (attrs1.size() != attrs2.size()) {
return false;
}
for (auto &attr : attrs1) {
if (!t2->HasAttr(attr.first)) {
return false;
}
if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) {
return false;
}
}
return true;
}
};
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
class KPrim;
extern KPrim g_k_prims;
class DFunctor;
using DFunctorPtr = std::shared_ptr<DFunctor>;
// D Functor's rules to map closure object and morphisms.
class DFunctor {
class DFunctor : public std::enable_shared_from_this<DFunctor> {
public:
DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources);
~DFunctor() = default;
@ -54,7 +80,7 @@ class DFunctor {
// Construct user defined k object.
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
// Register functor objects to form a global view.
void Init(const DFunctorPtr &functor, bool is_top = false);
void Init(bool is_top = false);
bool IsInScope(const AnfNodePtr &node);
// Clear resources.

View File

@ -51,7 +51,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
}
return user_defined;
}
f->Init(f, is_top);
f->Init(is_top);
f->MapObject();
f->MapMorphism();
auto ret = f->k_graph();

View File

@ -82,7 +82,7 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
return iter->second;
}
if (prim->name() == "make_tuple") {
if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
return meta;
@ -111,7 +111,7 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
return fprop;
}
if (prim->name() == "make_tuple") {
if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
return nullptr;
}

View File

@ -113,15 +113,30 @@ class InferenceOptPrepareLib {
// predicate functions
inline bool IsNode(const AnfNodePtr &) { return true; }
inline bool IsCNode(const AnfNodePtr &node) { return node->isa<CNode>(); }
inline bool IsCNode(const AnfNodePtr &node) {
if (node != nullptr) {
return node->isa<CNode>();
}
return false;
}
inline bool IsVNode(const AnfNodePtr &node) { return node->isa<ValueNode>(); }
inline bool IsVNode(const AnfNodePtr &node) {
if (node != nullptr) {
return node->isa<ValueNode>();
}
return false;
}
inline bool IsParam(const AnfNodePtr &node) { return node->isa<Parameter>(); }
inline bool IsParam(const AnfNodePtr &node) {
if (node != nullptr) {
return node->isa<Parameter>();
}
return false;
}
// Check if CNode Input 0 is Func Graph
inline bool IsCNodeGraph(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
if (node == nullptr || !node->isa<CNode>()) {
return false;
}
@ -131,7 +146,7 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) {
// Check if CNode Input 0 is CNode
inline bool IsCNodeDup(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
if (node == nullptr || !node->isa<CNode>()) {
return false;
}