!69520 add an env to control the behavior of internal kernel
Merge pull request !69520 from chengbin/fix_master_bkb
This commit is contained in:
commit
b53afe9f32
|
@ -102,8 +102,8 @@ const AnfNodePtr AddLayernormFusion::Process(const FuncGraphPtr &graph, const An
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::string fusion_op_name = "AddLayerNorm";
|
||||
std::vector<std::string> enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
|
||||
const std::string fusion_op_name = "AddLayerNorm";
|
||||
auto &enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
|
||||
bool enable_add_layernorm =
|
||||
(std::find(enable_op_list.begin(), enable_op_list.end(), fusion_op_name) != enable_op_list.end());
|
||||
if (!enable_add_layernorm) {
|
||||
|
|
|
@ -69,8 +69,8 @@ const AnfNodePtr AddRmsNormFusion::Process(const FuncGraphPtr &graph, const AnfN
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::string fusion_op_name = "AddRmsNorm";
|
||||
std::vector<std::string> enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
|
||||
const std::string fusion_op_name = "AddRmsNorm";
|
||||
auto enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
|
||||
bool enable_add_rmsnorm =
|
||||
(std::find(enable_op_list.begin(), enable_op_list.end(), fusion_op_name) != enable_op_list.end());
|
||||
if (!enable_add_rmsnorm) {
|
||||
|
|
|
@ -109,10 +109,9 @@ const AnfNodePtr MatMulAllReduceFusion::Process(const mindspore::FuncGraphPtr &f
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::string fusion_op_name = kMatMulAllReduceOpName;
|
||||
std::vector<std::string> enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
|
||||
auto enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
|
||||
bool enable_matmul_allreduce =
|
||||
(std::find(enable_op_list.begin(), enable_op_list.end(), fusion_op_name) != enable_op_list.end());
|
||||
(std::find(enable_op_list.begin(), enable_op_list.end(), kMatMulAllReduceOpName) != enable_op_list.end());
|
||||
if (!enable_matmul_allreduce) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ bool MultiMatmulsFusion::Run(const FuncGraphPtr &graph) {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
|
||||
auto enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
|
||||
bool enable_matmul_qkv =
|
||||
(std::find(enable_op_list.begin(), enable_op_list.end(), kMatmulQkvOpName) != enable_op_list.end());
|
||||
bool enable_matmul_ffn =
|
||||
|
|
|
@ -591,25 +591,44 @@ void MsContext::SetAscendConfig() {
|
|||
set_param<std::string>(MS_CTX_GE_OPTIONS, "");
|
||||
}
|
||||
|
||||
std::vector<std::string> SplitString(const std::string &str, char delim) {
|
||||
inline void SplitString(const std::string &str, char delim, std::set<std::string> *output_list) {
|
||||
std::stringstream ss(str);
|
||||
std::string item;
|
||||
std::vector<std::string> elems;
|
||||
while (std::getline(ss, item, delim)) {
|
||||
if (!item.empty()) {
|
||||
elems.push_back(item);
|
||||
output_list->emplace(item);
|
||||
}
|
||||
}
|
||||
return elems;
|
||||
}
|
||||
|
||||
inline std::string SetToString(const std::set<std::string> &kernel_list) {
|
||||
std::string out = "";
|
||||
for (auto &name : kernel_list) {
|
||||
out.append(name).append(", ");
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void MsContext::SetMsInternalEnableCustomKernelList() {
|
||||
const std::string kDefaultEnabledOpList =
|
||||
"MatMul,RmsNorm,Add,Sub,FlashAttentionScore,PagedAttention,AddRmsNorm,AddLayerNorm";
|
||||
auto internal_op_boost_env = common::GetEnv("MS_ENABLE_INTERNAL_BOOST");
|
||||
bool is_enalbe_internal_op = true;
|
||||
if (internal_op_boost_env == "off") {
|
||||
is_enalbe_internal_op = false;
|
||||
}
|
||||
|
||||
ms_internal_enable_custom_kernel_list_.clear();
|
||||
if (is_enalbe_internal_op) {
|
||||
SplitString(kDefaultEnabledOpList, ',', &ms_internal_enable_custom_kernel_list_);
|
||||
}
|
||||
|
||||
std::string env = common::GetEnv("MS_INTERNAL_ENABLE_CUSTOM_KERNEL_LIST");
|
||||
if (!env.empty()) {
|
||||
ms_internal_enable_custom_kernel_list_ = SplitString(env, ',');
|
||||
MS_LOG(INFO) << "MS internal enable custom kernel list is " << ms_internal_enable_custom_kernel_list_;
|
||||
return;
|
||||
SplitString(env, ',', &ms_internal_enable_custom_kernel_list_);
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Enable internal kernel list: " << SetToString(ms_internal_enable_custom_kernel_list_);
|
||||
}
|
||||
|
||||
bool MsContext::IsEnableInferBoost() {
|
||||
|
@ -637,7 +656,7 @@ bool MsContext::IsEnableInferBoost() {
|
|||
return enable_infer_boost_.value();
|
||||
}
|
||||
|
||||
std::vector<std::string> MsContext::ms_internal_enable_custom_kernel_list() const {
|
||||
const std::set<std::string> &MsContext::ms_internal_enable_custom_kernel_list() const {
|
||||
return ms_internal_enable_custom_kernel_list_;
|
||||
}
|
||||
|
||||
|
|
|
@ -237,7 +237,7 @@ class MS_CORE_API MsContext {
|
|||
|
||||
bool IsEnableInferBoost();
|
||||
void SetMsInternalEnableCustomKernelList();
|
||||
std::vector<std::string> ms_internal_enable_custom_kernel_list() const;
|
||||
const std::set<std::string> &ms_internal_enable_custom_kernel_list() const;
|
||||
|
||||
void RegisterSetEnv(const EnvFunc &func);
|
||||
void RegisterCheckEnv(const EnvFunc &func);
|
||||
|
@ -327,7 +327,7 @@ class MS_CORE_API MsContext {
|
|||
bool not_convert_jit_{false};
|
||||
|
||||
std::optional<bool> enable_infer_boost_ = std::nullopt;
|
||||
std::vector<std::string> ms_internal_enable_custom_kernel_list_;
|
||||
std::set<std::string> ms_internal_enable_custom_kernel_list_;
|
||||
};
|
||||
|
||||
// set method implementation for type bool/int/uint32_t/float/std::string
|
||||
|
|
Loading…
Reference in New Issue