forked from mindspore-Ecosystem/mindspore
!27872 [lite]strengthen external interface
Merge pull request !27872 from 徐安越/master4
This commit is contained in:
commit
cf04d2eb66
|
@ -579,23 +579,25 @@ bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) {
|
|||
}
|
||||
auto fmk = config->fmk;
|
||||
auto is_train = config->trainModel;
|
||||
std::unordered_map<std::string, opt::PassPtr> passes = {
|
||||
{"DumpGraph", std::make_shared<opt::DumpGraph>(config)},
|
||||
{"RemoveRedundantOpPass", std::make_shared<opt::RemoveRedundantOpPass>(config->trainModel)},
|
||||
{"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)},
|
||||
{"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)},
|
||||
{"ConstFoldPass", std::make_shared<opt::ConstFoldPass>(fmk, is_train)},
|
||||
{"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)},
|
||||
{"DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>()},
|
||||
{"SpecialNodePostProcess", std::make_shared<opt::SpecialNodePostProcess>()},
|
||||
{"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train)},
|
||||
{"SpecifyGraphInputFormat", std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat)}};
|
||||
// pass_name, pass and boolean value to indicate whether can be called by external extension,
|
||||
std::vector<std::tuple<std::string, opt::PassPtr, bool>> pass_infos = {
|
||||
{"DumpGraph", std::make_shared<opt::DumpGraph>(config), true},
|
||||
{"RemoveRedundantOpPass", std::make_shared<opt::RemoveRedundantOpPass>(config->trainModel), false},
|
||||
{"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train), true},
|
||||
{"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train), true},
|
||||
{"ConstFoldPass", std::make_shared<opt::ConstFoldPass>(fmk, is_train), true},
|
||||
{"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train), false},
|
||||
{"DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>(), false},
|
||||
{"SpecialNodePostProcess", std::make_shared<opt::SpecialNodePostProcess>(), false},
|
||||
{"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train), true},
|
||||
{"SpecifyGraphInputFormat", std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat), false}};
|
||||
bool succeed_store = true;
|
||||
for (auto iter = passes.begin(); iter != passes.end(); ++iter) {
|
||||
MS_CHECK_TRUE_RET(iter->second != nullptr, false);
|
||||
if (PassStorage::StorePass(iter->first, iter->second) != RET_OK) {
|
||||
MS_LOG(ERROR) << "external pass name conflicts with that of internal pass, the pass name is " << iter->first
|
||||
<< ", please edit external pass name.";
|
||||
for (const auto &pass_info : pass_infos) {
|
||||
MS_CHECK_TRUE_RET(std::get<1>(pass_info) != nullptr, false);
|
||||
if (PassStorage::StorePass(std::get<0>(pass_info), std::get<1>(pass_info),
|
||||
std::get<opt::kInputIndexTwo>(pass_info)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "external pass name conflicts with that of internal pass, the pass name is "
|
||||
<< std::get<0>(pass_info) << ", please edit external pass name.";
|
||||
succeed_store = false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "tools/converter/optimizer_manager.h"
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
@ -25,6 +26,7 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
std::map<std::string, opt::PassPtr> PassStorage::pass_storage_;
|
||||
std::set<std::string> PassStorage::inaccessible_for_outer_;
|
||||
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names) {
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "func graph is nullptr.";
|
||||
|
@ -58,6 +60,12 @@ bool RunExternalPass(const FuncGraphPtr &func_graph, registry::PassPosition posi
|
|||
return false;
|
||||
}
|
||||
auto schedule_task = registry::PassRegistry::GetOuterScheduleTask(position);
|
||||
for (const auto &pass_name : schedule_task) {
|
||||
if (!PassStorage::IsAccessibleForOuter(pass_name)) {
|
||||
MS_LOG(ERROR) << pass_name << " is an inaccessible pass for outer calling.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!RunOptimizerPass(func_graph, schedule_task)) {
|
||||
MS_LOG(WARNING) << "run external scheduled task failed.";
|
||||
return false;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_LITE_TOOLS_CONVERTER_OPTIMIZER_MANAGER_H
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
|
@ -29,17 +30,24 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
class PassStorage {
|
||||
public:
|
||||
static int StorePass(const std::string &pass_name, const opt::PassPtr &pass) {
|
||||
static int StorePass(const std::string &pass_name, const opt::PassPtr &pass, bool access_for_outer) {
|
||||
if (registry::PassRegistry::GetPassFromStoreRoom(pass_name) != nullptr) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
pass_storage_[pass_name] = pass;
|
||||
if (!access_for_outer) {
|
||||
inaccessible_for_outer_.insert(pass_name);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
static opt::PassPtr GetPassFromStorage(const std::string &pass_name) { return pass_storage_[pass_name]; }
|
||||
static bool IsAccessibleForOuter(const std::string &pass_name) {
|
||||
return inaccessible_for_outer_.find(pass_name) == inaccessible_for_outer_.end();
|
||||
}
|
||||
|
||||
private:
|
||||
static std::map<std::string, opt::PassPtr> pass_storage_;
|
||||
static std::set<std::string> inaccessible_for_outer_;
|
||||
};
|
||||
|
||||
bool RunOptimizerPass(const FuncGraphPtr &func_graph, const std::vector<std::string> &pass_names);
|
||||
|
|
|
@ -25,7 +25,13 @@ namespace {
|
|||
std::map<FmkType, ModelParserCreator> model_parser_room;
|
||||
} // namespace
|
||||
|
||||
ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) { model_parser_room[fmk] = creator; }
|
||||
ModelParserRegistry::ModelParserRegistry(FmkType fmk, ModelParserCreator creator) {
|
||||
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) {
|
||||
MS_LOG(ERROR) << "ILLEGAL FMK: fmk must be in FmkType.";
|
||||
return;
|
||||
}
|
||||
model_parser_room[fmk] = creator;
|
||||
}
|
||||
|
||||
converter::ModelParser *ModelParserRegistry::GetModelParser(FmkType fmk) {
|
||||
if (fmk < converter::kFmkTypeTf || fmk > converter::kFmkTypeTflite) {
|
||||
|
|
|
@ -18,10 +18,12 @@
|
|||
#include <map>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include "src/common/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace registry {
|
||||
namespace {
|
||||
constexpr size_t kOpNumLimit = 10000;
|
||||
std::map<converter::FmkType, std::map<std::string, converter::NodeParserPtr>> node_parser_room;
|
||||
std::mutex node_mutex;
|
||||
} // namespace
|
||||
|
@ -29,6 +31,12 @@ NodeParserRegistry::NodeParserRegistry(converter::FmkType fmk_type, const std::v
|
|||
const converter::NodeParserPtr &node_parser) {
|
||||
std::unique_lock<std::mutex> lock(node_mutex);
|
||||
std::string node_type_str = CharToString(node_type);
|
||||
if (node_parser_room.find(fmk_type) != node_parser_room.end()) {
|
||||
if (node_parser_room[fmk_type].size() == kOpNumLimit) {
|
||||
MS_LOG(WARNING) << "Op's number is up to the limitation, The parser will not be registered.";
|
||||
return;
|
||||
}
|
||||
}
|
||||
node_parser_room[fmk_type][node_type_str] = node_parser;
|
||||
}
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
namespace mindspore {
|
||||
namespace registry {
|
||||
namespace {
|
||||
constexpr size_t kPassNumLimit = 10000;
|
||||
std::map<std::string, PassBasePtr> outer_pass_storage;
|
||||
std::map<registry::PassPosition, std::vector<std::string>> external_assigned_passes;
|
||||
std::mutex pass_mutex;
|
||||
|
@ -34,6 +35,10 @@ void RegPass(const std::string &pass_name, const PassBasePtr &pass) {
|
|||
return;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(pass_mutex);
|
||||
if (outer_pass_storage.size() == kPassNumLimit) {
|
||||
MS_LOG(WARNING) << "ass's number is up to the limitation. The pass will not be registered.";
|
||||
return;
|
||||
}
|
||||
outer_pass_storage[pass_name] = pass;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -43,6 +48,10 @@ PassRegistry::PassRegistry(const std::vector<char> &pass_name, const PassBasePtr
|
|||
}
|
||||
|
||||
PassRegistry::PassRegistry(PassPosition position, const std::vector<std::vector<char>> &names) {
|
||||
if (position < POSITION_BEGIN || position > POSITION_END) {
|
||||
MS_LOG(ERROR) << "ILLEGAL position: position must be POSITION_BEGIN or POSITION_END.";
|
||||
return;
|
||||
}
|
||||
std::unique_lock<std::mutex> lock(pass_mutex);
|
||||
external_assigned_passes[position] = VectorCharToString(names);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue