forked from mindspore-Ecosystem/mindspore
!18821 [lite]limit fmk when user register ModelParser
Merge pull request !18821 from 徐安越/master_core
This commit is contained in:
commit
6ab3bea212
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_FRAMEWORK_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_FRAMEWORK_H_
|
||||
|
||||
#include "include/lite_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace converter {
|
||||
/// \brief FmkType defined frameworks which converter tool supports.
|
||||
enum MS_API FmkType : int {
|
||||
FmkType_TF = 0,
|
||||
FmkType_CAFFE = 1,
|
||||
FmkType_ONNX = 2,
|
||||
FmkType_MS = 3,
|
||||
FmkType_TFLITE = 4,
|
||||
};
|
||||
} // namespace converter
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_FRAMEWORK_H_
|
|
@ -16,11 +16,12 @@
|
|||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
||||
#define MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "include/lite_utils.h"
|
||||
#include "include/registry/framework.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore::lite {
|
||||
/// \brief ModelParser defined a model parser
|
||||
class MS_API ModelParser;
|
||||
|
@ -47,15 +48,15 @@ class MS_API ModelParserRegistry {
|
|||
/// \param[in] fmk Define identification of a certain framework.
|
||||
///
|
||||
/// \return Pointer of ModelParser.
|
||||
ModelParser *GetModelParser(const std::string &fmk);
|
||||
ModelParser *GetModelParser(const FmkType fmk);
|
||||
|
||||
/// \brief Method to register model parser.
|
||||
///
|
||||
/// \param[in] fmk Define identification of a certain framework.
|
||||
/// \param[in] creator Define function pointer of creating ModelParser.
|
||||
void RegParser(const std::string &fmk, ModelParserCreator creator);
|
||||
int RegParser(const FmkType fmk, ModelParserCreator creator);
|
||||
|
||||
std::unordered_map<std::string, ModelParserCreator> parsers_;
|
||||
std::unordered_map<FmkType, ModelParserCreator> parsers_;
|
||||
};
|
||||
|
||||
/// \brief ModelRegistrar defined registration class of ModelParser.
|
||||
|
@ -65,7 +66,7 @@ class MS_API ModelRegistrar {
|
|||
///
|
||||
/// \param[in] fmk Define identification of a certain framework.
|
||||
/// \param[in] creator Define function pointer of creating ModelParser.
|
||||
ModelRegistrar(const std::string &fmk, ModelParserCreator creator) {
|
||||
ModelRegistrar(const FmkType fmk, ModelParserCreator creator) {
|
||||
ModelParserRegistry::GetInstance()->RegParser(fmk, creator);
|
||||
}
|
||||
|
||||
|
@ -77,7 +78,7 @@ class MS_API ModelRegistrar {
|
|||
///
|
||||
/// \param[in] fmk Define identification of a certain framework.
|
||||
/// \param[in] parserCreator Define function pointer of creating ModelParser.
|
||||
#define REG_MODEL_PARSER(fmk, parserCreator) static ModelRegistrar g_##type##fmk##ModelParserReg(#fmk, parserCreator);
|
||||
#define REG_MODEL_PARSER(fmk, parserCreator) static ModelRegistrar g_##type##fmk##ModelParserReg(fmk, parserCreator);
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_INCLUDE_REGISTRY_MODEL_PARSER_REGISTRY_H
|
||||
|
|
|
@ -20,7 +20,9 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::lite::ModelRegistrar;
|
||||
using mindspore::lite::converter::Flags;
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
namespace mindspore {
|
||||
class ModelParserRegistryTest : public mindspore::CommonTest {
|
||||
public:
|
||||
|
@ -33,7 +35,9 @@ TEST_F(ModelParserRegistryTest, TestRegistry) {
|
|||
ASSERT_NE(add_parser, nullptr);
|
||||
auto proposal_parser = node_parser_reg->GetNodeParser("proposal");
|
||||
ASSERT_NE(proposal_parser, nullptr);
|
||||
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser("TEST");
|
||||
REG_MODEL_PARSER(FmkType_CAFFE,
|
||||
TestModelParserCreator); // register test model parser creator, which will overwrite existing.
|
||||
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser(FmkType_CAFFE);
|
||||
ASSERT_NE(model_parser, nullptr);
|
||||
Flags flags;
|
||||
auto func_graph = model_parser->Parse(flags);
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
|
||||
using mindspore::lite::ModelRegistrar;
|
||||
namespace mindspore {
|
||||
FuncGraphPtr ModelParserTest::Parse(const lite::converter::Flags &flag) {
|
||||
// construct funcgraph
|
||||
|
@ -169,5 +168,4 @@ lite::ModelParser *TestModelParserCreator() {
|
|||
}
|
||||
return model_parser;
|
||||
}
|
||||
REG_MODEL_PARSER(TEST, TestModelParserCreator);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,6 +40,8 @@ class ModelParserTest : public lite::ModelParser {
|
|||
std::map<std::string, std::vector<std::string>> model_layers_info_;
|
||||
std::vector<std::string> model_structure_;
|
||||
};
|
||||
|
||||
lite::ModelParser *TestModelParserCreator();
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_TEST_UT_TOOLS_CONVERTER_REGISTRY_MODEL_PARSER_TEST_H
|
||||
|
|
|
@ -27,14 +27,18 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "ut/tools/converter/registry/model_parser_test.h"
|
||||
|
||||
using mindspore::lite::ModelRegistrar;
|
||||
using mindspore::lite::converter::Flags;
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
namespace mindspore {
|
||||
class PassRegistryTest : public mindspore::CommonTest {
|
||||
public:
|
||||
PassRegistryTest() = default;
|
||||
void SetUp() override {
|
||||
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser("TEST");
|
||||
REG_MODEL_PARSER(FmkType_CAFFE, TestModelParserCreator);
|
||||
auto model_parser = lite::ModelParserRegistry::GetInstance()->GetModelParser(FmkType_CAFFE);
|
||||
if (model_parser == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
|||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
model_parser_ = ModelParserRegistry::GetInstance()->GetModelParser(flag.fmkIn);
|
||||
model_parser_ = ModelParserRegistry::GetInstance()->GetModelParser(flag.fmk);
|
||||
if (model_parser_ == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "include/registry/framework.h"
|
||||
#include "tools/common/flag_parser.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
@ -31,15 +32,6 @@ using mindspore::schema::QuantType_PostTraining;
|
|||
using mindspore::schema::QuantType_QUANT_NONE;
|
||||
using mindspore::schema::QuantType_WeightQuant;
|
||||
namespace converter {
|
||||
enum FmkType {
|
||||
FmkType_TF = 0,
|
||||
FmkType_CAFFE = 1,
|
||||
FmkType_ONNX = 2,
|
||||
FmkType_MS = 3,
|
||||
FmkType_TFLITE = 4,
|
||||
FmkType_ONNX_LOW_VERSION = 5
|
||||
};
|
||||
|
||||
enum ParallelSplitType { SplitNo = 0, SplitByUserRatio = 1, SplitByUserAttr = 2 };
|
||||
constexpr auto kMaxSplitRatio = 10;
|
||||
constexpr auto kComputeRate = "computeRate";
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_CAFFE;
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
namespace {
|
||||
|
@ -603,5 +604,5 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
|
|||
}
|
||||
return layer.name();
|
||||
}
|
||||
REG_MODEL_PARSER(CAFFE, LiteModelParserCreator<CaffeModelParser>)
|
||||
REG_MODEL_PARSER(FmkType_CAFFE, LiteModelParserCreator<CaffeModelParser>)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "ops/transpose.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_ONNX;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
|
@ -248,11 +249,7 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
|
|||
}
|
||||
OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
|
||||
onnx_root_graph_ = onnx_model_.graph();
|
||||
if (OnnxNodeParser::opset_version() > 15) {
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
||||
} else {
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION)));
|
||||
}
|
||||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS OnnxModelParser::ConvertOnnxGraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph,
|
||||
|
@ -1353,6 +1350,6 @@ int OnnxModelParser::Onnx2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graph
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(ONNX, LiteModelParserCreator<OnnxModelParser>)
|
||||
REG_MODEL_PARSER(FmkType_ONNX, LiteModelParserCreator<OnnxModelParser>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "tools/converter/parser/parser_utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_TF;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
|
@ -1239,6 +1240,6 @@ int TFModelParser::TF2AnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(TF, LiteModelParserCreator<TFModelParser>)
|
||||
REG_MODEL_PARSER(FmkType_TF, LiteModelParserCreator<TFModelParser>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "tools/converter/parser/tflite/tflite_inputs_adjust.h"
|
||||
#include "tools/converter/parser/parser_utils.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType_TFLITE;
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr size_t kConvWeightIndex = 2;
|
||||
|
@ -628,5 +629,5 @@ int TfliteModelParser::Tflite2AnfAdjust(const std::set<FuncGraphPtr> &all_func_g
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_MODEL_PARSER(TFLITE, LiteModelParserCreator<TfliteModelParser>)
|
||||
REG_MODEL_PARSER(FmkType_TFLITE, LiteModelParserCreator<TfliteModelParser>)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -16,7 +16,10 @@
|
|||
|
||||
#include "include/registry/model_parser_registry.h"
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -25,7 +28,7 @@ ModelParserRegistry *ModelParserRegistry::GetInstance() {
|
|||
return &instance;
|
||||
}
|
||||
|
||||
ModelParser *ModelParserRegistry::GetModelParser(const std::string &fmk) {
|
||||
ModelParser *ModelParserRegistry::GetModelParser(const FmkType fmk) {
|
||||
auto it = parsers_.find(fmk);
|
||||
if (it != parsers_.end()) {
|
||||
auto creator = it->second;
|
||||
|
@ -34,9 +37,14 @@ ModelParser *ModelParserRegistry::GetModelParser(const std::string &fmk) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void ModelParserRegistry::RegParser(const std::string &fmk, ModelParserCreator creator) {
|
||||
int ModelParserRegistry::RegParser(const FmkType fmk, ModelParserCreator creator) {
|
||||
if (fmk < converter::FmkType_TF || fmk > converter::FmkType_TFLITE) {
|
||||
MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto instance = ModelParserRegistry::GetInstance();
|
||||
instance->parsers_[fmk] = creator;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
|
|
Loading…
Reference in New Issue