Adding encoder layer fusion

This commit is contained in:
nizzan 2023-01-18 15:36:22 +02:00
parent f4752a8ab8
commit 72a341b321
32 changed files with 5460 additions and 1874 deletions

View File

@ -125,7 +125,7 @@ static BaseRef GetVar(const BaseRef &x) {
EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
if (equiv == nullptr) MS_EXCEPTION_IF_NULL(equiv);
MS_EXCEPTION_IF_NULL(equiv);
if (utils::isa<VarPtr>(pattern)) {
VarPtr var = utils::cast<VarPtr>(pattern);
if (var->matches(expr)) {
@ -222,7 +222,9 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorR
return nullptr;
}
}
if ((values_expr_len == 0) && (values_pattern_len == 0)) return equiv;
if ((values_expr_len == 0) && (values_pattern_len == 0)) {
return equiv;
}
if (values_expr_len < values_pattern_len - 1) {
MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len;
return nullptr;

View File

@ -0,0 +1,36 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include "nnacl/infer/encoder_layer_infer.h"
#include "nnacl/infer/infer_register.h"
int EncoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
int check_ret = CheckAugmentWithMinSize(inputs, inputs_size, outputs, outputs_size, parameter, C9NUM, C1NUM);
if (check_ret != NNACL_OK) {
return check_ret;
}
const TensorC *input = inputs[FIRST_INPUT];
TensorC *output0 = outputs[FIRST_INPUT];
SetDataTypeFormat(output0, input);
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
SetShapeTensor(output0, input);
return NNACL_OK;
}
REG_INFER(EncoderLayer, PrimType_Inner_EncoderLayer, EncoderLayerInferShape)

View File

@ -0,0 +1,31 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_
#include "nnacl/infer/common_infer.h"
#ifdef __cplusplus
extern "C" {
#endif
int EncoderLayerInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_INFER_ENCODER_LAYER_INFER_H_

View File

@ -31,6 +31,7 @@
#include "nnacl/infer/assign_add_infer.h"
#include "nnacl/infer/assign_infer.h"
#include "nnacl/infer/attention_infer.h"
#include "nnacl/infer/encoder_layer_infer.h"
#include "nnacl/infer/audio_spectrogram_infer.h"
#include "nnacl/infer/batch_to_space_infer.h"
#include "nnacl/infer/bias_grad_infer.h"
@ -402,6 +403,8 @@ void RegAllInferFunc5() {
g_inner_op_infer_func[PrimType_Inner_Identity - PrimType_InnerOpMin] = NULL;
#ifndef RUNTIME_PASS_CLIP
g_inner_op_infer_func[PrimType_Inner_ShapeFusion - PrimType_InnerOpMin] = ShapeFusionInferShape;
g_inner_op_infer_func[PrimType_Inner_EncoderLayer - PrimType_InnerOpMin] = EncoderLayerInferShape;
#endif
g_inner_op_infer_func[PrimType_Inner_ToFormat - PrimType_InnerOpMin] = NULL;
}

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_OP_BASE_H_
#define MINDSPORE_NNACL_OP_BASE_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_
#include <stdint.h>
#include <stdlib.h>
@ -37,8 +37,10 @@
#define C8NUM 8
#define C9NUM 9
#define C10NUM 10
#define C11NUM 11
#define C12NUM 12
#define C13NUM 13
#define C14NUM 14
#define C16NUM 16
#define C20NUM 20
#define C21NUM 21
@ -533,6 +535,7 @@ enum PrimType {
PrimType_Inner_ShapeFusion = 10003,
PrimType_Inner_GraphKernel = 10004,
PrimType_Inner_SplitReduceConcatFusion = 10005,
PrimType_Inner_EncoderLayer = 10006,
PrimType_InnerOpMax,
PrimType_InnerOpMin = PrimType_Inner_ToFormat
};
@ -660,4 +663,4 @@ typedef enum CalFixedMultiplierMode {
Method_DoublePrecision
} CalFixedMultiplierMode;
#endif // MINDSPORE_NNACL_OP_BASE_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_NNACL_OP_BASE_H_

View File

@ -31,6 +31,10 @@ void Attention::set_head_size(int64_t head_size) {
void Attention::set_cross(bool cross) { (void)this->AddAttr(kCross, api::MakeValue(cross)); }
void Attention::set_position_bias(bool position_bias) {
(void)this->AddAttr(kPositionBias, api::MakeValue(position_bias));
}
int64_t Attention::get_head_num() const {
auto value_ptr = this->GetAttr(kAttentionNumHeads);
return GetValue<int64_t>(value_ptr);
@ -46,10 +50,16 @@ bool Attention::get_cross() const {
return GetValue<bool>(value_ptr);
}
void Attention::Init(int64_t head_num, int64_t head_size, bool cross) {
bool Attention::get_position_bias() const {
auto value_ptr = this->GetAttr(kPositionBias);
return GetValue<bool>(value_ptr);
}
void Attention::Init(int64_t head_num, int64_t head_size, bool position_bias, bool cross) {
this->set_head_num(head_num);
this->set_head_size(head_size);
this->set_cross(cross);
this->set_position_bias(position_bias);
}
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
} // namespace mindspore::ops

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ATTENTION_H_
#define LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ATTENTION_H_
#ifndef MINDSPORE_CORE_OPS_ATTENTION_H_
#define MINDSPORE_CORE_OPS_ATTENTION_H_
#include <map>
#include <vector>
#include <string>
@ -40,14 +40,17 @@ class MIND_API Attention : public BaseOperator {
/// \param[in] head_num Define head number.
/// \param[in] head_size Define size per head.
/// \param[in] cross Define is cross attention. Default false.
void Init(int64_t head_num, int64_t head_size, bool cross = false);
/// \param[in] position_bias Define is position bias attention.
void Init(int64_t head_num, int64_t head_size, bool position_bias, bool cross = false);
void set_head_num(int64_t head_num);
void set_head_size(int64_t head_size);
void set_cross(bool cross);
void set_position_bias(bool position_bias);
int64_t get_head_num() const;
int64_t get_head_size() const;
bool get_cross() const;
bool get_position_bias() const;
};
} // namespace ops
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_TOOLS_CONVERTER_OPS_ATTENTION_H_
#endif // MINDSPORE_CORE_OPS_ATTENTION_H_

View File

@ -0,0 +1,92 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/encoder_layer.h"
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore::ops {
MIND_API_OPERATOR_IMPL(EncoderLayer, BaseOperator);
void EncoderLayer::set_head_num(int64_t head_num) {
(void)this->AddAttr(kEncoderLayerNumHeads, api::MakeValue(head_num));
}
void EncoderLayer::set_head_size(int64_t head_size) {
(void)this->AddAttr(kEncoderLayerSizePerHead, api::MakeValue(head_size));
}
void EncoderLayer::set_post_layernorm(bool post_layernorm) {
(void)this->AddAttr(kEncoderLayerPostLayernorm, api::MakeValue(post_layernorm));
}
void EncoderLayer::set_eps_layernorm1(float eps_layernorm1) {
(void)this->AddAttr(kEncoderLayerEpsLayerNorm1, api::MakeValue(eps_layernorm1));
}
void EncoderLayer::set_eps_layernorm2(float eps_layernorm2) {
(void)this->AddAttr(kEncoderLayerEpsLayerNorm2, api::MakeValue(eps_layernorm2));
}
void EncoderLayer::set_ffn_hidden_size(int64_t ffn_hidden_size) {
(void)this->AddAttr(kEncoderLayerFfnHiddenSize, api::MakeValue(ffn_hidden_size));
}
void EncoderLayer::set_position_bias(bool position_bias) {
(void)this->AddAttr(kPositionBias, api::MakeValue(position_bias));
}
int64_t EncoderLayer::get_head_num() const {
auto value_ptr = this->GetAttr(kEncoderLayerNumHeads);
return GetValue<int64_t>(value_ptr);
}
int64_t EncoderLayer::get_head_size() const {
auto value_ptr = this->GetAttr(kEncoderLayerSizePerHead);
return GetValue<int64_t>(value_ptr);
}
bool EncoderLayer::get_post_layernorm() const {
auto value_ptr = this->GetAttr(kEncoderLayerPostLayernorm);
return GetValue<bool>(value_ptr);
}
float EncoderLayer::get_eps_layernorm1() const {
auto value_ptr = this->GetAttr(kEncoderLayerEpsLayerNorm1);
return GetValue<float>(value_ptr);
}
float EncoderLayer::get_eps_layernorm2() const {
auto value_ptr = this->GetAttr(kEncoderLayerEpsLayerNorm2);
return GetValue<float>(value_ptr);
}
int64_t EncoderLayer::get_ffn_hidden_size() const {
auto value_ptr = this->GetAttr(kEncoderLayerFfnHiddenSize);
return GetValue<int64_t>(value_ptr);
}
bool EncoderLayer::get_position_bias() const {
auto value_ptr = this->GetAttr(kPositionBias);
return GetValue<bool>(value_ptr);
}
void EncoderLayer::Init(int64_t head_num, int64_t head_size, float eps_layernorm1, float eps_layernorm2,
int64_t ffn_hidden_size, bool position_bias, bool post_layernorm = false) {
this->set_head_num(head_num);
this->set_head_size(head_size);
this->set_post_layernorm(post_layernorm);
this->set_eps_layernorm1(eps_layernorm1);
this->set_eps_layernorm2(eps_layernorm2);
this->set_ffn_hidden_size(ffn_hidden_size);
this->set_position_bias(position_bias);
}
REGISTER_PRIMITIVE_C(kNameEncoderLayer, EncoderLayer);
} // namespace mindspore::ops

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_ENCODER_LAYER_H_
#define MINDSPORE_CORE_OPS_ENCODER_LAYER_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameEncoderLayer = "EncoderLayer";
/// \brief EncoderLayer op in MindIR.
class MIND_API EncoderLayer : public BaseOperator {
public:
MIND_API_BASE_MEMBER(EncoderLayer);
/// \brief Constructor.
EncoderLayer() : BaseOperator(kNameEncoderLayer) {
InitIOName({"input", "gamma1", "beta1", "weight_attn_qkv", "bias_attn_qkv", "mask", "weight_attn_o", "bias_attn_o",
"gamma2", "beta2", "weight_m", "bias_m", "weight_p", "bias_p"},
{"output"});
}
/// \brief Initialize EncoderLayer op.
/// \param[in] head_num Define head number.
/// \param[in] head_size Define size per head.
/// \param[in] eps_layernorm1 Define eps layernorm1.
/// \param[in] eps_layernorm2 Define eps layernorm2.
/// \param[in] ffn_hidden_size Define ffn hidden size.
/// \param[in] position_bias Define ffn position_bias.
void Init(int64_t head_num, int64_t head_size, float eps_layernorm1, float eps_layernorm2, int64_t ffn_hidden_size,
bool position_bias, bool post_layernorm);
void set_head_num(int64_t head_num);
void set_head_size(int64_t head_size);
void set_post_layernorm(bool post_layernorm);
void set_eps_layernorm1(float eps_layernorm1);
void set_eps_layernorm2(float eps_layernorm2);
void set_ffn_hidden_size(int64_t ffn_hidden_size);
void set_position_bias(bool position_bias);
int64_t get_head_num() const;
int64_t get_head_size() const;
bool get_post_layernorm() const;
float get_eps_layernorm1() const;
float get_eps_layernorm2() const;
int64_t get_ffn_hidden_size() const;
bool get_position_bias() const;
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_ENCODER_LAYER_H_

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_OP_NAME_H
#define MINDSPORE_CORE_OPS_OP_NAME_H
#ifndef MINDSPORE_CORE_OPS_OP_NAME_H_
#define MINDSPORE_CORE_OPS_OP_NAME_H_
#include <cstddef>
namespace mindspore::ops {
@ -378,6 +378,13 @@ constexpr auto kSampleNum = "sample_num";
constexpr auto kRoiEndMode = "roi_end_mode";
constexpr auto kUpper = "upper";
constexpr auto kConjugate = "conjugate";
constexpr auto kEncoderLayerNumHeads = "head_num";
constexpr auto kEncoderLayerSizePerHead = "head_size";
constexpr auto kEncoderLayerPostLayernorm = "post_layernorm";
constexpr auto kEncoderLayerFfnHiddenSize = "ffn_hidden_size";
constexpr auto kEncoderLayerEpsLayerNorm1 = "eps_layernorm1";
constexpr auto kEncoderLayerEpsLayerNorm2 = "eps_layernorm2";
constexpr auto kPositionBias = "position_bias";
constexpr auto KExclusive = "exclusive";
constexpr auto KReverse = "reverse";
constexpr auto KComputeEigenvectors = "compute_eigenvectors";
@ -414,4 +421,4 @@ constexpr size_t kFormatNC1HWC0IndexW = 3;
constexpr size_t kFormatNC1HWC0IndexC0 = 4;
enum Dims : size_t { kDim0 = 0, kDim1, kDim2, kDim3, kDim4, kDim5, kDim6, kDim7, kDim8 };
} // namespace mindspore::ops
#endif // MINDSPORE_CORE_OPS_OP_NAME_H
#endif // MINDSPORE_CORE_OPS_OP_NAME_H_

View File

@ -83,6 +83,9 @@ class MS_API Converter {
void SetNoFusion(bool no_fusion);
bool GetNoFusion();
void SetOptimizeTransformer(bool optimize_transformer);
bool GetOptimizeTransformer();
inline void SetDevice(const std::string &device);
inline std::string GetDevice();

View File

@ -97,6 +97,16 @@ OpParameter *PopulateCustomParameter(const void *prim) {
param->op_parameter_.type_ = PrimType_Inner_SplitReduceConcatFusion;
return reinterpret_cast<OpParameter *>(param);
} else if (type == "EncoderLayer") {
std::cout << "EncoderLayer populate" << std::endl;
auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc EncoderLayer failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
param->type_ = PrimType_Inner_EncoderLayer;
return reinterpret_cast<OpParameter *>(param);
} else {
MS_LOG(ERROR) << "Unsupported custom type: " << type;
}

View File

@ -57,10 +57,10 @@ class DelegateRegistrar {
~DelegateRegistrar() = default;
};
#define REG_DELEGATE(device_type, provider, creator) \
static DelegateCreator func = [=](const std::shared_ptr<Context> &context, const ConfigInfos &config_infos) { \
return creator(context, config_infos); \
}; \
#define REG_DELEGATE(device_type, provider, creator) \
static DelegateCreator func = [](const std::shared_ptr<Context> &context, const ConfigInfos &config_infos) { \
return creator(context, config_infos); \
}; \
static DelegateRegistrar g_##device_type##provider##Delegate(device_type, provider, &func);
} // namespace mindspore

View File

@ -0,0 +1,255 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/extendrt/delegate/tensorrt/op/encoder_tensorrt.h"
#include <cuda_runtime.h>
#include <numeric>
#include <memory>
#include <vector>
#include <functional>
#include <unordered_map>
#include <algorithm>
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
#include "NvInferRuntimeCommon.h"
#include "ops/encoder_layer.h"
#include "src/fastertransformer/kernels/unfused_attention_kernels.h"
#include "src/fastertransformer/kernels/activation_kernels.h"
#include "src/fastertransformer/utils/cuda_utils.h"
#include "src/fastertransformer/utils/allocator.h"
#include "src/fastertransformer/kernels/layernorm_kernels.h"
namespace mindspore::lite {
namespace {
constexpr std::size_t kTwo = 2;
constexpr std::size_t kThree = 3;
} // namespace
// Multi Head Attention TensorRT op
int EncoderTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) {
if (in_tensors.size() != C14NUM) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
nvinfer1::ITensor *EncoderTensorRT::castTensor(TensorRTContext *ctx, const TensorInfo &ms_tensor,
const std::string &op_name) {
if (ctx == nullptr || ctx->network() == nullptr) {
MS_LOG(ERROR) << "context or network is null for ConvertConstantTensor";
return nullptr;
}
nvinfer1::Dims dims = ConvertCudaDims(ms_tensor.Shape());
if (dims.nbDims == -1) {
MS_LOG(INFO) << ms_tensor.Name() << " ConvertCudaDims failed, convert as scalar.";
dims.nbDims = 1;
dims.d[0] = 1;
}
nvinfer1::DataType data_type = ConvertDataType(ms_tensor.DataType());
if (!ms_tensor.IsConst()) {
MS_LOG(ERROR) << "ConvertConstantTensor from a MSTensor with nullptr data: " << ms_tensor.Name();
return nullptr;
}
nvinfer1::Weights weights{data_type, ms_tensor.Data(), ms_tensor.ElementNum()};
if (data_type == nvinfer1::DataType::kFLOAT && is_ffn_fp16_) {
void *data_float16 = malloc(ms_tensor.ElementNum() * sizeof(float));
if (data_float16 == nullptr) {
MS_LOG(ERROR) << "Malloc buffer failed.";
return nullptr;
}
auto src = static_cast<const float *>(ms_tensor.Data());
auto dst = static_cast<half *>(data_float16);
for (int i = 0; i < ms_tensor.ElementNum(); i++) {
dst[i] = static_cast<half>(src[i]);
}
weights.values = data_float16;
}
nvinfer1::IConstantLayer *constant_tensor = ctx->network()->addConstant(dims, weights);
if (constant_tensor == nullptr) {
MS_LOG(ERROR) << "create constant_tensor failed.";
return nullptr;
}
ctx->RegisterLayer(constant_tensor, ms_tensor.Name() + "_" + op_name);
auto tensor_ptr = constant_tensor->getOutput(0);
return tensor_ptr;
}
int EncoderTensorRT::AddInnerOp(TensorRTContext *ctx) {
if (ctx == nullptr || ctx->network() == nullptr) {
MS_LOG(ERROR) << "context or network is invalid";
return RET_ERROR;
}
auto encoder_op = AsOps<ops::EncoderLayer>();
if (encoder_op == nullptr) {
MS_LOG(ERROR) << "op action convert failed";
return RET_ERROR;
}
fastertransformer::encoderParamT params;
memset_s(&params, sizeof(params), 0, sizeof(params));
params.head_num = encoder_op->get_head_num();
params.head_size = encoder_op->get_head_size();
params.layernorm_post = encoder_op->get_post_layernorm();
params.eps1 = encoder_op->get_eps_layernorm1();
params.eps2 = encoder_op->get_eps_layernorm2();
params.ffn_hidden_size = encoder_op->get_ffn_hidden_size();
params.is_cross = false;
params.ffn_fp16 = is_ffn_fp16_;
params.position_bias = encoder_op->get_position_bias();
params.cublas_handle = GetCublasHandle();
params.qkv_bias = !params.position_bias;
params.projection_bias = !params.position_bias;
params.hidden_size = params.head_num * params.head_size;
auto compute_type = runtime_->GetRuntimePrecisionMode();
if (is_ffn_fp16_) {
size_t start_fp16 = (params.layernorm_post) ? C7NUM : C9NUM;
size_t end_fp16 = (params.layernorm_post) ? C11NUM : C13NUM;
for (size_t i = 0; i < in_tensors_.size(); i++) {
auto in_tensor = input(ctx, i);
if (in_tensors_[i].IsConst() || in_tensor.trt_tensor_ == nullptr) {
if (i > start_fp16 && i < end_fp16) {
in_tensor.trt_tensor_ = castTensor(ctx, in_tensors_[i], op_name_);
ctx->RegisterTensor(in_tensor, in_tensors_[i].Name());
} else {
in_tensor.trt_tensor_ = lite::ConvertConstantTensor(ctx, in_tensors_[i], op_name_);
ctx->RegisterTensor(in_tensor, in_tensors_[i].Name());
}
}
}
}
nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_;
auto plugin =
std::make_shared<EncoderPlugin>(input_tensor->getName(), compute_type, params, GetCublasLtHandle(), device_id_);
const int input_number = inputs().size();
nvinfer1::ITensor *inputTensors[input_number];
for (int i = 0; i < input_number; i++) {
inputTensors[i] = input(ctx, i).trt_tensor_;
}
nvinfer1::IPluginV2Layer *encoder_layer = ctx->network()->addPluginV2(inputTensors, input_number, *plugin);
if (encoder_layer == nullptr) {
MS_LOG(ERROR) << "add encoder op failed for TensorRT.";
return RET_ERROR;
}
encoder_layer->setName((op_name_ + "plugin_encoder_layer").c_str());
nvinfer1::ITensor *encoder_tensor = encoder_layer->getOutput(0);
ctx->RegisterTensor(ITensorHelper{encoder_tensor, Format::NCHW, true}, out_tensors_[0].Name());
this->layer_ = encoder_layer;
return RET_OK;
}
REGISTER_TENSORRT_PLUGIN(EncoderPluginCreater);
template class TensorRTPluginCreater<EncoderPlugin>;
template <class T>
nvinfer1::PluginFieldCollection TensorRTPluginCreater<T>::field_collection_{};
template <class T>
std::vector<nvinfer1::PluginField> TensorRTPluginCreater<T>::fields_;
int EncoderPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace,
cudaStream_t stream) noexcept {
if (compute_type_ == RuntimePrecisionMode_FP16) {
return RunCudaEncoder<half>(inputDesc, outputDesc, inputs, outputs, workspace, stream,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} else {
return RunCudaEncoder<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
}
template <typename T>
int EncoderPlugin::RunCudaEncoder(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream, cublasGemmAlgo_t algoId) {
params_.stream = stream;
params_.algo = algoId;
void *inputs_forward[] = {
const_cast<void *>(inputs[0]), const_cast<void *>(inputs[1]), const_cast<void *>(inputs[2]),
const_cast<void *>(inputs[3]), const_cast<void *>(inputs[4]), const_cast<void *>(inputs[5]),
const_cast<void *>(inputs[6]), const_cast<void *>(inputs[7]), const_cast<void *>(inputs[8]),
const_cast<void *>(inputs[9]), const_cast<void *>(inputs[10]), const_cast<void *>(inputs[11]),
const_cast<void *>(inputs[12]), const_cast<void *>(inputs[13])};
void *outputs_forward[] = {outputs[0]};
fastertransformer::forwardEncoder<T>(inputs_forward, num_of_inputs_, outputs_forward, num_of_outputs_, &params_,
workspace);
return RET_OK;
}
bool EncoderPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
int nbOutputs) noexcept {
auto type = (compute_type_ == RuntimePrecisionMode_FP16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;
for (int i = 0; i < pos; i++) {
if (tensorsDesc[pos].type != tensorsDesc[i].type) return false;
}
bool res = (tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) && (tensorsDesc[pos].type == type);
return res;
}
void EncoderPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {
const int request_batch_size = static_cast<const int>(in[0].desc.dims.d[0]);
const int request_src_seq_len = static_cast<const int>(in[0].desc.dims.d[1]);
const int request_tgt_seq_len = request_src_seq_len;
params_.batch_size = request_batch_size;
params_.src_seq_len = request_src_seq_len;
params_.tgt_seq_len = request_tgt_seq_len;
num_of_inputs_ = nbInputs;
num_of_outputs_ = nbOutputs;
}
size_t EncoderPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
if (compute_type_ == RuntimePrecisionMode_FP16) {
return fastertransformer::GetEncoderLayerWorkspaceSize<half>(&params_);
} else {
return fastertransformer::GetEncoderLayerWorkspaceSize<float>(&params_);
}
}
nvinfer1::DimsExprs EncoderPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs,
int nbInputDims, nvinfer1::IExprBuilder &exprBuilder) noexcept {
nvinfer1::DimsExprs dims;
if (index == 0) {
int num_dims = inputs[0].nbDims;
dims.nbDims = num_dims;
if (num_dims == INPUT_SIZE2) {
dims.d[0] = exprBuilder.constant(inputs[0].d[0]->getConstantValue());
dims.d[1] = exprBuilder.constant(inputs[0].d[1]->getConstantValue());
} else if (num_dims == INPUT_SIZE3) {
dims.d[0] = exprBuilder.constant(inputs[0].d[0]->getConstantValue());
dims.d[1] = exprBuilder.constant(inputs[0].d[1]->getConstantValue());
dims.d[kTwo] = exprBuilder.constant(inputs[0].d[kTwo]->getConstantValue());
}
}
return dims;
}
nvinfer1::IPluginV2DynamicExt *EncoderPlugin::clone() const noexcept {
auto *plugin = new EncoderPlugin(*this);
if (plugin == nullptr) {
MS_LOG(ERROR) << "plugin is null";
return nullptr;
}
plugin->setPluginNamespace(name_space_.c_str());
return plugin;
}
size_t EncoderPlugin::getSerializationSize() const noexcept {
return sizeof(int) + sizeof(fastertransformer::encoderParamT);
}
void EncoderPlugin::serialize(void *buffer) const noexcept {
SerializeValue(&buffer, &compute_type_, sizeof(int));
SerializeValue(&buffer, &params_, sizeof(fastertransformer::encoderParamT));
}
REGISTER_TENSORRT_CREATOR(ops::kNameEncoderLayer, EncoderTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,106 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ENCODER_TENSORRT_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ENCODER_TENSORRT_H_
#include <string>
#include <vector>
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h"
#include "src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h"
#include "src/fastertransformer/layers/encoder_layers/encoder.h"
namespace mindspore::lite {
class EncoderTensorRT : public TensorRTOp {
public:
EncoderTensorRT(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors, std::string name)
: TensorRTOp(base_operator, in_tensors, out_tensors, name) {}
~EncoderTensorRT() override = default;
bool IsWeightInputHanledInner() const override { return is_ffn_fp16_; }
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) override;
private:
nvinfer1::ITensor *castTensor(TensorRTContext *ctx, const TensorInfo &ms_tensor, const std::string &op_name);
bool is_ffn_fp16_ = false;
};
constexpr auto ENCODER_PLUGIN_NAME{"EncoderPlugin"};
class EncoderPlugin : public TensorRTPlugin {
public:
EncoderPlugin(const std::string name, int compute_type, fastertransformer::encoderParamT params,
cublasLtHandle_t cublaslt_handle, uint32_t device_id)
: TensorRTPlugin(name, std::string(ENCODER_PLUGIN_NAME), device_id),
compute_type_(compute_type),
params_(params),
cublaslt_handle_(cublaslt_handle) {}
EncoderPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
: TensorRTPlugin(std::string(name), std::string(ENCODER_PLUGIN_NAME)) {
const nvinfer1::PluginField *fields = fc->fields;
compute_type_ = static_cast<const int *>(fields[0].data)[0];
params_ = static_cast<const fastertransformer::encoderParamT *>(fields[1].data)[0];
cublaslt_handle_ = static_cast<const cublasLtHandle_t *>(fields[2].data)[0];
}
EncoderPlugin(const char *name, const void *serialData, size_t serialLength)
: TensorRTPlugin(std::string(name), std::string(ENCODER_PLUGIN_NAME)) {
DeserializeValue(&serialData, &serialLength, &compute_type_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &params_, sizeof(fastertransformer::encoderParamT));
}
EncoderPlugin() = delete;
~EncoderPlugin() override {}
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void *buffer) const noexcept override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
int nbOutputs) noexcept override;
private:
const std::string layer_name_;
std::string name_space_;
int compute_type_;
mutable fastertransformer::encoderParamT params_;
cublasLtHandle_t cublaslt_handle_;
int num_of_inputs_;
int num_of_outputs_;
template <typename T>
int RunCudaEncoder(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream,
cublasGemmAlgo_t algoId);
};
class EncoderPluginCreater : public TensorRTPluginCreater<EncoderPlugin> {
public:
EncoderPluginCreater() : TensorRTPluginCreater(std::string(ENCODER_PLUGIN_NAME)) {}
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_ENCODER_TENSORRT_H_

View File

@ -45,27 +45,11 @@ std::ostream &operator<<(std::ostream &s, const nvinfer1::ITensor &t) {
}
} // namespace
#define SET_GEMM_PARAMS(gemm_ops_, gemm_lds_, gemm_op1_, gemm_op2_, gemm_ld1_, gemm_ld2_, gemm_ld3_) \
do { \
gemm_ops_[0] = gemm_op1_; \
gemm_ops_[1] = gemm_op2_; \
gemm_lds_[0] = gemm_ld1_; \
gemm_lds_[1] = gemm_ld2_; \
gemm_lds_[2] = gemm_ld3_; \
} while (0)
#define SET_GEMM_DIMS(gemm_dims_, gemm_dim1_, gemm_dim2_, gemm_dim3_) \
do { \
gemm_dims_[0] = gemm_dim1_; \
gemm_dims_[1] = gemm_dim2_; \
gemm_dims_[2] = gemm_dim3_; \
} while (0)
// Multi Head Attention TensorRT op
int MhaTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) {
if (in_tensors.size() < 7 || in_tensors.size() > 9) { // T5 has 6 or 7 inputs, other models have 8 or 9 inputs
MS_LOG(ERROR) << "Unsupported number of inputs, size is " << in_tensors.size();
if (in_tensors.size() < C7NUM || in_tensors.size() > C9NUM) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
return RET_OK;
@ -81,16 +65,25 @@ int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) {
MS_LOG(ERROR) << "op action convert failed";
return RET_ERROR;
}
// get attribute for Attn op - TODO - add attribute in op
int head_number = mha_op->get_head_num();
int head_size = mha_op->get_head_size();
auto compute_type = runtime_->GetRuntimePrecisionMode(); // mha_op->get_compute_type();
auto compute_type = runtime_->GetRuntimePrecisionMode();
bool is_cross = mha_op->get_cross();
const int input_number = inputs().size();
bool is_position_bias = (((input_number == 8) && is_cross) || ((input_number == 7) && !is_cross)) ? true : false;
bool is_position_bias = mha_op->get_position_bias();
nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_;
auto plugin = std::make_shared<MhaPlugin>(input_tensor->getName(), compute_type, head_number, head_size, is_cross,
is_position_bias, GetCublasHandle(), GetCublasLtHandle(), device_id_);
fastertransformer::encoderParamT params;
memset_s(&params, sizeof(params), 0, sizeof(params));
params.head_num = head_number;
params.head_size = head_size;
params.hidden_size = head_number * head_size;
params.cublas_handle = GetCublasHandle();
params.qkv_bias = !is_position_bias;
params.projection_bias = !is_position_bias;
params.is_cross = is_cross;
params.position_bias = is_position_bias;
auto plugin =
std::make_shared<MhaPlugin>(input_tensor->getName(), compute_type, params, GetCublasLtHandle(), device_id_);
const int input_number = inputs().size();
nvinfer1::ITensor *inputTensors[input_number];
for (int i = 0; i < input_number; i++) {
inputTensors[i] = input(ctx, i).trt_tensor_;
@ -101,17 +94,12 @@ int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) {
return RET_ERROR;
}
mha_layer->setName((op_name_ + "plugin_attention").c_str());
// TODO(haim) one output
nvinfer1::ITensor *attn_tensor = mha_layer->getOutput(0);
#ifndef TEST_
ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name());
#else /* TEST_ */
ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name() + "attn");
#endif /* TEST_ */
// nvinfer1::ITensor *key_tensor = mha_layer->getOutput(1);
// ctx->RegisterTensor(ITensorHelper{key_tensor, Format::NCHW, true}, out_tensors_[1].Name());
// nvinfer1::ITensor *value_tensor = mha_layer->getOutput(kTwo);
// ctx->RegisterTensor(ITensorHelper{value_tensor, Format::NCHW, true}, out_tensors_[kTwo].Name());
this->layer_ = mha_layer;
#ifdef TEST_
auto weight_projection = input(ctx, 4).trt_tensor_;
@ -154,166 +142,51 @@ std::vector<nvinfer1::PluginField> TensorRTPluginCreater<T>::fields_;
int MhaPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept {
if (compute_type_ == RuntimePrecisionMode_FP16) {
return RunCudaMha<half>(inputDesc, outputDesc, inputs, outputs, workspace, stream, fast_algo_gemm);
return RunCudaMha<half>(inputDesc, outputDesc, inputs, outputs, workspace, stream, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} else {
return RunCudaMha<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream, fast_algo_gemm);
}
}
template <typename T>
void MhaPlugin::SetInnerAddr(void *workspace, size_t size_q, size_t size_k, size_t qk_buf_len, size_t qkv_buf_2_len,
size_t extra_size) {
size_t qkv_len = size_q + (size_k * 2); // size_v is equal to size_k
size_t q_buf_2_len = size_q;
auto buff_size =
qkv_len + q_buf_2_len + qk_buf_len + (qkv_buf_2_len * 2); // qkv_buf_3_ len is equal to qkv_buf_2_len
qkv_buf_ = workspace;
q_buf_2_ = static_cast<T *>(qkv_buf_) + qkv_len;
qk_buf_ = static_cast<T *>(q_buf_2_) + q_buf_2_len;
qkv_buf_2_ = static_cast<T *>(qk_buf_) + qk_buf_len;
qkv_buf_3_ = static_cast<T *>(qkv_buf_2_) + qkv_buf_2_len;
output1_ = static_cast<T *>(workspace) + buff_size;
output2_ = static_cast<T *>(output1_) + extra_size;
}
template <typename T>
void MhaPlugin::RunPhase1GEMM(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, int *gemm_dims,
int *gemm_lds, cublasOperation_t *gemm_ops, cudaDataType *gemm_data_types, void *alpha,
void *beta, cublasGemmAlgo_t algoId, cudaStream_t stream) {
int cross_tensor_offset = 0;
if (is_cross_) cross_tensor_offset = 1;
const int from_tensor_idx = 0, encoder_tensor_idx = 1, weight_qkv_tensor_idx = 3;
const int weight_qkv_tensor_idx_cross = 3 + cross_tensor_offset;
const int bias_qkv_tensor_idx = 5 + cross_tensor_offset;
const int attn_mask_tensor_idx = (is_position_bias_) ? 6 + cross_tensor_offset : 7 + cross_tensor_offset;
auto from_tensor = static_cast<const T *>(inputs[from_tensor_idx]);
auto encoder_output_tensor = static_cast<const T *>(inputs[encoder_tensor_idx]);
auto weight_q = static_cast<const T *>(inputs[weight_qkv_tensor_idx]);
auto weight_kv = static_cast<const T *>(inputs[weight_qkv_tensor_idx_cross]);
auto weight_qkv = static_cast<const T *>(inputs[weight_qkv_tensor_idx_cross]);
auto bias_qkv = (is_position_bias_) ? nullptr : static_cast<const T *>(inputs[bias_qkv_tensor_idx]);
auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims;
const int request_batch_size = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[0]);
const int request_src_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]);
const int request_tgt_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]);
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
if (is_cross_) {
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, hidden_size, hidden_size, hidden_size);
SET_GEMM_DIMS(gemm_dims, hidden_size, request_batch_size * request_src_seq_len, hidden_size);
CublasGemmWrapper(weight_q, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops,
const_cast<const cudaDataType *>(gemm_data_types), alpha, beta, cublas_handle_);
SET_GEMM_DIMS(gemm_dims, C2NUM * hidden_size, request_batch_size * request_tgt_seq_len, hidden_size);
gemm_lds[0] = gemm_lds[THIRD_INPUT] = C2NUM * hidden_size;
CublasGemmWrapper(weight_kv, encoder_output_tensor,
static_cast<T *>(qkv_buf_) + (request_batch_size * request_src_seq_len) * hidden_size, gemm_dims,
gemm_lds, gemm_ops, const_cast<const cudaDataType *>(gemm_data_types), alpha, beta,
cublas_handle_);
fastertransformer::invokeCrossAddFusedQKVBiasTranspose(
static_cast<T *>(q_buf_2_), static_cast<T *>(output1_), static_cast<T *>(output2_), static_cast<T *>(qkv_buf_),
bias_qkv, request_batch_size, request_src_seq_len, request_tgt_seq_len, head_number_, head_size_, stream);
} else {
CublasGemmWrapper(weight_qkv, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops,
const_cast<const cudaDataType *>(gemm_data_types), alpha, beta, cublas_handle_, algoId);
fastertransformer::invokeAddFusedQKVBiasTranspose(
static_cast<T *>(q_buf_2_), static_cast<T *>(output1_), static_cast<T *>(output2_), static_cast<T *>(qkv_buf_),
bias_qkv, request_batch_size, request_src_seq_len, head_number_, head_size_, 0, stream);
return RunCudaMha<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
}
template <typename T>
int MhaPlugin::RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream,
cublasGemmAlgo_t *algoId) {
// inputs order:
// 0] Q
// 1] K
// 2] V
// 3] W
// 4] PW
// 5] B
// 6] PB
// 7] AttnMask
// inputs order cross:
// 0] Q
// 1] K enco output
// 2] V
// 3] Wq
// 4] Wkv
// 5] PW
// 6] Bqkv
// 7] PB
// 8] AttnMask
int cross_tensor_offset = 0;
cublasSetStream(cublas_handle_, stream);
if (is_cross_) cross_tensor_offset = 1;
cublasGemmAlgo_t algoId) {
int cross_tensor_offset = (params_.is_cross) ? 1 : 0;
const int weight_projection_tensor_idx = 4 + cross_tensor_offset;
const int bias_projection_tensor_idx = 6 + cross_tensor_offset;
const int attn_mask_tensor_idx = (is_position_bias_) ? 6 + cross_tensor_offset : 7 + cross_tensor_offset;
const int bias_position_tensor_idx = 5 + cross_tensor_offset;
auto attention_mask = static_cast<const T *>(inputs[attn_mask_tensor_idx]);
auto weight_projection = static_cast<const T *>(inputs[weight_projection_tensor_idx]);
auto bias_projection = (is_position_bias_) ? nullptr : static_cast<const T *>(inputs[bias_projection_tensor_idx]);
auto bias_position = (is_position_bias_) ? static_cast<const T *>(inputs[bias_position_tensor_idx]) : nullptr;
auto output0 = static_cast<T *>(outputs[0]);
auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims;
const int request_batch_size = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[0]);
const int request_src_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]);
const int request_tgt_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]);
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
auto extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
size_t qk_buf_len = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len;
size_t qkv_buf_2_len = request_batch_size * request_src_seq_len * hidden_size;
SetInnerAddr<T>(workspace, size_q, size_k, qk_buf_len, qkv_buf_2_len, extra_tmp_size);
cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N};
cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F};
if constexpr (std::is_same<T, half>::value)
std::fill(std::begin(gemm_data_types), std::end(gemm_data_types), CUDA_R_16F);
float alpha = 1.0f, beta = 0.0f;
int gemm_dims[] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size};
int gemm_lds[] = {3 * hidden_size, hidden_size, 3 * hidden_size};
RunPhase1GEMM<T>(inputDesc, inputs, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha, &beta, algoId[0], stream);
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_T, CUBLAS_OP_N, head_size_, head_size_, request_tgt_seq_len);
SET_GEMM_DIMS(gemm_dims, request_tgt_seq_len, request_src_seq_len, head_size_);
int gemm_strides[] = {request_tgt_seq_len * head_size_, request_src_seq_len * head_size_,
request_src_seq_len * request_tgt_seq_len};
CublasGemmStridedBatchedWrapper(output1_, q_buf_2_, qk_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta,
request_batch_size * head_number_, cublas_handle_, algoId[1]);
T scalar = static_cast<T>(1.0f / sqrtf(head_size_ * 1.0f));
fastertransformer::invokeMixMaskedSoftMax(static_cast<T *>(qk_buf_), attention_mask, bias_position,
request_batch_size, request_src_seq_len, request_tgt_seq_len, head_number_,
scalar, stream);
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, head_size_, request_tgt_seq_len, head_size_);
SET_GEMM_DIMS(gemm_dims, head_size_, request_src_seq_len, request_tgt_seq_len);
gemm_strides[1] = request_src_seq_len * request_tgt_seq_len;
gemm_strides[THIRD_INPUT] = request_src_seq_len * head_size_;
CublasGemmStridedBatchedWrapper(output2_, qk_buf_, qkv_buf_2_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta,
request_batch_size * head_number_, cublas_handle_, algoId[2]);
fastertransformer::invokeTransposeQKV(static_cast<T *>(qkv_buf_3_), static_cast<T *>(qkv_buf_2_), request_batch_size,
request_src_seq_len, head_number_, head_size_, stream);
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, hidden_size, hidden_size, hidden_size);
SET_GEMM_DIMS(gemm_dims, hidden_size, request_batch_size * request_src_seq_len, hidden_size);
CublasGemmWrapper(weight_projection, qkv_buf_3_, output0, gemm_dims, gemm_lds, gemm_ops,
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta, cublas_handle_, algoId[3]);
if (!is_position_bias_) {
int len = request_batch_size * request_src_seq_len;
fastertransformer::invokeAddBias(reinterpret_cast<T *>(output0), reinterpret_cast<const T *>(bias_projection), len,
hidden_size, stream);
const int attn_mask_tensor_idx = 7 + cross_tensor_offset;
const int bias_qkv_tensor_idx = 5 + cross_tensor_offset;
const int weight_qkv_tensor_idx = 3;
const int position_bias_tensor_idx = 6 + cross_tensor_offset;
params_.stream = stream;
params_.algo = algoId;
void *inputs_attn[num_of_inputs_];
int index = 0;
inputs_attn[index++] = const_cast<void *>(inputs[0]);
if (params_.is_cross) {
inputs_attn[index++] = const_cast<void *>(inputs[1]);
inputs_attn[index++] = const_cast<void *>(inputs[weight_qkv_tensor_idx]);
inputs_attn[index++] = const_cast<void *>(inputs[weight_qkv_tensor_idx + 1]);
} else {
inputs_attn[index++] = const_cast<void *>(inputs[weight_qkv_tensor_idx]);
}
if (params_.qkv_bias) {
inputs_attn[index++] = const_cast<void *>(inputs[bias_qkv_tensor_idx]);
}
if (params_.position_bias) {
inputs_attn[index++] = const_cast<void *>(inputs[position_bias_tensor_idx]);
inputs_attn[index++] = const_cast<void *>(inputs[attn_mask_tensor_idx - C2NUM]);
} else {
inputs_attn[index++] = const_cast<void *>(inputs[attn_mask_tensor_idx]);
}
inputs_attn[index++] = const_cast<void *>(inputs[weight_projection_tensor_idx]);
if (params_.projection_bias) {
inputs_attn[index++] = const_cast<void *>(inputs[bias_projection_tensor_idx]);
}
void *outputs_attn[] = {outputs[0]};
fastertransformer::forward_attn<T>(reinterpret_cast<T **>(inputs_attn), num_of_inputs_,
reinterpret_cast<T **>(outputs_attn), num_of_outputs_, &params_, workspace);
return RET_OK;
}
@ -328,48 +201,33 @@ bool MhaPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorD
}
void MhaPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {}
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {
int cross_tensor_offset = 0;
int position_bias_tensor_offsets = 0;
if (params_.is_cross) cross_tensor_offset = 1;
if (params_.position_bias) position_bias_tensor_offsets = 1;
const int attn_mask_tensor_idx = 7 + cross_tensor_offset - position_bias_tensor_offsets;
const int request_batch_size = static_cast<const int>(in[attn_mask_tensor_idx].desc.dims.d[0]);
const int request_src_seq_len = static_cast<const int>(in[attn_mask_tensor_idx].desc.dims.d[1]);
const int request_tgt_seq_len = static_cast<const int>(in[attn_mask_tensor_idx].desc.dims.d[2]);
params_.batch_size = request_batch_size;
params_.src_seq_len = request_src_seq_len;
params_.tgt_seq_len = request_tgt_seq_len;
num_of_inputs_ = nbInputs;
num_of_outputs_ = nbOutputs;
}
size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
auto attn_dim_size = inputs[nbInputs - 1].dims.nbDims;
const int request_batch_size = static_cast<const int>(inputs[nbInputs - 1].dims.d[0]);
const int request_src_seq_len = static_cast<const int>(inputs[nbInputs - 1].dims.d[attn_dim_size - 2]);
const int request_tgt_seq_len = static_cast<const int>(inputs[nbInputs - 1].dims.d[attn_dim_size - 1]);
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
// TODO(NIZZAN) Fix efficient allocator
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
size_t size_v = size_k;
size_t qkv_len = size_q + size_k + size_v;
size_t q_buf_2_len = size_q;
size_t qk_buf_len = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len;
size_t qkv_buf_2_len = request_batch_size * request_src_seq_len * hidden_size;
size_t qkv_buf_3_len = qkv_buf_2_len;
size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len;
size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
int elem_size = sizeof(float);
if (compute_type_ == RuntimePrecisionMode_FP16) {
elem_size = sizeof(half);
return fastertransformer::GetAttnWorkspaceSize<half>(&params_);
} else {
return fastertransformer::GetAttnWorkspaceSize<float>(&params_);
}
return (buff_size + extra_tmp_size + extra_tmp_size) * elem_size;
}
nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
nvinfer1::IExprBuilder &exprBuilder) noexcept {
// MHA inputs:
// from_tensor [batch_size, src_seq_len, hidden_size_] or [batch_size * src_seq_len, hidden_size_]
// encoder_output [batch_size, tgt_seq_len, hidden_size_] or [batch_size * tgt_seq_len, hidden_size_]--> only in
// cross MHA attention_mask [batch_size, 1, src_seq_len, tgt_seq_len] or [batch_size, src_seq_len, tgt_seq_len]
// MHA output_tensors:
// attention_out [batch_size, src_seq_len, hidden_size_]
// key_cache [batch, head_num, size_per_head]
// value_cache [batch, head_num, tgt_seq_len, size_per_head]
nvinfer1::DimsExprs dims;
if (index == 0) {
#ifndef TEST_
@ -378,21 +236,20 @@ nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1
if (num_dims == INPUT_SIZE2) {
dims.d[0] = exprBuilder.constant(inputs[nbInputDims - 1].d[0]->getConstantValue() *
inputs[nbInputDims - 1].d[1]->getConstantValue());
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
auto hidden_size = exprBuilder.constant(params_.head_size * params_.head_num);
dims.d[1] = hidden_size;
} else if (num_dims == INPUT_SIZE3) {
dims.d[0] = inputs[nbInputDims - 1].d[0]; // batch
dims.d[1] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
auto hidden_size = exprBuilder.constant(params_.head_size * params_.head_num);
dims.d[kTwo] = hidden_size;
}
} else {
// TODO(Haim) - Fix size in case of 2d input
dims.nbDims = INPUT_SIZE4;
dims.d[0] = inputs[nbInputDims - 1].d[0]; // batch
dims.d[1] = exprBuilder.constant(head_number_);
dims.d[1] = exprBuilder.constant(params_.head_num);
dims.d[kTwo] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
dims.d[kThree] = exprBuilder.constant(head_size_);
dims.d[kThree] = exprBuilder.constant(params_.head_size);
}
#else
dims.nbDims = C2NUM;
@ -405,7 +262,7 @@ nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1
}
nvinfer1::IPluginV2DynamicExt *MhaPlugin::clone() const noexcept {
auto *plugin = new MhaPlugin(*this); // TODO(haim) CopyConstructor
auto *plugin = new MhaPlugin(*this);
if (plugin == nullptr) {
MS_LOG(ERROR) << "plugin is null";
return nullptr;
@ -418,13 +275,13 @@ int MhaPlugin::initialize() noexcept { return 0; }
void MhaPlugin::terminate() noexcept {}
size_t MhaPlugin::getSerializationSize() const noexcept { return INPUT_SIZE4 * sizeof(int); }
size_t MhaPlugin::getSerializationSize() const noexcept {
return sizeof(int) + sizeof(fastertransformer::encoderParamT);
}
void MhaPlugin::serialize(void *buffer) const noexcept {
SerializeValue(&buffer, &compute_type_, sizeof(int));
SerializeValue(&buffer, &head_number_, sizeof(int));
SerializeValue(&buffer, &head_size_, sizeof(int));
SerializeValue(&buffer, &is_cross_, sizeof(int));
SerializeValue(&buffer, &params_, sizeof(fastertransformer::encoderParamT));
}
REGISTER_TENSORRT_CREATOR(ops::kNameAttention, MhaTensorRT)
} // namespace mindspore::lite

View File

@ -22,6 +22,7 @@
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h"
#include "src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h"
#include "src/fastertransformer/layers/encoder_layers/encoder.h"
namespace mindspore::lite {
class MhaTensorRT : public TensorRTOp {
@ -42,41 +43,29 @@ class MhaTensorRT : public TensorRTOp {
constexpr auto MHA_PLUGIN_NAME{"AttentionPlugin"};
class MhaPlugin : public TensorRTPlugin {
public:
MhaPlugin(const std::string name, int compute_type, int head_number, int head_size, bool is_cross,
bool is_position_bias, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, uint32_t device_id)
MhaPlugin(const std::string name, int compute_type, fastertransformer::encoderParamT params,
cublasLtHandle_t cublaslt_handle, uint32_t device_id)
: TensorRTPlugin(name, std::string(MHA_PLUGIN_NAME), device_id),
compute_type_(compute_type),
head_number_(head_number),
head_size_(head_size),
is_cross_(is_cross),
is_position_bias_(is_position_bias),
cublas_handle_(cublas_handle),
params_(params),
cublaslt_handle_(cublaslt_handle) {}
MhaPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
: TensorRTPlugin(std::string(name), std::string(MHA_PLUGIN_NAME)) {
const nvinfer1::PluginField *fields = fc->fields;
compute_type_ = static_cast<const int *>(fields[0].data)[0];
head_number_ = static_cast<const int *>(fields[1].data)[0];
head_size_ = static_cast<const int *>(fields[2].data)[0];
is_cross_ = static_cast<const int *>(fields[3].data)[0];
is_position_bias_ = static_cast<const int *>(fields[4].data)[0];
params_ = static_cast<const fastertransformer::encoderParamT *>(fields[1].data)[0];
}
MhaPlugin(const char *name, const void *serialData, size_t serialLength)
: TensorRTPlugin(std::string(name), std::string(MHA_PLUGIN_NAME)) {
DeserializeValue(&serialData, &serialLength, &compute_type_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &head_number_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &head_size_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &is_cross_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &is_position_bias_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &params_, sizeof(fastertransformer::encoderParamT));
}
MhaPlugin() = delete;
~MhaPlugin() override {
// std::cout << "~MhaPlugin" << std::endl;
}
~MhaPlugin() override {}
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
@ -95,38 +84,17 @@ class MhaPlugin : public TensorRTPlugin {
int initialize() noexcept override;
private:
bool needResize(const int *current_dims, const int *last_dims);
template <typename T>
int RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream,
cublasGemmAlgo_t *algoId);
template <typename T>
void SetInnerAddr(void *workspace, size_t size_q, size_t size_k, size_t qk_buf_len, size_t qkv_buf_2_len,
size_t extra_size);
template <typename T>
void RunPhase1GEMM(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, int *gemm_dims,
int *gemm_lds, cublasOperation_t *gemm_ops, cudaDataType *gemm_data_types, void *alpha, void *beta,
cublasGemmAlgo_t algoId, cudaStream_t stream);
cublasGemmAlgo_t algoId);
const std::string layer_name_;
std::string name_space_;
int compute_type_;
int head_number_;
int head_size_;
bool is_cross_;
bool is_position_bias_;
cublasGemmAlgo_t fast_algo_gemm[4] = {CUBLAS_GEMM_DEFAULT_TENSOR_OP, CUBLAS_GEMM_DEFAULT_TENSOR_OP,
CUBLAS_GEMM_DEFAULT_TENSOR_OP, CUBLAS_GEMM_DEFAULT_TENSOR_OP};
cublasHandle_t cublas_handle_;
mutable fastertransformer::encoderParamT params_;
cublasLtHandle_t cublaslt_handle_;
void *qkv_buf_{nullptr};
void *q_buf_2_{nullptr};
void *qk_buf_{nullptr};
void *qkv_buf_2_{nullptr};
void *qkv_buf_3_{nullptr};
void *output1_{nullptr};
void *output2_{nullptr};
int num_of_inputs_;
int num_of_outputs_;
};
class MhaPluginCreater : public TensorRTPluginCreater<MhaPlugin> {
public:

View File

@ -106,7 +106,8 @@ int TensorRTAllocator::SyncMemDeviceToHost(tensor::Tensor *host_tensor, const st
return SyncMemInHostAndDevice(host_tensor->data_c(), device_tensor_name, host_tensor->Size(), false, sync);
}
int TensorRTAllocator::SyncMemDeviceToHost(void *dst_data, size_t data_size, const std::string &device_tensor_name) {
int TensorRTAllocator::SyncMemDeviceToHost(void *dst_data, size_t data_size, const std::string &device_tensor_name,
bool sync) {
if (dst_data == nullptr) {
MS_LOG(ERROR) << " dst host data cannot be nullptr.";
return RET_ERROR;
@ -124,7 +125,11 @@ int TensorRTAllocator::SyncMemDeviceToHost(void *dst_data, size_t data_size, con
MS_LOG(ERROR) << "device_ptr is null for " << device_tensor_name;
return RET_ERROR;
}
auto cuda_ret = cudaMemcpy(dst_data, device_ptr, data_size, cudaMemcpyDeviceToHost);
cudaError_t cuda_ret;
if (sync)
cuda_ret = cudaMemcpy(dst_data, device_ptr, data_size, cudaMemcpyDeviceToHost);
else
cuda_ret = cudaMemcpyAsync(dst_data, device_ptr, data_size, cudaMemcpyDeviceToHost, stream_);
if (cuda_ret != cudaSuccess) {
MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret);
return RET_ERROR;
@ -176,7 +181,11 @@ int TensorRTAllocator::SyncMemInHostAndDevice(void *host_data, const std::string
void *src_ptr = is_host2device ? host_data : device_ptr;
void *dst_ptr = is_host2device ? device_ptr : host_data;
cudaMemcpyKind kind = is_host2device ? cudaMemcpyHostToDevice : cudaMemcpyDeviceToHost;
auto cuda_ret = cudaMemcpy(dst_ptr, src_ptr, data_size, kind);
cudaError_t cuda_ret;
if (sync)
cuda_ret = cudaMemcpy(dst_ptr, src_ptr, data_size, kind);
else
cuda_ret = cudaMemcpyAsync(dst_ptr, src_ptr, data_size, kind, stream_);
if (cuda_ret != cudaSuccess) {
MS_LOG(ERROR) << "copy mem failed,ret " << cudaGetErrorName(cuda_ret);
return RET_ERROR;

View File

@ -54,7 +54,7 @@ class TensorRTAllocator {
int SyncMemHostToDevice(const tensor::Tensor &host_tensor, const std::string &device_tensor_name, bool sync = true);
int SyncMemDeviceToHost(tensor::Tensor *host_tensor, const std::string &device_tensor_name, bool sync = true);
int SyncMemDeviceToHost(void *dst_data, size_t data_size, const std::string &device_tensor_name);
int SyncMemDeviceToHost(void *dst_data, size_t data_size, const std::string &device_tensor_name, bool sync = true);
int ClearDeviceMem();

View File

@ -26,6 +26,7 @@
#include <fstream>
#include <limits>
#include <unordered_map>
#include <iomanip>
#include "src/extendrt/delegate/delegate_utils.h"
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
#include "src/common/utils.h"
@ -719,8 +720,8 @@ int TensorRTSubGraph::OnNewInputShapes(const std::vector<ShapeVector> &new_shape
return RET_OK;
}
int TensorRTSubGraph::PreExecute(const std::vector<tensor::Tensor> &inputs,
const std::vector<tensor::Tensor> &outputs) {
int TensorRTSubGraph::PreExecute(const std::vector<tensor::Tensor> &inputs, const std::vector<tensor::Tensor> &outputs,
bool sync) {
if (inputs_.size() != inputs.size()) {
MS_LOG(ERROR) << "Graph inputs size " << inputs_.size() << " != execute inputs size " << inputs.size();
return RET_ERROR;
@ -748,7 +749,7 @@ int TensorRTSubGraph::PreExecute(const std::vector<tensor::Tensor> &inputs,
MS_LOG(ERROR) << "realloc for input tensor device memory failed.";
return RET_ERROR;
}
ret = runtime_->GetAllocator()->SyncMemHostToDevice(inputs[i], trt_tensor_name);
ret = runtime_->GetAllocator()->SyncMemHostToDevice(inputs[i], trt_tensor_name, sync);
if (ret != RET_OK) {
MS_LOG(ERROR) << "sync mem from host to device failed for " << trt_tensor_name;
return RET_ERROR;
@ -788,7 +789,7 @@ int TensorRTSubGraph::PreExecute(const std::vector<tensor::Tensor> &inputs,
return RET_OK;
}
int TensorRTSubGraph::PostExecute(std::vector<tensor::Tensor> *outputs) {
int TensorRTSubGraph::PostExecute(std::vector<tensor::Tensor> *outputs, bool sync) {
if (!outputs->empty() && outputs->size() != outputs_.size()) {
MS_LOG(ERROR) << "Graph outputs size " << outputs_.size() << " != execute outputs size " << outputs->size();
return RET_ERROR;
@ -819,8 +820,8 @@ int TensorRTSubGraph::PostExecute(std::vector<tensor::Tensor> *outputs) {
MS_LOG(ERROR) << "Specified output device or host address cannot be nullptr";
return RET_ERROR;
}
int sync_ret =
runtime_->GetAllocator()->SyncMemDeviceToHost(host_address, outputs_[i].DataSize(), trt_out_tensor_name);
int sync_ret = runtime_->GetAllocator()->SyncMemDeviceToHost(host_address, outputs_[i].DataSize(),
trt_out_tensor_name, sync);
if (sync_ret != RET_OK) {
MS_LOG(ERROR) << "sync mem from device to host failed for " << trt_out_tensor_name;
return sync_ret;
@ -828,7 +829,7 @@ int TensorRTSubGraph::PostExecute(std::vector<tensor::Tensor> *outputs) {
}
} else {
tensor::Tensor output_tensor(static_cast<enum TypeId>(outputs_[i].DataType()), new_shape);
int sync_ret = runtime_->GetAllocator()->SyncMemDeviceToHost(&output_tensor, trt_out_tensor_name);
int sync_ret = runtime_->GetAllocator()->SyncMemDeviceToHost(&output_tensor, trt_out_tensor_name, sync);
if (sync_ret != RET_OK) {
MS_LOG(ERROR) << "sync mem from device to host failed for " << trt_out_tensor_name;
return sync_ret;
@ -854,19 +855,38 @@ bool TensorRTSubGraph::ValidInputResizeDims(const nvinfer1::Dims &construct_dims
}
int TensorRTSubGraph::Execute(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs) {
#ifdef ASYNC_INFER
bool sync = false;
#else
bool sync = true;
#endif
int ret = lite::SetCudaDevice(device_info_);
if (ret != RET_OK) {
return ret;
}
ret = PreExecute(inputs, *outputs);
ret = PreExecute(inputs, *outputs, sync);
if (ret != RET_OK) {
return ret;
}
if (!this->trt_context_->executeV2(tensor_bindings_)) {
MS_LOG(ERROR) << "TensorRT execute failed.";
return RET_ERROR;
if (sync) {
if (!this->trt_context_->executeV2(tensor_bindings_)) {
MS_LOG(ERROR) << "TensorRT execute failed.";
return RET_ERROR;
}
} else {
if (!this->trt_context_->enqueueV2(tensor_bindings_, stream_, nullptr)) {
MS_LOG(ERROR) << "TensorRT execute failed.";
return RET_ERROR;
}
}
return PostExecute(outputs);
ret = PostExecute(outputs, sync);
if (ret != RET_OK) {
return ret;
}
if (!sync) {
cudaStreamSynchronize(stream_);
}
return RET_OK;
}
int TensorRTSubGraph::Resize(const std::vector<tensor::Tensor> &, const std::vector<ShapeVector> &new_shapes) {

View File

@ -90,8 +90,9 @@ class TensorRTSubGraph {
nvinfer1::Dims SetInputDimsProfile(const TensorInfo &in_tensor, int index);
int ParseInputsProfile();
int PreExecute(const std::vector<tensor::Tensor> &inputs, const std::vector<tensor::Tensor> &outputs);
int PostExecute(std::vector<tensor::Tensor> *outputs);
int PreExecute(const std::vector<tensor::Tensor> &inputs, const std::vector<tensor::Tensor> &outputs,
bool sync = true);
int PostExecute(std::vector<tensor::Tensor> *outputs, bool sync = true);
int OnNewInputShapes(const std::vector<ShapeVector> &inputs);

View File

@ -51,6 +51,7 @@
#include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
#include "tools/optimizer/fusion/tensor_dot_fusion.h"
#include "tools/optimizer/fusion/multi_head_attention_fusion.h"
#include "tools/optimizer/fusion/encoder_layer_fusion.h"
#include "tools/optimizer/fusion/glu_fusion.h"
#include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h"
#include "tools/optimizer/fusion/matmul_add_fusion.h"
@ -319,9 +320,10 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared
std::make_shared<opt::AddActivationFusion>(),
std::make_shared<opt::ExpandDimsReshapeFusion>(),
std::make_shared<opt::SqueezeExpandDimsFusion>()};
#ifdef ENABLE_CLOUD_FUSION_INFERENCE
fusions.push_back(std::make_shared<opt::MultiHeadAttentionFusion>());
#endif
if (param->optimize_transformer) {
fusions.push_back(std::make_shared<opt::MultiHeadAttentionFusion>());
fusions.push_back(std::make_shared<opt::EncoderLayerFusion>());
}
for (size_t index = 0; index < fusions.size(); index++) {
auto pass_ptr = fusions.at(index);
MS_CHECK_TRUE_RET(pass_ptr != nullptr, RET_ERROR);

View File

@ -89,6 +89,7 @@ Flags::Flags() {
"Set the target device, support Ascend, Ascend310 and Ascend310P will be deprecated.", "");
AddFlag(&Flags::saveTypeStr, "saveType", "The type of saved model. MINDIR | MINDIR_LITE", "MINDIR_LITE");
AddFlag(&Flags::optimizeStr, "optimize", "The type of optimization. none | general | ascend_oriented", "");
AddFlag(&Flags::optimizeTransformerStr, "optimizeTransformer", "Enable Fast-Transformer fusion true|false", "false");
}
int Flags::InitInputOutputDataType() {
@ -274,7 +275,7 @@ int Flags::InitExportMindIR() {
return RET_INPUT_PARAM_INVALID;
}
if (this->exportMindIR == "MINDIR") {
if ((this->exportMindIR == "MINDIR") && (this->optimizeTransformer == false)) {
this->disableFusion = true;
}
return RET_OK;
@ -311,6 +312,18 @@ int Flags::InitSaveType() {
return RET_OK;
}
int Flags::InitOptimizeTransformer() {
if (this->optimizeTransformerStr == "true") {
this->optimizeTransformer = true;
} else if (this->optimizeTransformerStr == "false") {
this->optimizeTransformer = false;
} else {
std::cerr << "INPUT ILLEGAL: optimizeTransformer must be true|false " << std::endl;
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}
int Flags::PreInit(int argc, const char **argv) {
if (argc == 1) {
std::cout << this->Usage() << std::endl;
@ -383,6 +396,13 @@ int Flags::Init(int argc, const char **argv) {
std::cerr << "Init encrypt failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
ret = InitOptimizeTransformer();
if (ret != RET_OK) {
std::cerr << "Init optimize transformers failed." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
ret = InitPreInference();
if (ret != RET_OK) {
std::cerr << "Init pre inference failed." << std::endl;

View File

@ -43,6 +43,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
int InitOptimize();
int InitExportMindIR();
int InitSaveType();
int InitOptimizeTransformer();
int Init(int argc, const char **argv);
int PreInit(int argc, const char **argv);
@ -85,6 +87,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
bool encryption = false;
#endif
std::string device;
std::string optimizeTransformerStr;
bool optimizeTransformer = false;
};
} // namespace converter
} // namespace mindspore

View File

@ -68,6 +68,7 @@ int main(int argc, const char **argv) {
converter.SetTrainModel(flags.trainModel);
converter.SetNoFusion(flags.disableFusion);
converter.SetDevice(flags.device);
converter.SetOptimizeTransformer(flags.optimizeTransformer);
auto status = converter.Convert();
if (status != mindspore::kSuccess) {

View File

@ -270,6 +270,20 @@ bool Converter::GetNoFusion() {
}
}
void Converter::SetOptimizeTransformer(bool optimizeTransformer) {
if (data_ != nullptr) {
data_->optimize_transformer = optimizeTransformer;
}
}
bool Converter::GetOptimizeTransformer() {
if (data_ != nullptr) {
return data_->optimize_transformer;
} else {
return false;
}
}
void Converter::SetDevice(const std::vector<char> &device) {
if (data_ != nullptr) {
data_->device = CharToString(device);

View File

@ -58,6 +58,7 @@ struct ConverterPara {
bool pre_infer = false;
bool train_model = false;
bool no_fusion = false;
bool optimize_transformer = false;
bool is_runtime_converter = false;
std::set<std::string> fusion_blacklists;

View File

@ -0,0 +1,411 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define USE_DEPRECATED_API
#include "tools/optimizer/fusion/encoder_layer_fusion.h"
#include <functional>
#include <utility>
#include <vector>
#include <algorithm>
#include "tools/optimizer/common/gllo_utils.h"
#include "nnacl/op_base.h"
#include "ops/tuple_get_item.h"
#include "tools/common/tensor_util.h"
#include "ops/op_utils.h"
namespace mindspore::opt {
namespace {
const auto &p1 = std::placeholders::_1;
} // namespace
bool EncoderLayerFusion::Init() const {
input_ = std::make_shared<Var>("input");
MS_CHECK_TRUE_RET(input_ != nullptr, false);
beta1_ = std::make_shared<Var>("beta1");
MS_CHECK_TRUE_RET(beta1_ != nullptr, false);
gamma1_ = std::make_shared<Var>("gamma1");
MS_CHECK_TRUE_RET(gamma1_ != nullptr, false);
beta2_ = std::make_shared<Var>("beta2");
MS_CHECK_TRUE_RET(beta2_ != nullptr, false);
gamma2_ = std::make_shared<Var>("gamma2");
MS_CHECK_TRUE_RET(gamma2_ != nullptr, false);
weight_attn_qkv_ = std::make_shared<Var>("weight_attn_qkv");
MS_CHECK_TRUE_RET(weight_attn_qkv_ != nullptr, false);
weight_attn_o_ = std::make_shared<CondVar>(IsParamNode, "weight_attn_o");
MS_CHECK_TRUE_RET(weight_attn_o_ != nullptr, false);
weight_m_ = std::make_shared<CondVar>(IsParamNode, "weight_m");
MS_CHECK_TRUE_RET(weight_m_ != nullptr, false);
weight_p_ = std::make_shared<CondVar>(IsParamNode, "weight_p");
MS_CHECK_TRUE_RET(weight_p_ != nullptr, false);
bias_attn_qkv_ = std::make_shared<Var>("bias_attn_qkv");
MS_CHECK_TRUE_RET(bias_attn_qkv_ != nullptr, false);
bias_attn_o_ = std::make_shared<CondVar>(IsParamNode, "bias_attn_o");
MS_CHECK_TRUE_RET(bias_attn_o_ != nullptr, false);
bias_m_ = std::make_shared<CondVar>(IsParamNode, "bias_m");
MS_CHECK_TRUE_RET(bias_m_ != nullptr, false);
bias_p_ = std::make_shared<CondVar>(IsParamNode, "bias_p");
MS_CHECK_TRUE_RET(bias_p_ != nullptr, false);
mask_ = std::make_shared<Var>("mask");
MS_CHECK_TRUE_RET(mask_ != nullptr, false);
is_attention_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAttention), "is_attention");
MS_CHECK_TRUE_RET(is_attention_ != nullptr, false);
is_layernorm1_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm1");
MS_CHECK_TRUE_RET(is_layernorm1_ != nullptr, false);
is_layernorm2_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm2");
MS_CHECK_TRUE_RET(is_layernorm2_ != nullptr, false);
position_bias_ = std::make_shared<Var>("position_bias");
MS_CHECK_TRUE_RET(is_layernorm2_ != nullptr, false);
is_act_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "activation");
MS_CHECK_TRUE_RET(is_act_ != nullptr, {});
return true;
}
VectorRef EncoderLayerFusion::getTuple(bool post_layernorm, bool layernorm_fusion = false,
bool is_position_bias = false) const {
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
auto var1 = std::make_shared<Var>("var1-reshape");
MS_CHECK_TRUE_RET(var1 != nullptr, {});
auto reshape1 = VectorRef({is_reshape1, input_, var1});
if (post_layernorm) {
return reshape1;
}
if (layernorm_fusion) {
return DefineLayerNorm(is_position_bias, reshape1, gamma1_, beta1_);
}
auto layer_norm = VectorRef({is_layernorm1_, reshape1, gamma1_, beta1_});
auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
auto var_tuple = std::make_shared<Var>("var_tuple");
auto tuple = VectorRef({is_tuple, layer_norm, var_tuple});
return tuple;
}
VectorRef EncoderLayerFusion::DefineLayerNorm(bool is_position_bias, VectorRef input, VarPtr gamma, VarPtr beta) const {
auto var1 = std::make_shared<Var>("var1");
MS_CHECK_TRUE_RET(var1 != nullptr, {});
auto is_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "reduce");
MS_CHECK_TRUE_RET(is_reduce != nullptr, {});
auto reduce1 = VectorRef({is_reduce, input, var1});
auto is_sub = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSubFusion), "sub-f");
MS_CHECK_TRUE_RET(is_sub != nullptr, {});
auto sub = VectorRef({is_sub, input, reduce1});
auto is_sqr = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSquare), "sqr");
MS_CHECK_TRUE_RET(is_sqr != nullptr, {});
auto sqr = (is_position_bias) ? VectorRef({is_sqr, input}) : VectorRef({is_sqr, sub});
auto var2 = std::make_shared<Var>("var2");
MS_CHECK_TRUE_RET(var2 != nullptr, {});
auto is_reduce2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "reduce2");
MS_CHECK_TRUE_RET(is_reduce2 != nullptr, {});
auto reduce2 = VectorRef({is_reduce2, sqr, var2});
auto var3 = std::make_shared<Var>("var3");
MS_CHECK_TRUE_RET(var3 != nullptr, {});
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is-add");
MS_CHECK_TRUE_RET(is_add != nullptr, {});
auto add = VectorRef({is_add, reduce2, var3});
auto is_sqr2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSqrt), "sqr2");
MS_CHECK_TRUE_RET(is_sqr2 != nullptr, {});
auto sqr2 = VectorRef({is_sqr2, add});
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "real-div");
MS_CHECK_TRUE_RET(is_div != nullptr, {});
if (is_position_bias) {
auto real_div = VectorRef({is_div, input, sqr2});
auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "mul");
MS_CHECK_TRUE_RET(is_mul != nullptr, {});
auto mul = VectorRef({is_mul, real_div, gamma});
return mul;
} else {
auto real_div = VectorRef({is_div, sub, sqr2});
auto is_scale = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimScaleFusion), "scale");
MS_CHECK_TRUE_RET(is_scale != nullptr, {});
auto scale = VectorRef({is_scale, real_div, gamma, beta});
return scale;
}
}
VectorRef EncoderLayerFusion::DefinePatternEncoderLayer(bool post_layernorm = true, bool layernorm_fusion = false,
bool is_position_bias = false) const {
VectorRef attention, tuple, tuple2, tuple3, reshape2, matmul1;
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
auto var1 = std::make_shared<Var>("var1");
MS_CHECK_TRUE_RET(var1 != nullptr, {});
auto reshape1 = VectorRef({is_reshape1, input_, var1});
if (!is_position_bias) {
attention = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
getTuple(post_layernorm, layernorm_fusion, is_position_bias),
getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_,
weight_attn_o_, bias_attn_qkv_, bias_attn_o_, mask_});
} else {
attention = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
getTuple(post_layernorm, layernorm_fusion, is_position_bias),
getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_,
weight_attn_o_, position_bias_, mask_});
}
if (!is_position_bias) {
auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
auto var_tuple = std::make_shared<Var>("var_tuple");
tuple = VectorRef({is_tuple, attention, var_tuple});
} else {
tuple = attention;
}
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
auto add = VectorRef({is_add, reshape1, tuple});
if (layernorm_fusion) {
tuple2 = DefineLayerNorm(is_position_bias, add, gamma2_, beta2_);
} else {
auto layer_norm2 = VectorRef({is_layernorm2_, add, gamma2_, beta2_});
auto is_tuple2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item2");
auto var_tuple2 = std::make_shared<Var>("var_tuple2");
tuple2 = VectorRef({is_tuple2, layer_norm2, var_tuple2});
}
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder2");
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
auto var2 = std::make_shared<Var>("var2");
MS_CHECK_TRUE_RET(var2 != nullptr, {});
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
if (is_position_bias) {
reshape2 = VectorRef({is_reshape2, add, var2});
matmul1 = VectorRef({is_matmul1, tuple2, weight_m_});
} else if (post_layernorm || layernorm_fusion) {
reshape2 = VectorRef({is_reshape2, tuple2, var2});
matmul1 = VectorRef({is_matmul1, tuple2, weight_m_, bias_m_});
} else {
reshape2 = VectorRef({is_reshape2, add, var2});
matmul1 = VectorRef({is_matmul1, tuple2, weight_m_, bias_m_});
}
auto act = VectorRef({is_act_, matmul1});
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
auto matmul2 =
(is_position_bias) ? VectorRef({is_matmul2, matmul1, weight_p_}) : VectorRef({is_matmul2, act, weight_p_, bias_p_});
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder3");
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
auto var3 = std::make_shared<Var>("var3");
MS_CHECK_TRUE_RET(var3 != nullptr, {});
auto reshape3 = VectorRef({is_reshape3, matmul2, var3});
auto is_add3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add3");
auto add3 = VectorRef({is_add3, reshape2, reshape3});
if (!post_layernorm || layernorm_fusion) {
return add3;
}
auto is_reshape4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
MS_CHECK_TRUE_RET(is_reshape4 != nullptr, {});
auto var4 = std::make_shared<Var>("var4");
MS_CHECK_TRUE_RET(var4 != nullptr, {});
auto reshape4 = VectorRef({is_reshape4, add3, var4});
if (layernorm_fusion) {
tuple3 = DefineLayerNorm(is_position_bias, reshape4, gamma1_, beta1_);
} else {
auto layer_norm = VectorRef({is_layernorm1_, reshape4, gamma1_, beta1_});
auto is_tuple3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item3");
auto var_tuple3 = std::make_shared<Var>("var_tuple3");
tuple3 = VectorRef({is_tuple3, layer_norm, var_tuple3});
}
auto is_reshape5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
MS_CHECK_TRUE_RET(is_reshape5 != nullptr, {});
auto var5 = std::make_shared<Var>("var5");
MS_CHECK_TRUE_RET(var5 != nullptr, {});
auto reshape5 = VectorRef({is_reshape5, tuple3, var5});
return reshape5;
}
std::unordered_map<std::string, VectorRef> EncoderLayerFusion::DefinePatterns() const {
std::unordered_map<std::string, VectorRef> patterns;
if (!Init()) {
MS_LOG(ERROR) << "initial member failed.";
return patterns;
}
patterns[kPatternEncoderLayerPre] = DefinePatternEncoderLayer(false);
patterns[kPatternEncoderLayerPost] = DefinePatternEncoderLayer(true);
patterns[kPatternEncoderLayerPostNorm] = DefinePatternEncoderLayer(true, true);
patterns[kPatternEncoderLayerPreNorm] = DefinePatternEncoderLayer(false, true);
patterns[kPatternEncoderLayerT5] = DefinePatternEncoderLayer(false, true, true);
return patterns;
}
AnfNodePtr EncoderLayerFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
return nullptr;
}
if (pattern_name == kPatternEncoderLayerPost || pattern_name == kPatternEncoderLayerPostNorm) {
return CreateMaskedEncoderLayerFusionNode(func_graph, equiv, node, true);
} else if (pattern_name == kPatternEncoderLayerPre || pattern_name == kPatternEncoderLayerPreNorm) {
return CreateMaskedEncoderLayerFusionNode(func_graph, equiv, node, false);
} else if (pattern_name == kPatternEncoderLayerT5) {
is_position_bias_ = true;
return CreateMaskedEncoderLayerFusionNode(func_graph, equiv, node, false);
}
return nullptr;
}
bool EncoderLayerFusion::IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const VarPtr &input_prim) const {
auto act_input = GetAttribute(func_graph, equiv, is_act_);
MS_ASSERT(act_input != nullptr);
auto act_primitive = ops::GetOperator<ops::Activation>(act_input);
MS_CHECK_TRUE_RET(act_primitive != nullptr, false);
auto act_primitive_c = act_primitive->GetPrim();
if (act_primitive_c->GetAttr(ops::kActivationType) == nullptr ||
act_primitive->get_activation_type() != mindspore::GELU) {
return false;
}
return true;
}
AnfNodePtr EncoderLayerFusion::GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
VarPtr node_name) const {
if ((*equiv)[node_name] == nullptr || !utils::isa<AnfNodePtr>((*equiv)[node_name])) {
MS_LOG(ERROR) << node_name << "is not AnfNodePtr";
return nullptr;
}
AnfNodePtr node = utils::cast<AnfNodePtr>((*equiv)[node_name]);
MS_ASSERT(node != nullptr);
if (node == nullptr || !utils::isa<CNodePtr>(node)) {
auto manager = func_graph->manager();
if (manager == nullptr) {
return nullptr;
}
auto users = manager->node_users();
auto it = users.find(node);
if (it != users.end()) {
node = it->second.front().first;
}
if (node == nullptr || !utils::isa<CNodePtr>(node)) {
return nullptr;
}
}
auto cnode = utils::cast<CNodePtr>(node);
MS_ASSERT(cnode != nullptr);
auto input = cnode->input(0);
return input;
}
STATUS EncoderLayerFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num,
int *head_size, float *eps1, float *eps2) const {
auto attn_input = GetAttribute(func_graph, equiv, is_attention_);
MS_ASSERT(attn_input != nullptr);
auto attn_prim = ops::GetOperator<ops::Attention>(attn_input);
if (attn_prim->GetAttr(ops::kEncoderLayerNumHeads) != nullptr) {
*head_num = attn_prim->get_head_num();
}
if (attn_prim->GetAttr(ops::kAttentionSizePerHead) != nullptr) {
*head_size = attn_prim->get_head_size();
}
if (attn_prim->GetAttr(ops::kPositionBias) != nullptr) {
is_position_bias_ = attn_prim->get_position_bias();
}
auto layrn1_input = GetAttribute(func_graph, equiv, is_layernorm1_);
auto layrn1_prim = ops::GetOperator<ops::LayerNormFusion>(layrn1_input);
if (layrn1_prim->GetAttr(ops::kEpsilon) != nullptr) {
*eps1 = layrn1_prim->get_epsilon();
}
auto layrn2_input = GetAttribute(func_graph, equiv, is_layernorm2_);
auto layrn2_prim = ops::GetOperator<ops::LayerNormFusion>(layrn2_input);
if (layrn2_prim->GetAttr(ops::kEpsilon) != nullptr) {
*eps2 = layrn2_prim->get_epsilon();
}
if (!IsActGELU(func_graph, equiv, is_act_)) {
return false;
}
return RET_OK;
}
std::shared_ptr<ops::EncoderLayer> EncoderLayerFusion::CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
bool post_layernorm, int64_t ffn_hidden_size) const {
auto encoder_layer_prim = std::make_shared<ops::EncoderLayer>();
if (encoder_layer_prim == nullptr) {
MS_LOG(ERROR) << "Build enoder layer primitive failed.";
return nullptr;
}
int head_num = 0;
int head_size = 0;
float eps1 = 1e-6;
float eps2 = 1e-6;
if (CheckPattern(func_graph, equiv, &head_num, &head_size, &eps1, &eps2)) {
return nullptr;
}
encoder_layer_prim->Init(head_num, head_size, eps1, eps2, ffn_hidden_size, is_position_bias_, post_layernorm);
return encoder_layer_prim;
}
CNodePtr EncoderLayerFusion::CreateMaskedEncoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const AnfNodePtr &node,
bool post_layernorm = true) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
MS_ASSERT(node != nullptr);
auto input = utils::cast<AnfNodePtr>((*equiv)[input_]);
AnfNodePtr position_bias, input_mask, bias_attn_o, bias_attn_qkv, beta1, beta2, bias_m, bias_p;
auto weight_qkv = utils::cast<AnfNodePtr>((*equiv)[weight_attn_qkv_]);
auto weight_attn_o = utils::cast<AnfNodePtr>((*equiv)[weight_attn_o_]);
auto weight_m = utils::cast<AnfNodePtr>((*equiv)[weight_m_]);
auto weight_p = utils::cast<AnfNodePtr>((*equiv)[weight_p_]);
if (!is_position_bias_) {
bias_attn_qkv = utils::cast<AnfNodePtr>((*equiv)[bias_attn_qkv_]);
bias_attn_o = utils::cast<AnfNodePtr>((*equiv)[bias_attn_o_]);
bias_m = utils::cast<AnfNodePtr>((*equiv)[bias_m_]);
bias_p = utils::cast<AnfNodePtr>((*equiv)[bias_p_]);
beta1 = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
beta2 = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
}
auto gamma1 = utils::cast<AnfNodePtr>((*equiv)[gamma1_]);
auto gamma2 = utils::cast<AnfNodePtr>((*equiv)[gamma2_]);
if (mask_) {
input_mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
}
auto base_shape_ptr = weight_m->Shape();
MS_EXCEPTION_IF_NULL(base_shape_ptr);
auto input_shape_ptr = base_shape_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(input_shape_ptr);
auto input_shape = input_shape_ptr->shape();
MS_ASSERT(input_shape != nullptr);
int ffn_hidden_size = (int64_t)input_shape[1];
auto encoder_layer_prim = CreatePrim(func_graph, equiv, post_layernorm, ffn_hidden_size);
MS_CHECK_TRUE_RET(encoder_layer_prim != nullptr, nullptr);
auto encoder_layer_prim_c = encoder_layer_prim->GetPrim();
MS_CHECK_TRUE_RET(encoder_layer_prim_c != nullptr, nullptr);
auto value_node = NewValueNode(encoder_layer_prim_c);
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
std::vector<AnfNodePtr> new_node_inputs;
ParameterPtr c_bias_m_param, c_weight_p_param, c_bias_p_param, c_weight_m_param;
if (is_position_bias_) {
position_bias = utils::cast<AnfNodePtr>((*equiv)[position_bias_]);
if (!post_layernorm)
new_node_inputs = {value_node, input, gamma1, weight_qkv, input_mask,
weight_attn_o, gamma2, weight_m, weight_p, position_bias};
else
new_node_inputs = {value_node, input, weight_qkv, input_mask, weight_attn_o,
gamma1, weight_m, weight_p, gamma2, position_bias};
} else {
if (!post_layernorm) {
new_node_inputs = {value_node, input, gamma1, beta1, weight_qkv, bias_attn_qkv, input_mask, weight_attn_o,
bias_attn_o, gamma2, beta2, weight_m, bias_m, weight_p, bias_p};
} else {
new_node_inputs = {value_node, input, weight_qkv, bias_attn_qkv, input_mask,
weight_attn_o, bias_attn_o, gamma1, beta1, weight_m,
bias_m, weight_p, bias_p, gamma2, beta2};
}
}
auto new_node = func_graph->NewCNode(new_node_inputs);
MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
auto old_node = node->cast<CNodePtr>();
MS_CHECK_TRUE_RET(old_node->abstract() != nullptr, nullptr);
new_node->set_abstract(old_node->abstract()->Clone());
new_node->set_fullname_with_scope(node->fullname_with_scope() + "/encoder_layer");
return new_node;
}
} // namespace mindspore::opt

View File

@ -0,0 +1,90 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
#include "include/common/utils/utils.h"
#include "include/errorcode.h"
#include "ops/encoder_layer.h"
#include "tools/optimizer/fusion/multi_head_attention_fusion.h"
#include "ops/fusion/layer_norm_fusion.h"
#include "ops/fusion/activation.h"
namespace mindspore {
namespace opt {
class EncoderLayerFusion : public MultiplePatternProcessPass {
public:
explicit EncoderLayerFusion(const std::string &name = "EncoderLayerFusion", bool multigraph = true)
: MultiplePatternProcessPass(name, multigraph) {}
~EncoderLayerFusion() override = default;
AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
const EquivPtr &) const override;
std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
protected:
virtual bool Init() const;
private:
const std::string kPatternEncoderLayerPost = "PatternTEncoderLayerPost";
const std::string kPatternEncoderLayerPre = "PatternTEncoderLayerPre";
const std::string kPatternEncoderLayerPostNorm = "PatternTEncoderLayerPostNorm";
const std::string kPatternEncoderLayerPreNorm = "PatternTEncoderLayerPreNorm";
const std::string kPatternEncoderLayerT5 = "PatternEncoderLayerT5";
VectorRef DefinePatternEncoderLayer(bool post_layernorm, bool layernorm_fusion, bool is_position_bias_) const;
VectorRef getTuple(bool post_layernorm, bool layernorm_fusion, bool is_position_bias) const;
VectorRef DefineLayerNorm(bool is_position_bias, VectorRef input, VarPtr gamma, VarPtr beta) const;
CNodePtr CreateMaskedEncoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const AnfNodePtr &node, bool post_layernorm) const;
AnfNodePtr GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv, VarPtr node_name) const;
bool IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const VarPtr &input_prim) const;
lite::STATUS CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num, int *head_size,
float *eps1, float *eps2) const;
std::shared_ptr<ops::EncoderLayer> CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
bool post_layernorm, int64_t ffn_hidden_size) const;
protected:
mutable VarPtr input_{nullptr};
mutable VarPtr position_bias_{nullptr};
mutable VarPtr beta1_{nullptr};
mutable VarPtr gamma1_{nullptr};
mutable VarPtr beta2_{nullptr};
mutable VarPtr gamma2_{nullptr};
mutable VarPtr weight_attn_qkv_{nullptr};
mutable VarPtr weight_attn_qkv_cross_{nullptr};
mutable VarPtr weight_attn_o_{nullptr};
mutable VarPtr weight_m_{nullptr};
mutable VarPtr weight_p_{nullptr};
mutable VarPtr bias_attn_qkv_{nullptr};
mutable VarPtr bias_attn_o_{nullptr};
mutable VarPtr bias_m_{nullptr};
mutable VarPtr bias_p_{nullptr};
mutable VarPtr mask_{nullptr};
mutable VarPtr is_attention_{nullptr};
mutable VarPtr is_layernorm1_{nullptr};
mutable VarPtr is_layernorm2_{nullptr};
mutable bool is_position_bias_{false};
mutable VarPtr is_act_{nullptr};
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_

