!66967 [cherrypick from 2.3] lite parse convert cfg

Merge pull request !66967 from emmmmtang/r2.3-parse-cfg
This commit is contained in:
i-robot 2024-03-22 06:39:14 +00:00 committed by Gitee
commit 6f846d599c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 104 additions and 25 deletions

View File

@ -13,6 +13,7 @@ mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc:mi
mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/data_queue_op.cc:mindspore::dataset::DataQueueOp::SendDataToAscend
mindspore/mindspore/lite/tools/converter/graphdef_transform.cc:mindspore::lite::GraphDefTransform::Transform
mindspore/mindspore/lite/providers/nnie_proposal/src/proposal.cc:mindspore::proposal::Rpn
mindspore/mindspore/lite/tools/converter/config_parser/config_file_parser.cc:mindspore::lite::ConfigFileParser::SetParamByConfigfile
mindspore/mindspore/core/abstract/ops/primitive_infer_map.cc:mindspore::abstract::GetPrimitiveInferMap
mindspore/mindspore/core/ir/tensor.cc:mindspore::tensor::MakeTensorData
mindspore/mindspore/core/ir/dtype_extends.cc:mindspore::UnitSizeInBytes

View File

@ -108,6 +108,7 @@ static const char *const kGraphCompilerCacheDirKey = "model_cache_dir";
static const char *const kModifyMixList = "mixprecision_list_path";
static const char *const kEnableCustomOp = "enable_custom_op";
static const char *const kPluginCustomOps = "plugin_custom_ops";
static const char *const kOpAttrs = "op_attrs";
static const char *const kAoeMode = "aoe_mode";
static const char *const kProvider = "provider";
static const char *const kAscendProviderGe = "ge";

View File

@ -56,6 +56,7 @@ struct AclModelOptionCfg {
std::map<std::string, std::string> build_options_map;
std::map<std::string, std::string> aoe_global_options_map;
std::map<std::string, std::string> aoe_tuning_options_map;
std::map<std::string, std::map<std::string, std::string>> op_attrs_map;
};
} // namespace acl
} // namespace lite

View File

@ -821,7 +821,8 @@ bool AnfTransform::StoreBuiltinPass(const std::shared_ptr<ConverterPara> &param)
param->aclModelOptionCfgParam.disable_custom_fusion_pattern),
false},
{"MakeListPass", std::make_shared<opt::MakeListPass>(), true},
{"FlashAttentionFusion", std::make_shared<opt::FlashAttentionFusion>(), false},
{"FlashAttentionFusion", std::make_shared<opt::FlashAttentionFusion>(param->aclModelOptionCfgParam.op_attrs_map),
false},
{"GroupNormSiluFusion", std::make_shared<opt::GroupNormSiluFusion>(), false},
{"GeGluV2Fusion", std::make_shared<opt::GeGluV2Fusion>(), false},
{"AddLayerNormFusion", std::make_shared<opt::AddLayerNormFusion>(), false},

View File

