!18821 [lite]limit fmk when user register ModelParser

Merge pull request !18821 from 徐安越/master_core
This commit is contained in:
i-robot 2021-06-25 08:15:51 +00:00 committed by Gitee
commit 6ab3bea212
13 changed files with 76 additions and 31 deletions

View File

@ -0,0 +1,36 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_FRAMEWORK_H_
#define MINDSPORE_LITE_INCLUDE_REGISTRY_FRAMEWORK_H_
#include "include/lite_utils.h"
namespace mindspore {
namespace lite {
namespace converter {
/// \brief FmkType defined frameworks which converter tool supports.
enum MS_API FmkType : int {
FmkType_TF = 0,
FmkType_CAFFE = 1,
FmkType_ONNX = 2,
FmkType_MS = 3,
FmkType_TFLITE = 4,
};
} // namespace converter
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_FRAMEWORK_H_

View File

@ -16,11 +16,12 @@
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
#include <string>
#include <memory>
#include <unordered_map>
#include "include/lite_utils.h"
#include "include/registry/framework.h"
using mindspore::lite::converter::FmkType;
namespace mindspore::lite {
/// \brief ModelParser defined a model parser
class MS_API ModelParser;
@ -47,15 +48,15 @@ class MS_API ModelParserRegistry {
/// \param[in] fmk Define identification of a certain framework.
///
/// \return Pointer of ModelParser.
ModelParser *GetModelParser(const std::string &fmk);
ModelParser *GetModelParser(const FmkType fmk);
/// \brief Method to register model parser.
///
/// \param[in] fmk Define identification of a certain framework.
/// \param[in] creator Define function pointer of creating ModelParser.
void RegParser(const std::string &fmk, ModelParserCreator creator);
int RegParser(const FmkType fmk, ModelParserCreator creator);
std::unordered_map<std::string, ModelParserCreator> parsers_;
std::unordered_map<FmkType, ModelParserCreator> parsers_;
};
/// \brief ModelRegistrar defined registration class of ModelParser.
@ -65,7 +66,7 @@ class MS_API ModelRegistrar {
///
/// \param[in] fmk Define identification of a certain framework.
/// \param[in] creator Define function pointer of creating ModelParser.
ModelRegistrar(const std::string &fmk, ModelParserCreator creator) {
ModelRegistrar(const FmkType fmk, ModelParserCreator creator) {
ModelParserRegistry::GetInstance()->RegParser(fmk, creator);
}
@ -77,7 +78,7 @@ class MS_API ModelRegistrar {
///
/// \param[in] fmk Define identification of a certain framework.
/// \param[in] parserCreator Define function pointer of creating ModelParser.
#define REG_MODEL_PARSER(fmk, parserCreator) static ModelRegistrar g_##type##fmk##ModelParserReg(#fmk, parserCreator);
#define REG_MODEL_PARSER(fmk, parserCreator) static ModelRegistrar g_##type##fmk##ModelParserReg(fmk, parserCreator);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H

View File

@ -20,7 +20,9 @@
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/converter_flags.h"
using mindspore::lite::ModelRegistrar;
using mindspore::lite::converter::Flags;
using mindspore::lite::converter::FmkType_CAFFE;
namespace mindspore {
class ModelParserRegistryTest : public mindspore::CommonTest {
public:
@ -33,7 +35,9 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
ASSERT_NE(add_parser, nullptr);
auto proposal_parser = node_parser_reg->GetNodeParser("proposal");
ASSERT_NE(proposal_parser, nullptr);
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser("TEST");
REG_MODEL_PARSER(FmkType_CAFFE,
TestModelParserCreator); // register test model parser creator, which will overwrite existing.
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser(FmkType_CAFFE);
ASSERT_NE(model_parser, nullptr);
Flags flags;
auto func_graph = model_parser->Parse(flags);

View File

@ -20,7 +20,6 @@
#include "include/errorcode.h"
#include "include/registry/model_parser_registry.h"
using mindspore::lite::ModelRegistrar;
namespace mindspore {
FuncGraphPtr ModelParserTest::Parse(const lite::converter::Flags &flag) {
// construct funcgraph
@ -169,5 +168,4 @@ lite::ModelParser *TestModelParserCreator() {
}
return model_parser;
}
REG_MODEL_PARSER(TEST, TestModelParserCreator);
} // namespace mindspore

View File

@ -40,6 +40,8 @@ class ModelParserTest : public lite::ModelParser {
std::map<std::string, std::vector<std::string>> model_layers_info_;
std::vector<std::string> model_structure_;
};
lite::ModelParser *TestModelParserCreator();
} // namespace mindspore
#endif // LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_MODEL_PARSER_TEST_H

View File

@ -27,14 +27,18 @@
#include "tools/converter/converter_flags.h"
#include "tools/converter/model_parser.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "ut/tools/converter/registry/model_parser_test.h"
using mindspore::lite::ModelRegistrar;
using mindspore::lite::converter::Flags;
using mindspore::lite::converter::FmkType_CAFFE;
namespace mindspore {
class PassRegistryTest : public mindspore::CommonTest {
public:
PassRegistryTest() = default;
void SetUp() override {
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser("TEST");
REG_MODEL_PARSER(FmkType_CAFFE, TestModelParserCreator);
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser(FmkType_CAFFE);
if (model_parser == nullptr) {
return;
}

View File

@ -40,7 +40,7 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
return nullptr;
}
} else {
model_parser_ = ModelParserRegistry::GetInstance()->GetModelParser(flag.fmkIn);
model_parser_ = ModelParserRegistry::GetInstance()->GetModelParser(flag.fmk);
if (model_parser_ == nullptr) {
return nullptr;
}

View File

@ -19,6 +19,7 @@
#include <string>
#include <vector>
#include "include/registry/framework.h"
#include "tools/common/flag_parser.h"
#include "ir/dtype/type_id.h"
#include "schema/inner/model_generated.h"
@ -31,15 +32,6 @@ using mindspore::schema::QuantType_PostTraining;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_WeightQuant;
namespace converter {
enum FmkType {
FmkType_TF = 0,
FmkType_CAFFE = 1,
FmkType_ONNX = 2,
FmkType_MS = 3,
FmkType_TFLITE = 4,
FmkType_ONNX_LOW_VERSION = 5
};
enum ParallelSplitType { SplitNo = 0, SplitByUserRatio = 1, SplitByUserAttr = 2 };
constexpr auto kMaxSplitRatio = 10;
constexpr auto kComputeRate = "computeRate";

View File

@ -32,6 +32,7 @@
#include "tools/converter/parser/parser_utils.h"
#include "tools/optimizer/common/gllo_utils.h"
using mindspore::lite::converter::FmkType_CAFFE;
namespace mindspore::lite {
namespace {
namespace {
@ -603,5 +604,5 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
}
return layer.name();
}
REG_MODEL_PARSER(CAFFE, LiteModelParserCreator<CaffeModelParser>)
REG_MODEL_PARSER(FmkType_CAFFE, LiteModelParserCreator<CaffeModelParser>)
} // namespace mindspore::lite

View File

@ -36,6 +36,7 @@
#include "tools/converter/parser/parser_utils.h"
#include "ops/transpose.h"
using mindspore::lite::converter::FmkType_ONNX;
namespace mindspore {
namespace lite {
namespace {
@ -248,11 +249,7 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
}
OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
onnx_root_graph_ = onnx_model_.graph();
if (OnnxNodeParser::opset_version() > 15) {
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
} else {
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION)));
}
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
return RET_OK;
}
STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
@ -1353,6 +1350,6 @@ int OnnxModelParser::Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graph
return RET_OK;
}
REG_MODEL_PARSER(ONNX, LiteModelParserCreator<OnnxModelParser>)
REG_MODEL_PARSER(FmkType_ONNX, LiteModelParserCreator<OnnxModelParser>)
} // namespace lite
} // namespace mindspore

View File

@ -33,6 +33,7 @@
#include "tools/converter/parser/parser_utils.h"
#include "tools/common/tensor_util.h"
using mindspore::lite::converter::FmkType_TF;
namespace mindspore {
namespace lite {
namespace {
@ -1239,6 +1240,6 @@ int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
return RET_OK;
}
REG_MODEL_PARSER(TF, LiteModelParserCreator<TFModelParser>)
REG_MODEL_PARSER(FmkType_TF, LiteModelParserCreator<TFModelParser>)
} // namespace lite
} // namespace mindspore

View File

@ -31,6 +31,7 @@
#include "tools/converter/parser/tflite/tflite_inputs_adjust.h"
#include "tools/converter/parser/parser_utils.h"
using mindspore::lite::converter::FmkType_TFLITE;
namespace mindspore::lite {
namespace {
constexpr size_t kConvWeightIndex = 2;
@ -628,5 +629,5 @@ int TfliteModelParser::Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_g
return RET_OK;
}
REG_MODEL_PARSER(TFLITE, LiteModelParserCreator<TfliteModelParser>)
REG_MODEL_PARSER(FmkType_TFLITE, LiteModelParserCreator<TfliteModelParser>)
} // namespace mindspore::lite

View File

@ -16,7 +16,10 @@
#include "include/registry/model_parser_registry.h"
#include <string>
#include <set>
#include <unordered_map>
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
namespace mindspore {
namespace lite {
@ -25,7 +28,7 @@ ModelParserRegistry *ModelParserRegistry::GetInstance() {
return &instance;
}
ModelParser *ModelParserRegistry::GetModelParser(const std::string &fmk) {
ModelParser *ModelParserRegistry::GetModelParser(const FmkType fmk) {
auto it = parsers_.find(fmk);
if (it != parsers_.end()) {
auto creator = it->second;
@ -34,9 +37,14 @@ ModelParser *ModelParserRegistry::GetModelParser(const std::string &fmk) {
return nullptr;
}
void ModelParserRegistry::RegParser(const std::string &fmk, ModelParserCreator creator) {
int ModelParserRegistry::RegParser(const FmkType fmk, ModelParserCreator creator) {
if (fmk < converter::FmkType_TF || fmk > converter::FmkType_TFLITE) {
MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
return RET_ERROR;
}
auto instance = ModelParserRegistry::GetInstance();
instance->parsers_[fmk] = creator;
return RET_OK;
}
} // namespace lite