forked from mindspore-Ecosystem/mindspore
abstract out passs
This commit is contained in:
parent
60de7e032e
commit
ecafae75d5
|
@ -189,9 +189,12 @@ if(PLATFORM_ARM64)
|
|||
DESTINATION ${RUNTIME_DIR}/third_party/hiai_ddk/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
endif()
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/registry/ DESTINATION ${RUNTIME_INC_DIR}/registry
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "register_kernel_interface.h"
|
||||
PATTERN "register_kernel.h")
|
||||
if(SUPPORT_TRAIN)
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "registry*" EXCLUDE)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
|
||||
${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
|
||||
|
@ -199,7 +202,7 @@ if(PLATFORM_ARM64)
|
|||
else()
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
|
||||
PATTERN "*registry.h" EXCLUDE)
|
||||
PATTERN "registry*" EXCLUDE)
|
||||
endif()
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/schema/model_generated.h DESTINATION ${RUNTIME_INC_DIR}/schema
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -261,9 +264,12 @@ elseif(PLATFORM_ARM32)
|
|||
DESTINATION ${RUNTIME_DIR}/third_party/hiai_ddk/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
endif()
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/registry/ DESTINATION ${RUNTIME_INC_DIR}/registry
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "register_kernel_interface.h"
|
||||
PATTERN "register_kernel.h")
|
||||
if(SUPPORT_TRAIN)
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "registry*" EXCLUDE)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
|
||||
${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
|
||||
|
@ -271,7 +277,7 @@ elseif(PLATFORM_ARM32)
|
|||
else()
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
|
||||
PATTERN "*registry.h" EXCLUDE)
|
||||
PATTERN "registry*" EXCLUDE)
|
||||
endif()
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/schema/model_generated.h DESTINATION ${RUNTIME_INC_DIR}/schema
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -328,8 +334,6 @@ elseif(WIN32)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${TOP_DIR}/mindspore/ccsrc/backend/optimizer/common/pass.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||
|
@ -355,6 +359,8 @@ elseif(WIN32)
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${protobuf_LIBPATH}/libprotobuf.a DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${flatbuffers_LIBPATH}/libflatbuffers.a DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
file(GLOB_RECURSE OPENCV_LIB_LIST
|
||||
${opencv_LIBPATH}/../bin/libopencv_core*
|
||||
${opencv_LIBPATH}/../bin/libopencv_imgcodecs*
|
||||
|
@ -379,13 +385,16 @@ elseif(WIN32)
|
|||
${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
endif()
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/registry/ DESTINATION ${RUNTIME_INC_DIR}/registry
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "register_kernel_interface.h"
|
||||
PATTERN "register_kernel.h")
|
||||
if(SUPPORT_TRAIN)
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "registry*" EXCLUDE)
|
||||
else()
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
|
||||
PATTERN "*registry.h" EXCLUDE)
|
||||
PATTERN "registry*" EXCLUDE)
|
||||
endif()
|
||||
install(FILES ${TOP_DIR}/build/mindspore/schema/model_generated.h DESTINATION ${RUNTIME_INC_DIR}/schema
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -414,10 +423,12 @@ elseif(WIN32)
|
|||
install(FILES ${LIB_LIST} DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
else()
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/registry/ DESTINATION ${RUNTIME_INC_DIR}/registry
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "register_kernel_interface.h"
|
||||
PATTERN "register_kernel.h")
|
||||
if(SUPPORT_TRAIN)
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE
|
||||
PATTERN "framework.h" EXCLUDE)
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "registry*" EXCLUDE)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
|
||||
${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
|
||||
|
@ -425,7 +436,7 @@ else()
|
|||
else()
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
|
||||
PATTERN "*registry.h" EXCLUDE PATTERN "framework.h" EXCLUDE)
|
||||
PATTERN "registry*" EXCLUDE)
|
||||
endif()
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/schema/model_generated.h DESTINATION ${RUNTIME_INC_DIR}/schema
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -466,8 +477,6 @@ else()
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/core/utils/ DESTINATION ${CONVERTER_ROOT_DIR}/include/core/utils
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
|
||||
install(FILES ${TOP_DIR}/mindspore/ccsrc/backend/optimizer/common/pass.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/model_parser.h
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/include COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/tools/converter/ops/ops_def.h
|
||||
|
@ -495,6 +504,8 @@ else()
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${protobuf_LIBPATH}/libprotobuf.a DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${flatbuffers_LIBPATH}/libflatbuffers.a DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${openssl_LIBPATH}/libcrypto.a DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(TARGETS converter_lite RUNTIME DESTINATION ${CONVERTER_ROOT_DIR}/converter
|
||||
|
|
|
@ -55,6 +55,10 @@ class MS_CORE_API FuncGraph {
|
|||
static std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &node);
|
||||
|
||||
static FuncGraphPtr Create();
|
||||
|
||||
static AnfNodePtr MakeValueNode(const FuncGraphPtr &func_graph);
|
||||
|
||||
static FuncGraphPtr GetFuncGraphFromAnfNode(const AnfNodePtr &input);
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CORE_API_IR_FUNC_GRAPH_H_
|
||||
|
|
|
@ -760,5 +760,15 @@ std::vector<AnfNodePtr> api::FuncGraph::TopoSort(const AnfNodePtr &node) { retur
|
|||
// Create an api::FuncGraph instance.
|
||||
api::FuncGraphPtr api::FuncGraph::Create() { return std::make_shared<mindspore::FuncGraph>(); }
|
||||
|
||||
AnfNodePtr api::FuncGraph::MakeValueNode(const api::FuncGraphPtr &func_graph) {
|
||||
auto fg = std::dynamic_pointer_cast<mindspore::FuncGraph>(func_graph);
|
||||
return NewValueNode(fg);
|
||||
}
|
||||
|
||||
api::FuncGraphPtr api::FuncGraph::GetFuncGraphFromAnfNode(const AnfNodePtr &input) {
|
||||
auto fg = GetValueNode<mindspore::FuncGraphPtr>(input);
|
||||
return fg;
|
||||
}
|
||||
|
||||
const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared<Primitive>("FuncGraph");
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -237,7 +237,7 @@ build_lite() {
|
|||
compile_nnie_script=${BASEPATH}/mindspore/lite/tools/providers/NNIE/Hi3516D/compile_nnie.sh
|
||||
cd ${BASEPATH}/../
|
||||
if [[ "${local_lite_platform}" == "x86_64" ]]; then
|
||||
sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master_dev -j $THREAD_NUM
|
||||
sh ${compile_nnie_script} -I x86_64 -b nnie_3516_master_dev_2 -j $THREAD_NUM
|
||||
if [[ $? -ne 0 ]]; then
|
||||
echo "compile x86_64 for nnie failed."
|
||||
exit 1
|
||||
|
|
|
@ -41,7 +41,7 @@ bool CheckPrimitiveTypeTutorial(const AnfNodePtr &node, const PrimitivePtr &prim
|
|||
} // namespace
|
||||
|
||||
// convert addn to custom op
|
||||
AnfNodePtr PassTutorial::CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode) {
|
||||
AnfNodePtr PassTutorial::CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -66,17 +66,17 @@ AnfNodePtr PassTutorial::CreateCustomOp(const FuncGraphPtr func_graph, const CNo
|
|||
return custom_cnode;
|
||||
}
|
||||
|
||||
bool PassTutorial::Run(const FuncGraphPtr &func_graph) {
|
||||
bool PassTutorial::Execute(const api::FuncGraphPtr &func_graph) {
|
||||
if (func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// generate a func_graph manager.
|
||||
auto manager = Manage(func_graph, true);
|
||||
auto manager = api::FuncGraphManager::Manage(func_graph, true);
|
||||
if (manager == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
auto node_list = api::FuncGraph::TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
|
|
|
@ -17,18 +17,18 @@
|
|||
#ifndef MINDSPORE_LITE_EXAMPLES_CONVERTER_REGISTER_SRC_PASS_REGISTRY_TUTORIAL_H
|
||||
#define MINDSPORE_LITE_EXAMPLES_CONVERTER_REGISTER_SRC_PASS_REGISTRY_TUTORIAL_H
|
||||
|
||||
#include "include/pass.h"
|
||||
#include "include/registry/pass_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class PassTutorial : public Pass {
|
||||
class PassTutorial : public registry::PassBase {
|
||||
public:
|
||||
PassTutorial() : Pass("pass_tutorial") {}
|
||||
PassTutorial() : PassBase("PassTutorial") {}
|
||||
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
bool Execute(const api::FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode);
|
||||
AnfNodePtr CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_PASS_BASE_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_PASS_BASE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "include/lite_utils.h"
|
||||
#include "api/ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace registry {
|
||||
/// \brief PassBase defined a base class, which provides an interface for user to operate FuncGraph.
|
||||
class MS_API PassBase {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
///
|
||||
/// \param[in] name Define pass name, which should be unique with each other.
|
||||
explicit PassBase(const std::string &name = "PassBase") : name_(name) {}
|
||||
|
||||
/// \brief Destructor
|
||||
virtual ~PassBase() = default;
|
||||
|
||||
/// \brief An interface for user to operate FuncGraph.
|
||||
///
|
||||
/// \param[in] func_graph Define the struct of the model.
|
||||
///
|
||||
/// \return Boolean value to represent whether the operation is successful or not.
|
||||
virtual bool Execute(const api::FuncGraphPtr &func_graph) = 0;
|
||||
|
||||
private:
|
||||
const std::string name_;
|
||||
};
|
||||
|
||||
/// \brief PassBasePtr defined a shared_ptr type.
|
||||
using PassBasePtr = std::shared_ptr<PassBase>;
|
||||
} // namespace registry
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_PASS_BASE_H_
|
|
@ -19,20 +19,13 @@
|
|||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "include/lite_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
/// \brief P defined a basic interface.
|
||||
///
|
||||
/// \note List public class and interface for reference.
|
||||
class MS_API Pass;
|
||||
using PassPtr = std::shared_ptr<Pass>;
|
||||
} // namespace opt
|
||||
|
||||
namespace registry {
|
||||
class PassBase;
|
||||
using PassBasePtr = std::shared_ptr<PassBase>;
|
||||
/// \brief PassPosition defined where to place user's pass.
|
||||
enum MS_API PassPosition { POSITION_BEGIN = 0, POSITION_END = 1 };
|
||||
|
||||
|
@ -43,7 +36,7 @@ class MS_API PassRegistry {
|
|||
///
|
||||
/// \param[in] pass_name Define the name of the pass, a string which should guarantee uniqueness.
|
||||
/// \param[in] pass Define pass instance.
|
||||
PassRegistry(const std::string &pass_name, const opt::PassPtr &pass);
|
||||
PassRegistry(const std::string &pass_name, const PassBasePtr &pass);
|
||||
|
||||
/// \brief Constructor of PassRegistry to assign which passes are required for external extension.
|
||||
///
|
||||
|
@ -63,10 +56,10 @@ class MS_API PassRegistry {
|
|||
|
||||
/// \brief Static method to obtain pass instance according to passes' name.
|
||||
///
|
||||
/// \param[in] pass_names Define the name of passes.
|
||||
/// \param[in] pass_names Define the name of pass.
|
||||
///
|
||||
/// \return Pass Instance Vector.
|
||||
static std::vector<opt::PassPtr> GetPassFromStoreRoom(const std::vector<std::string> &pass_names);
|
||||
static PassBasePtr GetPassFromStoreRoom(const std::string &pass_name);
|
||||
};
|
||||
|
||||
/// \brief Defined registering macro to register Pass, which called by user directly.
|
||||
|
|
|
@ -40,9 +40,15 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
|
|||
ConverterParameters converter_parameters;
|
||||
auto func_graph = model_parser->Parse(converter_parameters);
|
||||
ASSERT_NE(func_graph, nullptr);
|
||||
auto node_list = func_graph->GetOrderedCnodes();
|
||||
ASSERT_EQ(node_list.size(), 3);
|
||||
auto iter = node_list.begin();
|
||||
auto node_list = func_graph->TopoSort(func_graph->get_return());
|
||||
std::vector<AnfNodePtr> cnode_list;
|
||||
for (auto &node : node_list) {
|
||||
if (node->isa<CNode>()) {
|
||||
cnode_list.push_back(node);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(cnode_list.size(), 3);
|
||||
auto iter = cnode_list.begin();
|
||||
bool is_add = opt::CheckPrimitiveType(*iter, prim::kPrimAddFusion);
|
||||
ASSERT_EQ(is_add, true);
|
||||
++iter;
|
||||
|
|
|
@ -21,9 +21,9 @@
|
|||
#include "include/registry/model_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
FuncGraphPtr ModelParserTest::Parse(const converter::ConverterParameters &flag) {
|
||||
api::FuncGraphPtr ModelParserTest::Parse(const converter::ConverterParameters &flag) {
|
||||
// construct funcgraph
|
||||
res_graph_ = std::make_shared<FuncGraph>();
|
||||
res_graph_ = api::FuncGraph::Create();
|
||||
auto ret = InitOriginModelStructure();
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "obtain origin model structure failed.";
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace mindspore {
|
|||
class ModelParserTest : public converter::ModelParser {
|
||||
public:
|
||||
ModelParserTest() = default;
|
||||
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
|
||||
private:
|
||||
int InitOriginModelStructure();
|
||||
|
|
|
@ -18,13 +18,14 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include "common/common_test.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "include/registry/pass_base.h"
|
||||
#include "include/registry/pass_registry.h"
|
||||
#include "ops/fusion/add_fusion.h"
|
||||
#include "ops/addn.h"
|
||||
#include "ops/custom.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/optimizer_manager.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
|
||||
|
@ -44,14 +45,14 @@ class PassRegistryTest : public mindspore::CommonTest {
|
|||
ConverterParameters converter_parameters;
|
||||
func_graph_ = model_parser->Parse(converter_parameters);
|
||||
}
|
||||
FuncGraphPtr func_graph_ = nullptr;
|
||||
api::FuncGraphPtr func_graph_ = nullptr;
|
||||
};
|
||||
|
||||
namespace opt {
|
||||
// fuse add and add to addn.
|
||||
class Test1Fusion : public Pass {
|
||||
class Test1Fusion : public registry::PassBase {
|
||||
public:
|
||||
Test1Fusion() : Pass("Test1Fusion") {}
|
||||
Test1Fusion() : PassBase("Test1Fusion") {}
|
||||
bool CanFusion(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
|
@ -90,15 +91,15 @@ class Test1Fusion : public Pass {
|
|||
return input_cnode_num > 0;
|
||||
}
|
||||
|
||||
bool Run(const FuncGraphPtr &func_graph) override {
|
||||
bool Execute(const api::FuncGraphPtr &func_graph) override {
|
||||
if (func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto manager = Manage(func_graph);
|
||||
auto manager = api::FuncGraphManager::Manage(func_graph);
|
||||
if (manager == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
auto node_list = api::FuncGraph::TopoSort(func_graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNode>(node)) {
|
||||
continue;
|
||||
|
@ -130,10 +131,10 @@ class Test1Fusion : public Pass {
|
|||
};
|
||||
|
||||
// convert addn to custom op
|
||||
class Test2Fusion : public Pass {
|
||||
class Test2Fusion : public registry::PassBase {
|
||||
public:
|
||||
Test2Fusion() : Pass("Test2Fusion") {}
|
||||
AnfNodePtr CreateCustomOp(const FuncGraphPtr func_graph, const CNodePtr &cnode) {
|
||||
Test2Fusion() : PassBase("Test2Fusion") {}
|
||||
AnfNodePtr CreateCustomOp(const api::FuncGraphPtr func_graph, const CNodePtr &cnode) {
|
||||
if (func_graph == nullptr || cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -158,11 +159,11 @@ class Test2Fusion : public Pass {
|
|||
return custom_cnode;
|
||||
}
|
||||
|
||||
bool Run(const FuncGraphPtr &func_graph) override {
|
||||
bool Execute(const api::FuncGraphPtr &func_graph) override {
|
||||
if (func_graph == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto manager = Manage(func_graph);
|
||||
auto manager = api::FuncGraphManager::Manage(func_graph);
|
||||
if (manager == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
@ -194,14 +195,26 @@ REG_SCHEDULED_PASS(POSITION_BEGIN, schedule)
|
|||
TEST_F(PassRegistryTest, TestRegistry) {
|
||||
auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(POSITION_BEGIN);
|
||||
ASSERT_EQ(schedule_task.size(), 2);
|
||||
auto passes = registry::PassRegistry::GetPassFromStoreRoom(schedule_task);
|
||||
ASSERT_EQ(passes.size(), 2);
|
||||
std::vector<registry::PassBasePtr> passes;
|
||||
auto pass1 = registry::PassRegistry::GetPassFromStoreRoom("Test1Fusion");
|
||||
ASSERT_NE(pass1, nullptr);
|
||||
passes.push_back(pass1);
|
||||
auto pass2 = registry::PassRegistry::GetPassFromStoreRoom("Test2Fusion");
|
||||
ASSERT_NE(pass2, nullptr);
|
||||
passes.push_back(pass2);
|
||||
ASSERT_NE(func_graph_, nullptr);
|
||||
for (auto &pass : passes) {
|
||||
auto ret = pass->Run(func_graph_);
|
||||
auto ret = pass->Execute(func_graph_);
|
||||
ASSERT_EQ(ret, true);
|
||||
}
|
||||
auto cnode_list = func_graph_->GetOrderedCnodes();
|
||||
std::vector<CNodePtr> cnode_list;
|
||||
auto node_list = api::FuncGraph::TopoSort(func_graph_->get_return());
|
||||
for (auto &node : node_list) {
|
||||
ASSERT_NE(node, nullptr);
|
||||
if (node->isa<CNode>()) {
|
||||
cnode_list.push_back(node->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(cnode_list.size(), 2);
|
||||
bool is_custom = opt::CheckPrimitiveType(cnode_list.front(), prim::kPrimCustom);
|
||||
ASSERT_EQ(is_custom, true);
|
||||
|
|
|
@ -144,7 +144,7 @@ STATUS AclPass::PreProcGraph(const FuncGraphPtr &func_graph) {
|
|||
return lite::RET_OK;
|
||||
}
|
||||
// The format of nodes (cnode, parameter, val) must be nchw due to interface of convert om
|
||||
if (!lite::RunOptimizerPass(func_graph, {"ToNCHWFormat", "DeleteRedundantTranspose"})) {
|
||||
if (!lite::RunOptimizerPass(func_graph, {"ToNCHWFormat", "DecreaseTransposeAlgo"})) {
|
||||
MS_LOG(ERROR) << "To nchw format success.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -501,25 +501,36 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
|
|||
return old_graph;
|
||||
}
|
||||
|
||||
void AnfTransform::AppendPassToStoreRoom(const converter::Flags *config) {
|
||||
bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) {
|
||||
if (config == nullptr) {
|
||||
MS_LOG(ERROR) << "config is nullptr";
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
auto fmk = config->fmk;
|
||||
auto is_train = config->trainModel;
|
||||
registry::PassRegistry("DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train));
|
||||
registry::PassRegistry("DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>());
|
||||
registry::PassRegistry("InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train));
|
||||
registry::PassRegistry("ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train));
|
||||
registry::PassRegistry("ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train));
|
||||
registry::PassRegistry("SpecifyGraphInputFormat",
|
||||
std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat));
|
||||
registry::PassRegistry("DumpGraph", std::make_shared<opt::DumpGraph>(config));
|
||||
std::unordered_map<std::string, opt::PassPtr> passes = {
|
||||
{"DumpGraph", std::make_shared<opt::DumpGraph>(config)},
|
||||
{"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)},
|
||||
{"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)},
|
||||
{"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)},
|
||||
{"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train)},
|
||||
{"SpecifyGraphInputFormat", std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat)}};
|
||||
bool succeed_store = true;
|
||||
for (auto iter = passes.begin(); iter != passes.end(); ++iter) {
|
||||
if (PassStorage::StorePass(iter->first, iter->second) != RET_OK) {
|
||||
MS_LOG(ERROR) << "external pass name conflicts with that of internal pass, the pass name is " << iter->first
|
||||
<< ", please edit external pass name.";
|
||||
succeed_store = false;
|
||||
}
|
||||
}
|
||||
return succeed_store;
|
||||
}
|
||||
|
||||
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
|
||||
AppendPassToStoreRoom(config);
|
||||
if (!StoreBuiltinPass(config)) {
|
||||
MS_LOG(ERROR) << "store pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto new_graph = TransformFuncGraph(main_graph, config);
|
||||
if (new_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "optimizer failed.";
|
||||
|
|
|
@ -57,7 +57,7 @@ class AnfTransform {
|
|||
|
||||
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
void AppendPassToStoreRoom(const converter::Flags *config);
|
||||
bool StoreBuiltinPass(const converter::Flags *config);
|
||||
|
||||
static STATUS MarkTrainInputOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
||||
|
|
|
@ -44,16 +44,13 @@ void InitConverterParameters(const converter::Flags &flag, converter::ConverterP
|
|||
} // namespace
|
||||
|
||||
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
||||
FuncGraphPtr func_graph = nullptr;
|
||||
api::FuncGraphPtr func_graph_base = nullptr;
|
||||
if (flag.fmk == converter::FmkType::kFmkTypeMs) {
|
||||
#ifdef SUPPORT_TRAIN
|
||||
kernel::PopulateTrainParameters();
|
||||
#endif
|
||||
MindsporeImporter ms_import;
|
||||
func_graph = ms_import.ImportMindIR(flag);
|
||||
if (func_graph == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
func_graph_base = ms_import.ImportMindIR(flag);
|
||||
} else {
|
||||
model_parser_ = registry::ModelParserRegistry::GetModelParser(flag.fmk);
|
||||
if (model_parser_ == nullptr) {
|
||||
|
@ -61,13 +58,18 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
|||
}
|
||||
converter::ConverterParameters converter_parameters;
|
||||
InitConverterParameters(flag, &converter_parameters);
|
||||
func_graph = model_parser_->Parse(converter_parameters);
|
||||
func_graph_base = model_parser_->Parse(converter_parameters);
|
||||
}
|
||||
if (func_graph == nullptr) {
|
||||
if (func_graph_base == nullptr) {
|
||||
MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NOT_SUPPORT);
|
||||
return nullptr;
|
||||
}
|
||||
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(func_graph_base);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func graph is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Update graph inputs and outputs dtype failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "api/ir/func_graph.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -32,10 +32,10 @@ class ModelParser {
|
|||
|
||||
virtual ~ModelParser() = default;
|
||||
|
||||
virtual FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
|
||||
virtual api::FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
|
||||
|
||||
protected:
|
||||
FuncGraphPtr res_graph_ = nullptr;
|
||||
api::FuncGraphPtr res_graph_ = nullptr;
|
||||
};
|
||||
|
||||
typedef ModelParser *(*ModelParserCreator)();
|
||||
|
|
|
@ -15,31 +15,39 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/optimizer_manager.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "include/registry/pass_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
std::map<std::string, opt::PassPtr> PassStorage::pass_stroge_;
|
||||
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func graph is nullptr.";
|
||||
return false;
|
||||
}
|
||||
auto schedule_passes = registry::PassRegistry::GetPassFromStoreRoom(pass_names);
|
||||
if (schedule_passes.size() != pass_names.size()) {
|
||||
MS_LOG(ERROR) << "exited pass cannot be obtained.";
|
||||
return false;
|
||||
}
|
||||
int index = 0;
|
||||
for (auto &pass : schedule_passes) {
|
||||
CHECK_NULL_RETURN(pass);
|
||||
if (!pass->Run(func_graph)) {
|
||||
MS_LOG(WARNING) << "run pass failed, pass name is " << pass_names[index];
|
||||
for (auto &pass_name : pass_names) {
|
||||
auto pass_outer = registry::PassRegistry::GetPassFromStoreRoom(pass_name);
|
||||
if (pass_outer != nullptr) {
|
||||
if (!pass_outer->Execute(func_graph)) {
|
||||
MS_LOG(ERROR) << "run pass failed, pass name is " << pass_name;
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto pass_builtin = PassStorage::GetPassFromStorage(pass_name);
|
||||
if (pass_builtin == nullptr) {
|
||||
MS_LOG(ERROR) << "exited pass cannot be obtained, pass name is " << pass_name;
|
||||
return false;
|
||||
}
|
||||
if (!pass_builtin->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "run pass failed, pass name is " << pass_name;
|
||||
return false;
|
||||
}
|
||||
++index;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -17,13 +17,31 @@
|
|||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/registry/pass_registry.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class PassStorage {
|
||||
public:
|
||||
static int StorePass(const std::string &pass_name, const opt::PassPtr &pass) {
|
||||
if (registry::PassRegistry::GetPassFromStoreRoom(pass_name) != nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
pass_stroge_[pass_name] = pass;
|
||||
return RET_OK;
|
||||
}
|
||||
static opt::PassPtr GetPassFromStorage(const std::string &pass_name) { return pass_stroge_[pass_name]; }
|
||||
|
||||
private:
|
||||
static std::map<std::string, opt::PassPtr> pass_stroge_;
|
||||
};
|
||||
|
||||
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names);
|
||||
bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition position);
|
||||
} // namespace lite
|
||||
|
|
|
@ -82,7 +82,7 @@ CaffeModelParser::CaffeModelParser() = default;
|
|||
|
||||
CaffeModelParser::~CaffeModelParser() = default;
|
||||
|
||||
FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
api::FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto model_file = flag.model_file;
|
||||
auto weight_file = flag.weight_file;
|
||||
STATUS status = InitOriginModel(model_file, weight_file);
|
||||
|
@ -114,7 +114,9 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
MS_CHECK_TRUE_RET(value_ptr != nullptr, nullptr);
|
||||
res_graph_->set_attr("fmk", value_ptr);
|
||||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(res_graph_, &all_func_graphs);
|
||||
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
|
||||
GetAllFuncGraph(func_graph, &all_func_graphs);
|
||||
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "AdjustForAnf failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -122,7 +124,7 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeCaffe, false);
|
||||
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -34,7 +34,7 @@ class CaffeModelParser : public converter::ModelParser {
|
|||
|
||||
~CaffeModelParser() override;
|
||||
|
||||
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
|
||||
private:
|
||||
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);
|
||||
|
|
|
@ -501,7 +501,7 @@ FuncGraphPtr OnnxModelParser::BuildBodyGraph(const onnx::NodeProto &loop_node, c
|
|||
return loop_body_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
api::FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto model_file = flag.model_file;
|
||||
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
|
||||
res_graph_ = std::make_shared<FuncGraph>();
|
||||
|
@ -514,13 +514,15 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
}
|
||||
MS_ASSERT(onnx_root_graph_ != nullptr);
|
||||
|
||||
status = ConvertOnnxGraph(onnx_root_graph_, res_graph_, &anf_nodes_map_, {}, "root_node");
|
||||
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
|
||||
status = ConvertOnnxGraph(onnx_root_graph_, func_graph, &anf_nodes_map_, {}, "root_node");
|
||||
if (RET_OK != status) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
MS_LOG(ERROR) << "convert onnx graph failed.";
|
||||
return nullptr;
|
||||
}
|
||||
static auto root_func_manager = Manage(res_graph_);
|
||||
static auto root_func_manager = Manage(func_graph);
|
||||
MS_ASSERT(root_func_manager != nullptr);
|
||||
for (auto &subgraph : all_subgraphs_) {
|
||||
MS_ASSERT(subgraph != nullptr);
|
||||
|
@ -530,7 +532,7 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::kFmkTypeOnnx)));
|
||||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(res_graph_, &all_func_graphs);
|
||||
GetAllFuncGraph(func_graph, &all_func_graphs);
|
||||
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "AdjustForAnf failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -543,7 +545,7 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag)
|
|||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeOnnx, false);
|
||||
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "create unify_format return nullptr");
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ class OnnxModelParser : public converter::ModelParser {
|
|||
|
||||
~OnnxModelParser() override = default;
|
||||
|
||||
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
|
||||
private:
|
||||
STATUS InitOriginModel(const std::string &model_file);
|
||||
|
|
|
@ -492,7 +492,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(const std::vector<const tensor
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
api::FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto modelFile = flag.model_file;
|
||||
NotSupportOp::GetInstance()->set_fmk_type("TF");
|
||||
auto status = ValidateFileStr(modelFile, ".pb");
|
||||
|
@ -528,7 +528,9 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
|||
tf_root_graph_nodes_vec_.emplace_back(&node_def);
|
||||
}
|
||||
|
||||
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_vec_, res_graph_, &anf_root_node_map_, true);
|
||||
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
|
||||
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_vec_, func_graph, &anf_root_node_map_, true);
|
||||
if (status != RET_OK) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
|
@ -536,7 +538,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
|||
bool success_flag = true;
|
||||
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
|
||||
auto &node_def = tf_root_graph_->node(i);
|
||||
status = ConvertOps(node_def, tf_root_graph_nodes_, res_graph_, &anf_root_node_map_);
|
||||
status = ConvertOps(node_def, tf_root_graph_nodes_, func_graph, &anf_root_node_map_);
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
if (status != RET_OK) {
|
||||
success_flag = false;
|
||||
|
@ -570,7 +572,7 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
|||
return nullptr;
|
||||
}
|
||||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(res_graph_, &all_func_graphs);
|
||||
GetAllFuncGraph(func_graph, &all_func_graphs);
|
||||
|
||||
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "AdjustForAnf failed.";
|
||||
|
@ -584,12 +586,12 @@ FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
|||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTf, false);
|
||||
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
res_graph_->set_manager(nullptr);
|
||||
static auto root_func_manager = Manage(res_graph_);
|
||||
func_graph->set_manager(nullptr);
|
||||
static auto root_func_manager = Manage(func_graph);
|
||||
return res_graph_;
|
||||
}
|
||||
|
||||
|
@ -780,7 +782,12 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGr
|
|||
<< " second_func_map.size(): " << second_func_map.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
static auto root_func_manager = Manage(res_graph_);
|
||||
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func graph is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
static auto root_func_manager = Manage(func_graph);
|
||||
|
||||
for (auto &kv : first_func_map) {
|
||||
auto control_flow_node = kv.first;
|
||||
|
@ -1082,7 +1089,12 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
|
|||
}
|
||||
}
|
||||
}
|
||||
auto status = MakeAnfGraphOutputs(output_nodes, res_graph_);
|
||||
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "unc graph is invalid.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = MakeAnfGraphOutputs(output_nodes, func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "make anf graph outputs node error";
|
||||
return status;
|
||||
|
|
|
@ -40,7 +40,7 @@ class TFModelParser : public converter::ModelParser {
|
|||
TFModelParser() = default;
|
||||
~TFModelParser() override = default;
|
||||
|
||||
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
|
||||
static int TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
|
|||
return tflite::UnPackModel(tflite_model_buf_);
|
||||
}
|
||||
|
||||
FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
api::FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto model_file = flag.model_file;
|
||||
// load graph
|
||||
tflite_model_ = ReadTfliteModel(model_file);
|
||||
|
@ -81,7 +81,9 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
|
|||
}
|
||||
|
||||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(res_graph_, &all_func_graphs);
|
||||
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
|
||||
GetAllFuncGraph(func_graph, &all_func_graphs);
|
||||
|
||||
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "AdjustForAnf failed.";
|
||||
|
@ -95,7 +97,7 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag
|
|||
}
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(kFmkTypeTflite, false);
|
||||
MS_CHECK_TRUE_RET(unify_format != nullptr, nullptr);
|
||||
if (!unify_format->Run(res_graph_)) {
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
MS_LOG(ERROR) << "Run insert transpose failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -542,7 +544,9 @@ STATUS TfliteModelParser::ControlFlowNodePostProcess() {
|
|||
if (control_flow_map_.empty()) {
|
||||
return RET_OK;
|
||||
}
|
||||
static auto root_func_manager = Manage(res_graph_);
|
||||
auto func_graph = std::dynamic_pointer_cast<FuncGraph>(res_graph_);
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_ERROR);
|
||||
static auto root_func_manager = Manage(func_graph);
|
||||
for (auto &node_vs_graph : control_flow_map_) {
|
||||
auto control_flow_node = node_vs_graph.first;
|
||||
auto sub_graphs = node_vs_graph.second;
|
||||
|
|
|
@ -36,7 +36,7 @@ class TfliteModelParser : public converter::ModelParser {
|
|||
|
||||
~TfliteModelParser() override = default;
|
||||
|
||||
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
api::FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
|
||||
static int Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
|
||||
|
|
|
@ -20,25 +20,24 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace registry {
|
||||
namespace {
|
||||
std::map<std::string, opt::PassPtr> pass_store_room;
|
||||
std::map<std::string, PassBasePtr> outer_pass_storage;
|
||||
std::map<registry::PassPosition, std::vector<std::string>> external_assigned_passes;
|
||||
std::mutex pass_mutex;
|
||||
void RegPass(const std::string &pass_name, const opt::PassPtr &pass) {
|
||||
void RegPass(const std::string &pass_name, const PassBasePtr &pass) {
|
||||
if (pass == nullptr) {
|
||||
MS_LOG(ERROR) << "pass is nullptr.";
|
||||
return;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(pass_mutex);
|
||||
pass_store_room[pass_name] = pass;
|
||||
outer_pass_storage[pass_name] = pass;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
PassRegistry::PassRegistry(const std::string &pass_name, const opt::PassPtr &pass) { RegPass(pass_name, pass); }
|
||||
PassRegistry::PassRegistry(const std::string &pass_name, const PassBasePtr &pass) { RegPass(pass_name, pass); }
|
||||
|
||||
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::string> &names) {
|
||||
std::unique_lock<std::mutex> lock(pass_mutex);
|
||||
|
@ -49,17 +48,8 @@ std::vector<std::string> PassRegistry::GetOuterScheduleTask(PassPosition positio
|
|||
return external_assigned_passes[position];
|
||||
}
|
||||
|
||||
std::vector<opt::PassPtr> PassRegistry::GetPassFromStoreRoom(const std::vector<std::string> &pass_names) {
|
||||
std::vector<opt::PassPtr> schedule_passes;
|
||||
for (auto &name : pass_names) {
|
||||
auto iter = pass_store_room.find(name);
|
||||
if (iter == pass_store_room.end()) {
|
||||
continue;
|
||||
}
|
||||
MS_CHECK_TRUE_RET(iter->second != nullptr, std::vector<opt::PassPtr>{});
|
||||
schedule_passes.push_back(iter->second);
|
||||
}
|
||||
return schedule_passes;
|
||||
PassBasePtr PassRegistry::GetPassFromStoreRoom(const std::string &pass_name) {
|
||||
return outer_pass_storage.find(pass_name) == outer_pass_storage.end() ? nullptr : outer_pass_storage[pass_name];
|
||||
}
|
||||
} // namespace registry
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,16 +17,15 @@
|
|||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_DELETE_REDUNDANT_TRANSPOSE_H_
|
||||
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DeleteRedundantTranspose : public Pass {
|
||||
class DeleteRedundantTranspose {
|
||||
public:
|
||||
DeleteRedundantTranspose() : Pass("DeleteRedundantTranspose") {}
|
||||
DeleteRedundantTranspose() = default;
|
||||
~DeleteRedundantTranspose() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
bool Run(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
STATUS DeleteNot4DTranspose(const FuncGraphPtr &func_graph);
|
||||
|
|
Loading…
Reference in New Issue