!42198 support online converter

Merge pull request !42198 from 周超/online_converter
This commit is contained in:
i-robot 2022-09-17 07:27:34 +00:00 committed by Gitee
commit 2b2edc30f4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 152 additions and 81 deletions

View File

@ -19,7 +19,7 @@
#include "tools/converter/converter.h" #include "tools/converter/converter.h"
#include "tools/converter/cxx_api/converter_para.h" #include "tools/converter/cxx_api/converter_para.h"
void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_size, mindspore::api::FuncGraphPtr RuntimeConvert(const char *model_buf, const size_t &buf_size,
const std::shared_ptr<mindspore::Context> &context) { const std::shared_ptr<mindspore::Context> &context) {
auto param = std::make_shared<mindspore::ConverterPara>(); auto param = std::make_shared<mindspore::ConverterPara>();
if (param == nullptr) { if (param == nullptr) {
@ -33,6 +33,7 @@ void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_
param->train_model = false; param->train_model = false;
param->export_mindir = mindspore::kMindIR; param->export_mindir = mindspore::kMindIR;
param->enable_encryption = false; param->enable_encryption = false;
param->is_runtime_converter = true;
auto device_list = context->MutableDeviceInfo(); auto device_list = context->MutableDeviceInfo();
for (auto &device : device_list) { for (auto &device : device_list) {
@ -69,11 +70,11 @@ void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_
} }
mindspore::lite::ConverterImpl cvt; mindspore::lite::ConverterImpl cvt;
void *dst_buff; mindspore::FuncGraphPtr graph = cvt.Convert(param, model_buf, buf_size);
auto ret = cvt.Convert(param, nullptr, model_buf, buf_size, &dst_buff, dst_size); if (graph == nullptr) {
if (ret != mindspore::lite::RET_OK) { MS_LOG(ERROR) << "Convert model failed";
MS_LOG(ERROR) << "Convert model failed.";
return nullptr; return nullptr;
} }
return dst_buff; auto api_graph = mindspore::api::MakeShared<mindspore::api::FuncGraph>(graph);
return api_graph;
} }

View File

@ -20,11 +20,12 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include "include/api/context.h" #include "include/api/context.h"
#include "mindapi/ir/func_graph.h"
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_size, mindspore::api::FuncGraphPtr RuntimeConvert(const char *model_buf, const size_t &buf_size,
const std::shared_ptr<mindspore::Context> &context); const std::shared_ptr<mindspore::Context> &context);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -15,11 +15,11 @@
*/ */
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_ #ifndef MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_ #define MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_
#include <string>
#if !defined(_WIN32) && !defined(_WIN64) #if !defined(_WIN32) && !defined(_WIN64)
#include <dlfcn.h> #include <dlfcn.h>
#include <dirent.h> #include <dirent.h>
#include <memory> #include <memory>
#include <string>
#include <fstream> #include <fstream>
#include "utils/file_utils.h" #include "utils/file_utils.h"
#include "include/api/status.h" #include "include/api/status.h"
@ -118,5 +118,16 @@ inline void DLSoClose(void *handle) {
} \ } \
} while (false) } while (false)
} // namespace mindspore } // namespace mindspore
#else
inline Status DLSoPath(const std::string &benchmark_so, const std::string &target_so, std::string *target_so_path) {
MS_LOG(ERROR) << "Not support dlopen so";
return kMEFailed;
}
inline Status DLSoOpen(const std::string &dl_path, const std::string &func_name, void **handle, void **function,
bool runtime_convert = false) {
MS_LOG(ERROR) << "Not support dlopen so";
return kMEFailed;
}
#endif #endif
#endif // MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_ #endif // MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_

View File

