abstract out passs

This commit is contained in:
xuanyue 2021-09-13 15:19:34 +08:00
parent 60de7e032e
commit ecafae75d5
29 changed files with 274 additions and 136 deletions

View File

@ -189,9 +189,12 @@ if(PLATFORM_ARM64)
DESTINATION ${RUNTIME_DIR}/third_party/hiai_ddk/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
endif()
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/registry/ DESTINATION ${RUNTIME_INC_DIR}/registry
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "register_kernel_interface.h"
PATTERN "register_kernel.h")
if(SUPPORT_TRAIN)
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "registry*" EXCLUDE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
@ -199,7 +202,7 @@ if(PLATFORM_ARM64)
else()
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
PATTERN "*registry.h" EXCLUDE)
PATTERN "registry*" EXCLUDE)
endif()
install(FILES ${TOP_DIR}/mindspore/lite/build/schema/model_generated.h DESTINATION ${RUNTIME_INC_DIR}/schema
COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -261,9 +264,12 @@ elseif(PLATFORM_ARM32)
DESTINATION ${RUNTIME_DIR}/third_party/hiai_ddk/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
endif()
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/registry/ DESTINATION ${RUNTIME_INC_DIR}/registry
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "register_kernel_interface.h"
PATTERN "register_kernel.h")
if(SUPPORT_TRAIN)
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "registry*" EXCLUDE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
@ -271,7 +277,7 @@ elseif(PLATFORM_ARM32)
else()
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
PATTERN "*registry.h" EXCLUDE)
PATTERN "registry*" EXCLUDE)
endif()
install(FILES ${TOP_DIR}/mindspore/lite/build/schema/model_generated.h DESTINATION ${RUNTIME_INC_DIR}/schema
COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -328,8 +334,6 @@ elseif(WIN32)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(FILES ${TOP_DIR}/mindspore/ccsrc/backend/optimizer/common/pass.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
@ -355,6 +359,8 @@ elseif(WIN32)
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${protobuf_LIBPATH}/libprotobuf.a DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${flatbuffers_LIBPATH}/libflatbuffers.a DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
file(GLOB_RECURSE OPENCV_LIB_LIST
${opencv_LIBPATH}/../bin/libopencv_core*
${opencv_LIBPATH}/../bin/libopencv_imgcodecs*
@ -379,13 +385,16 @@ elseif(WIN32)
${RUNTIME_COMPONENT_NAME})
endif()
endif()
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/registry/ DESTINATION ${RUNTIME_INC_DIR}/registry
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "register_kernel_interface.h"
PATTERN "register_kernel.h")
if(SUPPORT_TRAIN)
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "registry*" EXCLUDE)
else()
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
PATTERN "*registry.h" EXCLUDE)
PATTERN "registry*" EXCLUDE)
endif()
install(FILES ${TOP_DIR}/build/mindspore/schema/model_generated.h DESTINATION ${RUNTIME_INC_DIR}/schema
COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -414,10 +423,12 @@ elseif(WIN32)
install(FILES ${LIB_LIST} DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
else()
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/registry/ DESTINATION ${RUNTIME_INC_DIR}/registry
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "register_kernel_interface.h"
PATTERN "register_kernel.h")
if(SUPPORT_TRAIN)
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE
PATTERN "framework.h" EXCLUDE)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "registry*" EXCLUDE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
@ -425,7 +436,7 @@ else()
else()
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
PATTERN "*registry.h" EXCLUDE PATTERN "framework.h" EXCLUDE)
PATTERN "registry*" EXCLUDE)
endif()
install(FILES ${TOP_DIR}/mindspore/lite/build/schema/model_generated.h DESTINATION ${RUNTIME_INC_DIR}/schema
COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -466,8 +477,6 @@ else()
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(FILES ${TOP_DIR}/mindspore/ccsrc/backend/optimizer/common/pass.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
@ -495,6 +504,8 @@ else()
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${protobuf_LIBPATH}/libprotobuf.a DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${flatbuffers_LIBPATH}/libflatbuffers.a DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${openssl_LIBPATH}/libcrypto.a DESTINATION ${CONVERTER_ROOT_DIR}/lib
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(TARGETS converter_lite RUNTIME DESTINATION ${CONVERTER_ROOT_DIR}/converter

View File

@ -55,6 +55,10 @@ class MS_CORE_API FuncGraph {
static std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &node);
static FuncGraphPtr Create();
static AnfNodePtr MakeValueNode(const FuncGraphPtr &func_graph);
static FuncGraphPtr GetFuncGraphFromAnfNode(const AnfNodePtr &input);
};
} // namespace mindspore::api
#endif // MINDSPORE_CORE_API_IR_FUNC_GRAPH_H_

View File

@ -760,5 +760,15 @@ std::vector<AnfNodePtr> api::FuncGraph::TopoSort(const AnfNodePtr &node) { retur
// Create an api::FuncGraph instance.
api::FuncGraphPtr api::FuncGraph::Create() { return std::make_shared<mindspore::FuncGraph>(); }
AnfNodePtr api::FuncGraph::MakeValueNode(const api::FuncGraphPtr &func_graph) {
auto fg = std::dynamic_pointer_cast<mindspore::FuncGraph>(func_graph);
return NewValueNode(fg);
}
api::FuncGraphPtr api::FuncGraph::GetFuncGraphFromAnfNode(const AnfNodePtr &input) {
auto fg = GetValueNode<mindspore::FuncGraphPtr>(input);
return fg;
}
const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
} // namespace mindspore

View File

@ -237,7 +237,7 @@ build_lite() {
compile_nnie_script=${BASEPATH}/mindspore/lite/tools/providers/NNIE/Hi3516D/compile_nnie.sh
cd ${BASEPATH}/../
if [[ "${local_lite_platform}" == "x86_64" ]]; then
sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master_dev -j $THREAD_NUM
sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master_dev_2 -j $THREAD_NUM
if [[ $? -ne 0 ]]; then
echo "compile x86_64 for nnie failed."
exit 1

View File

@ -41,7 +41,7 @@ bool CheckPrimitiveTypeTutorial(const AnfNodePtr &node, const PrimitivePtr &prim
} // namespace
// convert addn to custom op
AnfNodePtr PassTutorial::CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode) {
AnfNodePtr PassTutorial::CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode) {
if (cnode == nullptr) {
return nullptr;
}
@ -66,17 +66,17 @@ AnfNodePtr PassTutorial::CreateCustomOp(const FuncGraphPtr func_graph, const CNo
return custom_cnode;
}
bool PassTutorial::Run(const FuncGraphPtr &func_graph) {
bool PassTutorial::Execute(const api::FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return false;
}
// generate a func_graph manager.
auto manager = Manage(func_graph, true);
auto manager = api::FuncGraphManager::Manage(func_graph, true);
if (manager == nullptr) {
return false;
}
auto node_list = TopoSort(func_graph->get_return());
auto node_list = api::FuncGraph::TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;

View File

@ -17,18 +17,18 @@
#ifndef MINDSPORE_LITE_EXAMPLES_CONVERTER_REGISTER_SRC_PASS_REGISTRY_TUTORIAL_H
#define MINDSPORE_LITE_EXAMPLES_CONVERTER_REGISTER_SRC_PASS_REGISTRY_TUTORIAL_H
#include "include/pass.h"
#include "include/registry/pass_base.h"
namespace mindspore {
namespace opt {
class PassTutorial : public Pass {
class PassTutorial : public registry::PassBase {
public:
PassTutorial() : Pass("pass_tutorial") {}
PassTutorial() : PassBase("PassTutorial") {}
bool Run(const FuncGraphPtr &func_graph) override;
bool Execute(const api::FuncGraphPtr &func_graph) override;
private:
AnfNodePtr CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode);
AnfNodePtr CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode);
};
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_PASS_BASE_H_
#define MINDSPORE_LITE_INCLUDE_REGISTRY_PASS_BASE_H_
#include <memory>
#include <string>
#include "include/lite_utils.h"
#include "api/ir/func_graph.h"
namespace mindspore {
namespace registry {
/// \brief PassBase defined a base class, which provides an interface for user to operate FuncGraph.
class MS_API PassBase {
public:
/// \brief Constructor
///
/// \param[in] name Define pass name, which should be unique with each other.
explicit PassBase(const std::string &name = "PassBase") : name_(name) {}
/// \brief Destructor
virtual ~PassBase() = default;
/// \brief An interface for user to operate FuncGraph.
///
/// \param[in] func_graph Define the struct of the model.
///
/// \return Boolean value to represent whether the operation is successful or not.
virtual bool Execute(const api::FuncGraphPtr &func_graph) = 0;
private:
const std::string name_;
};
/// \brief PassBasePtr defined a shared_ptr type.
using PassBasePtr = std::shared_ptr<PassBase>;
} // namespace registry
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_PASS_BASE_H_

View File

@ -19,20 +19,13 @@
#include <vector>
#include <string>
#include <utility>
#include <memory>
#include "include/lite_utils.h"
namespace mindspore {
namespace opt {
/// \brief P defined a basic interface.
///
/// \note List public class and interface for reference.
class MS_API Pass;
using PassPtr = std::shared_ptr<Pass>;
} // namespace opt
namespace registry {
class PassBase;
using PassBasePtr = std::shared_ptr<PassBase>;
/// \brief PassPosition defined where to place user's pass.
enum MS_API PassPosition { POSITION_BEGIN = 0, POSITION_END = 1 };
@ -43,7 +36,7 @@ class MS_API PassRegistry {
///
/// \param[in] pass_name Define the name of the pass, a string which should guarantee uniqueness.
/// \param[in] pass Define pass instance.
PassRegistry(const std::string &pass_name, const opt::PassPtr &pass);
PassRegistry(const std::string &pass_name, const PassBasePtr &pass);
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
///
@ -63,10 +56,10 @@ class MS_API PassRegistry {
/// \brief Static method to obtain pass instance according to passes' name.
///
/// \param[in] pass_names Define the name of passes.
/// \param[in] pass_names Define the name of pass.
///
/// \return Pass Instance Vector.
static std::vector<opt::PassPtr> GetPassFromStoreRoom(const std::vector<std::string> &pass_names);
static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name);
};
/// \brief Defined registering macro to register Pass, which called by user directly.

View File

@ -40,9 +40,15 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
ConverterParameters converter_parameters;
auto func_graph = model_parser->Parse(converter_parameters);
ASSERT_NE(func_graph, nullptr);
auto node_list = func_graph->GetOrderedCnodes();
ASSERT_EQ(node_list.size(), 3);
auto iter = node_list.begin();
auto node_list = func_graph->TopoSort(func_graph->get_return());
std::vector<AnfNodePtr> cnode_list;
for (auto &node : node_list) {
if (node->isa<CNode>()) {
cnode_list.push_back(node);
}
}
ASSERT_EQ(cnode_list.size(), 3);
auto iter = cnode_list.begin();
bool is_add = opt::CheckPrimitiveType(*iter, prim::kPrimAddFusion);
ASSERT_EQ(is_add, true);
++iter;

View File

@ -21,9 +21,9 @@
#include "include/registry/model_parser_registry.h"
namespace mindspore {
FuncGraphPtr ModelParserTest::Parse(const converter::ConverterParameters &flag) {
api::FuncGraphPtr ModelParserTest::Parse(const converter::ConverterParameters &flag) {
// construct funcgraph
res_graph_ = std::make_shared<FuncGraph>();
res_graph_ = api::FuncGraph::Create();
auto ret = InitOriginModelStructure();
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "obtain origin model structure failed.";

View File

@ -28,7 +28,7 @@ namespace mindspore {
class ModelParserTest : public converter::ModelParser {
public:
ModelParserTest() = default;
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
private:
int InitOriginModelStructure();

View File

@ -18,13 +18,14 @@
#include <string>
#include <vector>
#include "common/common_test.h"
#include "backend/optimizer/common/pass.h"
#include "include/registry/model_parser_registry.h"
#include "include/registry/pass_base.h"
#include "include/registry/pass_registry.h"
#include "ops/fusion/add_fusion.h"
#include "ops/addn.h"
#include "ops/custom.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/optimizer_manager.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "ut/tools/converter/registry/model_parser_test.h"
@ -44,14 +45,14 @@ class PassRegistryTest : public mindspore::CommonTest {
ConverterParameters converter_parameters;
func_graph_ = model_parser->Parse(converter_parameters);
}
FuncGraphPtr func_graph_ = nullptr;
api::FuncGraphPtr func_graph_ = nullptr;
};
namespace opt {
// fuse add and add to addn.
class Test1Fusion : public Pass {
class Test1Fusion : public registry::PassBase {
public:
Test1Fusion() : Pass("Test1Fusion") {}
Test1Fusion() : PassBase("Test1Fusion") {}
bool CanFusion(const CNodePtr &cnode) {
if (cnode == nullptr) {
return false;
@ -90,15 +91,15 @@ class Test1Fusion : public Pass {
return input_cnode_num > 0;
}
bool Run(const FuncGraphPtr &func_graph) override {
bool Execute(const api::FuncGraphPtr &func_graph) override {
if (func_graph == nullptr) {
return false;
}
auto manager = Manage(func_graph);
auto manager = api::FuncGraphManager::Manage(func_graph);
if (manager == nullptr) {
return false;
}
auto node_list = TopoSort(func_graph->get_return());
auto node_list = api::FuncGraph::TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
@ -130,10 +131,10 @@ class Test1Fusion : public Pass {
};
// convert addn to custom op
class Test2Fusion : public Pass {
class Test2Fusion : public registry::PassBase {
public:
Test2Fusion() : Pass("Test2Fusion") {}
AnfNodePtr CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode) {
Test2Fusion() : PassBase("Test2Fusion") {}
AnfNodePtr CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode) {
if (func_graph == nullptr || cnode == nullptr) {
return nullptr;
}
@ -158,11 +159,11 @@ class Test2Fusion : public Pass {
return custom_cnode;
}
bool Run(const FuncGraphPtr &func_graph) override {
bool Execute(const api::FuncGraphPtr &func_graph) override {
if (func_graph == nullptr) {
return false;
}
auto manager = Manage(func_graph);
auto manager = api::FuncGraphManager::Manage(func_graph);
if (manager == nullptr) {
return false;
}
@ -194,14 +195,26 @@ REG_SCHEDULED_PASS(POSITION_BEGIN, schedule)
TEST_F(PassRegistryTest, TestRegistry) {
auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(POSITION_BEGIN);
ASSERT_EQ(schedule_task.size(), 2);
auto passes = registry::PassRegistry::GetPassFromStoreRoom(schedule_task);
ASSERT_EQ(passes.size(), 2);
std::vector<registry::PassBasePtr> passes;
auto pass1 = registry::PassRegistry::GetPassFromStoreRoom("Test1Fusion");
ASSERT_NE(pass1, nullptr);
passes.push_back(pass1);
auto pass2 = registry::PassRegistry::GetPassFromStoreRoom("Test2Fusion");
ASSERT_NE(pass2, nullptr);
passes.push_back(pass2);
ASSERT_NE(func_graph_, nullptr);
for (auto &pass : passes) {
auto ret = pass->Run(func_graph_);
auto ret = pass->Execute(func_graph_);
ASSERT_EQ(ret, true);
}
auto cnode_list = func_graph_->GetOrderedCnodes();
std::vector<CNodePtr> cnode_list;
auto node_list = api::FuncGraph::TopoSort(func_graph_->get_return());
for (auto &node : node_list) {
ASSERT_NE(node, nullptr);
if (node->isa<CNode>()) {
cnode_list.push_back(node->cast<CNodePtr>());
}
}
ASSERT_EQ(cnode_list.size(), 2);
bool is_custom = opt::CheckPrimitiveType(cnode_list.front(), prim::kPrimCustom);
ASSERT_EQ(is_custom, true);

View File

@ -144,7 +144,7 @@ STATUS AclPass::PreProcGraph(const FuncGraphPtr &func_graph) {
return lite::RET_OK;
}
// The format of nodes (cnode, parameter, val) must be nchw due to interface of convert om
if (!lite::RunOptimizerPass(func_graph, {"ToNCHWFormat", "DeleteRedundantTranspose"})) {
if (!lite::RunOptimizerPass(func_graph, {"ToNCHWFormat", "DecreaseTransposeAlgo"})) {
MS_LOG(ERROR) << "To nchw format success.";
return lite::RET_ERROR;
}

View File

@ -501,25 +501,36 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
return old_graph;
}
void AnfTransform::AppendPassToStoreRoom(const converter::Flags *config) {
bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) {
if (config == nullptr) {
MS_LOG(ERROR) << "config is nullptr";
return;
return false;
}
auto fmk = config->fmk;
auto is_train = config->trainModel;
registry::PassRegistry("DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train));
registry::PassRegistry("DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>());
registry::PassRegistry("InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train));
registry::PassRegistry("ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train));
registry::PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train));
registry::PassRegistry("SpecifyGraphInputFormat",
std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat));
registry::PassRegistry("DumpGraph", std::make_shared<opt::DumpGraph>(config));
std::unordered_map<std::string, opt::PassPtr> passes = {
{"DumpGraph", std::make_shared<opt::DumpGraph>(config)},
{"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)},
{"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)},
{"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)},
{"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train)},
{"SpecifyGraphInputFormat", std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat)}};
bool succeed_store = true;
for (auto iter = passes.begin(); iter != passes.end(); ++iter) {
if (PassStorage::StorePass(iter->first, iter->second) != RET_OK) {
MS_LOG(ERROR) << "external pass name conflicts with that of internal pass, the pass name is " << iter->first
<< ", please edit external pass name.";
succeed_store = false;
}
}
return succeed_store;
}
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
AppendPassToStoreRoom(config);
if (!StoreBuiltinPass(config)) {
MS_LOG(ERROR) << "store pass failed.";
return nullptr;
}
auto new_graph = TransformFuncGraph(main_graph, config);
if (new_graph == nullptr) {
MS_LOG(ERROR) << "optimizer failed.";

View File

@ -57,7 +57,7 @@ class AnfTransform {
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
void AppendPassToStoreRoom(const converter::Flags *config);
bool StoreBuiltinPass(const converter::Flags *config);
static STATUS MarkTrainInputOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode);

View File

@ -44,16 +44,13 @@ void InitConverterParameters(const converter::Flags &flag, converter::ConverterP
} // namespace
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
FuncGraphPtr func_graph = nullptr;
api::FuncGraphPtr func_graph_base = nullptr;
if (flag.fmk == converter::FmkType::kFmkTypeMs) {
#ifdef SUPPORT_TRAIN
kernel::PopulateTrainParameters();
#endif
MindsporeImporter ms_import;
func_graph = ms_import.ImportMindIR(flag);
if (func_graph == nullptr) {
return nullptr;
}
func_graph_base = ms_import.ImportMindIR(flag);
} else {
model_parser_ = registry::ModelParserRegistry::GetModelParser(flag.fmk);
if (model_parser_ == nullptr) {
@ -61,13 +58,18 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
}
converter::ConverterParameters converter_parameters;
InitConverterParameters(flag, &converter_parameters);
func_graph = model_parser_->Parse(converter_parameters);
func_graph_base = model_parser_->Parse(converter_parameters);
}
if (func_graph == nullptr) {
if (func_graph_base == nullptr) {
MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NOT_SUPPORT);
return nullptr;
}
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(func_graph_base);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func graph is invalid.";
return nullptr;
}
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
MS_LOG(ERROR) << "Update graph inputs and outputs dtype failed.";
return nullptr;

View File

@ -21,7 +21,7 @@
#include <memory>
#include "schema/inner/model_generated.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "api/ir/func_graph.h"
#include "include/registry/model_parser_registry.h"
#include "utils/log_adapter.h"
@ -32,10 +32,10 @@ class ModelParser {
virtual ~ModelParser() = default;
virtual FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
virtual api::FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
protected:
FuncGraphPtr res_graph_ = nullptr;
api::FuncGraphPtr res_graph_ = nullptr;
};
typedef ModelParser *(*ModelParserCreator)();

View File

@ -15,31 +15,39 @@
*/
#include "tools/converter/optimizer_manager.h"
#include <map>
#include <string>
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "src/common/log_util.h"
#include "include/registry/pass_base.h"
namespace mindspore {
namespace lite {
std::map<std::string, opt::PassPtr> PassStorage::pass_stroge_;
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func graph is nullptr.";
return false;
}
auto schedule_passes = registry::PassRegistry::GetPassFromStoreRoom(pass_names);
if (schedule_passes.size() != pass_names.size()) {
MS_LOG(ERROR) << "exited pass cannot be obtained.";
return false;
}
int index = 0;
for (auto &pass : schedule_passes) {
CHECK_NULL_RETURN(pass);
if (!pass->Run(func_graph)) {
MS_LOG(WARNING) << "run pass failed, pass name is " << pass_names[index];
for (auto &pass_name : pass_names) {
auto pass_outer = registry::PassRegistry::GetPassFromStoreRoom(pass_name);
if (pass_outer != nullptr) {
if (!pass_outer->Execute(func_graph)) {
MS_LOG(ERROR) << "run pass failed, pass name is " << pass_name;
return false;
}
continue;
}
auto pass_builtin = PassStorage::GetPassFromStorage(pass_name);
if (pass_builtin == nullptr) {
MS_LOG(ERROR) << "exited pass cannot be obtained, pass name is " << pass_name;
return false;
}
if (!pass_builtin->Run(func_graph)) {
MS_LOG(ERROR) << "run pass failed, pass name is " << pass_name;
return false;
}
++index;
}
return true;
}

View File

@ -17,13 +17,31 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H
#include <map>
#include <string>
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "include/errorcode.h"
#include "include/registry/pass_registry.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace lite {
class PassStorage {
public:
static int StorePass(const std::string &pass_name, const opt::PassPtr &pass) {
if (registry::PassRegistry::GetPassFromStoreRoom(pass_name) != nullptr) {
return RET_ERROR;
}
pass_stroge_[pass_name] = pass;
return RET_OK;
}
static opt::PassPtr GetPassFromStorage(const std::string &pass_name) { return pass_stroge_[pass_name]; }
private:
static std::map<std::string, opt::PassPtr> pass_stroge_;
};
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names);
bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition position);
} // namespace lite

View File

@ -82,7 +82,7 @@ CaffeModelParser::CaffeModelParser() = default;
CaffeModelParser::~CaffeModelParser() = default;
FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) {
api::FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) {
auto model_file = flag.model_file;
auto weight_file = flag.weight_file;
STATUS status = InitOriginModel(model_file, weight_file);
@ -114,7 +114,9 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
MS_CHECK_TRUE_RET(value_ptr != nullptr, nullptr);
res_graph_->set_attr("fmk", value_ptr);
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(res_graph_, &all_func_graphs);
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
GetAllFuncGraph(func_graph, &all_func_graphs);
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "AdjustForAnf failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -122,7 +124,7 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeCaffe, false);
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
if (!unify_format->Run(res_graph_)) {
if (!unify_format->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}

View File

@ -34,7 +34,7 @@ class CaffeModelParser : public converter::ModelParser {
~CaffeModelParser() override;
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
private:
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);

View File

@ -501,7 +501,7 @@ FuncGraphPtr OnnxModelParser::BuildBodyGraph(const onnx::NodeProto &loop_node, c
return loop_body_graph;
}
FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
api::FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
auto model_file = flag.model_file;
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
res_graph_ = std::make_shared<FuncGraph>();
@ -514,13 +514,15 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
}
MS_ASSERT(onnx_root_graph_ != nullptr);
status = ConvertOnnxGraph(onnx_root_graph_, res_graph_, &anf_nodes_map_, {}, "root_node");
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
status = ConvertOnnxGraph(onnx_root_graph_, func_graph, &anf_nodes_map_, {}, "root_node");
if (RET_OK != status) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
MS_LOG(ERROR) << "convert onnx graph failed.";
return nullptr;
}
static auto root_func_manager = Manage(res_graph_);
static auto root_func_manager = Manage(func_graph);
MS_ASSERT(root_func_manager != nullptr);
for (auto &subgraph : all_subgraphs_) {
MS_ASSERT(subgraph != nullptr);
@ -530,7 +532,7 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(res_graph_, &all_func_graphs);
GetAllFuncGraph(func_graph, &all_func_graphs);
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "AdjustForAnf failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -543,7 +545,7 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false);
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
if (!unify_format->Run(res_graph_)) {
if (!unify_format->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}

View File

@ -41,7 +41,7 @@ class OnnxModelParser : public converter::ModelParser {
~OnnxModelParser() override = default;
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
private:
STATUS InitOriginModel(const std::string &model_file);

View File

@ -492,7 +492,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(const std::vector<const tensor
return RET_OK;
}
FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
api::FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
auto modelFile = flag.model_file;
NotSupportOp::GetInstance()->set_fmk_type("TF");
auto status = ValidateFileStr(modelFile, ".pb");
@ -528,7 +528,9 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
tf_root_graph_nodes_vec_.emplace_back(&node_def);
}
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_vec_, res_graph_, &anf_root_node_map_, true);
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_vec_, func_graph, &anf_root_node_map_, true);
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
@ -536,7 +538,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
bool success_flag = true;
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
auto &node_def = tf_root_graph_->node(i);
status = ConvertOps(node_def, tf_root_graph_nodes_, res_graph_, &anf_root_node_map_);
status = ConvertOps(node_def, tf_root_graph_nodes_, func_graph, &anf_root_node_map_);
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
if (status != RET_OK) {
success_flag = false;
@ -570,7 +572,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
return nullptr;
}
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(res_graph_, &all_func_graphs);
GetAllFuncGraph(func_graph, &all_func_graphs);
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "AdjustForAnf failed.";
@ -584,12 +586,12 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTf, false);
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
if (!unify_format->Run(res_graph_)) {
if (!unify_format->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}
res_graph_->set_manager(nullptr);
static auto root_func_manager = Manage(res_graph_);
func_graph->set_manager(nullptr);
static auto root_func_manager = Manage(func_graph);
return res_graph_;
}
@ -780,7 +782,12 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGr
<< " second_func_map.size(): " << second_func_map.size();
return RET_ERROR;
}
static auto root_func_manager = Manage(res_graph_);
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func graph is invalid.";
return RET_ERROR;
}
static auto root_func_manager = Manage(func_graph);
for (auto &kv : first_func_map) {
auto control_flow_node = kv.first;
@ -1082,7 +1089,12 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
}
}
}
auto status = MakeAnfGraphOutputs(output_nodes, res_graph_);
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "unc graph is invalid.";
return RET_ERROR;
}
auto status = MakeAnfGraphOutputs(output_nodes, func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "make anf graph outputs node error";
return status;

View File

@ -40,7 +40,7 @@ class TFModelParser : public converter::ModelParser {
TFModelParser() = default;
~TFModelParser() override = default;
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
static int TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);

View File

@ -56,7 +56,7 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
return tflite::UnPackModel(tflite_model_buf_);
}
FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
api::FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
auto model_file = flag.model_file;
// load graph
tflite_model_ = ReadTfliteModel(model_file);
@ -81,7 +81,9 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
}
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(res_graph_, &all_func_graphs);
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
GetAllFuncGraph(func_graph, &all_func_graphs);
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "AdjustForAnf failed.";
@ -95,7 +97,7 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTflite, false);
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
if (!unify_format->Run(res_graph_)) {
if (!unify_format->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";
return nullptr;
}
@ -542,7 +544,9 @@ STATUS TfliteModelParser::ControlFlowNodePostProcess() {
if (control_flow_map_.empty()) {
return RET_OK;
}
static auto root_func_manager = Manage(res_graph_);
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_ERROR);
static auto root_func_manager = Manage(func_graph);
for (auto &node_vs_graph : control_flow_map_) {
auto control_flow_node = node_vs_graph.first;
auto sub_graphs = node_vs_graph.second;

View File

@ -36,7 +36,7 @@ class TfliteModelParser : public converter::ModelParser {
~TfliteModelParser() override = default;
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
static int Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);

View File

@ -20,25 +20,24 @@
#include <string>
#include <vector>
#include "src/common/log_adapter.h"
#include "nnacl/op_base.h"
namespace mindspore {
namespace registry {
namespace {
std::map<std::string, opt::PassPtr> pass_store_room;
std::map<std::string, PassBasePtr> outer_pass_storage;
std::map<registry::PassPosition, std::vector<std::string>> external_assigned_passes;
std::mutex pass_mutex;
void RegPass(const std::string &pass_name, const opt::PassPtr &pass) {
void RegPass(const std::string &pass_name, const PassBasePtr &pass) {
if (pass == nullptr) {
MS_LOG(ERROR) << "pass is nullptr.";
return;
}
std::unique_lock<std::mutex> lock(pass_mutex);
pass_store_room[pass_name] = pass;
outer_pass_storage[pass_name] = pass;
}
} // namespace
PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) { RegPass(pass_name, pass); }
PassRegistry::PassRegistry(const std::string &pass_name, const PassBasePtr &pass) { RegPass(pass_name, pass); }
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &names) {
std::unique_lock<std::mutex> lock(pass_mutex);
@ -49,17 +48,8 @@ std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition positio
return external_assigned_passes[position];
}
std::vector<opt::PassPtr> PassRegistry::GetPassFromStoreRoom(const std::vector<std::string> &pass_names) {
std::vector<opt::PassPtr> schedule_passes;
for (auto &name : pass_names) {
auto iter = pass_store_room.find(name);
if (iter == pass_store_room.end()) {
continue;
}
MS_CHECK_TRUE_RET(iter->second != nullptr, std::vector<opt::PassPtr>{});
schedule_passes.push_back(iter->second);
}
return schedule_passes;
PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::string &pass_name) {
return outer_pass_storage.find(pass_name) == outer_pass_storage.end() ? nullptr : outer_pass_storage[pass_name];
}
} // namespace registry
} // namespace mindspore

View File

@ -17,16 +17,15 @@
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_
#include "backend/optimizer/common/pass.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
namespace opt {
class DeleteRedundantTranspose : public Pass {
class DeleteRedundantTranspose {
public:
DeleteRedundantTranspose() : Pass("DeleteRedundantTranspose") {}
DeleteRedundantTranspose() = default;
~DeleteRedundantTranspose() = default;
bool Run(const FuncGraphPtr &func_graph) override;
bool Run(const FuncGraphPtr &func_graph);
private:
STATUS DeleteNot4DTranspose(const FuncGraphPtr &func_graph);