From 96aea47492799c460e6b4f9c69e39140a89a4c72 Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Thu, 4 Mar 2021 11:30:53 +0800 Subject: [PATCH] 310 add context options Signed-off-by: zhoufeng --- include/api/context.h | 24 +++++++++++++ mindspore/ccsrc/cxx_api/context.cc | 35 +++++++++++++++++++ .../ccsrc/cxx_api/graph/acl/acl_env_guard.cc | 4 +-- .../ccsrc/cxx_api/graph/acl/acl_graph_impl.cc | 2 +- .../cxx_api/model/acl/acl_model_options.cc | 3 +- .../cxx_api/model/acl/acl_model_options.h | 1 + 6 files changed, 65 insertions(+), 4 deletions(-) diff --git a/include/api/context.h b/include/api/context.h index 90dfa408d63..1d8852bdeec 100644 --- a/include/api/context.h +++ b/include/api/context.h @@ -45,10 +45,16 @@ struct MS_API GlobalContext : public Context { static void SetGlobalDeviceID(const uint32_t &device_id); static uint32_t GetGlobalDeviceID(); + static inline void SetGlobalDumpConfigPath(const std::string &cfg_path); + static inline std::string GetGlobalDumpConfigPath(); + private: // api without std::string static void SetGlobalDeviceTarget(const std::vector &device_target); static std::vector GetGlobalDeviceTargetChar(); + + static void SetGlobalDumpConfigPath(const std::vector &cfg_path); + static std::vector GetGlobalDumpConfigPathChar(); }; struct MS_API ModelContext : public Context { @@ -72,6 +78,9 @@ struct MS_API ModelContext : public Context { const std::string &op_select_impl_mode); static inline std::string GetOpSelectImplMode(const std::shared_ptr &context); + static inline void SetFusionSwitchConfigPath(const std::shared_ptr &context, const std::string &cfg_path); + static inline std::string GetFusionSwitchConfigPath(const std::shared_ptr &context); + private: // api without std::string static void SetInsertOpConfigPath(const std::shared_ptr &context, const std::vector &cfg_path); @@ -89,6 +98,9 @@ struct MS_API ModelContext : public Context { static void SetOpSelectImplMode(const std::shared_ptr &context, const std::vector &op_select_impl_mode); static std::vector GetOpSelectImplModeChar(const std::shared_ptr &context); + + static void SetFusionSwitchConfigPath(const std::shared_ptr &context, const std::vector &cfg_path); + static std::vector GetFusionSwitchConfigPathChar(const std::shared_ptr &context); }; void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { @@ -96,6 +108,11 @@ void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { } std::string GlobalContext::GetGlobalDeviceTarget() { return CharToString(GetGlobalDeviceTargetChar()); } +void GlobalContext::SetGlobalDumpConfigPath(const std::string &cfg_path) { + SetGlobalDumpConfigPath(StringToChar(cfg_path)); +} +std::string GlobalContext::GetGlobalDumpConfigPath() { return CharToString(GetGlobalDumpConfigPathChar()); } + void ModelContext::SetInsertOpConfigPath(const std::shared_ptr &context, const std::string &cfg_path) { SetInsertOpConfigPath(context, StringToChar(cfg_path)); } @@ -131,5 +148,12 @@ void ModelContext::SetOpSelectImplMode(const std::shared_ptr &context, std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr &context) { return CharToString(GetOpSelectImplModeChar(context)); } + +void ModelContext::SetFusionSwitchConfigPath(const std::shared_ptr &context, const std::string &cfg_path) { + SetFusionSwitchConfigPath(context, StringToChar(cfg_path)); +} +std::string ModelContext::GetFusionSwitchConfigPath(const std::shared_ptr &context) { + return CharToString(GetFusionSwitchConfigPathChar(context)); +} } // namespace mindspore #endif // MINDSPORE_INCLUDE_API_CONTEXT_H diff --git a/mindspore/ccsrc/cxx_api/context.cc b/mindspore/ccsrc/cxx_api/context.cc index d9679639f93..b8ef34b9556 100644 --- a/mindspore/ccsrc/cxx_api/context.cc +++ b/mindspore/ccsrc/cxx_api/context.cc @@ -21,6 +21,7 @@ constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target"; constexpr auto kGlobalContextDeviceID = "mindspore.ascend.globalcontext.device_id"; +constexpr auto kGlobalContextDumpCfgPath = "mindspore.ascend.globalcontext.dump_config_file_path"; constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc constexpr auto kModelOptionInputShape = "mindspore.option.input_shape"; @@ -29,6 +30,7 @@ constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP3 constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode"; // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16" constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode"; +constexpr auto KModelOptionFusionSwitchCfgPath = "mindspore.option.fusion_switch_config_file_path"; namespace mindspore { struct Context::Data { @@ -93,6 +95,23 @@ uint32_t GlobalContext::GetGlobalDeviceID() { return GetValue(global_context, kGlobalContextDeviceID); } +void GlobalContext::SetGlobalDumpConfigPath(const std::vector &cfg_path) { + auto global_context = GetGlobalContext(); + MS_EXCEPTION_IF_NULL(global_context); + if (global_context->data == nullptr) { + global_context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(global_context->data); + } + global_context->data->params[kGlobalContextDumpCfgPath] = CharToString(cfg_path); +} + +std::vector GlobalContext::GetGlobalDumpConfigPathChar() { + auto global_context = GetGlobalContext(); + MS_EXCEPTION_IF_NULL(global_context); + const std::string &ref = GetValue(global_context, kGlobalContextDumpCfgPath); + return StringToChar(ref); +} + void ModelContext::SetInsertOpConfigPath(const std::shared_ptr &context, const std::vector &cfg_path) { MS_EXCEPTION_IF_NULL(context); if (context->data == nullptr) { @@ -182,4 +201,20 @@ std::vector ModelContext::GetOpSelectImplModeChar(const std::shared_ptr(context, kModelOptionOpSelectImplMode); return StringToChar(ref); } + +void ModelContext::SetFusionSwitchConfigPath(const std::shared_ptr &context, + const std::vector &cfg_path) { + MS_EXCEPTION_IF_NULL(context); + if (context->data == nullptr) { + context->data = std::make_shared(); + MS_EXCEPTION_IF_NULL(context->data); + } + context->data->params[KModelOptionFusionSwitchCfgPath] = CharToString(cfg_path); +} + +std::vector ModelContext::GetFusionSwitchConfigPathChar(const std::shared_ptr &context) { + MS_EXCEPTION_IF_NULL(context); + const std::string &ref = GetValue(context, KModelOptionFusionSwitchCfgPath); + return StringToChar(ref); +} } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.cc b/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.cc index 624d1c8832e..1dd030f8ff9 100644 --- a/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.cc +++ b/mindspore/ccsrc/cxx_api/graph/acl/acl_env_guard.cc @@ -23,7 +23,7 @@ std::mutex AclEnvGuard::global_acl_env_mutex_; AclEnvGuard::AclEnvGuard(std::string_view cfg_file) { errno_ = aclInit(cfg_file.data()); - if (errno_ != ACL_ERROR_NONE) { + if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_INITIALIZE) { MS_LOG(ERROR) << "Execute aclInit Failed"; return; } @@ -32,7 +32,7 @@ AclEnvGuard::AclEnvGuard(std::string_view cfg_file) { AclEnvGuard::~AclEnvGuard() { errno_ = aclFinalize(); - if (errno_ != ACL_ERROR_NONE) { + if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_FINALIZE) { MS_LOG(ERROR) << "Finalize acl failed"; } MS_LOG(INFO) << "Acl finalize success"; diff --git a/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc index 439161910b5..6867366ac11 100644 --- a/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/acl/acl_graph_impl.cc @@ -90,7 +90,7 @@ Status AclGraphImpl::InitEnv() { return kSuccess; } - acl_env_ = AclEnvGuard::GetAclEnv(""); + acl_env_ = AclEnvGuard::GetAclEnv(GlobalContext::GetGlobalDumpConfigPath()); if (acl_env_ == nullptr) { MS_LOG(ERROR) << "Acl init failed."; return kMCDeviceError; diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc index ffc059f7709..d98be0f5896 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.cc @@ -43,6 +43,7 @@ AclModelOptions::AclModelOptions(const std::shared_ptr &context) { precision_mode = ModelContext::GetPrecisionMode(context); op_select_impl_mode = ModelContext::GetOpSelectImplMode(context); + fusion_switch_cfg_path = ModelContext::GetFusionSwitchConfigPath(context); } std::tuple, std::map> AclModelOptions::GenAclOptions() @@ -50,7 +51,7 @@ std::tuple, std::map init_options_map = { {&op_select_impl_mode, ge::ir_option::OP_SELECT_IMPL_MODE}, {&soc_version, ge::ir_option::SOC_VERSION}, - }; + {&fusion_switch_cfg_path, ge::ir_option::FUSION_SWITCH_FILE}}; const std::map build_options_map = { {&insert_op_cfg_path, ge::ir_option::INSERT_OP_FILE}, {&input_format, ge::ir_option::INPUT_FORMAT}, diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.h b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.h index 46f783c68e7..5bc32bf8def 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.h +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_options.h @@ -34,6 +34,7 @@ struct AclModelOptions { std::string output_type; std::string precision_mode; std::string op_select_impl_mode; + std::string fusion_switch_cfg_path; std::string soc_version = "Ascend310"; explicit AclModelOptions(const std::shared_ptr &context);