@ -52,23 +52,24 @@
#include "tools/optimizer/graph/padv3_ge_pass.h"
namespace mindspore::lite {
void EnableKVCacheFusion(std::vector<opt::PassPtr> *fusions) {
void EnableKVCacheFusion(std::vector<opt::PassPtr> *fusions, const std::shared_ptr<ConverterPara> &param) {
fusions->push_back(std::make_shared<opt::KVCacheMgrOneBranchFusion>());
fusions->push_back(std::make_shared<opt::KVCacheMgrConcatFusion>());
fusions->push_back(std::make_shared<opt::KVCacheMgrLoadFusion>());
fusions->push_back(std::make_shared<opt::KVCacheMgrAssignFusion>());
}
void EnableMatMulAllReduceFusion(std::vector<opt::PassPtr> *fusions) {
void EnableMatMulAllReduceFusion(std::vector<opt::PassPtr> *fusions, const std::shared_ptr<ConverterPara> &param) {
fusions->push_back(std::make_shared<opt::MatMulAllReduceFusion>());
fusions->push_back(std::make_shared<opt::QuantFusionXOffsetToBias>());
}
void EnableFlashAttentionFusion(std::vector<opt::PassPtr> *fusions) {
fusions->push_back(std::make_shared<opt::FlashAttentionFusion>());
void EnableFlashAttentionFusion(std::vector<opt::PassPtr> *fusions, const std::shared_ptr<ConverterPara> &param) {
fusions->push_back(std::make_shared<opt::FlashAttentionFusion>(param->ascendGeOptionCfg.op_attrs_map));
}
void EnableFlashAttentionAntiquantFusion(std::vector<opt::PassPtr> *fusions) {
void EnableFlashAttentionAntiquantFusion(std::vector<opt::PassPtr> *fusions,
const std::shared_ptr<ConverterPara> &param) {
fusions->push_back(std::make_shared<opt::FlashAttentionAntiquantFusion>());
}
@ -84,25 +85,32 @@ int AnfTransformForGe::RunGeFusionPass(const FuncGraphPtr &old_graph, const std:
std::vector<opt::PassPtr> fusions{std::make_shared<opt::MakeListPass>(), std::make_shared<opt::ScalarOpPass>(),
std::make_shared<opt::PadV3GePass>()};
std::map<std::string, std::function<void(std::vector<opt::PassPtr> *)>> fusion_mappings = {
{kFusionNameMatMulAllReduce, std::function<void(std::vector<opt::PassPtr> *)>(EnableMatMulAllReduceFusion)},
{kFusionNameKVCache, std::function<void(std::vector<opt::PassPtr> *)>(EnableKVCacheFusion)},
{kFusionNameFlashAttention, std::function<void(std::vector<opt::PassPtr> *)>(EnableFlashAttentionFusion)},
{kFusionNameFlashAttentionAntiquant,
std::function<void(std::vector<opt::PassPtr> *)>(EnableFlashAttentionAntiquantFusion)}};
std::map<std::string, std::function<void(std::vector<opt::PassPtr> *, const std::shared_ptr<ConverterPara> &)>>
fusion_mappings = {
{kFusionNameMatMulAllReduce,
std::function<void(std::vector<opt::PassPtr> *, const std::shared_ptr<ConverterPara> &)>(
EnableMatMulAllReduceFusion)},
{kFusionNameKVCache,
std::function<void(std::vector<opt::PassPtr> *, const std::shared_ptr<ConverterPara> &)>(EnableKVCacheFusion)},
{kFusionNameFlashAttention,
std::function<void(std::vector<opt::PassPtr> *, const std::shared_ptr<ConverterPara> &)>(
EnableFlashAttentionFusion)},
{kFusionNameFlashAttentionAntiquant,
std::function<void(std::vector<opt::PassPtr> *, const std::shared_ptr<ConverterPara> &)>(
EnableFlashAttentionAntiquantFusion)}};
auto plugin_custom_ops = param->ascendGeOptionCfg.plugin_custom_ops;
MS_LOG(INFO) << "plugin_custom_ops: " << plugin_custom_ops;
if (find(plugin_custom_ops.begin(), plugin_custom_ops.end(), "All") != plugin_custom_ops.end()) {
MS_LOG(INFO) << "using all fusion";
EnableFlashAttentionFusion(&fusions);
EnableKVCacheFusion(&fusions);
EnableFlashAttentionAntiquantFusion(&fusions);
EnableFlashAttentionFusion(&fusions, param);
EnableKVCacheFusion(&fusions, param);
EnableFlashAttentionAntiquantFusion(&fusions, param);
// MatMulAllReduce has performance degradation in incremental inference scenarios,
// and is not controlled by "All" temporarily.
if (find(plugin_custom_ops.begin(), plugin_custom_ops.end(), kFusionNameMatMulAllReduce) !=
plugin_custom_ops.end()) {
EnableMatMulAllReduceFusion(&fusions);
EnableMatMulAllReduceFusion(&fusions, param);
}
} else {
for (uint i = 0; i < plugin_custom_ops.size(); i++) {
@ -110,7 +118,7 @@ int AnfTransformForGe::RunGeFusionPass(const FuncGraphPtr &old_graph, const std:
auto plugin_func = fusion_mappings[plugin_name];
if (plugin_func != nullptr) {
MS_LOG(INFO) << "using " << plugin_name;
plugin_func(&fusions);
plugin_func(&fusions, param);
}
}
}

View File

@ -42,6 +42,9 @@ constexpr auto kDynamicQuantParam = "dynamic_quant_param";
constexpr auto kGraphKernelParam = "graph_kernel_param";
constexpr int kNumSize3 = 3;
constexpr int kNumSize2 = 2;
constexpr size_t kNumIndex0 = 0;
constexpr size_t kNumIndex1 = 1;
constexpr size_t kNumIndex2 = 2;
} // namespace
using ShapeVector = std::vector<int64_t>;
const int kBatchDim = 0;
@ -431,6 +434,32 @@ bool ConfigFileParser::SetParamByConfigfile(const std::shared_ptr<mindspore::Con
return false;
}
}
auto op_attrs_str = FindInAscendMap(kOpAttrs, ascend_map);
std::vector<std::string> op_attrs_vec = {};
if (!op_attrs_str.empty()) {
MS_LOG(INFO) << "op_attrs_str: " << op_attrs_str;
op_attrs_vec = mindspore::lite::SplitStringToVector(op_attrs_str, ";");
std::map<std::string, std::string> attr;
for (auto op_attr_str : op_attrs_vec) {
MS_LOG(INFO) << "op_attr: " << op_attr_str;
auto op_attr = mindspore::lite::SplitStringToVector(op_attr_str, ":");
if (op_attr.size() != kNumSize3) {
return false;
}
auto op_type = op_attr[kNumIndex0];
auto attr_key = op_attr[kNumIndex1];
auto attr_value = op_attr[kNumIndex2];
param->aclModelOptionCfgParam.op_attrs_map[op_type].insert(std::make_pair(attr_key, attr_value));
param->ascendGeOptionCfg.op_attrs_map[op_type].insert(std::make_pair(attr_key, attr_value));
}
}
for (auto item : param->aclModelOptionCfgParam.op_attrs_map) {
for (auto attr : item.second) {
MS_LOG(INFO) << "op type: " << item.first << ", key: " << attr.first << ", value: " << attr.second;
}
}
auto it = ascend_map.find("input_shape");
if (it != ascend_map.end()) {
param->aclModelOptionCfgParam.input_shape = RemoveInputShapeBrackets(it->second);

View File

@ -47,6 +47,7 @@ struct GraphKernelCfg {
struct AscendGeOptionCfg {
std::vector<std::string> plugin_custom_ops;
std::map<std::string, std::map<std::string, std::string>> op_attrs_map;
std::vector<int64_t> inputs_to_variable;
std::vector<int64_t> outputs_to_variable;
};

View File

@ -17,6 +17,7 @@
#include "tools/optimizer/fusion/flash_attention_fusion.h"
#include <memory>
#include <utility>
#include <string>
#include "ops/op_utils.h"
#include "ops/array_ops.h"
#include "ops/nn_ops.h"
@ -60,6 +61,7 @@ constexpr int64_t kNumDValue = 40;
constexpr int64_t kNumPadSize = 8;
constexpr int kNumPowerTwo = 2;
constexpr float kNumPowerHalf = 0.5;
bool goBSH = false;
bool IsDivNode(const BaseRef &n) {
if (utils::isa<AnfNodePtr>(n)) {
@ -2108,12 +2110,26 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDWithoutCast(const st
auto input_tensor_v_shape = GetTensorShape(v_reshape, kNumIndex1);
MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
<< " , v shape: " << input_tensor_v_shape;
auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
MS_CHECK_TRUE_RET(q_trans_reshape != nullptr, nullptr);
auto k_trans_reshape = k_trans->cast<CNodePtr>()->input(kNumIndex1);
MS_CHECK_TRUE_RET(k_trans_reshape != nullptr, nullptr);
auto v_trans_reshape = v_trans->cast<CNodePtr>()->input(kNumIndex1);
MS_CHECK_TRUE_RET(v_trans_reshape != nullptr, nullptr);
auto top_matmul_q = q_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
MS_CHECK_TRUE_RET(top_matmul_q != nullptr, nullptr);
auto top_matmul_k = k_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
MS_CHECK_TRUE_RET(top_matmul_k != nullptr, nullptr);
auto top_matmul_v = v_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
MS_CHECK_TRUE_RET(top_matmul_v != nullptr, nullptr);
float scale_value = 0;
int64_t num_head = 0;
int64_t next_tokens = kNumMaxNextTokenSize;
int64_t d_value = 0;
auto mul_const_input = mul->input(kNumIndex2);
bool actual_BSH = false;
if (input_tensor_q_shape.size() != kNumShapeSize4) {
scale_value = GetScaleValueForDynamicShape(mul_const_input);
@ -2123,7 +2139,6 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDWithoutCast(const st
auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
auto output_shape_node = node->cast<CNodePtr>();
output_shape_node->set_input(kNumIndex2, shape_node);
auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
num_head = GetNumHeadForSD(q_trans_reshape);
} else if (input_tensor_q_shape.size() == kNumShapeSize4) {
MS_LOG(INFO) << "get flash attention param for static shape.";
@ -2142,16 +2157,23 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDWithoutCast(const st
if (d_value == kNumDValue) {
fa_node = CreateFAForSD15(func_graph, node, q_trans, k_trans, v_trans, num_head, next_tokens, scale_value);
}
} else if (goBSH) {
fa_node = CreatePromptFlashAttentionCnodeForBSH(func_graph, node, top_matmul_q, top_matmul_k, top_matmul_v, nullptr,
num_head, next_tokens, scale_value);
actual_BSH = true;
} else {
fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
next_tokens, scale_value, num_head);
}
if (fa_node == nullptr) {
return nullptr;
}
MS_CHECK_TRUE_RET(fa_node != nullptr, nullptr);
auto manager = Manage(func_graph);
(void)manager->Replace(cnode, fa_node);
MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
if (actual_BSH) {
(void)manager->Replace(node, fa_node);
} else {
(void)manager->Replace(cnode, fa_node);
}
MS_LOG(INFO) << "create prompt flash attention success for without cast, BSH: " << actual_BSH;
return nullptr;
}
@ -2412,6 +2434,15 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDEinsum(const std::st
AnfNodePtr FlashAttentionFusion::Process(const std::string &patten_name, const FuncGraphPtr &func_graph,
const AnfNodePtr &node, const EquivPtr &equiv) const {
MS_LOG(INFO) << "do flash attention fusion, pattern name: " << patten_name;
if (op_attrs_map_.find("FlashAttention") != op_attrs_map_.end()) {
if (op_attrs_map_.at("FlashAttention").find("input_layout") != op_attrs_map_.at("FlashAttention").end()) {
auto layout_str = op_attrs_map_.at("FlashAttention").at("input_layout");
if (strcmp(layout_str.c_str(), "BSH") == 0) {
MS_LOG(INFO) << "Use user config, FA layout is BSH";
goBSH = true;
}
}
}
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
MS_LOG(ERROR) << "function graph, node or equiv is nullptr.";
return nullptr;

View File

@ -20,6 +20,7 @@
#include <string>
#include <memory>
#include <vector>
#include <map>
#include <unordered_map>
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
#include "tools/optimizer/common/gllo_utils.h"
@ -41,8 +42,11 @@ namespace opt {
*/
class FlashAttentionFusion : public MultiplePatternProcessPass {
public:
explicit FlashAttentionFusion(const std::string &name = "FlashAttentionFusion", bool multigraph = true)
: MultiplePatternProcessPass(name, multigraph) {}
explicit FlashAttentionFusion(std::map<std::string, std::map<std::string, std::string>> op_attrs_map,
const std::string &name = "FlashAttentionFusion", bool multigraph = true)
: MultiplePatternProcessPass(name, multigraph) {
op_attrs_map_ = op_attrs_map;
}
~FlashAttentionFusion() override = default;
@ -51,6 +55,8 @@ class FlashAttentionFusion : public MultiplePatternProcessPass {
AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
std::map<std::string, std::map<std::string, std::string>> op_attrs_map_;
CNodePtr CreatePromptFlashAttentionCnodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
const AnfNodePtr &atten_mask, int64_t num_heads, int64_t next_token,