From 3086f9a4941f7e166af01731a1bb11c6bc021d2b Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Thu, 1 Jul 2021 15:40:45 +0800 Subject: [PATCH] modify multihead fusion --- .../cpu/nnacl/infer/attention_infer.c | 50 +++++ .../cpu/nnacl/infer/attention_infer.h | 31 +++ .../cpu/nnacl/infer/infer_register.h | 3 +- mindspore/core/ops/attention.cc | 22 ++ .../tools/converter => core}/ops/attention.h | 16 +- mindspore/core/ops/op_utils.h | 11 +- mindspore/lite/schema/ops.fbs | 4 + mindspore/lite/src/ops/ops_def.cc | 4 + mindspore/lite/src/ops/ops_func_declare.h | 2 + mindspore/lite/src/ops/ops_utils.cc | 6 + .../lite/src/ops/populate/common_populate.cc | 2 + mindspore/lite/test/CMakeLists.txt | 5 +- mindspore/lite/tools/converter/CMakeLists.txt | 4 +- .../lite/tools/converter/anf_transform.cc | 5 +- .../lite/tools/converter/ops/attention.cc | 45 ---- .../common/multiple_pattern_process_pass.cc | 49 +++++ .../common/multiple_pattern_process_pass.h | 47 ++++ ...sion.cc => multi_head_attention_fusion.cc} | 203 +++++++++++++----- ...fusion.h => multi_head_attention_fusion.h} | 41 ++-- .../fusion/reshape_reshape_fusion.cc | 57 +++++ .../optimizer/fusion/reshape_reshape_fusion.h | 41 ++++ ...ite_rel_pos_multi_head_attention_fusion.cc | 35 +-- ...lite_rel_pos_multi_head_attention_fusion.h | 14 +- 23 files changed, 532 insertions(+), 165 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/attention_infer.c create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/attention_infer.h create mode 100644 mindspore/core/ops/attention.cc rename mindspore/{lite/tools/converter => core}/ops/attention.h (71%) delete mode 100644 mindspore/lite/tools/converter/ops/attention.cc create mode 100644 mindspore/lite/tools/optimizer/common/multiple_pattern_process_pass.cc create mode 100644 mindspore/lite/tools/optimizer/common/multiple_pattern_process_pass.h rename mindspore/lite/tools/optimizer/fusion/{tf_multi_head_attention_fusion.cc => multi_head_attention_fusion.cc} (51%) rename mindspore/lite/tools/optimizer/fusion/{tf_multi_head_attention_fusion.h => multi_head_attention_fusion.h} (58%) create mode 100644 mindspore/lite/tools/optimizer/fusion/reshape_reshape_fusion.cc create mode 100644 mindspore/lite/tools/optimizer/fusion/reshape_reshape_fusion.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/attention_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/attention_infer.c new file mode 100644 index 00000000000..b8f830df9b2 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/attention_infer.c @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/infer/attention_infer.h" +#include "nnacl/infer/infer_register.h" + +int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { + int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 7, 1); + if (check_ret != NNACL_OK) { + return check_ret; + } + const TensorC *q_input = inputs[0]; + TensorC *output = outputs[0]; + SetDataTypeFormat(output, q_input); + if (!InferFlag(inputs, inputs_size)) { + return NNACL_INFER_INVALID; + } + const TensorC *q_weight = inputs[3]; + if (q_input->shape_size_ != 2 && q_input->shape_size_ != 3) { + return NNACL_ERR; + } + if (q_weight->shape_size_ != 2) { + return NNACL_ERR; + } + int batch = (q_input->shape_size_ == 2) ? 1 : q_input->shape_[0]; + int f_seq = (q_input->shape_size_ == 2) ? q_input->shape_[0] : q_input->shape_[1]; + int d_model = q_weight->shape_[1]; + + output->shape_[0] = batch; + output->shape_[1] = f_seq; + output->shape_[2] = d_model; + output->shape_size_ = 3; + return NNACL_OK; +} + +REG_INFER(Attention, PrimType_Attention, AttentionInferShape) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/attention_infer.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/attention_infer.h new file mode 100644 index 00000000000..ebb602ade01 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/attention_infer.h @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_NNACL_ATTENTION_INFER_H +#define MINDSPORE_NNACL_ATTENTION_INFER_H + +#include "nnacl/infer/common_infer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter); + +#ifdef __cplusplus +} +#endif +#endif // MINDSPORE_NNACL_ATTENTION_INFER_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h index 0088de7483a..0499b7f367e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/infer_register.h @@ -225,8 +225,9 @@ enum PrimType { PrimType_TensorArrayRead = 198, PrimType_TensorArrayWrite = 199, PrimType_Affine = 200, + PrimType_Attention = 201, PrimType_MIN = PrimType_NONE, - PrimType_MAX = PrimType_Affine + 1 + PrimType_MAX = PrimType_Attention + 1 }; void RegInfer(int prim_type, InferShape func); diff --git a/mindspore/core/ops/attention.cc b/mindspore/core/ops/attention.cc new file mode 100644 index 00000000000..89533ba520a --- /dev/null +++ b/mindspore/core/ops/attention.cc @@ -0,0 +1,22 @@ + +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/attention.h" + +namespace mindspore::ops { +REGISTER_PRIMITIVE_C(kNameAttention, Attention); +} // namespace mindspore::ops diff --git a/mindspore/lite/tools/converter/ops/attention.h b/mindspore/core/ops/attention.h similarity index 71% rename from mindspore/lite/tools/converter/ops/attention.h rename to mindspore/core/ops/attention.h index 19e83e487e3..1d74fe0dfc6 100644 --- a/mindspore/lite/tools/converter/ops/attention.h +++ b/mindspore/core/ops/attention.h @@ -19,26 +19,24 @@ #include #include #include +#include "utils/check_convert_utils.h" #include "ops/primitive_c.h" #include "abstract/abstract_value.h" namespace mindspore { namespace ops { constexpr auto kNameAttention = "Attention"; +// Attention MultiHeadAttention class Attention : public PrimitiveC { public: Attention() : PrimitiveC(kNameAttention) { - InitIOName({"query", "key", "value", "w_q", "b_q", "w_k", "b_k", "w_v", "b_v", "w_o", "b_o"}, {"output"}); + InitIOName( + {"q", "k", "v", "weight_q", "weight_k", "weight_v", "weight_o", "bias_q", "bias_k", "bias_v", "bias_o", "mask"}, + {"output"}); } - ~Attention() = default; + ~Attention() override = default; MS_DECLARE_PARENT(Attention, PrimitiveC); - void Init(const int64_t number_heads = 0, const int64_t key_dim = 0, const int64_t value_dim = 0); - void set_num_heads(const int64_t num_heads); - void set_key_dim(const int64_t key_dim); - void set_value_dim(const int64_t value_dim); - int64_t get_num_heads() const; - int64_t get_key_dim() const; - int64_t get_value_dim() const; + void Init() {} }; } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/op_utils.h b/mindspore/core/ops/op_utils.h index 8fd2c168f35..ffb47bb9d5f 100644 --- a/mindspore/core/ops/op_utils.h +++ b/mindspore/core/ops/op_utils.h @@ -28,6 +28,9 @@ namespace mindspore::ops { constexpr auto kAlpha = "alpha"; constexpr auto kActivation = "activation"; constexpr auto kActivationType = "activation_type"; +constexpr auto kAttentionQActType = "attention_q_act_type"; +constexpr auto kAttentionKActType = "attention_k_act_type"; +constexpr auto kAttentionVActType = "attention_v_act_type"; constexpr auto kAddress = "address"; constexpr auto kAlignCorners = "align_corners"; constexpr auto kAttr = "attr"; @@ -85,6 +88,7 @@ constexpr auto kGradX = "grad_x"; constexpr auto kGradY = "grad_y"; constexpr auto kGroup = "group"; constexpr auto kHasBias = "has_bias"; +constexpr auto kAttentionHasMask = "attention_has_mask"; constexpr auto kHiddenSize = "hidden_size"; constexpr auto kId = "id"; constexpr auto kImageSizeH = "image_size_h"; @@ -124,9 +128,10 @@ constexpr auto kNumElements = "num_elements"; constexpr auto kNumBits = "num_bits"; constexpr auto kNumDirections = "num_directions"; constexpr auto kNumProj = "num_proj"; -constexpr auto kNumHeads = "num_heads"; -constexpr auto kKeyDim = "key_dim"; -constexpr auto kValueDim = "value_dim"; +constexpr auto kAttentionNumHeads = "attention_num_heads"; +constexpr auto kAttentionSizePerHead = "attention_size_per_head"; +constexpr auto kAttentionFromSeqLen = "attention_from_seq_len"; +constexpr auto kAttentionToSeqLen = "attention_to_seq_len"; constexpr auto kOffset = "offset"; constexpr auto kNmsIouThreshold = "nms_iou_threshold"; constexpr auto kNmsScoreThreshold = "nms_score_threshold"; diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 662a6a37592..e0a79fe4b95 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -218,6 +218,7 @@ union PrimitiveType { TensorArrayRead, TensorArrayWrite, Affine, + Attention, } table Abs { @@ -1194,3 +1195,6 @@ table Affine { transpose_a: bool = false; transpose_b: bool = false; } + +table Attention { +} diff --git a/mindspore/lite/src/ops/ops_def.cc b/mindspore/lite/src/ops/ops_def.cc index 633aae02364..a64de36d47a 100644 --- a/mindspore/lite/src/ops/ops_def.cc +++ b/mindspore/lite/src/ops/ops_def.cc @@ -218,6 +218,7 @@ OP_TYPE(TensorArrayRead) OP_TYPE(TensorArrayWrite) // kaldi affine op OP_TYPE(Affine) +OP_TYPE(Attention) OP_TYPE_DEF_END(PrimitiveType) OP_SCHEMA_DEF(Abs) @@ -1194,3 +1195,6 @@ OP_ATTR_ENUM_WITH_VALUE(activation_type, ActivationType, 0) OP_ATTR_WITH_VALUE(transpose_a, bool, false) OP_ATTR_WITH_VALUE(transpose_b, bool, false) OP_SCHEMA_DEF_END(Affine) + +OP_SCHEMA_DEF(Attention) +OP_SCHEMA_DEF_END(Attention) diff --git a/mindspore/lite/src/ops/ops_func_declare.h b/mindspore/lite/src/ops/ops_func_declare.h index 620b2c3eced..c6b7b5fdc9f 100644 --- a/mindspore/lite/src/ops/ops_func_declare.h +++ b/mindspore/lite/src/ops/ops_func_declare.h @@ -32,6 +32,7 @@ #include "ops/assert.h" #include "ops/assign.h" #include "ops/assign_add.h" +#include "ops/attention.h" #include "ops/atan.h" #include "ops/audio_spectrogram.h" #include "ops/avg_pool.h" @@ -458,6 +459,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(TensorArray) FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayRead) FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayWrite) FUNC_MSOP2SCHEMAOP_DECLARE(Affine) +FUNC_MSOP2SCHEMAOP_DECLARE(Attention) #endif } // namespace mindspore::lite::ops #else diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index b8dad4cb96b..0eb61c703de 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -800,6 +800,11 @@ std::unique_ptr AffinePrimitiveCreator(const AnfNodePtr &nod return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +std::unique_ptr AttentionPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} + RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator); RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator); RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator); @@ -1023,6 +1028,7 @@ RegistryMSOps g_TensorArrayCreatorRegistry("TensorArray", TensorArrayPrimitiveCr RegistryMSOps g_TensorArrayReadCreatorRegistry("TensorArrayRead", TensorArrayReadPrimitiveCreator); RegistryMSOps g_TensorArrayWriteCreatorRegistry("TensorArrayWrite", TensorArrayWritePrimitiveCreator); RegistryMSOps g_AffineCreatorRegistry("Affine", AffinePrimitiveCreator); +RegistryMSOps g_AttentionCreatorRegistry("Attention", AttentionPrimitiveCreator); std::unique_ptr CustomPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); diff --git a/mindspore/lite/src/ops/populate/common_populate.cc b/mindspore/lite/src/ops/populate/common_populate.cc index 3f268b420b3..b689ea1784a 100644 --- a/mindspore/lite/src/ops/populate/common_populate.cc +++ b/mindspore/lite/src/ops/populate/common_populate.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "src/ops/populate/populate_register.h" +using mindspore::schema::PrimitiveType_Attention; using mindspore::schema::PrimitiveType_Depend; using mindspore::schema::PrimitiveType_ZerosLike; @@ -35,5 +36,6 @@ OpParameter *PopulateCommonParameter(const void *prim) { } REG_POPULATE(PrimitiveType_ZerosLike, PopulateCommonParameter, SCHEMA_CUR) REG_POPULATE(PrimitiveType_Depend, PopulateCommonParameter, SCHEMA_CUR) +REG_POPULATE(PrimitiveType_Attention, PopulateCommonParameter, SCHEMA_CUR) } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 23a3e0cdf9a..1c1edc92648 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -174,7 +174,6 @@ if(MSLITE_ENABLE_CONVERTER) ${TEST_LITE_SRC} ${TEST_CASE_TFLITE_PARSERS_SRC} ${LITE_DIR}/tools/converter/ops/while.cc - ${LITE_DIR}/tools/converter/ops/attention.cc ${LITE_DIR}/tools/common/protobuf_utils.cc ${LITE_DIR}/tools/converter/optimizer.cc ${LITE_DIR}/tools/converter/anf_transform.cc @@ -188,6 +187,7 @@ if(MSLITE_ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc ${LITE_DIR}/tools/optimizer/common/gllo_utils.cc ${LITE_DIR}/tools/optimizer/common/format_utils.cc + ${LITE_DIR}/tools/optimizer/common/multiple_pattern_process_pass.cc ${LITE_DIR}/tools/optimizer/fusion/affine_activation_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/affine_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -196,6 +196,8 @@ if(MSLITE_ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/conv_transform_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/conv_scale_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/multi_head_attention_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/reshape_reshape_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/norm_fusion.cc @@ -208,7 +210,6 @@ if(MSLITE_ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/mul_add_fusion.cc - ${LITE_DIR}/tools/optimizer/fusion/tf_multi_head_attention_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/glu_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 91292567bb1..10e7756d80b 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -44,6 +44,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/common/pass_manager_extends.cc ../optimizer/common/gllo_utils.cc ../optimizer/common/format_utils.cc + ../optimizer/common/multiple_pattern_process_pass.cc ../optimizer/fusion/affine_activation_fusion.cc ../optimizer/fusion/affine_fusion.cc ../optimizer/fusion/conv_biasadd_fusion.cc @@ -62,7 +63,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/tflite_lstm_cell_fusion.cc ../optimizer/fusion/tf_lstm_cell_fusion.cc ../optimizer/fusion/tf_bidirection_gru_fusion.cc - ../optimizer/fusion/tf_multi_head_attention_fusion.cc + ../optimizer/fusion/multi_head_attention_fusion.cc + ../optimizer/fusion/reshape_reshape_fusion.cc ../optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc ../optimizer/fusion/glu_fusion.cc ../optimizer/fusion/matmul_add_fusion.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 05d3b9132ac..53bb04e0a79 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -37,13 +37,14 @@ #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h" #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h" #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h" -#include "tools/optimizer/fusion/tf_multi_head_attention_fusion.h" +#include "tools/optimizer/fusion/multi_head_attention_fusion.h" #include "tools/optimizer/fusion/glu_fusion.h" #include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h" #include "tools/optimizer/fusion/matmul_add_fusion.h" #include "tools/optimizer/fusion/tf_gelu_fusion.h" #include "tools/optimizer/fusion/onnx_gelu_fusion.h" #include "tools/optimizer/fusion/squeeze_fusion.h" +#include "tools/optimizer/fusion/reshape_reshape_fusion.h" #include "tools/optimizer/graph/add_tensor_array.h" #include "tools/optimizer/graph/redundant_op_remove_pass.h" #include "tools/optimizer/graph/clip_convert_activation_pass.h" @@ -81,6 +82,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: if (!config->trainModel) { // remove quantdtype when awaretraining fusion_pm->AddPass(std::make_shared()); + fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); auto conv_bn_pass = std::make_shared(); conv_bn_pass->SetFmkType(config->fmk); @@ -100,7 +102,6 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter:: fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); - fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared(config->fmk)); diff --git a/mindspore/lite/tools/converter/ops/attention.cc b/mindspore/lite/tools/converter/ops/attention.cc deleted file mode 100644 index 8fa8ebdee08..00000000000 --- a/mindspore/lite/tools/converter/ops/attention.cc +++ /dev/null @@ -1,45 +0,0 @@ - -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include "tools/converter/ops/attention.h" -#include "ops/op_utils.h" - -namespace mindspore::ops { -void Attention::Init(int64_t number_heads, int64_t key_dim, int64_t value_dim) { - this->set_num_heads(number_heads); - this->set_key_dim(key_dim); - this->set_value_dim(value_dim); -} -void Attention::set_num_heads(const int64_t num_heads) { this->AddAttr(kNumHeads, MakeValue(num_heads)); } -void Attention::set_key_dim(const int64_t key_dim) { this->AddAttr(kKeyDim, MakeValue(key_dim)); } -void Attention::set_value_dim(const int64_t value_dim) { this->AddAttr(kValueDim, MakeValue(value_dim)); } -int64_t Attention::get_num_heads() const { - auto value_ptr = this->GetAttr(kNumHeads); - return GetValue(value_ptr); -} -int64_t Attention::get_key_dim() const { - auto value_ptr = this->GetAttr(kKeyDim); - return GetValue(value_ptr); -} -int64_t Attention::get_value_dim() const { - auto value_ptr = this->GetAttr(kValueDim); - return GetValue(value_ptr); -} - -REGISTER_PRIMITIVE_C(kNameAttention, Attention); -} // namespace mindspore::ops diff --git a/mindspore/lite/tools/optimizer/common/multiple_pattern_process_pass.cc b/mindspore/lite/tools/optimizer/common/multiple_pattern_process_pass.cc new file mode 100644 index 00000000000..80a76a7c9db --- /dev/null +++ b/mindspore/lite/tools/optimizer/common/multiple_pattern_process_pass.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/common/multiple_pattern_process_pass.h" + +namespace mindspore::opt { +MultiplePatternProcessPass::MultiplePatternProcessPass(const std::string &name, bool multigraph) + : NodePass(name), multigraph_(multigraph), pattern_engine_(PatternEngine(std::make_shared())) {} + +AnfNodePtr MultiplePatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + if (patterns_.empty()) { + VarPtr fg = std::make_shared("RootG"); + auto patterns = std::move(DefinePatterns()); + for (const auto &pattern : patterns) { + auto primitive_var = std::make_shared(); + this->patterns_[pattern.first] = (SexpToNode(pattern.second, fg, primitive_var.get(), multigraph_)); + this->primitive_var_maps_[pattern.first] = primitive_var; + } + } + + auto empty_equiv = std::make_shared(); + MS_ASSERT(primitive_var_maps_.size() == patterns_.size()); + for (const auto &iter : primitive_var_maps_) { + auto name = iter.first; + auto primitive_var = iter.second; + auto pattern = this->patterns_[name]; + MS_EXCEPTION_IF_NULL(primitive_var); + MS_EXCEPTION_IF_NULL(pattern); + EquivPtr equiv = pattern_engine_.Match(pattern, node, *primitive_var, empty_equiv); + if (equiv != nullptr && !equiv->empty()) { + return Process(name, func_graph, node, equiv); + } + } + return nullptr; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/common/multiple_pattern_process_pass.h b/mindspore/lite/tools/optimizer/common/multiple_pattern_process_pass.h new file mode 100644 index 00000000000..a8aecd0290c --- /dev/null +++ b/mindspore/lite/tools/optimizer/common/multiple_pattern_process_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_MULTIPLE_PATTERN_PROCESS_PASS_H_ +#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_MULTIPLE_PATTERN_PROCESS_PASS_H_ + +#include +#include +#include +#include +#include "backend/optimizer/common/node_pass.h" +#include "backend/optimizer/common/pattern_engine.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +class MultiplePatternProcessPass : public NodePass { + public: + explicit MultiplePatternProcessPass(const std::string &name = "", bool multigraph = true); + ~MultiplePatternProcessPass() override = default; + virtual AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0; + virtual std::unordered_map DefinePatterns() const = 0; + AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; + + private: + std::unordered_map patterns_; + std::unordered_map primitive_var_maps_; + bool multigraph_ = true; + PatternEngine pattern_engine_; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_MULTIPLE_PATTERN_PROCESS_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/tf_multi_head_attention_fusion.cc b/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.cc similarity index 51% rename from mindspore/lite/tools/optimizer/fusion/tf_multi_head_attention_fusion.cc rename to mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.cc index a484e210b87..eb48e8c14c8 100644 --- a/mindspore/lite/tools/optimizer/fusion/tf_multi_head_attention_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.cc @@ -13,8 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "tools/optimizer/fusion/tf_multi_head_attention_fusion.h" +#include "tools/optimizer/fusion/multi_head_attention_fusion.h" #include +#include #include "tools/optimizer/common/gllo_utils.h" namespace mindspore::opt { @@ -22,27 +23,108 @@ namespace { const auto &p1 = std::placeholders::_1; } // namespace -TfMultiHeadAttentionFusion::TfMultiHeadAttentionFusion(const string &name, bool multigraph) - : PatternProcessPass(name, multigraph) { +MultiHeadAttentionFusion::MultiHeadAttentionFusion(const string &name, bool multigraph) + : MultiplePatternProcessPass(name, multigraph) { input_q_ = std::make_shared(); input_k_ = std::make_shared(); input_v_ = std::make_shared(); - weight_q_ = std::make_shared(); - weight_k_ = std::make_shared(); - weight_v_ = std::make_shared(); - weight_o_ = std::make_shared(); + weight_q_ = std::make_shared(IsParamNode); + weight_k_ = std::make_shared(IsParamNode); + weight_v_ = std::make_shared(IsParamNode); + weight_o_ = std::make_shared(IsParamNode); - bias_q_ = std::make_shared(); - bias_k_ = std::make_shared(); - bias_v_ = std::make_shared(); - bias_o_ = std::make_shared(); + bias_q_ = std::make_shared(IsParamNode); + bias_k_ = std::make_shared(IsParamNode); + bias_v_ = std::make_shared(IsParamNode); + bias_o_ = std::make_shared(IsParamNode); + + mask_ = std::make_shared(); reshape_k_ = std::make_shared(); reshape_v_ = std::make_shared(); } -const BaseRef TfMultiHeadAttentionFusion::DefinePattern() const { +namespace { +VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) { + auto dense = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMul)), input, weight, bias}); + auto reshape = + VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), dense, std::make_shared()}); + return VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose)), reshape, + std::make_shared(IsParamNode)}); +} + +VectorRef DefineMask(const BaseRef &mask_input) { + auto expand_dims = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimExpandDims)), mask_input, + std::make_shared(IsParamNode)}); + auto sub = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimSubFusion)), + std::make_shared(IsParamNode), expand_dims}); + return VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMulFusion)), sub, + std::make_shared(IsParamNode)}); +} +} // namespace + +VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern() const { + auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_); + auto k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_); + auto v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_); + auto q2k = + VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMul)), q_embedding, k_embedding}); + auto q2k_normed = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMulFusion)), q2k, + std::make_shared(IsParamNode)}); + auto mask = DefineMask(mask_); + auto q2k_normed_masked = + VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimAddFusion)), q2k_normed, mask}); + auto softmax = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimSoftmax)), q2k_normed_masked}); + auto softmax2v = + VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMul)), softmax, v_embedding}); + auto softmax2v_transposed = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose)), + softmax2v, std::make_shared(IsParamNode)}); + auto softmax2v_transposed_reshaped = + VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), softmax2v_transposed, + std::make_shared()}); + return VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMul)), + softmax2v_transposed_reshaped, weight_o_, bias_o_}); +} + +namespace { +VectorRef DefineDensePattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) { + auto transpose = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose)), input, + std::make_shared(IsParamNode)}); + auto reshape1 = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose, + std::make_shared(IsParamNode)}); + auto matmul = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMul)), reshape1, weight}); + auto reshape2 = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), matmul, + std::make_shared(IsParamNode)}); + if (bias == nullptr) { + return reshape2; + } + auto bias_add = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimBiasAdd)), reshape2, bias}); + return bias_add; +} + +VectorRef DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, + const BaseRef &reshape_shape, bool transpose = false) { + auto input_after_dense = DefineDensePattern(input, weight, bias); + auto result = VectorRef( + {std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), input_after_dense, reshape_shape}); + if (transpose) { + result = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose)), result, + std::make_shared(IsParamNode)}); + } + return result; +} + +VectorRef DefineProcessOutputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) { + auto transpose = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose)), input, + std::make_shared(IsParamNode)}); + auto reshape = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose, + std::make_shared(IsParamNode)}); + return DefineDensePattern(reshape, weight, bias); +} +} // namespace + +VectorRef MultiHeadAttentionFusion::DefineMPWithoutMaskPattern() const { auto query = DefineProcessInputPattern(input_q_, weight_q_, bias_q_, std::make_shared(IsParamNode)); auto query_div = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimDivFusion)), query, std::make_shared(IsParamNode)}); @@ -59,52 +141,27 @@ const BaseRef TfMultiHeadAttentionFusion::DefinePattern() const { return DefineProcessOutputPattern(softmax_mul_val, weight_o_, bias_o_); } -const AnfNodePtr TfMultiHeadAttentionFusion::Process(const mindspore::FuncGraphPtr &func_graph, - const mindspore::AnfNodePtr &node, - const mindspore::EquivPtr &equiv) const { - return CreateMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), 0); +std::unordered_map MultiHeadAttentionFusion::DefinePatterns() const { + std::unordered_map patterns; + patterns[kMPAWithoutMaskPatternName] = DefineMPWithoutMaskPattern(); + patterns[kMPAWithMaskPatternName] = DefineMPWithMaskPattern(); + return patterns; } -const VectorRef TfMultiHeadAttentionFusion::DefineDensePattern(const BaseRef &input, const BaseRef &weight, - const BaseRef &bias) const { - auto transpose = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose)), input, - std::make_shared(IsParamNode)}); - auto reshape1 = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose, - std::make_shared(IsParamNode)}); - auto matmul = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMul)), reshape1, weight}); - auto reshape2 = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), matmul, - std::make_shared(IsParamNode)}); - if (bias == nullptr) { - return reshape2; +AnfNodePtr MultiHeadAttentionFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph, + const mindspore::AnfNodePtr &node, + const mindspore::EquivPtr &equiv) const { + if (pattern_name == kMPAWithoutMaskPatternName) { + return CreateMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), 0); + } else if (pattern_name == kMPAWithMaskPatternName) { + return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), 0); + } else { + return nullptr; } - auto bias_add = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimBiasAdd)), reshape2, bias}); - return bias_add; } -const VectorRef TfMultiHeadAttentionFusion::DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight, - const BaseRef &bias, const BaseRef &reshape_shape, - bool transpose) const { - auto input_after_dense = DefineDensePattern(input, weight, bias); - auto result = VectorRef( - {std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), input_after_dense, reshape_shape}); - if (transpose) { - result = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose)), result, - std::make_shared(IsParamNode)}); - } - return result; -} - -const VectorRef TfMultiHeadAttentionFusion::DefineProcessOutputPattern(const BaseRef &input, const BaseRef &weight, - const BaseRef &bias) const { - auto transpose = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose)), input, - std::make_shared(IsParamNode)}); - auto reshape = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose, - std::make_shared(IsParamNode)}); - return DefineDensePattern(reshape, weight, bias); -} - -CNodePtr TfMultiHeadAttentionFusion::CreateMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, - const std::string &base_name, int var_offset) const { +CNodePtr MultiHeadAttentionFusion::CreateMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, + const std::string &base_name, int var_offset) const { MS_ASSERT(func_graph != nullptr); MS_ASSERT(equiv != nullptr); auto attention_prim = BuildAttentionPrim(equiv); @@ -151,14 +208,14 @@ STATUS GetIntParameterData(const ParameterPtr ¶m_ptr, std::vector *resu } auto ptr = reinterpret_cast(default_param_ptr->data_c()); int64_t shape_size = - std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies()); + std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>()); for (int64_t i = 0; i < shape_size; i++) { result->emplace_back(ptr[i]); } return RET_OK; } -std::shared_ptr TfMultiHeadAttentionFusion::BuildAttentionPrim(const EquivPtr &equiv) const { +std::shared_ptr MultiHeadAttentionFusion::BuildAttentionPrim(const EquivPtr &equiv) const { auto attention_prim = std::make_shared(); if (attention_prim == nullptr) { MS_LOG(ERROR) << "Build attention primitive failed."; @@ -191,9 +248,39 @@ std::shared_ptr TfMultiHeadAttentionFusion::BuildAttentionPrim(c MS_LOG(ERROR) << "Shape k or shape v is invalid."; return nullptr; } - attention_prim->set_num_heads(shape_k.at(shape_k.size() - 2)); - attention_prim->set_key_dim(shape_k.at(shape_k.size() - 1)); - attention_prim->set_value_dim(shape_v.at(shape_v.size() - 1)); return attention_prim; } + +CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph, + const EquivPtr &equiv, const string &base_name, + int var_offset) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(equiv != nullptr); + auto attention_prim = std::make_shared(); + if (attention_prim == nullptr) { + MS_LOG(ERROR) << "Build attention primitive failed."; + return nullptr; + } + auto value_node = NewValueNode(attention_prim); + auto input_q = utils::cast((*equiv)[input_q_]); + auto input_k = utils::cast((*equiv)[input_k_]); + auto input_v = utils::cast((*equiv)[input_v_]); + + auto weight_q = utils::cast((*equiv)[weight_q_]); + auto weight_k = utils::cast((*equiv)[weight_k_]); + auto weight_v = utils::cast((*equiv)[weight_v_]); + auto weight_o = utils::cast((*equiv)[weight_o_]); + + auto bias_q = utils::cast((*equiv)[bias_q_]); + auto bias_k = utils::cast((*equiv)[bias_k_]); + auto bias_v = utils::cast((*equiv)[bias_v_]); + auto bias_o = utils::cast((*equiv)[bias_o_]); + auto mask = utils::cast((*equiv)[mask_]); + + std::vector new_node_inputs = {value_node, input_q, input_k, input_v, weight_q, weight_k, weight_v, + weight_o, bias_q, bias_k, bias_v, bias_o, mask}; + auto new_node = func_graph->NewCNode(new_node_inputs); + new_node->set_fullname_with_scope(base_name); + return new_node; +} } // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/tf_multi_head_attention_fusion.h b/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.h similarity index 58% rename from mindspore/lite/tools/optimizer/fusion/tf_multi_head_attention_fusion.h rename to mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.h index fae54edab2e..b1d9b1cf5b0 100644 --- a/mindspore/lite/tools/optimizer/fusion/tf_multi_head_attention_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.h @@ -19,31 +19,40 @@ #include #include #include -#include "backend/optimizer/common/optimizer.h" +#include +#include "tools/optimizer/common/multiple_pattern_process_pass.h" #include "utils/utils.h" #include "include/errorcode.h" -#include "tools/converter/ops/attention.h" +#include "ops/attention.h" namespace mindspore { namespace opt { -class TfMultiHeadAttentionFusion : public PatternProcessPass { +class MultiHeadAttentionFusion : public MultiplePatternProcessPass { public: - explicit TfMultiHeadAttentionFusion(const std::string &name = "tflite_multi_head_attention_fusion", - bool multigraph = true); - ~TfMultiHeadAttentionFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + explicit MultiHeadAttentionFusion(const std::string &name = "multi_head_attention_fusion", bool multigraph = true); + + ~MultiHeadAttentionFusion() override = default; + + std::unordered_map DefinePatterns() const override; + + AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &, + const EquivPtr &) const override; protected: - const VectorRef DefineDensePattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) const; - virtual const VectorRef DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, - const BaseRef &reshape_shape, bool transpose = false) const; - virtual const VectorRef DefineProcessOutputPattern(const BaseRef &input, const BaseRef &weight, - const BaseRef &bias) const; - + // define patterns + VectorRef DefineMPWithMaskPattern() const; + VectorRef DefineMPWithoutMaskPattern() const; + // create multi-head-attention without mask + virtual std::shared_ptr BuildAttentionPrim(const EquivPtr &equiv) const; CNodePtr CreateMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const std::string &base_name, int var_offset) const; - virtual std::shared_ptr BuildAttentionPrim(const EquivPtr &equiv) const; + // create masked-multi-head-attention + CNodePtr CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, + const std::string &base_name, int var_offset) const; + + protected: + const std::string kMPAWithoutMaskPatternName = "MPAWithoutMaskPattern"; + const std::string kMPAWithMaskPatternName = "MPAWithMaskPattern"; VarPtr input_q_; VarPtr input_k_; @@ -58,6 +67,8 @@ class TfMultiHeadAttentionFusion : public PatternProcessPass { VarPtr bias_v_; VarPtr bias_o_; + VarPtr mask_; + VarPtr reshape_k_; VarPtr reshape_v_; }; diff --git a/mindspore/lite/tools/optimizer/fusion/reshape_reshape_fusion.cc b/mindspore/lite/tools/optimizer/fusion/reshape_reshape_fusion.cc new file mode 100644 index 00000000000..562c99fea36 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/reshape_reshape_fusion.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/reshape_reshape_fusion.h" +#include "ops/op_utils.h" +#include "ops/reshape.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore::opt { +namespace { +const auto &p1 = std::placeholders::_1; +} // namespace + +const BaseRef ReshapeReshapeFusion::DefinePattern() const { + auto reshape1 = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), reshape_input_, + std::make_shared(IsParamNode)}); + return VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)), reshape1, reshape_shape_}); +} + +const AnfNodePtr ReshapeReshapeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(node != nullptr); + auto reshape_prim = std::make_shared(); + if (reshape_prim == nullptr) { + MS_LOG(ERROR) << "Build reshape primitive failed."; + return nullptr; + } + auto value_node = NewValueNode(reshape_prim); + auto input = utils::cast((*equiv)[reshape_input_]); + auto shape = utils::cast((*equiv)[reshape_shape_]); + if (input == nullptr || shape == nullptr) { + MS_LOG(ERROR) << "Cannot find reshape input and weight."; + return nullptr; + } + // create scale op + auto new_reshape = func_graph->NewCNode({value_node, input, shape}); + if (new_reshape == nullptr) { + MS_LOG(ERROR) << "Create new reshape cnode failed."; + return nullptr; + } + return new_reshape; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/reshape_reshape_fusion.h b/mindspore/lite/tools/optimizer/fusion/reshape_reshape_fusion.h new file mode 100644 index 00000000000..cd7d0de2bd1 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/reshape_reshape_fusion.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_RESHAPE_RESHAPE_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_RESHAPE_RESHAPE_FUSION_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace opt { +class ReshapeReshapeFusion : public PatternProcessPass { + public: + explicit ReshapeReshapeFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion") + : PatternProcessPass(name, multigraph) {} + ~ReshapeReshapeFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr reshape_input_ = std::make_shared(); + VarPtr reshape_shape_ = std::make_shared(); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_RESHAPE_RESHAPE_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc index 26c25d82b06..139d1110e95 100644 --- a/mindspore/lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc @@ -15,6 +15,7 @@ */ #include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h" #include +#include #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/quant_param_holder.h" #include "mindspore/core/ops/transpose.h" @@ -22,10 +23,11 @@ namespace mindspore::opt { namespace { const auto &p1 = std::placeholders::_1; + } // namespace TfliteRelPosMultiHeadAttentionFusion::TfliteRelPosMultiHeadAttentionFusion(const string &name, bool multigraph) - : TfMultiHeadAttentionFusion(name, multigraph) { + : MultiHeadAttentionFusion(name, multigraph) { query_u_ = std::make_shared(); query_v_ = std::make_shared(); input_p_ = std::make_shared(); @@ -44,7 +46,7 @@ TfliteRelPosMultiHeadAttentionFusion::TfliteRelPosMultiHeadAttentionFusion(const } } -const BaseRef TfliteRelPosMultiHeadAttentionFusion::DefinePattern() const { +std::unordered_map TfliteRelPosMultiHeadAttentionFusion::DefinePatterns() const { auto query = DefineProcessInputPattern(input_q_, weight_q_, bias_q_, query_stack_params_, query_prim_); auto query_with_bias_u = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimAddFusion)), query, query_u_}); @@ -77,11 +79,15 @@ const BaseRef TfliteRelPosMultiHeadAttentionFusion::DefinePattern() const { auto value = DefineProcessInputPattern(input_v_, weight_v_, bias_v_, value_stack_params_, value_prim_, true); auto output = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMul)), logits_softmax, value}); - return DefineProcessOutputPattern(output, weight_o_, bias_o_); + auto pattern = DefineProcessOutputPattern(output, weight_o_, bias_o_); + std::unordered_map patterns; + patterns.insert(std::make_pair(kRPMHAttentionPatternName, pattern)); + return patterns; } -const AnfNodePtr TfliteRelPosMultiHeadAttentionFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { +AnfNodePtr TfliteRelPosMultiHeadAttentionFusion::Process(const std::string &pattern_name, + const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { return CreateRelPosMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope()); } @@ -139,25 +145,6 @@ std::shared_ptr TfliteRelPosMultiHeadAttentionFusion::BuildAtten } shape_k.emplace_back(dim); } - - std::vector shape_v; - for (auto &value_stack_param : value_stack_params_) { - auto reshape_v = utils::cast((*equiv)[value_stack_param]); - int dim; - if (RET_OK != GetIntParameterData(reshape_v, &dim)) { - MS_LOG(ERROR) << "Get reshape k data failed"; - return nullptr; - } - shape_v.emplace_back(dim); - } - - if (shape_k.size() < 2 || shape_v.size() < 2 || shape_k.at(shape_k.size() - 2) != shape_v.at(shape_v.size() - 2)) { - MS_LOG(ERROR) << "Shape k or shape v is invalid."; - return nullptr; - } - attention_prim->set_num_heads(shape_k.at(shape_k.size() - 2)); - attention_prim->set_key_dim(shape_k.at(shape_k.size() - 1)); - attention_prim->set_value_dim(shape_v.at(shape_v.size() - 1)); return attention_prim; } diff --git a/mindspore/lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h b/mindspore/lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h index 7ce78509209..e9d9cbb296c 100644 --- a/mindspore/lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h @@ -19,22 +19,24 @@ #include #include #include +#include #include "backend/optimizer/common/optimizer.h" #include "utils/utils.h" #include "include/errorcode.h" -#include "tools/optimizer/fusion/tf_multi_head_attention_fusion.h" +#include "tools/optimizer/fusion/multi_head_attention_fusion.h" namespace mindspore::opt { -class TfliteRelPosMultiHeadAttentionFusion : public TfMultiHeadAttentionFusion { +class TfliteRelPosMultiHeadAttentionFusion : public MultiHeadAttentionFusion { public: explicit TfliteRelPosMultiHeadAttentionFusion(const std::string &name = "tflite_rel_pos_multi_head_attention_fusion", bool multigraph = true); ~TfliteRelPosMultiHeadAttentionFusion() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + std::unordered_map DefinePatterns() const override; + AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &, + const EquivPtr &) const override; protected: - std::shared_ptr BuildAttentionPrim(const EquivPtr &equiv) const; + std::shared_ptr BuildAttentionPrim(const EquivPtr &equiv) const override; const VectorRef DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, const std::vector &stack_params, const VarPtr &full_connect_prim, @@ -45,6 +47,8 @@ class TfliteRelPosMultiHeadAttentionFusion : public TfMultiHeadAttentionFusion { const std::string &base_name) const; const VectorRef DefineRelativeShiftPattern(const BaseRef &input) const; + private: + const std::string kRPMHAttentionPatternName = "RPMHAttentionPattern"; VarPtr query_u_; VarPtr query_v_; VarPtr query_prim_;