View File

@ -39,30 +39,34 @@ bool MultiHeadAttentionFusion::Init() const {
MS_CHECK_TRUE_RET(input_k_ != nullptr, false);
input_v_ = std::make_shared<Var>("input_v");
MS_CHECK_TRUE_RET(input_v_ != nullptr, false);
input1_ = std::make_shared<Var>("input1_");
MS_CHECK_TRUE_RET(input1_ != nullptr, false);
input2_ = std::make_shared<Var>("input2_");
MS_CHECK_TRUE_RET(input2_ != nullptr, false);
position_bias_ = std::make_shared<Var>("position_bias_");
MS_CHECK_TRUE_RET(position_bias_ != nullptr, false);
weight_q_ = std::make_shared<CondVar>(IsParamNode, "weight_q");
MS_CHECK_TRUE_RET(weight_q_ != nullptr, false);
weight_k_ = std::make_shared<CondVar>(IsParamNode, "weight_k");
MS_CHECK_TRUE_RET(weight_k_ != nullptr, false);
weight_v_ = std::make_shared<CondVar>(IsParamNode, "weight_v");
MS_CHECK_TRUE_RET(weight_v_ != nullptr, false);
weight_o_ = std::make_shared<CondVar>(IsParamNode);
weight_o_ = std::make_shared<CondVar>(IsParamNode, "weight_o");
MS_CHECK_TRUE_RET(weight_o_ != nullptr, false);
weight_o2_ = std::make_shared<CondVar>(IsParamNode, "weight_o2");
MS_CHECK_TRUE_RET(weight_o2_ != nullptr, false);
bias_q_ = std::make_shared<CondVar>(IsParamNode, "bias_q");
MS_CHECK_TRUE_RET(bias_q_ != nullptr, false);
bias_k_ = std::make_shared<CondVar>(IsParamNode, "bias_k");
MS_CHECK_TRUE_RET(bias_k_ != nullptr, false);
bias_v_ = std::make_shared<CondVar>(IsParamNode, "bias_v");
MS_CHECK_TRUE_RET(bias_v_ != nullptr, false);
bias_o_ = std::make_shared<CondVar>(IsParamNode);
bias_o_ = std::make_shared<CondVar>(IsParamNode, "bias_o");
MS_CHECK_TRUE_RET(bias_o_ != nullptr, false);
bias_o2_ = std::make_shared<CondVar>(IsParamNode, "bias_o2");
MS_CHECK_TRUE_RET(bias_o2_ != nullptr, false);
mask_ = std::make_shared<Var>("mask");
MS_CHECK_TRUE_RET(mask_ != nullptr, false);
reshape_k_ = std::make_shared<Var>("reshape_k");
MS_CHECK_TRUE_RET(reshape_k_ != nullptr, false);
reshape_v_ = std::make_shared<Var>("reshape_v");
@ -78,7 +82,7 @@ bool MultiHeadAttentionFusion::Init() const {
namespace {
VectorRef DefineMask(const BaseRef &mask_input) {
auto is_expand_dims = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims));
auto is_expand_dims = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims), "m-expand");
MS_CHECK_TRUE_RET(is_expand_dims != nullptr, {});
auto var1 = std::make_shared<Var>("m-var1");
MS_CHECK_TRUE_RET(var1 != nullptr, {});
@ -94,6 +98,7 @@ VectorRef DefineMask(const BaseRef &mask_input) {
MS_CHECK_TRUE_RET(var3 != nullptr, {});
return VectorRef({is_mul, sub, var3});
}
STATUS GetIntParameterData(const ParameterPtr &param_ptr, std::vector<int> *result) {
if (param_ptr == nullptr || !param_ptr->has_default()) {
MS_LOG(DEBUG) << "param not have default";
@ -139,29 +144,40 @@ STATUS GetAxis(const BaseRef &n, std::vector<int> *axes) {
VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias,
const BaseRef &axis, const BaseRef &transpose_var, bool test_div,
bool transpose) const {
bool transpose, bool mul) const {
auto is_matmul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "e-matmul");
MS_CHECK_TRUE_RET(is_matmul != nullptr, {});
auto dense = VectorRef({is_matmul, input, weight, bias});
auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "e-reshape");
MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
auto reshape = VectorRef({is_reshape, dense, axis});
auto var2 = std::make_shared<Var>();
VectorRef conn;
if (transpose) {
conn = VectorRef({transpose_var, reshape, var2});
auto var2 = std::make_shared<Var>("var2");
MS_CHECK_TRUE_RET(var2 != nullptr, {});
VectorRef conn1;
if (mul) {
auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "e-mul");
MS_CHECK_TRUE_RET(is_mul != nullptr, {});
conn1 = VectorRef({is_mul, reshape, var2});
} else {
conn = reshape;
conn1 = reshape;
}
auto var3 = std::make_shared<Var>("var3");
MS_CHECK_TRUE_RET(var3 != nullptr, {});
VectorRef conn2;
if (transpose) {
conn2 = VectorRef({transpose_var, conn1, var3});
} else {
conn2 = conn1;
}
if (test_div) {
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "e-div");
MS_CHECK_TRUE_RET(is_div != nullptr, {});
auto var3 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var3 != nullptr, {});
auto div = VectorRef({is_div, conn, var3});
auto var4 = std::make_shared<Var>("var4");
MS_CHECK_TRUE_RET(var4 != nullptr, {});
auto div = VectorRef({is_div, conn2, var4});
return div;
}
return conn;
return conn2;
}
VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &axis,
@ -172,7 +188,7 @@ VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const
auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "e-reshape");
MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
auto reshape = VectorRef({is_reshape, dense, axis});
auto var2 = std::make_shared<Var>();
auto var2 = std::make_shared<Var>("var2");
VectorRef conn;
if (transpose) {
conn = VectorRef({transpose_var, reshape, var2});
@ -182,7 +198,7 @@ VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const
if (test_div) {
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "e-div");
MS_CHECK_TRUE_RET(is_div != nullptr, {});
auto var3 = std::make_shared<Var>();
auto var3 = std::make_shared<Var>("var3");
MS_CHECK_TRUE_RET(var3 != nullptr, {});
auto div = VectorRef({is_div, conn, var3});
return div;
@ -196,20 +212,20 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool mask) const {
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true);
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, true);
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_);
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape1");
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
auto var1 = std::make_shared<Var>();
auto var1 = std::make_shared<Var>("var1");
MS_CHECK_TRUE_RET(var1 != nullptr, {});
VectorRef reshape1;
if (mask) {
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion));
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
MS_CHECK_TRUE_RET(is_add != nullptr, {});
auto mask = DefineMask(mask_);
MS_CHECK_TRUE_RET(!mask.empty(), {});
@ -218,28 +234,28 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool mask) const {
} else {
reshape1 = VectorRef({is_reshape1, matmul1, var1});
}
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax));
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax), "is_softmax");
MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
auto softmax = VectorRef({is_softmax, reshape1});
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape");
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
auto var2 = std::make_shared<Var>();
auto var2 = std::make_shared<Var>("var2");
MS_CHECK_TRUE_RET(var2 != nullptr, {});
auto reshape2 = VectorRef({is_reshape2, softmax, var2});
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
auto matmul2 = VectorRef({is_matmul2, reshape2, v_embedding});
auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "is_transpose");
MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
auto var3 = std::make_shared<Var>();
auto var3 = std::make_shared<Var>("var3");
MS_CHECK_TRUE_RET(var3 != nullptr, {});
auto transpose = VectorRef({is_transpose, matmul2, var3});
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshpae");
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
auto var4 = std::make_shared<Var>();
auto var4 = std::make_shared<Var>("var4");
MS_CHECK_TRUE_RET(var4 != nullptr, {});
auto reshape3 = VectorRef({is_reshape3, transpose, var4});
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul3");
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_, bias_o_});
return matmul3;
@ -290,18 +306,26 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5() const {
return matmul3;
}
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose) const {
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose, bool no_div_flag) const {
VectorRef k_embedding, v_embedding;
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "q_transpose");
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
VectorRef q_embedding;
if (transpose) {
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, true);
if (no_div_flag) {
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, false, true);
} else {
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, true);
}
} else {
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, false);
}
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
k_embedding = DefineEmbedding(input_k_, weight_k_, reshape_k_, k_transpose_, true, true);
if (no_div_flag) {
k_embedding = DefineEmbedding(input_k_, weight_k_, reshape_k_, k_transpose_, false, true);
} else {
k_embedding = DefineEmbedding(input_k_, weight_k_, reshape_k_, k_transpose_, true, true);
}
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, reshape_v_, v_transpose_, false, true);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
@ -323,7 +347,12 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose)
auto reshape1 = VectorRef({is_reshape1, add2, var1});
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax), "softmax");
MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
auto softmax = VectorRef({is_softmax, reshape1});
VectorRef softmax, matmul2;
if (no_div_flag) {
softmax = VectorRef({is_softmax, add2});
} else {
softmax = VectorRef({is_softmax, reshape1});
}
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape2");
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
auto var2 = std::make_shared<Var>("var2");
@ -331,7 +360,11 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose)
auto reshape2 = VectorRef({is_reshape2, softmax, var2});
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul2");
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
auto matmul2 = VectorRef({is_matmul2, reshape2, v_embedding});
if (no_div_flag) {
matmul2 = VectorRef({is_matmul2, softmax, v_embedding});
} else {
matmul2 = VectorRef({is_matmul2, reshape2, v_embedding});
}
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape3");
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
auto var4 = std::make_shared<Var>("var4");
@ -347,12 +380,12 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose)
} else {
reshape3 = VectorRef({is_reshape3, matmul2, var4});
}
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul3");
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_});
return matmul3;
}
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA() const {
VectorRef k_embedding, v_embedding;
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
@ -366,8 +399,6 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA() const {
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
auto var1 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var1 != nullptr, {});
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion));
MS_CHECK_TRUE_RET(is_add != nullptr, {});
auto mask = DefineMask(mask_);
@ -395,6 +426,66 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA() const {
return matmul3;
}
VectorRef MultiHeadAttentionFusion::DefineMPPatternSwin(bool flag) const {
VectorRef k_embedding, v_embedding;
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, false, true, true);
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, false, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
auto var1 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var1 != nullptr, {});
auto is_add1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion));
MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
auto add1 = VectorRef({is_add1, matmul1, var1});
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax));
MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
VectorRef softmax;
if (flag) {
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape");
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
auto var2 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var2 != nullptr, {});
auto reshape1 = VectorRef({is_reshape1, add1, var2});
auto var3 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var3 != nullptr, {});
auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion));
MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
auto add2 = VectorRef({is_add2, reshape1, var3});
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape");
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
auto var4 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var4 != nullptr, {});
auto reshape2 = VectorRef({is_reshape2, add2, var4});
softmax = VectorRef({is_softmax, reshape2});
} else {
softmax = VectorRef({is_softmax, add1});
}
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
auto matmul2 = VectorRef({is_matmul2, softmax, v_embedding});
auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
auto var5 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var5 != nullptr, {});
auto transpose = VectorRef({is_transpose, matmul2, var5});
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape3");
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
auto var6 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var6 != nullptr, {});
auto reshape3 = VectorRef({is_reshape3, transpose, var6});
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul3");
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_, bias_o_});
return matmul3;
}
namespace {
template <typename T>
STATUS TransposeMatrix(std::shared_ptr<tensor::Tensor> src, std::shared_ptr<tensor::Tensor> dst) {
@ -442,7 +533,6 @@ std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<
for (std::size_t i = 1; i < base_shape_size; ++i) {
new_shape.push_back(base_shape.at(i));
}
// calculate data
auto concat_tensor = std::make_shared<tensor::Tensor>(base_data_type, new_shape);
MS_CHECK_TRUE_RET(concat_tensor != nullptr, nullptr);
@ -481,13 +571,15 @@ std::unordered_map<std::string, VectorRef> MultiHeadAttentionFusion::DefinePatte
MS_LOG(ERROR) << "initial member failed.";
return patterns;
}
patterns[kMPAWithMaskPatternName] = DefineMPWithMaskPattern();
patterns[kMPAPatternName] = DefineMPWithMaskPattern(false);
patterns[kMPAWithMaskPatternNamePA] = DefineMPWithMaskPatternPA();
patterns[kMPAWithMaskPatternNameT5] = DefineMPWithMaskPatternT5();
patterns[kMPAWithMaskPatternNameT5New] = DefineMPWithMaskPatternT5New(false);
patterns[kMPAWithMaskPatternNameT5New2] = DefineMPWithMaskPatternT5New(true, true);
patterns[kMPAWithMaskTransposePatternNameT5New] = DefineMPWithMaskPatternT5New();
patterns[kMPAPatternNameSwin1] = DefineMPPatternSwin();
patterns[kMPAPatternNameSwin2] = DefineMPPatternSwin(false);
return patterns;
}
@ -496,7 +588,6 @@ bool MultiHeadAttentionFusion::CheckPattern(const EquivPtr &equiv, int *head_num
MS_ASSERT(head_num != nullptr);
MS_ASSERT(head_size != nullptr);
std::vector<int> reshape_axes;
// UNDO !!!!!!!!!
if (GetAxis((*equiv)[reshape_axis_], &reshape_axes) != lite::RET_OK) {
MS_LOG(ERROR) << "cannot figure out reshape";
return false;
@ -523,44 +614,20 @@ AnfNodePtr MultiHeadAttentionFusion::Process(const std::string &pattern_name, co
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
return nullptr;
}
++match_count_;
if ((pattern_name == kMPAWithMaskPatternName) || (pattern_name == kMPAWithMaskPatternNamePA) ||
(pattern_name == kMPAWithMaskPatternNameT5) ||
(pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New)) {
if (pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New) {
(pattern_name == kMPAWithMaskPatternNameT5) || (pattern_name == kMPAWithMaskPatternNameT5New) ||
(pattern_name == kMPAWithMaskTransposePatternNameT5New) || (pattern_name == kMPAWithMaskPatternNameT5New2)) {
if (pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New ||
pattern_name == kMPAWithMaskPatternNameT5New2) {
t5_x_ = true;
}
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), true);
}
if (pattern_name == kMPAPatternName)
if (pattern_name == kMPAPatternName || pattern_name == kMPAPatternNameSwin1 || pattern_name == kMPAPatternNameSwin2)
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), false);
return nullptr;
}
// STATUS GetIntParameterData(const ParameterPtr &param_ptr, std::vector<int> *result) {
// if (param_ptr == nullptr || !param_ptr->has_default()) {
// MS_LOG(DEBUG) << "param not have default";
// return RET_ERROR;
// }
// auto default_param = param_ptr->default_param();
// if (default_param == nullptr || !utils::isa<tensor::TensorPtr>(default_param)) {
// MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
// return RET_ERROR;
// }
// auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
// if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
// MS_LOG(DEBUG) << "default param is not int";
// return RET_ERROR;
// }
// auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
// int64_t shape_size =
// std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>());
// for (int64_t i = 0; i < shape_size; i++) {
// result->emplace_back(ptr[i]);
// }
// return RET_OK;
// }
std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(const EquivPtr &equiv) const {
MS_ASSERT(equiv != nullptr);
auto attention_prim = std::make_shared<ops::Attention>();
@ -572,19 +639,16 @@ std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(con
MS_LOG(ERROR) << "Reshape k is not a parameter";
return nullptr;
}
if (!utils::isa<ParameterPtr>((*equiv)[reshape_v_])) {
MS_LOG(ERROR) << "Reshape v is not a parameter";
return nullptr;
}
auto reshape_k = utils::cast<ParameterPtr>((*equiv)[reshape_k_]);
std::vector<int> shape_k;
if (RET_OK != GetIntParameterData(reshape_k, &shape_k)) {
MS_LOG(ERROR) << "Get reshape k data failed";
return nullptr;
}
auto reshape_v = utils::cast<ParameterPtr>((*equiv)[reshape_v_]);
std::vector<int> shape_v;
if (RET_OK != GetIntParameterData(reshape_v, &shape_v)) {
@ -694,7 +758,7 @@ std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::CreatePrim(const Equiv
if (!CheckPattern(equiv, &head_num, &head_size)) {
return nullptr;
}
attention_prim->Init(head_num, head_size, cross);
attention_prim->Init(head_num, head_size, t5_x_, cross);
return attention_prim;
}
@ -726,7 +790,7 @@ bool MultiHeadAttentionFusion::IsCross(const EquivPtr &equiv) const {
ret = FetchShapeFromAbstract(input_v->abstract(), &inputv_shape);
MS_CHECK_TRUE_RET(ret == RET_OK, false);
if ((inputq_shape != inputv_shape) || ((match_count_ > 1) && (input_q != input_v))) {
if ((inputq_shape != inputv_shape) || (input_q != input_v)) {
return true;
}
return false;
@ -795,9 +859,8 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
bool cross = IsCross(equiv);
std::vector<AnfNodePtr> redundant;
auto [weight_o, weight_q_tensor, weight_k_tensor, weight_v_tensor] = GetAttentionNodeWeights(equiv, &redundant);
AnfNodePtr bias_q;
AnfNodePtr bias_q, bias_o;
ParameterPtr c_bias_param;
AnfNodePtr bias_o;
if (!t5_x_) {
bias_q = utils::cast<AnfNodePtr>((*equiv)[bias_q_]);
auto bias_k = utils::cast<AnfNodePtr>((*equiv)[bias_k_]);
@ -824,9 +887,7 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
if (!cross && !t5_x_) {
redundant.push_back(bias_q);
}
tensor::TensorPtr c_weights;
tensor::TensorPtr q_weight_t;
tensor::TensorPtr c_weights, q_weight_t;
if (cross) {
c_weights = ConcatTensors({weight_k_tensor, weight_v_tensor}, true);
q_weight_t = ConcatTensors({weight_q_tensor}, true);
@ -852,7 +913,6 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
}
std::vector<AnfNodePtr> new_node_inputs =
GetNewNodeInputs(equiv, q_weight_param, c_weight_param, weight_o, c_bias_param, bias_o, mask, cross);
auto new_node = func_graph->NewCNode(new_node_inputs);
MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
if (vnode) {

View File

@ -50,9 +50,11 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
VectorRef DefineMPWithMaskPattern(bool mask = true) const;
VectorRef DefineMPWithMaskPatternPA() const;
VectorRef DefineMPWithMaskPatternT5() const;
VectorRef DefineMPWithMaskPatternT5New(bool transpose = true) const;
VectorRef DefineMPWithMaskPatternT5New(bool transpose = true, bool no_div_flag = false) const;
VectorRef DefineMPPatternSwin(bool flag = true) const;
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, const BaseRef &axis,
const BaseRef &transpose_var, bool test_div = false, bool transpose = true) const;
const BaseRef &transpose_var, bool test_div = false, bool transpose = true,
bool mul = false) const;
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &axis,
const BaseRef &transpose_var, bool test_div, bool transpose) const;
@ -85,21 +87,28 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
const std::string kMPAPatternName = "MPAPattern";
const std::string kMPAWithMaskPatternNameT5 = "MPAWithMaskPatternT5";
const std::string kMPAWithMaskPatternNameT5New = "MPAWithMaskPatternT5New";
const std::string kMPAWithMaskPatternNameT5New2 = "MPAWithMaskPatternT5New2";
const std::string kMPAWithMaskTransposePatternNameT5New = "MPAWithMaskTransposePatternT5New";
const std::string kMPAPatternNameSwin1 = "MPAPatternNameSwin1";
const std::string kMPAPatternNameSwin2 = "MPAPatternNameSwin2";
mutable VarPtr input_q_{nullptr};
mutable VarPtr input_k_{nullptr};
mutable VarPtr input_v_{nullptr};
mutable VarPtr input1_{nullptr};
mutable VarPtr input2_{nullptr};
mutable VarPtr position_bias_{nullptr};
mutable VarPtr weight_q_{nullptr};
mutable VarPtr weight_k_{nullptr};
mutable VarPtr weight_v_{nullptr};
mutable VarPtr weight_o_{nullptr};
mutable VarPtr weight_o2_{nullptr};
mutable VarPtr bias_q_{nullptr};
mutable VarPtr bias_k_{nullptr};
mutable VarPtr bias_v_{nullptr};
mutable VarPtr bias_o_{nullptr};
mutable VarPtr bias_o2_{nullptr};
mutable VarPtr mask_{nullptr};

File diff suppressed because it is too large Load Diff