!66967 [cherrypick from 2.3] lite parse convert cfg
Merge pull request !66967 from emmmmtang/r2.3-parse-cfg
This commit is contained in:
commit
6f846d599c
|
@ -13,6 +13,7 @@ mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc:mi
|
|||
mindspore/mindspore/ccsrc/minddata/dataset/engine/datasetops/data_queue_op.cc:mindspore::dataset::DataQueueOp::SendDataToAscend
|
||||
mindspore/mindspore/lite/tools/converter/graphdef_transform.cc:mindspore::lite::GraphDefTransform::Transform
|
||||
mindspore/mindspore/lite/providers/nnie_proposal/src/proposal.cc:mindspore::proposal::Rpn
|
||||
mindspore/mindspore/lite/tools/converter/config_parser/config_file_parser.cc:mindspore::lite::ConfigFileParser::SetParamByConfigfile
|
||||
mindspore/mindspore/core/abstract/ops/primitive_infer_map.cc:mindspore::abstract::GetPrimitiveInferMap
|
||||
mindspore/mindspore/core/ir/tensor.cc:mindspore::tensor::MakeTensorData
|
||||
mindspore/mindspore/core/ir/dtype_extends.cc:mindspore::UnitSizeInBytes
|
||||
|
|
|
@ -108,6 +108,7 @@ static const char *const kGraphCompilerCacheDirKey = "model_cache_dir";
|
|||
static const char *const kModifyMixList = "mixprecision_list_path";
|
||||
static const char *const kEnableCustomOp = "enable_custom_op";
|
||||
static const char *const kPluginCustomOps = "plugin_custom_ops";
|
||||
static const char *const kOpAttrs = "op_attrs";
|
||||
static const char *const kAoeMode = "aoe_mode";
|
||||
static const char *const kProvider = "provider";
|
||||
static const char *const kAscendProviderGe = "ge";
|
||||
|
|
|
@ -56,6 +56,7 @@ struct AclModelOptionCfg {
|
|||
std::map<std::string, std::string> build_options_map;
|
||||
std::map<std::string, std::string> aoe_global_options_map;
|
||||
std::map<std::string, std::string> aoe_tuning_options_map;
|
||||
std::map<std::string, std::map<std::string, std::string>> op_attrs_map;
|
||||
};
|
||||
} // namespace acl
|
||||
} // namespace lite
|
||||
|
|
|
@ -821,7 +821,8 @@ bool AnfTransform::StoreBuiltinPass(const std::shared_ptr<ConverterPara> ¶m)
|
|||
param->aclModelOptionCfgParam.disable_custom_fusion_pattern),
|
||||
false},
|
||||
{"MakeListPass", std::make_shared<opt::MakeListPass>(), true},
|
||||
{"FlashAttentionFusion", std::make_shared<opt::FlashAttentionFusion>(), false},
|
||||
{"FlashAttentionFusion", std::make_shared<opt::FlashAttentionFusion>(param->aclModelOptionCfgParam.op_attrs_map),
|
||||
false},
|
||||
{"GroupNormSiluFusion", std::make_shared<opt::GroupNormSiluFusion>(), false},
|
||||
{"GeGluV2Fusion", std::make_shared<opt::GeGluV2Fusion>(), false},
|
||||
{"AddLayerNormFusion", std::make_shared<opt::AddLayerNormFusion>(), false},
|
||||
|
|
|
@ -52,23 +52,24 @@
|
|||
#include "tools/optimizer/graph/padv3_ge_pass.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
void EnableKVCacheFusion(std::vector<opt::PassPtr> *fusions) {
|
||||
void EnableKVCacheFusion(std::vector<opt::PassPtr> *fusions, const std::shared_ptr<ConverterPara> ¶m) {
|
||||
fusions->push_back(std::make_shared<opt::KVCacheMgrOneBranchFusion>());
|
||||
fusions->push_back(std::make_shared<opt::KVCacheMgrConcatFusion>());
|
||||
fusions->push_back(std::make_shared<opt::KVCacheMgrLoadFusion>());
|
||||
fusions->push_back(std::make_shared<opt::KVCacheMgrAssignFusion>());
|
||||
}
|
||||
|
||||
void EnableMatMulAllReduceFusion(std::vector<opt::PassPtr> *fusions) {
|
||||
void EnableMatMulAllReduceFusion(std::vector<opt::PassPtr> *fusions, const std::shared_ptr<ConverterPara> ¶m) {
|
||||
fusions->push_back(std::make_shared<opt::MatMulAllReduceFusion>());
|
||||
fusions->push_back(std::make_shared<opt::QuantFusionXOffsetToBias>());
|
||||
}
|
||||
|
||||
void EnableFlashAttentionFusion(std::vector<opt::PassPtr> *fusions) {
|
||||
fusions->push_back(std::make_shared<opt::FlashAttentionFusion>());
|
||||
void EnableFlashAttentionFusion(std::vector<opt::PassPtr> *fusions, const std::shared_ptr<ConverterPara> ¶m) {
|
||||
fusions->push_back(std::make_shared<opt::FlashAttentionFusion>(param->ascendGeOptionCfg.op_attrs_map));
|
||||
}
|
||||
|
||||
void EnableFlashAttentionAntiquantFusion(std::vector<opt::PassPtr> *fusions) {
|
||||
void EnableFlashAttentionAntiquantFusion(std::vector<opt::PassPtr> *fusions,
|
||||
const std::shared_ptr<ConverterPara> ¶m) {
|
||||
fusions->push_back(std::make_shared<opt::FlashAttentionAntiquantFusion>());
|
||||
}
|
||||
|
||||
|
@ -84,25 +85,32 @@ int AnfTransformForGe::RunGeFusionPass(const FuncGraphPtr &old_graph, const std:
|
|||
|
||||
std::vector<opt::PassPtr> fusions{std::make_shared<opt::MakeListPass>(), std::make_shared<opt::ScalarOpPass>(),
|
||||
std::make_shared<opt::PadV3GePass>()};
|
||||
std::map<std::string, std::function<void(std::vector<opt::PassPtr> *)>> fusion_mappings = {
|
||||
{kFusionNameMatMulAllReduce, std::function<void(std::vector<opt::PassPtr> *)>(EnableMatMulAllReduceFusion)},
|
||||
{kFusionNameKVCache, std::function<void(std::vector<opt::PassPtr> *)>(EnableKVCacheFusion)},
|
||||
{kFusionNameFlashAttention, std::function<void(std::vector<opt::PassPtr> *)>(EnableFlashAttentionFusion)},
|
||||
{kFusionNameFlashAttentionAntiquant,
|
||||
std::function<void(std::vector<opt::PassPtr> *)>(EnableFlashAttentionAntiquantFusion)}};
|
||||
std::map<std::string, std::function<void(std::vector<opt::PassPtr> *, const std::shared_ptr<ConverterPara> &)>>
|
||||
fusion_mappings = {
|
||||
{kFusionNameMatMulAllReduce,
|
||||
std::function<void(std::vector<opt::PassPtr> *, const std::shared_ptr<ConverterPara> &)>(
|
||||
EnableMatMulAllReduceFusion)},
|
||||
{kFusionNameKVCache,
|
||||
std::function<void(std::vector<opt::PassPtr> *, const std::shared_ptr<ConverterPara> &)>(EnableKVCacheFusion)},
|
||||
{kFusionNameFlashAttention,
|
||||
std::function<void(std::vector<opt::PassPtr> *, const std::shared_ptr<ConverterPara> &)>(
|
||||
EnableFlashAttentionFusion)},
|
||||
{kFusionNameFlashAttentionAntiquant,
|
||||
std::function<void(std::vector<opt::PassPtr> *, const std::shared_ptr<ConverterPara> &)>(
|
||||
EnableFlashAttentionAntiquantFusion)}};
|
||||
|
||||
auto plugin_custom_ops = param->ascendGeOptionCfg.plugin_custom_ops;
|
||||
MS_LOG(INFO) << "plugin_custom_ops: " << plugin_custom_ops;
|
||||
if (find(plugin_custom_ops.begin(), plugin_custom_ops.end(), "All") != plugin_custom_ops.end()) {
|
||||
MS_LOG(INFO) << "using all fusion";
|
||||
EnableFlashAttentionFusion(&fusions);
|
||||
EnableKVCacheFusion(&fusions);
|
||||
EnableFlashAttentionAntiquantFusion(&fusions);
|
||||
EnableFlashAttentionFusion(&fusions, param);
|
||||
EnableKVCacheFusion(&fusions, param);
|
||||
EnableFlashAttentionAntiquantFusion(&fusions, param);
|
||||
// MatMulAllReduce has performance degradation in incremental inference scenarios,
|
||||
// and is not controlled by "All" temporarily.
|
||||
if (find(plugin_custom_ops.begin(), plugin_custom_ops.end(), kFusionNameMatMulAllReduce) !=
|
||||
plugin_custom_ops.end()) {
|
||||
EnableMatMulAllReduceFusion(&fusions);
|
||||
EnableMatMulAllReduceFusion(&fusions, param);
|
||||
}
|
||||
} else {
|
||||
for (uint i = 0; i < plugin_custom_ops.size(); i++) {
|
||||
|
@ -110,7 +118,7 @@ int AnfTransformForGe::RunGeFusionPass(const FuncGraphPtr &old_graph, const std:
|
|||
auto plugin_func = fusion_mappings[plugin_name];
|
||||
if (plugin_func != nullptr) {
|
||||
MS_LOG(INFO) << "using " << plugin_name;
|
||||
plugin_func(&fusions);
|
||||
plugin_func(&fusions, param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -42,6 +42,9 @@ constexpr auto kDynamicQuantParam = "dynamic_quant_param";
|
|||
constexpr auto kGraphKernelParam = "graph_kernel_param";
|
||||
constexpr int kNumSize3 = 3;
|
||||
constexpr int kNumSize2 = 2;
|
||||
constexpr size_t kNumIndex0 = 0;
|
||||
constexpr size_t kNumIndex1 = 1;
|
||||
constexpr size_t kNumIndex2 = 2;
|
||||
} // namespace
|
||||
using ShapeVector = std::vector<int64_t>;
|
||||
const int kBatchDim = 0;
|
||||
|
@ -431,6 +434,32 @@ bool ConfigFileParser::SetParamByConfigfile(const std::shared_ptr<mindspore::Con
|
|||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto op_attrs_str = FindInAscendMap(kOpAttrs, ascend_map);
|
||||
std::vector<std::string> op_attrs_vec = {};
|
||||
if (!op_attrs_str.empty()) {
|
||||
MS_LOG(INFO) << "op_attrs_str: " << op_attrs_str;
|
||||
op_attrs_vec = mindspore::lite::SplitStringToVector(op_attrs_str, ";");
|
||||
std::map<std::string, std::string> attr;
|
||||
for (auto op_attr_str : op_attrs_vec) {
|
||||
MS_LOG(INFO) << "op_attr: " << op_attr_str;
|
||||
auto op_attr = mindspore::lite::SplitStringToVector(op_attr_str, ":");
|
||||
if (op_attr.size() != kNumSize3) {
|
||||
return false;
|
||||
}
|
||||
auto op_type = op_attr[kNumIndex0];
|
||||
auto attr_key = op_attr[kNumIndex1];
|
||||
auto attr_value = op_attr[kNumIndex2];
|
||||
param->aclModelOptionCfgParam.op_attrs_map[op_type].insert(std::make_pair(attr_key, attr_value));
|
||||
param->ascendGeOptionCfg.op_attrs_map[op_type].insert(std::make_pair(attr_key, attr_value));
|
||||
}
|
||||
}
|
||||
for (auto item : param->aclModelOptionCfgParam.op_attrs_map) {
|
||||
for (auto attr : item.second) {
|
||||
MS_LOG(INFO) << "op type: " << item.first << ", key: " << attr.first << ", value: " << attr.second;
|
||||
}
|
||||
}
|
||||
|
||||
auto it = ascend_map.find("input_shape");
|
||||
if (it != ascend_map.end()) {
|
||||
param->aclModelOptionCfgParam.input_shape = RemoveInputShapeBrackets(it->second);
|
||||
|
|
|
@ -47,6 +47,7 @@ struct GraphKernelCfg {
|
|||
|
||||
struct AscendGeOptionCfg {
|
||||
std::vector<std::string> plugin_custom_ops;
|
||||
std::map<std::string, std::map<std::string, std::string>> op_attrs_map;
|
||||
std::vector<int64_t> inputs_to_variable;
|
||||
std::vector<int64_t> outputs_to_variable;
|
||||
};
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include "tools/optimizer/fusion/flash_attention_fusion.h"
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/array_ops.h"
|
||||
#include "ops/nn_ops.h"
|
||||
|
@ -60,6 +61,7 @@ constexpr int64_t kNumDValue = 40;
|
|||
constexpr int64_t kNumPadSize = 8;
|
||||
constexpr int kNumPowerTwo = 2;
|
||||
constexpr float kNumPowerHalf = 0.5;
|
||||
bool goBSH = false;
|
||||
|
||||
bool IsDivNode(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
|
@ -2108,12 +2110,26 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDWithoutCast(const st
|
|||
auto input_tensor_v_shape = GetTensorShape(v_reshape, kNumIndex1);
|
||||
MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
|
||||
<< " , v shape: " << input_tensor_v_shape;
|
||||
auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
|
||||
MS_CHECK_TRUE_RET(q_trans_reshape != nullptr, nullptr);
|
||||
auto k_trans_reshape = k_trans->cast<CNodePtr>()->input(kNumIndex1);
|
||||
MS_CHECK_TRUE_RET(k_trans_reshape != nullptr, nullptr);
|
||||
auto v_trans_reshape = v_trans->cast<CNodePtr>()->input(kNumIndex1);
|
||||
MS_CHECK_TRUE_RET(v_trans_reshape != nullptr, nullptr);
|
||||
|
||||
auto top_matmul_q = q_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
|
||||
MS_CHECK_TRUE_RET(top_matmul_q != nullptr, nullptr);
|
||||
auto top_matmul_k = k_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
|
||||
MS_CHECK_TRUE_RET(top_matmul_k != nullptr, nullptr);
|
||||
auto top_matmul_v = v_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
|
||||
MS_CHECK_TRUE_RET(top_matmul_v != nullptr, nullptr);
|
||||
|
||||
float scale_value = 0;
|
||||
int64_t num_head = 0;
|
||||
int64_t next_tokens = kNumMaxNextTokenSize;
|
||||
int64_t d_value = 0;
|
||||
auto mul_const_input = mul->input(kNumIndex2);
|
||||
bool actual_BSH = false;
|
||||
|
||||
if (input_tensor_q_shape.size() != kNumShapeSize4) {
|
||||
scale_value = GetScaleValueForDynamicShape(mul_const_input);
|
||||
|
@ -2123,7 +2139,6 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDWithoutCast(const st
|
|||
auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
|
||||
auto output_shape_node = node->cast<CNodePtr>();
|
||||
output_shape_node->set_input(kNumIndex2, shape_node);
|
||||
auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
|
||||
num_head = GetNumHeadForSD(q_trans_reshape);
|
||||
} else if (input_tensor_q_shape.size() == kNumShapeSize4) {
|
||||
MS_LOG(INFO) << "get flash attention param for static shape.";
|
||||
|
@ -2142,16 +2157,23 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDWithoutCast(const st
|
|||
if (d_value == kNumDValue) {
|
||||
fa_node = CreateFAForSD15(func_graph, node, q_trans, k_trans, v_trans, num_head, next_tokens, scale_value);
|
||||
}
|
||||
} else if (goBSH) {
|
||||
fa_node = CreatePromptFlashAttentionCnodeForBSH(func_graph, node, top_matmul_q, top_matmul_k, top_matmul_v, nullptr,
|
||||
num_head, next_tokens, scale_value);
|
||||
actual_BSH = true;
|
||||
} else {
|
||||
fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
|
||||
next_tokens, scale_value, num_head);
|
||||
}
|
||||
if (fa_node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
MS_CHECK_TRUE_RET(fa_node != nullptr, nullptr);
|
||||
auto manager = Manage(func_graph);
|
||||
(void)manager->Replace(cnode, fa_node);
|
||||
MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
|
||||
MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
|
||||
if (actual_BSH) {
|
||||
(void)manager->Replace(node, fa_node);
|
||||
} else {
|
||||
(void)manager->Replace(cnode, fa_node);
|
||||
}
|
||||
MS_LOG(INFO) << "create prompt flash attention success for without cast, BSH: " << actual_BSH;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -2412,6 +2434,15 @@ CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDEinsum(const std::st
|
|||
AnfNodePtr FlashAttentionFusion::Process(const std::string &patten_name, const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &node, const EquivPtr &equiv) const {
|
||||
MS_LOG(INFO) << "do flash attention fusion, pattern name: " << patten_name;
|
||||
if (op_attrs_map_.find("FlashAttention") != op_attrs_map_.end()) {
|
||||
if (op_attrs_map_.at("FlashAttention").find("input_layout") != op_attrs_map_.at("FlashAttention").end()) {
|
||||
auto layout_str = op_attrs_map_.at("FlashAttention").at("input_layout");
|
||||
if (strcmp(layout_str.c_str(), "BSH") == 0) {
|
||||
MS_LOG(INFO) << "Use user config, FA layout is BSH";
|
||||
goBSH = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||
MS_LOG(ERROR) << "function graph, node or equiv is nullptr.";
|
||||
return nullptr;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
@ -41,8 +42,11 @@ namespace opt {
|
|||
*/
|
||||
class FlashAttentionFusion : public MultiplePatternProcessPass {
|
||||
public:
|
||||
explicit FlashAttentionFusion(const std::string &name = "FlashAttentionFusion", bool multigraph = true)
|
||||
: MultiplePatternProcessPass(name, multigraph) {}
|
||||
explicit FlashAttentionFusion(std::map<std::string, std::map<std::string, std::string>> op_attrs_map,
|
||||
const std::string &name = "FlashAttentionFusion", bool multigraph = true)
|
||||
: MultiplePatternProcessPass(name, multigraph) {
|
||||
op_attrs_map_ = op_attrs_map;
|
||||
}
|
||||
|
||||
~FlashAttentionFusion() override = default;
|
||||
|
||||
|
@ -51,6 +55,8 @@ class FlashAttentionFusion : public MultiplePatternProcessPass {
|
|||
AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::map<std::string, std::map<std::string, std::string>> op_attrs_map_;
|
||||
|
||||
CNodePtr CreatePromptFlashAttentionCnodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
|
||||
const AnfNodePtr &atten_mask, int64_t num_heads, int64_t next_token,
|
||||
|
|
Loading…
Reference in New Issue