split converter so

This commit is contained in:
sunsuodong 2022-04-19 19:39:04 -07:00
parent 1e54e390ca
commit b0776fd004
71 changed files with 822 additions and 837 deletions

View File

@ -5,6 +5,11 @@ set(RUNTIME_PKG_NAME ${PKG_NAME_PREFIX}-${RUNTIME_COMPONENT_NAME})
set(CONVERTER_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/converter)
set(OBFUSCATOR_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/obfuscator)
set(CROPPER_ROOT_DIR ${RUNTIME_PKG_NAME}/tools/cropper)
if(WIN32)
set(BUILD_DIR ${TOP_DIR}/build)
else()
set(BUILD_DIR ${TOP_DIR}/mindspore/lite/build)
endif()
set(TEST_CASE_DIR ${TOP_DIR}/mindspore/lite/test/build)
set(RUNTIME_DIR ${RUNTIME_PKG_NAME}/runtime)
@ -483,6 +488,8 @@ if(PLATFORM_ARM64)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(TARGETS converter_lite RUNTIME DESTINATION ${CONVERTER_ROOT_DIR}/converter
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${BUILD_DIR}/tools/converter/libmindspore_converter.so
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin.so
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libglog.so.0
@ -713,8 +720,10 @@ elseif(WIN32)
file(GLOB LIB_LIST ${CXX_DIR}/libstdc++-6.dll ${CXX_DIR}/libwinpthread-1.dll
${CXX_DIR}/libssp-0.dll ${CXX_DIR}/libgcc_s_*-1.dll)
if(MSLITE_ENABLE_CONVERTER)
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/converter_lite.exe
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/converter_lite/converter_lite.exe
DESTINATION ${CONVERTER_ROOT_DIR}/converter COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/libmindspore_converter.dll
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${LIB_LIST} DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/registry/libmslite_converter_plugin.dll
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
@ -867,6 +876,8 @@ else()
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h")
install(TARGETS converter_lite RUNTIME DESTINATION ${CONVERTER_ROOT_DIR}/converter
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${BUILD_DIR}/tools/converter/libmindspore_converter.so
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin.so
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0 DESTINATION ${CONVERTER_ROOT_DIR}/lib RENAME libglog.so.0

View File

@ -24,6 +24,7 @@
#include "src/runtime/lite_session.h"
#include "src/runtime/kernel_exec.h"
#include "src/common/file_utils.h"
#include "include/converter.h"
namespace mindspore {
class MindrtParallelTest : public mindspore::CommonTest {
@ -96,12 +97,12 @@ int CheckRuntime1(lite::LiteSession *session) {
}
TEST_F(MindrtParallelTest, offline1) {
const char *converter_argv[] = {"./converter", "--fmk=TFLITE",
"--modelFile=./mindrtParallel/mindrt_parallel_model.tflite",
"--outputFile=./mindrtParallel/mindrt_parallel_model_split",
"--configFile=./mindrtParallel/mindrt_parallel_model.config"};
int converter_ret = mindspore::lite::RunConverter(5, converter_argv);
ASSERT_EQ(converter_ret, lite::RET_OK);
mindspore::Converter converter(converter::kFmkTypeTflite, "./mindrtParallel/mindrt_parallel_model.tflite",
"./mindrtParallel/mindrt_parallel_model_split");
converter.SetConfigFile("./mindrtParallel/mindrt_parallel_model.config");
auto status = converter.Convert();
ASSERT_EQ(status, kSuccess);
size_t size = 0;
char *graph_buf = lite::ReadFile("./mindrtParallel/mindrt_parallel_model_split.ms", &size);
@ -135,11 +136,12 @@ TEST_F(MindrtParallelTest, offline1) {
}
TEST_F(MindrtParallelTest, runtime1) {
const char *converter_argv[] = {"./converter", "--fmk=TFLITE",
"--modelFile=./mindrtParallel/mindrt_parallel_model.tflite",
"--outputFile=./mindrtParallel/mindrt_parallel_model"};
int converter_ret = mindspore::lite::RunConverter(4, converter_argv);
ASSERT_EQ(converter_ret, lite::RET_OK);
mindspore::Converter converter(converter::kFmkTypeTflite, "./mindrtParallel/mindrt_parallel_model.tflite",
"./mindrtParallel/mindrt_parallel_model");
converter.SetConfigFile("./mindrtParallel/mindrt_parallel_model.config");
auto status = converter.Convert();
ASSERT_EQ(status, kSuccess);
size_t size = 0;
char *graph_buf = lite::ReadFile("./mindrtParallel/mindrt_parallel_model.ms", &size);

View File

@ -112,7 +112,7 @@ TEST_F(ActivationFusionTest, TestHardTanhReluNode) {
auto meta_graph = BuildGraph(schema::ActivationType_HARD_TANH, schema::ActivationType_RELU);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
@ -126,7 +126,7 @@ TEST_F(ActivationFusionTest, TestRelu6HardTanhNode) {
auto meta_graph = BuildGraph(schema::ActivationType_RELU6, schema::ActivationType_HARD_TANH);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
@ -140,7 +140,7 @@ TEST_F(ActivationFusionTest, TestBadCase_ReluSigmoid) {
auto meta_graph = BuildGraph(schema::ActivationType_RELU, schema::ActivationType_SIGMOID);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 2);

View File

@ -157,7 +157,7 @@ TEST_F(AddConcatActivationFusionTest, TestAddConcatReluNode) {
auto meta_graph = BuildGraph(schema::ActivationType_RELU6);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), kGraphNodeSize);

View File

@ -136,7 +136,7 @@ TEST_F(ConvActivationFusionTest, TestConvReluNode) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_RELU);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
@ -149,7 +149,7 @@ TEST_F(ConvActivationFusionTest, TestConvRelu6Node) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_RELU6);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
@ -162,7 +162,7 @@ TEST_F(ConvActivationFusionTest, TestBadCase_ConvRelu) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::ActivationType_LEAKY_RELU);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 2);

View File

@ -145,7 +145,7 @@ TEST_F(ConvBiasAddFusionTest, TestConvAddNode) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_BiasAdd);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
@ -156,7 +156,7 @@ TEST_F(ConvBiasAddFusionTest, TestDeptiwiseConvAddNode) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_AddFusion);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
@ -166,7 +166,7 @@ TEST_F(ConvBiasAddFusionTest, TestBadCase_ConvAdd) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_MatMulFusion);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 2);

View File

@ -262,7 +262,7 @@ TEST_F(ConvBNFusionTest, TestConvAddNode) {
auto meta_graph = BuildCaffeGraph(schema::PrimitiveType_Conv2DFusion);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
@ -272,7 +272,7 @@ TEST_F(ConvBNFusionTest, TestDeptiwiseConvAddNode) {
auto meta_graph = BuildTFGraph(schema::PrimitiveType_Conv2DFusion);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);

View File

@ -187,7 +187,7 @@ TEST_F(ConvScaleFusionTest, TestConvScaleNode) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, true);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
@ -198,7 +198,7 @@ TEST_F(ConvScaleFusionTest, TestDeptiwiseConvScaleNode) {
auto meta_graph = BuildGraph(schema::PrimitiveType_Conv2DFusion, false);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);

View File

@ -21,7 +21,6 @@
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
#include "ops/fusion/mat_mul_fusion.h"
#include "ops/fusion/activation.h"
#include "tools/converter/converter_flags.h"
namespace mindspore {
class MatMulActivationFusionInoutTest : public FusionInoutTest {
@ -29,10 +28,7 @@ class MatMulActivationFusionInoutTest : public FusionInoutTest {
MatMulActivationFusionInoutTest() = default;
protected:
void InitPass() override {
converter::Flags ctx;
this->pass_ = std::make_shared<opt::MatMulActivationFusion>(ctx);
}
void InitPass() override { this->pass_ = std::make_shared<opt::MatMulActivationFusion>(nullptr); }
void InitGraph() override {
this->graph_ = std::make_shared<FuncGraph>();

View File

@ -136,7 +136,7 @@ TEST_F(MatMulAddFusionTest, TestMatMulMulNode) {
auto meta_graph = BuildGraph();
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);

View File

@ -158,7 +158,7 @@ TEST_F(TransMatMulFusionTest, TestTransMatMulNode1) {
auto meta_graph = BuildGraph(trans_param_a, trans_param_b, output_dims);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
@ -180,7 +180,7 @@ TEST_F(TransMatMulFusionTest, TestTransMatMulNode2) {
auto meta_graph = BuildGraph(trans_param_a, trans_param_b, output_dims);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 1);
@ -202,7 +202,7 @@ TEST_F(TransMatMulFusionTest, TestBadCase_TransMatMul) {
auto meta_graph = BuildGraph(trans_param_a, trans_param_b, output_dims);
auto func_graph = lite::AnfImporterFromMetaGraphT::Fb2Anf(meta_graph.get());
auto anf_transform = new lite::AnfTransform();
auto new_graph = anf_transform->Transform(func_graph);
auto new_graph = anf_transform->Transform(func_graph, nullptr);
ASSERT_NE(nullptr, new_graph);
auto new_meta_graph = lite::Export(new_graph);
ASSERT_EQ(new_meta_graph->nodes.size(), 3);

View File

@ -4,6 +4,7 @@ set(USE_GLOG on)
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
add_compile_definitions(ENABLE_OPENSSL)
endif()
set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
set(TOOLS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..)
@ -38,7 +39,6 @@ include_directories(${TOP_DIR}/mindspore/ccsrc/plugin/device/cpu/kernel)
file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/converter.cc
${CMAKE_CURRENT_SOURCE_DIR}/converter_flags.cc
${CMAKE_CURRENT_SOURCE_DIR}/anf_transform.cc
${CMAKE_CURRENT_SOURCE_DIR}/graphdef_transform.cc
${CMAKE_CURRENT_SOURCE_DIR}/optimizer.cc
@ -72,8 +72,6 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/*.cc
)
list(REMOVE_ITEM CONVERTER_SRC cxx_api/converter.cc)
if((NOT WIN32) AND MSLITE_ENABLE_DPICO_ATC_ADAPTER)
add_subdirectory(adapter/dpico)
endif()
@ -297,7 +295,8 @@ set_property(SOURCE ${LITE_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindsp
add_library(converter_runtime_mid OBJECT ${LITE_SRC})
add_dependencies(converter_runtime_mid fbs_src fbs_inner_src)
target_compile_options(converter_runtime_mid PRIVATE "-Wno-stringop-overflow")
add_library(mindspore_converter SHARED $<TARGET_OBJECTS:converter_runtime_mid>)
target_compile_options(mindspore_converter PRIVATE "-Wno-stringop-overflow")
add_library(converter_src_mid OBJECT ${CONVERTER_SRC})
add_dependencies(converter_src_mid fbs_src fbs_inner_src)
@ -307,32 +306,27 @@ add_dependencies(ccsrc_src_mid fbs_src fbs_inner_src)
target_compile_definitions(ccsrc_src_mid PRIVATE BACKEND_DLL)
target_compile_definitions(converter_src_mid PRIVATE BACKEND_DLL)
add_executable(converter_lite
main.cc
)
add_dependencies(converter_lite fbs_src fbs_inner_src)
if(NOT ENABLE_CLOUD_AND_LITE)
add_dependencies(converter_lite nnacl_mid)
add_dependencies(mindspore_converter nnacl_mid)
endif()
if((NOT WIN32) AND MSLITE_ENABLE_DPICO_ATC_ADAPTER)
add_dependencies(converter_lite dpico_atc_adapter)
add_dependencies(mindspore_converter dpico_atc_adapter)
endif()
if(MSLITE_GPU_BACKEND STREQUAL opencl)
include_directories(${SRC_DIR}/runtime/kernel/opencl)
target_link_libraries(converter_lite PRIVATE opencl_kernel_mid)
target_link_libraries(mindspore_converter opencl_kernel_mid)
endif()
if(MSLITE_ENABLE_FP16)
target_link_libraries(converter_lite PRIVATE
target_link_libraries(mindspore_converter
nnacl_fp16_mid
)
endif()
target_link_libraries(converter_lite PRIVATE
target_link_libraries(mindspore_converter
ccsrc_src_mid
converter_src_mid
converter_runtime_mid
cpu_ops_mid
nnacl_mid
cpu_kernel_mid
@ -359,34 +353,36 @@ target_link_libraries(converter_lite PRIVATE
)
if(SUPPORT_TRAIN)
target_link_libraries(converter_lite PRIVATE train_cpu_kernel_mid)
target_link_libraries(mindspore_converter train_cpu_kernel_mid)
endif()
if(ENABLE_CONVERT_PYTORCH_MODEL)
target_link_libraries(converter_lite PRIVATE pytorch_parser_mid)
target_link_libraries(mindspore_converter pytorch_parser_mid)
endif()
if(NOT ENABLE_CLOUD_AND_LITE)
target_link_libraries(converter_lite PRIVATE
target_link_libraries(mindspore_converter
ccsrc_debug_common_mid_
mindir_proto_mid
_mindspore_transform_express_ir_obj)
endif()
if(MSLITE_ENABLE_ACL)
target_link_libraries(converter_lite PRIVATE
target_link_libraries(mindspore_converter
lite_acl_mid
ascend_kernel_mid)
endif()
if(NOT MSVC)
target_link_libraries(converter_lite PRIVATE pthread)
target_link_libraries(mindspore_converter pthread)
endif()
if(NOT WIN32)
target_link_libraries(converter_lite PRIVATE dl)
target_link_libraries(mindspore_converter dl)
endif()
if(ENABLE_MODEL_OBF)
target_link_libraries(converter_lite PRIVATE
target_link_libraries(mindspore_converter
${OBF_LIB_DIR}/libmsdeobfuscator-lite.so)
endif()
add_subdirectory(converter_lite)

View File

@ -21,9 +21,9 @@
namespace mindspore {
namespace opt {
AclPass::AclPass(const converter::Flags &config) : Pass("ACL") {
AclPass::AclPass(const std::shared_ptr<ConverterPara> &param) : Pass("ACL") {
#ifdef ENABLE_LITE_ACL
impl_ = std::make_shared<AclPassImpl>(config);
impl_ = std::make_shared<AclPassImpl>(param);
#endif
}

View File

@ -20,7 +20,7 @@
#define USE_DEPRECATED_API
#include <memory>
#include "backend/common/optimizer/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/cxx_api/converter_para.h"
namespace mindspore {
namespace opt {
@ -29,7 +29,7 @@ using AclPassImplPtr = std::shared_ptr<AclPassImpl>;
class AclPass : public Pass {
public:
explicit AclPass(const converter::Flags &config);
explicit AclPass(const std::shared_ptr<ConverterPara> &param);
~AclPass() override = default;
bool Run(const FuncGraphPtr &func_graph) override;

View File

@ -101,9 +101,9 @@ STATUS PreProcForOnnx(const FuncGraphPtr &func_graph) {
}
} // namespace
AclPassImpl::AclPassImpl(const converter::Flags &config)
: fmk_type_(config.fmk),
user_options_cfg_(std::move(config.aclModelOptionCfgParam)),
AclPassImpl::AclPassImpl(const std::shared_ptr<ConverterPara> &param)
: fmk_type_(param->fmk_type),
user_options_cfg_(std::move(param->aclModelOptionCfgParam)),
om_parameter_(nullptr),
custom_node_(nullptr) {}

View File

@ -27,8 +27,8 @@
#include "include/registry/converter_context.h"
#include "cxx_api/model/acl/acl_model_options.h"
#include "tools/converter/adapter/acl/common/acl_types.h"
#include "tools/converter/converter_flags.h"
#include "ops/custom.h"
#include "tools/converter/cxx_api/converter_para.h"
namespace mindspore {
namespace opt {
@ -37,7 +37,7 @@ using mindspore::lite::STATUS;
class AclPassImpl {
public:
explicit AclPassImpl(const converter::Flags &config);
explicit AclPassImpl(const std::shared_ptr<ConverterPara> &param);
~AclPassImpl() = default;
bool Run(const FuncGraphPtr &func_graph);

View File

@ -171,13 +171,12 @@ STATUS AnfTransform::MarkTrainOp(const FuncGraphPtr &func_graph) {
return RET_OK;
}
int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
auto status = MarkTrainOp(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "MarkTrainOp failed.";
return RET_ERROR;
}
CHECK_NULL_RETURN(config);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
CHECK_NULL_RETURN(optimizer);
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
@ -190,8 +189,8 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
fusion_pm->AddPass(std::make_shared<opt::TransposeFusion>());
fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>(config->fmk));
fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>(config->fmk));
fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>(param->fmk_type));
fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>(param->fmk_type));
fusion_pm->AddPass(std::make_shared<opt::GroupNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion>());
@ -200,7 +199,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
fusion_pm->AddPass(std::make_shared<opt::BatchNormToScaleFusion>());
fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
fusion_pm->AddPass(std::make_shared<opt::ActivationFusion>());
if (config->fullQuantParam.target_device != quant::NVGPU) {
if (param->fullQuantParam.target_device != quant::NVGPU) {
fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>());
}
fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>());
@ -212,7 +211,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>());
fusion_pm->AddPass(std::make_shared<opt::GLUFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk, config->trainModel));
fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(param->fmk_type, param->train_model));
fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
@ -226,7 +225,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
fusion_pm->AddPass(std::make_shared<opt::FullConnectedFusion>());
fusion_pm->AddPass(std::make_shared<opt::FullconnectedAddFusion>());
fusion_pm->AddPass(std::make_shared<opt::TensorDotFusion>());
fusion_pm->AddPass(std::make_shared<opt::MatMulActivationFusion>(*config));
fusion_pm->AddPass(std::make_shared<opt::MatMulActivationFusion>(param));
optimizer->AddPassManager(fusion_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
MS_LOG(ERROR) << "run op fusion failed.";
@ -235,12 +234,12 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
return RET_OK;
}
int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
MS_LOG(DEBUG) << "Run ParallelPass start";
if (config->trainModel || config->parallel_split_config_.parallel_split_type_ == converter::SplitNo) {
if (param->train_model || param->parallel_split_config.parallel_split_type_ == SplitNo) {
return RET_OK;
}
if (config->parallel_split_config_.parallel_split_type_ == converter::SplitByUserRatio) {
if (param->parallel_split_config.parallel_split_type_ == SplitByUserRatio) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
CHECK_NULL_RETURN(optimizer);
auto graph_inputs = old_graph->get_inputs();
@ -261,9 +260,8 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
}
}
// 1. deal with split strategy
std::unordered_map<std::string, opt::SplitStrategy> split_strategys =
opt::ParserSplitStrategy(config->parallel_split_config_.parallel_compute_rates_,
config->parallel_split_config_.parallel_devices_, split_mode);
std::unordered_map<std::string, opt::SplitStrategy> split_strategys = opt::ParserSplitStrategy(
param->parallel_split_config.parallel_compute_rates_, param->parallel_split_config.parallel_devices_, split_mode);
if (split_strategys.empty()) {
MS_LOG(WARNING) << "No valid split_strategy. Run convert without split";
return RET_OK;
@ -279,7 +277,7 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
// we do not deal with single conv node
for (int match_number = max_match_number; match_number > opt::kDefaultBatch; --match_number) {
// 3. multi_conv parallel pass
parallel_pm->AddPass(std::make_shared<opt::MultiConvSplitPass>(split_strategys, config->fmk, match_number));
parallel_pm->AddPass(std::make_shared<opt::MultiConvSplitPass>(split_strategys, param->fmk_type, match_number));
parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
}
@ -293,18 +291,18 @@ int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter
return RET_OK;
}
int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
CHECK_NULL_RETURN(optimizer);
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
CHECK_NULL_RETURN(graph_pm);
if (config->fmk == converter::kFmkTypeTflite || config->fmk == converter::kFmkTypeTf ||
config->fmk == converter::kFmkTypeOnnx) {
if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
param->fmk_type == converter::kFmkTypeOnnx) {
graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
}
auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
CHECK_NULL_RETURN(slice_prepose_pass);
slice_prepose_pass->SetFmkType(config->fmk);
slice_prepose_pass->SetFmkType(param->fmk_type);
graph_pm->AddPass(slice_prepose_pass);
optimizer->AddPassManager(graph_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
@ -314,8 +312,8 @@ int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::F
return RET_OK;
}
int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto acl_pass = std::make_shared<opt::AclPass>(*config);
int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
auto acl_pass = std::make_shared<opt::AclPass>(param);
CHECK_NULL_RETURN(acl_pass);
if (!acl_pass->Run(old_graph)) {
MS_LOG(ERROR) << "Acl pass failed.";
@ -326,8 +324,8 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter:
CHECK_NULL_RETURN(optimizer);
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
CHECK_NULL_RETURN(convert_pm);
convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(config->trainModel));
convert_pm->AddPass(std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel));
convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
convert_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
convert_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
optimizer->AddPassManager(convert_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
@ -337,14 +335,14 @@ int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter:
return RET_OK;
}
int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
CHECK_NULL_RETURN(optimizer);
CHECK_NULL_RETURN(const_fold_pm);
const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel));
if (!config->trainModel) {
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk, config->trainModel));
const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
if (!param->train_model) {
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(param->fmk_type, param->train_model));
}
const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
const_fold_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
@ -356,8 +354,8 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte
return RET_OK;
}
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, converter::Flags *config) {
quant::QuantizationOptimizer optimizer(config);
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
quant::QuantizationOptimizer optimizer(param);
auto ret = optimizer.Run(old_graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Post training quantization failed.";
@ -366,8 +364,8 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, converter::Flags *co
return RET_OK;
}
bool RunEliminateRedundantPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto eliminate_cast_pass = std::make_shared<opt::EliminateRedundantCastPass>(config->fmk, config->trainModel);
bool RunEliminateRedundantPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
auto eliminate_cast_pass = std::make_shared<opt::EliminateRedundantCastPass>(param->fmk_type, param->train_model);
MS_CHECK_TRUE_RET(eliminate_cast_pass != nullptr, false);
if (!eliminate_cast_pass->Run(old_graph)) {
MS_LOG(ERROR) << "Run cast elimination pass failed.";
@ -390,11 +388,12 @@ bool RunEliminateRedundantPass(const FuncGraphPtr &old_graph, const converter::F
return true;
}
FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
const std::shared_ptr<ConverterPara> &param) {
MS_ASSERT(old_graph != nullptr);
MS_ASSERT(config != nullptr);
MS_ASSERT(param != nullptr);
auto status = RunConvertPass(old_graph, config);
auto status = RunConvertPass(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run convert pass failed.";
return nullptr;
@ -405,7 +404,7 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
return nullptr;
}
status = RunConstFoldPass(old_graph, config);
status = RunConstFoldPass(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run const fold pass failed.";
return nullptr;
@ -420,13 +419,13 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
}
}
if (!RunEliminateRedundantPass(old_graph, config)) {
if (!RunEliminateRedundantPass(old_graph, param)) {
MS_LOG(ERROR) << "Run elimination of redundant pass failed.";
return nullptr;
}
if (!config->disableFusion) {
status = RunFusionPass(old_graph, config);
if (!param->no_fusion) {
status = RunFusionPass(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run fusion pass failed.";
return nullptr;
@ -451,23 +450,23 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
}
}
status = RunGraphPass(old_graph, config);
status = RunGraphPass(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run convert pass failed.";
return nullptr;
}
status = RunParallelPass(old_graph, config);
status = RunParallelPass(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run convert pass failed.";
return nullptr;
}
if (!config->pluginsPath.empty() && config->commonQuantParam.quant_type != schema::QuantType_QUANT_NONE) {
if (!param->plugins_path.empty() && param->commonQuantParam.quant_type != schema::QuantType_QUANT_NONE) {
MS_LOG(ERROR) << "Unsupported external extension with quantization.";
return nullptr;
}
status = DoQuantize(old_graph, const_cast<converter::Flags *>(config));
status = DoQuantize(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do Quantize failed.";
return nullptr;
@ -476,17 +475,17 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
return old_graph;
}
bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) {
if (config == nullptr) {
bool AnfTransform::StoreBuiltinPass(const std::shared_ptr<ConverterPara> &param) {
if (param == nullptr) {
MS_LOG(ERROR) << "config is nullptr";
return false;
}
auto fmk = config->fmk;
auto is_train = config->trainModel;
auto fmk = param->fmk_type;
auto is_train = param->train_model;
// pass_name, pass and boolean value to indicate whether can be called by external extension,
std::vector<std::tuple<std::string, opt::PassPtr, bool>> pass_infos = {
{"DumpGraph", std::make_shared<opt::DumpGraph>(config), true},
{"RemoveRedundantOpPass", std::make_shared<opt::RemoveRedundantOpPass>(config->trainModel), false},
{"DumpGraph", std::make_shared<opt::DumpGraph>(param), true},
{"RemoveRedundantOpPass", std::make_shared<opt::RemoveRedundantOpPass>(param->train_model), false},
{"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train), true},
{"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train), true},
{"ConstFoldPass", std::make_shared<opt::ConstFoldPass>(fmk, is_train), true},
@ -494,25 +493,25 @@ bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) {
{"DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>(), false},
{"SpecialNodePostProcess", std::make_shared<opt::SpecialNodePostProcess>(), false},
{"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train), true},
{"SpecifyGraphInputFormat", std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat), false}};
{"SpecifyGraphInputFormat", std::make_shared<opt::SpecifyGraphInputFormat>(param->input_format), false}};
for (const auto &pass_info : pass_infos) {
MS_CHECK_TRUE_RET(std::get<1>(pass_info) != nullptr, false);
PassStorage::StorePass(std::get<0>(pass_info), std::get<1>(pass_info), std::get<opt::kInputIndexTwo>(pass_info));
}
auto dump_graph_outer = std::make_shared<opt::DumpGraph>(config);
auto dump_graph_outer = std::make_shared<opt::DumpGraph>(param);
MS_CHECK_TRUE_MSG(dump_graph_outer != nullptr, false, "dumpGraph object is a nullptr.");
registry::PassRegistry("DumpGraph", dump_graph_outer);
return true;
}
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const std::shared_ptr<ConverterPara> &param) {
MS_CHECK_TRUE_MSG(main_graph != nullptr, nullptr, "Input func_graph is nullptr");
MS_CHECK_TRUE_MSG(config != nullptr, nullptr, "Input converter config is nullptr");
if (!StoreBuiltinPass(config)) {
MS_CHECK_TRUE_MSG(param != nullptr, nullptr, "Input converter param is nullptr");
if (!StoreBuiltinPass(param)) {
MS_LOG(ERROR) << "store pass failed.";
return nullptr;
}
auto new_graph = TransformFuncGraph(main_graph, config);
auto new_graph = TransformFuncGraph(main_graph, param);
if (new_graph == nullptr) {
MS_LOG(ERROR) << "optimizer failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);

View File

@ -23,7 +23,6 @@
#include "backend/common/optimizer/optimizer.h"
#include "schema/inner/model_generated.h"
#include "tools/common/meta_graph_serializer.h"
#include "tools/converter/converter_flags.h"
#include "ir/anf.h"
#include "tools/converter/quantizer/quantizer.h"
#include "tools/converter/converter_context.h"
@ -34,24 +33,24 @@ class AnfTransform {
public:
AnfTransform();
virtual ~AnfTransform();
FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
private:
FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
static int RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunFusionPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
static int RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
static int RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunConvertPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
static int RunConstFoldPass(const FuncGraphPtr &olde_graph, const converter::Flags *config);
static int RunConstFoldPass(const FuncGraphPtr &olde_graph, const std::shared_ptr<ConverterPara> &param);
static int RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
static int RunParallelPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
static int DoQuantize(const FuncGraphPtr &old_graph, converter::Flags *config);
static int DoQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param);
static bool StoreBuiltinPass(const converter::Flags *config);
static bool StoreBuiltinPass(const std::shared_ptr<ConverterPara> &param);
static STATUS MarkTrainInputOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode);

View File

@ -19,7 +19,7 @@
#include <memory>
#include <vector>
#include <set>
#include "tools/converter/converter_flags.h"
#include <algorithm>
#include "src/common/log_adapter.h"
#include "tools/common/meta_graph_serializer.h"
#include "tools/lite_exporter/anf_exporter.h"
@ -42,43 +42,52 @@
#include "include/api/model.h"
#include "tools/mindir_exporter/mindir_serializer.h"
#include "src/common/primitive_t_utils.h"
#include "tools/converter/config_parser/acl_option_param_parser.h"
#include "tools/converter/config_parser/micro_param_parser.h"
#include "tools/converter/config_parser/preprocess_parser.h"
#include "tools/converter/config_parser/quant_param_parser.h"
#include "tools/common/string_util.h"
#include "src/common/file_utils.h"
namespace mindspore {
extern "C" {
void common_log_init();
}
namespace lite {
namespace {
constexpr size_t kMaxNum1024 = 1024;
void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) {
MS_ASSERT(converter_parameters != nullptr);
converter_parameters->fmk = flag.fmk;
converter_parameters->model_file = flag.modelFile;
converter_parameters->weight_file = flag.weightFile;
}
constexpr size_t kPluginPathMaxNum = 10;
constexpr int kPathLengthUpperLimit = 1024;
constexpr size_t kEncMaxLen = 16;
FuncGraphPtr ConvertGraph(const api::FuncGraphPtr &func_graph) {
auto impl = func_graph->impl();
return std::dynamic_pointer_cast<FuncGraph>(impl);
}
} // namespace
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr<ConverterPara> &param) {
api::FuncGraphPtr func_graph_base = nullptr;
if (flag.fmk == converter::FmkType::kFmkTypeMs) {
if (param->fmk_type == converter::FmkType::kFmkTypeMs) {
#ifdef SUPPORT_TRAIN
kernel::PopulateTrainParameters();
#endif
MindsporeImporter ms_import;
func_graph_base = api::MakeShared<api::FuncGraph>(ms_import.ImportMindIR(flag));
func_graph_base = api::MakeShared<api::FuncGraph>(ms_import.ImportMindIR(param));
} else {
model_parser_ = registry::ModelParserRegistry::GetModelParser(flag.fmk);
model_parser_ = registry::ModelParserRegistry::GetModelParser(param->fmk_type);
if (model_parser_ == nullptr) {
MS_LOG(ERROR) << "Unsupported to converter models with fmk: " << flag.fmkIn;
MS_LOG(ERROR) << "Unsupported to converter models with fmk: " << param->fmk_type;
return nullptr;
}
converter::ConverterParameters converter_parameters;
InitConverterParameters(flag, &converter_parameters);
converter_parameters.fmk = param->fmk_type;
converter_parameters.model_file = param->model_file;
converter_parameters.weight_file = param->weight_file;
func_graph_base = model_parser_->Parse(converter_parameters);
}
if (func_graph_base == nullptr) {
MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn;
MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << param->fmk_type;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NOT_SUPPORT);
return nullptr;
}
@ -94,9 +103,10 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
return func_graph;
}
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag, const void *buf, const size_t &size) {
FuncGraphPtr ConverterImpl::BuildFuncGraph(const std::shared_ptr<ConverterPara> &param, const void *buf,
const size_t &size) {
MindsporeImporter ms_import;
FuncGraphPtr func_graph = ms_import.ImportMindIR(flag, buf, size);
FuncGraphPtr func_graph = ms_import.ImportMindIR(param, buf, size);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Get funcGraph failed.";
return nullptr;
@ -110,39 +120,50 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag, const void
return func_graph;
}
schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag, const void *buf,
const size_t &size) {
if (flag == nullptr || buf == nullptr) {
MS_LOG(ERROR) << "Input flag is nullptr";
schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, const void *buf,
const size_t &size) {
if (param == nullptr || buf == nullptr) {
MS_LOG(ERROR) << "Input param is nullptr";
return nullptr;
}
auto graph = BuildFuncGraph(*flag, buf, size);
auto graph = BuildFuncGraph(param, buf, size);
if (graph == nullptr) {
MS_LOG(ERROR) << "Parser/Import model return nullptr";
return nullptr;
}
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed.");
// funcgraph_transform
graph = funcgraph_transform_->Transform(graph, flag.get());
graph = funcgraph_transform_->Transform(graph, param);
MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Transform anf graph return nullptr.");
// export protobuf
auto status = MindIRSerialize(flag, graph);
auto status = MindIRSerialize(param, graph);
if (status != RET_OK) {
MS_LOG(WARNING) << "Export to mindir proto return nullptr.";
}
return TransferFuncGraph(flag, graph);
return TransferFuncGraph(param, graph);
}
schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag) {
if (flag == nullptr) {
MS_LOG(ERROR) << "Input flag is nullptr";
schema::MetaGraphT *ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param) {
if (param == nullptr) {
MS_LOG(ERROR) << "Input param is nullptr";
return nullptr;
}
param->aclModelOptionCfgParam.om_file_path = param->output_file;
param->aclModelOptionCfgParam.offline = true;
if (!param->config_file.empty()) {
auto ret = InitConfigFile(param);
if (ret != RET_OK) {
std::cerr << "Init config file failed." << std::endl;
return nullptr;
}
}
// load plugin
static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders;
if (!flag->pluginsPath.empty()) {
for (auto &path : flag->pluginsPath) {
if (!param->plugins_path.empty()) {
for (auto &path : param->plugins_path) {
auto dl_loader = std::make_shared<DynamicLibraryLoader>();
MS_CHECK_TRUE_RET(dl_loader != nullptr, nullptr);
auto status = dl_loader->Open(path);
@ -154,7 +175,7 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &
}
}
auto graph = BuildFuncGraph(*flag);
auto graph = BuildFuncGraph(param);
if (graph == nullptr) {
MS_LOG(ERROR) << "Parser/Import model return nullptr";
return nullptr;
@ -162,23 +183,23 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
// funcgraph transform
graph = funcgraph_transform_->Transform(graph, flag.get());
graph = funcgraph_transform_->Transform(graph, param);
if (graph == nullptr) {
MS_LOG(ERROR) << "Transform anf graph return nullptr";
return nullptr;
}
// export protobuf
auto status = MindIRSerialize(flag, graph);
auto status = MindIRSerialize(param, graph);
if (status != RET_OK) {
MS_LOG(WARNING) << "Export to mindir proto return nullptr.";
}
return TransferFuncGraph(flag, graph);
return TransferFuncGraph(param, graph);
}
schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr<converter::Flags> &flag,
FuncGraphPtr func_graph) {
schema::MetaGraphT *ConverterImpl::TransferFuncGraph(const std::shared_ptr<ConverterPara> &param,
FuncGraphPtr func_graph) {
MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed");
#ifdef MSLITE_ENABLE_GRAPH_KERNEL
if (graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
@ -187,7 +208,7 @@ schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr<converter
#endif
// protobuf -> flatbuffer
auto meta_graph = Export(func_graph, false, false, flag->trainModel);
auto meta_graph = Export(func_graph, false, false, param->train_model);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta graph return nullptr";
return nullptr;
@ -195,7 +216,7 @@ schema::MetaGraphT *Converter::TransferFuncGraph(const std::unique_ptr<converter
// metagraph compile
metagraph_transform_->SetGraphDef(meta_graph);
auto status = metagraph_transform_->Transform(*flag);
auto status = metagraph_transform_->Transform(param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Transform meta graph failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -233,8 +254,8 @@ int CheckExistCustomOps(const schema::MetaGraphT *meta_graph, bool *exist_custom
return RET_OK;
}
int PreInference(const schema::MetaGraphT &meta_graph, const std::unique_ptr<converter::Flags> &flags) {
if (flags->trainModel) {
int PreInference(const schema::MetaGraphT &meta_graph, bool train_model) {
if (train_model) {
MS_LOG(WARNING) << "train model dont support pre-infer.";
return RET_OK;
}
@ -298,32 +319,213 @@ int PreInference(const schema::MetaGraphT &meta_graph, const std::unique_ptr<con
return RET_OK;
}
int RunConverter(int argc, const char **argv) {
std::ostringstream oss;
auto flags = std::make_unique<converter::Flags>();
if (flags == nullptr) {
oss.clear();
oss << "NEW FLAGS ERROR:" << RET_MEMORY_FAILED << " " << GetErrorInfo(RET_MEMORY_FAILED);
MS_LOG(ERROR) << oss.str();
std::cout << oss.str() << std::endl;
return RET_MEMORY_FAILED;
int ConverterImpl::InitConfigFile(const std::shared_ptr<ConverterPara> &param) {
lite::ConfigFileParser config_parser;
auto ret = config_parser.ParseConfigFile(param->config_file);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse config file failed.";
return ret;
}
auto status = flags->Init(argc, argv);
if (status != RET_OK) {
if (status != RET_SUCCESS_EXIT) {
oss.clear();
oss << "CONVERTER::FLAGS INIT FAILED:" << status << " " << GetErrorInfo(status);
MS_LOG(ERROR) << oss.str();
std::cout << oss.str() << std::endl;
ret = lite::PreprocessParser::ParsePreprocess(config_parser.GetDataPreProcessString(), &param->dataPreProcessParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse preprocess failed.";
return ret;
}
ret = lite::QuantParamParser::ParseCommonQuant(config_parser.GetCommonQuantString(), &param->commonQuantParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse common quant param failed.";
return ret;
}
ret = lite::QuantParamParser::ParseFullQuant(config_parser.GetFullQuantString(), &param->fullQuantParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse full quant param failed.";
return ret;
}
ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_parser.GetMixedBitWeightQuantString(),
&param->mixedBitWeightQuantParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse mixed bit weight quant param failed.";
return ret;
}
ret = InitExtendedIntegrationInfo(param, config_parser);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse extended integration info failed.";
return ret;
}
lite::AclOptionParamParser acl_param_parser;
ret = acl_param_parser.ParseAclOptionCfg(config_parser.GetAclOptionCfgString(), &param->aclModelOptionCfgParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse acl option param failed.";
return ret;
}
(void)CheckOfflineParallelConfig(param->config_file, &param->parallel_split_config);
lite::MicroParamParser micro_param_parser;
ret = micro_param_parser.ParseMicroParam(config_parser.GetMicroParamString(), &param->microParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse micro param failed.";
return ret;
}
return RET_OK;
}
int ConverterImpl::InitExtendedIntegrationInfo(const std::shared_ptr<ConverterPara> &param,
const lite::ConfigFileParser &config_parser) {
auto extended_info = config_parser.GetRegistryInfoString();
if (!extended_info.plugin_path.empty()) {
const char delimiter = ';';
auto relative_path = lite::SplitStringToVector(extended_info.plugin_path, delimiter);
if (relative_path.size() > kPluginPathMaxNum) {
MS_LOG(ERROR) << "extended plugin library's num is too big, which shouldn't be larger than " << kPluginPathMaxNum;
return RET_INPUT_PARAM_INVALID;
}
for (auto &i : relative_path) {
param->plugins_path.push_back(lite::RealPath(i.c_str()));
}
return status;
}
// Load graph
MS_LOG(DEBUG) << "start reading model file";
Converter cvt;
auto meta_graph = cvt.Convert(flags);
if (!extended_info.disable_fusion.empty()) {
if (extended_info.disable_fusion == "on") {
param->no_fusion = true;
} else if (extended_info.disable_fusion == "off") {
param->no_fusion = false;
} else {
std::cerr << "CONFIG SETTING ILLEGAL: disable_fusion should be on/off" << std::endl;
return RET_INPUT_PARAM_INVALID;
}
}
return RET_OK;
}
bool ConverterImpl::CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) {
// device: [device0 device1] ---> {cpu, gpu}
// computeRate: [x: y] x >=0 && y >=0 && x/y < 10
MS_ASSERT(parallel_split_config != nullptr);
std::vector<std::string> config_devices = {"cpu", "gpu", "npu"};
auto compute_rate_result = GetStrFromConfigFile(file, kComputeRate);
if (compute_rate_result.empty()) {
return false;
}
std::string device0_result = GetStrFromConfigFile(file, kSplitDevice0);
if (device0_result.empty()) {
return false;
}
std::string device1_result = GetStrFromConfigFile(file, kSplitDevice1);
if (device1_result.empty()) {
return false;
}
bool device0_flag = false;
bool device1_flag = false;
for (const auto &device : config_devices) {
if (device == device0_result) {
device0_flag = true;
}
if (device == device1_result) {
device1_flag = true;
}
}
if (!device0_flag || !device1_flag) {
return false;
}
const char delimiter = ';';
std::vector<std::string> device_rates = lite::SplitStringToVector(compute_rate_result, delimiter);
const char colon = ':';
for (const auto &device : device_rates) {
std::vector<std::string> rate = lite::SplitStringToVector(device, colon);
int64_t compute_rate = 0;
try {
compute_rate = std::stoi(rate.back());
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Get compute rate failed: " << e.what();
return false;
}
parallel_split_config->parallel_compute_rates_.push_back(compute_rate);
}
const size_t support_rates_num = 2;
if (parallel_split_config->parallel_compute_rates_.size() != support_rates_num) {
return false;
}
int64_t bigger_rate = INT32_MIN;
int64_t smaller_rate = INT32_MAX;
for (const auto &rate : parallel_split_config->parallel_compute_rates_) {
if (rate <= 0 || rate > INT32_MAX) {
return false;
}
bigger_rate = std::max(rate, bigger_rate);
smaller_rate = std::min(rate, smaller_rate);
}
parallel_split_config->parallel_devices_.push_back(device0_result);
parallel_split_config->parallel_devices_.push_back(device1_result);
// parall_split_type will extend by other user's attr
parallel_split_config->parallel_split_type_ = SplitByUserRatio;
if (smaller_rate == 0) {
MS_LOG(ERROR) << "smaller_rate is zero";
return false;
}
return bigger_rate / smaller_rate <= kMaxSplitRatio;
}
std::string ConverterImpl::GetStrFromConfigFile(const std::string &file, const std::string &target_key) {
std::string res;
if (file.empty()) {
MS_LOG(ERROR) << "file is nullptr";
return res;
}
auto resolved_path = std::make_unique<char[]>(PATH_MAX);
if (resolved_path == nullptr) {
MS_LOG(ERROR) << "new resolved_path failed";
return "";
}
#ifdef _WIN32
auto *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit);
#else
char *real_path = realpath(file.c_str(), resolved_path.get());
#endif
if (real_path == nullptr || strlen(real_path) == 0) {
MS_LOG(ERROR) << "file path is not valid : " << file;
return "";
}
std::ifstream ifs(resolved_path.get());
if (!ifs.good()) {
MS_LOG(ERROR) << "file: " << real_path << " is not exist";
return res;
}
if (!ifs.is_open()) {
MS_LOG(ERROR) << "file: " << real_path << "open failed";
return res;
}
std::string line;
while (std::getline(ifs, line)) {
lite::Trim(&line);
if (line.empty() || line.at(0) == '#' || line.at(0) == '[') {
continue;
}
auto index = line.find('=');
if (index == std::string::npos) {
MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
return "";
}
auto key = line.substr(0, index);
auto value = line.substr(index + 1);
lite::Trim(&key);
lite::Trim(&value);
if (key == target_key) {
return value;
}
}
return res;
}
int RunConverter(const std::shared_ptr<ConverterPara> &param) {
mindspore::common_log_init();
ConverterImpl converter_impl;
auto meta_graph = converter_impl.Convert(param);
NotSupportOp::GetInstance()->PrintOps();
status = ReturnCode::GetSingleReturnCode()->status_code();
int status = ReturnCode::GetSingleReturnCode()->status_code();
std::ostringstream oss;
if (meta_graph == nullptr) {
oss.clear();
oss << "CONVERT RESULT FAILED:" << status << " " << GetErrorInfo(status);
@ -335,8 +537,8 @@ int RunConverter(int argc, const char **argv) {
// save graph to file
meta_graph->version = Version();
if (flags->infer) {
status = PreInference(*meta_graph, flags);
if (param->pre_infer) {
status = PreInference(*meta_graph, param->train_model);
if (status != RET_OK) {
oss.clear();
oss << "PRE INFERENCE FAILED:" << status << " " << GetErrorInfo(status);
@ -347,10 +549,10 @@ int RunConverter(int argc, const char **argv) {
}
}
if (flags->microParam.enable_micro) {
status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, flags->outputFile, flags->microParam.codegen_mode,
flags->microParam.target, flags->microParam.support_parallel,
flags->microParam.debug_mode);
if (param->microParam.enable_micro) {
status = micro::Coder::MicroSourceCodeGeneration(*meta_graph, param->output_file, param->microParam.codegen_mode,
param->microParam.target, param->microParam.support_parallel,
param->microParam.debug_mode);
if (status != RET_OK) {
delete meta_graph;
oss.clear();
@ -360,7 +562,27 @@ int RunConverter(int argc, const char **argv) {
return status;
}
} else {
status = MetaGraphSerializer::Save(*meta_graph, flags->outputFile, flags->encKey, flags->keyLen, flags->encMode);
unsigned char encKey[kEncMaxLen] = {0};
size_t keyLen = 0;
if (param->enable_encryption) {
if (!param->encrypt_key.empty()) {
keyLen = lite::Hex2ByteArray(param->encrypt_key, encKey, kEncMaxLen);
if (keyLen != kEncMaxLen) {
MS_LOG(ERROR) << "enc_key must expressed in hexadecimal characters "
<< " and only support AES-GCM method and the key length is 16.";
return RET_INPUT_PARAM_INVALID;
}
} else {
MS_LOG(ERROR) << "If you don't need to use model encryption, please set --encryption=false.";
return RET_INPUT_PARAM_INVALID;
}
}
status = MetaGraphSerializer::Save(*meta_graph, param->output_file, encKey, keyLen, param->encrypt_mode);
if (memset_s(encKey, kEncMaxLen, 0, kEncMaxLen) != EOK) {
MS_LOG(ERROR) << "memset failed.";
delete meta_graph;
return RET_ERROR;
}
if (status != RET_OK) {
delete meta_graph;
oss.clear();
@ -370,15 +592,6 @@ int RunConverter(int argc, const char **argv) {
return status;
}
}
// clear key
flags->dec_key.clear();
flags->encKeyStr.clear();
status = memset_s(flags->encKey, converter::kEncMaxLen, 0, converter::kEncMaxLen);
if (status != EOK) {
MS_LOG(ERROR) << "memset failed.";
delete meta_graph;
return RET_ERROR;
}
delete meta_graph;
oss.clear();
oss << "CONVERT RESULT SUCCESS:" << status;

View File

@ -20,39 +20,54 @@
#define USE_DEPRECATED_API
#include <memory>
#include <string>
#include "include/converter.h"
#include "include/registry/model_parser.h"
#include "schema/inner/model_generated.h"
#include "tools/converter/graphdef_transform.h"
#include "include/registry/model_parser_registry.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/anf_transform.h"
#include "tools/converter/converter_context.h"
#include "tools/common/graph_util.h"
#include "tools/converter/preprocess/preprocess_param.h"
#include "tools/converter/quantizer/quant_params.h"
#include "tools/converter/adapter/acl/common/acl_types.h"
#include "micro/coder/config.h"
#include "tools/converter/cxx_api/converter_para.h"
#include "tools/converter/config_parser/config_file_parser.h"
namespace mindspore {
namespace lite {
class Converter {
constexpr auto kMaxSplitRatio = 10;
constexpr auto kComputeRate = "computeRate";
constexpr auto kSplitDevice0 = "device0";
constexpr auto kSplitDevice1 = "device1";
class ConverterImpl {
public:
Converter() = default;
~Converter() {
ConverterImpl() = default;
~ConverterImpl() {
delete model_parser_;
this->model_parser_ = nullptr;
}
schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag);
schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag, const void *buf, const size_t &size);
schema::MetaGraphT *Convert(const std::shared_ptr<ConverterPara> &param);
schema::MetaGraphT *Convert(const std::shared_ptr<ConverterPara> &param, const void *buf, const size_t &size);
private:
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag);
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag, const void *buf, const size_t &size);
schema::MetaGraphT *TransferFuncGraph(const std::unique_ptr<converter::Flags> &flag, FuncGraphPtr func_graph);
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> &param);
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> &param, const void *buf, const size_t &size);
schema::MetaGraphT *TransferFuncGraph(const std::shared_ptr<ConverterPara> &param, FuncGraphPtr func_graph);
int InitConfigFile(const std::shared_ptr<ConverterPara> &param);
int InitExtendedIntegrationInfo(const std::shared_ptr<ConverterPara> &param,
const lite::ConfigFileParser &config_file_parser);
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config);
std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key);
protected:
converter::ModelParser *model_parser_ = nullptr;
std::unique_ptr<GraphDefTransform> metagraph_transform_ = std::make_unique<GraphDefTransform>();
std::unique_ptr<AnfTransform> funcgraph_transform_ = std::make_unique<AnfTransform>();
};
int RunConverter(int argc, const char **argv);
} // namespace lite
} // namespace mindspore

View File

@ -1,8 +1,12 @@
remove_definitions(-DUSE_GLOG)
link_directories(${opencv_INC}/../lib)
add_executable(converter_lite main.cc converter_flags.cc
${TOP_DIR}/mindspore/lite/src/common/log.cc
${TOP_DIR}/mindspore/lite/src/common/utils.cc
${TOP_DIR}/mindspore/core/utils/status.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../common/flag_parser.cc
${CMAKE_CURRENT_SOURCE_DIR}/../../common/string_util.cc)
target_link_libraries(converter_lite PRIVATE mindspore-converter)
target_link_libraries(converter_lite mindspore_converter)

View File

@ -13,8 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/converter_flags.h"
#include "tools/converter/converter_lite/converter_flags.h"
#include <climits>
#include <cstdlib>
#include <string>
@ -22,27 +21,13 @@
#include <vector>
#include <memory>
#include <algorithm>
#include "ir/dtype/type_id.h"
#include "common/file_utils.h"
#include "tools/common/string_util.h"
#include "common/log_util.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/config_parser/config_file_parser.h"
#include "tools/converter/config_parser/preprocess_parser.h"
#include "tools/converter/config_parser/quant_param_parser.h"
#include "tools/converter/config_parser/acl_option_param_parser.h"
#include "tools/converter/config_parser/micro_param_parser.h"
namespace mindspore {
namespace converter {
#include "tools/common/string_util.h"
namespace mindspore::converter {
using mindspore::lite::RET_INPUT_PARAM_INVALID;
using mindspore::lite::RET_OK;
namespace {
constexpr size_t kPluginPathMaxNum = 10;
constexpr int kQuantBitNumInt16 = 16;
constexpr int kPathLengthUpperLimit = 1024;
constexpr int kMinShapeSizeInStr = 2;
} // namespace
Flags::Flags() {
AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", "");
AddFlag(&Flags::modelFile, "modelFile",
@ -106,13 +91,13 @@ Flags::Flags() {
int Flags::InitInputOutputDataType() {
if (this->inputDataTypeStr == "FLOAT") {
this->inputDataType = TypeId::kNumberTypeFloat32;
this->inputDataType = DataType::kNumberTypeFloat32;
} else if (this->inputDataTypeStr == "INT8") {
this->inputDataType = TypeId::kNumberTypeInt8;
this->inputDataType = DataType::kNumberTypeInt8;
} else if (this->inputDataTypeStr == "UINT8") {
this->inputDataType = TypeId::kNumberTypeUInt8;
this->inputDataType = DataType::kNumberTypeUInt8;
} else if (this->inputDataTypeStr == "DEFAULT") {
this->inputDataType = TypeId::kTypeUnknown;
this->inputDataType = DataType::kTypeUnknown;
} else {
std::cerr
<< "INPUT INVALID: inputDataType is invalid: %s, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT, got: "
@ -121,13 +106,13 @@ int Flags::InitInputOutputDataType() {
}
if (this->outputDataTypeStr == "FLOAT") {
this->outputDataType = TypeId::kNumberTypeFloat32;
this->outputDataType = DataType::kNumberTypeFloat32;
} else if (this->outputDataTypeStr == "INT8") {
this->outputDataType = TypeId::kNumberTypeInt8;
this->outputDataType = DataType::kNumberTypeInt8;
} else if (this->outputDataTypeStr == "UINT8") {
this->outputDataType = TypeId::kNumberTypeUInt8;
this->outputDataType = DataType::kNumberTypeUInt8;
} else if (this->outputDataTypeStr == "DEFAULT") {
this->outputDataType = TypeId::kTypeUnknown;
this->outputDataType = DataType::kTypeUnknown;
} else {
std::cerr
<< "INPUT INVALID: outputDataType is invalid: %s, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT, got: "
@ -177,11 +162,11 @@ int Flags::InitTrainModel() {
std::cerr << "INPUT ILLEGAL: train model converter supporting only MINDIR format" << std::endl;
return RET_INPUT_PARAM_INVALID;
}
if ((this->inputDataType != TypeId::kNumberTypeFloat32) && (this->inputDataType != TypeId::kTypeUnknown)) {
if ((this->inputDataType != DataType::kNumberTypeFloat32) && (this->inputDataType != DataType::kTypeUnknown)) {
std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 input tensors" << std::endl;
return RET_INPUT_PARAM_INVALID;
}
if ((this->outputDataType != TypeId::kNumberTypeFloat32) && (this->outputDataType != TypeId::kTypeUnknown)) {
if ((this->outputDataType != DataType::kNumberTypeFloat32) && (this->outputDataType != DataType::kTypeUnknown)) {
std::cerr << "INPUT ILLEGAL: train model converter supporting only FP32 output tensors" << std::endl;
return RET_INPUT_PARAM_INVALID;
}
@ -202,7 +187,11 @@ int Flags::InitInTensorShape() const {
}
shape.clear();
auto string_split = lite::StrSplit(shape_str, std::string(":"));
CHECK_LESS_RETURN(string_split.size(), kMinShapeSizeInStr);
constexpr int kMinShapeSizeInStr = 2;
if (string_split.size() < kMinShapeSizeInStr) {
MS_LOG(ERROR) << "shape size must not be less than " << kMinShapeSizeInStr;
return mindspore::lite::RET_ERROR;
}
auto name = string_split[0];
for (size_t i = 1; i < string_split.size() - 1; ++i) {
name += ":" + string_split[i];
@ -236,7 +225,7 @@ int Flags::InitInTensorShape() const {
shape.push_back(dim_value);
}
}
lite::ConverterInnerContext::GetInstance()->UpdateGraphInputTensorShape(name, shape);
graph_input_shape_map[name] = shape;
}
return RET_OK;
}
@ -253,89 +242,6 @@ int Flags::InitGraphInputFormat() {
return RET_OK;
}
int Flags::InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser) {
auto extended_info = config_file_parser.GetRegistryInfoString();
if (!extended_info.plugin_path.empty()) {
const char delimiter = ';';
auto relative_path = lite::SplitStringToVector(extended_info.plugin_path, delimiter);
if (relative_path.size() > kPluginPathMaxNum) {
MS_LOG(ERROR) << "extended plugin library's num is too big, which shouldn't be larger than " << kPluginPathMaxNum;
return RET_INPUT_PARAM_INVALID;
}
for (auto &i : relative_path) {
this->pluginsPath.push_back(lite::RealPath(i.c_str()));
}
}
if (!extended_info.disable_fusion.empty()) {
if (extended_info.disable_fusion == "on") {
this->disableFusion = true;
} else if (extended_info.disable_fusion == "off") {
this->disableFusion = false;
} else {
std::cerr << "CONFIG SETTING ILLEGAL: disable_fusion should be on/off" << std::endl;
return RET_INPUT_PARAM_INVALID;
}
}
return RET_OK;
}
void Flags::InitAclDefaultOption() {
this->aclModelOptionCfgParam.om_file_path = this->outputFile;
this->aclModelOptionCfgParam.offline = true;
}
int Flags::InitConfigFile() {
lite::ConfigFileParser config_file_parser;
auto ret = config_file_parser.ParseConfigFile(this->configFile);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse config file failed.";
return ret;
}
ret =
lite::PreprocessParser::ParsePreprocess(config_file_parser.GetDataPreProcessString(), &this->dataPreProcessParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse preprocess failed.";
return ret;
}
ret = lite::QuantParamParser::ParseCommonQuant(config_file_parser.GetCommonQuantString(), &this->commonQuantParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse common quant param failed.";
return ret;
}
ret = lite::QuantParamParser::ParseFullQuant(config_file_parser.GetFullQuantString(), &this->fullQuantParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse full quant param failed.";
return ret;
}
ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_file_parser.GetMixedBitWeightQuantString(),
&this->mixedBitWeightQuantParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse mixed bit weight quant param failed.";
return ret;
}
ret = InitExtendedIntegrationInfo(config_file_parser);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse extended integration info failed.";
return ret;
}
lite::AclOptionParamParser acl_param_parser;
ret = acl_param_parser.ParseAclOptionCfg(config_file_parser.GetAclOptionCfgString(), &this->aclModelOptionCfgParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse acl option param failed.";
return ret;
}
(void)CheckOfflineParallelConfig(this->configFile, &parallel_split_config_);
lite::MicroParamParser micro_param_parser;
ret = micro_param_parser.ParseMicroParam(config_file_parser.GetMicroParamString(), &this->microParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse micro param failed.";
return ret;
}
return RET_OK;
}
int Flags::InitSaveFP16() {
if (saveFP16Str == "on") {
saveFP16 = true;
@ -398,13 +304,6 @@ int Flags::InitEncrypt() {
MS_LOG(ERROR) << "If you don't need to use model encryption, please set --encryption=false.";
return RET_INPUT_PARAM_INVALID;
}
keyLen = lite::Hex2ByteArray(encKeyStr, encKey, kEncMaxLen);
if (keyLen != kEncMaxLen) {
MS_LOG(ERROR) << "enc_key " << encKeyStr << " must expressed in hexadecimal characters "
<< " and only support AES-GCM method and the key length is 16.";
return RET_INPUT_PARAM_INVALID;
}
encKeyStr.clear();
}
return RET_OK;
}
@ -412,7 +311,7 @@ int Flags::InitEncrypt() {
int Flags::PreInit(int argc, const char **argv) {
if (argc == 1) {
std::cout << this->Usage() << std::endl;
return lite::RET_SUCCESS_EXIT;
return lite::RET_OK;
}
lite::Option<std::string> err = this->ParseFlags(argc, argv);
@ -424,7 +323,7 @@ int Flags::PreInit(int argc, const char **argv) {
if (this->help) {
std::cout << this->Usage() << std::endl;
return lite::RET_SUCCESS_EXIT;
return lite::RET_OK;
}
if (this->modelFile.empty()) {
std::cerr << "INPUT MISSING: model file path is necessary" << std::endl;
@ -450,15 +349,6 @@ int Flags::PreInit(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID;
}
if (!this->configFile.empty()) {
auto ret = InitConfigFile();
if (ret != RET_OK) {
std::cerr << "Init config file failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
}
InitAclDefaultOption();
return RET_OK;
}
@ -526,120 +416,4 @@ int Flags::Init(int argc, const char **argv) {
}
return RET_OK;
}
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) {
// device: [device0 device1] ---> {cpu, gpu}
// computeRate: [x: y] x >=0 && y >=0 && x/y < 10
MS_ASSERT(parallel_split_config != nullptr);
std::vector<std::string> config_devices = {"cpu", "gpu", "npu"};
auto compute_rate_result = GetStrFromConfigFile(file, kComputeRate);
if (compute_rate_result.empty()) {
return false;
}
std::string device0_result = GetStrFromConfigFile(file, kSplitDevice0);
if (device0_result.empty()) {
return false;
}
std::string device1_result = GetStrFromConfigFile(file, kSplitDevice1);
if (device1_result.empty()) {
return false;
}
bool device0_flag = false;
bool device1_flag = false;
for (const auto &device : config_devices) {
if (device == device0_result) {
device0_flag = true;
}
if (device == device1_result) {
device1_flag = true;
}
}
if (!device0_flag || !device1_flag) {
return false;
}
const char delimiter = ';';
std::vector<std::string> device_rates = lite::SplitStringToVector(compute_rate_result, delimiter);
const char colon = ':';
for (const auto &device : device_rates) {
std::vector<std::string> rate = lite::SplitStringToVector(device, colon);
int64_t compute_rate = 0;
try {
compute_rate = std::stoi(rate.back());
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Get compute rate failed: " << e.what();
return false;
}
parallel_split_config->parallel_compute_rates_.push_back(compute_rate);
}
if (parallel_split_config->parallel_compute_rates_.size() != 2) {
return false;
}
int64_t bigger_rate = INT32_MIN;
int64_t smaller_rate = INT32_MAX;
for (const auto &rate : parallel_split_config->parallel_compute_rates_) {
if (rate <= 0 || rate > INT32_MAX) {
return false;
}
bigger_rate = std::max(rate, bigger_rate);
smaller_rate = std::min(rate, smaller_rate);
}
parallel_split_config->parallel_devices_.push_back(device0_result);
parallel_split_config->parallel_devices_.push_back(device1_result);
// parall_split_type will extend by other user's attr
parallel_split_config->parallel_split_type_ = SplitByUserRatio;
return bigger_rate / smaller_rate <= kMaxSplitRatio;
}
std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key) {
std::string res;
if (file.empty()) {
MS_LOG(ERROR) << "file is nullptr";
return res;
}
auto resolved_path = std::make_unique<char[]>(PATH_MAX);
if (resolved_path == nullptr) {
MS_LOG(ERROR) << "new resolved_path failed";
return "";
}
#ifdef _WIN32
auto *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit);
#else
char *real_path = realpath(file.c_str(), resolved_path.get());
#endif
if (real_path == nullptr || strlen(real_path) == 0) {
MS_LOG(ERROR) << "file path is not valid : " << file;
return "";
}
std::ifstream ifs(resolved_path.get());
if (!ifs.good()) {
MS_LOG(ERROR) << "file: " << real_path << " is not exist";
return res;
}
if (!ifs.is_open()) {
MS_LOG(ERROR) << "file: " << real_path << "open failed";
return res;
}
std::string line;
while (std::getline(ifs, line)) {
lite::Trim(&line);
if (line.empty() || line.at(0) == '#' || line.at(0) == '[') {
continue;
}
auto index = line.find('=');
if (index == std::string::npos) {
MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
return "";
}
auto key = line.substr(0, index);
auto value = line.substr(index + 1);
lite::Trim(&key);
lite::Trim(&value);
if (key == target_key) {
return value;
}
}
return res;
}
} // namespace converter
} // namespace mindspore
} // namespace mindspore::converter

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,104 +13,65 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H_
#include <string>
#include <vector>
#include <map>
#include "include/api/format.h"
#include "include/api/data_type.h"
#include "include/registry/converter_context.h"
#include "tools/common/flag_parser.h"
#include "ir/dtype/type_id.h"
#include "schema/inner/model_generated.h"
#include "tools/converter/preprocess/preprocess_param.h"
#include "tools/converter/quantizer/quant_params.h"
#include "tools/converter/adapter/acl/common/acl_types.h"
#include "micro/coder/config.h"
namespace mindspore {
namespace lite {
class ConfigFileParser;
} // namespace lite
namespace converter {
using mindspore::schema::QuantType;
enum ParallelSplitType { SplitNo = 0, SplitByUserRatio = 1, SplitByUserAttr = 2 };
constexpr auto kMaxSplitRatio = 10;
constexpr auto kComputeRate = "computeRate";
constexpr auto kSplitDevice0 = "device0";
constexpr auto kSplitDevice1 = "device1";
constexpr size_t kEncMaxLen = 16;
struct ParallelSplitConfig {
ParallelSplitType parallel_split_type_ = SplitNo;
std::vector<int64_t> parallel_compute_rates_;
std::vector<std::string> parallel_devices_;
};
class Flags : public virtual mindspore::lite::FlagParser {
public:
Flags();
~Flags() = default;
int InitInputOutputDataType();
int InitFmk();
int InitTrainModel();
int InitConfigFile();
int InitInTensorShape() const;
int InitGraphInputFormat();
int InitExtendedIntegrationInfo(const lite::ConfigFileParser &config_file_parser);
int InitEncrypt();
int InitPreInference();
int InitSaveFP16();
int InitNoFusion();
void InitAclDefaultOption();
int InitExportMindIR();
int Init(int argc, const char **argv);
int PreInit(int argc, const char **argv);
std::string modelFile;
std::string outputFile;
std::string fmkIn;
FmkType fmk;
std::string modelFile;
std::string outputFile;
std::string weightFile;
TypeId inputDataType;
TypeId outputDataType;
std::string saveFP16Str = "off";
bool saveFP16 = false;
std::string noFusionStr = "false";
bool disableFusion = false;
std::string inputDataTypeStr;
DataType inputDataType;
std::string outputDataTypeStr;
ParallelSplitConfig parallel_split_config_{};
DataType outputDataType;
std::string configFile;
std::string trainModelIn;
bool trainModel = false;
std::vector<std::string> pluginsPath;
bool disableFusion = false;
std::string inTensorShape;
mutable std::map<std::string, std::vector<int64_t>> graph_input_shape_map;
std::string dec_key = "";
std::string dec_mode = "AES-GCM";
std::string graphInputFormatStr;
std::string device;
mindspore::Format graphInputFormat = mindspore::NHWC;
std::string encKeyStr;
std::string encMode = "AES-GCM";
std::string inferStr;
bool infer = false;
std::string exportMindIR;
bool export_mindir = false;
#ifdef ENABLE_OPENSSL
std::string encryptionStr = "true";
bool encryption = true;
@ -118,22 +79,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string encryptionStr = "false";
bool encryption = false;
#endif
bool infer = false;
unsigned char encKey[kEncMaxLen];
size_t keyLen = 0;
bool export_mindir = false;
lite::quant::CommonQuantParam commonQuantParam;
lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam;
lite::quant::FullQuantParam fullQuantParam;
lite::preprocess::DataPreProcessParam dataPreProcessParam;
lite::acl::AclModelOptionCfg aclModelOptionCfgParam;
lite::micro::MicroParam microParam;
};
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config);
std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key);
} // namespace converter
} // namespace mindspore

View File

@ -0,0 +1,82 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if defined(__linux__) && !defined(Debug)
#include <csignal>
#endif
#define USE_DEPRECATED_API
#include <iostream>
#include "include/converter.h"
#include "include/api/status.h"
#include "tools/converter/converter_lite/converter_flags.h"
#if defined(__linux__) && !defined(Debug)
void SignalHandler(int sig) {
printf("encounter an unknown error, please verify the input model file or build the debug version\n");
exit(1);
}
#endif
int main(int argc, const char **argv) {
#if defined(__linux__) && !defined(Debug)
signal(SIGSEGV, SignalHandler);
signal(SIGABRT, SignalHandler);
signal(SIGFPE, SignalHandler);
signal(SIGBUS, SignalHandler);
#endif
#ifndef Debug
try {
#endif
mindspore::converter::Flags flags;
auto ret = flags.Init(argc, argv);
if (ret != mindspore::kSuccess) {
MS_LOG(ERROR) << "Flags Init failed. Ret: " << ret;
std::cout << "Flags Init failed. Ret: " << ret << std::endl;
return ret;
}
mindspore::Converter converter(flags.fmk, flags.modelFile, flags.outputFile, flags.weightFile);
converter.SetConfigFile(flags.configFile);
converter.SetWeightFp16(flags.saveFP16);
converter.SetInputShape(flags.graph_input_shape_map);
converter.SetInputFormat(flags.graphInputFormat);
converter.SetInputDataType(flags.inputDataType);
converter.SetOutputDataType(flags.outputDataType);
converter.SetExportMindIR(flags.export_mindir);
converter.SetDecryptKey(flags.dec_key);
flags.dec_key.clear();
converter.SetDecryptMode(flags.dec_mode);
converter.SetEnableEncryption(flags.encryption);
converter.SetEncryptKey(flags.encKeyStr);
flags.encKeyStr.clear();
converter.SetInfer(flags.infer);
converter.SetTrainModel(flags.trainModel);
converter.SetNoFusion(flags.disableFusion);
auto status = converter.Convert();
if (status != mindspore::kSuccess) {
MS_LOG(ERROR) << "Convert failed. Ret: " << status;
std::cout << "Convert failed. Ret: " << status << std::endl;
}
return status.StatusCode();
#ifndef Debug
} catch (const std::exception &e) {
std::cout << e.what() << std::endl;
std::cout << "encounter an unknown error, please verify the input model file or build the debug version\n";
return mindspore::kLiteError;
}
#endif
}

View File

@ -16,6 +16,7 @@
#include "include/converter.h"
#include "include/api/data_type.h"
#include "tools/converter/cxx_api/converter_para.h"
#include "tools/converter/converter_context.h"
#include "src/common/log_adapter.h"
namespace mindspore {
@ -65,6 +66,9 @@ bool Converter::GetWeightFp16() const {
void Converter::SetInputShape(const std::map<std::string, std::vector<int64_t>> &input_shape) {
if (data_ != nullptr) {
for (auto &it : input_shape) {
lite::ConverterInnerContext::GetInstance()->UpdateGraphInputTensorShape(it.first, it.second);
}
data_->input_shape = input_shape;
}
}

View File

@ -59,7 +59,7 @@ void CloneGraphInputs(const FuncGraphPtr &origin, const FuncGraphPtr &mirror, No
}
AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const FuncGraphPtr &mirror_graph,
const converter::Flags *flags) {
const std::shared_ptr<ConverterPara> &param) {
MS_ASSERT(cnode != nullptr && mirror_graph != nullptr);
if (index >= cnode->size()) {
MS_LOG(ERROR) << "input index out of range.";
@ -92,9 +92,9 @@ AnfNodePtr CloneParameterAndValueNode(const CNodePtr &cnode, size_t index, const
DataInfo data_info;
STATUS status;
if (utils::isa<Parameter>(node)) {
status = FetchDataFromParameterNode(cnode, index, flags->fmk, &data_info, true);
status = FetchDataFromParameterNode(cnode, index, param->fmk_type, &data_info, true);
} else if (utils::isa<ValueNode>(node)) {
status = FetchDataFromValueNode(cnode, index, flags->fmk, flags->trainModel, &data_info, true);
status = FetchDataFromValueNode(cnode, index, param->fmk_type, param->train_model, &data_info, true);
} else {
status = RET_ERROR;
}
@ -152,10 +152,10 @@ PrimitivePtr ClonePrimitive(const CNodePtr &cnode) {
}
} // namespace
FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *flags,
FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const std::shared_ptr<ConverterPara> &param,
std::map<FuncGraphPtr, FuncGraphPtr> *cloned_func_graph) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(flags != nullptr);
MS_ASSERT(param != nullptr);
MS_ASSERT(cloned_func_graph != nullptr);
auto cloned_func_graph_iter = cloned_func_graph->find(graph);
if (cloned_func_graph_iter != cloned_func_graph->end()) {
@ -196,10 +196,10 @@ FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *f
if (mirror_input == nullptr) {
if (IsValueNode<FuncGraph>(origin_input)) {
auto sub_func_graph = GetValueNode<FuncGraphPtr>(origin_input);
auto mirror_sub_graph = CloneFuncGraph(sub_func_graph, flags, cloned_func_graph);
auto mirror_sub_graph = CloneFuncGraph(sub_func_graph, param, cloned_func_graph);
mirror_input = NewValueNode(mirror_sub_graph);
} else {
mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph, flags);
mirror_input = CloneParameterAndValueNode(cnode, i, mirror_graph, param);
}
if (mirror_input == nullptr) {
MS_LOG(ERROR) << "node input cannot be found.";
@ -226,11 +226,11 @@ FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *f
return mirror_graph;
}
STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) {
STATUS ExportModel(const FuncGraphPtr &graph, const std::shared_ptr<ConverterPara> &param) {
CHECK_NULL_RETURN(graph);
CHECK_NULL_RETURN(flags);
CHECK_NULL_RETURN(param);
std::map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph;
auto mirror_graph = CloneFuncGraph(graph, flags, &cloned_func_graph);
auto mirror_graph = CloneFuncGraph(graph, param, &cloned_func_graph);
if (mirror_graph == nullptr) {
MS_LOG(ERROR) << "Clone funcGraph failed.";
return RET_ERROR;
@ -253,8 +253,8 @@ STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) {
CHECK_NULL_RETURN(optimizer);
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
CHECK_NULL_RETURN(graph_pm);
if (flags->fmk == converter::kFmkTypeTflite || flags->fmk == converter::kFmkTypeTf ||
flags->fmk == converter::kFmkTypeOnnx) {
if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
param->fmk_type == converter::kFmkTypeOnnx) {
graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
}
optimizer->AddPassManager(graph_pm);
@ -274,7 +274,7 @@ STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags) {
return RET_ERROR;
}
metagraph_transform->SetGraphDef(meta_graph);
auto status = metagraph_transform->Transform(*flags);
auto status = metagraph_transform->Transform(param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Transform meta graph failed " << status;
delete meta_graph;

View File

@ -18,14 +18,16 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_EXPORT_MODEL_H
#include <map>
#include "tools/converter/converter_flags.h"
#include <memory>
#include "include/errorcode.h"
#include "ir/func_graph.h"
#include "tools/converter/cxx_api/converter_para.h"
namespace mindspore {
namespace lite {
FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const converter::Flags *flags,
FuncGraphPtr CloneFuncGraph(const FuncGraphPtr &graph, const std::shared_ptr<ConverterPara> &param,
std::map<FuncGraphPtr, FuncGraphPtr> *cloned_func_graph);
STATUS ExportModel(const FuncGraphPtr &graph, const converter::Flags *flags);
STATUS ExportModel(const FuncGraphPtr &graph, const std::shared_ptr<ConverterPara> &param);
} // namespace lite
} // namespace mindspore

View File

@ -19,7 +19,6 @@
#include <algorithm>
#include "schema/model_generated.h"
#include "src/common/log_adapter.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
#include "tools/converter/legacy_optimizer/graph/infershape_pass.h"
@ -51,11 +50,11 @@ std::vector<schema::CNodeT *> GetGraphNodes(const schema::MetaGraphT &graph_defT
return old_nodes;
}
int QuantTransform(const converter::Flags &ctx, schema::MetaGraphT *graph_defT) {
int QuantTransform(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT *graph_defT) {
MS_ASSERT(graph_defT != nullptr);
// quantization
if (ctx.commonQuantParam.quant_type == schema::QuantType_QUANT_NONE ||
ctx.commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) {
if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE ||
param->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) {
{
// quantization
// init old node indices
@ -63,7 +62,7 @@ int QuantTransform(const converter::Flags &ctx, schema::MetaGraphT *graph_defT)
Optimizer tensor_quant_optimizer;
tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
tensor_quant_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type));
tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass());
tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
auto status = tensor_quant_optimizer.Run(graph_defT);
@ -78,8 +77,9 @@ int QuantTransform(const converter::Flags &ctx, schema::MetaGraphT *graph_defT)
Optimizer quant_node_optimizer;
quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
auto old_nodes = GetGraphNodes(*graph_defT);
quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(ctx.inputDataType, ctx.outputDataType));
quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type));
quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(static_cast<TypeId>(param->input_data_type),
static_cast<TypeId>(param->output_data_type)));
quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
@ -94,12 +94,12 @@ int QuantTransform(const converter::Flags &ctx, schema::MetaGraphT *graph_defT)
}
} // namespace
int GraphDefTransform::Transform(const converter::Flags &ctx) {
int GraphDefTransform::Transform(const std::shared_ptr<ConverterPara> &param) {
STATUS status;
{
auto old_nodes = GetGraphNodes(*graph_defT_);
Optimizer unused_op_remove_optimizer;
if (!ctx.trainModel) {
if (!param->train_model) {
unused_op_remove_optimizer.AddPass(new (std::nothrow) DropoutNodeRemovePass());
}
unused_op_remove_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
@ -116,7 +116,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
// init old node indices
auto old_nodes = GetGraphNodes(*graph_defT_);
Optimizer format_trans_optimizer;
if (!ctx.trainModel && ctx.fmk != converter::kFmkTypeOnnx) {
if (!param->train_model && param->fmk_type != converter::kFmkTypeOnnx) {
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
}
@ -127,7 +127,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
auto ret = QuantTransform(ctx, graph_defT_);
auto ret = QuantTransform(param, graph_defT_);
if (ret != RET_OK && status != RET_NO_CHANGE) {
return status;
}
@ -149,10 +149,10 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
{
Optimizer forming_model_optimizer;
forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass(ctx));
forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(param->fmk_type));
forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass(param));
forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass());
forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(ctx.saveFP16));
forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(param->weight_fp16));
status = forming_model_optimizer.Run(graph_defT_);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed.";

View File

@ -23,7 +23,6 @@
#include "tools/converter/quantizer/quantizer.h"
#include "schema/inner/model_generated.h"
#include "tools/common/meta_graph_serializer.h"
#include "tools/converter/converter_flags.h"
namespace mindspore {
namespace lite {
@ -35,7 +34,7 @@ class GraphDefTransform {
public:
GraphDefTransform();
virtual ~GraphDefTransform();
virtual int Transform(const converter::Flags &ctx);
virtual int Transform(const std::shared_ptr<ConverterPara> &param);
void SetGraphDef(schema::MetaGraphT *dst_def);
protected:

View File

@ -19,8 +19,8 @@
#include <set>
#include <string>
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "include/registry/converter_context.h"
using mindspore::converter::FmkType;
using mindspore::schema::QuantType;

View File

@ -20,8 +20,8 @@
#include <string>
#include <vector>
#include <set>
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "include/registry/converter_context.h"
using mindspore::converter::FmkType;
using mindspore::schema::QuantType;

View File

@ -42,11 +42,12 @@ constexpr size_t kDependInputNum = 3;
constexpr size_t kDependFirstInputIdx = 1;
constexpr size_t kTupleGetItemFirstInputIdx = 1;
} // namespace
STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag) {
STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph,
const std::shared_ptr<ConverterPara> &param) {
MS_ASSERT(func_graph != nullptr);
auto primitive_adjust_pass = std::make_shared<PrimitiveAdjust>();
MS_CHECK_TRUE_MSG(primitive_adjust_pass != nullptr, RET_NULL_PTR, "primitive_adjust_pass is nullptr.");
primitive_adjust_pass->SetFmkType(flag.fmk);
primitive_adjust_pass->SetFmkType(param->fmk_type);
if (!primitive_adjust_pass->Run(func_graph)) {
MS_LOG(ERROR) << "primitive adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
@ -54,14 +55,14 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const
}
auto mindir_adjust_pass = std::make_shared<MindirAdjust>();
MS_CHECK_TRUE_MSG(mindir_adjust_pass != nullptr, RET_NULL_PTR, "mindir_adjust_pass is nullptr.");
mindir_adjust_pass->SetFmkType(flag.fmk);
mindir_adjust_pass->SetTrainFlag(flag.trainModel);
mindir_adjust_pass->SetFmkType(param->fmk_type);
mindir_adjust_pass->SetTrainFlag(param->train_model);
if (!mindir_adjust_pass->Run(func_graph)) {
MS_LOG(ERROR) << "MindIr adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
if (!flag.trainModel) {
if (!param->train_model) {
auto cast_op_adjust = std::make_shared<CastOpAdjust>();
MS_CHECK_TRUE_MSG(cast_op_adjust != nullptr, RET_NULL_PTR, "cast_op_adjust is nullptr.");
if (!cast_op_adjust->Run(func_graph)) {
@ -72,7 +73,7 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const
}
auto mindir_control_flow_adjust = std::make_shared<MindIRControlFlowAdjust>();
MS_CHECK_TRUE_MSG(mindir_control_flow_adjust != nullptr, RET_NULL_PTR, "mindir_control_flow_adjust is nullptr.");
mindir_control_flow_adjust->SetFmkType(flag.fmk);
mindir_control_flow_adjust->SetFmkType(param->fmk_type);
if (!mindir_control_flow_adjust->Run(func_graph)) {
MS_LOG(ERROR) << "MindIR control flow adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
@ -221,35 +222,37 @@ void MindsporeImporter::RemoveUnusedGraphInput(const FuncGraphPtr &func_graph) {
}
}
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag, const void *buff, const size_t &size) {
FuncGraphPtr MindsporeImporter::ImportMindIR(const std::shared_ptr<ConverterPara> &param, const void *buff,
const size_t &size) {
MindIRLoader mindir_loader;
auto func_graph = mindir_loader.LoadMindIR(buff, size);
return CheckAndUpdateFuncGraph(flag, func_graph);
return CheckAndUpdateFuncGraph(param, func_graph);
}
FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
FuncGraphPtr MindsporeImporter::ImportMindIR(const std::shared_ptr<ConverterPara> &param) {
FuncGraphPtr func_graph;
if (!flag.dec_key.empty()) {
if (!param->decrypt_key.empty()) {
unsigned char key[32];
const size_t key_len = Hex2ByteArray(flag.dec_key, key, 32);
const size_t key_len = Hex2ByteArray(param->decrypt_key, key, 32);
if (key_len == 0) {
return nullptr;
}
MindIRLoader mindir_loader(false, key, key_len, flag.dec_mode, false);
func_graph = mindir_loader.LoadMindIR(flag.modelFile);
MindIRLoader mindir_loader(false, key, key_len, param->decrypt_mode, false);
func_graph = mindir_loader.LoadMindIR(param->model_file);
auto ret = memset_s(key, sizeof(key), 0, key_len);
if (ret != 0) {
MS_LOG(EXCEPTION) << "memset_s error";
}
} else {
MindIRLoader mindir_loader;
func_graph = mindir_loader.LoadMindIR(flag.modelFile);
func_graph = mindir_loader.LoadMindIR(param->model_file);
}
return CheckAndUpdateFuncGraph(flag, func_graph);
return CheckAndUpdateFuncGraph(param, func_graph);
}
FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const converter::Flags &flag, FuncGraphPtr func_graph) {
FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const std::shared_ptr<ConverterPara> &param,
FuncGraphPtr func_graph) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
MS_LOG(ERROR)
@ -299,12 +302,12 @@ FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const converter::Flags &
MS_LOG(INFO) << "There is no need to adjust and pass graph when in Ascend.";
return func_graph;
#endif
if ((status = Mindir2AnfAdjust(func_graph, flag)) != RET_OK) {
if ((status = Mindir2AnfAdjust(func_graph, param)) != RET_OK) {
MS_LOG(ERROR) << "Mindir2AnfAdjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, flag.trainModel);
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, param->train_model);
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "unify_format is nullptr.");
if (!unify_format->Run(func_graph)) {
MS_LOG(ERROR) << "Run insert transpose failed.";

View File

@ -20,23 +20,25 @@
#include <set>
#include <string>
#include <vector>
#include "tools/converter/converter_flags.h"
#include <memory>
#include "load_mindir/load_model.h"
#include "tools/converter/cxx_api/converter_para.h"
#include "include/errorcode.h"
namespace mindspore::lite {
class MindsporeImporter {
public:
MindsporeImporter() = default;
~MindsporeImporter() = default;
FuncGraphPtr ImportMindIR(const converter::Flags &flag);
FuncGraphPtr ImportMindIR(const converter::Flags &flag, const void *buff, const size_t &size);
FuncGraphPtr ImportMindIR(const std::shared_ptr<ConverterPara> &param);
FuncGraphPtr ImportMindIR(const std::shared_ptr<ConverterPara> &param, const void *buff, const size_t &size);
private:
static void RemoveUnusedGraphInput(const FuncGraphPtr &func_graph);
STATUS GetFuncGraphOutputName(const CNodePtr &cnode);
STATUS TraceOutput(const AnfNodePtr &node);
FuncGraphPtr CheckAndUpdateFuncGraph(const converter::Flags &flag, FuncGraphPtr func_graph);
STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const converter::Flags &flag);
FuncGraphPtr CheckAndUpdateFuncGraph(const std::shared_ptr<ConverterPara> &param, FuncGraphPtr func_graph);
STATUS Mindir2AnfAdjust(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param);
std::vector<std::string> output_tensor_name_;
};

View File

@ -21,8 +21,8 @@
#include <string>
#include <vector>
#include "backend/common/optimizer/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "include/registry/converter_context.h"
using mindspore::converter::FmkType;
namespace mindspore {

View File

@ -28,7 +28,6 @@
#include "src/runtime/infer_manager.h"
#include "src/common/primitive_t_utils.h"
#include "tools/common/node_util.h"
#include "tools/converter/converter_flags.h"
#include "src/common/string_utils.h"
#include "src/common/log_util.h"
#include "nnacl/op_base.h"

View File

@ -25,10 +25,8 @@
#include <set>
#include "tools/common/graph_util.h"
#include "tools/converter/optimizer.h"
#include "tools/converter/converter_flags.h"
#include "include/registry/converter_context.h"
using mindspore::converter::kFmkTypeTf;
using mindspore::schema::TensorT;
namespace mindspore {
namespace lite {
struct InferTensor {
@ -61,7 +59,7 @@ class InferShapePass : public GraphPass {
void InitInferTensor(MetaGraphT *graph);
int InferSubgraph(const int64_t &subgraph_index, MetaGraphT *graph);
converter::FmkType fmk_type_ = kFmkTypeTf;
converter::FmkType fmk_type_ = converter::kFmkTypeTf;
std::vector<InferTensor> tensors_ = {};
std::set<CNodeT *> partial_cnode_inferred_{};
};

View File

@ -23,7 +23,7 @@ STATUS SetUnusedQuantParamToDefaultPass::Run(schema::MetaGraphT *graph) {
for (auto &tensor : graph->allTensors) {
bool has_quant_param = false;
for (auto &quant_param : tensor->quantParams) {
if (ctx_.fullQuantParam.target_device != quant::NVGPU) {
if (param_->fullQuantParam.target_device != quant::NVGPU) {
quant_param->min = 0.0;
quant_param->max = 0.0;
}

View File

@ -17,21 +17,22 @@
#define LITE_UNUSED_QUANT_PARAM_DATA_REMOVE_PASS_H
#include <memory>
#include "tools/converter/optimizer.h"
#include "tools/converter/converter_flags.h"
#include "tools/common/graph_util.h"
#include "tools/converter/cxx_api/converter_para.h"
namespace mindspore {
namespace lite {
class SetUnusedQuantParamToDefaultPass : public GraphPass {
public:
SetUnusedQuantParamToDefaultPass() {}
explicit SetUnusedQuantParamToDefaultPass(const converter::Flags &ctx) : ctx_(ctx) {}
explicit SetUnusedQuantParamToDefaultPass(const std::shared_ptr<ConverterPara> &param) : param_(param) {}
~SetUnusedQuantParamToDefaultPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
private:
converter::Flags ctx_;
const std::shared_ptr<ConverterPara> param_;
};
} // namespace lite
} // namespace mindspore

View File

@ -1,57 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if defined(__linux__) && !defined(Debug)
#include <csignal>
#endif
#define USE_DEPRECATED_API
#include "tools/converter/converter.h"
#if defined(__linux__) && !defined(Debug)
void SignalHandler(int sig) {
printf("encounter an unknown error, please verify the input model file or build the debug version\n");
exit(1);
}
#endif
namespace mindspore {
extern "C" {
extern void common_log_init();
}
} // namespace mindspore
int main(int argc, const char **argv) {
#if defined(__linux__) && !defined(Debug)
signal(SIGSEGV, SignalHandler);
signal(SIGABRT, SignalHandler);
signal(SIGFPE, SignalHandler);
signal(SIGBUS, SignalHandler);
#endif
int ret = 0;
#ifndef Debug
try {
#endif
mindspore::common_log_init();
ret = mindspore::lite::RunConverter(argc, argv);
#ifndef Debug
} catch (const std::exception &e) {
ret = mindspore::lite::RET_ERROR;
std::cout << e.what() << std::endl;
std::cout << "encounter an unknown error, please verify the input model file or build the debug version\n";
}
#endif
return ret;
}

View File

@ -22,11 +22,9 @@
#include <vector>
#include <map>
#include "backend/common/optimizer/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
using mindspore::converter::FmkType;
namespace mindspore::opt {
typedef enum { kThenBranch = 0, kElseBranch = 1 } BranchType;

View File

@ -22,9 +22,9 @@
#include <vector>
#include <memory>
#include "backend/common/optimizer/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/ops/ops_def.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "include/registry/converter_context.h"
using mindspore::converter::FmkType;
namespace mindspore::opt {

View File

@ -21,11 +21,9 @@
#include <vector>
#include <map>
#include "backend/common/optimizer/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/parser/tf/functionalize_control_op_pass.h"
using mindspore::converter::FmkType;
namespace mindspore::opt {
constexpr const int POS_INVALID = -1;

View File

@ -3,6 +3,7 @@ merge_parser(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_BINARY_DIR}/tools/converter/par
file(GLOB_RECURSE TFLITE_SRC_LIST ${CMAKE_BINARY_DIR}/tools/converter/parser/tflite/tflite_op_parser.cc)
set_property(SOURCE ${TFLITE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
set_property(SOURCE ${TFLITE_SRC_LIST} PROPERTY COMPILE_DEFINITIONS FLATBUFFERS_LOCALE_INDEPENDENT=0)
add_library(tflite_parser_mid OBJECT
${TFLITE_SRC_LIST}
)

View File

@ -313,7 +313,7 @@ int BiasCorrectionStrategy::Int8Inference(const MSKernelCallBack &before_call_ba
MS_LOG(ERROR) << "New model failed.";
return RET_ERROR;
}
auto ret = BuildModelByFuncGraph(int8_model, quant_func_graph, flags_);
auto ret = BuildModelByFuncGraph(int8_model, quant_func_graph, param_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build error.";
return RET_ERROR;
@ -637,7 +637,7 @@ MSKernelCallBack BiasCorrectionStrategy::GetNVGPUInt8AfterCallBack() {
int BiasCorrectionStrategy::DoBiasCorrection(const FuncGraphPtr &quant_func_graph) {
int status;
switch (this->flags_.fullQuantParam.target_device) {
switch (param_->fullQuantParam.target_device) {
case CPU:
status = DoCPUBiasCorrection(quant_func_graph);
break;
@ -645,8 +645,7 @@ int BiasCorrectionStrategy::DoBiasCorrection(const FuncGraphPtr &quant_func_grap
status = DoNVGPUBiasCorrection(quant_func_graph);
break;
default:
MS_LOG(ERROR) << "Unsupported target device " << this->flags_.fullQuantParam.target_device
<< " for bias correction.";
MS_LOG(ERROR) << "Unsupported target device " << param_->fullQuantParam.target_device << " for bias correction.";
return RET_ERROR;
}
if (status != RET_OK) {

View File

@ -39,10 +39,10 @@ enum CallBackType {
class BiasCorrectionStrategy {
public:
BiasCorrectionStrategy(const converter::Flags &flags, const std::shared_ptr<Calibrator> &calibrator,
BiasCorrectionStrategy(const std::shared_ptr<ConverterPara> &param, const std::shared_ptr<Calibrator> &calibrator,
const std::shared_ptr<QuantStrategy> &quant_strategy,
std::shared_ptr<mindspore::Model> fp32_ms_model, int activation_q_min, int activation_q_max)
: flags_(flags),
: param_(param),
calibrator_(calibrator),
quant_strategy_(quant_strategy),
fp32_ms_model_(fp32_ms_model),
@ -133,7 +133,7 @@ class BiasCorrectionStrategy {
}
private:
converter::Flags flags_;
const std::shared_ptr<ConverterPara> param_;
std::shared_ptr<Calibrator> calibrator_{nullptr};
std::shared_ptr<QuantStrategy> quant_strategy_{nullptr};
std::shared_ptr<mindspore::Model> fp32_ms_model_{nullptr};

View File

@ -28,8 +28,7 @@
namespace mindspore::lite::quant {
class CLEStrategy {
public:
explicit CLEStrategy(const FuncGraphPtr &func_graph, const converter::Flags &flags)
: func_graph_(func_graph), flags_(flags) {}
explicit CLEStrategy(const FuncGraphPtr &func_graph) : func_graph_(func_graph) {}
~CLEStrategy();
@ -54,7 +53,6 @@ class CLEStrategy {
private:
FuncGraphPtr func_graph_ = nullptr;
converter::Flags flags_;
CLEPattern *cle_pattern_ = nullptr;
};
} // namespace mindspore::lite::quant

View File

@ -663,11 +663,11 @@ int DebugInfoManager::SaveOutputInfo(const std::string &file_path) {
int DebugInfoManager::StatisticsDataPerRound(
const std::shared_ptr<mindspore::Model> &origin, const std::shared_ptr<mindspore::Model> &quant,
const std::map<std::string, OpParameter *> &op_parameters, const converter::Flags &config,
const std::map<std::string, OpParameter *> &op_parameters, const std::shared_ptr<ConverterPara> &param,
const std::map<std::string, mindspore::schema::Tensor *> &origin_input_tensor_map,
const std::map<std::string, mindspore::schema::Tensor *> &quant_input_tensor_map, const int &round) {
int ret;
auto data_preprocess = config.dataPreProcessParam;
auto data_preprocess = param->dataPreProcessParam;
for (auto tensor : origin->GetInputs()) {
if (data_preprocess.calibrate_size > 0) {
ret = preprocess::PreProcess(data_preprocess, tensor.Name(), round, &tensor);
@ -681,8 +681,8 @@ int DebugInfoManager::StatisticsDataPerRound(
}
std::cout << "Statistics the original data distribution. Round " << round << std::endl;
auto origin_before_callBack =
GetBeforeCallBack(origin_input_tensor_map, op_parameters, true, config.commonQuantParam.debug_mode);
auto origin_after_callBack = GetAfterCallBack(op_parameters, true, config.commonQuantParam.debug_mode);
GetBeforeCallBack(origin_input_tensor_map, op_parameters, true, param->commonQuantParam.debug_mode);
auto origin_after_callBack = GetAfterCallBack(op_parameters, true, param->commonQuantParam.debug_mode);
auto origin_outputs = origin->GetOutputs();
auto status = origin->Predict(origin->GetInputs(), &origin_outputs, origin_before_callBack, origin_after_callBack);
if (status != kSuccess) {
@ -692,8 +692,8 @@ int DebugInfoManager::StatisticsDataPerRound(
std::cout << "Statistics the quant data distribution. Round " << round << std::endl;
auto quant_before_callBack =
GetBeforeCallBack(quant_input_tensor_map, op_parameters, false, config.commonQuantParam.debug_mode);
auto quant_after_callBack = GetAfterCallBack(op_parameters, false, config.commonQuantParam.debug_mode);
GetBeforeCallBack(quant_input_tensor_map, op_parameters, false, param->commonQuantParam.debug_mode);
auto quant_after_callBack = GetAfterCallBack(op_parameters, false, param->commonQuantParam.debug_mode);
for (auto &tensor : quant->GetInputs()) {
auto tensor_data = tensor.MutableData();
CHECK_NULL_RETURN(tensor_data);
@ -722,7 +722,7 @@ std::string DebugInfoManager::CreateFilePath(const std::string &dir_path, const
int DebugInfoManager::CompareOriginWithQuant(const std::shared_ptr<mindspore::Model> &origin,
const std::shared_ptr<mindspore::Model> &quant,
const std::map<std::string, OpParameter *> &op_parameters,
const converter::Flags &config,
const std::shared_ptr<ConverterPara> &param,
const mindspore::lite::LiteModel &origin_lite_model,
const mindspore::lite::LiteModel &quant_lite_model) {
auto begin = GetTimeUs();
@ -732,27 +732,27 @@ int DebugInfoManager::CompareOriginWithQuant(const std::shared_ptr<mindspore::Mo
origin_outputs_[tensor.Name()] = tensor;
}
int ret;
auto data_preprocess = config.dataPreProcessParam;
auto data_preprocess = param->dataPreProcessParam;
// When the calibration data set does not exist, use 1 round of random numbers for comparison
int rounds = data_preprocess.calibrate_size > 0 ? data_preprocess.calibrate_size : 1;
for (int round = 0; round < rounds; round++) {
ret = StatisticsDataPerRound(origin, quant, op_parameters, config, origin_input_tensor_map, quant_input_tensor_map,
ret = StatisticsDataPerRound(origin, quant, op_parameters, param, origin_input_tensor_map, quant_input_tensor_map,
round);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Statistics Data failed for round: " << round;
FreeBuffer();
return RET_ERROR;
}
ret = GetClipAndCos(config.commonQuantParam.debug_mode);
ret = GetClipAndCos(param->commonQuantParam.debug_mode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get clip and cos failed.";
FreeBuffer();
return ret;
}
GetOutputInfo();
if (config.commonQuantParam.debug_mode == quant::DETAIL) {
if (param->commonQuantParam.debug_mode == quant::DETAIL) {
auto file_name = "round_" + std::to_string(round) + ".csv";
auto file_path = CreateFilePath(config.commonQuantParam.debug_info_save_path, file_name);
auto file_path = CreateFilePath(param->commonQuantParam.debug_info_save_path, file_name);
ret = SaveInfo(file_path);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Failed to save debug info to " + file_path;
@ -765,7 +765,7 @@ int DebugInfoManager::CompareOriginWithQuant(const std::shared_ptr<mindspore::Mo
}
auto file_name = "quant_param.csv";
auto quant_param_save_path = CreateFilePath(config.commonQuantParam.debug_info_save_path, file_name);
auto quant_param_save_path = CreateFilePath(param->commonQuantParam.debug_info_save_path, file_name);
ret = SaveQuantParam(quant_param_save_path);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Failed to save quant param to " + quant_param_save_path;
@ -773,7 +773,7 @@ int DebugInfoManager::CompareOriginWithQuant(const std::shared_ptr<mindspore::Mo
}
file_name = "output_summary.csv";
auto output_param_save_path = CreateFilePath(config.commonQuantParam.debug_info_save_path, file_name);
auto output_param_save_path = CreateFilePath(param->commonQuantParam.debug_info_save_path, file_name);
ret = SaveOutputInfo(output_param_save_path);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Failed to save output param to " + output_param_save_path;

View File

@ -87,7 +87,8 @@ class DebugInfoManager {
public:
int CompareOriginWithQuant(const std::shared_ptr<mindspore::Model> &origin,
const std::shared_ptr<mindspore::Model> &quant,
const std::map<std::string, OpParameter *> &op_parameters, const converter::Flags &config,
const std::map<std::string, OpParameter *> &op_parameters,
const std::shared_ptr<ConverterPara> &param,
const mindspore::lite::LiteModel &origin_lite_model,
const mindspore::lite::LiteModel &quant_lite_model);
@ -158,7 +159,8 @@ class DebugInfoManager {
int StatisticsDataPerRound(const std::shared_ptr<mindspore::Model> &origin,
const std::shared_ptr<mindspore::Model> &quant,
const std::map<std::string, OpParameter *> &op_parameters, const converter::Flags &config,
const std::map<std::string, OpParameter *> &op_parameters,
const std::shared_ptr<ConverterPara> &param,
const std::map<string, schema::Tensor *> &origin_input_tensor_map,
const std::map<string, schema::Tensor *> &quant_input_tensor_map, const int &round);

View File

@ -22,9 +22,9 @@
namespace mindspore::lite::quant {
int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
// Dynamic dont support filters.
flags_.commonQuantParam.min_quant_weight_channel = 0;
flags_.commonQuantParam.min_quant_weight_size = 0;
auto quantizer = WeightQuantizer(flags_);
param_->commonQuantParam.min_quant_weight_channel = 0;
param_->commonQuantParam.min_quant_weight_size = 0;
auto quantizer = WeightQuantizer(param_);
const std::set<PrimitivePtr> support_weight_quant_nodes = {prim::kPrimMatMulFusion, prim::kPrimGather};
const std::set<PrimitivePtr> symmetric_nodes = {prim::kPrimMatMulFusion};
auto ret = quantizer.WeightQuant(func_graph, support_weight_quant_nodes, {}, symmetric_nodes);
@ -36,7 +36,7 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
const std::set<PrimitivePtr> support_dynamic_quant_ops = {
prim::kPrimMatMulFusion,
};
ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops, flags_.commonQuantParam.skip_quant_node);
ret = manager.InsertDynamicQuantNode(func_graph, support_dynamic_quant_ops, param_->commonQuantParam.skip_quant_node);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Insert dynamic quant failed.";
return ret;

View File

@ -41,8 +41,8 @@
namespace mindspore::lite::quant {
class DynamicQuantizer : public Quantizer {
public:
explicit DynamicQuantizer(const converter::Flags &flags) : Quantizer(flags) {
bit_num_ = flags.commonQuantParam.bit_num;
explicit DynamicQuantizer(const std::shared_ptr<ConverterPara> &param) : Quantizer(param) {
bit_num_ = param->commonQuantParam.bit_num;
}
~DynamicQuantizer() = default;

View File

@ -180,7 +180,7 @@ int FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const Parame
MS_LOG(ERROR) << op_name << " Do bias quant failed.";
return ret;
}
} else if (flags_.fullQuantParam.per_channel && CheckNodeInSet(cnode, per_channel_ops_)) {
} else if (param_->fullQuantParam.per_channel && CheckNodeInSet(cnode, per_channel_ops_)) {
ret = DoParameterWeightQuant(cnode, input_node, primitive, input_index, true);
if (ret != RET_OK) {
MS_LOG(ERROR) << op_name << " Do bias quant failed.";
@ -443,7 +443,7 @@ void FullQuantQuantizer::InitKirinConfig() {
weight_channel_symmetric_ = true;
weight_layer_symmetric_ = false;
support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection};
flags_.fullQuantParam.bias_correction = false;
param_->fullQuantParam.bias_correction = false;
per_channel_ops_ = {prim::kPrimConv2DFusion};
}
@ -457,7 +457,7 @@ void FullQuantQuantizer::InitNvGpuConfig() {
support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimMatMul, prim::kPrimActivation,
prim::kPrimConv2dTransposeFusion};
per_channel_ops_ = {};
flags_.fullQuantParam.bias_correction = false;
param_->fullQuantParam.bias_correction = false;
}
void FullQuantQuantizer::InitQMinMax() {
@ -509,7 +509,7 @@ int FullQuantQuantizer::MarkQuantNode(const FuncGraphPtr &func_graph) {
}
int FullQuantQuantizer::InitDeviceConfig(const FuncGraphPtr &func_graph) {
switch (flags_.fullQuantParam.target_device) {
switch (param_->fullQuantParam.target_device) {
case CPU:
InitCpuConfig();
break;
@ -520,18 +520,18 @@ int FullQuantQuantizer::InitDeviceConfig(const FuncGraphPtr &func_graph) {
InitNvGpuConfig();
break;
default:
MS_LOG(ERROR) << " Unsupported device " << flags_.fullQuantParam.target_device;
MS_LOG(ERROR) << " Unsupported device " << param_->fullQuantParam.target_device;
return RET_ERROR;
break;
}
InitQMinMax();
calibrator_ = std::make_shared<Calibrator>(this->bit_num_, activation_q_max_, activation_q_min_,
this->flags_.fullQuantParam.activation_quant_method,
this->flags_.dataPreProcessParam, activation_symmetric_);
param_->fullQuantParam.activation_quant_method,
param_->dataPreProcessParam, activation_symmetric_);
MSLITE_CHECK_PTR(calibrator_);
quant_strategy_ = std::make_unique<QuantStrategy>(flags_.commonQuantParam.min_quant_weight_size,
flags_.commonQuantParam.min_quant_weight_channel,
flags_.commonQuantParam.skip_quant_node);
quant_strategy_ = std::make_unique<QuantStrategy>(param_->commonQuantParam.min_quant_weight_size,
param_->commonQuantParam.min_quant_weight_channel,
param_->commonQuantParam.skip_quant_node);
CHECK_NULL_RETURN(quant_strategy_);
auto ret = MarkQuantNode(func_graph);
if (ret != RET_OK) {
@ -593,7 +593,7 @@ int FullQuantQuantizer::DoInference(CollectType collect_type) {
int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_ASSERT(func_graph != nullptr);
MS_LOG(INFO) << "start to parse config file";
if (flags_.dataPreProcessParam.calibrate_path.empty()) {
if (param_->dataPreProcessParam.calibrate_path.empty()) {
MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
return RET_INPUT_PARAM_INVALID;
}
@ -611,7 +611,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_LOG(ERROR) << "New model failed.";
return RET_ERROR;
}
auto ret = BuildModelByFuncGraph(fp32_ms_model_, func_graph, flags_);
auto ret = BuildModelByFuncGraph(fp32_ms_model_, func_graph, param_);
if (ret != mindspore::kSuccess) {
MS_LOG(ERROR) << "Build model failed.";
return RET_ERROR;
@ -623,7 +623,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
return status;
}
if (flags_.fullQuantParam.activation_quant_method == KL) {
if (param_->fullQuantParam.activation_quant_method == KL) {
status = QuantWithKL();
if (status != RET_OK) {
MS_LOG(ERROR) << "Quant with KL failed.";
@ -649,9 +649,9 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
}
}
if (this->flags_.fullQuantParam.bias_correction) {
if (param_->fullQuantParam.bias_correction) {
MS_LOG(INFO) << "do bias correction";
BiasCorrectionStrategy strategy(flags_, calibrator_, quant_strategy_, fp32_ms_model_, activation_q_min_,
BiasCorrectionStrategy strategy(param_, calibrator_, quant_strategy_, fp32_ms_model_, activation_q_min_,
activation_q_max_);
status = strategy.DoBiasCorrection(func_graph);
if (status != RET_OK) {

View File

@ -39,8 +39,8 @@
namespace mindspore::lite::quant {
class FullQuantQuantizer : public Quantizer {
public:
explicit FullQuantQuantizer(const converter::Flags &flags) : Quantizer(flags) {
bit_num_ = flags.commonQuantParam.bit_num;
explicit FullQuantQuantizer(const std::shared_ptr<ConverterPara> &param) : Quantizer(param) {
bit_num_ = param_->commonQuantParam.bit_num;
}
~FullQuantQuantizer() override;

View File

@ -44,12 +44,12 @@ MinMax ParameterOptimizer::GetFineTuneRange(std::vector<float> *candidate_scales
return min_max;
}
int ParameterOptimizer::CloneFuncGraph(const FuncGraphPtr &func_graph, converter::Flags *flags,
int ParameterOptimizer::CloneFuncGraph(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param,
FuncGraphPtr *func_graph_bak) {
CHECK_NULL_RETURN(func_graph_bak);
CHECK_NULL_RETURN(flags);
CHECK_NULL_RETURN(param);
std::map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph;
*func_graph_bak = lite::CloneFuncGraph(func_graph, flags, &cloned_func_graph);
*func_graph_bak = lite::CloneFuncGraph(func_graph, param, &cloned_func_graph);
CHECK_NULL_RETURN(*func_graph_bak);
static auto root_func_manager = Manage(*func_graph_bak);
std::set<FuncGraphPtr> all_func_graphs = {};
@ -60,11 +60,12 @@ int ParameterOptimizer::CloneFuncGraph(const FuncGraphPtr &func_graph, converter
return RET_OK;
}
int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags,
int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph,
const std::shared_ptr<ConverterPara> &param,
std::shared_ptr<mindspore::Model> origin_model, int origin_model_size,
const InferenceParam &param, double *init_scale,
const InferenceParam &infer_param, double *init_scale,
std::vector<float> *candidate_scales, bool is_run_all) {
CHECK_NULL_RETURN(flags);
CHECK_NULL_RETURN(param);
CHECK_NULL_RETURN(origin_model);
CHECK_NULL_RETURN(init_scale);
CHECK_NULL_RETURN(candidate_scales);
@ -75,18 +76,18 @@ int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph
float best_compress_cos_sim = 0.0f;
int best_compress_model_size = 0;
size_t over_error_count = 0;
for (size_t round = 0; round < param.rounds; round++) {
auto scale = param.start_scale + round * param.step;
flags->commonQuantParam.quant_type = schema::QuantType_QUANT_WEIGHT;
for (size_t round = 0; round < infer_param.rounds; round++) {
auto scale = infer_param.start_scale + round * infer_param.step;
param->commonQuantParam.quant_type = schema::QuantType_QUANT_WEIGHT;
FuncGraphPtr func_graph_bak;
auto ret = CloneFuncGraph(func_graph, flags, &func_graph_bak);
auto ret = CloneFuncGraph(func_graph, param, &func_graph_bak);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Clone FuncGraph failed.";
return ret;
}
// quant
auto quantizer = std::make_unique<quant::WeightQuantizer>(*flags);
auto quantizer = std::make_unique<quant::WeightQuantizer>(param);
CHECK_NULL_RETURN(quantizer);
auto status = quantizer->DoQuantize(func_graph_bak, scale);
if (status != RET_OK) {
@ -98,7 +99,7 @@ int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph
int weight_quant_size;
auto weight_quant_model = std::make_shared<mindspore::Model>();
CHECK_NULL_RETURN(weight_quant_model);
auto build_status = BuildModelByFuncGraph(weight_quant_model, func_graph_bak, *flags, &weight_quant_size);
auto build_status = BuildModelByFuncGraph(weight_quant_model, func_graph_bak, param, &weight_quant_size);
if (build_status != kSuccess) {
MS_LOG(WARNING) << "build model failed!";
continue;
@ -155,28 +156,29 @@ int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph
return RET_OK;
}
int ParameterOptimizer::OriginModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags,
int ParameterOptimizer::OriginModelInference(const FuncGraphPtr &func_graph,
const std::shared_ptr<ConverterPara> &param,
std::shared_ptr<mindspore::Model> origin_model, int *origin_model_size) {
CHECK_NULL_RETURN(flags);
CHECK_NULL_RETURN(param);
CHECK_NULL_RETURN(origin_model);
CHECK_NULL_RETURN(origin_model_size);
FuncGraphPtr func_graph_bak;
auto ret = CloneFuncGraph(func_graph, flags, &func_graph_bak);
auto ret = CloneFuncGraph(func_graph, param, &func_graph_bak);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Clone FuncGraph failed.";
return RET_ERROR;
}
flags->commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
param->commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
*origin_model_size = 0;
auto status = BuildModelByFuncGraph(origin_model, func_graph_bak, *flags, origin_model_size);
auto status = BuildModelByFuncGraph(origin_model, func_graph_bak, param, origin_model_size);
if (status != kSuccess) {
MS_LOG(ERROR) << "build model failed!";
return RET_ERROR;
}
auto origin_inputs = origin_model->GetInputs();
for (auto input : origin_inputs) {
if (flags->dataPreProcessParam.calibrate_size > 0) {
ret = preprocess::PreProcess(flags->dataPreProcessParam, input.Name(), 0, &input);
if (param->dataPreProcessParam.calibrate_size > 0) {
ret = preprocess::PreProcess(param->dataPreProcessParam, input.Name(), 0, &input);
} else {
ret = GenerateRandomData(&input);
}
@ -195,16 +197,16 @@ int ParameterOptimizer::OriginModelInference(const FuncGraphPtr &func_graph, con
return RET_OK;
}
int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, converter::Flags *flags,
int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param,
double *init_scale) {
CHECK_NULL_RETURN(flags);
CHECK_NULL_RETURN(param);
CHECK_NULL_RETURN(init_scale);
double default_init_scale = *init_scale;
auto origin_model = std::make_shared<mindspore::Model>();
int origin_model_size;
auto ret = OriginModelInference(func_graph, flags, origin_model, &origin_model_size);
auto ret = OriginModelInference(func_graph, param, origin_model, &origin_model_size);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Origin Model Inference failed.";
return ret;
@ -215,14 +217,14 @@ int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, conve
float step = 0.005f;
std::vector<float> candidate_scales;
InferenceParam param{};
param.rounds = giant_rounds;
param.start_scale = start_scale;
param.step = step;
param.thread_num = flags->commonQuantParam.thread_num;
InferenceParam infer_param{};
infer_param.rounds = giant_rounds;
infer_param.start_scale = start_scale;
infer_param.step = step;
infer_param.thread_num = param->commonQuantParam.thread_num;
std::cout << "==========Search with giant step==============\n";
ret = WeightQuantModelInference(func_graph, flags, origin_model, origin_model_size, param, init_scale,
ret = WeightQuantModelInference(func_graph, param, origin_model, origin_model_size, infer_param, init_scale,
&candidate_scales, false);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Weight quant graph inference failed.";
@ -240,12 +242,12 @@ int ParameterOptimizer::GridSearchForScale(const FuncGraphPtr &func_graph, conve
const int baby_step_rounds = 25;
step = (min_max.max - min_max.min) / baby_step_rounds;
param.rounds = baby_step_rounds;
param.start_scale = start_scale;
param.step = step;
param.thread_num = flags->commonQuantParam.thread_num;
infer_param.rounds = baby_step_rounds;
infer_param.start_scale = start_scale;
infer_param.step = step;
infer_param.thread_num = param->commonQuantParam.thread_num;
std::cout << "==========Search with baby step==============\n";
ret = WeightQuantModelInference(func_graph, flags, origin_model, origin_model_size, param, init_scale,
ret = WeightQuantModelInference(func_graph, param, origin_model, origin_model_size, infer_param, init_scale,
&candidate_scales, true);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Weight quant graph inference failed.";

View File

@ -28,7 +28,6 @@
#include "tools/converter/parser/parser_utils.h"
#include "include/model.h"
#include "base/base.h"
#include "tools/converter/converter_flags.h"
namespace mindspore::lite::quant {
struct InferenceParam {
@ -44,19 +43,21 @@ class ParameterOptimizer {
~ParameterOptimizer() = default;
int GridSearchForScale(const FuncGraphPtr &func_graph, converter::Flags *flags, double *init_scale);
int GridSearchForScale(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param,
double *init_scale);
private:
MinMax GetFineTuneRange(std::vector<float> *candidate_scales);
int CloneFuncGraph(const FuncGraphPtr &func_graph, converter::Flags *flags, FuncGraphPtr *func_graph_bak);
int CloneFuncGraph(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param,
FuncGraphPtr *func_graph_bak);
int WeightQuantModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags,
int WeightQuantModelInference(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param,
std::shared_ptr<mindspore::Model> origin_model, int origin_model_size,
const InferenceParam &param, double *init_scale, std::vector<float> *candidate_scales,
bool is_run_all);
const InferenceParam &infer_param, double *init_scale,
std::vector<float> *candidate_scales, bool is_run_all);
int OriginModelInference(const FuncGraphPtr &func_graph, converter::Flags *flags,
int OriginModelInference(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param,
std::shared_ptr<mindspore::Model> origin_model, int *origin_model_size);
};
} // namespace mindspore::lite::quant

View File

@ -59,8 +59,8 @@ void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_f
}
}
int DoFullQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto quantizer = std::make_unique<FullQuantQuantizer>(*config);
int DoFullQuant(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
auto quantizer = std::make_unique<FullQuantQuantizer>(param);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New FullQuantQuantizer failed";
return RET_ERROR;
@ -73,16 +73,16 @@ int DoFullQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) {
return RET_OK;
}
int DoWeightQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) {
double init_scale = config->mixedBitWeightQuantParam.init_scale;
if (config->commonQuantParam.bit_num == 0 && config->mixedBitWeightQuantParam.auto_tune) {
int DoWeightQuant(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
double init_scale = param->mixedBitWeightQuantParam.init_scale;
if (param->commonQuantParam.bit_num == 0 && param->mixedBitWeightQuantParam.auto_tune) {
ParameterOptimizer optimizer;
auto status = optimizer.GridSearchForScale(old_graph, const_cast<converter::Flags *>(config), &init_scale);
auto status = optimizer.GridSearchForScale(old_graph, param, &init_scale);
if (status != RET_OK) {
MS_LOG(ERROR) << "Grid search with scale failed.";
return status;
}
auto quantizer = std::make_unique<WeightQuantizer>(*config);
auto quantizer = std::make_unique<WeightQuantizer>(param);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
return RET_ERROR;
@ -93,7 +93,7 @@ int DoWeightQuant(const FuncGraphPtr &old_graph, const converter::Flags *config)
return RET_ERROR;
}
} else {
auto quantizer = std::make_unique<WeightQuantizer>(*config);
auto quantizer = std::make_unique<WeightQuantizer>(param);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
return RET_ERROR;
@ -107,8 +107,8 @@ int DoWeightQuant(const FuncGraphPtr &old_graph, const converter::Flags *config)
return RET_OK;
}
int DoDynamicQuant(const FuncGraphPtr &old_graph, const converter::Flags *config) {
auto quantizer = std::make_unique<DynamicQuantizer>(*config);
int DoDynamicQuant(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
auto quantizer = std::make_unique<DynamicQuantizer>(param);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New DynamicQuantizer failed";
return RET_ERROR;
@ -121,7 +121,7 @@ int DoDynamicQuant(const FuncGraphPtr &old_graph, const converter::Flags *config
return RET_OK;
}
lite::LiteModel *ParseLiteModel(const FuncGraphPtr &func_graph, const converter::Flags &flags) {
lite::LiteModel *ParseLiteModel(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param) {
auto meta_graph = Export(func_graph, true, true);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph failed";
@ -131,7 +131,7 @@ lite::LiteModel *ParseLiteModel(const FuncGraphPtr &func_graph, const converter:
// transform
GraphDefTransform fb_transform;
fb_transform.SetGraphDef(meta_graph);
auto status = fb_transform.Transform(flags);
auto status = fb_transform.Transform(param);
if (status != RET_OK) {
MS_LOG(ERROR) << "FBTransform model failed";
delete meta_graph;
@ -152,11 +152,11 @@ lite::LiteModel *ParseLiteModel(const FuncGraphPtr &func_graph, const converter:
return static_cast<LiteModel *>(LiteModel::Import((const char *)content, size));
}
int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config,
int DoQuantDebug(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param,
const std::shared_ptr<mindspore::Model> &origin_model, mindspore::lite::LiteModel *origin_lite_model) {
auto quant_model = std::make_shared<mindspore::Model>();
CHECK_NULL_RETURN(quant_model);
auto ret = BuildModelByFuncGraph(quant_model, old_graph, *config);
auto ret = BuildModelByFuncGraph(quant_model, old_graph, param);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed";
return RET_ERROR;
@ -165,12 +165,12 @@ int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config,
FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
DebugInfoManager manager;
auto quant_lite_model = ParseLiteModel(old_graph, *config);
auto quant_lite_model = ParseLiteModel(old_graph, param);
if (quant_lite_model == nullptr) {
MS_LOG(ERROR) << "Parse lite model failed";
return RET_ERROR;
}
auto status = manager.CompareOriginWithQuant(origin_model, quant_model, op_parameters, *config, *origin_lite_model,
auto status = manager.CompareOriginWithQuant(origin_model, quant_model, op_parameters, param, *origin_lite_model,
*quant_lite_model);
auto free_buffer = [&] {
for (auto parameter : op_parameters) {
@ -190,17 +190,17 @@ int DoQuantDebug(const FuncGraphPtr &old_graph, const converter::Flags *config,
return RET_OK;
}
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
CHECK_NULL_RETURN(config);
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE) {
int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
CHECK_NULL_RETURN(param);
if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_NONE) {
return RET_OK;
}
int status;
bool per_layer =
config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL && !config->fullQuantParam.per_channel;
param->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL && !param->fullQuantParam.per_channel;
if (per_layer) {
CLEStrategy cle_strategy(old_graph, *config);
CLEStrategy cle_strategy(old_graph);
status = cle_strategy.Run();
if (status != RET_OK) {
MS_LOG(ERROR) << "do pre process failed!";
@ -210,43 +210,44 @@ int DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags
std::shared_ptr<mindspore::Model> origin;
lite::LiteModel *origin_lite_model = nullptr;
if (config->commonQuantParam.is_debug) { // Bak fp32 model for debug
converter::Flags new_flag = *config;
new_flag.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
if (param->commonQuantParam.is_debug) { // Bak fp32 model for debug
auto quant_type = param->commonQuantParam.quant_type;
param->commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
origin = std::make_shared<mindspore::Model>();
CHECK_NULL_RETURN(origin);
auto ret = BuildModelByFuncGraph(origin, old_graph, new_flag);
auto ret = BuildModelByFuncGraph(origin, old_graph, param);
param->commonQuantParam.quant_type = quant_type;
if (ret != kSuccess) {
MS_LOG(ERROR) << "Build model failed";
return RET_ERROR;
}
origin_lite_model = ParseLiteModel(old_graph, *config);
origin_lite_model = ParseLiteModel(old_graph, param);
if (origin_lite_model == nullptr) {
MS_LOG(ERROR) << "Parse lite model failed.";
return RET_ERROR;
}
}
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) { // Full Quantization
status = DoFullQuant(old_graph, config);
if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) { // Full Quantization
status = DoFullQuant(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do full quant failed.";
return status;
}
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) { // Weight Quantization
status = DoWeightQuant(old_graph, config);
} else if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) { // Weight Quantization
status = DoWeightQuant(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do weight quant failed.";
return status;
}
} else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_DYNAMIC) { // Dynamic Quantization
status = DoDynamicQuant(old_graph, config);
} else if (param->commonQuantParam.quant_type == schema::QuantType_QUANT_DYNAMIC) { // Dynamic Quantization
status = DoDynamicQuant(old_graph, param);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do dynamic quant failed.";
return status;
}
}
if (config->commonQuantParam.is_debug) {
status = DoQuantDebug(old_graph, config, origin, origin_lite_model);
if (param->commonQuantParam.is_debug) {
status = DoQuantDebug(old_graph, param, origin, origin_lite_model);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do quant debug failed.";
return status;
@ -260,7 +261,7 @@ int QuantizationOptimizer::Run(const mindspore::FuncGraphPtr &func_graph) {
GetFuncGraphs(func_graph, &all_func_graphs);
// Support for multi-subgraph models
for (auto &item : all_func_graphs) {
auto status = DoSingleGraphQuantize(item, flags_);
auto status = DoSingleGraphQuantize(item, param_);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do Quantize failed.";
return status;

View File

@ -18,19 +18,20 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZATION_OPTIMIZER_H
#include <utility>
#include <map>
#include <memory>
#include <vector>
#include "backend/common/optimizer/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/cxx_api/converter_para.h"
namespace mindspore::lite::quant {
class QuantizationOptimizer {
public:
explicit QuantizationOptimizer(converter::Flags *flags) : flags_(flags) {}
explicit QuantizationOptimizer(const std::shared_ptr<ConverterPara> &param) : param_(param) {}
~QuantizationOptimizer() = default;
int Run(const FuncGraphPtr &func_graph);
private:
converter::Flags *flags_;
const std::shared_ptr<ConverterPara> &param_;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZATION_OPTIMIZER_H

View File

@ -255,13 +255,13 @@ std::string NodePrimitiveType(const CNodePtr &cnode) {
}
Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, const FuncGraphPtr &func_graph,
const converter::Flags &flags) {
const std::shared_ptr<ConverterPara> &param) {
int size = 0;
return BuildModelByFuncGraph(model, func_graph, flags, &size);
return BuildModelByFuncGraph(model, func_graph, param, &size);
}
Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, const FuncGraphPtr &func_graph,
const converter::Flags &flags, int *size) {
const std::shared_ptr<ConverterPara> &param, int *size) {
auto meta_graph = Export(func_graph, true, true);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph failed";
@ -271,7 +271,7 @@ Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, con
// transform
GraphDefTransform fb_transform;
fb_transform.SetGraphDef(meta_graph);
auto status = fb_transform.Transform(flags);
auto status = fb_transform.Transform(param);
if (status != RET_OK) {
MS_LOG(ERROR) << "FBTransform model failed";
delete meta_graph;

View File

@ -18,9 +18,7 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_
#ifndef _MSC_VER
#include <dirent.h>
#endif
#include <sys/stat.h>
@ -54,6 +52,7 @@
#include "src/common/file_utils.h"
#include "src/common/quant_utils.h"
#include "include/api/model.h"
#include "tools/converter/cxx_api/converter_para.h"
namespace mindspore::lite::quant {
enum WeightQuantType {
@ -188,10 +187,10 @@ int FixedBitQuantFilter(const AnfNodePtr &parameter_node, const tensor::TensorPt
std::string NodePrimitiveType(const CNodePtr &cnode);
Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, const FuncGraphPtr &func_graph,
const converter::Flags &flags);
const std::shared_ptr<mindspore::ConverterPara> &param);
Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, const FuncGraphPtr &func_graph,
const converter::Flags &flags, int *size);
const std::shared_ptr<mindspore::ConverterPara> &param, int *size);
mindspore::lite::Tensor *MSTensorToLiteTensor(const mindspore::MSTensor &tensor);

View File

@ -24,20 +24,20 @@
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "base/base.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/quant_param_holder.h"
#include "tools/converter/cxx_api/converter_para.h"
namespace mindspore::lite::quant {
class Quantizer {
public:
explicit Quantizer(const converter::Flags &config) : flags_(config) {}
explicit Quantizer(const std::shared_ptr<mindspore::ConverterPara> &param) : param_(param) {}
virtual ~Quantizer() = default;
virtual int DoQuantize(FuncGraphPtr func_graph) = 0;
protected:
converter::Flags flags_;
const std::shared_ptr<mindspore::ConverterPara> param_;
};
} // namespace mindspore::lite::quant
#endif

View File

@ -37,7 +37,7 @@ int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph,
continue;
}
auto op_name = cnode->fullname_with_scope();
if (flags_.commonQuantParam.skip_quant_node.find(op_name) != flags_.commonQuantParam.skip_quant_node.end()) {
if (param_->commonQuantParam.skip_quant_node.find(op_name) != param_->commonQuantParam.skip_quant_node.end()) {
MS_LOG(INFO) << op_name << " is skip dynamic quant.";
continue;
}
@ -95,9 +95,9 @@ int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CN
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " can not quant weight";
continue;
}
auto quant_strategy = std::make_unique<QuantStrategy>(flags_.commonQuantParam.min_quant_weight_size,
flags_.commonQuantParam.min_quant_weight_channel,
flags_.commonQuantParam.skip_quant_node);
auto quant_strategy = std::make_unique<QuantStrategy>(param_->commonQuantParam.min_quant_weight_size,
param_->commonQuantParam.min_quant_weight_channel,
param_->commonQuantParam.skip_quant_node);
CHECK_NULL_RETURN(quant_strategy);
int preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
if (!quant_strategy->CanTensorQuantized(cnode, input, preferred_dim)) {
@ -114,15 +114,15 @@ int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CN
}
auto status = RET_ERROR;
if (is_mixed_bit_) {
status = MixedBitQuantFilter(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type,
status = MixedBitQuantFilter(parameter, tensor_info, primitive, param_->commonQuantParam.quant_type,
WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, mixed_bit_init_scale_, idx - 1,
preferred_dim, symmetric);
} else if (type_id_ == kNumberTypeInt8) {
status =
FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max, q_min,
bit_num_, tmp_weight_quant_type, type_id_, idx - 1, preferred_dim, symmetric);
status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, param_->commonQuantParam.quant_type,
q_max, q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1,
preferred_dim, symmetric);
} else if (type_id_ == kNumberTypeInt16) {
status = FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type,
status = FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, param_->commonQuantParam.quant_type,
q_max, q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1,
preferred_dim, symmetric);
}

View File

@ -41,12 +41,12 @@
namespace mindspore::lite::quant {
class WeightQuantizer : public Quantizer {
public:
explicit WeightQuantizer(const converter::Flags &flags) : Quantizer(flags) {
bit_num_ = flags.commonQuantParam.bit_num;
explicit WeightQuantizer(const std::shared_ptr<ConverterPara> &param) : Quantizer(param) {
bit_num_ = param_->commonQuantParam.bit_num;
if (bit_num_ == 0) {
type_id_ = kNumberTypeInt16;
is_mixed_bit_ = true;
mixed_bit_init_scale_ = flags.mixedBitWeightQuantParam.init_scale;
mixed_bit_init_scale_ = param_->mixedBitWeightQuantParam.init_scale;
}
// parse param for fixed bit quant.
if (!is_mixed_bit_) {

View File

@ -23,8 +23,8 @@
#include "ir/primitive.h"
#include "ir/func_graph.h"
#include "src/common/utils.h"
#include "tools/converter/converter_flags.h"
#include "nnacl/op_base.h"
#include "include/registry/converter_context.h"
namespace mindspore {
namespace lite {

View File

@ -93,12 +93,12 @@ int MindIRSerializer::RemoveQuantParameterHolder(FuncGraphPtr func_graph) {
return RET_OK;
}
int MindIRSerializer::Save(const std::unique_ptr<converter::Flags> &flag, const FuncGraphPtr &func_graph) {
int MindIRSerializer::Save(const std::shared_ptr<ConverterPara> &param, const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func_graph is nullptr.";
return RET_NULL_PTR;
}
auto output_file = flag->outputFile;
auto output_file = param->output_file;
auto ret = ParserPath(output_file);
if (ret != RET_OK) {
MS_LOG(ERROR) << "parse path failed.";
@ -403,13 +403,13 @@ int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const st
}
#endif
int MindIRSerialize(const std::unique_ptr<converter::Flags> &flag, const FuncGraphPtr &func_graph) {
int MindIRSerialize(const std::shared_ptr<ConverterPara> &param, const FuncGraphPtr &func_graph) {
#ifndef ENABLE_CLOUD_AND_LITE
if (!flag->export_mindir) {
if (!param->export_mindir) {
return RET_OK;
}
mindspore::lite::MindIRSerializer serializer;
return serializer.Save(flag, func_graph);
return serializer.Save(param, func_graph);
#else
MS_LOG(INFO) << "No need to serialize mindir when load model online.";
return RET_OK;

View File

@ -24,9 +24,9 @@
#include <set>
#include "mindspore/core/ir/func_graph.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/converter_flags.h"
#include "proto/mind_ir.pb.h"
#include "mindspore/core/utils/system/env.h"
#include "tools/converter/cxx_api/converter_para.h"
namespace mindspore::lite {
#ifndef ENABLE_CLOUD_AND_LITE
@ -40,7 +40,7 @@ class MindIRSerializer {
data_fs_ = nullptr;
}
}
int Save(const std::unique_ptr<converter::Flags> &flag, const FuncGraphPtr &func_graph);
int Save(const std::shared_ptr<ConverterPara> &param, const FuncGraphPtr &func_graph);
private:
int ParserPath(const std::string &output_path);
@ -74,6 +74,6 @@ class MindIRSerializer {
};
#endif
// export func_graph
int MindIRSerialize(const std::unique_ptr<converter::Flags> &flag, const FuncGraphPtr &func_graph);
int MindIRSerialize(const std::shared_ptr<ConverterPara> &param, const FuncGraphPtr &func_graph);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_MINDIR_SERIALIZER_MINDIR_SERIALIZER_H_

View File

@ -19,7 +19,7 @@
#include <string>
#include "backend/common/optimizer/optimizer.h"
#include "tools/converter/converter_flags.h"
#include "include/registry/converter_context.h"
using mindspore::converter::FmkType;
namespace mindspore::opt {

View File

@ -35,8 +35,8 @@ const BaseRef MatMulActivationFusion::DefinePattern() const {
const AnfNodePtr MatMulActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
// Int8 MatMul Kernel dont support matmul+activation
if (ctx_.commonQuantParam.quant_type == schema::QuantType_QUANT_ALL ||
ctx_.commonQuantParam.quant_type == schema::QuantType_QUANT_DYNAMIC) {
if (param_->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL ||
param_->commonQuantParam.quant_type == schema::QuantType_QUANT_DYNAMIC) {
return nullptr;
}
if (func_graph == nullptr || node == nullptr) {

View File

@ -17,19 +17,18 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ACTIVATION_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_MATMUL_ACTIVATION_FUSION_H_
#include <memory>
#include <string>
#include "backend/common/optimizer/optimizer.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/cxx_api/converter_para.h"
namespace mindspore {
namespace opt {
class MatMulActivationFusion : public PatternProcessPass {
public:
explicit MatMulActivationFusion(const converter::Flags &ctx, bool multigraph = true)
: PatternProcessPass("MatMulActivationFusion", multigraph) {
ctx_ = ctx;
}
explicit MatMulActivationFusion(const std::shared_ptr<ConverterPara> &param, bool multigraph = true)
: PatternProcessPass("MatMulActivationFusion", multigraph), param_(param) {}
~MatMulActivationFusion() = default;
private:
@ -37,7 +36,7 @@ class MatMulActivationFusion : public PatternProcessPass {
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
private:
converter::Flags ctx_;
const std::shared_ptr<ConverterPara> param_;
};
} // namespace opt
} // namespace mindspore

View File

@ -17,11 +17,8 @@
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CLIP_CONVERT_ACTIVATION_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_CLIP_CONVERT_ACTIVATION_PASS_H_
#include <string>
#include "tools/converter/converter_flags.h"
#include "backend/common/optimizer/pass.h"
using mindspore::converter::FmkType;
using mindspore::schema::QuantType;
namespace mindspore::opt {
class ClipConvertActivationPass : public Pass {
public:

View File

@ -16,6 +16,7 @@
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_DUMP_GRAPH_H_
#include <memory>
#include "backend/common/optimizer/pass.h"
#include "tools/converter/export_model.h"
#include "include/registry/pass_base.h"
@ -25,11 +26,11 @@ namespace mindspore {
namespace opt {
class DumpGraph : public registry::PassBase, public Pass {
public:
explicit DumpGraph(const converter::Flags *flags = nullptr) : Pass("DumpGraph"), flags_(flags) {}
explicit DumpGraph(const std::shared_ptr<ConverterPara> &param) : Pass("DumpGraph"), param_(param) {}
~DumpGraph() = default;
bool Run(const FuncGraphPtr &graph) override {
MS_CHECK_TRUE_MSG(graph != nullptr, false, "funcGraph is a nullptr.");
if (lite::ExportModel(graph, flags_) != lite::RET_OK) {
if (lite::ExportModel(graph, param_) != lite::RET_OK) {
MS_LOG(ERROR) << "dump graph failed.";
return false;
}
@ -45,7 +46,7 @@ class DumpGraph : public registry::PassBase, public Pass {
}
private:
const converter::Flags *flags_{nullptr};
const std::shared_ptr<ConverterPara> param_;
};
} // namespace opt
} // namespace mindspore

View File

@ -20,10 +20,10 @@
#include <memory>
#include <utility>
#include <string>
#include "tools/converter/converter_flags.h"
#include "backend/common/optimizer/pass.h"
#include "include/errorcode.h"
#include "mindspore/core/ir/manager.h"
#include "include/registry/converter_context.h"
using mindspore::converter::FmkType;
namespace mindspore::opt {

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_
#include <string>
#include "backend/common/optimizer/pass.h"
#include "tools/converter/converter_flags.h"
#include "include/registry/converter_context.h"
using mindspore::converter::FmkType;
namespace mindspore::opt {