forked from mindspore-Ecosystem/mindspore
!47992 [MS][LITE] Support large model inference.
Merge pull request !47992 from youshu/ys_large_model_impl2
This commit is contained in:
commit
9f0b6227dd
|
@ -40,6 +40,17 @@ bool Common::NeedMapping(const std::string &origin_name) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string Common::GetRandomStr(size_t str_len) {
|
||||||
|
std::random_device rd;
|
||||||
|
std::mt19937 generator(rd());
|
||||||
|
std::uniform_int_distribution<int> distribution{'a', 'z'};
|
||||||
|
|
||||||
|
std::string rand_str(str_len, '\0');
|
||||||
|
std::generate(rand_str.begin(), rand_str.end(), [&distribution, &generator]() { return distribution(generator); });
|
||||||
|
|
||||||
|
return rand_str;
|
||||||
|
}
|
||||||
|
|
||||||
std::string Common::GetRandomStr() {
|
std::string Common::GetRandomStr() {
|
||||||
std::string npy_suffix = ".npy";
|
std::string npy_suffix = ".npy";
|
||||||
#ifndef _MSC_VER
|
#ifndef _MSC_VER
|
||||||
|
|
|
@ -22,6 +22,12 @@
|
||||||
#include "acl/acl_rt.h"
|
#include "acl/acl_rt.h"
|
||||||
#include "cxx_api/model/aoe/auto_tune_process.h"
|
#include "cxx_api/model/aoe/auto_tune_process.h"
|
||||||
#include "plugin/device/ascend/optimizer/ge_optimization.h"
|
#include "plugin/device/ascend/optimizer/ge_optimization.h"
|
||||||
|
#if defined(ENABLE_CLOUD_FUSION_INFERENCE)
|
||||||
|
#include "tools/mindir_exporter/mindir_serializer.h"
|
||||||
|
#include "extendrt/cxx_api/file_utils.h"
|
||||||
|
#include "include/common/debug/common.h"
|
||||||
|
#include "load_mindir/load_model.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -182,6 +188,105 @@ Status ModelConverter::SaveModel(const ge::ModelBufferData &model) const {
|
||||||
return kSuccess;
|
return kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(ENABLE_CLOUD_FUSION_INFERENCE)
|
||||||
|
Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
|
||||||
|
MultiProcess multi_process;
|
||||||
|
Buffer buffer_ret;
|
||||||
|
const size_t file_name_len = 12;
|
||||||
|
std::string rand_file_name = Common::GetRandomStr(file_name_len);
|
||||||
|
const std::string dir_prefix = "/tmp/";
|
||||||
|
const std::string out_graph_path = dir_prefix + rand_file_name;
|
||||||
|
|
||||||
|
auto param = std::make_shared<ConverterPara>();
|
||||||
|
param->output_file = out_graph_path;
|
||||||
|
int ret = lite::MindIRSerialize(param, func_graph, false, nullptr, nullptr);
|
||||||
|
if (ret != lite::RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Export to mindir failed";
|
||||||
|
return buffer_ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto parent_process = [&buffer_ret](MultiProcess *multi_process) -> Status {
|
||||||
|
MS_EXCEPTION_IF_NULL(multi_process);
|
||||||
|
|
||||||
|
// receive convert model result from child
|
||||||
|
CreateBufferCall call = [&buffer_ret](size_t msg_len) -> uint8_t * {
|
||||||
|
(void)buffer_ret.ResizeData(msg_len);
|
||||||
|
return static_cast<uint8_t *>(buffer_ret.MutableData());
|
||||||
|
};
|
||||||
|
auto status = multi_process->ReceiveMsg(call);
|
||||||
|
if (status != kSuccess) {
|
||||||
|
MS_LOG_ERROR << "Receive result model from child process failed";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
};
|
||||||
|
auto child_process = [this, &dir_prefix, &out_graph_path](MultiProcess *multi_process) -> Status {
|
||||||
|
MS_EXCEPTION_IF_NULL(multi_process);
|
||||||
|
|
||||||
|
std::string real_path_str;
|
||||||
|
char real_path_mem[PATH_MAX] = {0};
|
||||||
|
|
||||||
|
// For split saved models, model name has _graph suffix.
|
||||||
|
std::string model_file_path_regular = out_graph_path + ".mindir";
|
||||||
|
std::string model_file_path_split = out_graph_path + "_graph.mindir";
|
||||||
|
std::string selected_file_path = model_file_path_regular;
|
||||||
|
|
||||||
|
// Check if model with regular naming exists, choose split naming if not.
|
||||||
|
auto real_path_ret = realpath(common::SafeCStr(model_file_path_regular), real_path_mem);
|
||||||
|
if (real_path_ret == nullptr) {
|
||||||
|
selected_file_path = model_file_path_split;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load model by file
|
||||||
|
Buffer buffer = ReadFile(selected_file_path);
|
||||||
|
MindIRLoader mindir_loader(true, nullptr, 0, kDecModeAesGcm, false);
|
||||||
|
auto func_graph = mindir_loader.LoadMindIR(buffer.Data(), buffer.DataSize(), dir_prefix);
|
||||||
|
if (func_graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Fail to load model, function graph is nullptr.";
|
||||||
|
return kMCFailed;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto df_graph = ConvertFuncGraphToAIR(func_graph);
|
||||||
|
if (df_graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed.";
|
||||||
|
return kMCFailed;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, std::string> init_options;
|
||||||
|
std::map<std::string, std::string> build_options;
|
||||||
|
auto option = options_.lock();
|
||||||
|
if (option != nullptr) {
|
||||||
|
std::tie(init_options, build_options) = option->GenAclOptions();
|
||||||
|
}
|
||||||
|
if (AutoTuneProcess::AoeOfflineTurningGraph(options_, df_graph) != kSuccess) {
|
||||||
|
MS_LOG(ERROR) << "Aoe tune graph failed.";
|
||||||
|
return kMCFailed;
|
||||||
|
}
|
||||||
|
Buffer model_result = BuildAirModel(df_graph, init_options, build_options);
|
||||||
|
if (model_result.DataSize() == 0) {
|
||||||
|
MS_LOG(ERROR) << "Convert model from MindIR to OM failed";
|
||||||
|
return kMCFailed;
|
||||||
|
}
|
||||||
|
|
||||||
|
// send result model to parent
|
||||||
|
auto status = multi_process->SendMsg(model_result.Data(), model_result.DataSize());
|
||||||
|
if (status != kSuccess) {
|
||||||
|
MS_LOG_ERROR << "Send result model to parent process failed";
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
return kSuccess;
|
||||||
|
};
|
||||||
|
ClearCurrentRtCtx();
|
||||||
|
auto status = multi_process.MainProcess(parent_process, child_process);
|
||||||
|
if (status != kSuccess) {
|
||||||
|
MS_LOG_ERROR << "Convert MindIR model to OM model failed";
|
||||||
|
} else {
|
||||||
|
MS_LOG_INFO << "Convert MindIR model to OM model success";
|
||||||
|
}
|
||||||
|
return buffer_ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
|
Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
|
||||||
MultiProcess multi_process;
|
MultiProcess multi_process;
|
||||||
Buffer buffer_ret;
|
Buffer buffer_ret;
|
||||||
|
@ -254,6 +359,7 @@ Buffer ModelConverter::LoadMindIR(const FuncGraphPtr &func_graph) {
|
||||||
}
|
}
|
||||||
return buffer_ret;
|
return buffer_ret;
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) {
|
Buffer ModelConverter::LoadAscendIRInner(const Buffer &model_data) {
|
||||||
ge::Model load_model = ge::Model("loadmodel", "version2");
|
ge::Model load_model = ge::Model("loadmodel", "version2");
|
||||||
|
|
|
@ -202,7 +202,8 @@ Status MultiProcess::ReceiveMsg(const CreateBufferCall &create_buffer_call) cons
|
||||||
msg_buffer = create_buffer_call(msg_len);
|
msg_buffer = create_buffer_call(msg_len);
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(msg_buffer);
|
MS_EXCEPTION_IF_NULL(msg_buffer);
|
||||||
auto ret = memcpy_s(msg_buffer + cur_offset, msg_len - cur_offset, shmat_data_addr_, receive_msg_->msg_len);
|
size_t destMax = std::min(shmat_data_max_size_, msg_len - cur_offset);
|
||||||
|
auto ret = memcpy_s(msg_buffer + cur_offset, destMax, shmat_data_addr_, receive_msg_->msg_len);
|
||||||
if (ret != EOK) {
|
if (ret != EOK) {
|
||||||
MS_LOG(INFO) << "memcpy_s failed, ret = " << ret;
|
MS_LOG(INFO) << "memcpy_s failed, ret = " << ret;
|
||||||
return kMEFailed;
|
return kMEFailed;
|
||||||
|
|
|
@ -37,6 +37,7 @@ class COMMON_EXPORT Common {
|
||||||
~Common() = default;
|
~Common() = default;
|
||||||
static bool NeedMapping(const std::string &origin_name);
|
static bool NeedMapping(const std::string &origin_name);
|
||||||
static std::string GetRandomStr();
|
static std::string GetRandomStr();
|
||||||
|
static std::string GetRandomStr(size_t str_len);
|
||||||
static bool MappingName(const std::string &input_path, std::optional<std::string> *prefix_path,
|
static bool MappingName(const std::string &input_path, std::optional<std::string> *prefix_path,
|
||||||
std::optional<std::string> *origin_name, std::optional<std::string> *mapped_name);
|
std::optional<std::string> *origin_name, std::optional<std::string> *mapped_name);
|
||||||
static std::optional<std::string> CreatePrefixPath(const std::string &input_path,
|
static std::optional<std::string> CreatePrefixPath(const std::string &input_path,
|
||||||
|
|
|
@ -24,6 +24,12 @@ static const TraceLabelType global_trace_type = (common::GetEnv("MS_DEV_TRACE_LA
|
||||||
? TraceLabelType::kWithUniqueId
|
? TraceLabelType::kWithUniqueId
|
||||||
: TraceLabelType::kShortSymbol;
|
: TraceLabelType::kShortSymbol;
|
||||||
TraceLabelType GetGlobalTraceLabelType() { return global_trace_type; }
|
TraceLabelType GetGlobalTraceLabelType() { return global_trace_type; }
|
||||||
|
TraceLabelType GetCurrentTraceLabelType() {
|
||||||
|
if (common::GetEnv("MS_DEV_TRACE_LABEL_WITH_UNIQUE_ID") == "1") {
|
||||||
|
return TraceLabelType::kWithUniqueId;
|
||||||
|
}
|
||||||
|
return TraceLabelType::kShortSymbol;
|
||||||
|
}
|
||||||
|
|
||||||
struct NameWithTrace {
|
struct NameWithTrace {
|
||||||
std::string name;
|
std::string name;
|
||||||
|
@ -106,7 +112,8 @@ std::string CombineUniqueID(const DebugInfoPtr &debug_info) {
|
||||||
std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); }
|
std::string LabelStringUnique(const DebugInfoPtr &debug_info) { return CombineUniqueID(debug_info); }
|
||||||
|
|
||||||
std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
|
std::string Label(const DebugInfoPtr &debug_info, TraceLabelType trace_label) {
|
||||||
if (GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) {
|
if ((GetGlobalTraceLabelType() == TraceLabelType::kWithUniqueId) ||
|
||||||
|
(GetCurrentTraceLabelType() == TraceLabelType::kWithUniqueId)) {
|
||||||
return LabelStringUnique(debug_info);
|
return LabelStringUnique(debug_info);
|
||||||
}
|
}
|
||||||
return LabelString(debug_info, trace_label);
|
return LabelString(debug_info, trace_label);
|
||||||
|
|
|
@ -184,7 +184,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)
|
||||||
|
|
||||||
if(MSLITE_ENABLE_CONVERTER)
|
if(MSLITE_ENABLE_CONVERTER)
|
||||||
add_subdirectory(convert)
|
add_subdirectory(convert)
|
||||||
target_link_libraries(mindspore-extendrt mindspore_converter)
|
target_link_libraries(mindspore-extendrt -Wl,--no-as-needed mindspore_converter)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(TEST_CLOUD_INFER off)
|
set(TEST_CLOUD_INFER off)
|
||||||
|
|
|
@ -18,6 +18,8 @@ file(GLOB ACL_SRC
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/common/*.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/common/*.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/mapper/*.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/mapper/*.cc
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/infer/*.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/infer/*.cc
|
||||||
|
${TOP_DIR}/mindspore/lite/src/extendrt/utils/serialization.cc
|
||||||
|
${TOP_DIR}/mindspore/lite/src/extendrt/cxx_api/serialization.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
set(ACL_SRC ${ACL_SRC} ${CMAKE_CURRENT_SOURCE_DIR}/acl_pass.cc)
|
set(ACL_SRC ${ACL_SRC} ${CMAKE_CURRENT_SOURCE_DIR}/acl_pass.cc)
|
||||||
|
|
|
@ -41,8 +41,12 @@ STATUS MatMulFusionMapper::Mapper(const CNodePtr &cnode) {
|
||||||
ops::BatchMatMul batch_mat_mul;
|
ops::BatchMatMul batch_mat_mul;
|
||||||
dst_prim = batch_mat_mul.GetPrim();
|
dst_prim = batch_mat_mul.GetPrim();
|
||||||
}
|
}
|
||||||
dst_prim->AddAttr("transpose_x1", transpose_a);
|
if (transpose_a != nullptr) {
|
||||||
dst_prim->AddAttr("transpose_x2", transpose_b);
|
dst_prim->AddAttr("transpose_x1", transpose_a);
|
||||||
|
}
|
||||||
|
if (transpose_b != nullptr) {
|
||||||
|
dst_prim->AddAttr("transpose_x2", transpose_b);
|
||||||
|
}
|
||||||
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
if (MoveAttrMap(cnode, dst_prim) != RET_OK) {
|
||||||
MS_LOG(ERROR) << "MatMulFusion mapper failed.";
|
MS_LOG(ERROR) << "MatMulFusion mapper failed.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
|
|
@ -317,26 +317,6 @@ STATUS ConverterFuncGraph::Optimize(const std::shared_ptr<ConverterPara> ¶m,
|
||||||
int ConverterFuncGraph::Save(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph, void **buff,
|
int ConverterFuncGraph::Save(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph, void **buff,
|
||||||
size_t *size) {
|
size_t *size) {
|
||||||
mindspore::lite::MindIRSerializer serializer;
|
mindspore::lite::MindIRSerializer serializer;
|
||||||
auto fv_count = 0;
|
|
||||||
std::vector<AnfNodePtr> params;
|
|
||||||
std::vector<AnfNodePtr> reorder_param;
|
|
||||||
reorder_param.reserve(func_graph->parameters().size());
|
|
||||||
for (const auto &node : func_graph->parameters()) {
|
|
||||||
auto param_node = node->cast<ParameterPtr>();
|
|
||||||
if (param_node == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "The parameters() in func graph should be all Parameter Node. but got " << node->DebugString();
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
if (param_node->has_default()) {
|
|
||||||
(void)params.emplace_back(param_node);
|
|
||||||
++fv_count;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
(void)reorder_param.emplace_back(param_node);
|
|
||||||
}
|
|
||||||
std::copy(params.begin(), params.end(), std::back_inserter(reorder_param));
|
|
||||||
func_graph->set_parameters(reorder_param);
|
|
||||||
func_graph->set_fv_param_count(fv_count);
|
|
||||||
auto ret = serializer.Save(param, func_graph);
|
auto ret = serializer.Save(param, func_graph);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "MindIR serialize fail";
|
MS_LOG(ERROR) << "MindIR serialize fail";
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <dirent.h>
|
#include <dirent.h>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <algorithm>
|
||||||
#include "mindspore/ccsrc/include/common/debug/dump_proto.h"
|
#include "mindspore/ccsrc/include/common/debug/dump_proto.h"
|
||||||
#include "mindspore/ccsrc/include/common/utils/utils.h"
|
#include "mindspore/ccsrc/include/common/utils/utils.h"
|
||||||
#include "src/common/file_utils.h"
|
#include "src/common/file_utils.h"
|
||||||
|
@ -100,6 +101,81 @@ int MindIRSerializer::RemoveQuantParameterHolder(FuncGraphPtr func_graph) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int MindIRSerializer::UpdateParamCount(const FuncGraphPtr &func_graph) {
|
||||||
|
auto fv_count = 0;
|
||||||
|
std::vector<AnfNodePtr> params;
|
||||||
|
std::vector<AnfNodePtr> reorder_param;
|
||||||
|
reorder_param.reserve(func_graph->parameters().size());
|
||||||
|
for (const auto &node : func_graph->parameters()) {
|
||||||
|
auto param_node = node->cast<ParameterPtr>();
|
||||||
|
if (param_node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The parameters() in func graph should be all Parameter Node. but got " << node->DebugString();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (param_node->has_default()) {
|
||||||
|
(void)params.emplace_back(param_node);
|
||||||
|
++fv_count;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
(void)reorder_param.emplace_back(param_node);
|
||||||
|
}
|
||||||
|
std::copy(params.begin(), params.end(), std::back_inserter(reorder_param));
|
||||||
|
func_graph->set_parameters(reorder_param);
|
||||||
|
func_graph->set_fv_param_count(fv_count);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int MindIRSerializer::PreProcSaveTogether(const FuncGraphPtr &func_graph) {
|
||||||
|
if (func_graph == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "func_graph is nullptr.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = UpdateParamCount(func_graph);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Update parameter count failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = ConvertQuantHolderToQuantizationParam(func_graph);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "add quant parameter holder failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = RemoveQuantParameterHolder(func_graph);
|
||||||
|
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
||||||
|
MS_LOG(ERROR) << "remove quant parameter holder failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse func_graph as model proto
|
||||||
|
std::string proto_string = GetBinaryProtoString(func_graph);
|
||||||
|
if (proto_string.empty()) {
|
||||||
|
MS_LOG(ERROR) << "parse proto string failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!model_proto_.ParseFromString(proto_string)) {
|
||||||
|
MS_LOG(ERROR) << "parse model proto from string failed.";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = ParamDict(func_graph);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "parse param form funcgraph failed.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
ret = IfSaveTogether(&save_together_);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "error occur when check condition of saving together.";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int MindIRSerializer::Save(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph) {
|
int MindIRSerializer::Save(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph) {
|
||||||
if (func_graph == nullptr) {
|
if (func_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "func_graph is nullptr.";
|
MS_LOG(ERROR) << "func_graph is nullptr.";
|
||||||
|
@ -116,38 +192,13 @@ int MindIRSerializer::Save(const std::shared_ptr<ConverterPara> ¶m, const Fu
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = ConvertQuantHolderToQuantizationParam(func_graph);
|
// Serialize to protobuf using unique parameter name label.
|
||||||
|
common::SetEnv("MS_DEV_TRACE_LABEL_WITH_UNIQUE_ID", "1", 0);
|
||||||
|
|
||||||
|
// Do preprocess on func_graph and check conditions for saving together.
|
||||||
|
ret = PreProcSaveTogether(func_graph);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "add quant parameter holder failed.";
|
MS_LOG(ERROR) << "PreProcSaveTogether failed";
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = RemoveQuantParameterHolder(func_graph);
|
|
||||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
|
||||||
MS_LOG(ERROR) << "remove quant parameter holder failed.";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto proto_string = GetBinaryProtoString(func_graph);
|
|
||||||
if (proto_string.empty()) {
|
|
||||||
MS_LOG(ERROR) << "parse proto string failed.";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!model_proto_.ParseFromString(proto_string)) {
|
|
||||||
MS_LOG(ERROR) << "parse model proto from string failed.";
|
|
||||||
return RET_NULL_PTR;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = ParamDict(func_graph);
|
|
||||||
if (ret != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "parse param form funcgraph failed.";
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = IfSaveTogether(&save_together_);
|
|
||||||
if (ret != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "error occur when check condition of saving together.";
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -550,14 +601,17 @@ int MindIRSerializer::GetBuffAndSize(void **buff, size_t *size) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
int MindIRSerialize(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph, void **buff,
|
int MindIRSerialize(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph, bool need_buff,
|
||||||
size_t *size) {
|
void **buff, size_t *size) {
|
||||||
mindspore::lite::MindIRSerializer serializer;
|
mindspore::lite::MindIRSerializer serializer;
|
||||||
auto ret = serializer.Save(param, func_graph);
|
auto ret = serializer.Save(param, func_graph);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "MindIR serialize fail";
|
MS_LOG(ERROR) << "MindIR serialize fail";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
return serializer.GetBuffAndSize(buff, size);
|
if (need_buff) {
|
||||||
|
return serializer.GetBuffAndSize(buff, size);
|
||||||
|
}
|
||||||
|
return RET_OK;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -42,6 +42,7 @@ class MindIRSerializer {
|
||||||
}
|
}
|
||||||
int Save(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph);
|
int Save(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph);
|
||||||
int GetBuffAndSize(void **buff, size_t *size);
|
int GetBuffAndSize(void **buff, size_t *size);
|
||||||
|
int PreProcSaveTogether(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int ParserPath(const std::string &output_path);
|
int ParserPath(const std::string &output_path);
|
||||||
|
@ -52,6 +53,7 @@ class MindIRSerializer {
|
||||||
int ConvertQuantHolderToQuantizationParam(const FuncGraphPtr &func_graph);
|
int ConvertQuantHolderToQuantizationParam(const FuncGraphPtr &func_graph);
|
||||||
std::shared_ptr<mindspore::QuantizationParam> ConvertQuantParamTToQuantizationParam(
|
std::shared_ptr<mindspore::QuantizationParam> ConvertQuantParamTToQuantizationParam(
|
||||||
std::vector<schema::QuantParamT> quant_param);
|
std::vector<schema::QuantParamT> quant_param);
|
||||||
|
int UpdateParamCount(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int ParamDict(const FuncGraphPtr &func_graph);
|
int ParamDict(const FuncGraphPtr &func_graph);
|
||||||
|
@ -77,7 +79,7 @@ class MindIRSerializer {
|
||||||
std::shared_ptr<system::FileSystem> fs_{};
|
std::shared_ptr<system::FileSystem> fs_{};
|
||||||
};
|
};
|
||||||
// export func_graph
|
// export func_graph
|
||||||
int MindIRSerialize(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph, void **buff,
|
int MindIRSerialize(const std::shared_ptr<ConverterPara> ¶m, const FuncGraphPtr &func_graph, bool need_buff,
|
||||||
size_t *size);
|
void **buff, size_t *size);
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
#endif // MINDSPORE_LITE_TOOLS_MINDIR_EXPORTER_MINDIR_SERIALIZER_H_
|
#endif // MINDSPORE_LITE_TOOLS_MINDIR_EXPORTER_MINDIR_SERIALIZER_H_
|
||||||
|
|
Loading…
Reference in New Issue