@ -25,6 +25,8 @@
#include "src/extendrt/convert/runtime_convert.h" #include "src/extendrt/convert/runtime_convert.h"
#include "src/common/config_file.h" #include "src/common/config_file.h"
#include "src/extendrt/utils/serialization.h" #include "src/extendrt/utils/serialization.h"
#include "mindapi/ir/func_graph.h"
#include "mindapi/base/base.h"
namespace mindspore { namespace mindspore {
namespace { namespace {
@ -33,50 +35,21 @@ constexpr size_t kMaxSectionNum = 100;
constexpr size_t kMaxConfigNumPerSection = 1000; constexpr size_t kMaxConfigNumPerSection = 1000;
} // namespace } // namespace
Status ModelImpl::build_by_buffer_impl(const void *model_data, size_t data_size, ModelType model_type, Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const std::string &model_path) { const std::shared_ptr<Context> &model_context, const std::string &model_path) {
const void *model_buff = model_data; const void *model_buff = model_data;
size_t model_size = data_size; size_t model_size = data_size;
#ifndef _WIN32
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size)) {
MS_LOG(WARNING) << "Need runtime convert";
std::string plugin_path;
auto ret = DLSoPath("libmindspore-lite.so", "libruntime_convert_plugin.so", &plugin_path);
if (ret != kSuccess) {
MS_LOG(WARNING) << "Get path of libruntime_convert_plugin.so failed. error: " << ret;
}
void *function = nullptr;
ret = DLSoOpen(plugin_path, "RuntimeConvert", &handle_, &function, true);
if (ret != kSuccess) {
MS_LOG(WARNING) << "DLSoOpen RuntimeConvert failed, so path: " << plugin_path;
}
auto convert =
reinterpret_cast<void *(*)(const char *, const size_t &, size_t *, const std::shared_ptr<Context> &)>(function);
if (convert != nullptr) {
model_buff = convert(static_cast<const char *>(model_data), data_size, &model_size, model_context);
}
} else {
MS_LOG(WARNING) << "Not need runtime convert";
}
#endif
auto mindir_path = GetConfig("model_file", "mindir_path"); auto mindir_path = GetConfig("model_file", "mindir_path");
if (mindir_path == "") { if (mindir_path == "") {
// user does not set mindir_path, convert from model_path // user does not set mindir_path, convert from model_path
mindir_path = model_path.substr(0, model_path.rfind("/")); mindir_path = model_path.substr(0, model_path.rfind("/"));
} }
graph_ = std::make_shared<Graph>();
auto ret = mindspore::infer::Serialization::Load(model_buff, model_size, model_type, graph_.get(), Key{},
kDecModeAesGcm, mindir_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Serialization::Load model failed.";
return ret;
}
session_ = InferSession::CreateSession(model_context); session_ = InferSession::CreateSession(model_context);
if (session_ == nullptr) { if (session_ == nullptr) {
MS_LOG(ERROR) << "Create session failed."; MS_LOG(ERROR) << "Create session failed.";
return kLiteNullptr; return kLiteNullptr;
} }
ret = session_->Init(model_context); auto ret = session_->Init(model_context);
if (ret != kSuccess) { if (ret != kSuccess) {
MS_LOG(ERROR) << "Init session failed."; MS_LOG(ERROR) << "Init session failed.";
return ret; return ret;
@ -87,18 +60,57 @@ Status ModelImpl::build_by_buffer_impl(const void *model_data, size_t data_size,
device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice)); device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
}); });
} }
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size)) {
return CompileGraphOnline(model_data, data_size, model_context);
}
graph_ = std::make_shared<Graph>();
ret = mindspore::infer::Serialization::Load(model_buff, model_size, model_type, graph_.get(), Key{}, kDecModeAesGcm,
mindir_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Serialization::Load model failed.";
return ret;
}
return session_->CompileGraph(graph_->graph_data_->GetFuncGraph()); return session_->CompileGraph(graph_->graph_data_->GetFuncGraph());
} }
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type, Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context) { const std::shared_ptr<Context> &model_context) {
return build_by_buffer_impl(model_data, data_size, model_type, model_context); return BuildByBufferImpl(model_data, data_size, model_type, model_context);
} }
Status ModelImpl::Build(const std::string &model_path, ModelType model_type, Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context) { const std::shared_ptr<Context> &model_context) {
auto buffer = ReadFile(model_path); auto buffer = ReadFile(model_path);
return this->build_by_buffer_impl(buffer.Data(), buffer.DataSize(), model_type, model_context, model_path); return BuildByBufferImpl(buffer.Data(), buffer.DataSize(), model_type, model_context, model_path);
}
Status ModelImpl::CompileGraphOnline(const void *model_data, size_t data_size,
const std::shared_ptr<Context> &model_context) {
MS_LOG(INFO) << "Need runtime convert";
std::string plugin_path;
auto ret = DLSoPath("libmindspore-lite.so", "libruntime_convert_plugin.so", &plugin_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Get path of libruntime_convert_plugin.so failed. error: " << ret;
return kLiteError;
}
void *function = nullptr;
ret = DLSoOpen(plugin_path, "RuntimeConvert", &handle_, &function, true);
if (ret != kSuccess) {
MS_LOG(WARNING) << "DLSoOpen RuntimeConvert failed, so path: " << plugin_path;
return kLiteError;
}
auto convert =
reinterpret_cast<mindspore::api::FuncGraphPtr (*)(const char *, const size_t &, const std::shared_ptr<Context> &)>(
function);
if (convert != nullptr) {
auto api_graph = convert(static_cast<const char *>(model_data), data_size, model_context);
auto impl = api_graph->impl();
auto inner_graph = std::dynamic_pointer_cast<FuncGraph>(impl);
return session_->CompileGraph(inner_graph);
} else {
MS_LOG(ERROR) << "convert is nullptr";
return kLiteNullptr;
}
} }
Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) { Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {

View File

@ -65,10 +65,9 @@ class ModelImpl {
std::string GetConfig(const std::string &section, const std::string &key); std::string GetConfig(const std::string &section, const std::string &key);
private: private:
Status build_by_buffer_impl(const void *model_data, size_t data_size, ModelType model_type, Status BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const std::string &model_path = ""); const std::shared_ptr<Context> &model_context, const std::string &model_path = "");
Status CompileGraphOnline(const void *model_data, size_t data_size, const std::shared_ptr<Context> &model_context);
private:
friend class Model; friend class Model;
friend class Serialization; friend class Serialization;
std::shared_ptr<Graph> graph_ = nullptr; std::shared_ptr<Graph> graph_ = nullptr;

View File

@ -676,6 +676,12 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
} }
return old_graph; return old_graph;
} }
if (param->is_runtime_converter) {
if (RunFormatTrans(old_graph) != RET_OK) {
MS_LOG(ERROR) << "Run format trans failed";
return nullptr;
}
}
if (RunPass(old_graph, param) != RET_OK) { if (RunPass(old_graph, param) != RET_OK) {
MS_LOG(ERROR) << "Proc online transform failed."; MS_LOG(ERROR) << "Proc online transform failed.";

View File

@ -21,6 +21,7 @@
#include <set> #include <set>
#include <tuple> #include <tuple>
#include <algorithm> #include <algorithm>
#include <utility>
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "tools/common/meta_graph_serializer.h" #include "tools/common/meta_graph_serializer.h"
#include "tools/lite_exporter/anf_exporter.h" #include "tools/lite_exporter/anf_exporter.h"
@ -217,6 +218,27 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, schema::
return RET_OK; return RET_OK;
} }
FuncGraphPtr ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, const void *buff, const size_t &size) {
auto graph = BuildFuncGraph(param, buff, size);
MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Build func graph return nullptr.");
auto ret = SaveOutputNames(graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "save output name failed.";
return nullptr;
}
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
graph = funcgraph_transform_->Transform(graph, param);
MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Transform anf graph return nullptr.");
graph->set_attr(kIsOptimized, MakeValue(true));
ret = UpdateFuncGraphInputAndOutputNames(graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Update input and output names of funcgraph failed.";
return nullptr;
}
return graph;
}
schema::MetaGraphT *ConverterImpl::TransferFuncGraph(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT *ConverterImpl::TransferFuncGraph(const std::shared_ptr<ConverterPara> &param,
FuncGraphPtr func_graph) { FuncGraphPtr func_graph) {
MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed"); MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed");
@ -570,13 +592,6 @@ int ConverterImpl::ReplaceShapeWithDynamicShape(const FuncGraphPtr &graph) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << "get value node failed"; MS_LOG(ERROR) << cnode->fullname_with_scope() << "get value node failed";
return RET_ERROR; return RET_ERROR;
} }
auto value = prim->GetAttr("infer_done");
if (value == nullptr) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " get infer_node attr failed";
return RET_ERROR;
}
bool infer_node = GetValue<bool>(value);
if (!infer_node) {
auto dynamic_shape_prim = std::make_shared<ops::DynamicShape>(); auto dynamic_shape_prim = std::make_shared<ops::DynamicShape>();
if (dynamic_shape_prim == nullptr) { if (dynamic_shape_prim == nullptr) {
MS_LOG(ERROR) << "Make DynamicShape op failed"; MS_LOG(ERROR) << "Make DynamicShape op failed";
@ -602,7 +617,30 @@ int ConverterImpl::ReplaceShapeWithDynamicShape(const FuncGraphPtr &graph) {
} }
} }
} }
return RET_OK;
} }
int ConverterImpl::SaveOutputNames(const FuncGraphPtr &graph) {
std::vector<std::pair<AnfNodePtr, int64_t>> outputs;
std::vector<std::string> output_names;
std::vector<std::vector<int64_t>> output_dims;
auto ret = GetFuncGraphOutputsInfo(graph, &outputs, &output_names, &output_dims);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Get outputs info of funcgraph failed.";
return RET_ERROR;
}
std::vector<std::string> update_output_names;
for (auto &it : outputs) {
if (utils::isa<mindspore::CNodePtr>(it.first)) {
auto cnode = it.first->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr");
AbstractBasePtr abstract = cnode->abstract();
MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is nullptr");
auto name = abstract->name();
update_output_names.emplace_back(name);
}
}
ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(update_output_names);
return RET_OK; return RET_OK;
} }

