!2295 Optimization for ad

Merge pull request !2295 from Kang/master
This commit is contained in:
mindspore-ci-bot 2020-06-19 11:10:58 +08:00 committed by Gitee
commit b3f91a4f22
6 changed files with 60 additions and 16 deletions

View File

@ -141,7 +141,10 @@ struct PrimitiveEqual {
}; };
struct PrimitiveHasher { struct PrimitiveHasher {
std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); } std::size_t operator()(PrimitivePtr const &prim) const {
MS_EXCEPTION_IF_NULL(prim);
return prim->Hash();
}
}; };
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_ #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_BASE_H_

View File

@ -54,8 +54,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
dout_ = tape_->add_parameter(); dout_ = tape_->add_parameter();
} }
void DFunctor::Init(const DFunctorPtr &functor, bool is_top) { void DFunctor::Init(bool is_top) {
func_graph_to_functor_[primal_graph_] = functor; func_graph_to_functor_[primal_graph_] = shared_from_this();
is_top_ = is_top; is_top_ = is_top;
if (is_top) { if (is_top) {
scope_ = primal_graph_->scope(); scope_ = primal_graph_->scope();
@ -371,7 +371,7 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false); primal->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
auto functor = std::make_shared<DFunctor>(primal, resources_); auto functor = std::make_shared<DFunctor>(primal, resources_);
functor->Init(functor); functor->Init();
functor->k_graph_ = fg; functor->k_graph_ = fg;
return fg; return fg;
@ -394,7 +394,7 @@ AnfNodePtr DFunctor::MapToK(const FuncGraphPtr &primal) {
} }
auto functor = std::make_shared<DFunctor>(primal, resources_); auto functor = std::make_shared<DFunctor>(primal, resources_);
functor->Init(functor); functor->Init();
functor->MapObject(); functor->MapObject();
functor->MapMorphism(); functor->MapMorphism();

View File

@ -35,14 +35,40 @@
namespace mindspore { namespace mindspore {
namespace ad { namespace ad {
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr>; struct PrimitiveTotalEqual {
bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const {
if (t1->name() != t2->name()) {
return false;
}
auto const &attrs1 = t1->attrs();
auto const &attrs2 = t2->attrs();
if (attrs1.size() != attrs2.size()) {
return false;
}
for (auto &attr : attrs1) {
if (!t2->HasAttr(attr.first)) {
return false;
}
if (!(*(attr.second) == *(t2->GetAttr(attr.first)))) {
return false;
}
}
return true;
}
};
using Registry = std::unordered_map<PrimitivePtr, FuncGraphPtr, PrimitiveHasher, PrimitiveTotalEqual>;
class KPrim; class KPrim;
extern KPrim g_k_prims; extern KPrim g_k_prims;
class DFunctor; class DFunctor;
using DFunctorPtr = std::shared_ptr<DFunctor>; using DFunctorPtr = std::shared_ptr<DFunctor>;
// D Functor's rules to map closure object and morphisms. // D Functor's rules to map closure object and morphisms.
class DFunctor { class DFunctor : public std::enable_shared_from_this<DFunctor> {
public: public:
DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources); DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources);
~DFunctor() = default; ~DFunctor() = default;
@ -54,7 +80,7 @@ class DFunctor {
// Construct user defined k object. // Construct user defined k object.
FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); FuncGraphPtr KUserDefined(const FuncGraphPtr &primal);
// Register functor objects to form a global view. // Register functor objects to form a global view.
void Init(const DFunctorPtr &functor, bool is_top = false); void Init(bool is_top = false);
bool IsInScope(const AnfNodePtr &node); bool IsInScope(const AnfNodePtr &node);
// Clear resources. // Clear resources.

View File

@ -51,7 +51,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
} }
return user_defined; return user_defined;
} }
f->Init(f, is_top); f->Init(is_top);
f->MapObject(); f->MapObject();
f->MapMorphism(); f->MapMorphism();
auto ret = f->k_graph(); auto ret = f->k_graph();

View File

@ -82,7 +82,7 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
return iter->second; return iter->second;
} }
if (prim->name() == "make_tuple") { if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient"); MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
bprop_registry_meta_[prim::kPrimMakeTuple] = meta; bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
return meta; return meta;
@ -111,7 +111,7 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
return fprop; return fprop;
} }
if (prim->name() == "make_tuple") { if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
return nullptr; return nullptr;
} }

View File

@ -113,15 +113,30 @@ class InferenceOptPrepareLib {
// predicate functions // predicate functions
inline bool IsNode(const AnfNodePtr &) { return true; } inline bool IsNode(const AnfNodePtr &) { return true; }
inline bool IsCNode(const AnfNodePtr &node) { return node->isa<CNode>(); } inline bool IsCNode(const AnfNodePtr &node) {
if (node != nullptr) {
return node->isa<CNode>();
}
return false;
}
inline bool IsVNode(const AnfNodePtr &node) { return node->isa<ValueNode>(); } inline bool IsVNode(const AnfNodePtr &node) {
if (node != nullptr) {
return node->isa<ValueNode>();
}
return false;
}
inline bool IsParam(const AnfNodePtr &node) { return node->isa<Parameter>(); } inline bool IsParam(const AnfNodePtr &node) {
if (node != nullptr) {
return node->isa<Parameter>();
}
return false;
}
// Check if CNode Input 0 is Func Graph // Check if CNode Input 0 is Func Graph
inline bool IsCNodeGraph(const AnfNodePtr &node) { inline bool IsCNodeGraph(const AnfNodePtr &node) {
if (!node->isa<CNode>()) { if (node == nullptr || !node->isa<CNode>()) {
return false; return false;
} }
@ -131,7 +146,7 @@ inline bool IsCNodeGraph(const AnfNodePtr &node) {
// Check if CNode Input 0 is CNode // Check if CNode Input 0 is CNode
inline bool IsCNodeDup(const AnfNodePtr &node) { inline bool IsCNodeDup(const AnfNodePtr &node) {
if (!node->isa<CNode>()) { if (node == nullptr || !node->isa<CNode>()) {
return false; return false;
} }