forked from mindspore-Ecosystem/mindspore
add context graph_kernel_flags
used the flag "opt_level" to control GraphKernel, 0 means disabled while non-zero value means enabled. the default value is controlled by context "enable_graph_kernel", but if it's also set in "graph_kernel_flags", then the flag will prevail. supported the whitelist and blacklist operators for GraphKernelExpander. "enable_expand_ops", "enable_expand_ops_only", "disable_expand_ops".
This commit is contained in:
parent
b1c86b6a22
commit
11ee3b1624
|
@ -21,6 +21,7 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||
|
@ -67,6 +68,29 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
|
|||
prim::kPrimAssignAdd,
|
||||
#endif
|
||||
};
|
||||
auto new_prim = [](const std::string &name) { return std::make_shared<Primitive>(name); };
|
||||
auto &flags = context::GraphKernelFlags::GetInstance();
|
||||
auto &enable_ops_only = flags.enable_expand_ops_only;
|
||||
if (!enable_ops_only.empty()) {
|
||||
expand_ops.clear();
|
||||
std::transform(enable_ops_only.begin(), enable_ops_only.end(), std::inserter(expand_ops, expand_ops.end()),
|
||||
new_prim);
|
||||
} else {
|
||||
auto &enable_ops = flags.enable_expand_ops;
|
||||
auto &disable_ops = flags.disable_expand_ops;
|
||||
if (!enable_ops.empty()) {
|
||||
std::transform(enable_ops.begin(), enable_ops.end(), std::inserter(expand_ops, expand_ops.end()), new_prim);
|
||||
}
|
||||
if (!disable_ops.empty()) {
|
||||
for (auto iter = expand_ops.begin(); iter != expand_ops.end();) {
|
||||
if (std::find(disable_ops.begin(), disable_ops.end(), (*iter)->name()) != disable_ops.end()) {
|
||||
expand_ops.erase(iter++);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return expand_ops;
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "backend/optimizer/mem_reuse/mem_reuse.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
|
@ -462,9 +463,7 @@ void MemReuseUtil::SetAllInfo(const KernelGraph *graph) {
|
|||
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, graph);
|
||||
#endif
|
||||
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
enable_visit_kernel_cache_ = context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL);
|
||||
enable_visit_kernel_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel();
|
||||
}
|
||||
|
||||
uint8_t *MemReuseUtil::GetNodeOutputPtr(const AnfNodePtr &node, size_t index) const {
|
||||
|
|
|
@ -46,6 +46,7 @@
|
|||
#include "runtime/device/ascend/ascend_stream_assign.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "utils/config_manager.h"
|
||||
|
@ -846,9 +847,7 @@ void AscendSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_
|
|||
}
|
||||
|
||||
void AscendSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
return;
|
||||
}
|
||||
opt::GraphKernelOptimize(kernel_graph);
|
||||
|
|
|
@ -69,6 +69,7 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "utils/config_manager.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "utils/utils.h"
|
||||
#if ENABLE_CPU && ENABLE_GPU
|
||||
#include "ps/util.h"
|
||||
|
@ -127,8 +128,6 @@ void GPUSession::StartKernelRT() const {
|
|||
|
||||
void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
|
||||
|
@ -136,7 +135,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
|||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayFusion>());
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
|
||||
}
|
||||
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
|
||||
|
@ -181,9 +180,7 @@ void GPUSession::RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kerne
|
|||
}
|
||||
|
||||
void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
return;
|
||||
}
|
||||
opt::GraphKernelOptimize(kernel_graph);
|
||||
|
|
|
@ -40,6 +40,7 @@
|
|||
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
|
||||
#include "frontend/optimizer/recompute.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "pipeline/jit/pipeline_split.h"
|
||||
#include "pipeline/jit/static_analysis/auto_monad.h"
|
||||
#include "frontend/optimizer/irpass/gradient_eliminate.h"
|
||||
|
@ -354,9 +355,7 @@ void InitOpt(const ResourcePtr &res) {
|
|||
g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
|
||||
g_pass_opts["opt_after_recompute"] =
|
||||
Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
if (!(context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL))) {
|
||||
if (!context::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
g_pass_opts["opt_graph_kernel_a"]->set_enable(false);
|
||||
g_pass_opts["opt_graph_kernel_b"]->set_enable(false);
|
||||
}
|
||||
|
|
|
@ -97,6 +97,7 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
|
|||
.value("tune_mode", MsCtxParam::MS_CTX_TUNE_MODE)
|
||||
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
|
||||
.value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH)
|
||||
.value("graph_kernel_flags", MsCtxParam::MS_CTX_GRAPH_KERNEL_FLAGS)
|
||||
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR);
|
||||
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
|
||||
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "runtime/device/gpu/distribution/collective_init.h"
|
||||
#include "utils/convert_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "runtime/device/gpu/gpu_common.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
@ -66,9 +67,7 @@ bool GPUKernelRuntime::SyncStream() {
|
|||
}
|
||||
|
||||
bool GPUKernelRuntime::Init() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
enable_relation_cache_ = context_ptr->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL);
|
||||
enable_relation_cache_ = context::GraphKernelFlags::GetInstance().IsEnableGraphKernel();
|
||||
|
||||
if (device_init_ == true) {
|
||||
GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory();
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace context {
|
||||
namespace {
|
||||
// Split string to tokens
|
||||
std::vector<std::string> GetTokens(const std::string &str, const std::string &delim) {
|
||||
std::vector<std::string> tokens;
|
||||
std::vector<char> c_str(str.begin(), str.end());
|
||||
c_str.push_back('\0');
|
||||
char *saveptr;
|
||||
char *pch = strtok_r(&c_str[0], delim.c_str(), &saveptr);
|
||||
while (pch != NULL) {
|
||||
tokens.emplace_back(pch);
|
||||
pch = strtok_r(NULL, delim.c_str(), &saveptr);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
// Parse flag string to key-value pair.
|
||||
// Flag format: "--key=value", bool flag's value can be implicit, the "--key" means "--key=true"
|
||||
std::pair<std::string, std::string> ParseFlag(const std::string &flag) {
|
||||
auto i = flag.find("--");
|
||||
// check the string starts with "--".
|
||||
if (i != 0 || flag.size() == 2) {
|
||||
return std::pair<std::string, std::string>();
|
||||
}
|
||||
i += 2;
|
||||
|
||||
auto j = flag.find('=', i + 1); // the key should not be empty, "--=" is invalid
|
||||
if (j == std::string::npos) {
|
||||
// no value, treated as bool flag.
|
||||
return std::make_pair(flag.substr(i), "");
|
||||
} else if (j + 1 != flag.size() && flag.find('=', j + 1) == std::string::npos) {
|
||||
// normal "--key=value" format
|
||||
return std::make_pair(flag.substr(i, j - i), flag.substr(j + 1));
|
||||
}
|
||||
// string with two "=" is invalid.
|
||||
return std::pair<std::string, std::string>();
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> ParseFlags(const std::string &flags) {
|
||||
std::map<std::string, std::string> flag_map;
|
||||
auto tokens = GetTokens(flags, " ");
|
||||
for (const auto &token : tokens) {
|
||||
auto flag = ParseFlag(token);
|
||||
if (flag.first != "") {
|
||||
if (!flag_map.insert(flag).second) {
|
||||
MS_LOG(WARNING) << "Repeated GraphKernel flag: " << flag.first;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Invalid GraphKernel flag: " << token;
|
||||
}
|
||||
}
|
||||
return flag_map;
|
||||
}
|
||||
|
||||
class FlagRegister {
|
||||
public:
|
||||
explicit FlagRegister(std::map<std::string, std::string> *flag_map) : flag_map_(*flag_map) {}
|
||||
~FlagRegister() = default;
|
||||
|
||||
template <typename T>
|
||||
void AddFlag(std::string flag_name, T *flag_var) {
|
||||
auto iter = flag_map_.find(flag_name);
|
||||
if (iter != flag_map_.end()) {
|
||||
T var;
|
||||
bool ret = ParseValue(iter->second, &var);
|
||||
if (ret) {
|
||||
*flag_var = std::move(var);
|
||||
} else {
|
||||
if (iter->second.empty()) {
|
||||
MS_LOG(WARNING) << "Invalid GraphKernel flag: --" << iter->first;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Invalid GraphKernel flag: --" << iter->first << "=" << iter->second;
|
||||
}
|
||||
}
|
||||
flag_map_.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool ParseValue(const std::string &s, std::vector<std::string> *result) {
|
||||
*result = GetTokens(s, ",");
|
||||
return !result->empty();
|
||||
}
|
||||
|
||||
bool ParseValue(const std::string &s, bool *result) {
|
||||
*result = (s.empty() || s == "true" || s == "on" || s == "1");
|
||||
return *result || s == "false" || s == "off" || s == "0";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ParseValue(const std::string &s, T *result) {
|
||||
if (s.empty()) {
|
||||
return false;
|
||||
}
|
||||
std::istringstream iss(s);
|
||||
iss >> (*result);
|
||||
return iss.eof();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ParseValue(const std::string &s, std::vector<T> *result) {
|
||||
result->clear();
|
||||
auto tokens = GetTokens(s, ",");
|
||||
if (tokens.empty()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto &tok : tokens) {
|
||||
T temp;
|
||||
if (!ParseValue(tok, &temp)) {
|
||||
result->clear();
|
||||
return false;
|
||||
}
|
||||
result->emplace_back(temp);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string> &flag_map_;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void GraphKernelFlags::Refresh() {
|
||||
auto flag_map = ParseFlags(flags_cache_);
|
||||
RegisterFlags(&flag_map);
|
||||
for (auto &item : flag_map) {
|
||||
MS_LOG(WARNING) << "Unknown GraphKernel flag: " << item.first;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_map) {
|
||||
FlagRegister reg(flag_map);
|
||||
|
||||
reg.AddFlag("dump_as_text", &dump_as_text);
|
||||
|
||||
reg.AddFlag("opt_level", &opt_level);
|
||||
reg.AddFlag("auto_tune", &auto_tune);
|
||||
reg.AddFlag("cluster_limit", &cluster_limit);
|
||||
|
||||
reg.AddFlag("enable_expand_ops", &enable_expand_ops);
|
||||
reg.AddFlag("enable_expand_ops_only", &enable_expand_ops_only);
|
||||
reg.AddFlag("disable_expand_ops", &disable_expand_ops);
|
||||
reg.AddFlag("enable_cluster_ops", &enable_cluster_ops);
|
||||
reg.AddFlag("enable_cluster_ops_only", &enable_cluster_ops_only);
|
||||
reg.AddFlag("disable_cluster_ops", &disable_cluster_ops);
|
||||
reg.AddFlag("enable_pass_only", &enable_pass_only);
|
||||
reg.AddFlag("disable_pass", &disable_pass);
|
||||
}
|
||||
|
||||
std::string GraphKernelFlags::DumpAllFlags() const {
|
||||
nlohmann::json json;
|
||||
json["dump_as_text"] = dump_as_text;
|
||||
|
||||
json["opt_level"] = opt_level;
|
||||
json["auto_tune"] = auto_tune;
|
||||
json["cluster_limit"] = cluster_limit;
|
||||
|
||||
json["enable_expand_ops"] = enable_expand_ops;
|
||||
json["enable_expand_ops_only"] = enable_expand_ops_only;
|
||||
json["disable_expand_ops"] = disable_expand_ops;
|
||||
json["enable_cluster_ops"] = enable_cluster_ops;
|
||||
json["enable_cluster_ops_only"] = enable_cluster_ops_only;
|
||||
json["disable_cluster_ops"] = disable_cluster_ops;
|
||||
json["enable_pass_only"] = enable_pass_only;
|
||||
json["disable_pass"] = disable_pass;
|
||||
|
||||
return json.dump();
|
||||
}
|
||||
} // namespace context
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,148 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
|
||||
#define MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace context {
|
||||
class GraphKernelFlags {
|
||||
public:
|
||||
static const GraphKernelFlags &GetInstance() {
|
||||
static std::unique_ptr<GraphKernelFlags> flags(nullptr);
|
||||
auto contexts = GetGraphKernelContext();
|
||||
if (flags == nullptr || contexts.first != flags->flags_cache_ || contexts.second != flags->enable_cache_) {
|
||||
flags.reset(new GraphKernelFlags(contexts.first, contexts.second));
|
||||
flags->Refresh();
|
||||
}
|
||||
return *flags;
|
||||
}
|
||||
|
||||
// Dump all flags to json-format string
|
||||
std::string DumpAllFlags() const;
|
||||
|
||||
// Check whether graph_kernel is enabled
|
||||
bool IsEnableGraphKernel() const { return opt_level > 0; }
|
||||
|
||||
GraphKernelFlags(const GraphKernelFlags &flags) = delete;
|
||||
~GraphKernelFlags() = default;
|
||||
|
||||
public:
|
||||
/**
|
||||
* dump_as_text, unsupported now.
|
||||
*/
|
||||
bool dump_as_text{false};
|
||||
|
||||
/**
|
||||
* opt_level, value from 0 to 3.
|
||||
* 0: GraphKernel disabled
|
||||
* 1: GraphKernel enabled
|
||||
* 2 and 3 are not supported now.
|
||||
* the default value is controlled by context `enable_graph_kernel`,
|
||||
* but if it's also set in `graph_kernel_flags`, then the flag will prevail.
|
||||
*/
|
||||
unsigned int opt_level{0};
|
||||
|
||||
/**
|
||||
* auto_tune, unsupported now.
|
||||
*/
|
||||
unsigned int auto_tune{0};
|
||||
|
||||
/**
|
||||
* cluster_limit, unsupported now.
|
||||
*/
|
||||
unsigned int cluster_limit{30};
|
||||
|
||||
/**
|
||||
* Additional expanding operators (case sensitive).
|
||||
* The operators to be added into the default expanding operator list.
|
||||
*/
|
||||
std::vector<std::string> enable_expand_ops;
|
||||
|
||||
/**
|
||||
* Expanding operators to be enabled (case sensitive).
|
||||
* Unlike the "enable_expand_ops", the default list will be overwritten by this list.
|
||||
* Note that the "enable_expand_ops" and "disable_expand_ops" will be ignored if this flag is set.
|
||||
*/
|
||||
std::vector<std::string> enable_expand_ops_only;
|
||||
|
||||
/**
|
||||
* Expanding operators to be disabled (case sensitive).
|
||||
* The behavior is undefined when this list overlaps with "enable_expand_ops".
|
||||
*/
|
||||
std::vector<std::string> disable_expand_ops;
|
||||
|
||||
/**
|
||||
* enable_cluster_ops, unsupported now.
|
||||
*/
|
||||
std::vector<std::string> enable_cluster_ops;
|
||||
|
||||
/**
|
||||
* enable_cluster_ops_only, unsupported now.
|
||||
*/
|
||||
std::vector<std::string> enable_cluster_ops_only;
|
||||
|
||||
/**
|
||||
* disable_cluster_ops, unsupported now.
|
||||
*/
|
||||
std::vector<std::string> disable_cluster_ops;
|
||||
|
||||
/**
|
||||
* enable_pass_only, unsupported now.
|
||||
*/
|
||||
std::vector<std::string> enable_pass_only;
|
||||
|
||||
/**
|
||||
* disable_pass, unsupported now.
|
||||
*/
|
||||
std::vector<std::string> disable_pass;
|
||||
|
||||
private:
|
||||
GraphKernelFlags(const std::string &graph_kernel_flags, bool enable_graph_kernel)
|
||||
: flags_cache_(graph_kernel_flags), enable_cache_(enable_graph_kernel) {
|
||||
opt_level = enable_graph_kernel ? 1 : 0;
|
||||
}
|
||||
|
||||
// get the `graph_kernel_flags` and `enable_graph_kernel`
|
||||
static std::pair<std::string, bool> GetGraphKernelContext() {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
// Use the environment variable in priority
|
||||
auto env_flags = std::getenv("MS_GRAPH_KERNEL_FLAGS");
|
||||
std::string flags = env_flags ? std::string(env_flags) : context->get_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS);
|
||||
return std::make_pair(flags, context->get_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL));
|
||||
}
|
||||
|
||||
// parse and refresh the flags
|
||||
void Refresh();
|
||||
// register the flags defined above
|
||||
void RegisterFlags(std::map<std::string, std::string> *flag_map);
|
||||
|
||||
// cache the flag string to check whether the flags is changed.
|
||||
std::string flags_cache_;
|
||||
// cache the enable_graph_kernel value to check whether the context is changed.
|
||||
bool enable_cache_;
|
||||
};
|
||||
} // namespace context
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_UTILS_GRAPH_KERNEL_FLAGS_H
|
|
@ -489,6 +489,7 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
'enable_dump': ['Ascend'],
|
||||
'save_dump_path': ['Ascend'],
|
||||
'enable_graph_kernel': ['Ascend', 'GPU'],
|
||||
'graph_kernel_flags': ['Ascend', 'GPU'],
|
||||
'enable_reduce_precision': ['Ascend'],
|
||||
'enable_profiling': ['Ascend'],
|
||||
'profiling_options': ['Ascend'],
|
||||
|
@ -513,7 +514,7 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
||||
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
||||
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
|
||||
enable_sparse=bool, max_call_depth=int, env_config_path=str)
|
||||
enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str)
|
||||
def set_context(**kwargs):
|
||||
"""
|
||||
Set context for running environment.
|
||||
|
@ -540,14 +541,14 @@ def set_context(**kwargs):
|
|||
=========================== =========================== =================
|
||||
check_bprop print_file_path max_device_memory
|
||||
device_id enable_dump enable_graph_kernel
|
||||
device_target save_dump_path
|
||||
device_target save_dump_path graph_kernel_flags
|
||||
enable_sparse enable_graph_kernel
|
||||
max_call_depth enable_reduce_precision
|
||||
mode enable_profiling
|
||||
reserve_class_name_in_scope profiling_options
|
||||
save_graphs variable_memory_max_size
|
||||
save_graphs_path auto_tune_mode
|
||||
env_config_path
|
||||
env_config_path graph_kernel_flags
|
||||
grad_for_scalar
|
||||
=========================== =========================== =================
|
||||
|
||||
|
@ -566,6 +567,7 @@ def set_context(**kwargs):
|
|||
`context.set_context(save_graphs_path="path/to/ir/files"+device_id)`.
|
||||
enable_graph_kernel (bool): Whether to enable composition of basic primitives. These primitives would be
|
||||
compiled into a fused kernel automatically. Default: False.
|
||||
graph_kernel_flags (str): Set graph_kernel flags.
|
||||
reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
|
||||
enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
|
||||
enable_dump (bool): Whether to enable dump. Default: False.
|
||||
|
|
|
@ -39,6 +39,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, ".");
|
||||
set_param<std::string>(MS_CTX_ENV_CONFIG_PATH, "");
|
||||
set_param<std::string>(MS_CTX_TUNE_MODE, "NO_TUNE");
|
||||
set_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS, "");
|
||||
set_param<uint32_t>(MS_CTX_TSD_REF, 0);
|
||||
set_param<uint32_t>(MS_CTX_GE_REF, 0);
|
||||
|
||||
|
|
|
@ -112,6 +112,7 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_PYTHON_EXE_PATH,
|
||||
MS_CTX_ENV_CONFIG_PATH,
|
||||
MS_CTX_TUNE_MODE,
|
||||
MS_CTX_GRAPH_KERNEL_FLAGS,
|
||||
MS_CTX_TYPE_STRING_END,
|
||||
|
||||
// parameter numbers of each type
|
||||
|
|
Loading…
Reference in New Issue