From 68e93e3b18f0ff133ddc37c091ce3cba1443ed5e Mon Sep 17 00:00:00 2001 From: chenping Date: Mon, 13 Dec 2021 19:39:36 +0800 Subject: [PATCH] =?UTF-8?q?model=5Fzoo=E6=A8=A1=E5=9E=8B=E5=9C=A8Ascend710?= =?UTF-8?q?=E4=B8=8A=E7=9A=84=E5=9C=A8=E7=BA=BF=E6=8E=A8=E7=90=86=E8=B0=83?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/api/context.h | 1 + mindspore/core/load_mindir/load_model.cc | 16 ++- mindspore/lite/include/lite_types.h | 9 +- mindspore/lite/src/cxx_api/model/model.cc | 8 +- .../lite/src/cxx_api/model/model_impl.cc | 5 +- mindspore/lite/src/inner_context.cc | 2 +- mindspore/lite/src/lite_session.cc | 103 +++++++++++++++++- mindspore/lite/src/lite_session.h | 9 ++ mindspore/lite/src/runtime/runtime_convert.cc | 49 ++++++++- mindspore/lite/src/runtime/runtime_convert.h | 4 +- .../converter/import/mindspore_importer.cc | 4 +- 11 files changed, 192 insertions(+), 18 deletions(-) diff --git a/include/api/context.h b/include/api/context.h index e26eb5d726e..d5fde38c945 100644 --- a/include/api/context.h +++ b/include/api/context.h @@ -450,6 +450,7 @@ class MS_API AscendDeviceInfo : public DeviceInfoContext { using Ascend310DeviceInfo = AscendDeviceInfo; using Ascend910DeviceInfo = AscendDeviceInfo; +using Ascend710DeviceInfo = AscendDeviceInfo; void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) { SetInsertOpConfigPath(StringToChar(cfg_path)); diff --git a/mindspore/core/load_mindir/load_model.cc b/mindspore/core/load_mindir/load_model.cc index 84b14963567..b5a8891971f 100644 --- a/mindspore/core/load_mindir/load_model.cc +++ b/mindspore/core/load_mindir/load_model.cc @@ -203,7 +203,21 @@ FuncGraphPtr MindIRLoader::LoadMindIR(const void *buffer, const size_t &size) { } MSANFModelParser model_parser; - model_parser.SetLite(); + + model_parser.SetMindIRDecKey(dec_key_); + model_parser.SetMindIRKeySize(key_len_); + model_parser.SetMindIRDecMode(dec_mode_); + model_parser.set_need_renormalize(need_renormalize_); + + if (!inc_load_) { + MSANFModelParser::LoadTensorMapClear(); + } else { + model_parser.SetIncLoad(); + } + if (is_lite_) { + model_parser.SetLite(); + } + FuncGraphPtr func_graph = model_parser.Parse(model); return func_graph; diff --git a/mindspore/lite/include/lite_types.h b/mindspore/lite/include/lite_types.h index ea3c468b50a..35a7dbe9460 100644 --- a/mindspore/lite/include/lite_types.h +++ b/mindspore/lite/include/lite_types.h @@ -27,10 +27,11 @@ typedef enum { /// \brief DeviceType defined for holding user's preferred backend. typedef enum { - DT_CPU, /**< CPU device type */ - DT_GPU, /**< GPU device type */ - DT_NPU, /**< NPU device type */ - DT_ASCEND /**< ASCEND device type */ + DT_CPU, /**< CPU device type */ + DT_GPU, /**< GPU device type */ + DT_NPU, /**< NPU device type */ + DT_ASCEND, /**< ASCEND device type */ + DT_END /**< NO device type */ } DeviceType; typedef enum { diff --git a/mindspore/lite/src/cxx_api/model/model.cc b/mindspore/lite/src/cxx_api/model/model.cc index 5d6de2d9049..68c403f429f 100644 --- a/mindspore/lite/src/cxx_api/model/model.cc +++ b/mindspore/lite/src/cxx_api/model/model.cc @@ -15,13 +15,13 @@ */ #include "include/api/model.h" +#include #ifdef GPU_TENSORRT #include #endif #ifdef ENABLE_LITE_ACL #include "acl/acl_base.h" #endif -#include #include "include/api/callback/callback.h" #include "include/api/context.h" #include "include/api/dual_abi_helper.h" @@ -164,7 +164,7 @@ bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) int driver_version = 0; int ret = cudaDriverGetVersion(&driver_version); if (ret != cudaSuccess || driver_version == 0) { - MS_LOG(WARNING) << "No nvidia GPU driver."; + MS_LOG(ERROR) << "No nvidia GPU driver."; return false; } return true; @@ -174,12 +174,12 @@ bool Model::CheckModelSupport(enum DeviceType device_type, ModelType model_type) if (device_type == kAscend || device_type == kAscend310) { const char *soc_name_c = aclrtGetSocName(); if (soc_name_c == nullptr) { - MS_LOG(WARNING) << "aclrtGetSocName failed."; + MS_LOG(ERROR) << "aclrtGetSocName failed."; return false; } std::string soc_name(soc_name_c); if (soc_name.find("910") != std::string::npos) { - MS_LOG(WARNING) << "Device not support, aclrtGetSocName: " << soc_name; + MS_LOG(ERROR) << "Device not support, aclrtGetSocName: " << soc_name; return false; } return true; diff --git a/mindspore/lite/src/cxx_api/model/model_impl.cc b/mindspore/lite/src/cxx_api/model/model_impl.cc index d4601c5ca17..569f7f44ae9 100644 --- a/mindspore/lite/src/cxx_api/model/model_impl.cc +++ b/mindspore/lite/src/cxx_api/model/model_impl.cc @@ -69,7 +69,8 @@ Status ModelImpl::Build(const void *model_data, size_t data_size, ModelType mode return kLiteNullptr; } - auto ret = session->LoadModelAndCompileByBuf(static_cast(model_data), model_type, data_size); + auto ret = + session->LoadModelAndCompileByBuf(static_cast(model_data), model_type, data_size, ms_context); if (ret != RET_OK) { MS_LOG(ERROR) << "Init session failed"; return kLiteError; @@ -88,7 +89,7 @@ Status ModelImpl::Build(const std::string &model_path, ModelType model_type, return kLiteNullptr; } - auto ret = session->LoadModelAndCompileByPath(model_path, model_type); + auto ret = session->LoadModelAndCompileByPath(model_path, model_type, ms_context); if (ret != RET_OK) { MS_LOG(ERROR) << "Init session failed"; return kLiteError; diff --git a/mindspore/lite/src/inner_context.cc b/mindspore/lite/src/inner_context.cc index dba34ef1403..5e2f9d20be9 100644 --- a/mindspore/lite/src/inner_context.cc +++ b/mindspore/lite/src/inner_context.cc @@ -279,7 +279,7 @@ bool InnerContext::IsProviderEnabled() const { bool InnerContext::IsAllDeviceTypeValid() const { return std::all_of(this->device_list_.begin(), this->device_list_.end(), [](const DeviceContext &device) { - return device.device_type_ >= DT_CPU && device.device_type_ <= DT_ASCEND; + return device.device_type_ >= DT_CPU && device.device_type_ < DT_END; }); } diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index d3497d8d592..14df5b3bdd9 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -1533,9 +1533,34 @@ mindspore::ModelType lite::LiteSession::LoadModelByBuff(const char *model_buf, c return mindspore::ModelType::kMindIR_Opt; } MS_LOG(WARNING) << "Invalid mslite model."; + return mindspore::ModelType::kMindIR; +} + +mindspore::ModelType lite::LiteSession::LoadModelByBuff(const char *model_buf, const size_t &buf_size, char **lite_buf, + size_t *size, mindspore::ModelType model_type, + const std::shared_ptr &ms_context) { + if (model_type == mindspore::ModelType::kMindIR_Opt) { + *size = buf_size; + *lite_buf = const_cast(model_buf); + return mindspore::ModelType::kMindIR_Opt; + } + + if (model_type != mindspore::ModelType::kMindIR) { + return mindspore::ModelType::kUnknownType; + } + + flatbuffers::Verifier verify((const uint8_t *)model_buf, buf_size); + auto version_verify = lite::LiteModel::VersionVerify(&verify); + if (version_verify != SCHEMA_INVALID) { + MS_LOG(DEBUG) << "The kMindIR type model buffer is valid mslite model buffer"; + *size = buf_size; + *lite_buf = const_cast(model_buf); + return mindspore::ModelType::kMindIR_Opt; + } + MS_LOG(WARNING) << "Invalid mslite model."; #ifdef RUNTIME_CONVERT - *lite_buf = RuntimeConvert(model_buf, buf_size, size); + *lite_buf = RuntimeConvert(model_buf, buf_size, size, ms_context); #else MS_LOG(ERROR) << "Please enable runtime convert."; #endif @@ -1562,6 +1587,27 @@ const char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspor return lite_buf; } +const char *lite::LiteSession::LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, + const std::shared_ptr &ms_context) { + size_t buf_size; + auto model_buf = lite::ReadFile(file.c_str(), &buf_size); + if (model_buf == nullptr) { + MS_LOG(ERROR) << "The model path is invalid"; + return model_buf; + } + + char *lite_buf = nullptr; + auto buf_model_type = LoadModelByBuff(model_buf, buf_size, &lite_buf, size, model_type, ms_context); + if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) { + return nullptr; + } + if (buf_model_type == mindspore::ModelType::kMindIR) { + delete[] model_buf; + model_buf = nullptr; + } + return lite_buf; +} + int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size) { size_t lite_buf_size = 0; @@ -1592,6 +1638,37 @@ int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore return RET_OK; } +int lite::LiteSession::LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type, + const size_t &buf_size, + const std::shared_ptr &ms_context) { + size_t lite_buf_size = 0; + char *lite_buf = nullptr; + auto buf_model_type = LoadModelByBuff(model_buf, buf_size, &lite_buf, &lite_buf_size, model_type, ms_context); + if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) { + MS_LOG(ERROR) << "Invalid model_buf"; + return RET_ERROR; + } + + auto *model = lite::ImportFromBuffer(lite_buf, lite_buf_size, true); + if (model == nullptr) { + MS_LOG(ERROR) << "Import model failed"; + return RET_ERROR; + } + auto ret = CompileGraph(model); + model->buf = nullptr; + if (buf_model_type == mindspore::ModelType::kMindIR) { + delete[] lite_buf; + lite_buf = nullptr; + } + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "Compile model failed"; + delete model; + return RET_ERROR; + } + set_model(model); + return RET_OK; +} + int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type) { size_t model_size; auto model_buf = LoadModelByPath(model_path, model_type, &model_size); @@ -1614,4 +1691,28 @@ int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, set_model(model); return RET_OK; } + +int lite::LiteSession::LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type, + const std::shared_ptr &ms_context) { + size_t model_size; + auto model_buf = LoadModelByPath(model_path, model_type, &model_size, ms_context); + if (model_buf == nullptr) { + MS_LOG(ERROR) << "Read model file failed"; + return RET_ERROR; + } + auto *model = lite::ImportFromBuffer(model_buf, model_size, true); + if (model == nullptr) { + MS_LOG(ERROR) << "Import model failed"; + return RET_ERROR; + } + + (reinterpret_cast(model))->set_keep_model_buf(true); + auto ret = CompileGraph(model); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "Compile model failed"; + return RET_ERROR; + } + set_model(model); + return RET_OK; +} } // namespace mindspore diff --git a/mindspore/lite/src/lite_session.h b/mindspore/lite/src/lite_session.h index 30705da8f93..28e97114c6a 100644 --- a/mindspore/lite/src/lite_session.h +++ b/mindspore/lite/src/lite_session.h @@ -54,12 +54,21 @@ class LiteSession : public session::LiteSession { static session::LiteSession *CreateSession(const std::string &model_path, const lite::Context *context); int LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size); + int LoadModelAndCompileByBuf(const char *model_buf, mindspore::ModelType model_type, const size_t &buf_size, + const std::shared_ptr &ms_context); int LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type); + int LoadModelAndCompileByPath(const std::string &model_path, mindspore::ModelType model_type, + const std::shared_ptr &ms_context); static mindspore::ModelType LoadModelByBuff(const char *model_buf, const size_t &buf_size, char **lite_buf, size_t *size, mindspore::ModelType model_type); + static mindspore::ModelType LoadModelByBuff(const char *model_buf, const size_t &buf_size, char **lite_buf, + size_t *size, mindspore::ModelType model_type, + const std::shared_ptr &ms_context); static const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size); + static const char *LoadModelByPath(const std::string &file, mindspore::ModelType model_type, size_t *size, + const std::shared_ptr &ms_context); virtual int Init(InnerContext *context); diff --git a/mindspore/lite/src/runtime/runtime_convert.cc b/mindspore/lite/src/runtime/runtime_convert.cc index 762b1219cb9..c722f470508 100644 --- a/mindspore/lite/src/runtime/runtime_convert.cc +++ b/mindspore/lite/src/runtime/runtime_convert.cc @@ -15,13 +15,17 @@ */ #ifdef RUNTIME_CONVERT +#include #include "src/runtime/runtime_convert.h" +#include "tools/common/string_util.h" #include "include/version.h" #include "tools/converter/converter.h" #include "tools/converter/converter_flags.h" +#include "acl/acl_base.h" namespace mindspore::lite { -char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size) { +char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size, + const std::shared_ptr &context) { if (model_buf == nullptr) { MS_LOG(ERROR) << "Invalid input model buffer."; return nullptr; @@ -33,8 +37,49 @@ char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size flag->outputDataType = kTypeUnknown; flag->saveFP16 = false; flag->trainModel = false; + + auto device_list = context->MutableDeviceInfo(); + for (auto &device : device_list) { + if (device->GetDeviceType() == kAscend) { + auto ascend_info = device->Cast(); + std::string dynamic_batch_size = ascend_info->GetDynamicBatchSize(); + if (!dynamic_batch_size.empty()) { + std::vector batch_size_string = SplitStringToVector(dynamic_batch_size, ','); + for (const auto &item : batch_size_string) { + int32_t val; + if (ConvertIntNum(item, &val)) { + size_t tmp_val = static_cast(val); + flag->aclModelOptionCfgParam.dynamic_batch_size.push_back(tmp_val); + } + } + } + flag->aclModelOptionCfgParam.device_id = ascend_info->GetDeviceID(); + flag->aclModelOptionCfgParam.output_type = ascend_info->GetOutputType(); + flag->aclModelOptionCfgParam.input_shape_map = ascend_info->GetInputShapeMap(); + flag->aclModelOptionCfgParam.input_format = ascend_info->GetInputFormat(); + flag->aclModelOptionCfgParam.input_shape = ascend_info->GetInputShape(); + flag->aclModelOptionCfgParam.precision_mode = ascend_info->GetPrecisionMode(); + flag->aclModelOptionCfgParam.op_select_impl_mode = ascend_info->GetOpSelectImplMode(); + flag->aclModelOptionCfgParam.fusion_switch_config_file_path = ascend_info->GetFusionSwitchConfigPath(); + flag->aclModelOptionCfgParam.buffer_optimize = ascend_info->GetBufferOptimizeMode(); + flag->aclModelOptionCfgParam.insert_op_config_file_path = ascend_info->GetInsertOpConfigPath(); + flag->aclModelOptionCfgParam.dynamic_image_size = ascend_info->GetDynamicImageSize(); + } else { + continue; + } + } + #ifdef ENABLE_LITE_ACL - flag->device = "Ascend310"; + const char *soc_name_c = aclrtGetSocName(); + if (soc_name_c != nullptr) { + std::string soc_name(soc_name_c); + if (soc_name.find("710") == std::string::npos) { + flag->device = "Ascend710"; + } + if (soc_name.find("310") == std::string::npos) { + flag->device = "Ascend310"; + } + } #endif Converter cvt; diff --git a/mindspore/lite/src/runtime/runtime_convert.h b/mindspore/lite/src/runtime/runtime_convert.h index afca200a784..48ffb5348ca 100644 --- a/mindspore/lite/src/runtime/runtime_convert.h +++ b/mindspore/lite/src/runtime/runtime_convert.h @@ -21,9 +21,11 @@ #include #include #include +#include "include/api/context.h" namespace mindspore::lite { -char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size); +char *RuntimeConvert(const char *model_buf, const size_t &buf_size, size_t *size, + const std::shared_ptr &context); char *RuntimeConvert(const std::string &file_path, size_t *size); } // namespace mindspore::lite #endif // RUNTIME_CONVERT diff --git a/mindspore/lite/tools/converter/import/mindspore_importer.cc b/mindspore/lite/tools/converter/import/mindspore_importer.cc index 8d695a694fa..5919510c6b3 100644 --- a/mindspore/lite/tools/converter/import/mindspore_importer.cc +++ b/mindspore/lite/tools/converter/import/mindspore_importer.cc @@ -324,8 +324,8 @@ FuncGraphPtr MindsporeImporter::CheckAndUpdateFuncGraph(const converter::Flags & return nullptr; } ConverterInnerContext::GetInstance()->SetGraphOutputTensorNames(output_tensor_name_); - if (flag.device == "Ascend310") { - MS_LOG(INFO) << "There is no need to adjust and pass graph when in Ascend310."; + if (flag.device == "Ascend310" || flag.device == "Ascend710") { + MS_LOG(INFO) << "There is no need to adjust and pass graph when in Ascend310 or Ascend710."; return func_graph; } if ((status = Mindir2AnfAdjust(func_graph, flag)) != RET_OK) {