diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index 5626b2b5de5..5cb24655c7d 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -13,6 +13,7 @@ mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc:mi mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/data_queue_op.cc:mindspore::dataset::DataQueueOp::SendDataToAscend mindspore/mindspore/lite/tools/converter/graphdef_transform.cc:mindspore::lite::GraphDefTransform::Transform mindspore/mindspore/lite/providers/nnie_proposal/src/proposal.cc:mindspore::proposal::Rpn +mindspore/mindspore/lite/tools/converter/config_parser/config_file_parser.cc:mindspore::lite::ConfigFileParser::SetParamByConfigfile mindspore/mindspore/core/abstract/ops/primitive_infer_map.cc:mindspore::abstract::GetPrimitiveInferMap mindspore/mindspore/core/ir/tensor.cc:mindspore::tensor::MakeTensorData mindspore/mindspore/core/ir/dtype_extends.cc:mindspore::UnitSizeInBytes diff --git a/mindspore/lite/src/common/common.h b/mindspore/lite/src/common/common.h index d1380ee960e..34b5c505a02 100644 --- a/mindspore/lite/src/common/common.h +++ b/mindspore/lite/src/common/common.h @@ -108,6 +108,7 @@ static const char *const kGraphCompilerCacheDirKey = "model_cache_dir"; static const char *const kModifyMixList = "mixprecision_list_path"; static const char *const kEnableCustomOp = "enable_custom_op"; static const char *const kPluginCustomOps = "plugin_custom_ops"; +static const char *const kOpAttrs = "op_attrs"; static const char *const kAoeMode = "aoe_mode"; static const char *const kProvider = "provider"; static const char *const kAscendProviderGe = "ge"; diff --git a/mindspore/lite/tools/converter/adapter/acl/common/acl_types.h b/mindspore/lite/tools/converter/adapter/acl/common/acl_types.h index 04c989dba9e..91fe8a59440 100644 --- a/mindspore/lite/tools/converter/adapter/acl/common/acl_types.h +++ b/mindspore/lite/tools/converter/adapter/acl/common/acl_types.h @@ -56,6 +56,7 @@ struct AclModelOptionCfg { std::map build_options_map; std::map aoe_global_options_map; std::map aoe_tuning_options_map; + std::map> op_attrs_map; }; } // namespace acl } // namespace lite diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 05186725b25..e8e6f024a53 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -821,7 +821,8 @@ bool AnfTransform::StoreBuiltinPass(const std::shared_ptr ¶m) param->aclModelOptionCfgParam.disable_custom_fusion_pattern), false}, {"MakeListPass", std::make_shared(), true}, - {"FlashAttentionFusion", std::make_shared(), false}, + {"FlashAttentionFusion", std::make_shared(param->aclModelOptionCfgParam.op_attrs_map), + false}, {"GroupNormSiluFusion", std::make_shared(), false}, {"GeGluV2Fusion", std::make_shared(), false}, {"AddLayerNormFusion", std::make_shared(), false}, diff --git a/mindspore/lite/tools/converter/anf_transform_for_ge.cc b/mindspore/lite/tools/converter/anf_transform_for_ge.cc index d2722e61596..f4b2e392055 100644 --- a/mindspore/lite/tools/converter/anf_transform_for_ge.cc +++ b/mindspore/lite/tools/converter/anf_transform_for_ge.cc @@ -52,23 +52,24 @@ #include "tools/optimizer/graph/padv3_ge_pass.h" namespace mindspore::lite { -void EnableKVCacheFusion(std::vector *fusions) { +void EnableKVCacheFusion(std::vector *fusions, const std::shared_ptr ¶m) { fusions->push_back(std::make_shared()); fusions->push_back(std::make_shared()); fusions->push_back(std::make_shared()); fusions->push_back(std::make_shared()); } -void EnableMatMulAllReduceFusion(std::vector *fusions) { +void EnableMatMulAllReduceFusion(std::vector *fusions, const std::shared_ptr ¶m) { fusions->push_back(std::make_shared()); fusions->push_back(std::make_shared()); } -void EnableFlashAttentionFusion(std::vector *fusions) { - fusions->push_back(std::make_shared()); +void EnableFlashAttentionFusion(std::vector *fusions, const std::shared_ptr ¶m) { + fusions->push_back(std::make_shared(param->ascendGeOptionCfg.op_attrs_map)); } -void EnableFlashAttentionAntiquantFusion(std::vector *fusions) { +void EnableFlashAttentionAntiquantFusion(std::vector *fusions, + const std::shared_ptr ¶m) { fusions->push_back(std::make_shared()); } @@ -84,25 +85,32 @@ int AnfTransformForGe::RunGeFusionPass(const FuncGraphPtr &old_graph, const std: std::vector fusions{std::make_shared(), std::make_shared(), std::make_shared()}; - std::map *)>> fusion_mappings = { - {kFusionNameMatMulAllReduce, std::function *)>(EnableMatMulAllReduceFusion)}, - {kFusionNameKVCache, std::function *)>(EnableKVCacheFusion)}, - {kFusionNameFlashAttention, std::function *)>(EnableFlashAttentionFusion)}, - {kFusionNameFlashAttentionAntiquant, - std::function *)>(EnableFlashAttentionAntiquantFusion)}}; + std::map *, const std::shared_ptr &)>> + fusion_mappings = { + {kFusionNameMatMulAllReduce, + std::function *, const std::shared_ptr &)>( + EnableMatMulAllReduceFusion)}, + {kFusionNameKVCache, + std::function *, const std::shared_ptr &)>(EnableKVCacheFusion)}, + {kFusionNameFlashAttention, + std::function *, const std::shared_ptr &)>( + EnableFlashAttentionFusion)}, + {kFusionNameFlashAttentionAntiquant, + std::function *, const std::shared_ptr &)>( + EnableFlashAttentionAntiquantFusion)}}; auto plugin_custom_ops = param->ascendGeOptionCfg.plugin_custom_ops; MS_LOG(INFO) << "plugin_custom_ops: " << plugin_custom_ops; if (find(plugin_custom_ops.begin(), plugin_custom_ops.end(), "All") != plugin_custom_ops.end()) { MS_LOG(INFO) << "using all fusion"; - EnableFlashAttentionFusion(&fusions); - EnableKVCacheFusion(&fusions); - EnableFlashAttentionAntiquantFusion(&fusions); + EnableFlashAttentionFusion(&fusions, param); + EnableKVCacheFusion(&fusions, param); + EnableFlashAttentionAntiquantFusion(&fusions, param); // MatMulAllReduce has performance degradation in incremental inference scenarios, // and is not controlled by "All" temporarily. if (find(plugin_custom_ops.begin(), plugin_custom_ops.end(), kFusionNameMatMulAllReduce) != plugin_custom_ops.end()) { - EnableMatMulAllReduceFusion(&fusions); + EnableMatMulAllReduceFusion(&fusions, param); } } else { for (uint i = 0; i < plugin_custom_ops.size(); i++) { @@ -110,7 +118,7 @@ int AnfTransformForGe::RunGeFusionPass(const FuncGraphPtr &old_graph, const std: auto plugin_func = fusion_mappings[plugin_name]; if (plugin_func != nullptr) { MS_LOG(INFO) << "using " << plugin_name; - plugin_func(&fusions); + plugin_func(&fusions, param); } } } diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc index be03a1edfaf..b72068c1106 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc @@ -42,6 +42,9 @@ constexpr auto kDynamicQuantParam = "dynamic_quant_param"; constexpr auto kGraphKernelParam = "graph_kernel_param"; constexpr int kNumSize3 = 3; constexpr int kNumSize2 = 2; +constexpr size_t kNumIndex0 = 0; +constexpr size_t kNumIndex1 = 1; +constexpr size_t kNumIndex2 = 2; } // namespace using ShapeVector = std::vector; const int kBatchDim = 0; @@ -431,6 +434,32 @@ bool ConfigFileParser::SetParamByConfigfile(const std::shared_ptr op_attrs_vec = {}; + if (!op_attrs_str.empty()) { + MS_LOG(INFO) << "op_attrs_str: " << op_attrs_str; + op_attrs_vec = mindspore::lite::SplitStringToVector(op_attrs_str, ";"); + std::map attr; + for (auto op_attr_str : op_attrs_vec) { + MS_LOG(INFO) << "op_attr: " << op_attr_str; + auto op_attr = mindspore::lite::SplitStringToVector(op_attr_str, ":"); + if (op_attr.size() != kNumSize3) { + return false; + } + auto op_type = op_attr[kNumIndex0]; + auto attr_key = op_attr[kNumIndex1]; + auto attr_value = op_attr[kNumIndex2]; + param->aclModelOptionCfgParam.op_attrs_map[op_type].insert(std::make_pair(attr_key, attr_value)); + param->ascendGeOptionCfg.op_attrs_map[op_type].insert(std::make_pair(attr_key, attr_value)); + } + } + for (auto item : param->aclModelOptionCfgParam.op_attrs_map) { + for (auto attr : item.second) { + MS_LOG(INFO) << "op type: " << item.first << ", key: " << attr.first << ", value: " << attr.second; + } + } + auto it = ascend_map.find("input_shape"); if (it != ascend_map.end()) { param->aclModelOptionCfgParam.input_shape = RemoveInputShapeBrackets(it->second); diff --git a/mindspore/lite/tools/converter/cxx_api/converter_para.h b/mindspore/lite/tools/converter/cxx_api/converter_para.h index 31abbdbef1c..40d5b07643f 100644 --- a/mindspore/lite/tools/converter/cxx_api/converter_para.h +++ b/mindspore/lite/tools/converter/cxx_api/converter_para.h @@ -47,6 +47,7 @@ struct GraphKernelCfg { struct AscendGeOptionCfg { std::vector plugin_custom_ops; + std::map> op_attrs_map; std::vector inputs_to_variable; std::vector outputs_to_variable; }; diff --git a/mindspore/lite/tools/optimizer/fusion/flash_attention_fusion.cc b/mindspore/lite/tools/optimizer/fusion/flash_attention_fusion.cc index 1a5a3fd8c15..ff0c2d723a7 100644 --- a/mindspore/lite/tools/optimizer/fusion/flash_attention_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/flash_attention_fusion.cc @@ -17,6 +17,7 @@ #include "tools/optimizer/fusion/flash_attention_fusion.h" #include #include +#include #include "ops/op_utils.h" #include "ops/array_ops.h" #include "ops/nn_ops.h" @@ -60,6 +61,7 @@ constexpr int64_t kNumDValue = 40; constexpr int64_t kNumPadSize = 8; constexpr int kNumPowerTwo = 2; constexpr float kNumPowerHalf = 0.5; +bool goBSH = false; bool IsDivNode(const BaseRef &n) { if (utils::isa(n)) { @@ -2108,12 +2110,26 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDWithoutCast(const st auto input_tensor_v_shape = GetTensorShape(v_reshape, kNumIndex1); MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape << " , v shape: " << input_tensor_v_shape; + auto q_trans_reshape = q_trans->cast()->input(kNumIndex1); + MS_CHECK_TRUE_RET(q_trans_reshape != nullptr, nullptr); + auto k_trans_reshape = k_trans->cast()->input(kNumIndex1); + MS_CHECK_TRUE_RET(k_trans_reshape != nullptr, nullptr); + auto v_trans_reshape = v_trans->cast()->input(kNumIndex1); + MS_CHECK_TRUE_RET(v_trans_reshape != nullptr, nullptr); + + auto top_matmul_q = q_trans_reshape->cast()->input(kNumIndex1); + MS_CHECK_TRUE_RET(top_matmul_q != nullptr, nullptr); + auto top_matmul_k = k_trans_reshape->cast()->input(kNumIndex1); + MS_CHECK_TRUE_RET(top_matmul_k != nullptr, nullptr); + auto top_matmul_v = v_trans_reshape->cast()->input(kNumIndex1); + MS_CHECK_TRUE_RET(top_matmul_v != nullptr, nullptr); float scale_value = 0; int64_t num_head = 0; int64_t next_tokens = kNumMaxNextTokenSize; int64_t d_value = 0; auto mul_const_input = mul->input(kNumIndex2); + bool actual_BSH = false; if (input_tensor_q_shape.size() != kNumShapeSize4) { scale_value = GetScaleValueForDynamicShape(mul_const_input); @@ -2123,7 +2139,6 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDWithoutCast(const st auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape"); auto output_shape_node = node->cast(); output_shape_node->set_input(kNumIndex2, shape_node); - auto q_trans_reshape = q_trans->cast()->input(kNumIndex1); num_head = GetNumHeadForSD(q_trans_reshape); } else if (input_tensor_q_shape.size() == kNumShapeSize4) { MS_LOG(INFO) << "get flash attention param for static shape."; @@ -2142,16 +2157,23 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDWithoutCast(const st if (d_value == kNumDValue) { fa_node = CreateFAForSD15(func_graph, node, q_trans, k_trans, v_trans, num_head, next_tokens, scale_value); } + } else if (goBSH) { + fa_node = CreatePromptFlashAttentionCnodeForBSH(func_graph, node, top_matmul_q, top_matmul_k, top_matmul_v, nullptr, + num_head, next_tokens, scale_value); + actual_BSH = true; } else { fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head, next_tokens, scale_value, num_head); } - if (fa_node == nullptr) { - return nullptr; - } + MS_CHECK_TRUE_RET(fa_node != nullptr, nullptr); auto manager = Manage(func_graph); - (void)manager->Replace(cnode, fa_node); - MS_LOG(INFO) << "create prompt flash attention success for stable diffusion."; + MS_CHECK_TRUE_RET(manager != nullptr, nullptr); + if (actual_BSH) { + (void)manager->Replace(node, fa_node); + } else { + (void)manager->Replace(cnode, fa_node); + } + MS_LOG(INFO) << "create prompt flash attention success for without cast, BSH: " << actual_BSH; return nullptr; } @@ -2412,6 +2434,15 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDEinsum(const std::st AnfNodePtr FlashAttentionFusion::Process(const std::string &patten_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const { MS_LOG(INFO) << "do flash attention fusion, pattern name: " << patten_name; + if (op_attrs_map_.find("FlashAttention") != op_attrs_map_.end()) { + if (op_attrs_map_.at("FlashAttention").find("input_layout") != op_attrs_map_.at("FlashAttention").end()) { + auto layout_str = op_attrs_map_.at("FlashAttention").at("input_layout"); + if (strcmp(layout_str.c_str(), "BSH") == 0) { + MS_LOG(INFO) << "Use user config, FA layout is BSH"; + goBSH = true; + } + } + } if (func_graph == nullptr || node == nullptr || equiv == nullptr) { MS_LOG(ERROR) << "function graph, node or equiv is nullptr."; return nullptr; diff --git a/mindspore/lite/tools/optimizer/fusion/flash_attention_fusion.h b/mindspore/lite/tools/optimizer/fusion/flash_attention_fusion.h index 180e4288208..71fbee8ccda 100644 --- a/mindspore/lite/tools/optimizer/fusion/flash_attention_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/flash_attention_fusion.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "tools/optimizer/common/multiple_pattern_process_pass.h" #include "tools/optimizer/common/gllo_utils.h" @@ -41,8 +42,11 @@ namespace opt { */ class FlashAttentionFusion : public MultiplePatternProcessPass { public: - explicit FlashAttentionFusion(const std::string &name = "FlashAttentionFusion", bool multigraph = true) - : MultiplePatternProcessPass(name, multigraph) {} + explicit FlashAttentionFusion(std::map> op_attrs_map, + const std::string &name = "FlashAttentionFusion", bool multigraph = true) + : MultiplePatternProcessPass(name, multigraph) { + op_attrs_map_ = op_attrs_map; + } ~FlashAttentionFusion() override = default; @@ -51,6 +55,8 @@ class FlashAttentionFusion : public MultiplePatternProcessPass { AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: + std::map> op_attrs_map_; + CNodePtr CreatePromptFlashAttentionCnodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v, const AnfNodePtr &atten_mask, int64_t num_heads, int64_t next_token,