!12871 add 310 options

From: @zhoufeng54
Reviewed-by: @kisnwang,@xu-yfei
Signed-off-by: @xu-yfei
This commit is contained in:
mindspore-ci-bot 2021-03-05 09:44:26 +08:00 committed by Gitee
commit 7036d35e9d
6 changed files with 65 additions and 4 deletions

View File

@ -45,10 +45,16 @@ struct MS_API GlobalContext : public Context {
static void SetGlobalDeviceID(const uint32_t &device_id);
static uint32_t GetGlobalDeviceID();
static inline void SetGlobalDumpConfigPath(const std::string &cfg_path);
static inline std::string GetGlobalDumpConfigPath();
private:
// api without std::string
static void SetGlobalDeviceTarget(const std::vector<char> &device_target);
static std::vector<char> GetGlobalDeviceTargetChar();
static void SetGlobalDumpConfigPath(const std::vector<char> &cfg_path);
static std::vector<char> GetGlobalDumpConfigPathChar();
};
struct MS_API ModelContext : public Context {
@ -72,6 +78,9 @@ struct MS_API ModelContext : public Context {
const std::string &op_select_impl_mode);
static inline std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
static inline void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
static inline std::string GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context);
private:
// api without std::string
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
@ -89,6 +98,9 @@ struct MS_API ModelContext : public Context {
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
const std::vector<char> &op_select_impl_mode);
static std::vector<char> GetOpSelectImplModeChar(const std::shared_ptr<Context> &context);
static void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
static std::vector<char> GetFusionSwitchConfigPathChar(const std::shared_ptr<Context> &context);
};
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
@ -96,6 +108,11 @@ void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
}
std::string GlobalContext::GetGlobalDeviceTarget() { return CharToString(GetGlobalDeviceTargetChar()); }
void GlobalContext::SetGlobalDumpConfigPath(const std::string &cfg_path) {
SetGlobalDumpConfigPath(StringToChar(cfg_path));
}
std::string GlobalContext::GetGlobalDumpConfigPath() { return CharToString(GetGlobalDumpConfigPathChar()); }
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
SetInsertOpConfigPath(context, StringToChar(cfg_path));
}
@ -131,5 +148,12 @@ void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
return CharToString(GetOpSelectImplModeChar(context));
}
void ModelContext::SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
SetFusionSwitchConfigPath(context, StringToChar(cfg_path));
}
std::string ModelContext::GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context) {
return CharToString(GetFusionSwitchConfigPathChar(context));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H

View File

@ -21,6 +21,7 @@
constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
constexpr auto kGlobalContextDeviceID = "mindspore.ascend.globalcontext.device_id";
constexpr auto kGlobalContextDumpCfgPath = "mindspore.ascend.globalcontext.dump_config_file_path";
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
@ -29,6 +30,7 @@ constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP3
constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
// "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
constexpr auto KModelOptionFusionSwitchCfgPath = "mindspore.option.fusion_switch_config_file_path";
namespace mindspore {
struct Context::Data {
@ -93,6 +95,23 @@ uint32_t GlobalContext::GetGlobalDeviceID() {
return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
}
void GlobalContext::SetGlobalDumpConfigPath(const std::vector<char> &cfg_path) {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
if (global_context->data == nullptr) {
global_context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(global_context->data);
}
global_context->data->params[kGlobalContextDumpCfgPath] = CharToString(cfg_path);
}
std::vector<char> GlobalContext::GetGlobalDumpConfigPathChar() {
auto global_context = GetGlobalContext();
MS_EXCEPTION_IF_NULL(global_context);
const std::string &ref = GetValue<std::string>(global_context, kGlobalContextDumpCfgPath);
return StringToChar(ref);
}
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(context);
if (context->data == nullptr) {
@ -182,4 +201,20 @@ std::vector<char> ModelContext::GetOpSelectImplModeChar(const std::shared_ptr<Co
const std::string &ref = GetValue<std::string>(context, kModelOptionOpSelectImplMode);
return StringToChar(ref);
}
void ModelContext::SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context,
const std::vector<char> &cfg_path) {
MS_EXCEPTION_IF_NULL(context);
if (context->data == nullptr) {
context->data = std::make_shared<Data>();
MS_EXCEPTION_IF_NULL(context->data);
}
context->data->params[KModelOptionFusionSwitchCfgPath] = CharToString(cfg_path);
}
std::vector<char> ModelContext::GetFusionSwitchConfigPathChar(const std::shared_ptr<Context> &context) {
MS_EXCEPTION_IF_NULL(context);
const std::string &ref = GetValue<std::string>(context, KModelOptionFusionSwitchCfgPath);
return StringToChar(ref);
}
} // namespace mindspore

View File

@ -23,7 +23,7 @@ std::mutex AclEnvGuard::global_acl_env_mutex_;
AclEnvGuard::AclEnvGuard(std::string_view cfg_file) {
errno_ = aclInit(cfg_file.data());
if (errno_ != ACL_ERROR_NONE) {
if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_INITIALIZE) {
MS_LOG(ERROR) << "Execute aclInit Failed";
return;
}
@ -32,7 +32,7 @@ AclEnvGuard::AclEnvGuard(std::string_view cfg_file) {
AclEnvGuard::~AclEnvGuard() {
errno_ = aclFinalize();
if (errno_ != ACL_ERROR_NONE) {
if (errno_ != ACL_ERROR_NONE && errno_ != ACL_ERROR_REPEAT_FINALIZE) {
MS_LOG(ERROR) << "Finalize acl failed";
}
MS_LOG(INFO) << "Acl finalize success";

View File

@ -90,7 +90,7 @@ Status AclGraphImpl::InitEnv() {
return kSuccess;
}
acl_env_ = AclEnvGuard::GetAclEnv("");
acl_env_ = AclEnvGuard::GetAclEnv(GlobalContext::GetGlobalDumpConfigPath());
if (acl_env_ == nullptr) {
MS_LOG(ERROR) << "Acl init failed.";
return kMCDeviceError;

View File

@ -43,6 +43,7 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
precision_mode = ModelContext::GetPrecisionMode(context);
op_select_impl_mode = ModelContext::GetOpSelectImplMode(context);
fusion_switch_cfg_path = ModelContext::GetFusionSwitchConfigPath(context);
}
std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string>> AclModelOptions::GenAclOptions()
@ -50,7 +51,7 @@ std::tuple<std::map<std::string, std::string>, std::map<std::string, std::string
const std::map<std::string const *, std::string> init_options_map = {
{&op_select_impl_mode, ge::ir_option::OP_SELECT_IMPL_MODE},
{&soc_version, ge::ir_option::SOC_VERSION},
};
{&fusion_switch_cfg_path, ge::ir_option::FUSION_SWITCH_FILE}};
const std::map<std::string const *, std::string> build_options_map = {
{&insert_op_cfg_path, ge::ir_option::INSERT_OP_FILE}, {&input_format, ge::ir_option::INPUT_FORMAT},

View File

@ -34,6 +34,7 @@ struct AclModelOptions {
std::string output_type;
std::string precision_mode;
std::string op_select_impl_mode;
std::string fusion_switch_cfg_path;
std::string soc_version = "Ascend310";
explicit AclModelOptions(const std::shared_ptr<Context> &context);