diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 285ef02330a..7f8ab779d1c 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -579,23 +579,25 @@ bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) { } auto fmk = config->fmk; auto is_train = config->trainModel; - std::unordered_map passes = { - {"DumpGraph", std::make_shared(config)}, - {"RemoveRedundantOpPass", std::make_shared(config->trainModel)}, - {"ToNCHWFormat", std::make_shared(fmk, is_train)}, - {"ToNHWCFormat", std::make_shared(fmk, is_train)}, - {"ConstFoldPass", std::make_shared(fmk, is_train)}, - {"InferShapePass", std::make_shared(fmk, is_train)}, - {"DeleteRedundantTranspose", std::make_shared()}, - {"SpecialNodePostProcess", std::make_shared()}, - {"DecreaseTransposeAlgo", std::make_shared(fmk, is_train)}, - {"SpecifyGraphInputFormat", std::make_shared(config->graphInputFormat)}}; + // pass_name, pass and boolean value to indicate whether can be called by external extension, + std::vector> pass_infos = { + {"DumpGraph", std::make_shared(config), true}, + {"RemoveRedundantOpPass", std::make_shared(config->trainModel), false}, + {"ToNCHWFormat", std::make_shared(fmk, is_train), true}, + {"ToNHWCFormat", std::make_shared(fmk, is_train), true}, + {"ConstFoldPass", std::make_shared(fmk, is_train), true}, + {"InferShapePass", std::make_shared(fmk, is_train), false}, + {"DeleteRedundantTranspose", std::make_shared(), false}, + {"SpecialNodePostProcess", std::make_shared(), false}, + {"DecreaseTransposeAlgo", std::make_shared(fmk, is_train), true}, + {"SpecifyGraphInputFormat", std::make_shared(config->graphInputFormat), false}}; bool succeed_store = true; - for (auto iter = passes.begin(); iter != passes.end(); ++iter) { - MS_CHECK_TRUE_RET(iter->second != nullptr, false); - if (PassStorage::StorePass(iter->first, iter->second) != RET_OK) { - MS_LOG(ERROR) << "external pass name conflicts with that of internal pass, the pass name is " << iter->first - << ", please edit external pass name."; + for (const auto &pass_info : pass_infos) { + MS_CHECK_TRUE_RET(std::get<1>(pass_info) != nullptr, false); + if (PassStorage::StorePass(std::get<0>(pass_info), std::get<1>(pass_info), + std::get(pass_info)) != RET_OK) { + MS_LOG(ERROR) << "external pass name conflicts with that of internal pass, the pass name is " + << std::get<0>(pass_info) << ", please edit external pass name."; succeed_store = false; } } diff --git a/mindspore/lite/tools/converter/optimizer_manager.cc b/mindspore/lite/tools/converter/optimizer_manager.cc index 4812e5027cc..137d0a74924 100644 --- a/mindspore/lite/tools/converter/optimizer_manager.cc +++ b/mindspore/lite/tools/converter/optimizer_manager.cc @@ -16,6 +16,7 @@ #include "tools/converter/optimizer_manager.h" #include +#include #include #include #include "backend/optimizer/common/pass.h" @@ -25,6 +26,7 @@ namespace mindspore { namespace lite { std::map PassStorage::pass_storage_; +std::set PassStorage::inaccessible_for_outer_; bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector &pass_names) { if (func_graph == nullptr) { MS_LOG(ERROR) << "func graph is nullptr."; @@ -58,6 +60,12 @@ bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition posi return false; } auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(position); + for (const auto &pass_name : schedule_task) { + if (!PassStorage::IsAccessibleForOuter(pass_name)) { + MS_LOG(ERROR) << pass_name << " is an inaccessible pass for outer calling."; + return false; + } + } if (!RunOptimizerPass(func_graph, schedule_task)) { MS_LOG(WARNING) << "run external scheduled task failed."; return false; diff --git a/mindspore/lite/tools/converter/optimizer_manager.h b/mindspore/lite/tools/converter/optimizer_manager.h index 0c5e60199ef..7e6f589c95a 100644 --- a/mindspore/lite/tools/converter/optimizer_manager.h +++ b/mindspore/lite/tools/converter/optimizer_manager.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H #include +#include #include #include #include "backend/optimizer/common/pass.h" @@ -29,17 +30,24 @@ namespace mindspore { namespace lite { class PassStorage { public: - static int StorePass(const std::string &pass_name, const opt::PassPtr &pass) { + static int StorePass(const std::string &pass_name, const opt::PassPtr &pass, bool access_for_outer) { if (registry::PassRegistry::GetPassFromStoreRoom(pass_name) != nullptr) { return RET_ERROR; } pass_storage_[pass_name] = pass; + if (!access_for_outer) { + inaccessible_for_outer_.insert(pass_name); + } return RET_OK; } static opt::PassPtr GetPassFromStorage(const std::string &pass_name) { return pass_storage_[pass_name]; } + static bool IsAccessibleForOuter(const std::string &pass_name) { + return inaccessible_for_outer_.find(pass_name) == inaccessible_for_outer_.end(); + } private: static std::map pass_storage_; + static std::set inaccessible_for_outer_; }; bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector &pass_names); diff --git a/mindspore/lite/tools/converter/registry/model_parser_registry.cc b/mindspore/lite/tools/converter/registry/model_parser_registry.cc index 1a79d58897a..cff1580bf51 100644 --- a/mindspore/lite/tools/converter/registry/model_parser_registry.cc +++ b/mindspore/lite/tools/converter/registry/model_parser_registry.cc @@ -25,7 +25,13 @@ namespace { std::map model_parser_room; } // namespace -ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) { model_parser_room[fmk] = creator; } +ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) { + if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) { + MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType."; + return; + } + model_parser_room[fmk] = creator; +} converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) { if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) { diff --git a/mindspore/lite/tools/converter/registry/node_parser_registry.cc b/mindspore/lite/tools/converter/registry/node_parser_registry.cc index fb7b1077866..4910f254262 100644 --- a/mindspore/lite/tools/converter/registry/node_parser_registry.cc +++ b/mindspore/lite/tools/converter/registry/node_parser_registry.cc @@ -18,10 +18,12 @@ #include #include #include +#include "src/common/log_adapter.h" namespace mindspore { namespace registry { namespace { +constexpr size_t kOpNumLimit = 10000; std::map> node_parser_room; std::mutex node_mutex; } // namespace @@ -29,6 +31,12 @@ NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::v const converter::NodeParserPtr &node_parser) { std::unique_lock lock(node_mutex); std::string node_type_str = CharToString(node_type); + if (node_parser_room.find(fmk_type) != node_parser_room.end()) { + if (node_parser_room[fmk_type].size() == kOpNumLimit) { + MS_LOG(WARNING) << "Op's number is up to the limitation, The parser will not be registered."; + return; + } + } node_parser_room[fmk_type][node_type_str] = node_parser; } diff --git a/mindspore/lite/tools/converter/registry/pass_registry.cc b/mindspore/lite/tools/converter/registry/pass_registry.cc index 5d365ff3818..5a508a3f54c 100644 --- a/mindspore/lite/tools/converter/registry/pass_registry.cc +++ b/mindspore/lite/tools/converter/registry/pass_registry.cc @@ -25,6 +25,7 @@ namespace mindspore { namespace registry { namespace { +constexpr size_t kPassNumLimit = 10000; std::map outer_pass_storage; std::map> external_assigned_passes; std::mutex pass_mutex; @@ -34,6 +35,10 @@ void RegPass(const std::string &pass_name, const PassBasePtr &pass) { return; } std::unique_lock lock(pass_mutex); + if (outer_pass_storage.size() == kPassNumLimit) { + MS_LOG(WARNING) << "ass's number is up to the limitation. The pass will not be registered."; + return; + } outer_pass_storage[pass_name] = pass; } } // namespace @@ -43,6 +48,10 @@ PassRegistry::PassRegistry(const std::vector &pass_name, const PassBasePtr } PassRegistry::PassRegistry(PassPosition position, const std::vector> &names) { + if (position < POSITION_BEGIN || position > POSITION_END) { + MS_LOG(ERROR) << "ILLEGAL position: position must be POSITION_BEGIN or POSITION_END."; + return; + } std::unique_lock lock(pass_mutex); external_assigned_passes[position] = VectorCharToString(names); }