forked from mindspore-Ecosystem/mindspore
!18904 [lite]open flags when converting
Merge pull request !18904 from 徐安越/master_core
This commit is contained in:
commit
fe9954ff52
|
@ -323,7 +323,8 @@ elseif(WIN32)
|
|||
else()
|
||||
if(SUPPORT_TRAIN)
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE
|
||||
PATTERN "framework.h" EXCLUDE)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
|
||||
${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
|
||||
|
@ -331,7 +332,7 @@ else()
|
|||
else()
|
||||
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
|
||||
PATTERN "*registry.h" EXCLUDE)
|
||||
PATTERN "*registry.h" EXCLUDE PATTERN "framework.h" EXCLUDE)
|
||||
endif()
|
||||
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
|
|
@ -16,13 +16,27 @@
|
|||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include "include/lite_utils.h"
|
||||
#include "include/registry/framework.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore::lite {
|
||||
namespace converter {
|
||||
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
|
||||
struct MS_API ConverterParameters {
|
||||
FmkType fmk_;
|
||||
schema::QuantType quant_type_;
|
||||
std::string model_file_;
|
||||
std::string weight_file_;
|
||||
std::map<std::string, std::string> attrs_;
|
||||
};
|
||||
} // namespace converter
|
||||
|
||||
/// \brief ModelParser defined a model parser
|
||||
class MS_API ModelParser;
|
||||
|
||||
|
@ -56,7 +70,7 @@ class MS_API ModelParserRegistry {
|
|||
/// \param[in] creator Define function pointer of creating ModelParser.
|
||||
int RegParser(const FmkType fmk, ModelParserCreator creator);
|
||||
|
||||
std::unordered_map<FmkType, ModelParserCreator> parsers_;
|
||||
std::map<FmkType, ModelParserCreator> parsers_;
|
||||
};
|
||||
|
||||
/// \brief ModelRegistrar defined registration class of ModelParser.
|
||||
|
|
|
@ -18,10 +18,9 @@
|
|||
#include "common/common_test.h"
|
||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::lite::ModelRegistrar;
|
||||
using mindspore::lite::converter::Flags;
|
||||
using mindspore::lite::converter::ConverterParameters;
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
namespace mindspore {
|
||||
class ModelParserRegistryTest : public mindspore::CommonTest {
|
||||
|
@ -39,8 +38,8 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
|
|||
TestModelParserCreator); // register test model parser creator, which will overwrite existing.
|
||||
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser(FmkType_CAFFE);
|
||||
ASSERT_NE(model_parser, nullptr);
|
||||
Flags flags;
|
||||
auto func_graph = model_parser->Parse(flags);
|
||||
ConverterParameters converter_parameters;
|
||||
auto func_graph = model_parser->Parse(converter_parameters);
|
||||
ASSERT_NE(func_graph, nullptr);
|
||||
auto node_list = func_graph->GetOrderedCnodes();
|
||||
ASSERT_EQ(node_list.size(), 3);
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "include/registry/model_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
FuncGraphPtr ModelParserTest::Parse(const lite::converter::Flags &flag) {
|
||||
FuncGraphPtr ModelParserTest::Parse(const lite::converter::ConverterParameters &flag) {
|
||||
// construct funcgraph
|
||||
res_graph_ = std::make_shared<FuncGraph>();
|
||||
auto ret = InitOriginModelStructure();
|
||||
|
|
|
@ -23,13 +23,12 @@
|
|||
#include "include/registry/model_parser_registry.h"
|
||||
#include "ut/tools/converter/registry/node_parser_test.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ModelParserTest : public lite::ModelParser {
|
||||
public:
|
||||
ModelParserTest() = default;
|
||||
FuncGraphPtr Parse(const lite::converter::Flags &flag) override;
|
||||
FuncGraphPtr Parse(const lite::converter::ConverterParameters &flag) override;
|
||||
|
||||
private:
|
||||
int InitOriginModelStructure();
|
||||
|
|
|
@ -24,13 +24,12 @@
|
|||
#include "ops/fusion/add_fusion.h"
|
||||
#include "ops/addn.h"
|
||||
#include "ops/custom.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
|
||||
using mindspore::lite::ModelRegistrar;
|
||||
using mindspore::lite::converter::Flags;
|
||||
using mindspore::lite::converter::ConverterParameters;
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
namespace mindspore {
|
||||
class PassRegistryTest : public mindspore::CommonTest {
|
||||
|
@ -42,8 +41,8 @@ class PassRegistryTest : public mindspore::CommonTest {
|
|||
if (model_parser == nullptr) {
|
||||
return;
|
||||
}
|
||||
Flags flags;
|
||||
func_graph_ = model_parser->Parse(flags);
|
||||
ConverterParameters converter_parameters;
|
||||
func_graph_ = model_parser->Parse(converter_parameters);
|
||||
}
|
||||
FuncGraphPtr func_graph_ = nullptr;
|
||||
};
|
||||
|
|
|
@ -31,6 +31,15 @@
|
|||
#include "tools/converter/import/mindspore_importer.h"
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) {
|
||||
converter_parameters->fmk_ = flag.fmk;
|
||||
converter_parameters->quant_type_ = flag.quantType;
|
||||
converter_parameters->model_file_ = flag.modelFile;
|
||||
converter_parameters->weight_file_ = flag.weightFile;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
||||
FuncGraphPtr func_graph = nullptr;
|
||||
if (flag.fmk == converter::FmkType::FmkType_MS) {
|
||||
|
@ -45,7 +54,9 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
|||
if (model_parser_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
func_graph = model_parser_->Parse(flag);
|
||||
converter::ConverterParameters converter_parameters;
|
||||
InitConverterParameters(flag, &converter_parameters);
|
||||
func_graph = model_parser_->Parse(converter_parameters);
|
||||
}
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn;
|
||||
|
|
|
@ -22,9 +22,8 @@
|
|||
#include "schema/inner/model_generated.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
using namespace schema;
|
||||
|
@ -34,7 +33,7 @@ class ModelParser {
|
|||
|
||||
virtual ~ModelParser() = default;
|
||||
|
||||
virtual FuncGraphPtr Parse(const converter::Flags &flag) { return this->res_graph_; }
|
||||
virtual FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
|
||||
|
||||
protected:
|
||||
FuncGraphPtr res_graph_ = nullptr;
|
||||
|
|
|
@ -67,10 +67,10 @@ CaffeModelParser::CaffeModelParser() = default;
|
|||
|
||||
CaffeModelParser::~CaffeModelParser() = default;
|
||||
|
||||
FuncGraphPtr CaffeModelParser::Parse(const converter::Flags &flag) {
|
||||
auto model_file = flag.modelFile;
|
||||
auto weight_file = flag.weightFile;
|
||||
quant_type_ = flag.quantType;
|
||||
FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto model_file = flag.model_file_;
|
||||
auto weight_file = flag.weight_file_;
|
||||
quant_type_ = flag.quant_type_;
|
||||
STATUS status = InitOriginModel(model_file, weight_file);
|
||||
if (status != RET_OK) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
|
|
@ -34,7 +34,7 @@ class CaffeModelParser : public ModelParser {
|
|||
|
||||
~CaffeModelParser() override;
|
||||
|
||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
||||
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
|
||||
private:
|
||||
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);
|
||||
|
|
|
@ -53,9 +53,9 @@ std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
|||
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
||||
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
||||
|
||||
FuncGraphPtr OnnxModelParser::Parse(const converter::Flags &flag) {
|
||||
string model_file = flag.modelFile;
|
||||
quant_type_ = flag.quantType;
|
||||
FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
string model_file = flag.model_file_;
|
||||
quant_type_ = flag.quant_type_;
|
||||
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
|
||||
res_graph_ = std::make_shared<FuncGraph>();
|
||||
auto status = InitOriginModel(model_file);
|
||||
|
|
|
@ -42,7 +42,7 @@ class OnnxModelParser : public ModelParser {
|
|||
|
||||
~OnnxModelParser() override = default;
|
||||
|
||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
||||
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
|
||||
static int Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/protobuf_utils.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
@ -478,9 +479,9 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
FuncGraphPtr TFModelParser::Parse(const converter::Flags &flag) {
|
||||
auto modelFile = flag.modelFile;
|
||||
quant_type_ = flag.quantType;
|
||||
FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto modelFile = flag.model_file_;
|
||||
quant_type_ = flag.quant_type_;
|
||||
NotSupportOp::GetInstance()->set_fmk_type("TF");
|
||||
auto status = ValidateFileStr(modelFile, ".pb");
|
||||
if (status != RET_OK) {
|
||||
|
|
|
@ -40,7 +40,7 @@ class TFModelParser : public ModelParser {
|
|||
TFModelParser() = default;
|
||||
~TFModelParser() override = default;
|
||||
|
||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
||||
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
|
||||
static int TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
|
||||
|
|
|
@ -51,9 +51,9 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
|
|||
return tflite::UnPackModel(tflite_model_buf_);
|
||||
}
|
||||
|
||||
FuncGraphPtr TfliteModelParser::Parse(const converter::Flags &flag) {
|
||||
auto model_file = flag.modelFile;
|
||||
quant_type_ = flag.quantType;
|
||||
FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
|
||||
auto model_file = flag.model_file_;
|
||||
quant_type_ = flag.quant_type_;
|
||||
// load graph
|
||||
tflite_model_ = ReadTfliteModel(model_file);
|
||||
if (tflite_model_ == nullptr) {
|
||||
|
|
|
@ -34,7 +34,7 @@ class TfliteModelParser : public ModelParser {
|
|||
|
||||
~TfliteModelParser() override = default;
|
||||
|
||||
FuncGraphPtr Parse(const converter::Flags &flag) override;
|
||||
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
|
||||
|
||||
static int Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);
|
||||
|
||||
|
|
Loading…
Reference in New Issue