View File

@ -58,6 +58,7 @@ class ConverterImpl {
return FuncGraphConvert(param, graph, meta_graph, true, dst_buff, dst_size); return FuncGraphConvert(param, graph, meta_graph, true, dst_buff, dst_size);
} }
int Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph, FuncGraphPtr func_graph); int Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph, FuncGraphPtr func_graph);
FuncGraphPtr Convert(const std::shared_ptr<ConverterPara> &param, const void *buff, const size_t &size);
private: private:
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> &param); FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> &param);
@ -73,6 +74,7 @@ class ConverterImpl {
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config); bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config);
std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key); std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key);
int ReplaceShapeWithDynamicShape(const FuncGraphPtr &graph); int ReplaceShapeWithDynamicShape(const FuncGraphPtr &graph);
int SaveOutputNames(const FuncGraphPtr &graph);
protected: protected:
converter::ModelParser *model_parser_ = nullptr; converter::ModelParser *model_parser_ = nullptr;

View File

@ -58,6 +58,7 @@ struct ConverterPara {
bool pre_infer = false; bool pre_infer = false;
bool train_model = false; bool train_model = false;
bool no_fusion = false; bool no_fusion = false;
bool is_runtime_converter = false;
std::set<std::string> fusion_blacklists; std::set<std::string> fusion_blacklists;
// inner // inner

View File

@ -59,7 +59,7 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph,
if (value != nullptr) { if (value != nullptr) {
is_optimized = GetValue<bool>(value); is_optimized = GetValue<bool>(value);
} }
if (!is_optimized) { if (!is_optimized && !param->is_runtime_converter) {
auto mindir_adjust_pass = std::make_shared<MindirAdjust>(); auto mindir_adjust_pass = std::make_shared<MindirAdjust>();
MS_CHECK_TRUE_MSG(mindir_adjust_pass != nullptr, RET_NULL_PTR, "mindir_adjust_pass is nullptr."); MS_CHECK_TRUE_MSG(mindir_adjust_pass != nullptr, RET_NULL_PTR, "mindir_adjust_pass is nullptr.");
mindir_adjust_pass->SetFmkType(param->fmk_type); mindir_adjust_pass->SetFmkType(param->fmk_type);
@ -327,7 +327,7 @@ FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const std::shared_ptr<Co
if (value != nullptr) { if (value != nullptr) {
is_optimized = GetValue<bool>(value); is_optimized = GetValue<bool>(value);
} }
if (!is_optimized) { if (!is_optimized && !param->is_runtime_converter) {
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, param->train_model); auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, param->train_model);
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "unify_format is nullptr."); MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "unify_format is nullptr.");
if (!unify_format->Run(func_graph)) { if (!unify_format->Run(func_graph)) {

View File

@ -383,9 +383,6 @@ int MindIRSerializer::IfSaveTogether(bool *save_together) {
} }
int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) { int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) {
if (isRuntimeConvert_) {
return RET_OK;
}
mind_ir::GraphProto *graph_proto = model_proto->mutable_graph(); mind_ir::GraphProto *graph_proto = model_proto->mutable_graph();
mind_ir::AttributeProto *attr_proto = graph_proto->add_attribute(); mind_ir::AttributeProto *attr_proto = graph_proto->add_attribute();
if (attr_proto != nullptr) { if (attr_proto != nullptr) {
@ -393,6 +390,9 @@ int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const st
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL); attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
attr_proto->set_i(is_fusion_); attr_proto->set_i(is_fusion_);
} }
if (isRuntimeConvert_) {
return RET_OK;
}
auto realpath = Common::CreatePrefixPath(output_file, true); auto realpath = Common::CreatePrefixPath(output_file, true);
if (!realpath.has_value()) { if (!realpath.has_value()) {