Adding encoder layer fusion
This commit is contained in:
parent
f4752a8ab8
commit
72a341b321
|
@ -125,7 +125,7 @@ static BaseRef GetVar(const BaseRef &x) {
|
|||
EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
|
||||
MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
|
||||
|
||||
if (equiv == nullptr) MS_EXCEPTION_IF_NULL(equiv);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
if (utils::isa<VarPtr>(pattern)) {
|
||||
VarPtr var = utils::cast<VarPtr>(pattern);
|
||||
if (var->matches(expr)) {
|
||||
|
@ -222,7 +222,9 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorR
|
|||
return nullptr;
|
||||
}
|
||||
}
|
||||
if ((values_expr_len == 0) && (values_pattern_len == 0)) return equiv;
|
||||
if ((values_expr_len == 0) && (values_pattern_len == 0)) {
|
||||
return equiv;
|
||||
}
|
||||
if (values_expr_len < values_pattern_len - 1) {
|
||||
MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len;
|
||||
return nullptr;
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <stdio.h>
|
||||
#include "nnacl/infer/encoder_layer_infer.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
|
||||
int EncoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C9NUM, C1NUM);
|
||||
if (check_ret != NNACL_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
const TensorC *input = inputs[FIRST_INPUT];
|
||||
TensorC *output0 = outputs[FIRST_INPUT];
|
||||
SetDataTypeFormat(output0, input);
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
SetShapeTensor(output0, input);
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
REG_INFER(EncoderLayer, PrimType_Inner_EncoderLayer, EncoderLayerInferShape)
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_
|
||||
|
||||
#include "nnacl/infer/common_infer.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
int EncoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_
|
|
@ -31,6 +31,7 @@
|
|||
#include "nnacl/infer/assign_add_infer.h"
|
||||
#include "nnacl/infer/assign_infer.h"
|
||||
#include "nnacl/infer/attention_infer.h"
|
||||
#include "nnacl/infer/encoder_layer_infer.h"
|
||||
#include "nnacl/infer/audio_spectrogram_infer.h"
|
||||
#include "nnacl/infer/batch_to_space_infer.h"
|
||||
#include "nnacl/infer/bias_grad_infer.h"
|
||||
|
@ -402,6 +403,8 @@ void RegAllInferFunc5() {
|
|||
g_inner_op_infer_func[PrimType_Inner_Identity - PrimType_InnerOpMin] = NULL;
|
||||
#ifndef RUNTIME_PASS_CLIP
|
||||
g_inner_op_infer_func[PrimType_Inner_ShapeFusion - PrimType_InnerOpMin] = ShapeFusionInferShape;
|
||||
g_inner_op_infer_func[PrimType_Inner_EncoderLayer - PrimType_InnerOpMin] = EncoderLayerInferShape;
|
||||
|
||||
#endif
|
||||
g_inner_op_infer_func[PrimType_Inner_ToFormat - PrimType_InnerOpMin] = NULL;
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_OP_BASE_H_
|
||||
#define MINDSPORE_NNACL_OP_BASE_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
|
@ -37,8 +37,10 @@
|
|||
#define C8NUM 8
|
||||
#define C9NUM 9
|
||||
#define C10NUM 10
|
||||
#define C11NUM 11
|
||||
#define C12NUM 12
|
||||
#define C13NUM 13
|
||||
#define C14NUM 14
|
||||
#define C16NUM 16
|
||||
#define C20NUM 20
|
||||
#define C21NUM 21
|
||||
|
@ -533,6 +535,7 @@ enum PrimType {
|
|||
PrimType_Inner_ShapeFusion = 10003,
|
||||
PrimType_Inner_GraphKernel = 10004,
|
||||
PrimType_Inner_SplitReduceConcatFusion = 10005,
|
||||
PrimType_Inner_EncoderLayer = 10006,
|
||||
PrimType_InnerOpMax,
|
||||
PrimType_InnerOpMin = PrimType_Inner_ToFormat
|
||||
};
|
||||
|
@ -660,4 +663,4 @@ typedef enum CalFixedMultiplierMode {
|
|||
Method_DoublePrecision
|
||||
} CalFixedMultiplierMode;
|
||||
|
||||
#endif // MINDSPORE_NNACL_OP_BASE_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_
|
||||
|
|
|
@ -31,6 +31,10 @@ void Attention::set_head_size(int64_t head_size) {
|
|||
|
||||
void Attention::set_cross(bool cross) { (void)this->AddAttr(kCross, api::MakeValue(cross)); }
|
||||
|
||||
void Attention::set_position_bias(bool position_bias) {
|
||||
(void)this->AddAttr(kPositionBias, api::MakeValue(position_bias));
|
||||
}
|
||||
|
||||
int64_t Attention::get_head_num() const {
|
||||
auto value_ptr = this->GetAttr(kAttentionNumHeads);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
|
@ -46,10 +50,16 @@ bool Attention::get_cross() const {
|
|||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void Attention::Init(int64_t head_num, int64_t head_size, bool cross) {
|
||||
bool Attention::get_position_bias() const {
|
||||
auto value_ptr = this->GetAttr(kPositionBias);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void Attention::Init(int64_t head_num, int64_t head_size, bool position_bias, bool cross) {
|
||||
this->set_head_num(head_num);
|
||||
this->set_head_size(head_size);
|
||||
this->set_cross(cross);
|
||||
this->set_position_bias(position_bias);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
|
||||
} // namespace mindspore::ops
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ATTENTION_H_
|
||||
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ATTENTION_H_
|
||||
#ifndef MINDSPORE_CORE_OPS_ATTENTION_H_
|
||||
#define MINDSPORE_CORE_OPS_ATTENTION_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
@ -40,14 +40,17 @@ class MIND_API Attention : public BaseOperator {
|
|||
/// \param[in] head_num Define head number.
|
||||
/// \param[in] head_size Define size per head.
|
||||
/// \param[in] cross Define is cross attention. Default false.
|
||||
void Init(int64_t head_num, int64_t head_size, bool cross = false);
|
||||
/// \param[in] position_bias Define is position bias attention.
|
||||
void Init(int64_t head_num, int64_t head_size, bool position_bias, bool cross = false);
|
||||
void set_head_num(int64_t head_num);
|
||||
void set_head_size(int64_t head_size);
|
||||
void set_cross(bool cross);
|
||||
void set_position_bias(bool position_bias);
|
||||
int64_t get_head_num() const;
|
||||
int64_t get_head_size() const;
|
||||
bool get_cross() const;
|
||||
bool get_position_bias() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ATTENTION_H_
|
||||
#endif // MINDSPORE_CORE_OPS_ATTENTION_H_
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/encoder_layer.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore::ops {
|
||||
MIND_API_OPERATOR_IMPL(EncoderLayer, BaseOperator);
|
||||
|
||||
void EncoderLayer::set_head_num(int64_t head_num) {
|
||||
(void)this->AddAttr(kEncoderLayerNumHeads, api::MakeValue(head_num));
|
||||
}
|
||||
|
||||
void EncoderLayer::set_head_size(int64_t head_size) {
|
||||
(void)this->AddAttr(kEncoderLayerSizePerHead, api::MakeValue(head_size));
|
||||
}
|
||||
|
||||
void EncoderLayer::set_post_layernorm(bool post_layernorm) {
|
||||
(void)this->AddAttr(kEncoderLayerPostLayernorm, api::MakeValue(post_layernorm));
|
||||
}
|
||||
void EncoderLayer::set_eps_layernorm1(float eps_layernorm1) {
|
||||
(void)this->AddAttr(kEncoderLayerEpsLayerNorm1, api::MakeValue(eps_layernorm1));
|
||||
}
|
||||
void EncoderLayer::set_eps_layernorm2(float eps_layernorm2) {
|
||||
(void)this->AddAttr(kEncoderLayerEpsLayerNorm2, api::MakeValue(eps_layernorm2));
|
||||
}
|
||||
void EncoderLayer::set_ffn_hidden_size(int64_t ffn_hidden_size) {
|
||||
(void)this->AddAttr(kEncoderLayerFfnHiddenSize, api::MakeValue(ffn_hidden_size));
|
||||
}
|
||||
void EncoderLayer::set_position_bias(bool position_bias) {
|
||||
(void)this->AddAttr(kPositionBias, api::MakeValue(position_bias));
|
||||
}
|
||||
|
||||
int64_t EncoderLayer::get_head_num() const {
|
||||
auto value_ptr = this->GetAttr(kEncoderLayerNumHeads);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t EncoderLayer::get_head_size() const {
|
||||
auto value_ptr = this->GetAttr(kEncoderLayerSizePerHead);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
bool EncoderLayer::get_post_layernorm() const {
|
||||
auto value_ptr = this->GetAttr(kEncoderLayerPostLayernorm);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
float EncoderLayer::get_eps_layernorm1() const {
|
||||
auto value_ptr = this->GetAttr(kEncoderLayerEpsLayerNorm1);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
float EncoderLayer::get_eps_layernorm2() const {
|
||||
auto value_ptr = this->GetAttr(kEncoderLayerEpsLayerNorm2);
|
||||
return GetValue<float>(value_ptr);
|
||||
}
|
||||
int64_t EncoderLayer::get_ffn_hidden_size() const {
|
||||
auto value_ptr = this->GetAttr(kEncoderLayerFfnHiddenSize);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
bool EncoderLayer::get_position_bias() const {
|
||||
auto value_ptr = this->GetAttr(kPositionBias);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void EncoderLayer::Init(int64_t head_num, int64_t head_size, float eps_layernorm1, float eps_layernorm2,
|
||||
int64_t ffn_hidden_size, bool position_bias, bool post_layernorm = false) {
|
||||
this->set_head_num(head_num);
|
||||
this->set_head_size(head_size);
|
||||
this->set_post_layernorm(post_layernorm);
|
||||
this->set_eps_layernorm1(eps_layernorm1);
|
||||
this->set_eps_layernorm2(eps_layernorm2);
|
||||
this->set_ffn_hidden_size(ffn_hidden_size);
|
||||
this->set_position_bias(position_bias);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameEncoderLayer, EncoderLayer);
|
||||
} // namespace mindspore::ops
|
|
@ -0,0 +1,65 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CORE_OPS_ENCODER_LAYER_H_
|
||||
#define MINDSPORE_CORE_OPS_ENCODER_LAYER_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameEncoderLayer = "EncoderLayer";
|
||||
/// \brief EncoderLayer op in MindIR.
|
||||
class MIND_API EncoderLayer : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(EncoderLayer);
|
||||
/// \brief Constructor.
|
||||
EncoderLayer() : BaseOperator(kNameEncoderLayer) {
|
||||
InitIOName({"input", "gamma1", "beta1", "weight_attn_qkv", "bias_attn_qkv", "mask", "weight_attn_o", "bias_attn_o",
|
||||
"gamma2", "beta2", "weight_m", "bias_m", "weight_p", "bias_p"},
|
||||
{"output"});
|
||||
}
|
||||
/// \brief Initialize EncoderLayer op.
|
||||
/// \param[in] head_num Define head number.
|
||||
/// \param[in] head_size Define size per head.
|
||||
/// \param[in] eps_layernorm1 Define eps layernorm1.
|
||||
/// \param[in] eps_layernorm2 Define eps layernorm2.
|
||||
/// \param[in] ffn_hidden_size Define ffn hidden size.
|
||||
/// \param[in] position_bias Define ffn position_bias.
|
||||
void Init(int64_t head_num, int64_t head_size, float eps_layernorm1, float eps_layernorm2, int64_t ffn_hidden_size,
|
||||
bool position_bias, bool post_layernorm);
|
||||
void set_head_num(int64_t head_num);
|
||||
void set_head_size(int64_t head_size);
|
||||
void set_post_layernorm(bool post_layernorm);
|
||||
void set_eps_layernorm1(float eps_layernorm1);
|
||||
void set_eps_layernorm2(float eps_layernorm2);
|
||||
void set_ffn_hidden_size(int64_t ffn_hidden_size);
|
||||
void set_position_bias(bool position_bias);
|
||||
int64_t get_head_num() const;
|
||||
int64_t get_head_size() const;
|
||||
bool get_post_layernorm() const;
|
||||
float get_eps_layernorm1() const;
|
||||
float get_eps_layernorm2() const;
|
||||
int64_t get_ffn_hidden_size() const;
|
||||
bool get_position_bias() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_ENCODER_LAYER_H_
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_OP_NAME_H
|
||||
#define MINDSPORE_CORE_OPS_OP_NAME_H
|
||||
#ifndef MINDSPORE_CORE_OPS_OP_NAME_H_
|
||||
#define MINDSPORE_CORE_OPS_OP_NAME_H_
|
||||
#include <cstddef>
|
||||
|
||||
namespace mindspore::ops {
|
||||
|
@ -378,6 +378,13 @@ constexpr auto kSampleNum = "sample_num";
|
|||
constexpr auto kRoiEndMode = "roi_end_mode";
|
||||
constexpr auto kUpper = "upper";
|
||||
constexpr auto kConjugate = "conjugate";
|
||||
constexpr auto kEncoderLayerNumHeads = "head_num";
|
||||
constexpr auto kEncoderLayerSizePerHead = "head_size";
|
||||
constexpr auto kEncoderLayerPostLayernorm = "post_layernorm";
|
||||
constexpr auto kEncoderLayerFfnHiddenSize = "ffn_hidden_size";
|
||||
constexpr auto kEncoderLayerEpsLayerNorm1 = "eps_layernorm1";
|
||||
constexpr auto kEncoderLayerEpsLayerNorm2 = "eps_layernorm2";
|
||||
constexpr auto kPositionBias = "position_bias";
|
||||
constexpr auto KExclusive = "exclusive";
|
||||
constexpr auto KReverse = "reverse";
|
||||
constexpr auto KComputeEigenvectors = "compute_eigenvectors";
|
||||
|
@ -414,4 +421,4 @@ constexpr size_t kFormatNC1HWC0IndexW = 3;
|
|||
constexpr size_t kFormatNC1HWC0IndexC0 = 4;
|
||||
enum Dims : size_t { kDim0 = 0, kDim1, kDim2, kDim3, kDim4, kDim5, kDim6, kDim7, kDim8 };
|
||||
} // namespace mindspore::ops
|
||||
#endif // MINDSPORE_CORE_OPS_OP_NAME_H
|
||||
#endif // MINDSPORE_CORE_OPS_OP_NAME_H_
|
||||
|
|
|
@ -83,6 +83,9 @@ class MS_API Converter {
|
|||
void SetNoFusion(bool no_fusion);
|
||||
bool GetNoFusion();
|
||||
|
||||
void SetOptimizeTransformer(bool optimize_transformer);
|
||||
bool GetOptimizeTransformer();
|
||||
|
||||
inline void SetDevice(const std::string &device);
|
||||
inline std::string GetDevice();
|
||||
|
||||
|
|
|
@ -97,6 +97,16 @@ OpParameter *PopulateCustomParameter(const void *prim) {
|
|||
|
||||
param->op_parameter_.type_ = PrimType_Inner_SplitReduceConcatFusion;
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
} else if (type == "EncoderLayer") {
|
||||
std::cout << "EncoderLayer populate" << std::endl;
|
||||
auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc EncoderLayer failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(OpParameter));
|
||||
param->type_ = PrimType_Inner_EncoderLayer;
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported custom type: " << type;
|
||||
}
|
||||
|
|
|
@ -57,10 +57,10 @@ class DelegateRegistrar {
|
|||
~DelegateRegistrar() = default;
|
||||
};
|
||||
|
||||
#define REG_DELEGATE(device_type, provider, creator) \
|
||||
static DelegateCreator func = [=](const std::shared_ptr<Context> &context, const ConfigInfos &config_infos) { \
|
||||
return creator(context, config_infos); \
|
||||
}; \
|
||||
#define REG_DELEGATE(device_type, provider, creator) \
|
||||
static DelegateCreator func = [](const std::shared_ptr<Context> &context, const ConfigInfos &config_infos) { \
|
||||
return creator(context, config_infos); \
|
||||
}; \
|
||||
static DelegateRegistrar g_##device_type##provider##Delegate(device_type, provider, &func);
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -0,0 +1,255 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/extendrt/delegate/tensorrt/op/encoder_tensorrt.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include <numeric>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
|
||||
#include "NvInferRuntimeCommon.h"
|
||||
#include "ops/encoder_layer.h"
|
||||
#include "src/fastertransformer/kernels/unfused_attention_kernels.h"
|
||||
#include "src/fastertransformer/kernels/activation_kernels.h"
|
||||
#include "src/fastertransformer/utils/cuda_utils.h"
|
||||
#include "src/fastertransformer/utils/allocator.h"
|
||||
#include "src/fastertransformer/kernels/layernorm_kernels.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr std::size_t kTwo = 2;
|
||||
constexpr std::size_t kThree = 3;
|
||||
} // namespace
|
||||
|
||||
// Multi Head Attention TensorRT op
|
||||
int EncoderTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
||||
const std::vector<TensorInfo> &out_tensors) {
|
||||
if (in_tensors.size() != C14NUM) {
|
||||
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
nvinfer1::ITensor *EncoderTensorRT::castTensor(TensorRTContext *ctx, const TensorInfo &ms_tensor,
|
||||
const std::string &op_name) {
|
||||
if (ctx == nullptr || ctx->network() == nullptr) {
|
||||
MS_LOG(ERROR) << "context or network is null for ConvertConstantTensor";
|
||||
return nullptr;
|
||||
}
|
||||
nvinfer1::Dims dims = ConvertCudaDims(ms_tensor.Shape());
|
||||
if (dims.nbDims == -1) {
|
||||
MS_LOG(INFO) << ms_tensor.Name() << " ConvertCudaDims failed, convert as scalar.";
|
||||
dims.nbDims = 1;
|
||||
dims.d[0] = 1;
|
||||
}
|
||||
nvinfer1::DataType data_type = ConvertDataType(ms_tensor.DataType());
|
||||
if (!ms_tensor.IsConst()) {
|
||||
MS_LOG(ERROR) << "ConvertConstantTensor from a MSTensor with nullptr data: " << ms_tensor.Name();
|
||||
return nullptr;
|
||||
}
|
||||
nvinfer1::Weights weights{data_type, ms_tensor.Data(), ms_tensor.ElementNum()};
|
||||
if (data_type == nvinfer1::DataType::kFLOAT && is_ffn_fp16_) {
|
||||
void *data_float16 = malloc(ms_tensor.ElementNum() * sizeof(float));
|
||||
if (data_float16 == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc buffer failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto src = static_cast<const float *>(ms_tensor.Data());
|
||||
auto dst = static_cast<half *>(data_float16);
|
||||
for (int i = 0; i < ms_tensor.ElementNum(); i++) {
|
||||
dst[i] = static_cast<half>(src[i]);
|
||||
}
|
||||
weights.values = data_float16;
|
||||
}
|
||||
nvinfer1::IConstantLayer *constant_tensor = ctx->network()->addConstant(dims, weights);
|
||||
if (constant_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "create constant_tensor failed.";
|
||||
return nullptr;
|
||||
}
|
||||
ctx->RegisterLayer(constant_tensor, ms_tensor.Name() + "_" + op_name);
|
||||
auto tensor_ptr = constant_tensor->getOutput(0);
|
||||
return tensor_ptr;
|
||||
}
|
||||
|
||||
int EncoderTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
||||
if (ctx == nullptr || ctx->network() == nullptr) {
|
||||
MS_LOG(ERROR) << "context or network is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto encoder_op = AsOps<ops::EncoderLayer>();
|
||||
if (encoder_op == nullptr) {
|
||||
MS_LOG(ERROR) << "op action convert failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
fastertransformer::encoderParamT params;
|
||||
memset_s(¶ms, sizeof(params), 0, sizeof(params));
|
||||
params.head_num = encoder_op->get_head_num();
|
||||
params.head_size = encoder_op->get_head_size();
|
||||
params.layernorm_post = encoder_op->get_post_layernorm();
|
||||
params.eps1 = encoder_op->get_eps_layernorm1();
|
||||
params.eps2 = encoder_op->get_eps_layernorm2();
|
||||
params.ffn_hidden_size = encoder_op->get_ffn_hidden_size();
|
||||
params.is_cross = false;
|
||||
params.ffn_fp16 = is_ffn_fp16_;
|
||||
params.position_bias = encoder_op->get_position_bias();
|
||||
params.cublas_handle = GetCublasHandle();
|
||||
params.qkv_bias = !params.position_bias;
|
||||
params.projection_bias = !params.position_bias;
|
||||
params.hidden_size = params.head_num * params.head_size;
|
||||
auto compute_type = runtime_->GetRuntimePrecisionMode();
|
||||
if (is_ffn_fp16_) {
|
||||
size_t start_fp16 = (params.layernorm_post) ? C7NUM : C9NUM;
|
||||
size_t end_fp16 = (params.layernorm_post) ? C11NUM : C13NUM;
|
||||
for (size_t i = 0; i < in_tensors_.size(); i++) {
|
||||
auto in_tensor = input(ctx, i);
|
||||
if (in_tensors_[i].IsConst() || in_tensor.trt_tensor_ == nullptr) {
|
||||
if (i > start_fp16 && i < end_fp16) {
|
||||
in_tensor.trt_tensor_ = castTensor(ctx, in_tensors_[i], op_name_);
|
||||
ctx->RegisterTensor(in_tensor, in_tensors_[i].Name());
|
||||
} else {
|
||||
in_tensor.trt_tensor_ = lite::ConvertConstantTensor(ctx, in_tensors_[i], op_name_);
|
||||
ctx->RegisterTensor(in_tensor, in_tensors_[i].Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_;
|
||||
auto plugin =
|
||||
std::make_shared<EncoderPlugin>(input_tensor->getName(), compute_type, params, GetCublasLtHandle(), device_id_);
|
||||
const int input_number = inputs().size();
|
||||
nvinfer1::ITensor *inputTensors[input_number];
|
||||
for (int i = 0; i < input_number; i++) {
|
||||
inputTensors[i] = input(ctx, i).trt_tensor_;
|
||||
}
|
||||
nvinfer1::IPluginV2Layer *encoder_layer = ctx->network()->addPluginV2(inputTensors, input_number, *plugin);
|
||||
if (encoder_layer == nullptr) {
|
||||
MS_LOG(ERROR) << "add encoder op failed for TensorRT.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
encoder_layer->setName((op_name_ + "plugin_encoder_layer").c_str());
|
||||
nvinfer1::ITensor *encoder_tensor = encoder_layer->getOutput(0);
|
||||
ctx->RegisterTensor(ITensorHelper{encoder_tensor, Format::NCHW, true}, out_tensors_[0].Name());
|
||||
this->layer_ = encoder_layer;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_TENSORRT_PLUGIN(EncoderPluginCreater);
|
||||
template class TensorRTPluginCreater<EncoderPlugin>;
|
||||
template <class T>
|
||||
nvinfer1::PluginFieldCollection TensorRTPluginCreater<T>::field_collection_{};
|
||||
template <class T>
|
||||
std::vector<nvinfer1::PluginField> TensorRTPluginCreater<T>::fields_;
|
||||
|
||||
int EncoderPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) noexcept {
|
||||
if (compute_type_ == RuntimePrecisionMode_FP16) {
|
||||
return RunCudaEncoder<half>(inputDesc, outputDesc, inputs, outputs, workspace, stream,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
} else {
|
||||
return RunCudaEncoder<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int EncoderPlugin::RunCudaEncoder(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||
void *const *outputs, void *workspace, cudaStream_t stream, cublasGemmAlgo_t algoId) {
|
||||
params_.stream = stream;
|
||||
params_.algo = algoId;
|
||||
void *inputs_forward[] = {
|
||||
const_cast<void *>(inputs[0]), const_cast<void *>(inputs[1]), const_cast<void *>(inputs[2]),
|
||||
const_cast<void *>(inputs[3]), const_cast<void *>(inputs[4]), const_cast<void *>(inputs[5]),
|
||||
const_cast<void *>(inputs[6]), const_cast<void *>(inputs[7]), const_cast<void *>(inputs[8]),
|
||||
const_cast<void *>(inputs[9]), const_cast<void *>(inputs[10]), const_cast<void *>(inputs[11]),
|
||||
const_cast<void *>(inputs[12]), const_cast<void *>(inputs[13])};
|
||||
void *outputs_forward[] = {outputs[0]};
|
||||
fastertransformer::forwardEncoder<T>(inputs_forward, num_of_inputs_, outputs_forward, num_of_outputs_, ¶ms_,
|
||||
workspace);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
bool EncoderPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
|
||||
int nbOutputs) noexcept {
|
||||
auto type = (compute_type_ == RuntimePrecisionMode_FP16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;
|
||||
for (int i = 0; i < pos; i++) {
|
||||
if (tensorsDesc[pos].type != tensorsDesc[i].type) return false;
|
||||
}
|
||||
bool res = (tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) && (tensorsDesc[pos].type == type);
|
||||
return res;
|
||||
}
|
||||
void EncoderPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {
|
||||
const int request_batch_size = static_cast<const int>(in[0].desc.dims.d[0]);
|
||||
const int request_src_seq_len = static_cast<const int>(in[0].desc.dims.d[1]);
|
||||
const int request_tgt_seq_len = request_src_seq_len;
|
||||
params_.batch_size = request_batch_size;
|
||||
params_.src_seq_len = request_src_seq_len;
|
||||
params_.tgt_seq_len = request_tgt_seq_len;
|
||||
num_of_inputs_ = nbInputs;
|
||||
num_of_outputs_ = nbOutputs;
|
||||
}
|
||||
size_t EncoderPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
|
||||
if (compute_type_ == RuntimePrecisionMode_FP16) {
|
||||
return fastertransformer::GetEncoderLayerWorkspaceSize<half>(¶ms_);
|
||||
} else {
|
||||
return fastertransformer::GetEncoderLayerWorkspaceSize<float>(¶ms_);
|
||||
}
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs EncoderPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs,
|
||||
int nbInputDims, nvinfer1::IExprBuilder &exprBuilder) noexcept {
|
||||
nvinfer1::DimsExprs dims;
|
||||
if (index == 0) {
|
||||
int num_dims = inputs[0].nbDims;
|
||||
dims.nbDims = num_dims;
|
||||
if (num_dims == INPUT_SIZE2) {
|
||||
dims.d[0] = exprBuilder.constant(inputs[0].d[0]->getConstantValue());
|
||||
dims.d[1] = exprBuilder.constant(inputs[0].d[1]->getConstantValue());
|
||||
} else if (num_dims == INPUT_SIZE3) {
|
||||
dims.d[0] = exprBuilder.constant(inputs[0].d[0]->getConstantValue());
|
||||
dims.d[1] = exprBuilder.constant(inputs[0].d[1]->getConstantValue());
|
||||
dims.d[kTwo] = exprBuilder.constant(inputs[0].d[kTwo]->getConstantValue());
|
||||
}
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *EncoderPlugin::clone() const noexcept {
|
||||
auto *plugin = new EncoderPlugin(*this);
|
||||
if (plugin == nullptr) {
|
||||
MS_LOG(ERROR) << "plugin is null";
|
||||
return nullptr;
|
||||
}
|
||||
plugin->setPluginNamespace(name_space_.c_str());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
size_t EncoderPlugin::getSerializationSize() const noexcept {
|
||||
return sizeof(int) + sizeof(fastertransformer::encoderParamT);
|
||||
}
|
||||
|
||||
void EncoderPlugin::serialize(void *buffer) const noexcept {
|
||||
SerializeValue(&buffer, &compute_type_, sizeof(int));
|
||||
SerializeValue(&buffer, ¶ms_, sizeof(fastertransformer::encoderParamT));
|
||||
}
|
||||
REGISTER_TENSORRT_CREATOR(ops::kNameEncoderLayer, EncoderTensorRT)
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ENCODER_TENSORRT_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ENCODER_TENSORRT_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
|
||||
#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h"
|
||||
#include "src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h"
|
||||
#include "src/fastertransformer/layers/encoder_layers/encoder.h"
|
||||
namespace mindspore::lite {
|
||||
class EncoderTensorRT : public TensorRTOp {
|
||||
public:
|
||||
EncoderTensorRT(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
||||
const std::vector<TensorInfo> &out_tensors, std::string name)
|
||||
: TensorRTOp(base_operator, in_tensors, out_tensors, name) {}
|
||||
|
||||
~EncoderTensorRT() override = default;
|
||||
bool IsWeightInputHanledInner() const override { return is_ffn_fp16_; }
|
||||
int AddInnerOp(TensorRTContext *ctx) override;
|
||||
|
||||
int IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
||||
const std::vector<TensorInfo> &out_tensors) override;
|
||||
|
||||
private:
|
||||
nvinfer1::ITensor *castTensor(TensorRTContext *ctx, const TensorInfo &ms_tensor, const std::string &op_name);
|
||||
bool is_ffn_fp16_ = false;
|
||||
};
|
||||
|
||||
constexpr auto ENCODER_PLUGIN_NAME{"EncoderPlugin"};
|
||||
class EncoderPlugin : public TensorRTPlugin {
|
||||
public:
|
||||
EncoderPlugin(const std::string name, int compute_type, fastertransformer::encoderParamT params,
|
||||
cublasLtHandle_t cublaslt_handle, uint32_t device_id)
|
||||
: TensorRTPlugin(name, std::string(ENCODER_PLUGIN_NAME), device_id),
|
||||
compute_type_(compute_type),
|
||||
params_(params),
|
||||
cublaslt_handle_(cublaslt_handle) {}
|
||||
|
||||
EncoderPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
|
||||
: TensorRTPlugin(std::string(name), std::string(ENCODER_PLUGIN_NAME)) {
|
||||
const nvinfer1::PluginField *fields = fc->fields;
|
||||
compute_type_ = static_cast<const int *>(fields[0].data)[0];
|
||||
params_ = static_cast<const fastertransformer::encoderParamT *>(fields[1].data)[0];
|
||||
cublaslt_handle_ = static_cast<const cublasLtHandle_t *>(fields[2].data)[0];
|
||||
}
|
||||
|
||||
EncoderPlugin(const char *name, const void *serialData, size_t serialLength)
|
||||
: TensorRTPlugin(std::string(name), std::string(ENCODER_PLUGIN_NAME)) {
|
||||
DeserializeValue(&serialData, &serialLength, &compute_type_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, ¶ms_, sizeof(fastertransformer::encoderParamT));
|
||||
}
|
||||
|
||||
EncoderPlugin() = delete;
|
||||
|
||||
~EncoderPlugin() override {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
|
||||
size_t getSerializationSize() const noexcept override;
|
||||
void serialize(void *buffer) const noexcept override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
|
||||
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override;
|
||||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
|
||||
int nbOutputs) noexcept override;
|
||||
|
||||
private:
|
||||
const std::string layer_name_;
|
||||
std::string name_space_;
|
||||
int compute_type_;
|
||||
mutable fastertransformer::encoderParamT params_;
|
||||
cublasLtHandle_t cublaslt_handle_;
|
||||
int num_of_inputs_;
|
||||
int num_of_outputs_;
|
||||
|
||||
template <typename T>
|
||||
int RunCudaEncoder(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream,
|
||||
cublasGemmAlgo_t algoId);
|
||||
};
|
||||
class EncoderPluginCreater : public TensorRTPluginCreater<EncoderPlugin> {
|
||||
public:
|
||||
EncoderPluginCreater() : TensorRTPluginCreater(std::string(ENCODER_PLUGIN_NAME)) {}
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ENCODER_TENSORRT_H_
|
|
@ -45,27 +45,11 @@ std::ostream &operator<<(std::ostream &s, const nvinfer1::ITensor &t) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
#define SET_GEMM_PARAMS(gemm_ops_, gemm_lds_, gemm_op1_, gemm_op2_, gemm_ld1_, gemm_ld2_, gemm_ld3_) \
|
||||
do { \
|
||||
gemm_ops_[0] = gemm_op1_; \
|
||||
gemm_ops_[1] = gemm_op2_; \
|
||||
gemm_lds_[0] = gemm_ld1_; \
|
||||
gemm_lds_[1] = gemm_ld2_; \
|
||||
gemm_lds_[2] = gemm_ld3_; \
|
||||
} while (0)
|
||||
|
||||
#define SET_GEMM_DIMS(gemm_dims_, gemm_dim1_, gemm_dim2_, gemm_dim3_) \
|
||||
do { \
|
||||
gemm_dims_[0] = gemm_dim1_; \
|
||||
gemm_dims_[1] = gemm_dim2_; \
|
||||
gemm_dims_[2] = gemm_dim3_; \
|
||||
} while (0)
|
||||
|
||||
// Multi Head Attention TensorRT op
|
||||
int MhaTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
||||
const std::vector<TensorInfo> &out_tensors) {
|
||||
if (in_tensors.size() < 7 || in_tensors.size() > 9) { // T5 has 6 or 7 inputs, other models have 8 or 9 inputs
|
||||
MS_LOG(ERROR) << "Unsupported number of inputs, size is " << in_tensors.size();
|
||||
if (in_tensors.size() < C7NUM || in_tensors.size() > C9NUM) {
|
||||
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -81,16 +65,25 @@ int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
|||
MS_LOG(ERROR) << "op action convert failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// get attribute for Attn op - TODO - add attribute in op
|
||||
int head_number = mha_op->get_head_num();
|
||||
int head_size = mha_op->get_head_size();
|
||||
auto compute_type = runtime_->GetRuntimePrecisionMode(); // mha_op->get_compute_type();
|
||||
auto compute_type = runtime_->GetRuntimePrecisionMode();
|
||||
bool is_cross = mha_op->get_cross();
|
||||
const int input_number = inputs().size();
|
||||
bool is_position_bias = (((input_number == 8) && is_cross) || ((input_number == 7) && !is_cross)) ? true : false;
|
||||
bool is_position_bias = mha_op->get_position_bias();
|
||||
nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_;
|
||||
auto plugin = std::make_shared<MhaPlugin>(input_tensor->getName(), compute_type, head_number, head_size, is_cross,
|
||||
is_position_bias, GetCublasHandle(), GetCublasLtHandle(), device_id_);
|
||||
fastertransformer::encoderParamT params;
|
||||
memset_s(¶ms, sizeof(params), 0, sizeof(params));
|
||||
params.head_num = head_number;
|
||||
params.head_size = head_size;
|
||||
params.hidden_size = head_number * head_size;
|
||||
params.cublas_handle = GetCublasHandle();
|
||||
params.qkv_bias = !is_position_bias;
|
||||
params.projection_bias = !is_position_bias;
|
||||
params.is_cross = is_cross;
|
||||
params.position_bias = is_position_bias;
|
||||
auto plugin =
|
||||
std::make_shared<MhaPlugin>(input_tensor->getName(), compute_type, params, GetCublasLtHandle(), device_id_);
|
||||
const int input_number = inputs().size();
|
||||
nvinfer1::ITensor *inputTensors[input_number];
|
||||
for (int i = 0; i < input_number; i++) {
|
||||
inputTensors[i] = input(ctx, i).trt_tensor_;
|
||||
|
@ -101,17 +94,12 @@ int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
mha_layer->setName((op_name_ + "plugin_attention").c_str());
|
||||
// TODO(haim) one output
|
||||
nvinfer1::ITensor *attn_tensor = mha_layer->getOutput(0);
|
||||
#ifndef TEST_
|
||||
ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name());
|
||||
#else /* TEST_ */
|
||||
ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name() + "attn");
|
||||
#endif /* TEST_ */
|
||||
// nvinfer1::ITensor *key_tensor = mha_layer->getOutput(1);
|
||||
// ctx->RegisterTensor(ITensorHelper{key_tensor, Format::NCHW, true}, out_tensors_[1].Name());
|
||||
// nvinfer1::ITensor *value_tensor = mha_layer->getOutput(kTwo);
|
||||
// ctx->RegisterTensor(ITensorHelper{value_tensor, Format::NCHW, true}, out_tensors_[kTwo].Name());
|
||||
this->layer_ = mha_layer;
|
||||
#ifdef TEST_
|
||||
auto weight_projection = input(ctx, 4).trt_tensor_;
|
||||
|
@ -154,166 +142,51 @@ std::vector<nvinfer1::PluginField> TensorRTPluginCreater<T>::fields_;
|
|||
int MhaPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept {
|
||||
if (compute_type_ == RuntimePrecisionMode_FP16) {
|
||||
return RunCudaMha<half>(inputDesc, outputDesc, inputs, outputs, workspace, stream, fast_algo_gemm);
|
||||
return RunCudaMha<half>(inputDesc, outputDesc, inputs, outputs, workspace, stream, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
} else {
|
||||
return RunCudaMha<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream, fast_algo_gemm);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MhaPlugin::SetInnerAddr(void *workspace, size_t size_q, size_t size_k, size_t qk_buf_len, size_t qkv_buf_2_len,
|
||||
size_t extra_size) {
|
||||
size_t qkv_len = size_q + (size_k * 2); // size_v is equal to size_k
|
||||
size_t q_buf_2_len = size_q;
|
||||
auto buff_size =
|
||||
qkv_len + q_buf_2_len + qk_buf_len + (qkv_buf_2_len * 2); // qkv_buf_3_ len is equal to qkv_buf_2_len
|
||||
qkv_buf_ = workspace;
|
||||
q_buf_2_ = static_cast<T *>(qkv_buf_) + qkv_len;
|
||||
qk_buf_ = static_cast<T *>(q_buf_2_) + q_buf_2_len;
|
||||
qkv_buf_2_ = static_cast<T *>(qk_buf_) + qk_buf_len;
|
||||
qkv_buf_3_ = static_cast<T *>(qkv_buf_2_) + qkv_buf_2_len;
|
||||
output1_ = static_cast<T *>(workspace) + buff_size;
|
||||
output2_ = static_cast<T *>(output1_) + extra_size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MhaPlugin::RunPhase1GEMM(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, int *gemm_dims,
|
||||
int *gemm_lds, cublasOperation_t *gemm_ops, cudaDataType *gemm_data_types, void *alpha,
|
||||
void *beta, cublasGemmAlgo_t algoId, cudaStream_t stream) {
|
||||
int cross_tensor_offset = 0;
|
||||
if (is_cross_) cross_tensor_offset = 1;
|
||||
const int from_tensor_idx = 0, encoder_tensor_idx = 1, weight_qkv_tensor_idx = 3;
|
||||
const int weight_qkv_tensor_idx_cross = 3 + cross_tensor_offset;
|
||||
const int bias_qkv_tensor_idx = 5 + cross_tensor_offset;
|
||||
const int attn_mask_tensor_idx = (is_position_bias_) ? 6 + cross_tensor_offset : 7 + cross_tensor_offset;
|
||||
|
||||
auto from_tensor = static_cast<const T *>(inputs[from_tensor_idx]);
|
||||
auto encoder_output_tensor = static_cast<const T *>(inputs[encoder_tensor_idx]);
|
||||
auto weight_q = static_cast<const T *>(inputs[weight_qkv_tensor_idx]);
|
||||
auto weight_kv = static_cast<const T *>(inputs[weight_qkv_tensor_idx_cross]);
|
||||
auto weight_qkv = static_cast<const T *>(inputs[weight_qkv_tensor_idx_cross]);
|
||||
auto bias_qkv = (is_position_bias_) ? nullptr : static_cast<const T *>(inputs[bias_qkv_tensor_idx]);
|
||||
|
||||
auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims;
|
||||
const int request_batch_size = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[0]);
|
||||
const int request_src_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]);
|
||||
const int request_tgt_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]);
|
||||
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
|
||||
if (is_cross_) {
|
||||
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, hidden_size, hidden_size, hidden_size);
|
||||
SET_GEMM_DIMS(gemm_dims, hidden_size, request_batch_size * request_src_seq_len, hidden_size);
|
||||
CublasGemmWrapper(weight_q, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops,
|
||||
const_cast<const cudaDataType *>(gemm_data_types), alpha, beta, cublas_handle_);
|
||||
SET_GEMM_DIMS(gemm_dims, C2NUM * hidden_size, request_batch_size * request_tgt_seq_len, hidden_size);
|
||||
gemm_lds[0] = gemm_lds[THIRD_INPUT] = C2NUM * hidden_size;
|
||||
|
||||
CublasGemmWrapper(weight_kv, encoder_output_tensor,
|
||||
static_cast<T *>(qkv_buf_) + (request_batch_size * request_src_seq_len) * hidden_size, gemm_dims,
|
||||
gemm_lds, gemm_ops, const_cast<const cudaDataType *>(gemm_data_types), alpha, beta,
|
||||
cublas_handle_);
|
||||
fastertransformer::invokeCrossAddFusedQKVBiasTranspose(
|
||||
static_cast<T *>(q_buf_2_), static_cast<T *>(output1_), static_cast<T *>(output2_), static_cast<T *>(qkv_buf_),
|
||||
bias_qkv, request_batch_size, request_src_seq_len, request_tgt_seq_len, head_number_, head_size_, stream);
|
||||
} else {
|
||||
CublasGemmWrapper(weight_qkv, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops,
|
||||
const_cast<const cudaDataType *>(gemm_data_types), alpha, beta, cublas_handle_, algoId);
|
||||
fastertransformer::invokeAddFusedQKVBiasTranspose(
|
||||
static_cast<T *>(q_buf_2_), static_cast<T *>(output1_), static_cast<T *>(output2_), static_cast<T *>(qkv_buf_),
|
||||
bias_qkv, request_batch_size, request_src_seq_len, head_number_, head_size_, 0, stream);
|
||||
return RunCudaMha<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int MhaPlugin::RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream,
|
||||
cublasGemmAlgo_t *algoId) {
|
||||
// inputs order:
|
||||
// 0] Q
|
||||
// 1] K
|
||||
// 2] V
|
||||
// 3] W
|
||||
// 4] PW
|
||||
// 5] B
|
||||
// 6] PB
|
||||
// 7] AttnMask
|
||||
// inputs order cross:
|
||||
// 0] Q
|
||||
// 1] K enco output
|
||||
// 2] V
|
||||
// 3] Wq
|
||||
// 4] Wkv
|
||||
// 5] PW
|
||||
// 6] Bqkv
|
||||
// 7] PB
|
||||
// 8] AttnMask
|
||||
int cross_tensor_offset = 0;
|
||||
cublasSetStream(cublas_handle_, stream);
|
||||
if (is_cross_) cross_tensor_offset = 1;
|
||||
cublasGemmAlgo_t algoId) {
|
||||
int cross_tensor_offset = (params_.is_cross) ? 1 : 0;
|
||||
const int weight_projection_tensor_idx = 4 + cross_tensor_offset;
|
||||
const int bias_projection_tensor_idx = 6 + cross_tensor_offset;
|
||||
const int attn_mask_tensor_idx = (is_position_bias_) ? 6 + cross_tensor_offset : 7 + cross_tensor_offset;
|
||||
const int bias_position_tensor_idx = 5 + cross_tensor_offset;
|
||||
|
||||
auto attention_mask = static_cast<const T *>(inputs[attn_mask_tensor_idx]);
|
||||
auto weight_projection = static_cast<const T *>(inputs[weight_projection_tensor_idx]);
|
||||
auto bias_projection = (is_position_bias_) ? nullptr : static_cast<const T *>(inputs[bias_projection_tensor_idx]);
|
||||
auto bias_position = (is_position_bias_) ? static_cast<const T *>(inputs[bias_position_tensor_idx]) : nullptr;
|
||||
auto output0 = static_cast<T *>(outputs[0]);
|
||||
auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims;
|
||||
const int request_batch_size = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[0]);
|
||||
const int request_src_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]);
|
||||
const int request_tgt_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]);
|
||||
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
|
||||
auto extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
|
||||
|
||||
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
|
||||
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
|
||||
size_t qk_buf_len = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len;
|
||||
size_t qkv_buf_2_len = request_batch_size * request_src_seq_len * hidden_size;
|
||||
SetInnerAddr<T>(workspace, size_q, size_k, qk_buf_len, qkv_buf_2_len, extra_tmp_size);
|
||||
|
||||
cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N};
|
||||
cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F};
|
||||
if constexpr (std::is_same<T, half>::value)
|
||||
std::fill(std::begin(gemm_data_types), std::end(gemm_data_types), CUDA_R_16F);
|
||||
float alpha = 1.0f, beta = 0.0f;
|
||||
int gemm_dims[] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size};
|
||||
int gemm_lds[] = {3 * hidden_size, hidden_size, 3 * hidden_size};
|
||||
|
||||
RunPhase1GEMM<T>(inputDesc, inputs, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha, &beta, algoId[0], stream);
|
||||
|
||||
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_T, CUBLAS_OP_N, head_size_, head_size_, request_tgt_seq_len);
|
||||
SET_GEMM_DIMS(gemm_dims, request_tgt_seq_len, request_src_seq_len, head_size_);
|
||||
int gemm_strides[] = {request_tgt_seq_len * head_size_, request_src_seq_len * head_size_,
|
||||
request_src_seq_len * request_tgt_seq_len};
|
||||
|
||||
CublasGemmStridedBatchedWrapper(output1_, q_buf_2_, qk_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
|
||||
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta,
|
||||
request_batch_size * head_number_, cublas_handle_, algoId[1]);
|
||||
|
||||
T scalar = static_cast<T>(1.0f / sqrtf(head_size_ * 1.0f));
|
||||
fastertransformer::invokeMixMaskedSoftMax(static_cast<T *>(qk_buf_), attention_mask, bias_position,
|
||||
request_batch_size, request_src_seq_len, request_tgt_seq_len, head_number_,
|
||||
scalar, stream);
|
||||
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, head_size_, request_tgt_seq_len, head_size_);
|
||||
SET_GEMM_DIMS(gemm_dims, head_size_, request_src_seq_len, request_tgt_seq_len);
|
||||
gemm_strides[1] = request_src_seq_len * request_tgt_seq_len;
|
||||
gemm_strides[THIRD_INPUT] = request_src_seq_len * head_size_;
|
||||
|
||||
CublasGemmStridedBatchedWrapper(output2_, qk_buf_, qkv_buf_2_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
|
||||
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta,
|
||||
request_batch_size * head_number_, cublas_handle_, algoId[2]);
|
||||
fastertransformer::invokeTransposeQKV(static_cast<T *>(qkv_buf_3_), static_cast<T *>(qkv_buf_2_), request_batch_size,
|
||||
request_src_seq_len, head_number_, head_size_, stream);
|
||||
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, hidden_size, hidden_size, hidden_size);
|
||||
SET_GEMM_DIMS(gemm_dims, hidden_size, request_batch_size * request_src_seq_len, hidden_size);
|
||||
CublasGemmWrapper(weight_projection, qkv_buf_3_, output0, gemm_dims, gemm_lds, gemm_ops,
|
||||
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta, cublas_handle_, algoId[3]);
|
||||
if (!is_position_bias_) {
|
||||
int len = request_batch_size * request_src_seq_len;
|
||||
fastertransformer::invokeAddBias(reinterpret_cast<T *>(output0), reinterpret_cast<const T *>(bias_projection), len,
|
||||
hidden_size, stream);
|
||||
const int attn_mask_tensor_idx = 7 + cross_tensor_offset;
|
||||
const int bias_qkv_tensor_idx = 5 + cross_tensor_offset;
|
||||
const int weight_qkv_tensor_idx = 3;
|
||||
const int position_bias_tensor_idx = 6 + cross_tensor_offset;
|
||||
params_.stream = stream;
|
||||
params_.algo = algoId;
|
||||
void *inputs_attn[num_of_inputs_];
|
||||
int index = 0;
|
||||
inputs_attn[index++] = const_cast<void *>(inputs[0]);
|
||||
if (params_.is_cross) {
|
||||
inputs_attn[index++] = const_cast<void *>(inputs[1]);
|
||||
inputs_attn[index++] = const_cast<void *>(inputs[weight_qkv_tensor_idx]);
|
||||
inputs_attn[index++] = const_cast<void *>(inputs[weight_qkv_tensor_idx + 1]);
|
||||
} else {
|
||||
inputs_attn[index++] = const_cast<void *>(inputs[weight_qkv_tensor_idx]);
|
||||
}
|
||||
if (params_.qkv_bias) {
|
||||
inputs_attn[index++] = const_cast<void *>(inputs[bias_qkv_tensor_idx]);
|
||||
}
|
||||
if (params_.position_bias) {
|
||||
inputs_attn[index++] = const_cast<void *>(inputs[position_bias_tensor_idx]);
|
||||
inputs_attn[index++] = const_cast<void *>(inputs[attn_mask_tensor_idx - C2NUM]);
|
||||
} else {
|
||||
inputs_attn[index++] = const_cast<void *>(inputs[attn_mask_tensor_idx]);
|
||||
}
|
||||
inputs_attn[index++] = const_cast<void *>(inputs[weight_projection_tensor_idx]);
|
||||
if (params_.projection_bias) {
|
||||
inputs_attn[index++] = const_cast<void *>(inputs[bias_projection_tensor_idx]);
|
||||
}
|
||||
void *outputs_attn[] = {outputs[0]};
|
||||
fastertransformer::forward_attn<T>(reinterpret_cast<T **>(inputs_attn), num_of_inputs_,
|
||||
reinterpret_cast<T **>(outputs_attn), num_of_outputs_, ¶ms_, workspace);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -328,48 +201,33 @@ bool MhaPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorD
|
|||
}
|
||||
|
||||
void MhaPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {}
|
||||
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {
|
||||
int cross_tensor_offset = 0;
|
||||
int position_bias_tensor_offsets = 0;
|
||||
if (params_.is_cross) cross_tensor_offset = 1;
|
||||
if (params_.position_bias) position_bias_tensor_offsets = 1;
|
||||
const int attn_mask_tensor_idx = 7 + cross_tensor_offset - position_bias_tensor_offsets;
|
||||
const int request_batch_size = static_cast<const int>(in[attn_mask_tensor_idx].desc.dims.d[0]);
|
||||
const int request_src_seq_len = static_cast<const int>(in[attn_mask_tensor_idx].desc.dims.d[1]);
|
||||
const int request_tgt_seq_len = static_cast<const int>(in[attn_mask_tensor_idx].desc.dims.d[2]);
|
||||
params_.batch_size = request_batch_size;
|
||||
params_.src_seq_len = request_src_seq_len;
|
||||
params_.tgt_seq_len = request_tgt_seq_len;
|
||||
num_of_inputs_ = nbInputs;
|
||||
num_of_outputs_ = nbOutputs;
|
||||
}
|
||||
|
||||
size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
|
||||
auto attn_dim_size = inputs[nbInputs - 1].dims.nbDims;
|
||||
const int request_batch_size = static_cast<const int>(inputs[nbInputs - 1].dims.d[0]);
|
||||
const int request_src_seq_len = static_cast<const int>(inputs[nbInputs - 1].dims.d[attn_dim_size - 2]);
|
||||
const int request_tgt_seq_len = static_cast<const int>(inputs[nbInputs - 1].dims.d[attn_dim_size - 1]);
|
||||
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
|
||||
|
||||
// TODO(NIZZAN) Fix efficient allocator
|
||||
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
|
||||
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
|
||||
size_t size_v = size_k;
|
||||
|
||||
size_t qkv_len = size_q + size_k + size_v;
|
||||
size_t q_buf_2_len = size_q;
|
||||
size_t qk_buf_len = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len;
|
||||
size_t qkv_buf_2_len = request_batch_size * request_src_seq_len * hidden_size;
|
||||
size_t qkv_buf_3_len = qkv_buf_2_len;
|
||||
size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len;
|
||||
|
||||
size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
|
||||
int elem_size = sizeof(float);
|
||||
if (compute_type_ == RuntimePrecisionMode_FP16) {
|
||||
elem_size = sizeof(half);
|
||||
return fastertransformer::GetAttnWorkspaceSize<half>(¶ms_);
|
||||
} else {
|
||||
return fastertransformer::GetAttnWorkspaceSize<float>(¶ms_);
|
||||
}
|
||||
|
||||
return (buff_size + extra_tmp_size + extra_tmp_size) * elem_size;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
|
||||
nvinfer1::IExprBuilder &exprBuilder) noexcept {
|
||||
// MHA inputs:
|
||||
// from_tensor [batch_size, src_seq_len, hidden_size_] or [batch_size * src_seq_len, hidden_size_]
|
||||
// encoder_output [batch_size, tgt_seq_len, hidden_size_] or [batch_size * tgt_seq_len, hidden_size_]--> only in
|
||||
// cross MHA attention_mask [batch_size, 1, src_seq_len, tgt_seq_len] or [batch_size, src_seq_len, tgt_seq_len]
|
||||
|
||||
// MHA output_tensors:
|
||||
// attention_out [batch_size, src_seq_len, hidden_size_]
|
||||
// key_cache [batch, head_num, size_per_head]
|
||||
// value_cache [batch, head_num, tgt_seq_len, size_per_head]
|
||||
nvinfer1::DimsExprs dims;
|
||||
if (index == 0) {
|
||||
#ifndef TEST_
|
||||
|
@ -378,21 +236,20 @@ nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1
|
|||
if (num_dims == INPUT_SIZE2) {
|
||||
dims.d[0] = exprBuilder.constant(inputs[nbInputDims - 1].d[0]->getConstantValue() *
|
||||
inputs[nbInputDims - 1].d[1]->getConstantValue());
|
||||
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
|
||||
auto hidden_size = exprBuilder.constant(params_.head_size * params_.head_num);
|
||||
dims.d[1] = hidden_size;
|
||||
} else if (num_dims == INPUT_SIZE3) {
|
||||
dims.d[0] = inputs[nbInputDims - 1].d[0]; // batch
|
||||
dims.d[1] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
|
||||
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
|
||||
auto hidden_size = exprBuilder.constant(params_.head_size * params_.head_num);
|
||||
dims.d[kTwo] = hidden_size;
|
||||
}
|
||||
} else {
|
||||
// TODO(Haim) - Fix size in case of 2d input
|
||||
dims.nbDims = INPUT_SIZE4;
|
||||
dims.d[0] = inputs[nbInputDims - 1].d[0]; // batch
|
||||
dims.d[1] = exprBuilder.constant(head_number_);
|
||||
dims.d[1] = exprBuilder.constant(params_.head_num);
|
||||
dims.d[kTwo] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
|
||||
dims.d[kThree] = exprBuilder.constant(head_size_);
|
||||
dims.d[kThree] = exprBuilder.constant(params_.head_size);
|
||||
}
|
||||
#else
|
||||
dims.nbDims = C2NUM;
|
||||
|
@ -405,7 +262,7 @@ nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1
|
|||
}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *MhaPlugin::clone() const noexcept {
|
||||
auto *plugin = new MhaPlugin(*this); // TODO(haim) CopyConstructor
|
||||
auto *plugin = new MhaPlugin(*this);
|
||||
if (plugin == nullptr) {
|
||||
MS_LOG(ERROR) << "plugin is null";
|
||||
return nullptr;
|
||||
|
@ -418,13 +275,13 @@ int MhaPlugin::initialize() noexcept { return 0; }
|
|||
|
||||
void MhaPlugin::terminate() noexcept {}
|
||||
|
||||
size_t MhaPlugin::getSerializationSize() const noexcept { return INPUT_SIZE4 * sizeof(int); }
|
||||
size_t MhaPlugin::getSerializationSize() const noexcept {
|
||||
return sizeof(int) + sizeof(fastertransformer::encoderParamT);
|
||||
}
|
||||
|
||||
void MhaPlugin::serialize(void *buffer) const noexcept {
|
||||
SerializeValue(&buffer, &compute_type_, sizeof(int));
|
||||
SerializeValue(&buffer, &head_number_, sizeof(int));
|
||||
SerializeValue(&buffer, &head_size_, sizeof(int));
|
||||
SerializeValue(&buffer, &is_cross_, sizeof(int));
|
||||
SerializeValue(&buffer, ¶ms_, sizeof(fastertransformer::encoderParamT));
|
||||
}
|
||||
REGISTER_TENSORRT_CREATOR(ops::kNameAttention, MhaTensorRT)
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
|
||||
#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h"
|
||||
#include "src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h"
|
||||
#include "src/fastertransformer/layers/encoder_layers/encoder.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class MhaTensorRT : public TensorRTOp {
|
||||
|
@ -42,41 +43,29 @@ class MhaTensorRT : public TensorRTOp {
|
|||
constexpr auto MHA_PLUGIN_NAME{"AttentionPlugin"};
|
||||
class MhaPlugin : public TensorRTPlugin {
|
||||
public:
|
||||
MhaPlugin(const std::string name, int compute_type, int head_number, int head_size, bool is_cross,
|
||||
bool is_position_bias, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, uint32_t device_id)
|
||||
MhaPlugin(const std::string name, int compute_type, fastertransformer::encoderParamT params,
|
||||
cublasLtHandle_t cublaslt_handle, uint32_t device_id)
|
||||
: TensorRTPlugin(name, std::string(MHA_PLUGIN_NAME), device_id),
|
||||
compute_type_(compute_type),
|
||||
head_number_(head_number),
|
||||
head_size_(head_size),
|
||||
is_cross_(is_cross),
|
||||
is_position_bias_(is_position_bias),
|
||||
cublas_handle_(cublas_handle),
|
||||
params_(params),
|
||||
cublaslt_handle_(cublaslt_handle) {}
|
||||
|
||||
MhaPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
|
||||
: TensorRTPlugin(std::string(name), std::string(MHA_PLUGIN_NAME)) {
|
||||
const nvinfer1::PluginField *fields = fc->fields;
|
||||
compute_type_ = static_cast<const int *>(fields[0].data)[0];
|
||||
head_number_ = static_cast<const int *>(fields[1].data)[0];
|
||||
head_size_ = static_cast<const int *>(fields[2].data)[0];
|
||||
is_cross_ = static_cast<const int *>(fields[3].data)[0];
|
||||
is_position_bias_ = static_cast<const int *>(fields[4].data)[0];
|
||||
params_ = static_cast<const fastertransformer::encoderParamT *>(fields[1].data)[0];
|
||||
}
|
||||
|
||||
MhaPlugin(const char *name, const void *serialData, size_t serialLength)
|
||||
: TensorRTPlugin(std::string(name), std::string(MHA_PLUGIN_NAME)) {
|
||||
DeserializeValue(&serialData, &serialLength, &compute_type_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &head_number_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &head_size_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &is_cross_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &is_position_bias_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, ¶ms_, sizeof(fastertransformer::encoderParamT));
|
||||
}
|
||||
|
||||
MhaPlugin() = delete;
|
||||
|
||||
~MhaPlugin() override {
|
||||
// std::cout << "~MhaPlugin" << std::endl;
|
||||
}
|
||||
~MhaPlugin() override {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
|
@ -95,38 +84,17 @@ class MhaPlugin : public TensorRTPlugin {
|
|||
int initialize() noexcept override;
|
||||
|
||||
private:
|
||||
bool needResize(const int *current_dims, const int *last_dims);
|
||||
template <typename T>
|
||||
int RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream,
|
||||
cublasGemmAlgo_t *algoId);
|
||||
template <typename T>
|
||||
void SetInnerAddr(void *workspace, size_t size_q, size_t size_k, size_t qk_buf_len, size_t qkv_buf_2_len,
|
||||
size_t extra_size);
|
||||
template <typename T>
|
||||
void RunPhase1GEMM(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, int *gemm_dims,
|
||||
int *gemm_lds, cublasOperation_t *gemm_ops, cudaDataType *gemm_data_types, void *alpha, void *beta,
|
||||
cublasGemmAlgo_t algoId, cudaStream_t stream);
|
||||
|
||||
cublasGemmAlgo_t algoId);
|
||||
const std::string layer_name_;
|
||||
std::string name_space_;
|
||||
int compute_type_;
|
||||
int head_number_;
|
||||
int head_size_;
|
||||
bool is_cross_;
|
||||
bool is_position_bias_;
|
||||
cublasGemmAlgo_t fast_algo_gemm[4] = {CUBLAS_GEMM_DEFAULT_TENSOR_OP, CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP, CUBLAS_GEMM_DEFAULT_TENSOR_OP};
|
||||
|
||||
cublasHandle_t cublas_handle_;
|
||||
mutable fastertransformer::encoderParamT params_;
|
||||
cublasLtHandle_t cublaslt_handle_;
|
||||
void *qkv_buf_{nullptr};
|
||||
void *q_buf_2_{nullptr};
|
||||
void *qk_buf_{nullptr};
|
||||
void *qkv_buf_2_{nullptr};
|
||||
void *qkv_buf_3_{nullptr};
|
||||
void *output1_{nullptr};
|
||||
void *output2_{nullptr};
|
||||
int num_of_inputs_;
|
||||
int num_of_outputs_;
|
||||
};
|
||||
class MhaPluginCreater : public TensorRTPluginCreater<MhaPlugin> {
|
||||
public:
|
||||
|
|
|
@ -106,7 +106,8 @@ int TensorRTAllocator::SyncMemDeviceToHost(tensor::Tensor *host_tensor, const st
|
|||
return SyncMemInHostAndDevice(host_tensor->data_c(), device_tensor_name, host_tensor->Size(), false, sync);
|
||||
}
|
||||
|
||||
int TensorRTAllocator::SyncMemDeviceToHost(void *dst_data, size_t data_size, const std::string &device_tensor_name) {
|
||||
int TensorRTAllocator::SyncMemDeviceToHost(void *dst_data, size_t data_size, const std::string &device_tensor_name,
|
||||
bool sync) {
|
||||
if (dst_data == nullptr) {
|
||||
MS_LOG(ERROR) << " dst host data cannot be nullptr.";
|
||||
return RET_ERROR;
|
||||
|
@ -124,7 +125,11 @@ int TensorRTAllocator::SyncMemDeviceToHost(void *dst_data, size_t data_size, con
|
|||
MS_LOG(ERROR) << "device_ptr is null for " << device_tensor_name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto cuda_ret = cudaMemcpy(dst_data, device_ptr, data_size, cudaMemcpyDeviceToHost);
|
||||
cudaError_t cuda_ret;
|
||||
if (sync)
|
||||
cuda_ret = cudaMemcpy(dst_data, device_ptr, data_size, cudaMemcpyDeviceToHost);
|
||||
else
|
||||
cuda_ret = cudaMemcpyAsync(dst_data, device_ptr, data_size, cudaMemcpyDeviceToHost, stream_);
|
||||
if (cuda_ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret);
|
||||
return RET_ERROR;
|
||||
|
@ -176,7 +181,11 @@ int TensorRTAllocator::SyncMemInHostAndDevice(void *host_data, const std::string
|
|||
void *src_ptr = is_host2device ? host_data : device_ptr;
|
||||
void *dst_ptr = is_host2device ? device_ptr : host_data;
|
||||
cudaMemcpyKind kind = is_host2device ? cudaMemcpyHostToDevice : cudaMemcpyDeviceToHost;
|
||||
auto cuda_ret = cudaMemcpy(dst_ptr, src_ptr, data_size, kind);
|
||||
cudaError_t cuda_ret;
|
||||
if (sync)
|
||||
cuda_ret = cudaMemcpy(dst_ptr, src_ptr, data_size, kind);
|
||||
else
|
||||
cuda_ret = cudaMemcpyAsync(dst_ptr, src_ptr, data_size, kind, stream_);
|
||||
if (cuda_ret != cudaSuccess) {
|
||||
MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret);
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -54,7 +54,7 @@ class TensorRTAllocator {
|
|||
|
||||
int SyncMemHostToDevice(const tensor::Tensor &host_tensor, const std::string &device_tensor_name, bool sync = true);
|
||||
int SyncMemDeviceToHost(tensor::Tensor *host_tensor, const std::string &device_tensor_name, bool sync = true);
|
||||
int SyncMemDeviceToHost(void *dst_data, size_t data_size, const std::string &device_tensor_name);
|
||||
int SyncMemDeviceToHost(void *dst_data, size_t data_size, const std::string &device_tensor_name, bool sync = true);
|
||||
|
||||
int ClearDeviceMem();
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
#include <iomanip>
|
||||
#include "src/extendrt/delegate/delegate_utils.h"
|
||||
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
|
||||
#include "src/common/utils.h"
|
||||
|
@ -719,8 +720,8 @@ int TensorRTSubGraph::OnNewInputShapes(const std::vector<ShapeVector> &new_shape
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorRTSubGraph::PreExecute(const std::vector<tensor::Tensor> &inputs,
|
||||
const std::vector<tensor::Tensor> &outputs) {
|
||||
int TensorRTSubGraph::PreExecute(const std::vector<tensor::Tensor> &inputs, const std::vector<tensor::Tensor> &outputs,
|
||||
bool sync) {
|
||||
if (inputs_.size() != inputs.size()) {
|
||||
MS_LOG(ERROR) << "Graph inputs size " << inputs_.size() << " != execute inputs size " << inputs.size();
|
||||
return RET_ERROR;
|
||||
|
@ -748,7 +749,7 @@ int TensorRTSubGraph::PreExecute(const std::vector<tensor::Tensor> &inputs,
|
|||
MS_LOG(ERROR) << "realloc for input tensor device memory failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = runtime_->GetAllocator()->SyncMemHostToDevice(inputs[i], trt_tensor_name);
|
||||
ret = runtime_->GetAllocator()->SyncMemHostToDevice(inputs[i], trt_tensor_name, sync);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "sync mem from host to device failed for " << trt_tensor_name;
|
||||
return RET_ERROR;
|
||||
|
@ -788,7 +789,7 @@ int TensorRTSubGraph::PreExecute(const std::vector<tensor::Tensor> &inputs,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorRTSubGraph::PostExecute(std::vector<tensor::Tensor> *outputs) {
|
||||
int TensorRTSubGraph::PostExecute(std::vector<tensor::Tensor> *outputs, bool sync) {
|
||||
if (!outputs->empty() && outputs->size() != outputs_.size()) {
|
||||
MS_LOG(ERROR) << "Graph outputs size " << outputs_.size() << " != execute outputs size " << outputs->size();
|
||||
return RET_ERROR;
|
||||
|
@ -819,8 +820,8 @@ int TensorRTSubGraph::PostExecute(std::vector<tensor::Tensor> *outputs) {
|
|||
MS_LOG(ERROR) << "Specified output device or host address cannot be nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int sync_ret =
|
||||
runtime_->GetAllocator()->SyncMemDeviceToHost(host_address, outputs_[i].DataSize(), trt_out_tensor_name);
|
||||
int sync_ret = runtime_->GetAllocator()->SyncMemDeviceToHost(host_address, outputs_[i].DataSize(),
|
||||
trt_out_tensor_name, sync);
|
||||
if (sync_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "sync mem from device to host failed for " << trt_out_tensor_name;
|
||||
return sync_ret;
|
||||
|
@ -828,7 +829,7 @@ int TensorRTSubGraph::PostExecute(std::vector<tensor::Tensor> *outputs) {
|
|||
}
|
||||
} else {
|
||||
tensor::Tensor output_tensor(static_cast<enum TypeId>(outputs_[i].DataType()), new_shape);
|
||||
int sync_ret = runtime_->GetAllocator()->SyncMemDeviceToHost(&output_tensor, trt_out_tensor_name);
|
||||
int sync_ret = runtime_->GetAllocator()->SyncMemDeviceToHost(&output_tensor, trt_out_tensor_name, sync);
|
||||
if (sync_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "sync mem from device to host failed for " << trt_out_tensor_name;
|
||||
return sync_ret;
|
||||
|
@ -854,19 +855,38 @@ bool TensorRTSubGraph::ValidInputResizeDims(const nvinfer1::Dims &construct_dims
|
|||
}
|
||||
|
||||
int TensorRTSubGraph::Execute(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs) {
|
||||
#ifdef ASYNC_INFER
|
||||
bool sync = false;
|
||||
#else
|
||||
bool sync = true;
|
||||
#endif
|
||||
int ret = lite::SetCudaDevice(device_info_);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
ret = PreExecute(inputs, *outputs);
|
||||
ret = PreExecute(inputs, *outputs, sync);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
if (!this->trt_context_->executeV2(tensor_bindings_)) {
|
||||
MS_LOG(ERROR) << "TensorRT execute failed.";
|
||||
return RET_ERROR;
|
||||
if (sync) {
|
||||
if (!this->trt_context_->executeV2(tensor_bindings_)) {
|
||||
MS_LOG(ERROR) << "TensorRT execute failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
if (!this->trt_context_->enqueueV2(tensor_bindings_, stream_, nullptr)) {
|
||||
MS_LOG(ERROR) << "TensorRT execute failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return PostExecute(outputs);
|
||||
ret = PostExecute(outputs, sync);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
if (!sync) {
|
||||
cudaStreamSynchronize(stream_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorRTSubGraph::Resize(const std::vector<tensor::Tensor> &, const std::vector<ShapeVector> &new_shapes) {
|
||||
|
|
|
@ -90,8 +90,9 @@ class TensorRTSubGraph {
|
|||
nvinfer1::Dims SetInputDimsProfile(const TensorInfo &in_tensor, int index);
|
||||
int ParseInputsProfile();
|
||||
|
||||
int PreExecute(const std::vector<tensor::Tensor> &inputs, const std::vector<tensor::Tensor> &outputs);
|
||||
int PostExecute(std::vector<tensor::Tensor> *outputs);
|
||||
int PreExecute(const std::vector<tensor::Tensor> &inputs, const std::vector<tensor::Tensor> &outputs,
|
||||
bool sync = true);
|
||||
int PostExecute(std::vector<tensor::Tensor> *outputs, bool sync = true);
|
||||
|
||||
int OnNewInputShapes(const std::vector<ShapeVector> &inputs);
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@
|
|||
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
|
||||
#include "tools/optimizer/fusion/tensor_dot_fusion.h"
|
||||
#include "tools/optimizer/fusion/multi_head_attention_fusion.h"
|
||||
#include "tools/optimizer/fusion/encoder_layer_fusion.h"
|
||||
#include "tools/optimizer/fusion/glu_fusion.h"
|
||||
#include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h"
|
||||
#include "tools/optimizer/fusion/matmul_add_fusion.h"
|
||||
|
@ -319,9 +320,10 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared
|
|||
std::make_shared<opt::AddActivationFusion>(),
|
||||
std::make_shared<opt::ExpandDimsReshapeFusion>(),
|
||||
std::make_shared<opt::SqueezeExpandDimsFusion>()};
|
||||
#ifdef ENABLE_CLOUD_FUSION_INFERENCE
|
||||
fusions.push_back(std::make_shared<opt::MultiHeadAttentionFusion>());
|
||||
#endif
|
||||
if (param->optimize_transformer) {
|
||||
fusions.push_back(std::make_shared<opt::MultiHeadAttentionFusion>());
|
||||
fusions.push_back(std::make_shared<opt::EncoderLayerFusion>());
|
||||
}
|
||||
for (size_t index = 0; index < fusions.size(); index++) {
|
||||
auto pass_ptr = fusions.at(index);
|
||||
MS_CHECK_TRUE_RET(pass_ptr != nullptr, RET_ERROR);
|
||||
|
|
|
@ -89,6 +89,7 @@ Flags::Flags() {
|
|||
"Set the target device, support Ascend, Ascend310 and Ascend310P will be deprecated.", "");
|
||||
AddFlag(&Flags::saveTypeStr, "saveType", "The type of saved model. MINDIR | MINDIR_LITE", "MINDIR_LITE");
|
||||
AddFlag(&Flags::optimizeStr, "optimize", "The type of optimization. none | general | ascend_oriented", "");
|
||||
AddFlag(&Flags::optimizeTransformerStr, "optimizeTransformer", "Enable Fast-Transformer fusion true|false", "false");
|
||||
}
|
||||
|
||||
int Flags::InitInputOutputDataType() {
|
||||
|
@ -274,7 +275,7 @@ int Flags::InitExportMindIR() {
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (this->exportMindIR == "MINDIR") {
|
||||
if ((this->exportMindIR == "MINDIR") && (this->optimizeTransformer == false)) {
|
||||
this->disableFusion = true;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -311,6 +312,18 @@ int Flags::InitSaveType() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::InitOptimizeTransformer() {
|
||||
if (this->optimizeTransformerStr == "true") {
|
||||
this->optimizeTransformer = true;
|
||||
} else if (this->optimizeTransformerStr == "false") {
|
||||
this->optimizeTransformer = false;
|
||||
} else {
|
||||
std::cerr << "INPUT ILLEGAL: optimizeTransformer must be true|false " << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Flags::PreInit(int argc, const char **argv) {
|
||||
if (argc == 1) {
|
||||
std::cout << this->Usage() << std::endl;
|
||||
|
@ -383,6 +396,13 @@ int Flags::Init(int argc, const char **argv) {
|
|||
std::cerr << "Init encrypt failed." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ret = InitOptimizeTransformer();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init optimize transformers failed." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ret = InitPreInference();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init pre inference failed." << std::endl;
|
||||
|
|
|
@ -43,6 +43,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
int InitOptimize();
|
||||
int InitExportMindIR();
|
||||
int InitSaveType();
|
||||
int InitOptimizeTransformer();
|
||||
|
||||
int Init(int argc, const char **argv);
|
||||
int PreInit(int argc, const char **argv);
|
||||
|
||||
|
@ -85,6 +87,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
bool encryption = false;
|
||||
#endif
|
||||
std::string device;
|
||||
std::string optimizeTransformerStr;
|
||||
bool optimizeTransformer = false;
|
||||
};
|
||||
} // namespace converter
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -68,6 +68,7 @@ int main(int argc, const char **argv) {
|
|||
converter.SetTrainModel(flags.trainModel);
|
||||
converter.SetNoFusion(flags.disableFusion);
|
||||
converter.SetDevice(flags.device);
|
||||
converter.SetOptimizeTransformer(flags.optimizeTransformer);
|
||||
|
||||
auto status = converter.Convert();
|
||||
if (status != mindspore::kSuccess) {
|
||||
|
|
|
@ -270,6 +270,20 @@ bool Converter::GetNoFusion() {
|
|||
}
|
||||
}
|
||||
|
||||
void Converter::SetOptimizeTransformer(bool optimizeTransformer) {
|
||||
if (data_ != nullptr) {
|
||||
data_->optimize_transformer = optimizeTransformer;
|
||||
}
|
||||
}
|
||||
|
||||
bool Converter::GetOptimizeTransformer() {
|
||||
if (data_ != nullptr) {
|
||||
return data_->optimize_transformer;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void Converter::SetDevice(const std::vector<char> &device) {
|
||||
if (data_ != nullptr) {
|
||||
data_->device = CharToString(device);
|
||||
|
|
|
@ -58,6 +58,7 @@ struct ConverterPara {
|
|||
bool pre_infer = false;
|
||||
bool train_model = false;
|
||||
bool no_fusion = false;
|
||||
bool optimize_transformer = false;
|
||||
bool is_runtime_converter = false;
|
||||
std::set<std::string> fusion_blacklists;
|
||||
|
||||
|
|
|
@ -0,0 +1,411 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define USE_DEPRECATED_API
|
||||
#include "tools/optimizer/fusion/encoder_layer_fusion.h"
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "ops/tuple_get_item.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
const auto &p1 = std::placeholders::_1;
|
||||
} // namespace
|
||||
|
||||
bool EncoderLayerFusion::Init() const {
|
||||
input_ = std::make_shared<Var>("input");
|
||||
MS_CHECK_TRUE_RET(input_ != nullptr, false);
|
||||
beta1_ = std::make_shared<Var>("beta1");
|
||||
MS_CHECK_TRUE_RET(beta1_ != nullptr, false);
|
||||
gamma1_ = std::make_shared<Var>("gamma1");
|
||||
MS_CHECK_TRUE_RET(gamma1_ != nullptr, false);
|
||||
beta2_ = std::make_shared<Var>("beta2");
|
||||
MS_CHECK_TRUE_RET(beta2_ != nullptr, false);
|
||||
gamma2_ = std::make_shared<Var>("gamma2");
|
||||
MS_CHECK_TRUE_RET(gamma2_ != nullptr, false);
|
||||
weight_attn_qkv_ = std::make_shared<Var>("weight_attn_qkv");
|
||||
MS_CHECK_TRUE_RET(weight_attn_qkv_ != nullptr, false);
|
||||
weight_attn_o_ = std::make_shared<CondVar>(IsParamNode, "weight_attn_o");
|
||||
MS_CHECK_TRUE_RET(weight_attn_o_ != nullptr, false);
|
||||
weight_m_ = std::make_shared<CondVar>(IsParamNode, "weight_m");
|
||||
MS_CHECK_TRUE_RET(weight_m_ != nullptr, false);
|
||||
weight_p_ = std::make_shared<CondVar>(IsParamNode, "weight_p");
|
||||
MS_CHECK_TRUE_RET(weight_p_ != nullptr, false);
|
||||
bias_attn_qkv_ = std::make_shared<Var>("bias_attn_qkv");
|
||||
MS_CHECK_TRUE_RET(bias_attn_qkv_ != nullptr, false);
|
||||
bias_attn_o_ = std::make_shared<CondVar>(IsParamNode, "bias_attn_o");
|
||||
MS_CHECK_TRUE_RET(bias_attn_o_ != nullptr, false);
|
||||
bias_m_ = std::make_shared<CondVar>(IsParamNode, "bias_m");
|
||||
MS_CHECK_TRUE_RET(bias_m_ != nullptr, false);
|
||||
bias_p_ = std::make_shared<CondVar>(IsParamNode, "bias_p");
|
||||
MS_CHECK_TRUE_RET(bias_p_ != nullptr, false);
|
||||
mask_ = std::make_shared<Var>("mask");
|
||||
MS_CHECK_TRUE_RET(mask_ != nullptr, false);
|
||||
is_attention_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAttention), "is_attention");
|
||||
MS_CHECK_TRUE_RET(is_attention_ != nullptr, false);
|
||||
is_layernorm1_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm1");
|
||||
MS_CHECK_TRUE_RET(is_layernorm1_ != nullptr, false);
|
||||
is_layernorm2_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm2");
|
||||
MS_CHECK_TRUE_RET(is_layernorm2_ != nullptr, false);
|
||||
position_bias_ = std::make_shared<Var>("position_bias");
|
||||
MS_CHECK_TRUE_RET(is_layernorm2_ != nullptr, false);
|
||||
is_act_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "activation");
|
||||
MS_CHECK_TRUE_RET(is_act_ != nullptr, {});
|
||||
return true;
|
||||
}
|
||||
|
||||
VectorRef EncoderLayerFusion::getTuple(bool post_layernorm, bool layernorm_fusion = false,
|
||||
bool is_position_bias = false) const {
|
||||
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
|
||||
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
|
||||
auto var1 = std::make_shared<Var>("var1-reshape");
|
||||
MS_CHECK_TRUE_RET(var1 != nullptr, {});
|
||||
auto reshape1 = VectorRef({is_reshape1, input_, var1});
|
||||
if (post_layernorm) {
|
||||
return reshape1;
|
||||
}
|
||||
if (layernorm_fusion) {
|
||||
return DefineLayerNorm(is_position_bias, reshape1, gamma1_, beta1_);
|
||||
}
|
||||
auto layer_norm = VectorRef({is_layernorm1_, reshape1, gamma1_, beta1_});
|
||||
auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
|
||||
auto var_tuple = std::make_shared<Var>("var_tuple");
|
||||
auto tuple = VectorRef({is_tuple, layer_norm, var_tuple});
|
||||
return tuple;
|
||||
}
|
||||
|
||||
VectorRef EncoderLayerFusion::DefineLayerNorm(bool is_position_bias, VectorRef input, VarPtr gamma, VarPtr beta) const {
|
||||
auto var1 = std::make_shared<Var>("var1");
|
||||
MS_CHECK_TRUE_RET(var1 != nullptr, {});
|
||||
auto is_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "reduce");
|
||||
MS_CHECK_TRUE_RET(is_reduce != nullptr, {});
|
||||
auto reduce1 = VectorRef({is_reduce, input, var1});
|
||||
auto is_sub = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSubFusion), "sub-f");
|
||||
MS_CHECK_TRUE_RET(is_sub != nullptr, {});
|
||||
auto sub = VectorRef({is_sub, input, reduce1});
|
||||
auto is_sqr = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSquare), "sqr");
|
||||
MS_CHECK_TRUE_RET(is_sqr != nullptr, {});
|
||||
auto sqr = (is_position_bias) ? VectorRef({is_sqr, input}) : VectorRef({is_sqr, sub});
|
||||
auto var2 = std::make_shared<Var>("var2");
|
||||
MS_CHECK_TRUE_RET(var2 != nullptr, {});
|
||||
auto is_reduce2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "reduce2");
|
||||
MS_CHECK_TRUE_RET(is_reduce2 != nullptr, {});
|
||||
auto reduce2 = VectorRef({is_reduce2, sqr, var2});
|
||||
auto var3 = std::make_shared<Var>("var3");
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is-add");
|
||||
MS_CHECK_TRUE_RET(is_add != nullptr, {});
|
||||
auto add = VectorRef({is_add, reduce2, var3});
|
||||
auto is_sqr2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSqrt), "sqr2");
|
||||
MS_CHECK_TRUE_RET(is_sqr2 != nullptr, {});
|
||||
auto sqr2 = VectorRef({is_sqr2, add});
|
||||
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "real-div");
|
||||
MS_CHECK_TRUE_RET(is_div != nullptr, {});
|
||||
if (is_position_bias) {
|
||||
auto real_div = VectorRef({is_div, input, sqr2});
|
||||
auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "mul");
|
||||
MS_CHECK_TRUE_RET(is_mul != nullptr, {});
|
||||
auto mul = VectorRef({is_mul, real_div, gamma});
|
||||
return mul;
|
||||
} else {
|
||||
auto real_div = VectorRef({is_div, sub, sqr2});
|
||||
auto is_scale = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimScaleFusion), "scale");
|
||||
MS_CHECK_TRUE_RET(is_scale != nullptr, {});
|
||||
auto scale = VectorRef({is_scale, real_div, gamma, beta});
|
||||
return scale;
|
||||
}
|
||||
}
|
||||
|
||||
VectorRef EncoderLayerFusion::DefinePatternEncoderLayer(bool post_layernorm = true, bool layernorm_fusion = false,
|
||||
bool is_position_bias = false) const {
|
||||
VectorRef attention, tuple, tuple2, tuple3, reshape2, matmul1;
|
||||
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
|
||||
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
|
||||
auto var1 = std::make_shared<Var>("var1");
|
||||
MS_CHECK_TRUE_RET(var1 != nullptr, {});
|
||||
auto reshape1 = VectorRef({is_reshape1, input_, var1});
|
||||
if (!is_position_bias) {
|
||||
attention = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
|
||||
getTuple(post_layernorm, layernorm_fusion, is_position_bias),
|
||||
getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_,
|
||||
weight_attn_o_, bias_attn_qkv_, bias_attn_o_, mask_});
|
||||
} else {
|
||||
attention = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
|
||||
getTuple(post_layernorm, layernorm_fusion, is_position_bias),
|
||||
getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_,
|
||||
weight_attn_o_, position_bias_, mask_});
|
||||
}
|
||||
if (!is_position_bias) {
|
||||
auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
|
||||
auto var_tuple = std::make_shared<Var>("var_tuple");
|
||||
tuple = VectorRef({is_tuple, attention, var_tuple});
|
||||
} else {
|
||||
tuple = attention;
|
||||
}
|
||||
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
|
||||
auto add = VectorRef({is_add, reshape1, tuple});
|
||||
if (layernorm_fusion) {
|
||||
tuple2 = DefineLayerNorm(is_position_bias, add, gamma2_, beta2_);
|
||||
} else {
|
||||
auto layer_norm2 = VectorRef({is_layernorm2_, add, gamma2_, beta2_});
|
||||
auto is_tuple2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item2");
|
||||
auto var_tuple2 = std::make_shared<Var>("var_tuple2");
|
||||
tuple2 = VectorRef({is_tuple2, layer_norm2, var_tuple2});
|
||||
}
|
||||
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder2");
|
||||
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
|
||||
auto var2 = std::make_shared<Var>("var2");
|
||||
MS_CHECK_TRUE_RET(var2 != nullptr, {});
|
||||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
|
||||
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
|
||||
if (is_position_bias) {
|
||||
reshape2 = VectorRef({is_reshape2, add, var2});
|
||||
matmul1 = VectorRef({is_matmul1, tuple2, weight_m_});
|
||||
} else if (post_layernorm || layernorm_fusion) {
|
||||
reshape2 = VectorRef({is_reshape2, tuple2, var2});
|
||||
matmul1 = VectorRef({is_matmul1, tuple2, weight_m_, bias_m_});
|
||||
} else {
|
||||
reshape2 = VectorRef({is_reshape2, add, var2});
|
||||
matmul1 = VectorRef({is_matmul1, tuple2, weight_m_, bias_m_});
|
||||
}
|
||||
auto act = VectorRef({is_act_, matmul1});
|
||||
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
|
||||
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
|
||||
auto matmul2 =
|
||||
(is_position_bias) ? VectorRef({is_matmul2, matmul1, weight_p_}) : VectorRef({is_matmul2, act, weight_p_, bias_p_});
|
||||
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder3");
|
||||
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
|
||||
auto var3 = std::make_shared<Var>("var3");
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
auto reshape3 = VectorRef({is_reshape3, matmul2, var3});
|
||||
auto is_add3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add3");
|
||||
auto add3 = VectorRef({is_add3, reshape2, reshape3});
|
||||
if (!post_layernorm || layernorm_fusion) {
|
||||
return add3;
|
||||
}
|
||||
auto is_reshape4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
|
||||
MS_CHECK_TRUE_RET(is_reshape4 != nullptr, {});
|
||||
auto var4 = std::make_shared<Var>("var4");
|
||||
MS_CHECK_TRUE_RET(var4 != nullptr, {});
|
||||
auto reshape4 = VectorRef({is_reshape4, add3, var4});
|
||||
if (layernorm_fusion) {
|
||||
tuple3 = DefineLayerNorm(is_position_bias, reshape4, gamma1_, beta1_);
|
||||
} else {
|
||||
auto layer_norm = VectorRef({is_layernorm1_, reshape4, gamma1_, beta1_});
|
||||
auto is_tuple3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item3");
|
||||
auto var_tuple3 = std::make_shared<Var>("var_tuple3");
|
||||
tuple3 = VectorRef({is_tuple3, layer_norm, var_tuple3});
|
||||
}
|
||||
auto is_reshape5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
|
||||
MS_CHECK_TRUE_RET(is_reshape5 != nullptr, {});
|
||||
auto var5 = std::make_shared<Var>("var5");
|
||||
MS_CHECK_TRUE_RET(var5 != nullptr, {});
|
||||
auto reshape5 = VectorRef({is_reshape5, tuple3, var5});
|
||||
return reshape5;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, VectorRef> EncoderLayerFusion::DefinePatterns() const {
|
||||
std::unordered_map<std::string, VectorRef> patterns;
|
||||
if (!Init()) {
|
||||
MS_LOG(ERROR) << "initial member failed.";
|
||||
return patterns;
|
||||
}
|
||||
patterns[kPatternEncoderLayerPre] = DefinePatternEncoderLayer(false);
|
||||
patterns[kPatternEncoderLayerPost] = DefinePatternEncoderLayer(true);
|
||||
patterns[kPatternEncoderLayerPostNorm] = DefinePatternEncoderLayer(true, true);
|
||||
patterns[kPatternEncoderLayerPreNorm] = DefinePatternEncoderLayer(false, true);
|
||||
patterns[kPatternEncoderLayerT5] = DefinePatternEncoderLayer(false, true, true);
|
||||
return patterns;
|
||||
}
|
||||
|
||||
AnfNodePtr EncoderLayerFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
|
||||
const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
|
||||
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (pattern_name == kPatternEncoderLayerPost || pattern_name == kPatternEncoderLayerPostNorm) {
|
||||
return CreateMaskedEncoderLayerFusionNode(func_graph, equiv, node, true);
|
||||
} else if (pattern_name == kPatternEncoderLayerPre || pattern_name == kPatternEncoderLayerPreNorm) {
|
||||
return CreateMaskedEncoderLayerFusionNode(func_graph, equiv, node, false);
|
||||
} else if (pattern_name == kPatternEncoderLayerT5) {
|
||||
is_position_bias_ = true;
|
||||
return CreateMaskedEncoderLayerFusionNode(func_graph, equiv, node, false);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool EncoderLayerFusion::IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const VarPtr &input_prim) const {
|
||||
auto act_input = GetAttribute(func_graph, equiv, is_act_);
|
||||
MS_ASSERT(act_input != nullptr);
|
||||
auto act_primitive = ops::GetOperator<ops::Activation>(act_input);
|
||||
MS_CHECK_TRUE_RET(act_primitive != nullptr, false);
|
||||
auto act_primitive_c = act_primitive->GetPrim();
|
||||
if (act_primitive_c->GetAttr(ops::kActivationType) == nullptr ||
|
||||
act_primitive->get_activation_type() != mindspore::GELU) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr EncoderLayerFusion::GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
VarPtr node_name) const {
|
||||
if ((*equiv)[node_name] == nullptr || !utils::isa<AnfNodePtr>((*equiv)[node_name])) {
|
||||
MS_LOG(ERROR) << node_name << "is not AnfNodePtr";
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr node = utils::cast<AnfNodePtr>((*equiv)[node_name]);
|
||||
MS_ASSERT(node != nullptr);
|
||||
if (node == nullptr || !utils::isa<CNodePtr>(node)) {
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto users = manager->node_users();
|
||||
auto it = users.find(node);
|
||||
if (it != users.end()) {
|
||||
node = it->second.front().first;
|
||||
}
|
||||
if (node == nullptr || !utils::isa<CNodePtr>(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
auto cnode = utils::cast<CNodePtr>(node);
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto input = cnode->input(0);
|
||||
return input;
|
||||
}
|
||||
STATUS EncoderLayerFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num,
|
||||
int *head_size, float *eps1, float *eps2) const {
|
||||
auto attn_input = GetAttribute(func_graph, equiv, is_attention_);
|
||||
MS_ASSERT(attn_input != nullptr);
|
||||
auto attn_prim = ops::GetOperator<ops::Attention>(attn_input);
|
||||
if (attn_prim->GetAttr(ops::kEncoderLayerNumHeads) != nullptr) {
|
||||
*head_num = attn_prim->get_head_num();
|
||||
}
|
||||
if (attn_prim->GetAttr(ops::kAttentionSizePerHead) != nullptr) {
|
||||
*head_size = attn_prim->get_head_size();
|
||||
}
|
||||
if (attn_prim->GetAttr(ops::kPositionBias) != nullptr) {
|
||||
is_position_bias_ = attn_prim->get_position_bias();
|
||||
}
|
||||
auto layrn1_input = GetAttribute(func_graph, equiv, is_layernorm1_);
|
||||
auto layrn1_prim = ops::GetOperator<ops::LayerNormFusion>(layrn1_input);
|
||||
if (layrn1_prim->GetAttr(ops::kEpsilon) != nullptr) {
|
||||
*eps1 = layrn1_prim->get_epsilon();
|
||||
}
|
||||
auto layrn2_input = GetAttribute(func_graph, equiv, is_layernorm2_);
|
||||
auto layrn2_prim = ops::GetOperator<ops::LayerNormFusion>(layrn2_input);
|
||||
if (layrn2_prim->GetAttr(ops::kEpsilon) != nullptr) {
|
||||
*eps2 = layrn2_prim->get_epsilon();
|
||||
}
|
||||
if (!IsActGELU(func_graph, equiv, is_act_)) {
|
||||
return false;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::shared_ptr<ops::EncoderLayer> EncoderLayerFusion::CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
bool post_layernorm, int64_t ffn_hidden_size) const {
|
||||
auto encoder_layer_prim = std::make_shared<ops::EncoderLayer>();
|
||||
if (encoder_layer_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Build enoder layer primitive failed.";
|
||||
return nullptr;
|
||||
}
|
||||
int head_num = 0;
|
||||
int head_size = 0;
|
||||
float eps1 = 1e-6;
|
||||
float eps2 = 1e-6;
|
||||
if (CheckPattern(func_graph, equiv, &head_num, &head_size, &eps1, &eps2)) {
|
||||
return nullptr;
|
||||
}
|
||||
encoder_layer_prim->Init(head_num, head_size, eps1, eps2, ffn_hidden_size, is_position_bias_, post_layernorm);
|
||||
return encoder_layer_prim;
|
||||
}
|
||||
|
||||
CNodePtr EncoderLayerFusion::CreateMaskedEncoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const AnfNodePtr &node,
|
||||
bool post_layernorm = true) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
auto input = utils::cast<AnfNodePtr>((*equiv)[input_]);
|
||||
AnfNodePtr position_bias, input_mask, bias_attn_o, bias_attn_qkv, beta1, beta2, bias_m, bias_p;
|
||||
auto weight_qkv = utils::cast<AnfNodePtr>((*equiv)[weight_attn_qkv_]);
|
||||
auto weight_attn_o = utils::cast<AnfNodePtr>((*equiv)[weight_attn_o_]);
|
||||
auto weight_m = utils::cast<AnfNodePtr>((*equiv)[weight_m_]);
|
||||
auto weight_p = utils::cast<AnfNodePtr>((*equiv)[weight_p_]);
|
||||
if (!is_position_bias_) {
|
||||
bias_attn_qkv = utils::cast<AnfNodePtr>((*equiv)[bias_attn_qkv_]);
|
||||
bias_attn_o = utils::cast<AnfNodePtr>((*equiv)[bias_attn_o_]);
|
||||
bias_m = utils::cast<AnfNodePtr>((*equiv)[bias_m_]);
|
||||
bias_p = utils::cast<AnfNodePtr>((*equiv)[bias_p_]);
|
||||
beta1 = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
|
||||
beta2 = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
|
||||
}
|
||||
auto gamma1 = utils::cast<AnfNodePtr>((*equiv)[gamma1_]);
|
||||
auto gamma2 = utils::cast<AnfNodePtr>((*equiv)[gamma2_]);
|
||||
if (mask_) {
|
||||
input_mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
|
||||
}
|
||||
auto base_shape_ptr = weight_m->Shape();
|
||||
MS_EXCEPTION_IF_NULL(base_shape_ptr);
|
||||
auto input_shape_ptr = base_shape_ptr->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_shape_ptr);
|
||||
auto input_shape = input_shape_ptr->shape();
|
||||
MS_ASSERT(input_shape != nullptr);
|
||||
int ffn_hidden_size = (int64_t)input_shape[1];
|
||||
auto encoder_layer_prim = CreatePrim(func_graph, equiv, post_layernorm, ffn_hidden_size);
|
||||
MS_CHECK_TRUE_RET(encoder_layer_prim != nullptr, nullptr);
|
||||
auto encoder_layer_prim_c = encoder_layer_prim->GetPrim();
|
||||
MS_CHECK_TRUE_RET(encoder_layer_prim_c != nullptr, nullptr);
|
||||
auto value_node = NewValueNode(encoder_layer_prim_c);
|
||||
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
|
||||
std::vector<AnfNodePtr> new_node_inputs;
|
||||
ParameterPtr c_bias_m_param, c_weight_p_param, c_bias_p_param, c_weight_m_param;
|
||||
if (is_position_bias_) {
|
||||
position_bias = utils::cast<AnfNodePtr>((*equiv)[position_bias_]);
|
||||
if (!post_layernorm)
|
||||
new_node_inputs = {value_node, input, gamma1, weight_qkv, input_mask,
|
||||
weight_attn_o, gamma2, weight_m, weight_p, position_bias};
|
||||
else
|
||||
new_node_inputs = {value_node, input, weight_qkv, input_mask, weight_attn_o,
|
||||
gamma1, weight_m, weight_p, gamma2, position_bias};
|
||||
} else {
|
||||
if (!post_layernorm) {
|
||||
new_node_inputs = {value_node, input, gamma1, beta1, weight_qkv, bias_attn_qkv, input_mask, weight_attn_o,
|
||||
bias_attn_o, gamma2, beta2, weight_m, bias_m, weight_p, bias_p};
|
||||
} else {
|
||||
new_node_inputs = {value_node, input, weight_qkv, bias_attn_qkv, input_mask,
|
||||
weight_attn_o, bias_attn_o, gamma1, beta1, weight_m,
|
||||
bias_m, weight_p, bias_p, gamma2, beta2};
|
||||
}
|
||||
}
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
|
||||
auto old_node = node->cast<CNodePtr>();
|
||||
MS_CHECK_TRUE_RET(old_node->abstract() != nullptr, nullptr);
|
||||
new_node->set_abstract(old_node->abstract()->Clone());
|
||||
new_node->set_fullname_with_scope(node->fullname_with_scope() + "/encoder_layer");
|
||||
return new_node;
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "ops/encoder_layer.h"
|
||||
#include "tools/optimizer/fusion/multi_head_attention_fusion.h"
|
||||
#include "ops/fusion/layer_norm_fusion.h"
|
||||
#include "ops/fusion/activation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class EncoderLayerFusion : public MultiplePatternProcessPass {
|
||||
public:
|
||||
explicit EncoderLayerFusion(const std::string &name = "EncoderLayerFusion", bool multigraph = true)
|
||||
: MultiplePatternProcessPass(name, multigraph) {}
|
||||
|
||||
~EncoderLayerFusion() override = default;
|
||||
|
||||
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
|
||||
const EquivPtr &) const override;
|
||||
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
|
||||
|
||||
protected:
|
||||
virtual bool Init() const;
|
||||
|
||||
private:
|
||||
const std::string kPatternEncoderLayerPost = "PatternTEncoderLayerPost";
|
||||
const std::string kPatternEncoderLayerPre = "PatternTEncoderLayerPre";
|
||||
const std::string kPatternEncoderLayerPostNorm = "PatternTEncoderLayerPostNorm";
|
||||
const std::string kPatternEncoderLayerPreNorm = "PatternTEncoderLayerPreNorm";
|
||||
const std::string kPatternEncoderLayerT5 = "PatternEncoderLayerT5";
|
||||
VectorRef DefinePatternEncoderLayer(bool post_layernorm, bool layernorm_fusion, bool is_position_bias_) const;
|
||||
VectorRef getTuple(bool post_layernorm, bool layernorm_fusion, bool is_position_bias) const;
|
||||
VectorRef DefineLayerNorm(bool is_position_bias, VectorRef input, VarPtr gamma, VarPtr beta) const;
|
||||
CNodePtr CreateMaskedEncoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const AnfNodePtr &node, bool post_layernorm) const;
|
||||
AnfNodePtr GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv, VarPtr node_name) const;
|
||||
bool IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const VarPtr &input_prim) const;
|
||||
lite::STATUS CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num, int *head_size,
|
||||
float *eps1, float *eps2) const;
|
||||
std::shared_ptr<ops::EncoderLayer> CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
bool post_layernorm, int64_t ffn_hidden_size) const;
|
||||
|
||||
protected:
|
||||
mutable VarPtr input_{nullptr};
|
||||
mutable VarPtr position_bias_{nullptr};
|
||||
mutable VarPtr beta1_{nullptr};
|
||||
mutable VarPtr gamma1_{nullptr};
|
||||
mutable VarPtr beta2_{nullptr};
|
||||
mutable VarPtr gamma2_{nullptr};
|
||||
mutable VarPtr weight_attn_qkv_{nullptr};
|
||||
mutable VarPtr weight_attn_qkv_cross_{nullptr};
|
||||
mutable VarPtr weight_attn_o_{nullptr};
|
||||
mutable VarPtr weight_m_{nullptr};
|
||||
mutable VarPtr weight_p_{nullptr};
|
||||
mutable VarPtr bias_attn_qkv_{nullptr};
|
||||
mutable VarPtr bias_attn_o_{nullptr};
|
||||
mutable VarPtr bias_m_{nullptr};
|
||||
mutable VarPtr bias_p_{nullptr};
|
||||
mutable VarPtr mask_{nullptr};
|
||||
mutable VarPtr is_attention_{nullptr};
|
||||
mutable VarPtr is_layernorm1_{nullptr};
|
||||
mutable VarPtr is_layernorm2_{nullptr};
|
||||
mutable bool is_position_bias_{false};
|
||||
mutable VarPtr is_act_{nullptr};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_
|
|
@ -39,30 +39,34 @@ bool MultiHeadAttentionFusion::Init() const {
|
|||
MS_CHECK_TRUE_RET(input_k_ != nullptr, false);
|
||||
input_v_ = std::make_shared<Var>("input_v");
|
||||
MS_CHECK_TRUE_RET(input_v_ != nullptr, false);
|
||||
input1_ = std::make_shared<Var>("input1_");
|
||||
MS_CHECK_TRUE_RET(input1_ != nullptr, false);
|
||||
input2_ = std::make_shared<Var>("input2_");
|
||||
MS_CHECK_TRUE_RET(input2_ != nullptr, false);
|
||||
position_bias_ = std::make_shared<Var>("position_bias_");
|
||||
MS_CHECK_TRUE_RET(position_bias_ != nullptr, false);
|
||||
|
||||
weight_q_ = std::make_shared<CondVar>(IsParamNode, "weight_q");
|
||||
MS_CHECK_TRUE_RET(weight_q_ != nullptr, false);
|
||||
weight_k_ = std::make_shared<CondVar>(IsParamNode, "weight_k");
|
||||
MS_CHECK_TRUE_RET(weight_k_ != nullptr, false);
|
||||
weight_v_ = std::make_shared<CondVar>(IsParamNode, "weight_v");
|
||||
MS_CHECK_TRUE_RET(weight_v_ != nullptr, false);
|
||||
weight_o_ = std::make_shared<CondVar>(IsParamNode);
|
||||
weight_o_ = std::make_shared<CondVar>(IsParamNode, "weight_o");
|
||||
MS_CHECK_TRUE_RET(weight_o_ != nullptr, false);
|
||||
|
||||
weight_o2_ = std::make_shared<CondVar>(IsParamNode, "weight_o2");
|
||||
MS_CHECK_TRUE_RET(weight_o2_ != nullptr, false);
|
||||
bias_q_ = std::make_shared<CondVar>(IsParamNode, "bias_q");
|
||||
MS_CHECK_TRUE_RET(bias_q_ != nullptr, false);
|
||||
bias_k_ = std::make_shared<CondVar>(IsParamNode, "bias_k");
|
||||
MS_CHECK_TRUE_RET(bias_k_ != nullptr, false);
|
||||
bias_v_ = std::make_shared<CondVar>(IsParamNode, "bias_v");
|
||||
MS_CHECK_TRUE_RET(bias_v_ != nullptr, false);
|
||||
bias_o_ = std::make_shared<CondVar>(IsParamNode);
|
||||
bias_o_ = std::make_shared<CondVar>(IsParamNode, "bias_o");
|
||||
MS_CHECK_TRUE_RET(bias_o_ != nullptr, false);
|
||||
|
||||
bias_o2_ = std::make_shared<CondVar>(IsParamNode, "bias_o2");
|
||||
MS_CHECK_TRUE_RET(bias_o2_ != nullptr, false);
|
||||
mask_ = std::make_shared<Var>("mask");
|
||||
MS_CHECK_TRUE_RET(mask_ != nullptr, false);
|
||||
|
||||
reshape_k_ = std::make_shared<Var>("reshape_k");
|
||||
MS_CHECK_TRUE_RET(reshape_k_ != nullptr, false);
|
||||
reshape_v_ = std::make_shared<Var>("reshape_v");
|
||||
|
@ -78,7 +82,7 @@ bool MultiHeadAttentionFusion::Init() const {
|
|||
|
||||
namespace {
|
||||
VectorRef DefineMask(const BaseRef &mask_input) {
|
||||
auto is_expand_dims = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims));
|
||||
auto is_expand_dims = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims), "m-expand");
|
||||
MS_CHECK_TRUE_RET(is_expand_dims != nullptr, {});
|
||||
auto var1 = std::make_shared<Var>("m-var1");
|
||||
MS_CHECK_TRUE_RET(var1 != nullptr, {});
|
||||
|
@ -94,6 +98,7 @@ VectorRef DefineMask(const BaseRef &mask_input) {
|
|||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
return VectorRef({is_mul, sub, var3});
|
||||
}
|
||||
|
||||
STATUS GetIntParameterData(const ParameterPtr ¶m_ptr, std::vector<int> *result) {
|
||||
if (param_ptr == nullptr || !param_ptr->has_default()) {
|
||||
MS_LOG(DEBUG) << "param not have default";
|
||||
|
@ -139,29 +144,40 @@ STATUS GetAxis(const BaseRef &n, std::vector<int> *axes) {
|
|||
|
||||
VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias,
|
||||
const BaseRef &axis, const BaseRef &transpose_var, bool test_div,
|
||||
bool transpose) const {
|
||||
bool transpose, bool mul) const {
|
||||
auto is_matmul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "e-matmul");
|
||||
MS_CHECK_TRUE_RET(is_matmul != nullptr, {});
|
||||
auto dense = VectorRef({is_matmul, input, weight, bias});
|
||||
auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "e-reshape");
|
||||
MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
|
||||
auto reshape = VectorRef({is_reshape, dense, axis});
|
||||
auto var2 = std::make_shared<Var>();
|
||||
VectorRef conn;
|
||||
if (transpose) {
|
||||
conn = VectorRef({transpose_var, reshape, var2});
|
||||
auto var2 = std::make_shared<Var>("var2");
|
||||
MS_CHECK_TRUE_RET(var2 != nullptr, {});
|
||||
VectorRef conn1;
|
||||
if (mul) {
|
||||
auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "e-mul");
|
||||
MS_CHECK_TRUE_RET(is_mul != nullptr, {});
|
||||
conn1 = VectorRef({is_mul, reshape, var2});
|
||||
} else {
|
||||
conn = reshape;
|
||||
conn1 = reshape;
|
||||
}
|
||||
auto var3 = std::make_shared<Var>("var3");
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
VectorRef conn2;
|
||||
if (transpose) {
|
||||
conn2 = VectorRef({transpose_var, conn1, var3});
|
||||
} else {
|
||||
conn2 = conn1;
|
||||
}
|
||||
if (test_div) {
|
||||
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "e-div");
|
||||
MS_CHECK_TRUE_RET(is_div != nullptr, {});
|
||||
auto var3 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
auto div = VectorRef({is_div, conn, var3});
|
||||
auto var4 = std::make_shared<Var>("var4");
|
||||
MS_CHECK_TRUE_RET(var4 != nullptr, {});
|
||||
auto div = VectorRef({is_div, conn2, var4});
|
||||
return div;
|
||||
}
|
||||
return conn;
|
||||
return conn2;
|
||||
}
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &axis,
|
||||
|
@ -172,7 +188,7 @@ VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const
|
|||
auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "e-reshape");
|
||||
MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
|
||||
auto reshape = VectorRef({is_reshape, dense, axis});
|
||||
auto var2 = std::make_shared<Var>();
|
||||
auto var2 = std::make_shared<Var>("var2");
|
||||
VectorRef conn;
|
||||
if (transpose) {
|
||||
conn = VectorRef({transpose_var, reshape, var2});
|
||||
|
@ -182,7 +198,7 @@ VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const
|
|||
if (test_div) {
|
||||
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "e-div");
|
||||
MS_CHECK_TRUE_RET(is_div != nullptr, {});
|
||||
auto var3 = std::make_shared<Var>();
|
||||
auto var3 = std::make_shared<Var>("var3");
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
auto div = VectorRef({is_div, conn, var3});
|
||||
return div;
|
||||
|
@ -196,20 +212,20 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool mask) const {
|
|||
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
|
||||
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true);
|
||||
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, true);
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_);
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
|
||||
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
|
||||
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
|
||||
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape1");
|
||||
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
|
||||
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
|
||||
auto var1 = std::make_shared<Var>();
|
||||
auto var1 = std::make_shared<Var>("var1");
|
||||
MS_CHECK_TRUE_RET(var1 != nullptr, {});
|
||||
VectorRef reshape1;
|
||||
if (mask) {
|
||||
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion));
|
||||
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
|
||||
MS_CHECK_TRUE_RET(is_add != nullptr, {});
|
||||
auto mask = DefineMask(mask_);
|
||||
MS_CHECK_TRUE_RET(!mask.empty(), {});
|
||||
|
@ -218,28 +234,28 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool mask) const {
|
|||
} else {
|
||||
reshape1 = VectorRef({is_reshape1, matmul1, var1});
|
||||
}
|
||||
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax));
|
||||
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax), "is_softmax");
|
||||
MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
|
||||
auto softmax = VectorRef({is_softmax, reshape1});
|
||||
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
|
||||
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape");
|
||||
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
|
||||
auto var2 = std::make_shared<Var>();
|
||||
auto var2 = std::make_shared<Var>("var2");
|
||||
MS_CHECK_TRUE_RET(var2 != nullptr, {});
|
||||
auto reshape2 = VectorRef({is_reshape2, softmax, var2});
|
||||
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
|
||||
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
|
||||
auto matmul2 = VectorRef({is_matmul2, reshape2, v_embedding});
|
||||
auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
|
||||
auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "is_transpose");
|
||||
MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
|
||||
auto var3 = std::make_shared<Var>();
|
||||
auto var3 = std::make_shared<Var>("var3");
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
auto transpose = VectorRef({is_transpose, matmul2, var3});
|
||||
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
|
||||
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshpae");
|
||||
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
|
||||
auto var4 = std::make_shared<Var>();
|
||||
auto var4 = std::make_shared<Var>("var4");
|
||||
MS_CHECK_TRUE_RET(var4 != nullptr, {});
|
||||
auto reshape3 = VectorRef({is_reshape3, transpose, var4});
|
||||
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul3");
|
||||
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
|
||||
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_, bias_o_});
|
||||
return matmul3;
|
||||
|
@ -290,18 +306,26 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5() const {
|
|||
return matmul3;
|
||||
}
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose) const {
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose, bool no_div_flag) const {
|
||||
VectorRef k_embedding, v_embedding;
|
||||
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "q_transpose");
|
||||
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
|
||||
VectorRef q_embedding;
|
||||
if (transpose) {
|
||||
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, true);
|
||||
if (no_div_flag) {
|
||||
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, false, true);
|
||||
} else {
|
||||
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, true);
|
||||
}
|
||||
} else {
|
||||
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, false);
|
||||
}
|
||||
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, reshape_k_, k_transpose_, true, true);
|
||||
if (no_div_flag) {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, reshape_k_, k_transpose_, false, true);
|
||||
} else {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, reshape_k_, k_transpose_, true, true);
|
||||
}
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, reshape_v_, v_transpose_, false, true);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
|
@ -323,7 +347,12 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose)
|
|||
auto reshape1 = VectorRef({is_reshape1, add2, var1});
|
||||
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax), "softmax");
|
||||
MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
|
||||
auto softmax = VectorRef({is_softmax, reshape1});
|
||||
VectorRef softmax, matmul2;
|
||||
if (no_div_flag) {
|
||||
softmax = VectorRef({is_softmax, add2});
|
||||
} else {
|
||||
softmax = VectorRef({is_softmax, reshape1});
|
||||
}
|
||||
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape2");
|
||||
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
|
||||
auto var2 = std::make_shared<Var>("var2");
|
||||
|
@ -331,7 +360,11 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose)
|
|||
auto reshape2 = VectorRef({is_reshape2, softmax, var2});
|
||||
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul2");
|
||||
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
|
||||
auto matmul2 = VectorRef({is_matmul2, reshape2, v_embedding});
|
||||
if (no_div_flag) {
|
||||
matmul2 = VectorRef({is_matmul2, softmax, v_embedding});
|
||||
} else {
|
||||
matmul2 = VectorRef({is_matmul2, reshape2, v_embedding});
|
||||
}
|
||||
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape3");
|
||||
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
|
||||
auto var4 = std::make_shared<Var>("var4");
|
||||
|
@ -347,12 +380,12 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose)
|
|||
} else {
|
||||
reshape3 = VectorRef({is_reshape3, matmul2, var4});
|
||||
}
|
||||
|
||||
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul3");
|
||||
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
|
||||
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_});
|
||||
return matmul3;
|
||||
}
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA() const {
|
||||
VectorRef k_embedding, v_embedding;
|
||||
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
|
||||
|
@ -366,8 +399,6 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA() const {
|
|||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
|
||||
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
|
||||
auto var1 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var1 != nullptr, {});
|
||||
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion));
|
||||
MS_CHECK_TRUE_RET(is_add != nullptr, {});
|
||||
auto mask = DefineMask(mask_);
|
||||
|
@ -395,6 +426,66 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA() const {
|
|||
return matmul3;
|
||||
}
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPPatternSwin(bool flag) const {
|
||||
VectorRef k_embedding, v_embedding;
|
||||
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
|
||||
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
|
||||
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, false, true, true);
|
||||
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, false, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
|
||||
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
|
||||
auto var1 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var1 != nullptr, {});
|
||||
auto is_add1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion));
|
||||
MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
|
||||
auto add1 = VectorRef({is_add1, matmul1, var1});
|
||||
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax));
|
||||
MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
|
||||
VectorRef softmax;
|
||||
if (flag) {
|
||||
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape");
|
||||
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
|
||||
auto var2 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var2 != nullptr, {});
|
||||
auto reshape1 = VectorRef({is_reshape1, add1, var2});
|
||||
auto var3 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion));
|
||||
MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
|
||||
auto add2 = VectorRef({is_add2, reshape1, var3});
|
||||
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape");
|
||||
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
|
||||
auto var4 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var4 != nullptr, {});
|
||||
auto reshape2 = VectorRef({is_reshape2, add2, var4});
|
||||
softmax = VectorRef({is_softmax, reshape2});
|
||||
} else {
|
||||
softmax = VectorRef({is_softmax, add1});
|
||||
}
|
||||
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
|
||||
auto matmul2 = VectorRef({is_matmul2, softmax, v_embedding});
|
||||
auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
|
||||
MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
|
||||
auto var5 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var5 != nullptr, {});
|
||||
auto transpose = VectorRef({is_transpose, matmul2, var5});
|
||||
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape3");
|
||||
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
|
||||
auto var6 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var6 != nullptr, {});
|
||||
auto reshape3 = VectorRef({is_reshape3, transpose, var6});
|
||||
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul3");
|
||||
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
|
||||
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_, bias_o_});
|
||||
return matmul3;
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
STATUS TransposeMatrix(std::shared_ptr<tensor::Tensor> src, std::shared_ptr<tensor::Tensor> dst) {
|
||||
|
@ -442,7 +533,6 @@ std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<
|
|||
for (std::size_t i = 1; i < base_shape_size; ++i) {
|
||||
new_shape.push_back(base_shape.at(i));
|
||||
}
|
||||
|
||||
// calculate data
|
||||
auto concat_tensor = std::make_shared<tensor::Tensor>(base_data_type, new_shape);
|
||||
MS_CHECK_TRUE_RET(concat_tensor != nullptr, nullptr);
|
||||
|
@ -481,13 +571,15 @@ std::unordered_map<std::string, VectorRef> MultiHeadAttentionFusion::DefinePatte
|
|||
MS_LOG(ERROR) << "initial member failed.";
|
||||
return patterns;
|
||||
}
|
||||
|
||||
patterns[kMPAWithMaskPatternName] = DefineMPWithMaskPattern();
|
||||
patterns[kMPAPatternName] = DefineMPWithMaskPattern(false);
|
||||
patterns[kMPAWithMaskPatternNamePA] = DefineMPWithMaskPatternPA();
|
||||
patterns[kMPAWithMaskPatternNameT5] = DefineMPWithMaskPatternT5();
|
||||
patterns[kMPAWithMaskPatternNameT5New] = DefineMPWithMaskPatternT5New(false);
|
||||
patterns[kMPAWithMaskPatternNameT5New2] = DefineMPWithMaskPatternT5New(true, true);
|
||||
patterns[kMPAWithMaskTransposePatternNameT5New] = DefineMPWithMaskPatternT5New();
|
||||
patterns[kMPAPatternNameSwin1] = DefineMPPatternSwin();
|
||||
patterns[kMPAPatternNameSwin2] = DefineMPPatternSwin(false);
|
||||
return patterns;
|
||||
}
|
||||
|
||||
|
@ -496,7 +588,6 @@ bool MultiHeadAttentionFusion::CheckPattern(const EquivPtr &equiv, int *head_num
|
|||
MS_ASSERT(head_num != nullptr);
|
||||
MS_ASSERT(head_size != nullptr);
|
||||
std::vector<int> reshape_axes;
|
||||
// UNDO !!!!!!!!!
|
||||
if (GetAxis((*equiv)[reshape_axis_], &reshape_axes) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "cannot figure out reshape";
|
||||
return false;
|
||||
|
@ -523,44 +614,20 @@ AnfNodePtr MultiHeadAttentionFusion::Process(const std::string &pattern_name, co
|
|||
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
++match_count_;
|
||||
if ((pattern_name == kMPAWithMaskPatternName) || (pattern_name == kMPAWithMaskPatternNamePA) ||
|
||||
(pattern_name == kMPAWithMaskPatternNameT5) ||
|
||||
(pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New)) {
|
||||
if (pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New) {
|
||||
(pattern_name == kMPAWithMaskPatternNameT5) || (pattern_name == kMPAWithMaskPatternNameT5New) ||
|
||||
(pattern_name == kMPAWithMaskTransposePatternNameT5New) || (pattern_name == kMPAWithMaskPatternNameT5New2)) {
|
||||
if (pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New ||
|
||||
pattern_name == kMPAWithMaskPatternNameT5New2) {
|
||||
t5_x_ = true;
|
||||
}
|
||||
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), true);
|
||||
}
|
||||
if (pattern_name == kMPAPatternName)
|
||||
if (pattern_name == kMPAPatternName || pattern_name == kMPAPatternNameSwin1 || pattern_name == kMPAPatternNameSwin2)
|
||||
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// STATUS GetIntParameterData(const ParameterPtr ¶m_ptr, std::vector<int> *result) {
|
||||
// if (param_ptr == nullptr || !param_ptr->has_default()) {
|
||||
// MS_LOG(DEBUG) << "param not have default";
|
||||
// return RET_ERROR;
|
||||
// }
|
||||
// auto default_param = param_ptr->default_param();
|
||||
// if (default_param == nullptr || !utils::isa<tensor::TensorPtr>(default_param)) {
|
||||
// MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
|
||||
// return RET_ERROR;
|
||||
// }
|
||||
// auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
|
||||
// if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
|
||||
// MS_LOG(DEBUG) << "default param is not int";
|
||||
// return RET_ERROR;
|
||||
// }
|
||||
// auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
|
||||
// int64_t shape_size =
|
||||
// std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>());
|
||||
// for (int64_t i = 0; i < shape_size; i++) {
|
||||
// result->emplace_back(ptr[i]);
|
||||
// }
|
||||
// return RET_OK;
|
||||
// }
|
||||
|
||||
std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(const EquivPtr &equiv) const {
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
auto attention_prim = std::make_shared<ops::Attention>();
|
||||
|
@ -572,19 +639,16 @@ std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(con
|
|||
MS_LOG(ERROR) << "Reshape k is not a parameter";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!utils::isa<ParameterPtr>((*equiv)[reshape_v_])) {
|
||||
MS_LOG(ERROR) << "Reshape v is not a parameter";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto reshape_k = utils::cast<ParameterPtr>((*equiv)[reshape_k_]);
|
||||
std::vector<int> shape_k;
|
||||
if (RET_OK != GetIntParameterData(reshape_k, &shape_k)) {
|
||||
MS_LOG(ERROR) << "Get reshape k data failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto reshape_v = utils::cast<ParameterPtr>((*equiv)[reshape_v_]);
|
||||
std::vector<int> shape_v;
|
||||
if (RET_OK != GetIntParameterData(reshape_v, &shape_v)) {
|
||||
|
@ -694,7 +758,7 @@ std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::CreatePrim(const Equiv
|
|||
if (!CheckPattern(equiv, &head_num, &head_size)) {
|
||||
return nullptr;
|
||||
}
|
||||
attention_prim->Init(head_num, head_size, cross);
|
||||
attention_prim->Init(head_num, head_size, t5_x_, cross);
|
||||
return attention_prim;
|
||||
}
|
||||
|
||||
|
@ -726,7 +790,7 @@ bool MultiHeadAttentionFusion::IsCross(const EquivPtr &equiv) const {
|
|||
ret = FetchShapeFromAbstract(input_v->abstract(), &inputv_shape);
|
||||
MS_CHECK_TRUE_RET(ret == RET_OK, false);
|
||||
|
||||
if ((inputq_shape != inputv_shape) || ((match_count_ > 1) && (input_q != input_v))) {
|
||||
if ((inputq_shape != inputv_shape) || (input_q != input_v)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
@ -795,9 +859,8 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
|
|||
bool cross = IsCross(equiv);
|
||||
std::vector<AnfNodePtr> redundant;
|
||||
auto [weight_o, weight_q_tensor, weight_k_tensor, weight_v_tensor] = GetAttentionNodeWeights(equiv, &redundant);
|
||||
AnfNodePtr bias_q;
|
||||
AnfNodePtr bias_q, bias_o;
|
||||
ParameterPtr c_bias_param;
|
||||
AnfNodePtr bias_o;
|
||||
if (!t5_x_) {
|
||||
bias_q = utils::cast<AnfNodePtr>((*equiv)[bias_q_]);
|
||||
auto bias_k = utils::cast<AnfNodePtr>((*equiv)[bias_k_]);
|
||||
|
@ -824,9 +887,7 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
|
|||
if (!cross && !t5_x_) {
|
||||
redundant.push_back(bias_q);
|
||||
}
|
||||
|
||||
tensor::TensorPtr c_weights;
|
||||
tensor::TensorPtr q_weight_t;
|
||||
tensor::TensorPtr c_weights, q_weight_t;
|
||||
if (cross) {
|
||||
c_weights = ConcatTensors({weight_k_tensor, weight_v_tensor}, true);
|
||||
q_weight_t = ConcatTensors({weight_q_tensor}, true);
|
||||
|
@ -852,7 +913,6 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
|
|||
}
|
||||
std::vector<AnfNodePtr> new_node_inputs =
|
||||
GetNewNodeInputs(equiv, q_weight_param, c_weight_param, weight_o, c_bias_param, bias_o, mask, cross);
|
||||
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
|
||||
if (vnode) {
|
||||
|
|
|
@ -50,9 +50,11 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
|
|||
VectorRef DefineMPWithMaskPattern(bool mask = true) const;
|
||||
VectorRef DefineMPWithMaskPatternPA() const;
|
||||
VectorRef DefineMPWithMaskPatternT5() const;
|
||||
VectorRef DefineMPWithMaskPatternT5New(bool transpose = true) const;
|
||||
VectorRef DefineMPWithMaskPatternT5New(bool transpose = true, bool no_div_flag = false) const;
|
||||
VectorRef DefineMPPatternSwin(bool flag = true) const;
|
||||
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, const BaseRef &axis,
|
||||
const BaseRef &transpose_var, bool test_div = false, bool transpose = true) const;
|
||||
const BaseRef &transpose_var, bool test_div = false, bool transpose = true,
|
||||
bool mul = false) const;
|
||||
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &axis,
|
||||
const BaseRef &transpose_var, bool test_div, bool transpose) const;
|
||||
|
||||
|
@ -85,21 +87,28 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
|
|||
const std::string kMPAPatternName = "MPAPattern";
|
||||
const std::string kMPAWithMaskPatternNameT5 = "MPAWithMaskPatternT5";
|
||||
const std::string kMPAWithMaskPatternNameT5New = "MPAWithMaskPatternT5New";
|
||||
const std::string kMPAWithMaskPatternNameT5New2 = "MPAWithMaskPatternT5New2";
|
||||
const std::string kMPAWithMaskTransposePatternNameT5New = "MPAWithMaskTransposePatternT5New";
|
||||
const std::string kMPAPatternNameSwin1 = "MPAPatternNameSwin1";
|
||||
const std::string kMPAPatternNameSwin2 = "MPAPatternNameSwin2";
|
||||
|
||||
mutable VarPtr input_q_{nullptr};
|
||||
mutable VarPtr input_k_{nullptr};
|
||||
mutable VarPtr input_v_{nullptr};
|
||||
mutable VarPtr input1_{nullptr};
|
||||
mutable VarPtr input2_{nullptr};
|
||||
mutable VarPtr position_bias_{nullptr};
|
||||
|
||||
mutable VarPtr weight_q_{nullptr};
|
||||
mutable VarPtr weight_k_{nullptr};
|
||||
mutable VarPtr weight_v_{nullptr};
|
||||
mutable VarPtr weight_o_{nullptr};
|
||||
mutable VarPtr weight_o2_{nullptr};
|
||||
mutable VarPtr bias_q_{nullptr};
|
||||
mutable VarPtr bias_k_{nullptr};
|
||||
mutable VarPtr bias_v_{nullptr};
|
||||
mutable VarPtr bias_o_{nullptr};
|
||||
mutable VarPtr bias_o2_{nullptr};
|
||||
|
||||
mutable VarPtr mask_{nullptr};
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue