!42198 support online converter
Merge pull request !42198 from 周超/online_converter
This commit is contained in:
commit
2b2edc30f4
|
@ -19,8 +19,8 @@
|
|||
#include "tools/converter/converter.h"
|
||||
#include "tools/converter/cxx_api/converter_para.h"
|
||||
|
||||
void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_size,
|
||||
const std::shared_ptr<mindspore::Context> &context) {
|
||||
mindspore::api::FuncGraphPtr RuntimeConvert(const char *model_buf, const size_t &buf_size,
|
||||
const std::shared_ptr<mindspore::Context> &context) {
|
||||
auto param = std::make_shared<mindspore::ConverterPara>();
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "New ConverterPara failed";
|
||||
|
@ -33,6 +33,7 @@ void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_
|
|||
param->train_model = false;
|
||||
param->export_mindir = mindspore::kMindIR;
|
||||
param->enable_encryption = false;
|
||||
param->is_runtime_converter = true;
|
||||
|
||||
auto device_list = context->MutableDeviceInfo();
|
||||
for (auto &device : device_list) {
|
||||
|
@ -69,11 +70,11 @@ void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_
|
|||
}
|
||||
|
||||
mindspore::lite::ConverterImpl cvt;
|
||||
void *dst_buff;
|
||||
auto ret = cvt.Convert(param, nullptr, model_buf, buf_size, &dst_buff, dst_size);
|
||||
if (ret != mindspore::lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Convert model failed.";
|
||||
mindspore::FuncGraphPtr graph = cvt.Convert(param, model_buf, buf_size);
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert model failed";
|
||||
return nullptr;
|
||||
}
|
||||
return dst_buff;
|
||||
auto api_graph = mindspore::api::MakeShared<mindspore::api::FuncGraph>(graph);
|
||||
return api_graph;
|
||||
}
|
||||
|
|
|
@ -20,12 +20,13 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/context.h"
|
||||
#include "mindapi/ir/func_graph.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *dst_size,
|
||||
const std::shared_ptr<mindspore::Context> &context);
|
||||
mindspore::api::FuncGraphPtr RuntimeConvert(const char *model_buf, const size_t &buf_size,
|
||||
const std::shared_ptr<mindspore::Context> &context);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_
|
||||
#include <string>
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#include <dlfcn.h>
|
||||
#include <dirent.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include "utils/file_utils.h"
|
||||
#include "include/api/status.h"
|
||||
|
@ -118,5 +118,16 @@ inline void DLSoClose(void *handle) {
|
|||
} \
|
||||
} while (false)
|
||||
} // namespace mindspore
|
||||
#else
|
||||
inline Status DLSoPath(const std::string &benchmark_so, const std::string &target_so, std::string *target_so_path) {
|
||||
MS_LOG(ERROR) << "Not support dlopen so";
|
||||
return kMEFailed;
|
||||
}
|
||||
|
||||
inline Status DLSoOpen(const std::string &dl_path, const std::string &func_name, void **handle, void **function,
|
||||
bool runtime_convert = false) {
|
||||
MS_LOG(ERROR) << "Not support dlopen so";
|
||||
return kMEFailed;
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_CXX_API_DLUTILS_H_
|
||||
|
|
|
@ -25,6 +25,8 @@
|
|||
#include "src/extendrt/convert/runtime_convert.h"
|
||||
#include "src/common/config_file.h"
|
||||
#include "src/extendrt/utils/serialization.h"
|
||||
#include "mindapi/ir/func_graph.h"
|
||||
#include "mindapi/base/base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
|
@ -33,50 +35,21 @@ constexpr size_t kMaxSectionNum = 100;
|
|||
constexpr size_t kMaxConfigNumPerSection = 1000;
|
||||
} // namespace
|
||||
|
||||
Status ModelImpl::build_by_buffer_impl(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const std::string &model_path) {
|
||||
Status ModelImpl::BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const std::string &model_path) {
|
||||
const void *model_buff = model_data;
|
||||
size_t model_size = data_size;
|
||||
#ifndef _WIN32
|
||||
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size)) {
|
||||
MS_LOG(WARNING) << "Need runtime convert";
|
||||
std::string plugin_path;
|
||||
auto ret = DLSoPath("libmindspore-lite.so", "libruntime_convert_plugin.so", &plugin_path);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(WARNING) << "Get path of libruntime_convert_plugin.so failed. error: " << ret;
|
||||
}
|
||||
void *function = nullptr;
|
||||
ret = DLSoOpen(plugin_path, "RuntimeConvert", &handle_, &function, true);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(WARNING) << "DLSoOpen RuntimeConvert failed, so path: " << plugin_path;
|
||||
}
|
||||
auto convert =
|
||||
reinterpret_cast<void *(*)(const char *, const size_t &, size_t *, const std::shared_ptr<Context> &)>(function);
|
||||
if (convert != nullptr) {
|
||||
model_buff = convert(static_cast<const char *>(model_data), data_size, &model_size, model_context);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Not need runtime convert";
|
||||
}
|
||||
#endif
|
||||
auto mindir_path = GetConfig("model_file", "mindir_path");
|
||||
if (mindir_path == "") {
|
||||
// user does not set mindir_path, convert from model_path
|
||||
mindir_path = model_path.substr(0, model_path.rfind("/"));
|
||||
}
|
||||
graph_ = std::make_shared<Graph>();
|
||||
auto ret = mindspore::infer::Serialization::Load(model_buff, model_size, model_type, graph_.get(), Key{},
|
||||
kDecModeAesGcm, mindir_path);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Serialization::Load model failed.";
|
||||
return ret;
|
||||
}
|
||||
session_ = InferSession::CreateSession(model_context);
|
||||
if (session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "Create session failed.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
ret = session_->Init(model_context);
|
||||
auto ret = session_->Init(model_context);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Init session failed.";
|
||||
return ret;
|
||||
|
@ -87,18 +60,57 @@ Status ModelImpl::build_by_buffer_impl(const void *model_data, size_t data_size,
|
|||
device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
|
||||
});
|
||||
}
|
||||
if (infer::mindir::MindirModelUtil::NeedRuntimeConvert(model_data, data_size)) {
|
||||
return CompileGraphOnline(model_data, data_size, model_context);
|
||||
}
|
||||
graph_ = std::make_shared<Graph>();
|
||||
ret = mindspore::infer::Serialization::Load(model_buff, model_size, model_type, graph_.get(), Key{}, kDecModeAesGcm,
|
||||
mindir_path);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Serialization::Load model failed.";
|
||||
return ret;
|
||||
}
|
||||
return session_->CompileGraph(graph_->graph_data_->GetFuncGraph());
|
||||
}
|
||||
|
||||
Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context) {
|
||||
return build_by_buffer_impl(model_data, data_size, model_type, model_context);
|
||||
return BuildByBufferImpl(model_data, data_size, model_type, model_context);
|
||||
}
|
||||
|
||||
Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context) {
|
||||
auto buffer = ReadFile(model_path);
|
||||
return this->build_by_buffer_impl(buffer.Data(), buffer.DataSize(), model_type, model_context, model_path);
|
||||
return BuildByBufferImpl(buffer.Data(), buffer.DataSize(), model_type, model_context, model_path);
|
||||
}
|
||||
|
||||
Status ModelImpl::CompileGraphOnline(const void *model_data, size_t data_size,
|
||||
const std::shared_ptr<Context> &model_context) {
|
||||
MS_LOG(INFO) << "Need runtime convert";
|
||||
std::string plugin_path;
|
||||
auto ret = DLSoPath("libmindspore-lite.so", "libruntime_convert_plugin.so", &plugin_path);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(ERROR) << "Get path of libruntime_convert_plugin.so failed. error: " << ret;
|
||||
return kLiteError;
|
||||
}
|
||||
void *function = nullptr;
|
||||
ret = DLSoOpen(plugin_path, "RuntimeConvert", &handle_, &function, true);
|
||||
if (ret != kSuccess) {
|
||||
MS_LOG(WARNING) << "DLSoOpen RuntimeConvert failed, so path: " << plugin_path;
|
||||
return kLiteError;
|
||||
}
|
||||
auto convert =
|
||||
reinterpret_cast<mindspore::api::FuncGraphPtr (*)(const char *, const size_t &, const std::shared_ptr<Context> &)>(
|
||||
function);
|
||||
if (convert != nullptr) {
|
||||
auto api_graph = convert(static_cast<const char *>(model_data), data_size, model_context);
|
||||
auto impl = api_graph->impl();
|
||||
auto inner_graph = std::dynamic_pointer_cast<FuncGraph>(impl);
|
||||
return session_->CompileGraph(inner_graph);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "convert is nullptr";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
}
|
||||
|
||||
Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims) {
|
||||
|
|
|
@ -65,10 +65,9 @@ class ModelImpl {
|
|||
std::string GetConfig(const std::string §ion, const std::string &key);
|
||||
|
||||
private:
|
||||
Status build_by_buffer_impl(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const std::string &model_path = "");
|
||||
|
||||
private:
|
||||
Status BuildByBufferImpl(const void *model_data, size_t data_size, ModelType model_type,
|
||||
const std::shared_ptr<Context> &model_context, const std::string &model_path = "");
|
||||
Status CompileGraphOnline(const void *model_data, size_t data_size, const std::shared_ptr<Context> &model_context);
|
||||
friend class Model;
|
||||
friend class Serialization;
|
||||
std::shared_ptr<Graph> graph_ = nullptr;
|
||||
|
|
|
@ -676,6 +676,12 @@ FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph,
|
|||
}
|
||||
return old_graph;
|
||||
}
|
||||
if (param->is_runtime_converter) {
|
||||
if (RunFormatTrans(old_graph) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run format trans failed";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
if (RunPass(old_graph, param) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Proc online transform failed.";
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <set>
|
||||
#include <tuple>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/meta_graph_serializer.h"
|
||||
#include "tools/lite_exporter/anf_exporter.h"
|
||||
|
@ -217,6 +218,27 @@ int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, schema::
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
FuncGraphPtr ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, const void *buff, const size_t &size) {
|
||||
auto graph = BuildFuncGraph(param, buff, size);
|
||||
MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Build func graph return nullptr.");
|
||||
auto ret = SaveOutputNames(graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "save output name failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_CHECK_TRUE_MSG(funcgraph_transform_ != nullptr, nullptr, "funcgraph_transform init failed");
|
||||
graph = funcgraph_transform_->Transform(graph, param);
|
||||
MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "Transform anf graph return nullptr.");
|
||||
graph->set_attr(kIsOptimized, MakeValue(true));
|
||||
ret = UpdateFuncGraphInputAndOutputNames(graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Update input and output names of funcgraph failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return graph;
|
||||
}
|
||||
|
||||
schema::MetaGraphT *ConverterImpl::TransferFuncGraph(const std::shared_ptr<ConverterPara> ¶m,
|
||||
FuncGraphPtr func_graph) {
|
||||
MS_CHECK_TRUE_MSG(metagraph_transform_ != nullptr, nullptr, "metagraph_transform_ init failed");
|
||||
|
@ -570,42 +592,58 @@ int ConverterImpl::ReplaceShapeWithDynamicShape(const FuncGraphPtr &graph) {
|
|||
MS_LOG(ERROR) << cnode->fullname_with_scope() << "get value node failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto value = prim->GetAttr("infer_done");
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " get infer_node attr failed";
|
||||
auto dynamic_shape_prim = std::make_shared<ops::DynamicShape>();
|
||||
if (dynamic_shape_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Make DynamicShape op failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
bool infer_node = GetValue<bool>(value);
|
||||
if (!infer_node) {
|
||||
auto dynamic_shape_prim = std::make_shared<ops::DynamicShape>();
|
||||
if (dynamic_shape_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Make DynamicShape op failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto dynamic_shape_prim_c = dynamic_shape_prim->GetPrim();
|
||||
if (dynamic_shape_prim_c == nullptr) {
|
||||
MS_LOG(ERROR) << "Get the primitive of dynamic shape op failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
inputs.erase(inputs.begin());
|
||||
auto dynamic_shape_node = graph->NewCNode(dynamic_shape_prim_c, inputs);
|
||||
dynamic_shape_node->set_abstract(ori_abstract);
|
||||
auto manager = Manage(graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "Replace shape node " << cnode->fullname_with_scope() << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!manager->Replace(cnode, dynamic_shape_node)) {
|
||||
MS_LOG(ERROR) << "Replace shape node " << cnode->fullname_with_scope() << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto dynamic_shape_prim_c = dynamic_shape_prim->GetPrim();
|
||||
if (dynamic_shape_prim_c == nullptr) {
|
||||
MS_LOG(ERROR) << "Get the primitive of dynamic shape op failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
inputs.erase(inputs.begin());
|
||||
auto dynamic_shape_node = graph->NewCNode(dynamic_shape_prim_c, inputs);
|
||||
dynamic_shape_node->set_abstract(ori_abstract);
|
||||
auto manager = Manage(graph, true);
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "Replace shape node " << cnode->fullname_with_scope() << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!manager->Replace(cnode, dynamic_shape_node)) {
|
||||
MS_LOG(ERROR) << "Replace shape node " << cnode->fullname_with_scope() << " failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConverterImpl::SaveOutputNames(const FuncGraphPtr &graph) {
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> outputs;
|
||||
std::vector<std::string> output_names;
|
||||
std::vector<std::vector<int64_t>> output_dims;
|
||||
auto ret = GetFuncGraphOutputsInfo(graph, &outputs, &output_names, &output_dims);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get outputs info of funcgraph failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<std::string> update_output_names;
|
||||
for (auto &it : outputs) {
|
||||
if (utils::isa<mindspore::CNodePtr>(it.first)) {
|
||||
auto cnode = it.first->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr");
|
||||
AbstractBasePtr abstract = cnode->abstract();
|
||||
MS_CHECK_TRUE_MSG(abstract != nullptr, RET_ERROR, "abstract is nullptr");
|
||||
auto name = abstract->name();
|
||||
update_output_names.emplace_back(name);
|
||||
}
|
||||
}
|
||||
ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(update_output_names);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int CheckFmkType(const std::shared_ptr<ConverterPara> ¶m) {
|
||||
if (param != nullptr) {
|
||||
std::set valid_values = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx,
|
||||
|
|
|
@ -58,6 +58,7 @@ class ConverterImpl {
|
|||
return FuncGraphConvert(param, graph, meta_graph, true, dst_buff, dst_size);
|
||||
}
|
||||
int Convert(const std::shared_ptr<ConverterPara> ¶m, schema::MetaGraphT **meta_graph, FuncGraphPtr func_graph);
|
||||
FuncGraphPtr Convert(const std::shared_ptr<ConverterPara> ¶m, const void *buff, const size_t &size);
|
||||
|
||||
private:
|
||||
FuncGraphPtr BuildFuncGraph(const std::shared_ptr<ConverterPara> ¶m);
|
||||
|
@ -73,6 +74,7 @@ class ConverterImpl {
|
|||
bool CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config);
|
||||
std::string GetStrFromConfigFile(const std::string &file, const std::string &target_key);
|
||||
int ReplaceShapeWithDynamicShape(const FuncGraphPtr &graph);
|
||||
int SaveOutputNames(const FuncGraphPtr &graph);
|
||||
|
||||
protected:
|
||||
converter::ModelParser *model_parser_ = nullptr;
|
||||
|
|
|
@ -58,6 +58,7 @@ struct ConverterPara {
|
|||
bool pre_infer = false;
|
||||
bool train_model = false;
|
||||
bool no_fusion = false;
|
||||
bool is_runtime_converter = false;
|
||||
std::set<std::string> fusion_blacklists;
|
||||
|
||||
// inner
|
||||
|
|
|
@ -59,7 +59,7 @@ STATUS MindsporeImporter::Mindir2AnfAdjust(const FuncGraphPtr &func_graph,
|
|||
if (value != nullptr) {
|
||||
is_optimized = GetValue<bool>(value);
|
||||
}
|
||||
if (!is_optimized) {
|
||||
if (!is_optimized && !param->is_runtime_converter) {
|
||||
auto mindir_adjust_pass = std::make_shared<MindirAdjust>();
|
||||
MS_CHECK_TRUE_MSG(mindir_adjust_pass != nullptr, RET_NULL_PTR, "mindir_adjust_pass is nullptr.");
|
||||
mindir_adjust_pass->SetFmkType(param->fmk_type);
|
||||
|
@ -327,7 +327,7 @@ FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const std::shared_ptr<Co
|
|||
if (value != nullptr) {
|
||||
is_optimized = GetValue<bool>(value);
|
||||
}
|
||||
if (!is_optimized) {
|
||||
if (!is_optimized && !param->is_runtime_converter) {
|
||||
auto unify_format = std::make_shared<UnifyFormatToNHWC>(converter::kFmkTypeMs, param->train_model);
|
||||
MS_CHECK_TRUE_MSG(unify_format != nullptr, nullptr, "unify_format is nullptr.");
|
||||
if (!unify_format->Run(func_graph)) {
|
||||
|
|
|
@ -383,9 +383,6 @@ int MindIRSerializer::IfSaveTogether(bool *save_together) {
|
|||
}
|
||||
|
||||
int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const std::string &output_file) {
|
||||
if (isRuntimeConvert_) {
|
||||
return RET_OK;
|
||||
}
|
||||
mind_ir::GraphProto *graph_proto = model_proto->mutable_graph();
|
||||
mind_ir::AttributeProto *attr_proto = graph_proto->add_attribute();
|
||||
if (attr_proto != nullptr) {
|
||||
|
@ -393,6 +390,9 @@ int MindIRSerializer::SaveProtoToFile(mind_ir::ModelProto *model_proto, const st
|
|||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_BOOL);
|
||||
attr_proto->set_i(is_fusion_);
|
||||
}
|
||||
if (isRuntimeConvert_) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto realpath = Common::CreatePrefixPath(output_file, true);
|
||||
if (!realpath.has_value()) {
|
||||
|
|
Loading…
Reference in New Issue