!42198 support online converter

Merge pull request !42198 from 周超/online_converter
This commit is contained in:
i-robot 2022-09-17 07:27:34 +00:00 committed by Gitee
commit 2b2edc30f4
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 152 additions and 81 deletions

View File

@ -19,8 +19,8 @@
#include "tools/converter/converter.h"
#include "tools/converter/cxx_api/converter_para.h"
void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_size,
const std::shared_ptr<mindspore::Context> &context) {
mindspore::api::FuncGraphPtr RuntimeConvert(const char *model_buf, const size_t &buf_size,
const std::shared_ptr<mindspore::Context> &context) {
auto param = std::make_shared<mindspore::ConverterPara>();
if (param == nullptr) {
MS_LOG(ERROR) << "New ConverterPara failed";
@ -33,6 +33,7 @@ void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_
param->train_model = false;
param->export_mindir = mindspore::kMindIR;
param->enable_encryption = false;
param->is_runtime_converter = true;
auto device_list = context->MutableDeviceInfo();
for (auto &device : device_list) {
@ -69,11 +70,11 @@ void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_
}
mindspore::lite::ConverterImpl cvt;
void *dst_buff;
auto ret = cvt.Convert(param, nullptr, model_buf, buf_size, &dst_buff, dst_size);
if (ret != mindspore::lite::RET_OK) {
MS_LOG(ERROR) << "Convert model failed.";
mindspore::FuncGraphPtr graph = cvt.Convert(param, model_buf, buf_size);
if (graph == nullptr) {
MS_LOG(ERROR) << "Convert model failed";
return nullptr;
}
return dst_buff;
auto api_graph = mindspore::api::MakeShared<mindspore::api::FuncGraph>(graph);
return api_graph;
}

View File

@ -20,12 +20,13 @@
#include <string>
#include <memory>
#include "include/api/context.h"
#include "mindapi/ir/func_graph.h"
#ifdef __cplusplus
extern "C" {
#endif
void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_size,
const std::shared_ptr<mindspore::Context> &context);
mindspore::api::FuncGraphPtr RuntimeConvert(const char *model_buf, const size_t &buf_size,
const std::shared_ptr<mindspore::Context> &context);
#ifdef __cplusplus
}
#endif

View File

@ -15,11 +15,11 @@
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_
#include <string>
#if !defined(_WIN32) && !defined(_WIN64)
#include <dlfcn.h>
#include <dirent.h>
#include <memory>
#include <string>
#include <fstream>
#include "utils/file_utils.h"
#include "include/api/status.h"
@ -118,5 +118,16 @@ inline void DLSoClose(void *handle) {
} \
} while (false)
} // namespace mindspore
#else
inline Status DLSoPath(const std::string &benchmark_so, const std::string &target_so, std::string *target_so_path) {
MS_LOG(ERROR) << "Not support dlopen so";
return kMEFailed;
}
inline Status DLSoOpen(const std::string &dl_path, const std::string &func_name, void **handle, void **function,
bool runtime_convert = false) {
MS_LOG(ERROR) << "Not support dlopen so";
return kMEFailed;
}
#endif
#endif // MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_

View File

