forked from mindspore-Ecosystem/mindspore
!32681 [MSLITE]add new core usage demo
Merge pull request !32681 from fangzhou0329/fz_dev
This commit is contained in:
commit
f912e3d15e
|
@ -22,15 +22,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace converter {
|
||||
ops::PrimitiveC *AddParserTutorial::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto prim = std::make_unique<ops::AddFusion>();
|
||||
ops::BaseOperatorPtr AddParserTutorial::Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) {
|
||||
auto prim = api::MakeShared<ops::AddFusion>();
|
||||
if (prim == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_activation_type(mindspore::NO_ACTIVATION); // user need to analyze tflite_op's attr.
|
||||
return prim.release();
|
||||
return prim;
|
||||
}
|
||||
|
||||
REG_NODE_PARSER(kFmkTypeTflite, ADD, std::make_shared<AddParserTutorial>());
|
||||
|
|
|
@ -26,9 +26,9 @@ class AddParserTutorial : public NodeParser {
|
|||
public:
|
||||
AddParserTutorial() = default;
|
||||
~AddParserTutorial() = default;
|
||||
ops::PrimitiveC *Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
ops::BaseOperatorPtr Parse(const std::unique_ptr<tflite::OperatorT> &tflite_op,
|
||||
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
|
||||
const std::unique_ptr<tflite::ModelT> &tflite_model) override;
|
||||
};
|
||||
} // namespace converter
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,19 +21,20 @@
|
|||
#include <vector>
|
||||
#include "include/registry/pass_registry.h"
|
||||
#include "ops/custom.h"
|
||||
#include "ops/fusion/add_fusion.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
// check a certain node is designated node's type.
|
||||
bool CheckPrimitiveTypeTutorial(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
|
||||
bool CheckPrimitiveTypeTutorial(const api::AnfNodePtr &node, const api::PrimitivePtr &primitive_type) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (node->isa<api::CNode>()) {
|
||||
auto cnode = node->cast<api::CNodePtr>();
|
||||
return IsPrimitive(cnode->input(0), primitive_type);
|
||||
} else if (node->isa<ValueNode>()) {
|
||||
} else if (node->isa<api::ValueNode>()) {
|
||||
return IsPrimitive(node, primitive_type);
|
||||
}
|
||||
return false;
|
||||
|
@ -41,11 +42,11 @@ bool CheckPrimitiveTypeTutorial(const AnfNodePtr &node, const PrimitivePtr &prim
|
|||
} // namespace
|
||||
|
||||
// convert addn to custom op
|
||||
AnfNodePtr PassTutorial::CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode) {
|
||||
api::AnfNodePtr PassTutorial::CreateCustomOp(const api::FuncGraphPtr func_graph, const api::CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto primc = std::make_shared<ops::Custom>();
|
||||
auto primc = api::MakeShared<ops::Custom>();
|
||||
if (primc == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -78,13 +79,13 @@ bool PassTutorial::Execute(const api::FuncGraphPtr &func_graph) {
|
|||
}
|
||||
auto node_list = api::FuncGraph::TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
if (!api::utils::isa<api::CNode>(node)) {
|
||||
continue;
|
||||
}
|
||||
if (!CheckPrimitiveTypeTutorial(node, prim::kPrimAddFusion)) {
|
||||
if (!CheckPrimitiveTypeTutorial(node, mindspore::api::MakeShared<mindspore::ops::AddFusion>())) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cnode = node->cast<api::CNodePtr>();
|
||||
auto custome_cnode = CreateCustomOp(func_graph, cnode);
|
||||
if (custome_cnode == nullptr) {
|
||||
return false;
|
||||
|
|
|
@ -30,7 +30,7 @@ class PassTutorial : public registry::PassBase {
|
|||
bool Execute(const api::FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode);
|
||||
api::AnfNodePtr CreateCustomOp(const api::FuncGraphPtr func_graph, const api::CNodePtr &cnode);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue