forked from mindspore-Ecosystem/mindspore
!18904 [lite]open flags when converting
Merge pull request !18904 from 徐安越/master_core
This commit is contained in:
commit
fe9954ff52
|
@ -323,7 +323,8 @@ elseif(WIN32)
|
||||||
else()
|
else()
|
||||||
if(SUPPORT_TRAIN)
|
if(SUPPORT_TRAIN)
|
||||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE
|
||||||
|
PATTERN "framework.h" EXCLUDE)
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
|
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
|
||||||
${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
|
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
|
||||||
|
@ -331,7 +332,7 @@ else()
|
||||||
else()
|
else()
|
||||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
|
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
|
||||||
PATTERN "*registry.h" EXCLUDE)
|
PATTERN "*registry.h" EXCLUDE PATTERN "framework.h" EXCLUDE)
|
||||||
endif()
|
endif()
|
||||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||||
|
|
|
@ -16,13 +16,27 @@
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
||||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <string>
|
||||||
#include "include/lite_utils.h"
|
#include "include/lite_utils.h"
|
||||||
#include "include/registry/framework.h"
|
#include "include/registry/framework.h"
|
||||||
|
#include "schema/inner/model_generated.h"
|
||||||
|
|
||||||
using mindspore::lite::converter::FmkType;
|
using mindspore::lite::converter::FmkType;
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
namespace converter {
|
||||||
|
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
|
||||||
|
struct MS_API ConverterParameters {
|
||||||
|
FmkType fmk_;
|
||||||
|
schema::QuantType quant_type_;
|
||||||
|
std::string model_file_;
|
||||||
|
std::string weight_file_;
|
||||||
|
std::map<std::string, std::string> attrs_;
|
||||||
|
};
|
||||||
|
} // namespace converter
|
||||||
|
|
||||||
/// \brief ModelParser defined a model parser
|
/// \brief ModelParser defined a model parser
|
||||||
class MS_API ModelParser;
|
class MS_API ModelParser;
|
||||||
|
|
||||||
|
@ -56,7 +70,7 @@ class MS_API ModelParserRegistry {
|
||||||
/// \param[in] creator Define function pointer of creating ModelParser.
|
/// \param[in] creator Define function pointer of creating ModelParser.
|
||||||
int RegParser(const FmkType fmk, ModelParserCreator creator);
|
int RegParser(const FmkType fmk, ModelParserCreator creator);
|
||||||
|
|
||||||
std::unordered_map<FmkType, ModelParserCreator> parsers_;
|
std::map<FmkType, ModelParserCreator> parsers_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// \brief ModelRegistrar defined registration class of ModelParser.
|
/// \brief ModelRegistrar defined registration class of ModelParser.
|
||||||
|
|
|
@ -18,10 +18,9 @@
|
||||||
#include "common/common_test.h"
|
#include "common/common_test.h"
|
||||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
|
||||||
|
|
||||||
using mindspore::lite::ModelRegistrar;
|
using mindspore::lite::ModelRegistrar;
|
||||||
using mindspore::lite::converter::Flags;
|
using mindspore::lite::converter::ConverterParameters;
|
||||||
using mindspore::lite::converter::FmkType_CAFFE;
|
using mindspore::lite::converter::FmkType_CAFFE;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class ModelParserRegistryTest : public mindspore::CommonTest {
|
class ModelParserRegistryTest : public mindspore::CommonTest {
|
||||||
|
@ -39,8 +38,8 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
|
||||||
TestModelParserCreator); // register test model parser creator, which will overwrite existing.
|
TestModelParserCreator); // register test model parser creator, which will overwrite existing.
|
||||||
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser(FmkType_CAFFE);
|
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser(FmkType_CAFFE);
|
||||||
ASSERT_NE(model_parser, nullptr);
|
ASSERT_NE(model_parser, nullptr);
|
||||||
Flags flags;
|
ConverterParameters converter_parameters;
|
||||||
auto func_graph = model_parser->Parse(flags);
|
auto func_graph = model_parser->Parse(converter_parameters);
|
||||||
ASSERT_NE(func_graph, nullptr);
|
ASSERT_NE(func_graph, nullptr);
|
||||||
auto node_list = func_graph->GetOrderedCnodes();
|
auto node_list = func_graph->GetOrderedCnodes();
|
||||||
ASSERT_EQ(node_list.size(), 3);
|
ASSERT_EQ(node_list.size(), 3);
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include "include/registry/model_parser_registry.h"
|
#include "include/registry/model_parser_registry.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
FuncGraphPtr ModelParserTest::Parse(const lite::converter::Flags &flag) {
|
FuncGraphPtr ModelParserTest::Parse(const lite::converter::ConverterParameters &flag) {
|
||||||
// construct funcgraph
|
// construct funcgraph
|
||||||
res_graph_ = std::make_shared<FuncGraph>();
|
res_graph_ = std::make_shared<FuncGraph>();
|
||||||
auto ret = InitOriginModelStructure();
|
auto ret = InitOriginModelStructure();
|
||||||
|
|
|
@ -23,13 +23,12 @@
|
||||||
#include "include/registry/model_parser_registry.h"
|
#include "include/registry/model_parser_registry.h"
|
||||||
#include "ut/tools/converter/registry/node_parser_test.h"
|
#include "ut/tools/converter/registry/node_parser_test.h"
|
||||||
#include "tools/converter/model_parser.h"
|
#include "tools/converter/model_parser.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class ModelParserTest : public lite::ModelParser {
|
class ModelParserTest : public lite::ModelParser {
|
||||||
public:
|
public:
|
||||||
ModelParserTest() = default;
|
ModelParserTest() = default;
|
||||||
FuncGraphPtr Parse(const lite::converter::Flags &flag) override;
|
FuncGraphPtr Parse(const lite::converter::ConverterParameters &flag) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int InitOriginModelStructure();
|
int InitOriginModelStructure();
|
||||||
|
|
|
@ -24,13 +24,12 @@
|
||||||
#include "ops/fusion/add_fusion.h"
|
#include "ops/fusion/add_fusion.h"
|
||||||
#include "ops/addn.h"
|
#include "ops/addn.h"
|
||||||
#include "ops/custom.h"
|
#include "ops/custom.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
|
||||||
#include "tools/converter/model_parser.h"
|
#include "tools/converter/model_parser.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||||
|
|
||||||
using mindspore::lite::ModelRegistrar;
|
using mindspore::lite::ModelRegistrar;
|
||||||
using mindspore::lite::converter::Flags;
|
using mindspore::lite::converter::ConverterParameters;
|
||||||
using mindspore::lite::converter::FmkType_CAFFE;
|
using mindspore::lite::converter::FmkType_CAFFE;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
class PassRegistryTest : public mindspore::CommonTest {
|
class PassRegistryTest : public mindspore::CommonTest {
|
||||||
|
@ -42,8 +41,8 @@ class PassRegistryTest : public mindspore::CommonTest {
|
||||||
if (model_parser == nullptr) {
|
if (model_parser == nullptr) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Flags flags;
|
ConverterParameters converter_parameters;
|
||||||
func_graph_ = model_parser->Parse(flags);
|
func_graph_ = model_parser->Parse(converter_parameters);
|
||||||
}
|
}
|
||||||
FuncGraphPtr func_graph_ = nullptr;
|
FuncGraphPtr func_graph_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
|
@ -31,6 +31,15 @@
|
||||||
#include "tools/converter/import/mindspore_importer.h"
|
#include "tools/converter/import/mindspore_importer.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
|
namespace {
|
||||||
|
void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) {
|
||||||
|
converter_parameters->fmk_ = flag.fmk;
|
||||||
|
converter_parameters->quant_type_ = flag.quantType;
|
||||||
|
converter_parameters->model_file_ = flag.modelFile;
|
||||||
|
converter_parameters->weight_file_ = flag.weightFile;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
||||||
FuncGraphPtr func_graph = nullptr;
|
FuncGraphPtr func_graph = nullptr;
|
||||||
if (flag.fmk == converter::FmkType::FmkType_MS) {
|
if (flag.fmk == converter::FmkType::FmkType_MS) {
|
||||||
|
@ -45,7 +54,9 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
||||||
if (model_parser_ == nullptr) {
|
if (model_parser_ == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
func_graph = model_parser_->Parse(flag);
|
converter::ConverterParameters converter_parameters;
|
||||||
|
InitConverterParameters(flag, &converter_parameters);
|
||||||
|
func_graph = model_parser_->Parse(converter_parameters);
|
||||||
}
|
}
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn;
|
MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn;
|
||||||
|
|
|
@ -22,9 +22,8 @@
|
||||||
#include "schema/inner/model_generated.h"
|
#include "schema/inner/model_generated.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "tools/converter/converter_context.h"
|
#include "include/registry/model_parser_registry.h"
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "utils/log_adapter.h"
|
||||||
#include "tools/converter/quant_param_holder.h"
|
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
using namespace schema;
|
using namespace schema;
|
||||||
|
@ -34,7 +33,7 @@ class ModelParser {
|
||||||
|
|
||||||
virtual ~ModelParser() = default;
|
virtual ~ModelParser() = default;
|
||||||
|
|
||||||
virtual FuncGraphPtr Parse(const converter::Flags &flag) { return this->res_graph_; }
|
virtual FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
FuncGraphPtr res_graph_ = nullptr;
|
FuncGraphPtr res_graph_ = nullptr;
|
||||||
|
|
|
@ -67,10 +67,10 @@ CaffeModelParser::CaffeModelParser() = default;
|
||||||
|
|
||||||
CaffeModelParser::~CaffeModelParser() = default;
|
CaffeModelParser::~CaffeModelParser() = default;
|
||||||
|
|
||||||
FuncGraphPtr CaffeModelParser::Parse(const converter::Flags &flag) {
|
FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||||
auto model_file = flag.modelFile;
|
auto model_file = flag.model_file_;
|
||||||
auto weight_file = flag.weightFile;
|
auto weight_file = flag.weight_file_;
|
||||||
quant_type_ = flag.quantType;
|
quant_type_ = flag.quant_type_;
|
||||||
STATUS status = InitOriginModel(model_file, weight_file);
|
STATUS status = InitOriginModel(model_file, weight_file);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
|
|
|
@ -34,7 +34,7 @@ class CaffeModelParser : public ModelParser {
|
||||||
|
|
||||||
~CaffeModelParser() override;
|
~CaffeModelParser() override;
|
||||||
|
|
||||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);
|
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);
|
||||||
|
|
|
@ -53,9 +53,9 @@ std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
||||||
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
||||||
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
||||||
|
|
||||||
FuncGraphPtr OnnxModelParser::Parse(const converter::Flags &flag) {
|
FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||||
string model_file = flag.modelFile;
|
string model_file = flag.model_file_;
|
||||||
quant_type_ = flag.quantType;
|
quant_type_ = flag.quant_type_;
|
||||||
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
|
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
|
||||||
res_graph_ = std::make_shared<FuncGraph>();
|
res_graph_ = std::make_shared<FuncGraph>();
|
||||||
auto status = InitOriginModel(model_file);
|
auto status = InitOriginModel(model_file);
|
||||||
|
|
|
@ -42,7 +42,7 @@ class OnnxModelParser : public ModelParser {
|
||||||
|
|
||||||
~OnnxModelParser() override = default;
|
~OnnxModelParser() override = default;
|
||||||
|
|
||||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||||
|
|
||||||
static int Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
static int Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "src/common/utils.h"
|
#include "src/common/utils.h"
|
||||||
#include "tools/common/graph_util.h"
|
#include "tools/common/graph_util.h"
|
||||||
#include "tools/common/protobuf_utils.h"
|
#include "tools/common/protobuf_utils.h"
|
||||||
|
#include "tools/converter/converter_context.h"
|
||||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
#include "tools/converter/ops/ops_def.h"
|
#include "tools/converter/ops/ops_def.h"
|
||||||
|
@ -478,9 +479,9 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr TFModelParser::Parse(const converter::Flags &flag) {
|
FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||||
auto modelFile = flag.modelFile;
|
auto modelFile = flag.model_file_;
|
||||||
quant_type_ = flag.quantType;
|
quant_type_ = flag.quant_type_;
|
||||||
NotSupportOp::GetInstance()->set_fmk_type("TF");
|
NotSupportOp::GetInstance()->set_fmk_type("TF");
|
||||||
auto status = ValidateFileStr(modelFile, ".pb");
|
auto status = ValidateFileStr(modelFile, ".pb");
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
|
|
|
@ -40,7 +40,7 @@ class TFModelParser : public ModelParser {
|
||||||
TFModelParser() = default;
|
TFModelParser() = default;
|
||||||
~TFModelParser() override = default;
|
~TFModelParser() override = default;
|
||||||
|
|
||||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||||
|
|
||||||
static int TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
static int TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||||
|
|
||||||
|
|
|
@ -51,9 +51,9 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
|
||||||
return tflite::UnPackModel(tflite_model_buf_);
|
return tflite::UnPackModel(tflite_model_buf_);
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr TfliteModelParser::Parse(const converter::Flags &flag) {
|
FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||||
auto model_file = flag.modelFile;
|
auto model_file = flag.model_file_;
|
||||||
quant_type_ = flag.quantType;
|
quant_type_ = flag.quant_type_;
|
||||||
// load graph
|
// load graph
|
||||||
tflite_model_ = ReadTfliteModel(model_file);
|
tflite_model_ = ReadTfliteModel(model_file);
|
||||||
if (tflite_model_ == nullptr) {
|
if (tflite_model_ == nullptr) {
|
||||||
|
|
|
@ -34,7 +34,7 @@ class TfliteModelParser : public ModelParser {
|
||||||
|
|
||||||
~TfliteModelParser() override = default;
|
~TfliteModelParser() override = default;
|
||||||
|
|
||||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||||
|
|
||||||
static int Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
static int Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue