!69520 add an env to control the behavior of internal kernel

Merge pull request !69520 from chengbin/fix_master_bkb
This commit is contained in:
zhengzuohe 2024-05-17 02:00:43 +00:00 committed by Gitee
commit b53afe9f32
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 36 additions and 18 deletions

View File

@ -102,8 +102,8 @@ const AnfNodePtr AddLayernormFusion::Process(const FuncGraphPtr &graph, const An
return nullptr;
}
std::string fusion_op_name = "AddLayerNorm";
std::vector<std::string> enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
const std::string fusion_op_name = "AddLayerNorm";
auto &enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
bool enable_add_layernorm =
(std::find(enable_op_list.begin(), enable_op_list.end(), fusion_op_name) != enable_op_list.end());
if (!enable_add_layernorm) {

View File

@ -69,8 +69,8 @@ const AnfNodePtr AddRmsNormFusion::Process(const FuncGraphPtr &graph, const AnfN
return nullptr;
}
std::string fusion_op_name = "AddRmsNorm";
std::vector<std::string> enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
const std::string fusion_op_name = "AddRmsNorm";
auto enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
bool enable_add_rmsnorm =
(std::find(enable_op_list.begin(), enable_op_list.end(), fusion_op_name) != enable_op_list.end());
if (!enable_add_rmsnorm) {

View File

@ -109,10 +109,9 @@ const AnfNodePtr MatMulAllReduceFusion::Process(const mindspore::FuncGraphPtr &f
return nullptr;
}
std::string fusion_op_name = kMatMulAllReduceOpName;
std::vector<std::string> enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
auto enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
bool enable_matmul_allreduce =
(std::find(enable_op_list.begin(), enable_op_list.end(), fusion_op_name) != enable_op_list.end());
(std::find(enable_op_list.begin(), enable_op_list.end(), kMatMulAllReduceOpName) != enable_op_list.end());
if (!enable_matmul_allreduce) {
return nullptr;
}

View File

@ -27,7 +27,7 @@ bool MultiMatmulsFusion::Run(const FuncGraphPtr &graph) {
return false;
}
std::vector<std::string> enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
auto enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
bool enable_matmul_qkv =
(std::find(enable_op_list.begin(), enable_op_list.end(), kMatmulQkvOpName) != enable_op_list.end());
bool enable_matmul_ffn =

View File

@ -591,25 +591,44 @@ void MsContext::SetAscendConfig() {
set_param<std::string>(MS_CTX_GE_OPTIONS, "");
}
std::vector<std::string> SplitString(const std::string &str, char delim) {
inline void SplitString(const std::string &str, char delim, std::set<std::string> *output_list) {
std::stringstream ss(str);
std::string item;
std::vector<std::string> elems;
while (std::getline(ss, item, delim)) {
if (!item.empty()) {
elems.push_back(item);
output_list->emplace(item);
}
}
return elems;
}
inline std::string SetToString(const std::set<std::string> &kernel_list) {
std::string out = "";
for (auto &name : kernel_list) {
out.append(name).append(", ");
}
return out;
}
void MsContext::SetMsInternalEnableCustomKernelList() {
const std::string kDefaultEnabledOpList =
"MatMul,RmsNorm,Add,Sub,FlashAttentionScore,PagedAttention,AddRmsNorm,AddLayerNorm";
auto internal_op_boost_env = common::GetEnv("MS_ENABLE_INTERNAL_BOOST");
bool is_enalbe_internal_op = true;
if (internal_op_boost_env == "off") {
is_enalbe_internal_op = false;
}
ms_internal_enable_custom_kernel_list_.clear();
if (is_enalbe_internal_op) {
SplitString(kDefaultEnabledOpList, ',', &ms_internal_enable_custom_kernel_list_);
}
std::string env = common::GetEnv("MS_INTERNAL_ENABLE_CUSTOM_KERNEL_LIST");
if (!env.empty()) {
ms_internal_enable_custom_kernel_list_ = SplitString(env, ',');
MS_LOG(INFO) << "MS internal enable custom kernel list is " << ms_internal_enable_custom_kernel_list_;
return;
SplitString(env, ',', &ms_internal_enable_custom_kernel_list_);
}
MS_LOG(INFO) << "Enable internal kernel list: " << SetToString(ms_internal_enable_custom_kernel_list_);
}
bool MsContext::IsEnableInferBoost() {
@ -637,7 +656,7 @@ bool MsContext::IsEnableInferBoost() {
return enable_infer_boost_.value();
}
std::vector<std::string> MsContext::ms_internal_enable_custom_kernel_list() const {
const std::set<std::string> &MsContext::ms_internal_enable_custom_kernel_list() const {
return ms_internal_enable_custom_kernel_list_;
}

View File

@ -237,7 +237,7 @@ class MS_CORE_API MsContext {
bool IsEnableInferBoost();
void SetMsInternalEnableCustomKernelList();
std::vector<std::string> ms_internal_enable_custom_kernel_list() const;
const std::set<std::string> &ms_internal_enable_custom_kernel_list() const;
void RegisterSetEnv(const EnvFunc &func);
void RegisterCheckEnv(const EnvFunc &func);
@ -327,7 +327,7 @@ class MS_CORE_API MsContext {
bool not_convert_jit_{false};
std::optional<bool> enable_infer_boost_ = std::nullopt;
std::vector<std::string> ms_internal_enable_custom_kernel_list_;
std::set<std::string> ms_internal_enable_custom_kernel_list_;
};
// set method implementation for type bool/int/uint32_t/float/std::string