@ -25,6 +25,8 @@
#include "src/extendrt/convert/runtime_convert.h"
#include "src/common/config_file.h"
#include "src/extendrt/utils/serialization.h"
#include "mindapi/ir/func_graph.h"
#include "mindapi/base/base.h"
namespace mindspore {
namespace {
@ -33,50 +35,21 @@ constexpr size_t kMaxSectionNum = 100;
constexpr size_t kMaxConfigNumPerSection = 1000;
} // namespace
Status ModelImpl::build_by_buffer_impl(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const std::string &model_path) {
Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const std::string &model_path) {
const void *model_buff = model_data;
size_t model_size = data_size;
#ifndef _WIN32
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size)) {
MS_LOG(WARNING) << "Need runtime convert";
std::string plugin_path;
auto ret = DLSoPath("libmindspore-lite.so", "libruntime_convert_plugin.so", &plugin_path);
if (ret != kSuccess) {
MS_LOG(WARNING) << "Get path of libruntime_convert_plugin.so failed. error: " << ret;
}
void *function = nullptr;
ret = DLSoOpen(plugin_path, "RuntimeConvert", &handle_, &function, true);
if (ret != kSuccess) {
MS_LOG(WARNING) << "DLSoOpen RuntimeConvert failed, so path: " << plugin_path;
}
auto convert =
reinterpret_cast<void *(*)(const char *, const size_t &, size_t *, const std::shared_ptr<Context> &)>(function);
if (convert != nullptr) {
model_buff = convert(static_cast<const char *>(model_data), data_size, &model_size, model_context);
}
} else {
MS_LOG(WARNING) << "Not need runtime convert";
}
#endif
auto mindir_path = GetConfig("model_file", "mindir_path");
if (mindir_path == "") {
// user does not set mindir_path, convert from model_path
mindir_path = model_path.substr(0, model_path.rfind("/"));
}
graph_ = std::make_shared<Graph>();
auto ret = mindspore::infer::Serialization::Load(model_buff, model_size, model_type, graph_.get(), Key{},
kDecModeAesGcm, mindir_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Serialization::Load model failed.";
return ret;
}
session_ = InferSession::CreateSession(model_context);
if (session_ == nullptr) {
MS_LOG(ERROR) << "Create session failed.";
return kLiteNullptr;
}
ret = session_->Init(model_context);
auto ret = session_->Init(model_context);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Init session failed.";
return ret;
@ -87,18 +60,57 @@ Status ModelImpl::build_by_buffer_impl(const void *model_data, size_t data_size,
device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
});
}
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size)) {
return CompileGraphOnline(model_data, data_size, model_context);
}
graph_ = std::make_shared<Graph>();
ret = mindspore::infer::Serialization::Load(model_buff, model_size, model_type, graph_.get(), Key{}, kDecModeAesGcm,
mindir_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Serialization::Load model failed.";
return ret;
}
return session_->CompileGraph(graph_->graph_data_->GetFuncGraph());
}
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context) {
return build_by_buffer_impl(model_data, data_size, model_type, model_context);
return BuildByBufferImpl(model_data, data_size, model_type, model_context);
}
Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
const std::shared_ptr<Context> &model_context) {
auto buffer = ReadFile(model_path);
return this->build_by_buffer_impl(buffer.Data(), buffer.DataSize(), model_type, model_context, model_path);
return BuildByBufferImpl(buffer.Data(), buffer.DataSize(), model_type, model_context, model_path);
}
Status ModelImpl::CompileGraphOnline(const void *model_data, size_t data_size,
const std::shared_ptr<Context> &model_context) {
MS_LOG(INFO) << "Need runtime convert";
std::string plugin_path;
auto ret = DLSoPath("libmindspore-lite.so", "libruntime_convert_plugin.so", &plugin_path);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Get path of libruntime_convert_plugin.so failed. error: " << ret;
return kLiteError;
}
void *function = nullptr;
ret = DLSoOpen(plugin_path, "RuntimeConvert", &handle_, &function, true);
if (ret != kSuccess) {
MS_LOG(WARNING) << "DLSoOpen RuntimeConvert failed, so path: " << plugin_path;
return kLiteError;
}
auto convert =
reinterpret_cast<mindspore::api::FuncGraphPtr (*)(const char *, const size_t &, const std::shared_ptr<Context> &)>(
function);
if (convert != nullptr) {
auto api_graph = convert(static_cast<const char *>(model_data), data_size, model_context);
auto impl = api_graph->impl();
auto inner_graph = std::dynamic_pointer_cast<FuncGraph>(impl);
return session_->CompileGraph(inner_graph);
} else {
MS_LOG(ERROR) << "convert is nullptr";
return kLiteNullptr;
}
}
Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {

View File

@ -65,10 +65,9 @@ class ModelImpl {
std::string GetConfig(const std::string &section, const std::string &key);
private:
Status build_by_buffer_impl(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const std::string &model_path = "");
private:
Status BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type,
const std::shared_ptr<Context> &model_context, const std::string &model_path = "");
Status CompileGraphOnline(const void *model_data, size_t data_size, const std::shared_ptr<Context> &model_context);
friend class Model;
friend class Serialization;
std::shared_ptr<Graph> graph_ = nullptr;

View File

@ -676,6 +676,12 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
}
return old_graph;
}
if (param->is_runtime_converter) {
if (RunFormatTrans(old_graph) != RET_OK) {
MS_LOG(ERROR) << "Run format trans failed";
return nullptr;
}
}
if (RunPass(old_graph, param) != RET_OK) {
MS_LOG(ERROR) << "Proc online transform failed.";

View File

@ -21,6 +21,7 @@
#include <set>
#include <tuple>
#include <algorithm>
#include <utility>
#include "src/common/log_adapter.h"
#include "tools/common/meta_graph_serializer.h"
#include "tools/lite_exporter/anf_exporter.h"
@ -217,6 +218,27 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, schema::
return RET_OK;
}
FuncGraphPtr ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, const void *buff, const size_t &size) {
auto graph = BuildFuncGraph(param, buff, size);
MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Build func graph return nullptr.");
auto ret = SaveOutputNames(graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "save output name failed.";
return nullptr;
}
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
graph = funcgraph_transform_->Transform(graph, param);
MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Transform anf graph return nullptr.");
graph->set_attr(kIsOptimized, MakeValue(true));
ret = UpdateFuncGraphInputAndOutputNames(graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Update input and output names of funcgraph failed.";
return nullptr;
}
return graph;
}
schema::MetaGraphT *ConverterImpl::TransferFuncGraph(const std::shared_ptr<ConverterPara> &param,
FuncGraphPtr func_graph) {
MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed");
@ -570,42 +592,58 @@ int ConverterImpl::ReplaceShapeWithDynamicShape(const FuncGraphPtr &graph) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << "get value node failed";
return RET_ERROR;
}
auto value = prim->GetAttr("infer_done");
if (value == nullptr) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " get infer_node attr failed";
auto dynamic_shape_prim = std::make_shared<ops::DynamicShape>();
if (dynamic_shape_prim == nullptr) {
MS_LOG(ERROR) << "Make DynamicShape op failed";
return RET_ERROR;
}
bool infer_node = GetValue<bool>(value);
if (!infer_node) {
auto dynamic_shape_prim = std::make_shared<ops::DynamicShape>();
if (dynamic_shape_prim == nullptr) {
MS_LOG(ERROR) << "Make DynamicShape op failed";
return RET_ERROR;
}
auto dynamic_shape_prim_c = dynamic_shape_prim->GetPrim();
if (dynamic_shape_prim_c == nullptr) {
MS_LOG(ERROR) << "Get the primitive of dynamic shape op failed";
return RET_ERROR;
}
auto inputs = cnode->inputs();
inputs.erase(inputs.begin());
auto dynamic_shape_node = graph->NewCNode(dynamic_shape_prim_c, inputs);
dynamic_shape_node->set_abstract(ori_abstract);
auto manager = Manage(graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "Replace shape node " << cnode->fullname_with_scope() << " failed";
return RET_ERROR;
}
if (!manager->Replace(cnode, dynamic_shape_node)) {
MS_LOG(ERROR) << "Replace shape node " << cnode->fullname_with_scope() << " failed";
return RET_ERROR;
}
auto dynamic_shape_prim_c = dynamic_shape_prim->GetPrim();
if (dynamic_shape_prim_c == nullptr) {
MS_LOG(ERROR) << "Get the primitive of dynamic shape op failed";
return RET_ERROR;
}
auto inputs = cnode->inputs();
inputs.erase(inputs.begin());
auto dynamic_shape_node = graph->NewCNode(dynamic_shape_prim_c, inputs);
dynamic_shape_node->set_abstract(ori_abstract);
auto manager = Manage(graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "Replace shape node " << cnode->fullname_with_scope() << " failed";
return RET_ERROR;
}
if (!manager->Replace(cnode, dynamic_shape_node)) {
MS_LOG(ERROR) << "Replace shape node " << cnode->fullname_with_scope() << " failed";
return RET_ERROR;
}
}
}
return RET_OK;
}
int ConverterImpl::SaveOutputNames(const FuncGraphPtr &graph) {
std::vector<std::pair<AnfNodePtr, int64_t>> outputs;
std::vector<std::string> output_names;
std::vector<std::vector<int64_t>> output_dims;
auto ret = GetFuncGraphOutputsInfo(graph, &outputs, &output_names, &output_dims);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "Get outputs info of funcgraph failed.";
return RET_ERROR;
}
std::vector<std::string> update_output_names;
for (auto &it : outputs) {
if (utils::isa<mindspore::CNodePtr>(it.first)) {
auto cnode = it.first->cast<CNodePtr>();
MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr");
AbstractBasePtr abstract = cnode->abstract();
MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is nullptr");
auto name = abstract->name();
update_output_names.emplace_back(name);
}
}
ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(update_output_names);
return RET_OK;
}
int CheckFmkType(const std::shared_ptr<ConverterPara> &param) {
if (param != nullptr) {
std::set valid_values = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx,

View File

@ -58,6 +58,7 @@ class ConverterImpl {
return FuncGraphConvert(param, graph, meta_graph, true, dst_buff, dst_size);
}
int Convert(const std::shared_ptr<ConverterPara> &param, schema::MetaGraphT **meta_graph, FuncGraphPtr func_graph);
FuncGraphPtr Convert(const std::shared_ptr<ConverterPara> &param, const void *buff, const size_t &size);
private:
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> &param);
@ -73,6 +74,7 @@ class ConverterImpl {
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config);
std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key);
int ReplaceShapeWithDynamicShape(const FuncGraphPtr &graph);
int SaveOutputNames(const FuncGraphPtr &graph);
protected:
converter::ModelParser *model_parser_ = nullptr;

