modify multihead fusion

This commit is contained in:
hangangqiang 2021-07-01 15:40:45 +08:00
parent 5b98330f8d
commit 3086f9a494
23 changed files with 532 additions and 165 deletions

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/infer/attention_infer.h"
#include "nnacl/infer/infer_register.h"
int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 7, 1);
if (check_ret != NNACL_OK) {
return check_ret;
}
const TensorC *q_input = inputs[0];
TensorC *output = outputs[0];
SetDataTypeFormat(output, q_input);
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
const TensorC *q_weight = inputs[3];
if (q_input->shape_size_ != 2 && q_input->shape_size_ != 3) {
return NNACL_ERR;
}
if (q_weight->shape_size_ != 2) {
return NNACL_ERR;
}
int batch = (q_input->shape_size_ == 2) ? 1 : q_input->shape_[0];
int f_seq = (q_input->shape_size_ == 2) ? q_input->shape_[0] : q_input->shape_[1];
int d_model = q_weight->shape_[1];
output->shape_[0] = batch;
output->shape_[1] = f_seq;
output->shape_[2] = d_model;
output->shape_size_ = 3;
return NNACL_OK;
}
REG_INFER(Attention, PrimType_Attention, AttentionInferShape)

View File

@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_ATTENTION_INFER_H
#define MINDSPORE_NNACL_ATTENTION_INFER_H
#include "nnacl/infer/common_infer.h"
#ifdef __cplusplus
extern "C" {
#endif
int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_ATTENTION_INFER_H

View File

@ -225,8 +225,9 @@ enum PrimType {
PrimType_TensorArrayRead = 198,
PrimType_TensorArrayWrite = 199,
PrimType_Affine = 200,
PrimType_Attention = 201,
PrimType_MIN = PrimType_NONE,
PrimType_MAX = PrimType_Affine + 1
PrimType_MAX = PrimType_Attention + 1
};
void RegInfer(int prim_type, InferShape func);

View File

@ -0,0 +1,22 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/attention.h"
namespace mindspore::ops {
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
} // namespace mindspore::ops

View File

@ -19,26 +19,24 @@
#include <vector>
#include <string>
#include <memory>
#include "utils/check_convert_utils.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace ops {
constexpr auto kNameAttention = "Attention";
// Attention MultiHeadAttention
class Attention : public PrimitiveC {
public:
Attention() : PrimitiveC(kNameAttention) {
InitIOName({"query", "key", "value", "w_q", "b_q", "w_k", "b_k", "w_v", "b_v", "w_o", "b_o"}, {"output"});
InitIOName(
{"q", "k", "v", "weight_q", "weight_k", "weight_v", "weight_o", "bias_q", "bias_k", "bias_v", "bias_o", "mask"},
{"output"});
}
~Attention() = default;
~Attention() override = default;
MS_DECLARE_PARENT(Attention, PrimitiveC);
void Init(const int64_t number_heads = 0, const int64_t key_dim = 0, const int64_t value_dim = 0);
void set_num_heads(const int64_t num_heads);
void set_key_dim(const int64_t key_dim);
void set_value_dim(const int64_t value_dim);
int64_t get_num_heads() const;
int64_t get_key_dim() const;
int64_t get_value_dim() const;
void Init() {}
};
} // namespace ops
} // namespace mindspore

View File

