From ab823632b1649bf69b0ebd06197e986284e60a63 Mon Sep 17 00:00:00 2001 From: fangzhou12 Date: Thu, 7 Apr 2022 10:19:26 +0800 Subject: [PATCH] add new core usage demo --- .../node_parser/add_parser_tutorial.cc | 10 +++++----- .../node_parser/add_parser_tutorial.h | 6 +++--- .../pass/pass_registry_tutorial.cc | 19 ++++++++++--------- .../pass/pass_registry_tutorial.h | 2 +- 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/mindspore/lite/examples/converter_extend/node_parser/add_parser_tutorial.cc b/mindspore/lite/examples/converter_extend/node_parser/add_parser_tutorial.cc index d8a73cbc0e4..9d62315ffb9 100644 --- a/mindspore/lite/examples/converter_extend/node_parser/add_parser_tutorial.cc +++ b/mindspore/lite/examples/converter_extend/node_parser/add_parser_tutorial.cc @@ -22,15 +22,15 @@ namespace mindspore { namespace converter { -ops::PrimitiveC *AddParserTutorial::Parse(const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_model) { - auto prim = std::make_unique(); +ops::BaseOperatorPtr AddParserTutorial::Parse(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_model) { + auto prim = api::MakeShared(); if (prim == nullptr) { return nullptr; } prim->set_activation_type(mindspore::NO_ACTIVATION); // user need to analyze tflite_op's attr. - return prim.release(); + return prim; } REG_NODE_PARSER(kFmkTypeTflite, ADD, std::make_shared()); diff --git a/mindspore/lite/examples/converter_extend/node_parser/add_parser_tutorial.h b/mindspore/lite/examples/converter_extend/node_parser/add_parser_tutorial.h index 6c6c185d1ee..28306d537eb 100644 --- a/mindspore/lite/examples/converter_extend/node_parser/add_parser_tutorial.h +++ b/mindspore/lite/examples/converter_extend/node_parser/add_parser_tutorial.h @@ -26,9 +26,9 @@ class AddParserTutorial : public NodeParser { public: AddParserTutorial() = default; ~AddParserTutorial() = default; - ops::PrimitiveC *Parse(const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_subgraph, - const std::unique_ptr &tflite_model) override; + ops::BaseOperatorPtr Parse(const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_subgraph, + const std::unique_ptr &tflite_model) override; }; } // namespace converter } // namespace mindspore diff --git a/mindspore/lite/examples/converter_extend/pass/pass_registry_tutorial.cc b/mindspore/lite/examples/converter_extend/pass/pass_registry_tutorial.cc index 5c0747e365d..cd149c89c4d 100644 --- a/mindspore/lite/examples/converter_extend/pass/pass_registry_tutorial.cc +++ b/mindspore/lite/examples/converter_extend/pass/pass_registry_tutorial.cc @@ -21,19 +21,20 @@ #include #include "include/registry/pass_registry.h" #include "ops/custom.h" +#include "ops/fusion/add_fusion.h" namespace mindspore { namespace opt { namespace { // check a certain node is designated node's type. -bool CheckPrimitiveTypeTutorial(const AnfNodePtr &node, const PrimitivePtr &primitive_type) { +bool CheckPrimitiveTypeTutorial(const api::AnfNodePtr &node, const api::PrimitivePtr &primitive_type) { if (node == nullptr) { return false; } - if (node->isa()) { - auto cnode = node->cast(); + if (node->isa()) { + auto cnode = node->cast(); return IsPrimitive(cnode->input(0), primitive_type); - } else if (node->isa()) { + } else if (node->isa()) { return IsPrimitive(node, primitive_type); } return false; @@ -41,11 +42,11 @@ bool CheckPrimitiveTypeTutorial(const AnfNodePtr &node, const PrimitivePtr &prim } // namespace // convert addn to custom op -AnfNodePtr PassTutorial::CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode) { +api::AnfNodePtr PassTutorial::CreateCustomOp(const api::FuncGraphPtr func_graph, const api::CNodePtr &cnode) { if (cnode == nullptr) { return nullptr; } - auto primc = std::make_shared(); + auto primc = api::MakeShared(); if (primc == nullptr) { return nullptr; } @@ -78,13 +79,13 @@ bool PassTutorial::Execute(const api::FuncGraphPtr &func_graph) { } auto node_list = api::FuncGraph::TopoSort(func_graph->get_return()); for (auto &node : node_list) { - if (!utils::isa(node)) { + if (!api::utils::isa(node)) { continue; } - if (!CheckPrimitiveTypeTutorial(node, prim::kPrimAddFusion)) { + if (!CheckPrimitiveTypeTutorial(node, mindspore::api::MakeShared())) { continue; } - auto cnode = node->cast(); + auto cnode = node->cast(); auto custome_cnode = CreateCustomOp(func_graph, cnode); if (custome_cnode == nullptr) { return false; diff --git a/mindspore/lite/examples/converter_extend/pass/pass_registry_tutorial.h b/mindspore/lite/examples/converter_extend/pass/pass_registry_tutorial.h index eb2797a51d5..1272fa627ac 100644 --- a/mindspore/lite/examples/converter_extend/pass/pass_registry_tutorial.h +++ b/mindspore/lite/examples/converter_extend/pass/pass_registry_tutorial.h @@ -30,7 +30,7 @@ class PassTutorial : public registry::PassBase { bool Execute(const api::FuncGraphPtr &func_graph) override; private: - AnfNodePtr CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode); + api::AnfNodePtr CreateCustomOp(const api::FuncGraphPtr func_graph, const api::CNodePtr &cnode); }; } // namespace opt } // namespace mindspore