modify multihead fusion
This commit is contained in:
parent
5b98330f8d
commit
3086f9a494
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/infer/attention_infer.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
|
||||
int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, 7, 1);
|
||||
if (check_ret != NNACL_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
const TensorC *q_input = inputs[0];
|
||||
TensorC *output = outputs[0];
|
||||
SetDataTypeFormat(output, q_input);
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
const TensorC *q_weight = inputs[3];
|
||||
if (q_input->shape_size_ != 2 && q_input->shape_size_ != 3) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
if (q_weight->shape_size_ != 2) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
int batch = (q_input->shape_size_ == 2) ? 1 : q_input->shape_[0];
|
||||
int f_seq = (q_input->shape_size_ == 2) ? q_input->shape_[0] : q_input->shape_[1];
|
||||
int d_model = q_weight->shape_[1];
|
||||
|
||||
output->shape_[0] = batch;
|
||||
output->shape_[1] = f_seq;
|
||||
output->shape_[2] = d_model;
|
||||
output->shape_size_ = 3;
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
REG_INFER(Attention, PrimType_Attention, AttentionInferShape)
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_ATTENTION_INFER_H
|
||||
#define MINDSPORE_NNACL_ATTENTION_INFER_H
|
||||
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_NNACL_ATTENTION_INFER_H
|
|
@ -225,8 +225,9 @@ enum PrimType {
|
|||
PrimType_TensorArrayRead = 198,
|
||||
PrimType_TensorArrayWrite = 199,
|
||||
PrimType_Affine = 200,
|
||||
PrimType_Attention = 201,
|
||||
PrimType_MIN = PrimType_NONE,
|
||||
PrimType_MAX = PrimType_Affine + 1
|
||||
PrimType_MAX = PrimType_Attention + 1
|
||||
};
|
||||
|
||||
void RegInfer(int prim_type, InferShape func);
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/attention.h"
|
||||
|
||||
namespace mindspore::ops {
|
||||
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
|
||||
} // namespace mindspore::ops
|
|
@ -19,26 +19,24 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameAttention = "Attention";
|
||||
// Attention MultiHeadAttention
|
||||
class Attention : public PrimitiveC {
|
||||
public:
|
||||
Attention() : PrimitiveC(kNameAttention) {
|
||||
InitIOName({"query", "key", "value", "w_q", "b_q", "w_k", "b_k", "w_v", "b_v", "w_o", "b_o"}, {"output"});
|
||||
InitIOName(
|
||||
{"q", "k", "v", "weight_q", "weight_k", "weight_v", "weight_o", "bias_q", "bias_k", "bias_v", "bias_o", "mask"},
|
||||
{"output"});
|
||||
}
|
||||
~Attention() = default;
|
||||
~Attention() override = default;
|
||||
MS_DECLARE_PARENT(Attention, PrimitiveC);
|
||||
void Init(const int64_t number_heads = 0, const int64_t key_dim = 0, const int64_t value_dim = 0);
|
||||
void set_num_heads(const int64_t num_heads);
|
||||
void set_key_dim(const int64_t key_dim);
|
||||
void set_value_dim(const int64_t value_dim);
|
||||
int64_t get_num_heads() const;
|
||||
int64_t get_key_dim() const;
|
||||
int64_t get_value_dim() const;
|
||||
void Init() {}
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -28,6 +28,9 @@ namespace mindspore::ops {
|
|||
constexpr auto kAlpha = "alpha";
|
||||
constexpr auto kActivation = "activation";
|
||||
constexpr auto kActivationType = "activation_type";
|
||||
constexpr auto kAttentionQActType = "attention_q_act_type";
|
||||
constexpr auto kAttentionKActType = "attention_k_act_type";
|
||||
constexpr auto kAttentionVActType = "attention_v_act_type";
|
||||
constexpr auto kAddress = "address";
|
||||
constexpr auto kAlignCorners = "align_corners";
|
||||
constexpr auto kAttr = "attr";
|
||||
|
@ -85,6 +88,7 @@ constexpr auto kGradX = "grad_x";
|
|||
constexpr auto kGradY = "grad_y";
|
||||
constexpr auto kGroup = "group";
|
||||
constexpr auto kHasBias = "has_bias";
|
||||
constexpr auto kAttentionHasMask = "attention_has_mask";
|
||||
constexpr auto kHiddenSize = "hidden_size";
|
||||
constexpr auto kId = "id";
|
||||
constexpr auto kImageSizeH = "image_size_h";
|
||||
|
@ -124,9 +128,10 @@ constexpr auto kNumElements = "num_elements";
|
|||
constexpr auto kNumBits = "num_bits";
|
||||
constexpr auto kNumDirections = "num_directions";
|
||||
constexpr auto kNumProj = "num_proj";
|
||||
constexpr auto kNumHeads = "num_heads";
|
||||
constexpr auto kKeyDim = "key_dim";
|
||||
constexpr auto kValueDim = "value_dim";
|
||||
constexpr auto kAttentionNumHeads = "attention_num_heads";
|
||||
constexpr auto kAttentionSizePerHead = "attention_size_per_head";
|
||||
constexpr auto kAttentionFromSeqLen = "attention_from_seq_len";
|
||||
constexpr auto kAttentionToSeqLen = "attention_to_seq_len";
|
||||
constexpr auto kOffset = "offset";
|
||||
constexpr auto kNmsIouThreshold = "nms_iou_threshold";
|
||||
constexpr auto kNmsScoreThreshold = "nms_score_threshold";
|
||||
|
|
|
@ -218,6 +218,7 @@ union PrimitiveType {
|
|||
TensorArrayRead,
|
||||
TensorArrayWrite,
|
||||
Affine,
|
||||
Attention,
|
||||
}
|
||||
|
||||
table Abs {
|
||||
|
@ -1194,3 +1195,6 @@ table Affine {
|
|||
transpose_a: bool = false;
|
||||
transpose_b: bool = false;
|
||||
}
|
||||
|
||||
table Attention {
|
||||
}
|
||||
|
|
|
@ -218,6 +218,7 @@ OP_TYPE(TensorArrayRead)
|
|||
OP_TYPE(TensorArrayWrite)
|
||||
// kaldi affine op
|
||||
OP_TYPE(Affine)
|
||||
OP_TYPE(Attention)
|
||||
OP_TYPE_DEF_END(PrimitiveType)
|
||||
|
||||
OP_SCHEMA_DEF(Abs)
|
||||
|
@ -1194,3 +1195,6 @@ OP_ATTR_ENUM_WITH_VALUE(activation_type, ActivationType, 0)
|
|||
OP_ATTR_WITH_VALUE(transpose_a, bool, false)
|
||||
OP_ATTR_WITH_VALUE(transpose_b, bool, false)
|
||||
OP_SCHEMA_DEF_END(Affine)
|
||||
|
||||
OP_SCHEMA_DEF(Attention)
|
||||
OP_SCHEMA_DEF_END(Attention)
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "ops/assert.h"
|
||||
#include "ops/assign.h"
|
||||
#include "ops/assign_add.h"
|
||||
#include "ops/attention.h"
|
||||
#include "ops/atan.h"
|
||||
#include "ops/audio_spectrogram.h"
|
||||
#include "ops/avg_pool.h"
|
||||
|
@ -458,6 +459,7 @@ FUNC_MSOP2SCHEMAOP_DECLARE(TensorArray)
|
|||
FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayRead)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(TensorArrayWrite)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Affine)
|
||||
FUNC_MSOP2SCHEMAOP_DECLARE(Attention)
|
||||
#endif
|
||||
} // namespace mindspore::lite::ops
|
||||
#else
|
||||
|
|
|
@ -800,6 +800,11 @@ std::unique_ptr<schema::PrimitiveT> AffinePrimitiveCreator(const AnfNodePtr &nod
|
|||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::PrimitiveT> AttentionPrimitiveCreator(const AnfNodePtr &node) {
|
||||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Attention>>(node);
|
||||
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
|
||||
}
|
||||
|
||||
RegistryMSOps g_absPrimitiveCreatorRegistry("Abs", AbsPrimitiveCreator);
|
||||
RegistryMSOps g_absGradPrimitiveCreatorRegistry("AbsGrad", AbsGradPrimitiveCreator);
|
||||
RegistryMSOps g_activationPrimitiveCreatorRegistry("Activation", ActivationPrimitiveCreator);
|
||||
|
@ -1023,6 +1028,7 @@ RegistryMSOps g_TensorArrayCreatorRegistry("TensorArray", TensorArrayPrimitiveCr
|
|||
RegistryMSOps g_TensorArrayReadCreatorRegistry("TensorArrayRead", TensorArrayReadPrimitiveCreator);
|
||||
RegistryMSOps g_TensorArrayWriteCreatorRegistry("TensorArrayWrite", TensorArrayWritePrimitiveCreator);
|
||||
RegistryMSOps g_AffineCreatorRegistry("Affine", AffinePrimitiveCreator);
|
||||
RegistryMSOps g_AttentionCreatorRegistry("Attention", AttentionPrimitiveCreator);
|
||||
|
||||
std::unique_ptr<schema::PrimitiveT> CustomPrimitiveCreator(const AnfNodePtr &node) {
|
||||
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Custom>>(node);
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
using mindspore::schema::PrimitiveType_Attention;
|
||||
using mindspore::schema::PrimitiveType_Depend;
|
||||
using mindspore::schema::PrimitiveType_ZerosLike;
|
||||
|
||||
|
@ -35,5 +36,6 @@ OpParameter *PopulateCommonParameter(const void *prim) {
|
|||
}
|
||||
REG_POPULATE(PrimitiveType_ZerosLike, PopulateCommonParameter, SCHEMA_CUR)
|
||||
REG_POPULATE(PrimitiveType_Depend, PopulateCommonParameter, SCHEMA_CUR)
|
||||
REG_POPULATE(PrimitiveType_Attention, PopulateCommonParameter, SCHEMA_CUR)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -174,7 +174,6 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${TEST_LITE_SRC}
|
||||
${TEST_CASE_TFLITE_PARSERS_SRC}
|
||||
${LITE_DIR}/tools/converter/ops/while.cc
|
||||
${LITE_DIR}/tools/converter/ops/attention.cc
|
||||
${LITE_DIR}/tools/common/protobuf_utils.cc
|
||||
${LITE_DIR}/tools/converter/optimizer.cc
|
||||
${LITE_DIR}/tools/converter/anf_transform.cc
|
||||
|
@ -188,6 +187,7 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/common/pass_manager_extends.cc
|
||||
${LITE_DIR}/tools/optimizer/common/gllo_utils.cc
|
||||
${LITE_DIR}/tools/optimizer/common/format_utils.cc
|
||||
${LITE_DIR}/tools/optimizer/common/multiple_pattern_process_pass.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/affine_activation_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/affine_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/conv_biasadd_fusion.cc
|
||||
|
@ -196,6 +196,8 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/fusion/conv_transform_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/conv_scale_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/multi_head_attention_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/reshape_reshape_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/norm_fusion.cc
|
||||
|
@ -208,7 +210,6 @@ if(MSLITE_ENABLE_CONVERTER)
|
|||
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/mul_add_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/tf_multi_head_attention_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/glu_fusion.cc
|
||||
${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc
|
||||
|
|
|
@ -44,6 +44,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/common/pass_manager_extends.cc
|
||||
../optimizer/common/gllo_utils.cc
|
||||
../optimizer/common/format_utils.cc
|
||||
../optimizer/common/multiple_pattern_process_pass.cc
|
||||
../optimizer/fusion/affine_activation_fusion.cc
|
||||
../optimizer/fusion/affine_fusion.cc
|
||||
../optimizer/fusion/conv_biasadd_fusion.cc
|
||||
|
@ -62,7 +63,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
../optimizer/fusion/tflite_lstm_cell_fusion.cc
|
||||
../optimizer/fusion/tf_lstm_cell_fusion.cc
|
||||
../optimizer/fusion/tf_bidirection_gru_fusion.cc
|
||||
../optimizer/fusion/tf_multi_head_attention_fusion.cc
|
||||
../optimizer/fusion/multi_head_attention_fusion.cc
|
||||
../optimizer/fusion/reshape_reshape_fusion.cc
|
||||
../optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.cc
|
||||
../optimizer/fusion/glu_fusion.cc
|
||||
../optimizer/fusion/matmul_add_fusion.cc
|
||||
|
|
|
@ -37,13 +37,14 @@
|
|||
#include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_multi_head_attention_fusion.h"
|
||||
#include "tools/optimizer/fusion/multi_head_attention_fusion.h"
|
||||
#include "tools/optimizer/fusion/glu_fusion.h"
|
||||
#include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h"
|
||||
#include "tools/optimizer/fusion/matmul_add_fusion.h"
|
||||
#include "tools/optimizer/fusion/tf_gelu_fusion.h"
|
||||
#include "tools/optimizer/fusion/onnx_gelu_fusion.h"
|
||||
#include "tools/optimizer/fusion/squeeze_fusion.h"
|
||||
#include "tools/optimizer/fusion/reshape_reshape_fusion.h"
|
||||
#include "tools/optimizer/graph/add_tensor_array.h"
|
||||
#include "tools/optimizer/graph/redundant_op_remove_pass.h"
|
||||
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
|
||||
|
@ -81,6 +82,7 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
if (!config->trainModel) {
|
||||
// remove quantdtype when awaretraining
|
||||
fusion_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
||||
auto conv_bn_pass = std::make_shared<opt::ConvBatchNormFusion>();
|
||||
conv_bn_pass->SetFmkType(config->fmk);
|
||||
|
@ -100,7 +102,6 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::
|
|||
fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfGeLUFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfMultiHeadAttentionFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::GLUFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
|
||||
#include "tools/converter/ops/attention.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore::ops {
|
||||
void Attention::Init(int64_t number_heads, int64_t key_dim, int64_t value_dim) {
|
||||
this->set_num_heads(number_heads);
|
||||
this->set_key_dim(key_dim);
|
||||
this->set_value_dim(value_dim);
|
||||
}
|
||||
void Attention::set_num_heads(const int64_t num_heads) { this->AddAttr(kNumHeads, MakeValue(num_heads)); }
|
||||
void Attention::set_key_dim(const int64_t key_dim) { this->AddAttr(kKeyDim, MakeValue(key_dim)); }
|
||||
void Attention::set_value_dim(const int64_t value_dim) { this->AddAttr(kValueDim, MakeValue(value_dim)); }
|
||||
int64_t Attention::get_num_heads() const {
|
||||
auto value_ptr = this->GetAttr(kNumHeads);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
int64_t Attention::get_key_dim() const {
|
||||
auto value_ptr = this->GetAttr(kKeyDim);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
int64_t Attention::get_value_dim() const {
|
||||
auto value_ptr = this->GetAttr(kValueDim);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
|
||||
} // namespace mindspore::ops
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
MultiplePatternProcessPass::MultiplePatternProcessPass(const std::string &name, bool multigraph)
|
||||
: NodePass(name), multigraph_(multigraph), pattern_engine_(PatternEngine(std::make_shared<Visitor>())) {}
|
||||
|
||||
AnfNodePtr MultiplePatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
if (patterns_.empty()) {
|
||||
VarPtr fg = std::make_shared<Var>("RootG");
|
||||
auto patterns = std::move(DefinePatterns());
|
||||
for (const auto &pattern : patterns) {
|
||||
auto primitive_var = std::make_shared<PrimitiveVarMap>();
|
||||
this->patterns_[pattern.first] = (SexpToNode(pattern.second, fg, primitive_var.get(), multigraph_));
|
||||
this->primitive_var_maps_[pattern.first] = primitive_var;
|
||||
}
|
||||
}
|
||||
|
||||
auto empty_equiv = std::make_shared<Equiv>();
|
||||
MS_ASSERT(primitive_var_maps_.size() == patterns_.size());
|
||||
for (const auto &iter : primitive_var_maps_) {
|
||||
auto name = iter.first;
|
||||
auto primitive_var = iter.second;
|
||||
auto pattern = this->patterns_[name];
|
||||
MS_EXCEPTION_IF_NULL(primitive_var);
|
||||
MS_EXCEPTION_IF_NULL(pattern);
|
||||
EquivPtr equiv = pattern_engine_.Match(pattern, node, *primitive_var, empty_equiv);
|
||||
if (equiv != nullptr && !equiv->empty()) {
|
||||
return Process(name, func_graph, node, equiv);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_MULTIPLE_PATTERN_PROCESS_PASS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_MULTIPLE_PATTERN_PROCESS_PASS_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include "backend/optimizer/common/node_pass.h"
|
||||
#include "backend/optimizer/common/pattern_engine.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class MultiplePatternProcessPass : public NodePass {
|
||||
public:
|
||||
explicit MultiplePatternProcessPass(const std::string &name = "", bool multigraph = true);
|
||||
~MultiplePatternProcessPass() override = default;
|
||||
virtual AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const = 0;
|
||||
virtual std::unordered_map<std::string, VectorRef> DefinePatterns() const = 0;
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, BaseRef> patterns_;
|
||||
std::unordered_map<std::string, PrimitiveVarMapPtr> primitive_var_maps_;
|
||||
bool multigraph_ = true;
|
||||
PatternEngine pattern_engine_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_COMMON_MULTIPLE_PATTERN_PROCESS_PASS_H_
|
|
@ -13,8 +13,9 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/fusion/tf_multi_head_attention_fusion.h"
|
||||
#include "tools/optimizer/fusion/multi_head_attention_fusion.h"
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
|
@ -22,27 +23,108 @@ namespace {
|
|||
const auto &p1 = std::placeholders::_1;
|
||||
} // namespace
|
||||
|
||||
TfMultiHeadAttentionFusion::TfMultiHeadAttentionFusion(const string &name, bool multigraph)
|
||||
: PatternProcessPass(name, multigraph) {
|
||||
MultiHeadAttentionFusion::MultiHeadAttentionFusion(const string &name, bool multigraph)
|
||||
: MultiplePatternProcessPass(name, multigraph) {
|
||||
input_q_ = std::make_shared<Var>();
|
||||
input_k_ = std::make_shared<Var>();
|
||||
input_v_ = std::make_shared<Var>();
|
||||
|
||||
weight_q_ = std::make_shared<Var>();
|
||||
weight_k_ = std::make_shared<Var>();
|
||||
weight_v_ = std::make_shared<Var>();
|
||||
weight_o_ = std::make_shared<Var>();
|
||||
weight_q_ = std::make_shared<CondVar>(IsParamNode);
|
||||
weight_k_ = std::make_shared<CondVar>(IsParamNode);
|
||||
weight_v_ = std::make_shared<CondVar>(IsParamNode);
|
||||
weight_o_ = std::make_shared<CondVar>(IsParamNode);
|
||||
|
||||
bias_q_ = std::make_shared<Var>();
|
||||
bias_k_ = std::make_shared<Var>();
|
||||
bias_v_ = std::make_shared<Var>();
|
||||
bias_o_ = std::make_shared<Var>();
|
||||
bias_q_ = std::make_shared<CondVar>(IsParamNode);
|
||||
bias_k_ = std::make_shared<CondVar>(IsParamNode);
|
||||
bias_v_ = std::make_shared<CondVar>(IsParamNode);
|
||||
bias_o_ = std::make_shared<CondVar>(IsParamNode);
|
||||
|
||||
mask_ = std::make_shared<Var>();
|
||||
|
||||
reshape_k_ = std::make_shared<Var>();
|
||||
reshape_v_ = std::make_shared<Var>();
|
||||
}
|
||||
|
||||
const BaseRef TfMultiHeadAttentionFusion::DefinePattern() const {
|
||||
namespace {
|
||||
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) {
|
||||
auto dense = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), input, weight, bias});
|
||||
auto reshape =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), dense, std::make_shared<Var>()});
|
||||
return VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), reshape,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
}
|
||||
|
||||
VectorRef DefineMask(const BaseRef &mask_input) {
|
||||
auto expand_dims = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims)), mask_input,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
auto sub = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSubFusion)),
|
||||
std::make_shared<CondVar>(IsParamNode), expand_dims});
|
||||
return VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion)), sub,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern() const {
|
||||
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_);
|
||||
auto k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_);
|
||||
auto v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_);
|
||||
auto q2k =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), q_embedding, k_embedding});
|
||||
auto q2k_normed = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion)), q2k,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
auto mask = DefineMask(mask_);
|
||||
auto q2k_normed_masked =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion)), q2k_normed, mask});
|
||||
auto softmax = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax)), q2k_normed_masked});
|
||||
auto softmax2v =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), softmax, v_embedding});
|
||||
auto softmax2v_transposed = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)),
|
||||
softmax2v, std::make_shared<CondVar>(IsParamNode)});
|
||||
auto softmax2v_transposed_reshaped =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), softmax2v_transposed,
|
||||
std::make_shared<Var>()});
|
||||
return VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)),
|
||||
softmax2v_transposed_reshaped, weight_o_, bias_o_});
|
||||
}
|
||||
|
||||
namespace {
|
||||
VectorRef DefineDensePattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) {
|
||||
auto transpose = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), input,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
auto reshape1 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
auto matmul = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), reshape1, weight});
|
||||
auto reshape2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), matmul,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
if (bias == nullptr) {
|
||||
return reshape2;
|
||||
}
|
||||
auto bias_add = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimBiasAdd)), reshape2, bias});
|
||||
return bias_add;
|
||||
}
|
||||
|
||||
VectorRef DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias,
|
||||
const BaseRef &reshape_shape, bool transpose = false) {
|
||||
auto input_after_dense = DefineDensePattern(input, weight, bias);
|
||||
auto result = VectorRef(
|
||||
{std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), input_after_dense, reshape_shape});
|
||||
if (transpose) {
|
||||
result = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), result,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
VectorRef DefineProcessOutputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) {
|
||||
auto transpose = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), input,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
auto reshape = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
return DefineDensePattern(reshape, weight, bias);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithoutMaskPattern() const {
|
||||
auto query = DefineProcessInputPattern(input_q_, weight_q_, bias_q_, std::make_shared<CondVar>(IsParamNode));
|
||||
auto query_div = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDivFusion)), query,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
|
@ -59,52 +141,27 @@ const BaseRef TfMultiHeadAttentionFusion::DefinePattern() const {
|
|||
return DefineProcessOutputPattern(softmax_mul_val, weight_o_, bias_o_);
|
||||
}
|
||||
|
||||
const AnfNodePtr TfMultiHeadAttentionFusion::Process(const mindspore::FuncGraphPtr &func_graph,
|
||||
const mindspore::AnfNodePtr &node,
|
||||
const mindspore::EquivPtr &equiv) const {
|
||||
return CreateMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), 0);
|
||||
std::unordered_map<std::string, VectorRef> MultiHeadAttentionFusion::DefinePatterns() const {
|
||||
std::unordered_map<std::string, VectorRef> patterns;
|
||||
patterns[kMPAWithoutMaskPatternName] = DefineMPWithoutMaskPattern();
|
||||
patterns[kMPAWithMaskPatternName] = DefineMPWithMaskPattern();
|
||||
return patterns;
|
||||
}
|
||||
|
||||
const VectorRef TfMultiHeadAttentionFusion::DefineDensePattern(const BaseRef &input, const BaseRef &weight,
|
||||
const BaseRef &bias) const {
|
||||
auto transpose = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), input,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
auto reshape1 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
auto matmul = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), reshape1, weight});
|
||||
auto reshape2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), matmul,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
if (bias == nullptr) {
|
||||
return reshape2;
|
||||
AnfNodePtr MultiHeadAttentionFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
|
||||
const mindspore::AnfNodePtr &node,
|
||||
const mindspore::EquivPtr &equiv) const {
|
||||
if (pattern_name == kMPAWithoutMaskPatternName) {
|
||||
return CreateMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), 0);
|
||||
} else if (pattern_name == kMPAWithMaskPatternName) {
|
||||
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), 0);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
auto bias_add = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimBiasAdd)), reshape2, bias});
|
||||
return bias_add;
|
||||
}
|
||||
|
||||
const VectorRef TfMultiHeadAttentionFusion::DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight,
|
||||
const BaseRef &bias, const BaseRef &reshape_shape,
|
||||
bool transpose) const {
|
||||
auto input_after_dense = DefineDensePattern(input, weight, bias);
|
||||
auto result = VectorRef(
|
||||
{std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), input_after_dense, reshape_shape});
|
||||
if (transpose) {
|
||||
result = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), result,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
const VectorRef TfMultiHeadAttentionFusion::DefineProcessOutputPattern(const BaseRef &input, const BaseRef &weight,
|
||||
const BaseRef &bias) const {
|
||||
auto transpose = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), input,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
auto reshape = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), transpose,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
return DefineDensePattern(reshape, weight, bias);
|
||||
}
|
||||
|
||||
CNodePtr TfMultiHeadAttentionFusion::CreateMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const std::string &base_name, int var_offset) const {
|
||||
CNodePtr MultiHeadAttentionFusion::CreateMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const std::string &base_name, int var_offset) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
auto attention_prim = BuildAttentionPrim(equiv);
|
||||
|
@ -151,14 +208,14 @@ STATUS GetIntParameterData(const ParameterPtr ¶m_ptr, std::vector<int> *resu
|
|||
}
|
||||
auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
|
||||
int64_t shape_size =
|
||||
std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<int>());
|
||||
std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>());
|
||||
for (int64_t i = 0; i < shape_size; i++) {
|
||||
result->emplace_back(ptr[i]);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::shared_ptr<ops::Attention> TfMultiHeadAttentionFusion::BuildAttentionPrim(const EquivPtr &equiv) const {
|
||||
std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(const EquivPtr &equiv) const {
|
||||
auto attention_prim = std::make_shared<ops::Attention>();
|
||||
if (attention_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Build attention primitive failed.";
|
||||
|
@ -191,9 +248,39 @@ std::shared_ptr<ops::Attention> TfMultiHeadAttentionFusion::BuildAttentionPrim(c
|
|||
MS_LOG(ERROR) << "Shape k or shape v is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
attention_prim->set_num_heads(shape_k.at(shape_k.size() - 2));
|
||||
attention_prim->set_key_dim(shape_k.at(shape_k.size() - 1));
|
||||
attention_prim->set_value_dim(shape_v.at(shape_v.size() - 1));
|
||||
return attention_prim;
|
||||
}
|
||||
|
||||
CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph,
|
||||
const EquivPtr &equiv, const string &base_name,
|
||||
int var_offset) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
auto attention_prim = std::make_shared<ops::Attention>();
|
||||
if (attention_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Build attention primitive failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto value_node = NewValueNode(attention_prim);
|
||||
auto input_q = utils::cast<AnfNodePtr>((*equiv)[input_q_]);
|
||||
auto input_k = utils::cast<AnfNodePtr>((*equiv)[input_k_]);
|
||||
auto input_v = utils::cast<AnfNodePtr>((*equiv)[input_v_]);
|
||||
|
||||
auto weight_q = utils::cast<AnfNodePtr>((*equiv)[weight_q_]);
|
||||
auto weight_k = utils::cast<AnfNodePtr>((*equiv)[weight_k_]);
|
||||
auto weight_v = utils::cast<AnfNodePtr>((*equiv)[weight_v_]);
|
||||
auto weight_o = utils::cast<AnfNodePtr>((*equiv)[weight_o_]);
|
||||
|
||||
auto bias_q = utils::cast<AnfNodePtr>((*equiv)[bias_q_]);
|
||||
auto bias_k = utils::cast<AnfNodePtr>((*equiv)[bias_k_]);
|
||||
auto bias_v = utils::cast<AnfNodePtr>((*equiv)[bias_v_]);
|
||||
auto bias_o = utils::cast<AnfNodePtr>((*equiv)[bias_o_]);
|
||||
auto mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
|
||||
|
||||
std::vector<AnfNodePtr> new_node_inputs = {value_node, input_q, input_k, input_v, weight_q, weight_k, weight_v,
|
||||
weight_o, bias_q, bias_k, bias_v, bias_o, mask};
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
new_node->set_fullname_with_scope(base_name);
|
||||
return new_node;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -19,31 +19,40 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include <unordered_map>
|
||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "tools/converter/ops/attention.h"
|
||||
#include "ops/attention.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TfMultiHeadAttentionFusion : public PatternProcessPass {
|
||||
class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
|
||||
public:
|
||||
explicit TfMultiHeadAttentionFusion(const std::string &name = "tflite_multi_head_attention_fusion",
|
||||
bool multigraph = true);
|
||||
~TfMultiHeadAttentionFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
explicit MultiHeadAttentionFusion(const std::string &name = "multi_head_attention_fusion", bool multigraph = true);
|
||||
|
||||
~MultiHeadAttentionFusion() override = default;
|
||||
|
||||
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
|
||||
|
||||
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
|
||||
const EquivPtr &) const override;
|
||||
|
||||
protected:
|
||||
const VectorRef DefineDensePattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias) const;
|
||||
virtual const VectorRef DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias,
|
||||
const BaseRef &reshape_shape, bool transpose = false) const;
|
||||
virtual const VectorRef DefineProcessOutputPattern(const BaseRef &input, const BaseRef &weight,
|
||||
const BaseRef &bias) const;
|
||||
|
||||
// define patterns
|
||||
VectorRef DefineMPWithMaskPattern() const;
|
||||
VectorRef DefineMPWithoutMaskPattern() const;
|
||||
// create multi-head-attention without mask
|
||||
virtual std::shared_ptr<ops::Attention> BuildAttentionPrim(const EquivPtr &equiv) const;
|
||||
CNodePtr CreateMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const std::string &base_name, int var_offset) const;
|
||||
virtual std::shared_ptr<ops::Attention> BuildAttentionPrim(const EquivPtr &equiv) const;
|
||||
// create masked-multi-head-attention
|
||||
CNodePtr CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const std::string &base_name, int var_offset) const;
|
||||
|
||||
protected:
|
||||
const std::string kMPAWithoutMaskPatternName = "MPAWithoutMaskPattern";
|
||||
const std::string kMPAWithMaskPatternName = "MPAWithMaskPattern";
|
||||
|
||||
VarPtr input_q_;
|
||||
VarPtr input_k_;
|
||||
|
@ -58,6 +67,8 @@ class TfMultiHeadAttentionFusion : public PatternProcessPass {
|
|||
VarPtr bias_v_;
|
||||
VarPtr bias_o_;
|
||||
|
||||
VarPtr mask_;
|
||||
|
||||
VarPtr reshape_k_;
|
||||
VarPtr reshape_v_;
|
||||
};
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/fusion/reshape_reshape_fusion.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/reshape.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
const auto &p1 = std::placeholders::_1;
|
||||
} // namespace
|
||||
|
||||
const BaseRef ReshapeReshapeFusion::DefinePattern() const {
|
||||
auto reshape1 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), reshape_input_,
|
||||
std::make_shared<CondVar>(IsParamNode)});
|
||||
return VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape)), reshape1, reshape_shape_});
|
||||
}
|
||||
|
||||
const AnfNodePtr ReshapeReshapeFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
auto reshape_prim = std::make_shared<ops::Reshape>();
|
||||
if (reshape_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Build reshape primitive failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto value_node = NewValueNode(reshape_prim);
|
||||
auto input = utils::cast<AnfNodePtr>((*equiv)[reshape_input_]);
|
||||
auto shape = utils::cast<AnfNodePtr>((*equiv)[reshape_shape_]);
|
||||
if (input == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "Cannot find reshape input and weight.";
|
||||
return nullptr;
|
||||
}
|
||||
// create scale op
|
||||
auto new_reshape = func_graph->NewCNode({value_node, input, shape});
|
||||
if (new_reshape == nullptr) {
|
||||
MS_LOG(ERROR) << "Create new reshape cnode failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return new_reshape;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_RESHAPE_RESHAPE_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_RESHAPE_RESHAPE_FUSION_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ReshapeReshapeFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit ReshapeReshapeFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion")
|
||||
: PatternProcessPass(name, multigraph) {}
|
||||
~ReshapeReshapeFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
VarPtr reshape_input_ = std::make_shared<Var>();
|
||||
VarPtr reshape_shape_ = std::make_shared<Var>();
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_RESHAPE_RESHAPE_FUSION_H_
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h"
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/quant_param_holder.h"
|
||||
#include "mindspore/core/ops/transpose.h"
|
||||
|
@ -22,10 +23,11 @@
|
|||
namespace mindspore::opt {
|
||||
namespace {
|
||||
const auto &p1 = std::placeholders::_1;
|
||||
|
||||
} // namespace
|
||||
|
||||
TfliteRelPosMultiHeadAttentionFusion::TfliteRelPosMultiHeadAttentionFusion(const string &name, bool multigraph)
|
||||
: TfMultiHeadAttentionFusion(name, multigraph) {
|
||||
: MultiHeadAttentionFusion(name, multigraph) {
|
||||
query_u_ = std::make_shared<Var>();
|
||||
query_v_ = std::make_shared<Var>();
|
||||
input_p_ = std::make_shared<Var>();
|
||||
|
@ -44,7 +46,7 @@ TfliteRelPosMultiHeadAttentionFusion::TfliteRelPosMultiHeadAttentionFusion(const
|
|||
}
|
||||
}
|
||||
|
||||
const BaseRef TfliteRelPosMultiHeadAttentionFusion::DefinePattern() const {
|
||||
std::unordered_map<std::string, VectorRef> TfliteRelPosMultiHeadAttentionFusion::DefinePatterns() const {
|
||||
auto query = DefineProcessInputPattern(input_q_, weight_q_, bias_q_, query_stack_params_, query_prim_);
|
||||
auto query_with_bias_u =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion)), query, query_u_});
|
||||
|
@ -77,11 +79,15 @@ const BaseRef TfliteRelPosMultiHeadAttentionFusion::DefinePattern() const {
|
|||
auto value = DefineProcessInputPattern(input_v_, weight_v_, bias_v_, value_stack_params_, value_prim_, true);
|
||||
auto output =
|
||||
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMul)), logits_softmax, value});
|
||||
return DefineProcessOutputPattern(output, weight_o_, bias_o_);
|
||||
auto pattern = DefineProcessOutputPattern(output, weight_o_, bias_o_);
|
||||
std::unordered_map<std::string, VectorRef> patterns;
|
||||
patterns.insert(std::make_pair(kRPMHAttentionPatternName, pattern));
|
||||
return patterns;
|
||||
}
|
||||
|
||||
const AnfNodePtr TfliteRelPosMultiHeadAttentionFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
AnfNodePtr TfliteRelPosMultiHeadAttentionFusion::Process(const std::string &pattern_name,
|
||||
const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
return CreateRelPosMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope());
|
||||
}
|
||||
|
||||
|
@ -139,25 +145,6 @@ std::shared_ptr<ops::Attention> TfliteRelPosMultiHeadAttentionFusion::BuildAtten
|
|||
}
|
||||
shape_k.emplace_back(dim);
|
||||
}
|
||||
|
||||
std::vector<int> shape_v;
|
||||
for (auto &value_stack_param : value_stack_params_) {
|
||||
auto reshape_v = utils::cast<ParameterPtr>((*equiv)[value_stack_param]);
|
||||
int dim;
|
||||
if (RET_OK != GetIntParameterData(reshape_v, &dim)) {
|
||||
MS_LOG(ERROR) << "Get reshape k data failed";
|
||||
return nullptr;
|
||||
}
|
||||
shape_v.emplace_back(dim);
|
||||
}
|
||||
|
||||
if (shape_k.size() < 2 || shape_v.size() < 2 || shape_k.at(shape_k.size() - 2) != shape_v.at(shape_v.size() - 2)) {
|
||||
MS_LOG(ERROR) << "Shape k or shape v is invalid.";
|
||||
return nullptr;
|
||||
}
|
||||
attention_prim->set_num_heads(shape_k.at(shape_k.size() - 2));
|
||||
attention_prim->set_key_dim(shape_k.at(shape_k.size() - 1));
|
||||
attention_prim->set_value_dim(shape_v.at(shape_v.size() - 1));
|
||||
return attention_prim;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,22 +19,24 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "utils/utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "tools/optimizer/fusion/tf_multi_head_attention_fusion.h"
|
||||
#include "tools/optimizer/fusion/multi_head_attention_fusion.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
class TfliteRelPosMultiHeadAttentionFusion : public TfMultiHeadAttentionFusion {
|
||||
class TfliteRelPosMultiHeadAttentionFusion : public MultiHeadAttentionFusion {
|
||||
public:
|
||||
explicit TfliteRelPosMultiHeadAttentionFusion(const std::string &name = "tflite_rel_pos_multi_head_attention_fusion",
|
||||
bool multigraph = true);
|
||||
~TfliteRelPosMultiHeadAttentionFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
|
||||
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
|
||||
const EquivPtr &) const override;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ops::Attention> BuildAttentionPrim(const EquivPtr &equiv) const;
|
||||
std::shared_ptr<ops::Attention> BuildAttentionPrim(const EquivPtr &equiv) const override;
|
||||
|
||||
const VectorRef DefineProcessInputPattern(const BaseRef &input, const BaseRef &weight, const BaseRef &bias,
|
||||
const std::vector<VarPtr> &stack_params, const VarPtr &full_connect_prim,
|
||||
|
@ -45,6 +47,8 @@ class TfliteRelPosMultiHeadAttentionFusion : public TfMultiHeadAttentionFusion {
|
|||
const std::string &base_name) const;
|
||||
const VectorRef DefineRelativeShiftPattern(const BaseRef &input) const;
|
||||
|
||||
private:
|
||||
const std::string kRPMHAttentionPatternName = "RPMHAttentionPattern";
|
||||
VarPtr query_u_;
|
||||
VarPtr query_v_;
|
||||
VarPtr query_prim_;
|
||||
|
|
Loading…
Reference in New Issue