@ -28,6 +28,9 @@ namespace mindspore::ops {
constexpr auto kAlpha = "alpha";
constexpr auto kActivation = "activation";
constexpr auto kActivationType = "activation_type";
constexpr auto kAttentionQActType = "attention_q_act_type";
constexpr auto kAttentionKActType = "attention_k_act_type";
constexpr auto kAttentionVActType = "attention_v_act_type";
constexpr auto kAddress = "address";
constexpr auto kAlignCorners = "align_corners";
constexpr auto kAttr = "attr";
@ -85,6 +88,7 @@ constexpr auto kGradX = "grad_x";
constexpr auto kGradY = "grad_y";
constexpr auto kGroup = "group";
constexpr auto kHasBias = "has_bias";
constexpr auto kAttentionHasMask = "attention_has_mask";
constexpr auto kHiddenSize = "hidden_size";
constexpr auto kId = "id";
constexpr auto kImageSizeH = "image_size_h";
@ -124,9 +128,10 @@ constexpr auto kNumElements = "num_elements";
constexpr auto kNumBits = "num_bits";
constexpr auto kNumDirections = "num_directions";
constexpr auto kNumProj = "num_proj";
constexpr auto kNumHeads = "num_heads";
constexpr auto kKeyDim = "key_dim";
constexpr auto kValueDim = "value_dim";
constexpr auto kAttentionNumHeads = "attention_num_heads";
constexpr auto kAttentionSizePerHead = "attention_size_per_head";
constexpr auto kAttentionFromSeqLen = "attention_from_seq_len";
constexpr auto kAttentionToSeqLen = "attention_to_seq_len";
constexpr auto kOffset = "offset";
constexpr auto kNmsIouThreshold = "nms_iou_threshold";
constexpr auto kNmsScoreThreshold = "nms_score_threshold";

View File

@ -218,6 +218,7 @@ union PrimitiveType {
TensorArrayRead,
TensorArrayWrite,
Affine,
Attention,
}
table Abs {
@ -1194,3 +1195,6 @@ table Affine {
transpose_a: bool = false;
transpose_b: bool = false;
}
table Attention {
}

View File

@ -218,6 +218,7 @@ OP_TYPE(TensorArrayRead)
OP_TYPE(TensorArrayWrite)
// kaldi affine op
OP_TYPE(Affine)
OP_TYPE(Attention)
OP_TYPE_DEF_END(PrimitiveType)
OP_SCHEMA_DEF(Abs)
@ -1194,3 +1195,6 @@ OP_ATTR_ENUM_WITH_VALUE(activation_type, ActivationType, 0)
OP_ATTR_WITH_VALUE(transpose_a, bool, false)
OP_ATTR_WITH_VALUE(transpose_b, bool, false)
OP_SCHEMA_DEF_END(Affine)
OP_SCHEMA_DEF(Attention)
OP_SCHEMA_DEF_END(Attention)

View File

@ -32,6 +32,7 @@
#include "ops/assert.h"
#include "ops/assign.h"
#include "ops/assign_add.h"
#include "ops/attention.h"
#include "ops/atan.h"
#include "ops/audio_spectrogram.h"
#include "ops/avg_pool.h"
@ -458,6 +459,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(TensorArray)
FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayRead)
FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayWrite)
FUNC_MSOP2SCHEMAOP_DECLARE(Affine)
FUNC_MSOP2SCHEMAOP_DECLARE(Attention)
#endif
} // namespace mindspore::lite::ops
#else

View File

@ -800,6 +800,11 @@ std::unique_ptr<schema::PrimitiveT> AffinePrimitiveCreator(const AnfNodePtr &nod
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
std::unique_ptr<schema::PrimitiveT> AttentionPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Attention>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator);
RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator);
@ -1023,6 +1028,7 @@ RegistryMSOps g_TensorArrayCreatorRegistry("TensorArray", TensorArrayPrimitiveCr
RegistryMSOps g_TensorArrayReadCreatorRegistry("TensorArrayRead", TensorArrayReadPrimitiveCreator);
RegistryMSOps g_TensorArrayWriteCreatorRegistry("TensorArrayWrite", TensorArrayWritePrimitiveCreator);
RegistryMSOps g_AffineCreatorRegistry("Affine", AffinePrimitiveCreator);
RegistryMSOps g_AttentionCreatorRegistry("Attention", AttentionPrimitiveCreator);
std::unique_ptr<schema::PrimitiveT> CustomPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node);

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
using mindspore::schema::PrimitiveType_Attention;
using mindspore::schema::PrimitiveType_Depend;
using mindspore::schema::PrimitiveType_ZerosLike;
@ -35,5 +36,6 @@ OpParameter *PopulateCommonParameter(const void *prim) {
}
REG_POPULATE(PrimitiveType_ZerosLike, PopulateCommonParameter, SCHEMA_CUR)
REG_POPULATE(PrimitiveType_Depend, PopulateCommonParameter, SCHEMA_CUR)
REG_POPULATE(PrimitiveType_Attention, PopulateCommonParameter, SCHEMA_CUR)
} // namespace lite
} // namespace mindspore

View File

