diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 8864bb4f5b1..2b3d08bacb7 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -21,6 +21,7 @@ #include #include +#include "utils/context/graph_kernel_flags.h" #include "backend/kernel_compiler/akg/akg_kernel_json_generator.h" #include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/kernel_build_info.h" @@ -67,6 +68,29 @@ std::unordered_set GetExpandOps() { prim::kPrimAssignAdd, #endif }; + auto new_prim = [](const std::string &name) { return std::make_shared(name); }; + auto &flags = context::GraphKernelFlags::GetInstance(); + auto &enable_ops_only = flags.enable_expand_ops_only; + if (!enable_ops_only.empty()) { + expand_ops.clear(); + std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::inserter(expand_ops, expand_ops.end()), + new_prim); + } else { + auto &enable_ops = flags.enable_expand_ops; + auto &disable_ops = flags.disable_expand_ops; + if (!enable_ops.empty()) { + std::transform(enable_ops.begin(), enable_ops.end(), std::inserter(expand_ops, expand_ops.end()), new_prim); + } + if (!disable_ops.empty()) { + for (auto iter = expand_ops.begin(); iter != expand_ops.end();) { + if (std::find(disable_ops.begin(), disable_ops.end(), (*iter)->name()) != disable_ops.end()) { + expand_ops.erase(iter++); + } else { + ++iter; + } + } + } + } return expand_ops; } } // namespace diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc index 400a0fc808d..086acf9c93e 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc @@ -17,6 +17,7 @@ #include "backend/optimizer/mem_reuse/mem_reuse.h" #include #include +#include "utils/context/graph_kernel_flags.h" #include "backend/optimizer/mem_reuse/mem_reuse_checker.h" #include "backend/optimizer/common/helper.h" @@ -462,9 +463,7 @@ void MemReuseUtil::SetAllInfo(const KernelGraph *graph) { MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph); #endif - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - enable_visit_kernel_cache_ = context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL); + enable_visit_kernel_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel(); } uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const { diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index d687b29077d..5ae2cff99ce 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -46,6 +46,7 @@ #include "runtime/device/ascend/ascend_stream_assign.h" #include "backend/session/anf_runtime_algorithm.h" #include "utils/ms_utils.h" +#include "utils/context/graph_kernel_flags.h" #include "backend/optimizer/common/helper.h" #include "runtime/device/kernel_runtime_manager.h" #include "utils/config_manager.h" @@ -846,9 +847,7 @@ void AscendSession::HardwareOptimize(const std::shared_ptr &kernel_ } void AscendSession::GraphKernelOptimize(const std::shared_ptr &kernel_graph) const { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL))) { + if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { return; } opt::GraphKernelOptimize(kernel_graph); diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index ac9e3a80507..ad7b74c4986 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -69,6 +69,7 @@ #include "utils/ms_utils.h" #include "utils/config_manager.h" #include "utils/ms_context.h" +#include "utils/context/graph_kernel_flags.h" #include "utils/utils.h" #if ENABLE_CPU && ENABLE_GPU #include "ps/util.h" @@ -127,8 +128,6 @@ void GPUSession::StartKernelRT() const { void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); auto optimizer = std::make_shared(); auto pm = std::make_shared(); pm->AddPass(std::make_shared()); @@ -136,7 +135,7 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); - if (!(context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL))) { + if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { pm->AddPass(std::make_shared("cast_all")); } pm->AddPass(std::make_shared("combine_momentum")); @@ -181,9 +180,7 @@ void GPUSession::RunOpHardwareOptimize(const std::shared_ptr &kerne } void GPUSession::GraphKernelOptimize(const std::shared_ptr &kernel_graph) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL))) { + if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { return; } opt::GraphKernelOptimize(kernel_graph); diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index 3af722277c9..8e1f4897584 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -40,6 +40,7 @@ #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h" #include "frontend/optimizer/recompute.h" #include "utils/log_adapter.h" +#include "utils/context/graph_kernel_flags.h" #include "pipeline/jit/pipeline_split.h" #include "pipeline/jit/static_analysis/auto_monad.h" #include "frontend/optimizer/irpass/gradient_eliminate.h" @@ -354,9 +355,7 @@ void InitOpt(const ResourcePtr &res) { g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass)); g_pass_opts["opt_after_recompute"] = Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass)); - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - if (!(context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL))) { + if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) { g_pass_opts["opt_graph_kernel_a"]->set_enable(false); g_pass_opts["opt_graph_kernel_b"]->set_enable(false); } diff --git a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc index efd8f46767d..cc359182bb9 100644 --- a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc +++ b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc @@ -97,6 +97,7 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) { .value("tune_mode", MsCtxParam::MS_CTX_TUNE_MODE) .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) .value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH) + .value("graph_kernel_flags", MsCtxParam::MS_CTX_GRAPH_KERNEL_FLAGS) .value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR); (void)py::class_>(*m, "MSContext") .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 5bf13e92d92..15b05369085 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -24,6 +24,7 @@ #include "runtime/device/gpu/distribution/collective_init.h" #include "utils/convert_utils.h" #include "utils/ms_context.h" +#include "utils/context/graph_kernel_flags.h" #include "runtime/device/kernel_runtime_manager.h" #include "runtime/device/gpu/gpu_common.h" #include "utils/ms_utils.h" @@ -66,9 +67,7 @@ bool GPUKernelRuntime::SyncStream() { } bool GPUKernelRuntime::Init() { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - enable_relation_cache_ = context_ptr->get_param(MS_CTX_ENABLE_GRAPH_KERNEL); + enable_relation_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel(); if (device_init_ == true) { GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory(); diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.cc b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc new file mode 100644 index 00000000000..25e73002f1a --- /dev/null +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.cc @@ -0,0 +1,196 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/context/graph_kernel_flags.h" + +#include +#include +#include +#include +#include +#include "nlohmann/json.hpp" +#include "utils/ms_context.h" + +namespace mindspore { +namespace context { +namespace { +// Split string to tokens +std::vector GetTokens(const std::string &str, const std::string &delim) { + std::vector tokens; + std::vector c_str(str.begin(), str.end()); + c_str.push_back('\0'); + char *saveptr; + char *pch = strtok_r(&c_str[0], delim.c_str(), &saveptr); + while (pch != NULL) { + tokens.emplace_back(pch); + pch = strtok_r(NULL, delim.c_str(), &saveptr); + } + return tokens; +} + +// Parse flag string to key-value pair. +// Flag format: "--key=value", bool flag's value can be implicit, the "--key" means "--key=true" +std::pair ParseFlag(const std::string &flag) { + auto i = flag.find("--"); + // check the string starts with "--". + if (i != 0 || flag.size() == 2) { + return std::pair(); + } + i += 2; + + auto j = flag.find('=', i + 1); // the key should not be empty, "--=" is invalid + if (j == std::string::npos) { + // no value, treated as bool flag. + return std::make_pair(flag.substr(i), ""); + } else if (j + 1 != flag.size() && flag.find('=', j + 1) == std::string::npos) { + // normal "--key=value" format + return std::make_pair(flag.substr(i, j - i), flag.substr(j + 1)); + } + // string with two "=" is invalid. + return std::pair(); +} + +std::map ParseFlags(const std::string &flags) { + std::map flag_map; + auto tokens = GetTokens(flags, " "); + for (const auto &token : tokens) { + auto flag = ParseFlag(token); + if (flag.first != "") { + if (!flag_map.insert(flag).second) { + MS_LOG(WARNING) << "Repeated GraphKernel flag: " << flag.first; + } + } else { + MS_LOG(WARNING) << "Invalid GraphKernel flag: " << token; + } + } + return flag_map; +} + +class FlagRegister { + public: + explicit FlagRegister(std::map *flag_map) : flag_map_(*flag_map) {} + ~FlagRegister() = default; + + template + void AddFlag(std::string flag_name, T *flag_var) { + auto iter = flag_map_.find(flag_name); + if (iter != flag_map_.end()) { + T var; + bool ret = ParseValue(iter->second, &var); + if (ret) { + *flag_var = std::move(var); + } else { + if (iter->second.empty()) { + MS_LOG(WARNING) << "Invalid GraphKernel flag: --" << iter->first; + } else { + MS_LOG(WARNING) << "Invalid GraphKernel flag: --" << iter->first << "=" << iter->second; + } + } + flag_map_.erase(iter); + } + } + + private: + bool ParseValue(const std::string &s, std::vector *result) { + *result = GetTokens(s, ","); + return !result->empty(); + } + + bool ParseValue(const std::string &s, bool *result) { + *result = (s.empty() || s == "true" || s == "on" || s == "1"); + return *result || s == "false" || s == "off" || s == "0"; + } + + template + bool ParseValue(const std::string &s, T *result) { + if (s.empty()) { + return false; + } + std::istringstream iss(s); + iss >> (*result); + return iss.eof(); + } + + template + bool ParseValue(const std::string &s, std::vector *result) { + result->clear(); + auto tokens = GetTokens(s, ","); + if (tokens.empty()) { + return false; + } + for (const auto &tok : tokens) { + T temp; + if (!ParseValue(tok, &temp)) { + result->clear(); + return false; + } + result->emplace_back(temp); + } + return true; + } + + std::map &flag_map_; +}; +} // namespace + +void GraphKernelFlags::Refresh() { + auto flag_map = ParseFlags(flags_cache_); + RegisterFlags(&flag_map); + for (auto &item : flag_map) { + MS_LOG(WARNING) << "Unknown GraphKernel flag: " << item.first; + } +} + +void GraphKernelFlags::RegisterFlags(std::map *flag_map) { + FlagRegister reg(flag_map); + + reg.AddFlag("dump_as_text", &dump_as_text); + + reg.AddFlag("opt_level", &opt_level); + reg.AddFlag("auto_tune", &auto_tune); + reg.AddFlag("cluster_limit", &cluster_limit); + + reg.AddFlag("enable_expand_ops", &enable_expand_ops); + reg.AddFlag("enable_expand_ops_only", &enable_expand_ops_only); + reg.AddFlag("disable_expand_ops", &disable_expand_ops); + reg.AddFlag("enable_cluster_ops", &enable_cluster_ops); + reg.AddFlag("enable_cluster_ops_only", &enable_cluster_ops_only); + reg.AddFlag("disable_cluster_ops", &disable_cluster_ops); + reg.AddFlag("enable_pass_only", &enable_pass_only); + reg.AddFlag("disable_pass", &disable_pass); +} + +std::string GraphKernelFlags::DumpAllFlags() const { + nlohmann::json json; + json["dump_as_text"] = dump_as_text; + + json["opt_level"] = opt_level; + json["auto_tune"] = auto_tune; + json["cluster_limit"] = cluster_limit; + + json["enable_expand_ops"] = enable_expand_ops; + json["enable_expand_ops_only"] = enable_expand_ops_only; + json["disable_expand_ops"] = disable_expand_ops; + json["enable_cluster_ops"] = enable_cluster_ops; + json["enable_cluster_ops_only"] = enable_cluster_ops_only; + json["disable_cluster_ops"] = disable_cluster_ops; + json["enable_pass_only"] = enable_pass_only; + json["disable_pass"] = disable_pass; + + return json.dump(); +} +} // namespace context +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/context/graph_kernel_flags.h b/mindspore/ccsrc/utils/context/graph_kernel_flags.h new file mode 100644 index 00000000000..d419750bd4d --- /dev/null +++ b/mindspore/ccsrc/utils/context/graph_kernel_flags.h @@ -0,0 +1,148 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H +#define MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H + +#include +#include +#include +#include +#include +#include "utils/ms_context.h" + +namespace mindspore { +namespace context { +class GraphKernelFlags { + public: + static const GraphKernelFlags &GetInstance() { + static std::unique_ptr flags(nullptr); + auto contexts = GetGraphKernelContext(); + if (flags == nullptr || contexts.first != flags->flags_cache_ || contexts.second != flags->enable_cache_) { + flags.reset(new GraphKernelFlags(contexts.first, contexts.second)); + flags->Refresh(); + } + return *flags; + } + + // Dump all flags to json-format string + std::string DumpAllFlags() const; + + // Check whether graph_kernel is enabled + bool IsEnableGraphKernel() const { return opt_level > 0; } + + GraphKernelFlags(const GraphKernelFlags &flags) = delete; + ~GraphKernelFlags() = default; + + public: + /** + * dump_as_text, unsupported now. + */ + bool dump_as_text{false}; + + /** + * opt_level, value from 0 to 3. + * 0: GraphKernel disabled + * 1: GraphKernel enabled + * 2 and 3 are not supported now. + * the default value is controlled by context `enable_graph_kernel`, + * but if it's also set in `graph_kernel_flags`, then the flag will prevail. + */ + unsigned int opt_level{0}; + + /** + * auto_tune, unsupported now. + */ + unsigned int auto_tune{0}; + + /** + * cluster_limit, unsupported now. + */ + unsigned int cluster_limit{30}; + + /** + * Additional expanding operators (case sensitive). + * The operators to be added into the default expanding operator list. + */ + std::vector enable_expand_ops; + + /** + * Expanding operators to be enabled (case sensitive). + * Unlike the "enable_expand_ops", the default list will be overwritten by this list. + * Note that the "enable_expand_ops" and "disable_expand_ops" will be ignored if this flag is set. + */ + std::vector enable_expand_ops_only; + + /** + * Expanding operators to be disabled (case sensitive). + * The behavior is undefined when this list overlaps with "enable_expand_ops". + */ + std::vector disable_expand_ops; + + /** + * enable_cluster_ops, unsupported now. + */ + std::vector enable_cluster_ops; + + /** + * enable_cluster_ops_only, unsupported now. + */ + std::vector enable_cluster_ops_only; + + /** + * disable_cluster_ops, unsupported now. + */ + std::vector disable_cluster_ops; + + /** + * enable_pass_only, unsupported now. + */ + std::vector enable_pass_only; + + /** + * disable_pass, unsupported now. + */ + std::vector disable_pass; + + private: + GraphKernelFlags(const std::string &graph_kernel_flags, bool enable_graph_kernel) + : flags_cache_(graph_kernel_flags), enable_cache_(enable_graph_kernel) { + opt_level = enable_graph_kernel ? 1 : 0; + } + + // get the `graph_kernel_flags` and `enable_graph_kernel` + static std::pair GetGraphKernelContext() { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + // Use the environment variable in priority + auto env_flags = std::getenv("MS_GRAPH_KERNEL_FLAGS"); + std::string flags = env_flags ? std::string(env_flags) : context->get_param(MS_CTX_GRAPH_KERNEL_FLAGS); + return std::make_pair(flags, context->get_param(MS_CTX_ENABLE_GRAPH_KERNEL)); + } + + // parse and refresh the flags + void Refresh(); + // register the flags defined above + void RegisterFlags(std::map *flag_map); + + // cache the flag string to check whether the flags is changed. + std::string flags_cache_; + // cache the enable_graph_kernel value to check whether the context is changed. + bool enable_cache_; +}; +} // namespace context +} // namespace mindspore +#endif // MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H diff --git a/mindspore/context.py b/mindspore/context.py index 4a2ebeac7dc..9d9eaf4eb5e 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -489,6 +489,7 @@ def _check_target_specific_cfgs(device, arg_key): 'enable_dump': ['Ascend'], 'save_dump_path': ['Ascend'], 'enable_graph_kernel': ['Ascend', 'GPU'], + 'graph_kernel_flags': ['Ascend', 'GPU'], 'enable_reduce_precision': ['Ascend'], 'enable_profiling': ['Ascend'], 'profiling_options': ['Ascend'], @@ -513,7 +514,7 @@ def _check_target_specific_cfgs(device, arg_key): save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, - enable_sparse=bool, max_call_depth=int, env_config_path=str) + enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str) def set_context(**kwargs): """ Set context for running environment. @@ -540,14 +541,14 @@ def set_context(**kwargs): =========================== =========================== ================= check_bprop print_file_path max_device_memory device_id enable_dump enable_graph_kernel - device_target save_dump_path + device_target save_dump_path graph_kernel_flags enable_sparse enable_graph_kernel max_call_depth enable_reduce_precision mode enable_profiling reserve_class_name_in_scope profiling_options save_graphs variable_memory_max_size save_graphs_path auto_tune_mode - env_config_path + env_config_path graph_kernel_flags grad_for_scalar =========================== =========================== ================= @@ -566,6 +567,7 @@ def set_context(**kwargs): `context.set_context(save_graphs_path="path/to/ir/files"+device_id)`. enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be compiled into a fused kernel automatically. Default: False. + graph_kernel_flags (str): Set graph_kernel flags. reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True. enable_reduce_precision (bool): Whether to enable precision reduction. Default: True. enable_dump (bool): Whether to enable dump. Default: False. diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index 06290c3ae1d..6ea55c9ba62 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -39,6 +39,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { set_param(MS_CTX_SAVE_DUMP_PATH, "."); set_param(MS_CTX_ENV_CONFIG_PATH, ""); set_param(MS_CTX_TUNE_MODE, "NO_TUNE"); + set_param(MS_CTX_GRAPH_KERNEL_FLAGS, ""); set_param(MS_CTX_TSD_REF, 0); set_param(MS_CTX_GE_REF, 0); diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index 3e685109829..4f39b2ea50d 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -112,6 +112,7 @@ enum MsCtxParam : unsigned { MS_CTX_PYTHON_EXE_PATH, MS_CTX_ENV_CONFIG_PATH, MS_CTX_TUNE_MODE, + MS_CTX_GRAPH_KERNEL_FLAGS, MS_CTX_TYPE_STRING_END, // parameter numbers of each type