View File

@ -58,6 +58,7 @@ struct ConverterPara {
bool pre_infer = false;
bool train_model = false;
bool no_fusion = false;
bool is_runtime_converter = false;
std::set<std::string> fusion_blacklists;
// inner

View File

@ -59,7 +59,7 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph,
if (value != nullptr) {
is_optimized = GetValue<bool>(value);
}
if (!is_optimized) {
if (!is_optimized && !param->is_runtime_converter) {
auto mindir_adjust_pass = std::make_shared<MindirAdjust>();
MS_CHECK_TRUE_MSG(mindir_adjust_pass != nullptr, RET_NULL_PTR, "mindir_adjust_pass is nullptr.");
mindir_adjust_pass->SetFmkType(param->fmk_type);
@ -327,7 +327,7 @@ FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const std::shared_ptr<Co
if (value != nullptr) {
is_optimized = GetValue<bool>(value);
}
if (!is_optimized) {
if (!is_optimized && !param->is_runtime_converter) {
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, param->train_model);
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "unify_format is nullptr.");
if (!unify_format->Run(func_graph)) {

View File

@ -383,9 +383,6 @@ int MindIRSerializer::IfSaveTogether(bool *save_together) {
}
int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) {
if (isRuntimeConvert_) {
return RET_OK;
}
mind_ir::GraphProto *graph_proto = model_proto->mutable_graph();
mind_ir::AttributeProto *attr_proto = graph_proto->add_attribute();
if (attr_proto != nullptr) {
@ -393,6 +390,9 @@ int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const st
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
attr_proto->set_i(is_fusion_);
}
if (isRuntimeConvert_) {
return RET_OK;
}
auto realpath = Common::CreatePrefixPath(output_file, true);
if (!realpath.has_value()) {