@ -174,7 +174,6 @@ if(MSLITE_ENABLE_CONVERTER)
${TEST_LITE_SRC}
${TEST_CASE_TFLITE_PARSERS_SRC}
${LITE_DIR}/tools/converter/ops/while.cc
${LITE_DIR}/tools/converter/ops/attention.cc
${LITE_DIR}/tools/common/protobuf_utils.cc
${LITE_DIR}/tools/converter/optimizer.cc
${LITE_DIR}/tools/converter/anf_transform.cc
@ -188,6 +187,7 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc
${LITE_DIR}/tools/optimizer/common/gllo_utils.cc
${LITE_DIR}/tools/optimizer/common/format_utils.cc
${LITE_DIR}/tools/optimizer/common/multiple_pattern_process_pass.cc
${LITE_DIR}/tools/optimizer/fusion/affine_activation_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/affine_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_biasadd_fusion.cc
@ -196,6 +196,8 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/conv_transform_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_scale_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/multi_head_attention_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/reshape_reshape_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/norm_fusion.cc
@ -208,7 +210,6 @@ if(MSLITE_ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/mul_add_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tf_multi_head_attention_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/glu_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc

View File

@ -44,6 +44,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/common/pass_manager_extends.cc
../optimizer/common/gllo_utils.cc
../optimizer/common/format_utils.cc
../optimizer/common/multiple_pattern_process_pass.cc
../optimizer/fusion/affine_activation_fusion.cc
../optimizer/fusion/affine_fusion.cc
../optimizer/fusion/conv_biasadd_fusion.cc
@ -62,7 +63,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/tflite_lstm_cell_fusion.cc
../optimizer/fusion/tf_lstm_cell_fusion.cc
../optimizer/fusion/tf_bidirection_gru_fusion.cc
../optimizer/fusion/tf_multi_head_attention_fusion.cc
../optimizer/fusion/multi_head_attention_fusion.cc
../optimizer/fusion/reshape_reshape_fusion.cc
../optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc
../optimizer/fusion/glu_fusion.cc
../optimizer/fusion/matmul_add_fusion.cc

View File

@ -37,13 +37,14 @@
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
#include "tools/optimizer/fusion/tf_multi_head_attention_fusion.h"
#include "tools/optimizer/fusion/multi_head_attention_fusion.h"
#include "tools/optimizer/fusion/glu_fusion.h"
#include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h"
#include "tools/optimizer/fusion/matmul_add_fusion.h"
#include "tools/optimizer/fusion/tf_gelu_fusion.h"
#include "tools/optimizer/fusion/onnx_gelu_fusion.h"
#include "tools/optimizer/fusion/squeeze_fusion.h"
#include "tools/optimizer/fusion/reshape_reshape_fusion.h"
#include "tools/optimizer/graph/add_tensor_array.h"
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
@ -81,6 +82,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
if (!config->trainModel) {
// remove quantdtype when awaretraining
fusion_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
auto conv_bn_pass = std::make_shared<opt::ConvBatchNormFusion>();
conv_bn_pass->SetFmkType(config->fmk);
@ -100,7 +102,6 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfGeLUFusion>());
fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfMultiHeadAttentionFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>());
fusion_pm->AddPass(std::make_shared<opt::GLUFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));

View File

