!18904 [lite]open flags when converting

Merge pull request !18904 from 徐安越/master_core
This commit is contained in:
i-robot 2021-06-28 02:08:22 +00:00 committed by Gitee
commit fe9954ff52
16 changed files with 60 additions and 37 deletions

View File

@ -323,7 +323,8 @@ elseif(WIN32)
else()
if(SUPPORT_TRAIN)
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "*registry.h" EXCLUDE
PATTERN "framework.h" EXCLUDE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.so DESTINATION
${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_TRAIN_LIB_NAME}.a DESTINATION
@ -331,7 +332,7 @@ else()
else()
install(DIRECTORY ${TOP_DIR}/mindspore/lite/include/ DESTINATION ${RUNTIME_INC_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "train*" EXCLUDE
PATTERN "*registry.h" EXCLUDE)
PATTERN "*registry.h" EXCLUDE PATTERN "framework.h" EXCLUDE)
endif()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})

View File

@ -16,13 +16,27 @@
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
#include <map>
#include <memory>
#include <unordered_map>
#include <string>
#include "include/lite_utils.h"
#include "include/registry/framework.h"
#include "schema/inner/model_generated.h"
using mindspore::lite::converter::FmkType;
namespace mindspore::lite {
namespace converter {
/// \brief ConverterParameters defined read-only converter parameters used by users in ModelParser.
struct MS_API ConverterParameters {
FmkType fmk_;
schema::QuantType quant_type_;
std::string model_file_;
std::string weight_file_;
std::map<std::string, std::string> attrs_;
};
} // namespace converter
/// \brief ModelParser defined a model parser
class MS_API ModelParser;
@ -56,7 +70,7 @@ class MS_API ModelParserRegistry {
/// \param[in] creator Define function pointer of creating ModelParser.
int RegParser(const FmkType fmk, ModelParserCreator creator);
std::unordered_map<FmkType, ModelParserCreator> parsers_;
std::map<FmkType, ModelParserCreator> parsers_;
};
/// \brief ModelRegistrar defined registration class of ModelParser.

View File

@ -18,10 +18,9 @@
#include "common/common_test.h"
#include "ut/tools/converter/registry/model_parser_test.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/converter_flags.h"
using mindspore::lite::ModelRegistrar;
using mindspore::lite::converter::Flags;
using mindspore::lite::converter::ConverterParameters;
using mindspore::lite::converter::FmkType_CAFFE;
namespace mindspore {
class ModelParserRegistryTest : public mindspore::CommonTest {
@ -39,8 +38,8 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
TestModelParserCreator); // register test model parser creator, which will overwrite existing.
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser(FmkType_CAFFE);
ASSERT_NE(model_parser, nullptr);
Flags flags;
auto func_graph = model_parser->Parse(flags);
ConverterParameters converter_parameters;
auto func_graph = model_parser->Parse(converter_parameters);
ASSERT_NE(func_graph, nullptr);
auto node_list = func_graph->GetOrderedCnodes();
ASSERT_EQ(node_list.size(), 3);

View File

@ -21,7 +21,7 @@
#include "include/registry/model_parser_registry.h"
namespace mindspore {
FuncGraphPtr ModelParserTest::Parse(const lite::converter::Flags &flag) {
FuncGraphPtr ModelParserTest::Parse(const lite::converter::ConverterParameters &flag) {
// construct funcgraph
res_graph_ = std::make_shared<FuncGraph>();
auto ret = InitOriginModelStructure();

View File

@ -23,13 +23,12 @@
#include "include/registry/model_parser_registry.h"
#include "ut/tools/converter/registry/node_parser_test.h"
#include "tools/converter/model_parser.h"
#include "tools/converter/converter_flags.h"
namespace mindspore {
class ModelParserTest : public lite::ModelParser {
public:
ModelParserTest() = default;
FuncGraphPtr Parse(const lite::converter::Flags &flag) override;
FuncGraphPtr Parse(const lite::converter::ConverterParameters &flag) override;
private:
int InitOriginModelStructure();

View File

@ -24,13 +24,12 @@
#include "ops/fusion/add_fusion.h"
#include "ops/addn.h"
#include "ops/custom.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/model_parser.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "ut/tools/converter/registry/model_parser_test.h"
using mindspore::lite::ModelRegistrar;
using mindspore::lite::converter::Flags;
using mindspore::lite::converter::ConverterParameters;
using mindspore::lite::converter::FmkType_CAFFE;
namespace mindspore {
class PassRegistryTest : public mindspore::CommonTest {
@ -42,8 +41,8 @@ class PassRegistryTest : public mindspore::CommonTest {
if (model_parser == nullptr) {
return;
}
Flags flags;
func_graph_ = model_parser->Parse(flags);
ConverterParameters converter_parameters;
func_graph_ = model_parser->Parse(converter_parameters);
}
FuncGraphPtr func_graph_ = nullptr;
};

View File

@ -31,6 +31,15 @@
#include "tools/converter/import/mindspore_importer.h"
namespace mindspore {
namespace lite {
namespace {
void InitConverterParameters(const converter::Flags &flag, converter::ConverterParameters *converter_parameters) {
converter_parameters->fmk_ = flag.fmk;
converter_parameters->quant_type_ = flag.quantType;
converter_parameters->model_file_ = flag.modelFile;
converter_parameters->weight_file_ = flag.weightFile;
}
} // namespace
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
FuncGraphPtr func_graph = nullptr;
if (flag.fmk == converter::FmkType::FmkType_MS) {
@ -45,7 +54,9 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
if (model_parser_ == nullptr) {
return nullptr;
}
func_graph = model_parser_->Parse(flag);
converter::ConverterParameters converter_parameters;
InitConverterParameters(flag, &converter_parameters);
func_graph = model_parser_->Parse(converter_parameters);
}
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn;

View File

@ -22,9 +22,8 @@
#include "schema/inner/model_generated.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/quant_param_holder.h"
#include "include/registry/model_parser_registry.h"
#include "utils/log_adapter.h"
namespace mindspore::lite {
using namespace schema;
@ -34,7 +33,7 @@ class ModelParser {
virtual ~ModelParser() = default;
virtual FuncGraphPtr Parse(const converter::Flags &flag) { return this->res_graph_; }
virtual FuncGraphPtr Parse(const converter::ConverterParameters &flags) { return this->res_graph_; }
protected:
FuncGraphPtr res_graph_ = nullptr;

View File

@ -67,10 +67,10 @@ CaffeModelParser::CaffeModelParser() = default;
CaffeModelParser::~CaffeModelParser() = default;
FuncGraphPtr CaffeModelParser::Parse(const converter::Flags &flag) {
auto model_file = flag.modelFile;
auto weight_file = flag.weightFile;
quant_type_ = flag.quantType;
FuncGraphPtr CaffeModelParser::Parse(const converter::ConverterParameters &flag) {
auto model_file = flag.model_file_;
auto weight_file = flag.weight_file_;
quant_type_ = flag.quant_type_;
STATUS status = InitOriginModel(model_file, weight_file);
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);

View File

@ -34,7 +34,7 @@ class CaffeModelParser : public ModelParser {
~CaffeModelParser() override;
FuncGraphPtr Parse(const converter::Flags &flag) override;
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
private:
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);

View File

@ -53,9 +53,9 @@ std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
FuncGraphPtr OnnxModelParser::Parse(const converter::Flags &flag) {
string model_file = flag.modelFile;
quant_type_ = flag.quantType;
FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) {
string model_file = flag.model_file_;
quant_type_ = flag.quant_type_;
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
res_graph_ = std::make_shared<FuncGraph>();
auto status = InitOriginModel(model_file);

View File

@ -42,7 +42,7 @@ class OnnxModelParser : public ModelParser {
~OnnxModelParser() override = default;
FuncGraphPtr Parse(const converter::Flags &flag) override;
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
static int Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);

View File

@ -22,6 +22,7 @@
#include "src/common/utils.h"
#include "tools/common/graph_util.h"
#include "tools/common/protobuf_utils.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/ops/ops_def.h"
@ -478,9 +479,9 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
return RET_OK;
}
FuncGraphPtr TFModelParser::Parse(const converter::Flags &flag) {
auto modelFile = flag.modelFile;
quant_type_ = flag.quantType;
FuncGraphPtr TFModelParser::Parse(const converter::ConverterParameters &flag) {
auto modelFile = flag.model_file_;
quant_type_ = flag.quant_type_;
NotSupportOp::GetInstance()->set_fmk_type("TF");
auto status = ValidateFileStr(modelFile, ".pb");
if (status != RET_OK) {

View File

@ -40,7 +40,7 @@ class TFModelParser : public ModelParser {
TFModelParser() = default;
~TFModelParser() override = default;
FuncGraphPtr Parse(const converter::Flags &flag) override;
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
static int TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);

View File

@ -51,9 +51,9 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const std::st
return tflite::UnPackModel(tflite_model_buf_);
}
FuncGraphPtr TfliteModelParser::Parse(const converter::Flags &flag) {
auto model_file = flag.modelFile;
quant_type_ = flag.quantType;
FuncGraphPtr TfliteModelParser::Parse(const converter::ConverterParameters &flag) {
auto model_file = flag.model_file_;
quant_type_ = flag.quant_type_;
// load graph
tflite_model_ = ReadTfliteModel(model_file);
if (tflite_model_ == nullptr) {

View File

@ -34,7 +34,7 @@ class TfliteModelParser : public ModelParser {
~TfliteModelParser() override = default;
FuncGraphPtr Parse(const converter::Flags &flag) override;
FuncGraphPtr Parse(const converter::ConverterParameters &flag) override;
static int Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs);