!27872 [lite]strengthen external interface

Merge pull request !27872 from 徐安越/master4
This commit is contained in:
i-robot 2021-12-18 13:19:19 +00:00 committed by Gitee
commit cf04d2eb66
6 changed files with 59 additions and 18 deletions

View File

@ -579,23 +579,25 @@ bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) {
}
auto fmk = config->fmk;
auto is_train = config->trainModel;
std::unordered_map<std::string, opt::PassPtr> passes = {
{"DumpGraph", std::make_shared<opt::DumpGraph>(config)},
{"RemoveRedundantOpPass", std::make_shared<opt::RemoveRedundantOpPass>(config->trainModel)},
{"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)},
{"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)},
{"ConstFoldPass", std::make_shared<opt::ConstFoldPass>(fmk, is_train)},
{"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)},
{"DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>()},
{"SpecialNodePostProcess", std::make_shared<opt::SpecialNodePostProcess>()},
{"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train)},
{"SpecifyGraphInputFormat", std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat)}};
// pass_name, pass and boolean value to indicate whether can be called by external extension,
std::vector<std::tuple<std::string, opt::PassPtr, bool>> pass_infos = {
{"DumpGraph", std::make_shared<opt::DumpGraph>(config), true},
{"RemoveRedundantOpPass", std::make_shared<opt::RemoveRedundantOpPass>(config->trainModel), false},
{"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train), true},
{"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train), true},
{"ConstFoldPass", std::make_shared<opt::ConstFoldPass>(fmk, is_train), true},
{"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train), false},
{"DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>(), false},
{"SpecialNodePostProcess", std::make_shared<opt::SpecialNodePostProcess>(), false},
{"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train), true},
{"SpecifyGraphInputFormat", std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat), false}};
bool succeed_store = true;
for (auto iter = passes.begin(); iter != passes.end(); ++iter) {
MS_CHECK_TRUE_RET(iter->second != nullptr, false);
if (PassStorage::StorePass(iter->first, iter->second) != RET_OK) {
MS_LOG(ERROR) << "external pass name conflicts with that of internal pass, the pass name is " << iter->first
<< ", please edit external pass name.";
for (const auto &pass_info : pass_infos) {
MS_CHECK_TRUE_RET(std::get<1>(pass_info) != nullptr, false);
if (PassStorage::StorePass(std::get<0>(pass_info), std::get<1>(pass_info),
std::get<opt::kInputIndexTwo>(pass_info)) != RET_OK) {
MS_LOG(ERROR) << "external pass name conflicts with that of internal pass, the pass name is "
<< std::get<0>(pass_info) << ", please edit external pass name.";
succeed_store = false;
}
}

View File

@ -16,6 +16,7 @@
#include "tools/converter/optimizer_manager.h"
#include <map>
#include <set>
#include <string>
#include <vector>
#include "backend/optimizer/common/pass.h"
@ -25,6 +26,7 @@
namespace mindspore {
namespace lite {
std::map<std::string, opt::PassPtr> PassStorage::pass_storage_;
std::set<std::string> PassStorage::inaccessible_for_outer_;
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names) {
if (func_graph == nullptr) {
MS_LOG(ERROR) << "func graph is nullptr.";
@ -58,6 +60,12 @@ bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition posi
return false;
}
auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(position);
for (const auto &pass_name : schedule_task) {
if (!PassStorage::IsAccessibleForOuter(pass_name)) {
MS_LOG(ERROR) << pass_name << " is an inaccessible pass for outer calling.";
return false;
}
}
if (!RunOptimizerPass(func_graph, schedule_task)) {
MS_LOG(WARNING) << "run external scheduled task failed.";
return false;

View File

@ -18,6 +18,7 @@
#define MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H
#include <map>
#include <set>
#include <string>
#include <vector>
#include "backend/optimizer/common/pass.h"
@ -29,17 +30,24 @@ namespace mindspore {
namespace lite {
class PassStorage {
public:
static int StorePass(const std::string &pass_name, const opt::PassPtr &pass) {
static int StorePass(const std::string &pass_name, const opt::PassPtr &pass, bool access_for_outer) {
if (registry::PassRegistry::GetPassFromStoreRoom(pass_name) != nullptr) {
return RET_ERROR;
}
pass_storage_[pass_name] = pass;
if (!access_for_outer) {
inaccessible_for_outer_.insert(pass_name);
}
return RET_OK;
}
static opt::PassPtr GetPassFromStorage(const std::string &pass_name) { return pass_storage_[pass_name]; }
static bool IsAccessibleForOuter(const std::string &pass_name) {
return inaccessible_for_outer_.find(pass_name) == inaccessible_for_outer_.end();
}
private:
static std::map<std::string, opt::PassPtr> pass_storage_;
static std::set<std::string> inaccessible_for_outer_;
};
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names);

View File

@ -25,7 +25,13 @@ namespace {
std::map<FmkType, ModelParserCreator> model_parser_room;
} // namespace
ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) { model_parser_room[fmk] = creator; }
ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) {
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) {
MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
return;
}
model_parser_room[fmk] = creator;
}
converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) {
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) {

View File

@ -18,10 +18,12 @@
#include <map>
#include <mutex>
#include <string>
#include "src/common/log_adapter.h"
namespace mindspore {
namespace registry {
namespace {
constexpr size_t kOpNumLimit = 10000;
std::map<converter::FmkType, std::map<std::string, converter::NodeParserPtr>> node_parser_room;
std::mutex node_mutex;
} // namespace
@ -29,6 +31,12 @@ NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::v
const converter::NodeParserPtr &node_parser) {
std::unique_lock<std::mutex> lock(node_mutex);
std::string node_type_str = CharToString(node_type);
if (node_parser_room.find(fmk_type) != node_parser_room.end()) {
if (node_parser_room[fmk_type].size() == kOpNumLimit) {
MS_LOG(WARNING) << "Op's number is up to the limitation, The parser will not be registered.";
return;
}
}
node_parser_room[fmk_type][node_type_str] = node_parser;
}

View File

@ -25,6 +25,7 @@
namespace mindspore {
namespace registry {
namespace {
constexpr size_t kPassNumLimit = 10000;
std::map<std::string, PassBasePtr> outer_pass_storage;
std::map<registry::PassPosition, std::vector<std::string>> external_assigned_passes;
std::mutex pass_mutex;
@ -34,6 +35,10 @@ void RegPass(const std::string &pass_name, const PassBasePtr &pass) {
return;
}
std::unique_lock<std::mutex> lock(pass_mutex);
if (outer_pass_storage.size() == kPassNumLimit) {
MS_LOG(WARNING) << "ass's number is up to the limitation. The pass will not be registered.";
return;
}
outer_pass_storage[pass_name] = pass;
}
} // namespace
@ -43,6 +48,10 @@ PassRegistry::PassRegistry(const std::vector<char> &pass_name, const PassBasePtr
}
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::vector<char>> &names) {
if (position < POSITION_BEGIN || position > POSITION_END) {
MS_LOG(ERROR) << "ILLEGAL position: position must be POSITION_BEGIN or POSITION_END.";
return;
}
std::unique_lock<std::mutex> lock(pass_mutex);
external_assigned_passes[position] = VectorCharToString(names);
}