@ -1,45 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "tools/converter/ops/attention.h"
#include "ops/op_utils.h"
namespace mindspore::ops {
void Attention::Init(int64_t number_heads, int64_t key_dim, int64_t value_dim) {
this->set_num_heads(number_heads);
this->set_key_dim(key_dim);
this->set_value_dim(value_dim);
}
void Attention::set_num_heads(const int64_t num_heads) { this->AddAttr(kNumHeads, MakeValue(num_heads)); }
void Attention::set_key_dim(const int64_t key_dim) { this->AddAttr(kKeyDim, MakeValue(key_dim)); }
void Attention::set_value_dim(const int64_t value_dim) { this->AddAttr(kValueDim, MakeValue(value_dim)); }
int64_t Attention::get_num_heads() const {
auto value_ptr = this->GetAttr(kNumHeads);
return GetValue<int64_t>(value_ptr);
}
int64_t Attention::get_key_dim() const {
auto value_ptr = this->GetAttr(kKeyDim);
return GetValue<int64_t>(value_ptr);
}
int64_t Attention::get_value_dim() const {
auto value_ptr = this->GetAttr(kValueDim);
return GetValue<int64_t>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
} // namespace mindspore::ops

View File

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
namespace mindspore::opt {
MultiplePatternProcessPass::MultiplePatternProcessPass(const std::string &name, bool multigraph)
: NodePass(name), multigraph_(multigraph), pattern_engine_(PatternEngine(std::make_shared<Visitor>())) {}
AnfNodePtr MultiplePatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
if (patterns_.empty()) {
VarPtr fg = std::make_shared<Var>("RootG");
auto patterns = std::move(DefinePatterns());
for (const auto &pattern : patterns) {
auto primitive_var = std::make_shared<PrimitiveVarMap>();
this->patterns_[pattern.first] = (SexpToNode(pattern.second, fg, primitive_var.get(), multigraph_));
this->primitive_var_maps_[pattern.first] = primitive_var;
}
}
auto empty_equiv = std::make_shared<Equiv>();
MS_ASSERT(primitive_var_maps_.size() == patterns_.size());
for (const auto &iter : primitive_var_maps_) {
auto name = iter.first;
auto primitive_var = iter.second;
auto pattern = this->patterns_[name];
MS_EXCEPTION_IF_NULL(primitive_var);
MS_EXCEPTION_IF_NULL(pattern);
EquivPtr equiv = pattern_engine_.Match(pattern, node, *primitive_var, empty_equiv);
if (equiv != nullptr && !equiv->empty()) {
return Process(name, func_graph, node, equiv);
}
}
return nullptr;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,47 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_MULTIPLE_PATTERN_PROCESS_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_MULTIPLE_PATTERN_PROCESS_PASS_H_
#include <string>
#include <utility>
#include <memory>
#include <unordered_map>
#include "backend/optimizer/common/node_pass.h"
#include "backend/optimizer/common/pattern_engine.h"
#include "backend/optimizer/common/helper.h"
namespace mindspore {
namespace opt {
class MultiplePatternProcessPass : public NodePass {
public:
explicit MultiplePatternProcessPass(const std::string &name = "", bool multigraph = true);
~MultiplePatternProcessPass() override = default;
virtual AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
virtual std::unordered_map<std::string, VectorRef> DefinePatterns() const = 0;
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
private:
std::unordered_map<std::string, BaseRef> patterns_;
std::unordered_map<std::string, PrimitiveVarMapPtr> primitive_var_maps_;
bool multigraph_ = true;
PatternEngine pattern_engine_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_MULTIPLE_PATTERN_PROCESS_PASS_H_

View File

@ -13,8 +13,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/tf_multi_head_attention_fusion.h"
#include "tools/optimizer/fusion/multi_head_attention_fusion.h"
#include <functional>
#include <utility>
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
@ -22,27 +23,108 @@ namespace {
const auto &p1 = std::placeholders::_1;
} // namespace
TfMultiHeadAttentionFusion::TfMultiHeadAttentionFusion(const string &name, bool multigraph)
: PatternProcessPass(name, multigraph) {
MultiHeadAttentionFusion::MultiHeadAttentionFusion(const string &name, bool multigraph)
: MultiplePatternProcessPass(name, multigraph) {
input_q_ = std::make_shared<Var>();
input_k_ = std::make_shared<Var>();
input_v_ = std::make_shared<Var>();
weight_q_ = std::make_shared<Var>();
weight_k_ = std::make_shared<Var>();
weight_v_ = std::make_shared<Var>();
weight_o_ = std::make_shared<Var>();
weight_q_ = std::make_shared<CondVar>(IsParamNode);
weight_k_ = std::make_shared<CondVar>(IsParamNode);
weight_v_ = std::make_shared<CondVar>(IsParamNode);
weight_o_ = std::make_shared<CondVar>(IsParamNode);
bias_q_ = std::make_shared<Var>();
bias_k_ = std::make_shared<Var>();
bias_v_ = std::make_shared<Var>();
bias_o_ = std::make_shared<Var>();
bias_q_ = std::make_shared<CondVar>(IsParamNode);
bias_k_ = std::make_shared<CondVar>(IsParamNode);
bias_v_ = std::make_shared<CondVar>(IsParamNode);
bias_o_ = std::make_shared<CondVar>(IsParamNode);
mask_ = std::make_shared<Var>();
reshape_k_ = std::make_shared<Var>();
reshape_v_ = std::make_shared<Var>();
}
const BaseRef TfMultiHeadAttentionFusion::DefinePattern() const {
namespace {
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) {
auto dense = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), input, weight, bias});
auto reshape =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), dense, std::make_shared<Var>()});
return VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), reshape,
std::make_shared<CondVar>(IsParamNode)});
}
VectorRef DefineMask(const BaseRef &mask_input) {
auto expand_dims = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims)), mask_input,
std::make_shared<CondVar>(IsParamNode)});
auto sub = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSubFusion)),
std::make_shared<CondVar>(IsParamNode), expand_dims});
return VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion)), sub,
std::make_shared<CondVar>(IsParamNode)});
}
} // namespace
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern() const {
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_);
auto k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_);
auto v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_);
auto q2k =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), q_embedding, k_embedding});
auto q2k_normed = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion)), q2k,
std::make_shared<CondVar>(IsParamNode)});
auto mask = DefineMask(mask_);
auto q2k_normed_masked =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion)), q2k_normed, mask});
auto softmax = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax)), q2k_normed_masked});
auto softmax2v =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), softmax, v_embedding});
auto softmax2v_transposed = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)),
softmax2v, std::make_shared<CondVar>(IsParamNode)});
auto softmax2v_transposed_reshaped =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), softmax2v_transposed,
std::make_shared<Var>()});
return VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)),
softmax2v_transposed_reshaped, weight_o_, bias_o_});
}
namespace {
VectorRef DefineDensePattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) {
auto transpose = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), input,
std::make_shared<CondVar>(IsParamNode)});
auto reshape1 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose,
std::make_shared<CondVar>(IsParamNode)});
auto matmul = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), reshape1, weight});
auto reshape2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), matmul,
std::make_shared<CondVar>(IsParamNode)});
if (bias == nullptr) {
return reshape2;
}
auto bias_add = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimBiasAdd)), reshape2, bias});
return bias_add;
}
VectorRef DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias,
const BaseRef &reshape_shape, bool transpose = false) {
auto input_after_dense = DefineDensePattern(input, weight, bias);
auto result = VectorRef(
{std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), input_after_dense, reshape_shape});
if (transpose) {
result = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), result,
std::make_shared<CondVar>(IsParamNode)});
}
return result;
}
VectorRef DefineProcessOutputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) {
auto transpose = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), input,
std::make_shared<CondVar>(IsParamNode)});
auto reshape = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose,
std::make_shared<CondVar>(IsParamNode)});
return DefineDensePattern(reshape, weight, bias);
}
} // namespace
VectorRef MultiHeadAttentionFusion::DefineMPWithoutMaskPattern() const {
auto query = DefineProcessInputPattern(input_q_, weight_q_, bias_q_, std::make_shared<CondVar>(IsParamNode));
auto query_div = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDivFusion)), query,
std::make_shared<CondVar>(IsParamNode)});
@ -59,52 +141,27 @@ const BaseRef TfMultiHeadAttentionFusion::DefinePattern() const {
return DefineProcessOutputPattern(softmax_mul_val, weight_o_, bias_o_);
}
const AnfNodePtr TfMultiHeadAttentionFusion::Process(const mindspore::FuncGraphPtr &func_graph,
const mindspore::AnfNodePtr &node,
const mindspore::EquivPtr &equiv) const {
return CreateMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), 0);
std::unordered_map<std::string, VectorRef> MultiHeadAttentionFusion::DefinePatterns() const {
std::unordered_map<std::string, VectorRef> patterns;
patterns[kMPAWithoutMaskPatternName] = DefineMPWithoutMaskPattern();
patterns[kMPAWithMaskPatternName] = DefineMPWithMaskPattern();
return patterns;
}
const VectorRef TfMultiHeadAttentionFusion::DefineDensePattern(const BaseRef &input, const BaseRef &weight,
const BaseRef &bias) const {
auto transpose = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), input,
std::make_shared<CondVar>(IsParamNode)});
auto reshape1 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose,
std::make_shared<CondVar>(IsParamNode)});
auto matmul = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), reshape1, weight});
auto reshape2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), matmul,
std::make_shared<CondVar>(IsParamNode)});
if (bias == nullptr) {
return reshape2;
AnfNodePtr MultiHeadAttentionFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
const mindspore::AnfNodePtr &node,
const mindspore::EquivPtr &equiv) const {
if (pattern_name == kMPAWithoutMaskPatternName) {
return CreateMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), 0);
} else if (pattern_name == kMPAWithMaskPatternName) {
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), 0);
} else {
return nullptr;
}
auto bias_add = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimBiasAdd)), reshape2, bias});
return bias_add;
}
const VectorRef TfMultiHeadAttentionFusion::DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight,
const BaseRef &bias, const BaseRef &reshape_shape,
bool transpose) const {
auto input_after_dense = DefineDensePattern(input, weight, bias);
auto result = VectorRef(
{std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), input_after_dense, reshape_shape});
if (transpose) {
result = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), result,
std::make_shared<CondVar>(IsParamNode)});
}
return result;
}
const VectorRef TfMultiHeadAttentionFusion::DefineProcessOutputPattern(const BaseRef &input, const BaseRef &weight,
const BaseRef &bias) const {
auto transpose = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), input,
std::make_shared<CondVar>(IsParamNode)});
auto reshape = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose,
std::make_shared<CondVar>(IsParamNode)});
return DefineDensePattern(reshape, weight, bias);
}
CNodePtr TfMultiHeadAttentionFusion::CreateMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const std::string &base_name, int var_offset) const {
CNodePtr MultiHeadAttentionFusion::CreateMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const std::string &base_name, int var_offset) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
auto attention_prim = BuildAttentionPrim(equiv);
@ -151,14 +208,14 @@ STATUS GetIntParameterData(const ParameterPtr &param_ptr, std::vector<int> *resu
}
auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
int64_t shape_size =
std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<int>());
std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>());
for (int64_t i = 0; i < shape_size; i++) {
result->emplace_back(ptr[i]);
}
return RET_OK;
}
std::shared_ptr<ops::Attention> TfMultiHeadAttentionFusion::BuildAttentionPrim(const EquivPtr &equiv) const {
std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(const EquivPtr &equiv) const {
auto attention_prim = std::make_shared<ops::Attention>();
if (attention_prim == nullptr) {
MS_LOG(ERROR) << "Build attention primitive failed.";
@ -191,9 +248,39 @@ std::shared_ptr<ops::Attention> TfMultiHeadAttentionFusion::BuildAttentionPrim(c
MS_LOG(ERROR) << "Shape k or shape v is invalid.";
return nullptr;
}
attention_prim->set_num_heads(shape_k.at(shape_k.size() - 2));
attention_prim->set_key_dim(shape_k.at(shape_k.size() - 1));
attention_prim->set_value_dim(shape_v.at(shape_v.size() - 1));
return attention_prim;
}
CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph,
const EquivPtr &equiv, const string &base_name,
int var_offset) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
auto attention_prim = std::make_shared<ops::Attention>();
if (attention_prim == nullptr) {
MS_LOG(ERROR) << "Build attention primitive failed.";
return nullptr;
}
auto value_node = NewValueNode(attention_prim);
auto input_q = utils::cast<AnfNodePtr>((*equiv)[input_q_]);
auto input_k = utils::cast<AnfNodePtr>((*equiv)[input_k_]);
auto input_v = utils::cast<AnfNodePtr>((*equiv)[input_v_]);
auto weight_q = utils::cast<AnfNodePtr>((*equiv)[weight_q_]);
auto weight_k = utils::cast<AnfNodePtr>((*equiv)[weight_k_]);
auto weight_v = utils::cast<AnfNodePtr>((*equiv)[weight_v_]);
auto weight_o = utils::cast<AnfNodePtr>((*equiv)[weight_o_]);
auto bias_q = utils::cast<AnfNodePtr>((*equiv)[bias_q_]);
auto bias_k = utils::cast<AnfNodePtr>((*equiv)[bias_k_]);
auto bias_v = utils::cast<AnfNodePtr>((*equiv)[bias_v_]);
auto bias_o = utils::cast<AnfNodePtr>((*equiv)[bias_o_]);
auto mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
std::vector<AnfNodePtr> new_node_inputs = {value_node, input_q, input_k, input_v, weight_q, weight_k, weight_v,
weight_o, bias_q, bias_k, bias_v, bias_o, mask};
auto new_node = func_graph->NewCNode(new_node_inputs);
new_node->set_fullname_with_scope(base_name);
return new_node;
}
} // namespace mindspore::opt

View File

@ -19,31 +19,40 @@
#include <vector>
#include <memory>
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include <unordered_map>
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
#include "utils/utils.h"
#include "include/errorcode.h"
#include "tools/converter/ops/attention.h"
#include "ops/attention.h"
namespace mindspore {
namespace opt {
class TfMultiHeadAttentionFusion : public PatternProcessPass {
class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
public:
explicit TfMultiHeadAttentionFusion(const std::string &name = "tflite_multi_head_attention_fusion",
bool multigraph = true);
~TfMultiHeadAttentionFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
explicit MultiHeadAttentionFusion(const std::string &name = "multi_head_attention_fusion", bool multigraph = true);
~MultiHeadAttentionFusion() override = default;
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
const EquivPtr &) const override;
protected:
const VectorRef DefineDensePattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) const;
virtual const VectorRef DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias,
const BaseRef &reshape_shape, bool transpose = false) const;
virtual const VectorRef DefineProcessOutputPattern(const BaseRef &input, const BaseRef &weight,
const BaseRef &bias) const;
// define patterns
VectorRef DefineMPWithMaskPattern() const;
VectorRef DefineMPWithoutMaskPattern() const;
// create multi-head-attention without mask
virtual std::shared_ptr<ops::Attention> BuildAttentionPrim(const EquivPtr &equiv) const;
CNodePtr CreateMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const std::string &base_name, int var_offset) const;
virtual std::shared_ptr<ops::Attention> BuildAttentionPrim(const EquivPtr &equiv) const;
// create masked-multi-head-attention
CNodePtr CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const std::string &base_name, int var_offset) const;
protected:
const std::string kMPAWithoutMaskPatternName = "MPAWithoutMaskPattern";
const std::string kMPAWithMaskPatternName = "MPAWithMaskPattern";
VarPtr input_q_;
VarPtr input_k_;
@ -58,6 +67,8 @@ class TfMultiHeadAttentionFusion : public PatternProcessPass {
VarPtr bias_v_;
VarPtr bias_o_;
VarPtr mask_;
VarPtr reshape_k_;
VarPtr reshape_v_;
};

