forked from mindspore-Ecosystem/mindspore
registry modelParser and pass
This commit is contained in:
parent
2b8083915e
commit
3869a02fc3
|
@ -217,6 +217,8 @@ elseif(WIN32)
|
|||
install(FILES ${LIB_LIST} DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/mindspore_core/gvar/libmindspore_gvar.dll
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/build/mindspore/tools/converter/registry/libmslite_converter_plugin_reg.dll
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${glog_LIBPATH}/../bin/libglog.dll DESTINATION ${CONVERTER_ROOT_DIR}/lib
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${NNACL_FILES} DESTINATION ${CODEGEN_ROOT_DIR}/include/nnacl COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
@ -286,6 +288,8 @@ else()
|
|||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/mindspore_core/gvar/libmindspore_gvar.so
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/tools/converter/registry/libmslite_converter_plugin_reg.so
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/lib COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${glog_LIBPATH}/libglog.so.0.4.0
|
||||
DESTINATION ${CONVERTER_ROOT_DIR}/third_party/glog/lib RENAME libglog.so.0
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
|
|
|
@ -46,6 +46,7 @@ set(LITE_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/common/file_utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/graph_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/dynamic_library_loader.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/log_adapter.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/string_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/common/prim_util.cc
|
||||
|
|
|
@ -91,8 +91,8 @@ int SimpleTile(void *cdata, int task_id) {
|
|||
void TileCPUKernel::FillOneDimTileParam() {
|
||||
// check if tile exact one dim
|
||||
int large_one_multiple_count = 0;
|
||||
int multiple;
|
||||
int mul_index;
|
||||
int multiple = 0;
|
||||
int mul_index = 0;
|
||||
for (auto i = 0; i < tile_parameter_->in_dim_; ++i) {
|
||||
if (tile_parameter_->multiples_[i] > 1) {
|
||||
large_one_multiple_count++;
|
||||
|
|
|
@ -430,6 +430,7 @@ if(ENABLE_CONVERTER)
|
|||
add_dependencies(lite-test fbs_inner_src)
|
||||
target_link_libraries(lite-test
|
||||
anf_exporter_mid
|
||||
mslite_converter_plugin_reg
|
||||
tflite_parser_mid
|
||||
caffe_parser_mid
|
||||
onnx_parser_mid
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "tools/optimizer/fusion/constant_folding_fusion.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "test/common/import_from_meta_graphT.h"
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "test/common/import_from_meta_graphT.h"
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "test/common/import_from_meta_graphT.h"
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "include/context.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "test/common/import_from_meta_graphT.h"
|
||||
|
|
|
@ -32,6 +32,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/utils.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/dynamic_library_loader.cc
|
||||
|
||||
../optimizer/common/node_pass_extends.cc
|
||||
../optimizer/common/pass_manager_extends.cc
|
||||
|
@ -107,6 +108,7 @@ add_subdirectory(parser/onnx)
|
|||
add_subdirectory(parser/tf)
|
||||
add_subdirectory(legacy_optimizer)
|
||||
add_subdirectory(quantizer)
|
||||
add_subdirectory(registry)
|
||||
add_subdirectory(${CORE_DIR} mindspore_core)
|
||||
|
||||
set(SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../src)
|
||||
|
@ -216,6 +218,7 @@ add_dependencies(converter_lite fbs_src)
|
|||
add_dependencies(converter_lite fbs_inner_src)
|
||||
|
||||
target_link_libraries(converter_lite PRIVATE
|
||||
mslite_converter_plugin_reg
|
||||
tflite_parser_mid
|
||||
tf_parser_mid
|
||||
caffe_parser_mid
|
||||
|
@ -234,3 +237,7 @@ target_link_libraries(converter_lite PRIVATE
|
|||
mindspore::flatbuffers
|
||||
pthread
|
||||
)
|
||||
|
||||
if(NOT WIN32)
|
||||
target_link_libraries(converter_lite PRIVATE dl)
|
||||
endif()
|
||||
|
|
|
@ -70,6 +70,7 @@
|
|||
#include "tools/optimizer/fisson/iter_node_outputs.h"
|
||||
#include "tools/optimizer/fisson/node_out_shapes.h"
|
||||
#include "tools/optimizer/parallel/parallel_pass.h"
|
||||
#include "tools/converter/registry/pass_registry.h"
|
||||
|
||||
using std::string;
|
||||
namespace mindspore::lite {
|
||||
|
@ -323,6 +324,22 @@ int AnfTransform::RunPrecedingPass(const FuncGraphPtr &old_graph, const converte
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS AnfTransform::RunPluginPass(const FuncGraphPtr &old_graph, int position) {
|
||||
auto instance = opt::PassRegistry::GetInstance();
|
||||
auto plugin_passes = instance->GetPasses();
|
||||
if (plugin_passes.find(position) == plugin_passes.end()) {
|
||||
MS_LOG(DEBUG) << "there is no plugin pass in current position.";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto plugin_pass = plugin_passes.at(position);
|
||||
if (!plugin_pass->Run(old_graph)) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
// quant
|
||||
if (config->quantType == schema::QuantType_PostTraining) {
|
||||
|
@ -383,11 +400,28 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
|
|||
MS_LOG(ERROR) << "Run convert pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
status = RunFusionPass(fg, config);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run fusion pass failed.";
|
||||
return nullptr;
|
||||
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
|
||||
format_pass->Init(config->fmk, config->trainModel);
|
||||
if (!format_pass->RunOnlyForShape(old_graph)) {
|
||||
MS_LOG(ERROR) << "Run format pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = RunPluginPass(old_graph, opt::POSITION_BEGIN);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run plugin pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (auto &fg : func_graphs_) {
|
||||
if (!config->disableFusion) {
|
||||
status = RunFusionPass(fg, config);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run fusion pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
status = RunConv1DAdjustPass(fg, config);
|
||||
|
@ -397,13 +431,19 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, con
|
|||
}
|
||||
}
|
||||
|
||||
auto format_pass = std::make_shared<opt::UnifyFormatPass>();
|
||||
format_pass = std::make_shared<opt::UnifyFormatPass>();
|
||||
format_pass->Init(config->fmk, config->trainModel);
|
||||
if (!format_pass->Run(old_graph)) {
|
||||
MS_LOG(ERROR) << "Run format pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = RunPluginPass(old_graph, opt::POSITION_END);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run plugin pass failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (auto &fg : func_graphs_) {
|
||||
status = RunGraphPass(fg, config);
|
||||
if (status != RET_OK) {
|
||||
|
|
|
@ -63,6 +63,8 @@ class AnfTransform {
|
|||
|
||||
static int RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
static STATUS RunPluginPass(const FuncGraphPtr &old_graph, int position);
|
||||
|
||||
int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
||||
void GetAllFuncGraph(const FuncGraphPtr &func_graph);
|
||||
|
|
|
@ -20,33 +20,37 @@
|
|||
#include "tools/converter/converter_flags.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/storage.h"
|
||||
#include "parser/caffe/caffe_converter.h"
|
||||
#include "parser/tflite/tflite_converter.h"
|
||||
#include "parser/onnx/onnx_converter.h"
|
||||
#include "parser/tf/tf_converter.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "include/version.h"
|
||||
#include "src/train/train_populate_parameter.h"
|
||||
#include "tools/converter/registry/model_parser_registry.h"
|
||||
#include "src/common/dynamic_library_loader.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
using FmkType = converter::FmkType;
|
||||
FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
||||
if (flag.fmkIn == "MINDIR") {
|
||||
kernel::PopulateTrainParameters();
|
||||
auto func_graph = LoadMindIR(flag.modelFile);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "get funcgraph failed.";
|
||||
return nullptr;
|
||||
}
|
||||
func_graph->set_attr("graph_name", MakeValue("main_graph"));
|
||||
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS)));
|
||||
|
||||
MindsporeImporter::MindsporeImporter() { kernel::PopulateTrainParameters(); }
|
||||
|
||||
std::unique_ptr<Converter> Converter::CreateConverter(converter::FmkType fmk) {
|
||||
switch (fmk) {
|
||||
case FmkType::FmkType_MS:
|
||||
return std::make_unique<MindsporeImporter>();
|
||||
case FmkType::FmkType_CAFFE:
|
||||
return std::make_unique<CaffeConverter>();
|
||||
case FmkType::FmkType_TFLITE:
|
||||
return std::make_unique<TfliteConverter>();
|
||||
case FmkType::FmkType_ONNX:
|
||||
return std::make_unique<OnnxConverter>();
|
||||
case FmkType::FmkType_TF:
|
||||
return std::make_unique<TFConverter>();
|
||||
default: {
|
||||
auto status = UpdateFuncGraphInputsAndOutputsDtype(func_graph);
|
||||
if (RET_OK != status) {
|
||||
MS_LOG(ERROR) << "update graph inputs and outputs dtype failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return func_graph;
|
||||
} else {
|
||||
model_parser_ = ModelParserRegistry::GetInstance()->GetModelParser(flag.fmkIn);
|
||||
if (model_parser_ != nullptr) {
|
||||
return model_parser_->Parse(flag.modelFile, flag.weightFile);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "get funcGraph failed for fmk:" << flag.fmkIn;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
@ -57,7 +61,21 @@ schema::MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &
|
|||
MS_LOG(ERROR) << "Input flag is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto graph = BuildFuncGraph(flag->modelFile, flag->weightFile, flag->quantType);
|
||||
|
||||
// load plugin
|
||||
if (!flag->pluginsPath.empty()) {
|
||||
DynamicLibraryLoader dynamic_library_loader{};
|
||||
for (auto &path : flag->pluginsPath) {
|
||||
auto status = dynamic_library_loader.Open(path.c_str());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "open dynamic library failed.";
|
||||
return nullptr;
|
||||
}
|
||||
dynamic_library_loader.Close();
|
||||
}
|
||||
}
|
||||
|
||||
auto graph = BuildFuncGraph(*flag);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Parser/Import model return nullptr";
|
||||
return nullptr;
|
||||
|
@ -110,16 +128,8 @@ int RunConverter(int argc, const char **argv) {
|
|||
}
|
||||
// Load graph
|
||||
MS_LOG(DEBUG) << "start reading model file";
|
||||
auto converter = Converter::CreateConverter(flags->fmk);
|
||||
if (converter == nullptr) {
|
||||
oss.clear();
|
||||
oss << "UNSUPPORTED FMKTYPE " << flags->fmk << ":" << RET_INPUT_PARAM_INVALID << " "
|
||||
<< GetErrorInfo(RET_INPUT_PARAM_INVALID);
|
||||
MS_LOG(ERROR) << oss.str();
|
||||
std::cout << oss.str() << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
auto meta_graph = converter->Convert(flags);
|
||||
Converter cvt;
|
||||
auto meta_graph = cvt.Convert(flags);
|
||||
NotSupportOp::GetInstance()->PrintOps();
|
||||
status = ReturnCode::GetSingleReturnCode()->status_code();
|
||||
if (meta_graph == nullptr) {
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/registry/model_parser_registry.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/anf_transform.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
|
@ -32,47 +33,17 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
class Converter {
|
||||
public:
|
||||
static std::unique_ptr<Converter> CreateConverter(converter::FmkType fmk);
|
||||
|
||||
virtual ~Converter() = default;
|
||||
|
||||
virtual schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag);
|
||||
|
||||
virtual FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
schema::QuantType quant_type) = 0;
|
||||
Converter() = default;
|
||||
~Converter() { delete model_parser_; }
|
||||
schema::MetaGraphT *Convert(const std::unique_ptr<converter::Flags> &flag);
|
||||
FuncGraphPtr BuildFuncGraph(const converter::Flags &flag);
|
||||
|
||||
protected:
|
||||
Converter() = default;
|
||||
|
||||
ModelParser *model_parser_ = nullptr;
|
||||
std::unique_ptr<GraphDefTransform> metagraph_transform_ = std::make_unique<GraphDefTransform>();
|
||||
std::unique_ptr<AnfTransform> funcgraph_transform_ = std::make_unique<AnfTransform>();
|
||||
};
|
||||
|
||||
class MindsporeImporter : public Converter {
|
||||
public:
|
||||
MindsporeImporter();
|
||||
|
||||
~MindsporeImporter() override = default;
|
||||
|
||||
FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
schema::QuantType quant_type) override {
|
||||
auto func_graph = LoadMindIR(model_file);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "get funcgraph failed.";
|
||||
return nullptr;
|
||||
}
|
||||
func_graph->set_attr("graph_name", MakeValue("main_graph"));
|
||||
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS)));
|
||||
|
||||
auto status = UpdateFuncGraphInputsAndOutputsDtype(func_graph);
|
||||
if (RET_OK != status) {
|
||||
MS_LOG(ERROR) << "update graph inputs and outputs dtype failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return func_graph;
|
||||
}
|
||||
};
|
||||
|
||||
int RunConverter(int argc, const char **argv);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,9 +15,13 @@
|
|||
*/
|
||||
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include <regex>
|
||||
#include <climits>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -180,6 +184,28 @@ int Flags::InitTrainModel() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::InitConfigFile() {
|
||||
auto plugins_path_str = GetStrFromConfigFile(this->configFile, "plugin_path");
|
||||
if (!plugins_path_str.empty()) {
|
||||
const char *delimiter = ";";
|
||||
this->pluginsPath = SplitStringToVector(plugins_path_str, *delimiter);
|
||||
}
|
||||
|
||||
auto disable_fusion_flag = GetStrFromConfigFile(this->configFile, "disable_fusion");
|
||||
if (!disable_fusion_flag.empty()) {
|
||||
if (disable_fusion_flag == "on") {
|
||||
this->disableFusion = true;
|
||||
} else if (disable_fusion_flag == "off") {
|
||||
this->disableFusion = false;
|
||||
} else {
|
||||
std::cerr << "CONFIG SETTING ILLEGAL: disable_fusion should be on/off";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::Init(int argc, const char **argv) {
|
||||
int ret;
|
||||
if (argc == 1) {
|
||||
|
@ -222,6 +248,14 @@ int Flags::Init(int argc, const char **argv) {
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (!this->configFile.empty()) {
|
||||
ret = InitConfigFile();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init config file failed.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
ret = InitInputOutputDataType();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init input output datatype failed.";
|
||||
|
@ -248,6 +282,80 @@ int Flags::Init(int argc, const char **argv) {
|
|||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key) {
|
||||
std::string res;
|
||||
if (file.empty()) {
|
||||
MS_LOG(ERROR) << "file is nullptr";
|
||||
return res;
|
||||
}
|
||||
|
||||
auto resolved_path = std::make_unique<char[]>(PATH_MAX);
|
||||
if (resolved_path == nullptr) {
|
||||
MS_LOG(ERROR) << "new resolved_path failed";
|
||||
return "";
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
char *real_path = _fullpath(resolved_path.get(), file.c_str(), 1024);
|
||||
#else
|
||||
char *real_path = realpath(file.c_str(), resolved_path.get());
|
||||
#endif
|
||||
if (real_path == nullptr || strlen(real_path) == 0) {
|
||||
MS_LOG(ERROR) << "file path is not valid : " << file;
|
||||
return "";
|
||||
}
|
||||
std::ifstream ifs(resolved_path.get());
|
||||
if (!ifs.good()) {
|
||||
MS_LOG(ERROR) << "file: " << real_path << " is not exist";
|
||||
return res;
|
||||
}
|
||||
if (!ifs.is_open()) {
|
||||
MS_LOG(ERROR) << "file: " << real_path << "open failed";
|
||||
return res;
|
||||
}
|
||||
std::string line;
|
||||
while (std::getline(ifs, line)) {
|
||||
lite::Trim(&line);
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto index = line.find('=');
|
||||
if (index == std::string::npos) {
|
||||
MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
|
||||
return "";
|
||||
}
|
||||
auto key = line.substr(0, index);
|
||||
auto value = line.substr(index + 1);
|
||||
lite::Trim(&key);
|
||||
lite::Trim(&value);
|
||||
if (key == target_key) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const char &delimiter) {
|
||||
if (raw_str.empty()) {
|
||||
MS_LOG(ERROR) << "input string is empty.";
|
||||
return {};
|
||||
}
|
||||
std::vector<std::string> res;
|
||||
std::string::size_type last_pos = 0;
|
||||
auto cur_pos = raw_str.find(delimiter);
|
||||
while (cur_pos != std::string::npos) {
|
||||
res.push_back(raw_str.substr(last_pos, cur_pos - last_pos));
|
||||
cur_pos++;
|
||||
last_pos = cur_pos;
|
||||
cur_pos = raw_str.find(delimiter, cur_pos);
|
||||
}
|
||||
if (last_pos < raw_str.size()) {
|
||||
res.push_back(raw_str.substr(last_pos, raw_str.size() - last_pos + 1));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace converter
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#define MINDSPORE_LITE_TOOLS_CONVERTER_CONVERTER_FLAGS_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tools/common/flag_parser.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
@ -57,6 +58,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
|
||||
int InitTrainModel();
|
||||
|
||||
int InitConfigFile();
|
||||
|
||||
int Init(int argc, const char **argv);
|
||||
|
||||
public:
|
||||
|
@ -83,7 +86,13 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
int quantWeightChannel;
|
||||
std::string trainModelIn;
|
||||
bool trainModel = false;
|
||||
std::vector<std::string> pluginsPath;
|
||||
bool disableFusion = false;
|
||||
};
|
||||
|
||||
std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key);
|
||||
|
||||
std::vector<std::string> SplitStringToVector(const std::string &raw_str, const char &delimiter);
|
||||
} // namespace converter
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,8 +34,8 @@ class ModelParser {
|
|||
|
||||
virtual ~ModelParser() = default;
|
||||
|
||||
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) {
|
||||
auto ret = ParseToFuncGraph(model_file, weight_file, quant_type);
|
||||
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file) {
|
||||
auto ret = ParseToFuncGraph(model_file, weight_file);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse to func graph failed : " << ret;
|
||||
return nullptr;
|
||||
|
@ -49,14 +49,25 @@ class ModelParser {
|
|||
}
|
||||
|
||||
protected:
|
||||
virtual int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) = 0;
|
||||
virtual int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) = 0;
|
||||
|
||||
virtual int PostAdjust() = 0;
|
||||
|
||||
protected:
|
||||
FuncGraphPtr res_graph_ = nullptr;
|
||||
};
|
||||
|
||||
typedef ModelParser *(*ModelParserCreator)();
|
||||
|
||||
template <class T>
|
||||
ModelParser *LiteModelParserCreator() {
|
||||
auto *parser = new (std::nothrow) T();
|
||||
if (parser == nullptr) {
|
||||
MS_LOG(ERROR) << "new model parser failed";
|
||||
return nullptr;
|
||||
}
|
||||
return parser;
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif
|
||||
|
|
|
@ -1,41 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "tools/converter/converter.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "tools/converter/parser/caffe/caffe_model_parser.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class CaffeConverter : public Converter {
|
||||
public:
|
||||
CaffeConverter() = default;
|
||||
|
||||
~CaffeConverter() override = default;
|
||||
|
||||
FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
schema::QuantType quant_type) override {
|
||||
CaffeModelParser parser;
|
||||
return parser.Parse(model_file, weight_file, quant_type);
|
||||
}
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_CONVERTER_H_
|
|
@ -26,6 +26,8 @@
|
|||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
bool IsSkipedLayer(const caffe::LayerParameter &layer) {
|
||||
|
@ -56,8 +58,7 @@ CaffeModelParser::CaffeModelParser() = default;
|
|||
|
||||
CaffeModelParser::~CaffeModelParser() = default;
|
||||
|
||||
int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) {
|
||||
STATUS status = InitOriginModel(model_file, weight_file);
|
||||
if (status != RET_OK) {
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -493,4 +494,6 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
|
|||
}
|
||||
|
||||
int CaffeModelParser::PostAdjust() { return RET_OK; }
|
||||
|
||||
REG_MODEL_PARSER(CAFFE, LiteModelParserCreator<CaffeModelParser>)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -22,9 +22,11 @@
|
|||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/registry/model_parser_registry.h"
|
||||
#include "proto/caffe.pb.h"
|
||||
#include "ops/primitive_c.h"
|
||||
|
||||
using STATUS = int;
|
||||
namespace mindspore::lite {
|
||||
class CaffeModelParser : public ModelParser {
|
||||
public:
|
||||
|
@ -32,8 +34,7 @@ class CaffeModelParser : public ModelParser {
|
|||
|
||||
~CaffeModelParser() override;
|
||||
|
||||
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) override;
|
||||
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) override;
|
||||
|
||||
int PostAdjust() override;
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ STATUS OnnxConstantParser::AddDataInfoAttr(const onnx::TensorProto &onnx_const_t
|
|||
std::vector<int64_t> shape_vector(onnx_const_tensor.dims().begin(), onnx_const_tensor.dims().end());
|
||||
auto tensor_info = std::make_shared<tensor::Tensor>(data_type, shape_vector);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "new a paramValueLite failed.";
|
||||
MS_LOG(ERROR) << "new a tensor::Tensor failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (OnnxModelParser::CopyOnnxTensorData(onnx_const_tensor, tensor_info) != RET_OK) {
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "tools/converter/converter.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxConverter : public Converter {
|
||||
public:
|
||||
OnnxConverter() = default;
|
||||
|
||||
~OnnxConverter() override = default;
|
||||
|
||||
FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
schema::QuantType quant_type) override {
|
||||
OnnxModelParser parser;
|
||||
return parser.Parse(model_file, weight_file, quant_type);
|
||||
}
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_CONVERTER_H
|
|
@ -21,7 +21,6 @@
|
|||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/protobuf_utils.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
@ -29,6 +28,8 @@
|
|||
#include "ops/tensor_list_stack.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -43,8 +44,7 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
|||
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
||||
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
||||
|
||||
int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) {
|
||||
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
|
||||
res_graph_ = std::make_shared<FuncGraph>();
|
||||
auto status = InitOriginModel(model_file);
|
||||
|
@ -723,7 +723,7 @@ ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) {
|
|||
return const_node;
|
||||
}
|
||||
|
||||
ValueNodePtr CreateValueNode(const PrimitiveType &op_type) {
|
||||
ValueNodePtr CreateValueNode(const schema::PrimitiveType &op_type) {
|
||||
auto node_type = schema::EnumNamePrimitiveType(op_type);
|
||||
auto op_primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
|
||||
if (op_primc_fns.find(node_type) == op_primc_fns.end()) {
|
||||
|
@ -1178,5 +1178,7 @@ TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type
|
|||
}
|
||||
|
||||
int OnnxModelParser::PostAdjust() { return 0; }
|
||||
|
||||
REG_MODEL_PARSER(ONNX, LiteModelParserCreator<OnnxModelParser>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,8 +29,10 @@
|
|||
#include <unordered_map>
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/registry/model_parser_registry.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -40,8 +42,7 @@ class OnnxModelParser : public ModelParser {
|
|||
|
||||
~OnnxModelParser() override = default;
|
||||
|
||||
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) override;
|
||||
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) override;
|
||||
|
||||
int PostAdjust() override;
|
||||
|
||||
|
@ -73,9 +74,9 @@ class OnnxModelParser : public ModelParser {
|
|||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, const CNodePtr &cnode);
|
||||
STATUS ConvertOpQuantParams(const onnx::NodeProto &onnx_node, ops::PrimitiveC *primitive_c);
|
||||
STATUS ParseQuantParam(const onnx::NodeProto &onnx_node);
|
||||
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
|
||||
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<QuantParamT> *quant_params);
|
||||
STATUS CopyTensorQuantParam(const std::string &tensor_name, QuantParamT *quant_param, bool scale_or_not);
|
||||
STATUS SetTensorQuantParam(const std::string &tensor_name, std::vector<schema::QuantParamT> *quant_params);
|
||||
STATUS SetTensorQuantParamFromNode(const std::string &tensor_name, std::vector<schema::QuantParamT> *quant_params);
|
||||
STATUS CopyTensorQuantParam(const std::string &tensor_name, schema::QuantParamT *quant_param, bool scale_or_not);
|
||||
STATUS ConvertLoopOnnxNode(const onnx::NodeProto &onnx_node,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map,
|
||||
const std::string &root_node_name);
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "tools/converter/converter.h"
|
||||
#include "tools/converter/parser/tf/tf_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TFConverter : public Converter {
|
||||
public:
|
||||
TFConverter() = default;
|
||||
|
||||
~TFConverter() override = default;
|
||||
|
||||
FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
schema::QuantType quant_type) override {
|
||||
TFModelParser parser;
|
||||
return parser.Parse(model_file, weight_file, quant_type);
|
||||
}
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_CONVERTER_H_
|
|
@ -28,6 +28,7 @@
|
|||
#include "ir/anf.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
@ -472,8 +473,7 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile,
|
||||
const QuantType &quantType) {
|
||||
int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile) {
|
||||
NotSupportOp::GetInstance()->set_fmk_type("TF");
|
||||
auto status = ValidateFileStr(modelFile, ".pb");
|
||||
if (status != RET_OK) {
|
||||
|
@ -559,8 +559,8 @@ int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::str
|
|||
|
||||
STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_sub_node_map,
|
||||
const tensorflow::FunctionDef &tf_sub_fuction, CNodePtr cnode,
|
||||
FuncGraphPtr sub_func_graph) {
|
||||
const tensorflow::FunctionDef &tf_sub_fuction, const CNodePtr &cnode,
|
||||
const FuncGraphPtr &sub_func_graph) {
|
||||
std::vector<ParameterPtr> sub_graph_inputs;
|
||||
auto &tf_sub_signature = tf_sub_fuction.signature();
|
||||
auto &sub_graph_name = tf_sub_signature.name();
|
||||
|
@ -598,7 +598,7 @@ STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorfl
|
|||
STATUS TFModelParser::ConvertSubgraphOutputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
|
||||
const std::unordered_map<std::string, AnfNodePtr> &anf_sub_node_map,
|
||||
const tensorflow::FunctionDef &tf_sub_fuction,
|
||||
FuncGraphPtr sub_func_graph) {
|
||||
const FuncGraphPtr &sub_func_graph) {
|
||||
auto &tf_sub_signature = tf_sub_fuction.signature();
|
||||
auto &sub_graph_name = tf_sub_signature.name();
|
||||
|
||||
|
@ -949,7 +949,7 @@ STATUS TFModelParser::ConvertOps(const tensorflow::NodeDef &node_def,
|
|||
return status;
|
||||
}
|
||||
|
||||
STATUS TFModelParser::ProcessControlFlowOp(CNodePtr anf_node, const string op_type,
|
||||
STATUS TFModelParser::ProcessControlFlowOp(const CNodePtr &anf_node, const string &op_type,
|
||||
const tensorflow::NodeDef &node_def) {
|
||||
if (op_type == "StatelessWhile" || op_type == "While") {
|
||||
MS_LOG(INFO) << "find while node:" << node_def.name();
|
||||
|
@ -1077,5 +1077,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes,
|
|||
}
|
||||
|
||||
int TFModelParser::PostAdjust() { return 0; }
|
||||
|
||||
REG_MODEL_PARSER(TF, LiteModelParserCreator<TFModelParser>)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "securec/include/securec.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/registry/model_parser_registry.h"
|
||||
#include "ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -38,14 +39,15 @@ class TFModelParser : public ModelParser {
|
|||
TFModelParser() = default;
|
||||
~TFModelParser() override = default;
|
||||
|
||||
int ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType);
|
||||
int ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile) override;
|
||||
|
||||
int PostAdjust() override;
|
||||
|
||||
private:
|
||||
static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info);
|
||||
STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value,
|
||||
const TypeId &type, const ParameterPtr ¶meter, std::vector<int64_t> *shape_vector);
|
||||
static STATUS ConvertConstTensor(const tensorflow::NodeDef &node_def, const tensorflow::AttrValue &attr_value,
|
||||
const TypeId &type, const ParameterPtr ¶meter,
|
||||
std::vector<int64_t> *shape_vector);
|
||||
static STATUS SetTensorInfoFromType(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info);
|
||||
STATUS ConvertParameter(const tensorflow::NodeDef &node, const ParameterPtr ¶meter,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_node_map);
|
||||
|
@ -63,7 +65,7 @@ class TFModelParser : public ModelParser {
|
|||
const std::map<std::string, const tensorflow::NodeDef *> &tf_node_map,
|
||||
const FuncGraphPtr &func_graph_ptr, std::unordered_map<std::string, AnfNodePtr> *anf_node_map);
|
||||
|
||||
STATUS ProcessControlFlowOp(CNodePtr anf_node, const string op_type, const tensorflow::NodeDef &node_def);
|
||||
STATUS ProcessControlFlowOp(const CNodePtr &anf_node, const string &op_type, const tensorflow::NodeDef &node_def);
|
||||
|
||||
STATUS ConvertRootGraphOutputs();
|
||||
|
||||
|
@ -71,17 +73,18 @@ class TFModelParser : public ModelParser {
|
|||
|
||||
STATUS ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
|
||||
std::unordered_map<std::string, AnfNodePtr> *anf_sub_node_map,
|
||||
const tensorflow::FunctionDef &tf_sub_fuction, CNodePtr cnode,
|
||||
FuncGraphPtr sub_func_graph);
|
||||
const tensorflow::FunctionDef &tf_sub_fuction, const CNodePtr &cnode,
|
||||
const FuncGraphPtr &sub_func_graph);
|
||||
|
||||
static STATUS ConvertSubgraphOutputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
|
||||
const std::unordered_map<std::string, AnfNodePtr> &anf_sub_node_map,
|
||||
const tensorflow::FunctionDef &tf_sub_fuction, FuncGraphPtr sub_func_graph);
|
||||
const tensorflow::FunctionDef &tf_sub_fuction,
|
||||
const FuncGraphPtr &sub_func_graph);
|
||||
|
||||
STATUS ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGraphPtr> &first_func_map,
|
||||
const std::map<CNodePtr, FuncGraphPtr> &second_func_map);
|
||||
|
||||
STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, ops::PrimitiveC *primitive_c);
|
||||
static STATUS ConvertQuantParams(const size_t &input_size, const size_t &output_size, ops::PrimitiveC *primitive_c);
|
||||
|
||||
static STATUS MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes, const FuncGraphPtr &anf_graph);
|
||||
|
||||
|
|
|
@ -1,42 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "tools/converter/converter.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "tools/converter/parser/tflite/tflite_model_parser.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class TfliteConverter : public Converter {
|
||||
public:
|
||||
TfliteConverter() = default;
|
||||
|
||||
~TfliteConverter() override = default;
|
||||
|
||||
FuncGraphPtr BuildFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
schema::QuantType quant_type) override {
|
||||
TfliteModelParser parser;
|
||||
return parser.Parse(model_file, weight_file, quant_type);
|
||||
}
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TFLITE_CONVERTER_H_
|
|
@ -19,12 +19,14 @@
|
|||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "tools/converter/converter_flags.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *model_path) {
|
||||
|
@ -42,8 +44,7 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *m
|
|||
return tflite::UnPackModel(tflite_model_buf_);
|
||||
}
|
||||
|
||||
int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) {
|
||||
int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) {
|
||||
// load graph
|
||||
tflite_model_ = ReadTfliteModel(model_file.c_str());
|
||||
if (tflite_model_ == nullptr) {
|
||||
|
@ -489,4 +490,6 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
|
|||
}
|
||||
|
||||
int TfliteModelParser::PostAdjust() { return 0; }
|
||||
|
||||
REG_MODEL_PARSER(TFLITE, LiteModelParserCreator<TfliteModelParser>)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include "tools/converter/model_parser.h"
|
||||
#include "tools/converter/registry/model_parser_registry.h"
|
||||
#include "tools/converter/parser/tflite/tflite_node_parser_registry.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
||||
|
@ -32,8 +33,7 @@ class TfliteModelParser : public ModelParser {
|
|||
|
||||
~TfliteModelParser() override = default;
|
||||
|
||||
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||
const QuantType &quant_type) override;
|
||||
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file) override;
|
||||
|
||||
int PostAdjust() override;
|
||||
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
file(GLOB CONVERT_REG_SRC
|
||||
pass_registry.cc
|
||||
model_parser_registry.cc
|
||||
)
|
||||
|
||||
set_property(SOURCE ${CONVERT_REG_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
add_library(mslite_converter_plugin_reg SHARED ${CONVERT_REG_SRC})
|
||||
|
||||
add_dependencies(mslite_converter_plugin_reg fbs_src)
|
||||
add_dependencies(mslite_converter_plugin_reg fbs_inner_src)
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/registry/model_parser_registry.h"
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
ModelParserRegistry *ModelParserRegistry::GetInstance() {
|
||||
static ModelParserRegistry instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
ModelParser *ModelParserRegistry::GetModelParser(const std::string &fmk) {
|
||||
auto it = parsers_.find(fmk);
|
||||
if (it != parsers_.end()) {
|
||||
auto creator = it->second;
|
||||
return creator();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ModelParserRegistry::RegParser(const std::string &fmk, ModelParserCreator creator) {
|
||||
auto instance = ModelParserRegistry::GetInstance();
|
||||
instance->parsers_[fmk] = creator;
|
||||
}
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_REGISTRY_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_REGISTRY_H
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "include/lite_utils.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class MS_API ModelParser;
|
||||
typedef ModelParser *(*ModelParserCreator)();
|
||||
|
||||
class MS_API ModelParserRegistry {
|
||||
public:
|
||||
ModelParserRegistry() = default;
|
||||
~ModelParserRegistry() = default;
|
||||
|
||||
static ModelParserRegistry *GetInstance();
|
||||
ModelParser *GetModelParser(const std::string &fmk);
|
||||
void RegParser(const std::string &fmk, ModelParserCreator creator);
|
||||
|
||||
std::unordered_map<std::string, ModelParserCreator> parsers_;
|
||||
};
|
||||
|
||||
class MS_API ModelRegistrar {
|
||||
public:
|
||||
ModelRegistrar(const std::string &fmk, ModelParserCreator creator) {
|
||||
ModelParserRegistry::GetInstance()->RegParser(fmk, creator);
|
||||
}
|
||||
~ModelRegistrar() = default;
|
||||
};
|
||||
|
||||
#define REG_MODEL_PARSER(fmk, parserCreator) static ModelRegistrar g_##type##fmk##ModelParserReg(#fmk, parserCreator);
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_MODEL_PARSER_REGISTRY_H
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/registry/pass_registry.h"
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
PassRegistry *PassRegistry::GetInstance() {
|
||||
static PassRegistry instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
void PassRegistry::RegPass(int position, const PassPtr &pass) {
|
||||
if (pass == nullptr) {
|
||||
std::cout << "pass is nullptr" << std::endl;
|
||||
return;
|
||||
}
|
||||
auto instance = PassRegistry::GetInstance();
|
||||
std::unique_lock<std::mutex> lock(instance->mutex_);
|
||||
instance->passes_[position] = pass;
|
||||
}
|
||||
|
||||
const std::unordered_map<int, PassPtr> &PassRegistry::GetPasses() const {
|
||||
auto instance = PassRegistry::GetInstance();
|
||||
std::unique_lock<std::mutex> lock(instance->mutex_);
|
||||
return instance->passes_;
|
||||
}
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_REGISTRY_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_REGISTRY_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "include/lite_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
enum MS_API PassPosition { POSITION_BEGIN = 0, POSITION_END = 1 };
|
||||
|
||||
class MS_API Pass;
|
||||
using PassPtr = std::shared_ptr<Pass>;
|
||||
class MS_API PassRegistry {
|
||||
public:
|
||||
virtual ~PassRegistry() = default;
|
||||
static PassRegistry *GetInstance();
|
||||
void RegPass(int position, const PassPtr &pass);
|
||||
const std::unordered_map<int, PassPtr> &GetPasses() const;
|
||||
|
||||
private:
|
||||
PassRegistry() = default;
|
||||
|
||||
private:
|
||||
std::unordered_map<int, PassPtr> passes_;
|
||||
std::mutex mutex_;
|
||||
};
|
||||
|
||||
class MS_API PassRegistrar {
|
||||
public:
|
||||
PassRegistrar(int pos, const PassPtr &pass) { PassRegistry::GetInstance()->RegPass(pos, pass); }
|
||||
~PassRegistrar() = default;
|
||||
};
|
||||
|
||||
#define REG_PASS(position, pass) static PassRegistrar g_##position##PassReg(position, std::make_shared<pass>());
|
||||
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_REGISTRY_PASS_REGISTRY_H_
|
Loading…
Reference in New Issue