!18904 [lite]open flags when converting

Merge pull request !18904 from 徐安越/master_core
This commit is contained in:
i-robot 2021-06-28 02:08:22 +00:00 committed by Gitee
commit fe9954ff52
16 changed files with 60 additions and 37 deletions

View File

@ -323,7 +323,8 @@ elseif(WIN32)
else() else()
if(SUPPORT_TRAIN) if(SUPPORT_TRAIN)
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR} install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE) COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE
PATTERN "framework.h" EXCLUDE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME}) ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
@ -331,7 +332,7 @@ else()
else() else()
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR} install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
PATTERN "*registry.h" EXCLUDE) PATTERN "*registry.h" EXCLUDE PATTERN "framework.h" EXCLUDE)
endif() endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME}) COMPONENT ${RUNTIME_COMPONENT_NAME})

View File

@ -16,13 +16,27 @@
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H #ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H #define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
#include <map>
#include <memory> #include <memory>
#include <unordered_map> #include <string>
#include "include/lite_utils.h" #include "include/lite_utils.h"
#include "include/registry/framework.h" #include "include/registry/framework.h"
#include "schema/inner/model_generated.h"
using mindspore::lite::converter::FmkType; using mindspore::lite::converter::FmkType;
namespace mindspore::lite { namespace mindspore::lite {
namespace converter {
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
struct MS_API ConverterParameters {
FmkType fmk_;
schema::QuantType quant_type_;
std::string model_file_;
std::string weight_file_;
std::map<std::string, std::string> attrs_;
};
} // namespace converter
/// \brief ModelParser defined a model parser /// \brief ModelParser defined a model parser
class MS_API ModelParser; class MS_API ModelParser;
@ -56,7 +70,7 @@ class MS_API ModelParserRegistry {
/// \param[in] creator Define function pointer of creating ModelParser. /// \param[in] creator Define function pointer of creating ModelParser.
int RegParser(const FmkType fmk, ModelParserCreator creator); int RegParser(const FmkType fmk, ModelParserCreator creator);
std::unordered_map<FmkType, ModelParserCreator> parsers_; std::map<FmkType, ModelParserCreator> parsers_;
}; };
/// \brief ModelRegistrar defined registration class of ModelParser. /// \brief ModelRegistrar defined registration class of ModelParser.

View File

@ -18,10 +18,9 @@
#include "common/common_test.h" #include "common/common_test.h"
#include "ut/tools/converter/registry/model_parser_test.h" #include "ut/tools/converter/registry/model_parser_test.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/converter_flags.h"
using mindspore::lite::ModelRegistrar; using mindspore::lite::ModelRegistrar;
using mindspore::lite::converter::Flags; using mindspore::lite::converter::ConverterParameters;
using mindspore::lite::converter::FmkType_CAFFE; using mindspore::lite::converter::FmkType_CAFFE;
namespace mindspore { namespace mindspore {
class ModelParserRegistryTest : public mindspore::CommonTest { class ModelParserRegistryTest : public mindspore::CommonTest {
@ -39,8 +38,8 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
TestModelParserCreator); // register test model parser creator, which will overwrite existing. TestModelParserCreator); // register test model parser creator, which will overwrite existing.
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser(FmkType_CAFFE); auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser(FmkType_CAFFE);
ASSERT_NE(model_parser, nullptr); ASSERT_NE(model_parser, nullptr);
Flags flags; ConverterParameters converter_parameters;
auto func_graph = model_parser->Parse(flags); auto func_graph = model_parser->Parse(converter_parameters);
ASSERT_NE(func_graph, nullptr); ASSERT_NE(func_graph, nullptr);
auto node_list = func_graph->GetOrderedCnodes(); auto node_list = func_graph->GetOrderedCnodes();
ASSERT_EQ(node_list.size(), 3); ASSERT_EQ(node_list.size(), 3);

View File

@ -21,7 +21,7 @@
#include "include/registry/model_parser_registry.h" #include "include/registry/model_parser_registry.h"
namespace mindspore { namespace mindspore {
FuncGraphPtr ModelParserTest::Parse(const lite::converter::Flags &flag) { FuncGraphPtr ModelParserTest::Parse(const lite::converter::ConverterParameters &flag) {
// construct funcgraph // construct funcgraph
res_graph_ = std::make_shared<FuncGraph>(); res_graph_ = std::make_shared<FuncGraph>();
auto ret = InitOriginModelStructure(); auto ret = InitOriginModelStructure();

View File

@ -23,13 +23,12 @@
#include "include/registry/model_parser_registry.h" #include "include/registry/model_parser_registry.h"
#include "ut/tools/converter/registry/node_parser_test.h" #include "ut/tools/converter/registry/node_parser_test.h"
#include "tools/converter/model_parser.h" #include "tools/converter/model_parser.h"
#include "tools/converter/converter_flags.h"
namespace mindspore { namespace mindspore {
class ModelParserTest : public lite::ModelParser { class ModelParserTest : public lite::ModelParser {
public: public:
ModelParserTest() = default; ModelParserTest() = default;
FuncGraphPtr Parse(const lite::converter::Flags &flag) override; FuncGraphPtr Parse(const lite::converter::ConverterParameters &flag) override;
private: private:
int InitOriginModelStructure(); int InitOriginModelStructure();

View File

@ -24,13 +24,12 @@
#include "ops/fusion/add_fusion.h" #include "ops/fusion/add_fusion.h"
#include "ops/addn.h" #include "ops/addn.h"
#include "ops/custom.h" #include "ops/custom.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/model_parser.h" #include "tools/converter/model_parser.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "ut/tools/converter/registry/model_parser_test.h" #include "ut/tools/converter/registry/model_parser_test.h"
using mindspore::lite::ModelRegistrar; using mindspore::lite::ModelRegistrar;
using mindspore::lite::converter::Flags; using mindspore::lite::converter::ConverterParameters;
using mindspore::lite::converter::FmkType_CAFFE; using mindspore::lite::converter::FmkType_CAFFE;
namespace mindspore { namespace mindspore {
class PassRegistryTest : public mindspore::CommonTest { class PassRegistryTest : public mindspore::CommonTest {
@ -42,8 +41,8 @@ class PassRegistryTest : public mindspore::CommonTest {
if (model_parser == nullptr) { if (model_parser == nullptr) {
return; return;
} }
Flags flags; ConverterParameters converter_parameters;
func_graph_ = model_parser->Parse(flags); func_graph_ = model_parser->Parse(converter_parameters);
} }
FuncGraphPtr func_graph_ = nullptr; FuncGraphPtr func_graph_ = nullptr;
}; };

View File

@ -31,6 +31,15 @@
#include "tools/converter/import/mindspore_importer.h" #include "tools/converter/import/mindspore_importer.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
namespace {
void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) {
converter_parameters->fmk_ = flag.fmk;
converter_parameters->quant_type_ = flag.quantType;
converter_parameters->model_file_ = flag.modelFile;
converter_parameters->weight_file_ = flag.weightFile;
}
} // namespace
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) { FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
FuncGraphPtr func_graph = nullptr; FuncGraphPtr func_graph = nullptr;
if (flag.fmk == converter::FmkType::FmkType_MS) { if (flag.fmk == converter::FmkType::FmkType_MS) {
@ -45,7 +54,9 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
if (model_parser_ == nullptr) { if (model_parser_ == nullptr) {
return nullptr; return nullptr;
} }
func_graph = model_parser_->Parse(flag); converter::ConverterParameters converter_parameters;
InitConverterParameters(flag, &converter_parameters);
func_graph = model_parser_->Parse(converter_parameters);
} }
if (func_graph == nullptr) { if (func_graph == nullptr) {
MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn; MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn;

View File

@ -22,9 +22,8 @@
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "tools/converter/converter_context.h" #include "include/registry/model_parser_registry.h"
#include "tools/converter/converter_flags.h" #include "utils/log_adapter.h"
#include "tools/converter/quant_param_holder.h"
namespace mindspore::lite { namespace mindspore::lite {
using namespace schema; using namespace schema;
@ -34,7 +33,7 @@ class ModelParser {
virtual ~ModelParser() = default; virtual ~ModelParser() = default;
virtual FuncGraphPtr Parse(const converter::Flags &flag) { return this->res_graph_; } virtual FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
protected: protected:
FuncGraphPtr res_graph_ = nullptr; FuncGraphPtr res_graph_ = nullptr;

View File

@ -67,10 +67,10 @@ CaffeModelParser::CaffeModelParser() = default;
CaffeModelParser::~CaffeModelParser() = default; CaffeModelParser::~CaffeModelParser() = default;
FuncGraphPtr CaffeModelParser::Parse(const converter::Flags &flag) { FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) {
auto model_file = flag.modelFile; auto model_file = flag.model_file_;
auto weight_file = flag.weightFile; auto weight_file = flag.weight_file_;
quant_type_ = flag.quantType; quant_type_ = flag.quant_type_;
STATUS status = InitOriginModel(model_file, weight_file); STATUS status = InitOriginModel(model_file, weight_file);
if (status != RET_OK) { if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);

View File

@ -34,7 +34,7 @@ class CaffeModelParser : public ModelParser {
~CaffeModelParser() override; ~CaffeModelParser() override;
FuncGraphPtr Parse(const converter::Flags &flag) override; FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
private: private:
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file); STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);

View File

@ -53,9 +53,9 @@ std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
FuncGraphPtr OnnxModelParser::Parse(const converter::Flags &flag) { FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
string model_file = flag.modelFile; string model_file = flag.model_file_;
quant_type_ = flag.quantType; quant_type_ = flag.quant_type_;
NotSupportOp::GetInstance()->set_fmk_type("ONNX"); NotSupportOp::GetInstance()->set_fmk_type("ONNX");
res_graph_ = std::make_shared<FuncGraph>(); res_graph_ = std::make_shared<FuncGraph>();
auto status = InitOriginModel(model_file); auto status = InitOriginModel(model_file);

View File

@ -42,7 +42,7 @@ class OnnxModelParser : public ModelParser {
~OnnxModelParser() override = default; ~OnnxModelParser() override = default;
FuncGraphPtr Parse(const converter::Flags &flag) override; FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
static int Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs); static int Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);

View File

@ -22,6 +22,7 @@
#include "src/common/utils.h" #include "src/common/utils.h"
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h" #include "tools/common/protobuf_utils.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h" #include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/optimizer/common/gllo_utils.h" #include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/ops/ops_def.h" #include "tools/converter/ops/ops_def.h"
@ -478,9 +479,9 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
return RET_OK; return RET_OK;
} }
FuncGraphPtr TFModelParser::Parse(const converter::Flags &flag) { FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
auto modelFile = flag.modelFile; auto modelFile = flag.model_file_;
quant_type_ = flag.quantType; quant_type_ = flag.quant_type_;
NotSupportOp::GetInstance()->set_fmk_type("TF"); NotSupportOp::GetInstance()->set_fmk_type("TF");
auto status = ValidateFileStr(modelFile, ".pb"); auto status = ValidateFileStr(modelFile, ".pb");
if (status != RET_OK) { if (status != RET_OK) {

View File

@ -40,7 +40,7 @@ class TFModelParser : public ModelParser {
TFModelParser() = default; TFModelParser() = default;
~TFModelParser() override = default; ~TFModelParser() override = default;
FuncGraphPtr Parse(const converter::Flags &flag) override; FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
static int TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs); static int TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);

View File

@ -51,9 +51,9 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
return tflite::UnPackModel(tflite_model_buf_); return tflite::UnPackModel(tflite_model_buf_);
} }
FuncGraphPtr TfliteModelParser::Parse(const converter::Flags &flag) { FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
auto model_file = flag.modelFile; auto model_file = flag.model_file_;
quant_type_ = flag.quantType; quant_type_ = flag.quant_type_;
// load graph // load graph
tflite_model_ = ReadTfliteModel(model_file); tflite_model_ = ReadTfliteModel(model_file);
if (tflite_model_ == nullptr) { if (tflite_model_ == nullptr) {

View File

@ -34,7 +34,7 @@ class TfliteModelParser : public ModelParser {
~TfliteModelParser() override = default; ~TfliteModelParser() override = default;
FuncGraphPtr Parse(const converter::Flags &flag) override; FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
static int Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs); static int Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);