View File

@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/reshape_reshape_fusion.h"
#include "ops/op_utils.h"
#include "ops/reshape.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
namespace {
const auto &p1 = std::placeholders::_1;
} // namespace
const BaseRef ReshapeReshapeFusion::DefinePattern() const {
auto reshape1 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), reshape_input_,
std::make_shared<CondVar>(IsParamNode)});
return VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), reshape1, reshape_shape_});
}
const AnfNodePtr ReshapeReshapeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
auto reshape_prim = std::make_shared<ops::Reshape>();
if (reshape_prim == nullptr) {
MS_LOG(ERROR) << "Build reshape primitive failed.";
return nullptr;
}
auto value_node = NewValueNode(reshape_prim);
auto input = utils::cast<AnfNodePtr>((*equiv)[reshape_input_]);
auto shape = utils::cast<AnfNodePtr>((*equiv)[reshape_shape_]);
if (input == nullptr || shape == nullptr) {
MS_LOG(ERROR) << "Cannot find reshape input and weight.";
return nullptr;
}
// create scale op
auto new_reshape = func_graph->NewCNode({value_node, input, shape});
if (new_reshape == nullptr) {
MS_LOG(ERROR) << "Create new reshape cnode failed.";
return nullptr;
}
return new_reshape;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,41 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_RESHAPE_RESHAPE_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_RESHAPE_RESHAPE_FUSION_H_
#include <string>
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace opt {
class ReshapeReshapeFusion : public PatternProcessPass {
public:
explicit ReshapeReshapeFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion")
: PatternProcessPass(name, multigraph) {}
~ReshapeReshapeFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr reshape_input_ = std::make_shared<Var>();
VarPtr reshape_shape_ = std::make_shared<Var>();
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_RESHAPE_RESHAPE_FUSION_H_

View File

@ -15,6 +15,7 @@
*/
#include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h"
#include <functional>
#include <utility>
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/quant_param_holder.h"
#include "mindspore/core/ops/transpose.h"
@ -22,10 +23,11 @@
namespace mindspore::opt {
namespace {
const auto &p1 = std::placeholders::_1;
} // namespace
TfliteRelPosMultiHeadAttentionFusion::TfliteRelPosMultiHeadAttentionFusion(const string &name, bool multigraph)
: TfMultiHeadAttentionFusion(name, multigraph) {
: MultiHeadAttentionFusion(name, multigraph) {
query_u_ = std::make_shared<Var>();
query_v_ = std::make_shared<Var>();
input_p_ = std::make_shared<Var>();
@ -44,7 +46,7 @@ TfliteRelPosMultiHeadAttentionFusion::TfliteRelPosMultiHeadAttentionFusion(const
}
}
const BaseRef TfliteRelPosMultiHeadAttentionFusion::DefinePattern() const {
std::unordered_map<std::string, VectorRef> TfliteRelPosMultiHeadAttentionFusion::DefinePatterns() const {
auto query = DefineProcessInputPattern(input_q_, weight_q_, bias_q_, query_stack_params_, query_prim_);
auto query_with_bias_u =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion)), query, query_u_});
@ -77,11 +79,15 @@ const BaseRef TfliteRelPosMultiHeadAttentionFusion::DefinePattern() const {
auto value = DefineProcessInputPattern(input_v_, weight_v_, bias_v_, value_stack_params_, value_prim_, true);
auto output =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), logits_softmax, value});
return DefineProcessOutputPattern(output, weight_o_, bias_o_);
auto pattern = DefineProcessOutputPattern(output, weight_o_, bias_o_);
std::unordered_map<std::string, VectorRef> patterns;
patterns.insert(std::make_pair(kRPMHAttentionPatternName, pattern));
return patterns;
}
const AnfNodePtr TfliteRelPosMultiHeadAttentionFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
AnfNodePtr TfliteRelPosMultiHeadAttentionFusion::Process(const std::string &pattern_name,
const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
return CreateRelPosMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope());
}
@ -139,25 +145,6 @@ std::shared_ptr<ops::Attention> TfliteRelPosMultiHeadAttentionFusion::BuildAtten
}
shape_k.emplace_back(dim);
}
std::vector<int> shape_v;
for (auto &value_stack_param : value_stack_params_) {
auto reshape_v = utils::cast<ParameterPtr>((*equiv)[value_stack_param]);
int dim;
if (RET_OK != GetIntParameterData(reshape_v, &dim)) {
MS_LOG(ERROR) << "Get reshape k data failed";
return nullptr;
}
shape_v.emplace_back(dim);
}
if (shape_k.size() < 2 || shape_v.size() < 2 || shape_k.at(shape_k.size() - 2) != shape_v.at(shape_v.size() - 2)) {
MS_LOG(ERROR) << "Shape k or shape v is invalid.";
return nullptr;
}
attention_prim->set_num_heads(shape_k.at(shape_k.size() - 2));
attention_prim->set_key_dim(shape_k.at(shape_k.size() - 1));
attention_prim->set_value_dim(shape_v.at(shape_v.size() - 1));
return attention_prim;
}

View File

@ -19,22 +19,24 @@
#include <vector>
#include <memory>
#include <string>
#include <unordered_map>
#include "backend/optimizer/common/optimizer.h"
#include "utils/utils.h"
#include "include/errorcode.h"
#include "tools/optimizer/fusion/tf_multi_head_attention_fusion.h"
#include "tools/optimizer/fusion/multi_head_attention_fusion.h"
namespace mindspore::opt {
class TfliteRelPosMultiHeadAttentionFusion : public TfMultiHeadAttentionFusion {
class TfliteRelPosMultiHeadAttentionFusion : public MultiHeadAttentionFusion {
public:
explicit TfliteRelPosMultiHeadAttentionFusion(const std::string &name = "tflite_rel_pos_multi_head_attention_fusion",
bool multigraph = true);
~TfliteRelPosMultiHeadAttentionFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
const EquivPtr &) const override;
protected:
std::shared_ptr<ops::Attention> BuildAttentionPrim(const EquivPtr &equiv) const;
std::shared_ptr<ops::Attention> BuildAttentionPrim(const EquivPtr &equiv) const override;
const VectorRef DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias,
const std::vector<VarPtr> &stack_params, const VarPtr &full_connect_prim,
@ -45,6 +47,8 @@ class TfliteRelPosMultiHeadAttentionFusion : public TfMultiHeadAttentionFusion {
const std::string &base_name) const;
const VectorRef DefineRelativeShiftPattern(const BaseRef &input) const;
private:
const std::string kRPMHAttentionPatternName = "RPMHAttentionPattern";
VarPtr query_u_;
VarPtr query_v_;
VarPtr query_prim_;