attention kernel bit exact

This commit is contained in:
yoni 2022-09-12 16:46:34 +03:00
parent ca371e1531
commit 3ae59185b2
33 changed files with 1086 additions and 137 deletions

View File

@ -8,6 +8,9 @@ mindspore_add_pkg(fast_transformers
URL ${REQ_URL}
MD5 ${MD5}
LIBS ${ft_libs}
LIB_PATH output/lib
LIB_PATH lib
PATCHES ${MINDSPORE_PROJECT_DIR}/third_party/patch/fast_transformer/001-fast_transformer.patch
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DEXAMPLES=off)
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DEXAMPLES=off)
include_directories(${fast_transformers_INC})
add_library(mindspore::fast_transformers ALIAS fast_transformers::transformer-shared)

View File

@ -58,9 +58,6 @@ endif()
if(ENABLE_GPU)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/cub.cmake)
if(NOT MSVC)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/fast_transformers.cmake)
endif()
if(ENABLE_MPI)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/nccl.cmake)
endif()

View File

@ -852,6 +852,8 @@ else()
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${fast_transformers_LIBPATH}/libtransformer-shared.so DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
endif()
else()
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}

View File

@ -133,9 +133,10 @@ EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv)
(*equiv)[var] = expr;
MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString();
return equiv;
} else {
MS_LOG(DEBUG) << "pattern not match: " + pattern.ToString() + ", " + expr.ToString();
}
}
return nullptr;
}
@ -222,6 +223,7 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorR
return nullptr;
}
}
if ((values_expr_len == 0) && (values_pattern_len == 0)) return equiv;
if (values_expr_len < values_pattern_len - 1) {
MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len;
return nullptr;

View File

@ -109,6 +109,7 @@ class VarHasher {
class CondVar : public Var {
public:
explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {}
explicit CondVar(const ConditionFunc &cond, std::string tag) : Var(tag), cond_fn_(cond) {}
~CondVar() override = default;
MS_DECLARE_PARENT(CondVar, Var);
bool matches(const BaseRef &value) override {

View File

@ -18,6 +18,13 @@
#include "nnacl/op_base.h"
typedef struct AttentionParameter {
OpParameter op_parameter_;
int head_num_;
int head_size_;
bool cross_;
} AttentionParameter;
typedef struct RelativePositionAttentionParameter {
// Primitive parameter
OpParameter op_parameter_;

View File

@ -16,6 +16,7 @@
#include "nnacl/infer/attention_infer.h"
#include "nnacl/infer/infer_register.h"
#include "nnacl/attention_parameter.h"
int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
@ -23,27 +24,44 @@ int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
if (check_ret != NNACL_OK) {
return check_ret;
}
const TensorC *q_input = inputs[0];
TensorC *output = outputs[0];
SetDataTypeFormat(output, q_input);
AttentionParameter *param = (AttentionParameter *)parameter;
const TensorC *q_input = inputs[FIRST_INPUT];
const TensorC *k_input = inputs[SECOND_INPUT];
TensorC *output0 = outputs[FIRST_INPUT];
SetDataTypeFormat(output0, q_input);
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
const TensorC *q_weight = inputs[3];
if (q_input->shape_size_ != 2 && q_input->shape_size_ != 3) {
const TensorC *q_weight = inputs[FOURTH_INPUT];
if (q_input->shape_size_ != C2NUM && q_input->shape_size_ != C3NUM) {
return NNACL_ERR;
}
if (q_weight->shape_size_ != 2) {
if (q_weight->shape_size_ != C2NUM) {
return NNACL_ERR;
}
int batch = (q_input->shape_size_ == 2) ? 1 : q_input->shape_[0];
int f_seq = (q_input->shape_size_ == 2) ? q_input->shape_[0] : q_input->shape_[1];
int d_model = q_weight->shape_[1];
output->shape_[0] = batch;
output->shape_[1] = f_seq;
output->shape_[2] = d_model;
output->shape_size_ = 3;
int t_seq_len = k_input->shape_[1];
output0->shape_[FIRST_INPUT] = batch;
output0->shape_[SECOND_INPUT] = f_seq;
output0->shape_[THIRD_INPUT] = param->head_num_ * param->head_size_;
output0->shape_size_ = C3NUM;
if (outputs_size >= C3NUM) {
TensorC *output1 = outputs[SECOND_INPUT];
SetDataTypeFormat(output1, q_input);
output1->shape_[FIRST_INPUT] = batch;
output1->shape_[SECOND_INPUT] = param->head_num_;
output1->shape_[THIRD_INPUT] = param->head_size_;
output1->shape_[FOURTH_INPUT] = t_seq_len;
output1->shape_size_ = C4NUM;
TensorC *output2 = outputs[THIRD_INPUT];
SetDataTypeFormat(output2, q_input);
output2->shape_[FIRST_INPUT] = batch;
output2->shape_[SECOND_INPUT] = param->head_num_;
output2->shape_[THIRD_INPUT] = t_seq_len;
output2->shape_[FOURTH_INPUT] = param->head_size_;
output2->shape_size_ = C4NUM;
}
return NNACL_OK;
}

View File

@ -17,9 +17,39 @@
#include "ops/attention.h"
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
namespace mindspore::ops {
MIND_API_OPERATOR_IMPL(Attention, BaseOperator);
void Attention::set_head_num(int64_t head_num) { (void)this->AddAttr(kAttentionNumHeads, api::MakeValue(head_num)); }
void Attention::set_head_size(int64_t head_size) {
(void)this->AddAttr(kAttentionSizePerHead, api::MakeValue(head_size));
}
void Attention::set_cross(bool cross) { (void)this->AddAttr(kCross, api::MakeValue(cross)); }
int64_t Attention::get_head_num() const {
auto value_ptr = this->GetAttr(kAttentionNumHeads);
return GetValue<int64_t>(value_ptr);
}
int64_t Attention::get_head_size() const {
auto value_ptr = this->GetAttr(kAttentionSizePerHead);
return GetValue<int64_t>(value_ptr);
}
bool Attention::get_cross() const {
auto value_ptr = this->GetAttr(kCross);
return GetValue<bool>(value_ptr);
}
void Attention::Init(int64_t head_num, int64_t head_size, bool cross) {
this->set_head_num(head_num);
this->set_head_size(head_size);
this->set_cross(cross);
}
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
} // namespace mindspore::ops

View File

@ -37,7 +37,16 @@ class MIND_API Attention : public BaseOperator {
{"output"});
}
/// \brief Initialize Attention op.
void Init() const {}
/// \param[in] head_num Define head number.
/// \param[in] head_size Define size per head.
/// \param[in] cross Define is cross attention. Default false.
void Init(int64_t head_num, int64_t head_size, bool cross = false);
void set_head_num(int64_t head_num);
void set_head_size(int64_t head_size);
void set_cross(bool cross);
int64_t get_head_num() const;
int64_t get_head_size() const;
bool get_cross() const;
};
} // namespace ops
} // namespace mindspore

View File

@ -686,6 +686,7 @@ GVAR_DEF(PrimitivePtr, kPrimSoftmaxV2WithDropoutDoMaskV3, std::make_shared<Primi
GVAR_DEF(PrimitivePtr, kPrimLogSoftmax, std::make_shared<Primitive>("LogSoftmax"));
GVAR_DEF(PrimitivePtr, kPrimLogSoftmaxGrad, std::make_shared<Primitive>("LogSoftmaxGrad"));
GVAR_DEF(PrimitivePtr, kPrimLstm, std::make_shared<Primitive>("LSTM"));
GVAR_DEF(PrimitivePtr, kPrimAttention, std::make_shared<Primitive>("Attention"));
GVAR_DEF(PrimitivePtr, kPrimLstmGradData, std::make_shared<Primitive>("LSTMGradData"));
GVAR_DEF(PrimitivePtr, kPrimLstmGradWeight, std::make_shared<Primitive>("LSTMGradWeight"));
GVAR_DEF(PrimitivePtr, kPrimTan, std::make_shared<Primitive>("Tan"));

View File

@ -145,10 +145,10 @@ constexpr auto kNumElements = "num_elements";
constexpr auto kNumBits = "num_bits";
constexpr auto kNumDirections = "num_directions";
constexpr auto kNumProj = "num_proj";
constexpr auto kAttentionNumHeads = "attention_num_heads";
constexpr auto kAttentionSizePerHead = "attention_size_per_head";
constexpr auto kAttentionFromSeqLen = "attention_from_seq_len";
constexpr auto kAttentionToSeqLen = "attention_to_seq_len";
constexpr auto kAttentionNumHeads = "head_num";
constexpr auto kAttentionSizePerHead = "head_size";
constexpr auto kAttentionFromSeqLen = "from_seq_len";
constexpr auto kAttentionToSeqLen = "to_seq_len";
constexpr auto kOffset = "offset";
constexpr auto kNmsIouThreshold = "nms_iou_threshold";
constexpr auto kNmsScoreThreshold = "nms_score_threshold";
@ -345,6 +345,7 @@ constexpr auto kSeed0 = "seed0";
constexpr auto kSeed1 = "seed1";
constexpr auto kHandle = "handle";
constexpr auto kBatchSize = "batch_size";
constexpr auto kCross = "cross";
constexpr auto kDeviceNum = "device_num";
constexpr size_t kInputIndex0 = 0;

View File

@ -749,6 +749,9 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
set(MSLITE_DEPS_LIBEVENT on)
set(MSLITE_DEPS_PYBIND11 on)
set(MSLITE_DEPS_OPENSSL on)
if(SUPPORT_TENSORRT)
set(MSLITE_DEPS_FAST_TRANSFORMERS on)
endif()
endif()
if(MSLITE_ENABLE_MODEL_ENCRYPTION)

View File

@ -34,6 +34,10 @@ if(MSLITE_DEPS_OPENCV)
include(${TOP_DIR}/cmake/external_libs/opencv.cmake)
endif()
if(MSLITE_DEPS_FAST_TRANSFORMERS)
include(${TOP_DIR}/cmake/external_libs/fast_transformers.cmake)
endif()
if(MSLITE_DEPS_MKLDNN)
if(CMAKE_SYSTEM_NAME MATCHES "Linux")
set(USE_MS_THREADPOOL_FOR_DNNL ON)

View File

@ -390,6 +390,9 @@ table Concat {
}
table Attention {
head_num: long;
head_size: long;
cross: bool;
}
table Conv2DBackpropFilterFusion {

View File

@ -390,6 +390,9 @@ OP_ATTR(axis, long)
OP_SCHEMA_DEF_END(Concat)
OP_SCHEMA_DEF(Attention)
OP_ATTR(head_num, long)
OP_ATTR(head_size, long);
OP_ATTR(cross, bool)
OP_SCHEMA_DEF_END(Attention)
OP_SCHEMA_DEF(Conv2DBackpropFilterFusion)

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/common/ops/populate/populate_register.h"
#include "nnacl/attention_parameter.h"
using mindspore::schema::PrimitiveType_Attention;
namespace mindspore {
namespace lite {
OpParameter *PopulateAttentionParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
auto value = primitive->value_as_Attention();
MS_CHECK_TRUE_MSG(value != nullptr, nullptr, "value is nullptr.");
auto *param = reinterpret_cast<AttentionParameter *>(malloc(sizeof(AttentionParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc AttentionParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(AttentionParameter));
param->op_parameter_.type_ = primitive->value_type();
param->head_num_ = value->head_num();
param->head_size_ = value->head_size();
param->cross_ = value->cross();
return reinterpret_cast<OpParameter *>(param);
}
REG_POPULATE(PrimitiveType_Attention, PopulateAttentionParameter, SCHEMA_CUR)
} // namespace lite
} // namespace mindspore

View File

@ -41,7 +41,6 @@ OpParameter *PopulateCommonParameter(const void *prim) {
REG_POPULATE(PrimitiveType_AddN, PopulateCommonParameter, SCHEMA_CUR)
REG_POPULATE(PrimitiveType_ZerosLike, PopulateCommonParameter, SCHEMA_CUR)
REG_POPULATE(PrimitiveType_Depend, PopulateCommonParameter, SCHEMA_CUR)
REG_POPULATE(PrimitiveType_Attention, PopulateCommonParameter, SCHEMA_CUR)
REG_POPULATE(PrimitiveType_SwitchLayer, PopulateCommonParameter, SCHEMA_CUR)
REG_POPULATE(PrimitiveType_Log1p, PopulateCommonParameter, SCHEMA_CUR)
} // namespace lite

View File

@ -97,4 +97,4 @@ target_link_libraries(
add_subdirectory(cuda_impl)
target_link_libraries(tensorrt_plugin cuda_kernel_mid gpu_distribution_collective)
target_link_libraries(tensorrt_plugin mindspore-extendrt mindspore_core)
target_link_libraries(tensorrt_plugin mindspore-extendrt mindspore_core mindspore::fast_transformers)

View File

@ -45,6 +45,7 @@ void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const
CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_b, trans_a, n, m, k, &alpha, b_addr, type_b, ldb, a_addr, type_a,
lda, &beta, c_addr, type_c, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *params,
const cublasOperation_t *operations, const cudaDataType *data_types,
cublasHandle_t cublas_handle) {
@ -67,4 +68,49 @@ void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *
type_a, lda, &beta, c_addrs, type_c, ldc, batch, compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
cublasHandle_t cublas_handle) {
const int m = params[0];
const int n = params[1];
const int k = params[2];
cublasOperation_t trans_a = operations[0];
cublasOperation_t trans_b = operations[1];
const int lda = lds[0];
const int ldb = lds[1];
const int ldc = lds[2];
cudaDataType type_a = data_types[0];
cudaDataType type_b = data_types[1];
cudaDataType type_c = data_types[2];
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda, b_addr, type_b,
ldb, beta, c_addr, type_c, ldc, compute_type, CUBLAS_GEMM_DEFAULT));
}
void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
const int *lds, const cublasOperation_t *operations, const int *strides,
const cudaDataType *data_types, void *alpha, void *beta, int batch,
cublasHandle_t cublas_handle) {
const int m = params[0];
const int n = params[1];
const int k = params[2];
cublasOperation_t trans_a = operations[0];
cublasOperation_t trans_b = operations[1];
const int lda = lds[0];
const int ldb = lds[1];
const int ldc = lds[2];
cudaDataType type_a = data_types[0];
cudaDataType type_b = data_types[1];
cudaDataType type_c = data_types[2];
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
const int stride_a = strides[0];
const int stride_b = strides[1];
const int stride_c = strides[2];
CUBLAS_CHECK_VOID(cublasGemmStridedBatchedEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda,
stride_a, b_addr, type_b, ldb, stride_b, beta, c_addr, type_c, ldc,
stride_c, batch, compute_type, CUBLAS_GEMM_DEFAULT));
}
} // namespace mindspore::lite

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_
#include <cublasLt.h>
#include <cublas_v2.h>
#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h"
#include "src/common/log_util.h"
@ -58,5 +59,13 @@ void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const
// data_types: type_a, type_b, type_c, compute type
void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *params,
const cublasOperation_t *operations, const cudaDataType *data_types, cublasHandle_t cublas_handle);
void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
cublasHandle_t cublas_handle);
void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
const int *lds, const cublasOperation_t *operations, const int *strides,
const cudaDataType *data_types, void *alpha, void *beta, int batch,
cublasHandle_t cublas_handle);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_

View File

@ -0,0 +1,330 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime.h>
#include <numeric>
#include <memory>
#include <vector>
#include <functional>
#include <unordered_map>
#include <algorithm>
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
#include "NvInferRuntimeCommon.h"
#include "src/extendrt/delegate/tensorrt/op/mha_tensorrt.h"
#include "ops/attention.h"
#include "src/fastertransformer/kernels/unfused_attention_kernels.h"
#include "src/fastertransformer/kernels/activation_kernels.h"
#include "src/fastertransformer/utils/cuda_utils.h"
#include "src/fastertransformer/utils/allocator.h"
namespace mindspore::lite {
namespace {
constexpr std::size_t kTwo = 2;
constexpr std::size_t kThree = 3;
} // namespace
// Multi Head Attention TensorRT op
int MhaTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) {
if (in_tensors.size() != 8 && in_tensors.size() != 6) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
return RET_ERROR;
}
return RET_OK;
}
int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) {
if (ctx == nullptr || ctx->network() == nullptr) {
MS_LOG(ERROR) << "context or network is invalid";
return RET_ERROR;
}
auto mha_op = AsOps<ops::Attention>();
if (mha_op == nullptr) {
MS_LOG(ERROR) << "op action convert failed";
return RET_ERROR;
}
// get attribute for Attn op - TODO - add attribute in op
int head_number = mha_op->get_head_num();
int head_size = mha_op->get_head_size();
int compute_type = 1; // mha_op->get_compute_type();
int is_cross = mha_op->get_cross();
nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_;
auto plugin = std::make_shared<MhaPlugin>(input_tensor->getName(), compute_type, head_number, head_size, is_cross,
GetCublasHandle(), GetCublasLtHandle(), device_id_);
const int input_number = inputs().size();
nvinfer1::ITensor *inputTensors[input_number];
for (int i = 0; i < input_number; i++) {
inputTensors[i] = input(ctx, i).trt_tensor_;
}
nvinfer1::IPluginV2Layer *mha_layer = ctx->network()->addPluginV2(inputTensors, input_number, *plugin);
if (mha_layer == nullptr) {
MS_LOG(ERROR) << "add mha op failed for TensorRT.";
return RET_ERROR;
}
mha_layer->setName(op_name_.c_str());
// TODO(haim) one output
nvinfer1::ITensor *attn_tensor = mha_layer->getOutput(0);
ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name());
// nvinfer1::ITensor *key_tensor = mha_layer->getOutput(1);
// ctx->RegisterTensor(ITensorHelper{key_tensor, Format::NCHW, true}, out_tensors_[1].Name());
// nvinfer1::ITensor *value_tensor = mha_layer->getOutput(kTwo);
// ctx->RegisterTensor(ITensorHelper{value_tensor, Format::NCHW, true}, out_tensors_[kTwo].Name());
this->layer_ = mha_layer;
return RET_OK;
}
// PLUGIN of Multi Head Attention Layer
REGISTER_TENSORRT_PLUGIN(MhaPluginCreater);
template class TensorRTPluginCreater<MhaPlugin>;
template <class T>
nvinfer1::PluginFieldCollection TensorRTPluginCreater<T>::field_collection_{};
template <class T>
std::vector<nvinfer1::PluginField> TensorRTPluginCreater<T>::fields_;
int MhaPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept {
return RunCudaMha(inputDesc, outputDesc, inputs, outputs, workspace, stream);
}
int MhaPlugin::RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) {
// inputs order:
// 0] Q
// 1] K
// 2] V
// 3] W
// 4] PW
// 5] B
// 6] PB
// 7] AttnMask
cublasSetStream(cublas_handle_, stream);
// TODO(Haim) - Fix tensor ids according to cross flag
const int from_tensor_idx = 0;
// const int encoder_tensor_idx = 1;
const int weight_qkv_tensor_idx = 3;
const int weight_projection_tensor_idx = 4;
const int bias_qkv_tensor_idx = 5;
const int bias_projection_tensor_idx = 6;
const int attn_mask_tensor_idx = 7;
auto from_tensor = static_cast<const float *>(inputs[from_tensor_idx]);
auto attention_mask = static_cast<const float *>(inputs[attn_mask_tensor_idx]);
auto weight_qkv = static_cast<const float *>(inputs[weight_qkv_tensor_idx]);
auto bias_qkv = static_cast<const float *>(inputs[bias_qkv_tensor_idx]);
auto weight_projection = static_cast<const float *>(inputs[weight_projection_tensor_idx]);
auto bias_projection = static_cast<const float *>(inputs[bias_projection_tensor_idx]);
auto output0 = static_cast<float *>(outputs[0]);
// auto output1 = static_cast<float *>(outputs[1]);
// auto output2 = static_cast<float *>(outputs[2]);
auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims;
const int request_batch_size = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[0]);
const int request_src_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]);
const int request_tgt_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]);
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
// TODO(NIZZAN): fix allocator
size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
size_t size_v = size_k;
size_t qkv_len = size_q + size_k + size_v;
size_t q_buf_2_len = size_q;
size_t qk_buf_len = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len;
size_t qkv_buf_2_len = request_batch_size * request_src_seq_len * hidden_size;
size_t qkv_buf_3_len = qkv_buf_2_len;
size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len;
qkv_buf_ = workspace;
q_buf_2_ = static_cast<float *>(qkv_buf_) + qkv_len;
qk_buf_ = static_cast<float *>(q_buf_2_) + q_buf_2_len;
qkv_buf_2_ = static_cast<float *>(qk_buf_) + qk_buf_len;
qkv_buf_3_ = static_cast<float *>(qkv_buf_2_) + qkv_buf_2_len;
output1_ = static_cast<float *>(workspace) + buff_size;
output2_ = static_cast<float *>(output1_) + extra_tmp_size;
int gemm_dims[3] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size};
int gemm_lds[3] = {3 * hidden_size, hidden_size, 3 * hidden_size};
cublasOperation_t gemm_ops[2] = {CUBLAS_OP_N, CUBLAS_OP_N};
const cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F};
float alpha = 1.0f;
float beta = 0.0f;
CublasGemmWrapper(weight_qkv, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha, &beta,
cublas_handle_);
fastertransformer::invokeAddFusedQKVBiasTranspose(static_cast<float *>(q_buf_2_), static_cast<float *>(output1_),
static_cast<float *>(output2_), static_cast<float *>(qkv_buf_),
bias_qkv, request_batch_size, request_src_seq_len, head_number_,
head_size_, 0, stream);
gemm_ops[0] = CUBLAS_OP_T;
gemm_ops[1] = CUBLAS_OP_N;
gemm_dims[0] = request_tgt_seq_len;
gemm_dims[1] = request_src_seq_len;
gemm_dims[THIRD_INPUT] = head_size_;
gemm_lds[0] = head_size_;
gemm_lds[1] = head_size_;
gemm_lds[THIRD_INPUT] = request_tgt_seq_len;
int gemm_strides[] = {request_tgt_seq_len * head_size_, request_src_seq_len * head_size_,
request_src_seq_len * request_tgt_seq_len};
CublasGemmStridedBatchedWrapper(output1_, q_buf_2_, qk_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
gemm_data_types, &alpha, &beta, request_batch_size * head_number_, cublas_handle_);
float scalar = (1.0f / sqrtf(static_cast<float>(head_size_) * 1.0f));
fastertransformer::invokeMixMaskedSoftMax(static_cast<float *>(qk_buf_), attention_mask, request_batch_size,
request_src_seq_len, request_tgt_seq_len, head_number_, scalar, stream);
gemm_ops[0] = CUBLAS_OP_N;
gemm_ops[1] = CUBLAS_OP_N;
gemm_dims[0] = head_size_;
gemm_dims[1] = request_src_seq_len;
gemm_dims[THIRD_INPUT] = request_tgt_seq_len;
gemm_lds[0] = head_size_;
gemm_lds[1] = request_tgt_seq_len;
gemm_lds[THIRD_INPUT] = head_size_;
gemm_strides[0] = request_tgt_seq_len * head_size_;
gemm_strides[1] = request_src_seq_len * request_tgt_seq_len;
gemm_strides[THIRD_INPUT] = request_src_seq_len * head_size_;
CublasGemmStridedBatchedWrapper(output2_, qk_buf_, qkv_buf_2_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
gemm_data_types, &alpha, &beta, request_batch_size * head_number_, cublas_handle_);
fastertransformer::invokeTransposeQKV(static_cast<float *>(qkv_buf_3_), static_cast<float *>(qkv_buf_2_),
request_batch_size, request_src_seq_len, head_number_, head_size_, stream);
gemm_ops[0] = CUBLAS_OP_N;
gemm_ops[1] = CUBLAS_OP_N;
gemm_dims[0] = hidden_size;
gemm_dims[1] = request_batch_size * request_src_seq_len;
gemm_dims[THIRD_INPUT] = hidden_size;
gemm_lds[0] = hidden_size;
gemm_lds[1] = hidden_size;
gemm_lds[THIRD_INPUT] = hidden_size;
CublasGemmWrapper(weight_projection, qkv_buf_3_, output0, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha,
&beta, cublas_handle_);
int len = request_batch_size * request_src_seq_len;
fastertransformer::invokeAddBias(reinterpret_cast<float *>(output0), reinterpret_cast<const float *>(bias_projection),
len, hidden_size, stream);
return RET_OK;
}
int MhaPlugin::RunCudaCrossMha(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) {
// Add Cross Mha Layer here
return RET_OK;
}
size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
auto attn_dim_size = inputs[nbInputs - 1].dims.nbDims;
const int request_batch_size = static_cast<const int>(inputs[nbInputs - 1].dims.d[0]);
const int request_src_seq_len = static_cast<const int>(inputs[nbInputs - 1].dims.d[attn_dim_size - 2]);
const int request_tgt_seq_len = static_cast<const int>(inputs[nbInputs - 1].dims.d[attn_dim_size - 1]);
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
// TODO(NIZZAN) Fix efficient allocator
// size_t buff_size = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len +
// request_batch_size * request_src_seq_len * hidden_size;
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
size_t size_v = size_k;
size_t qkv_len = size_q + size_k + size_v;
size_t q_buf_2_len = size_q;
size_t qk_buf_len = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len;
size_t qkv_buf_2_len = request_batch_size * request_src_seq_len * hidden_size;
size_t qkv_buf_3_len = qkv_buf_2_len;
size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len;
size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
return (buff_size + extra_tmp_size + extra_tmp_size) * sizeof(float);
}
nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
nvinfer1::IExprBuilder &exprBuilder) noexcept {
// MHA inputs:
// from_tensor [batch_size, src_seq_len, hidden_size_] or [batch_size * src_seq_len, hidden_size_]
// encoder_output [batch_size, tgt_seq_len, hidden_size_] or [batch_size * tgt_seq_len, hidden_size_]--> only in
// cross MHA attention_mask [batch_size, 1, src_seq_len, tgt_seq_len] or [batch_size, src_seq_len, tgt_seq_len]
// MHA output_tensors:
// attention_out [batch_size, src_seq_len, hidden_size_]
// key_cache [batch, head_num, size_per_head]
// value_cache [batch, head_num, tgt_seq_len, size_per_head]
nvinfer1::DimsExprs dims;
if (index == 0) {
// if (inputs[0].nbDims == 2) {
// dims.nbDims = INPUT_SIZE2;
// dims.d[0] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
// auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
// dims.d[1] = hidden_size;
// } else
{
dims.nbDims = INPUT_SIZE3;
dims.d[0] = inputs[nbInputDims - 1].d[0]; // batch
dims.d[1] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
dims.d[kTwo] = hidden_size;
}
} else {
// TODO(Haim) - Fix size in case of 2d input
dims.nbDims = INPUT_SIZE4;
dims.d[0] = inputs[nbInputDims - 1].d[0]; // batch
dims.d[1] = exprBuilder.constant(head_number_);
dims.d[kTwo] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
dims.d[kThree] = exprBuilder.constant(head_size_);
}
return dims;
}
nvinfer1::IPluginV2DynamicExt *MhaPlugin::clone() const noexcept {
auto *plugin = new MhaPlugin(*this); // TODO(haim) CopyConstructor
if (plugin == nullptr) {
MS_LOG(ERROR) << "plugin is null";
return nullptr;
}
plugin->setPluginNamespace(name_space_.c_str());
return plugin;
}
void MhaPlugin::terminate() noexcept {}
size_t MhaPlugin::getSerializationSize() const noexcept { return INPUT_SIZE4 * sizeof(int); }
void MhaPlugin::serialize(void *buffer) const noexcept {
SerializeValue(&buffer, &compute_type_, sizeof(int));
SerializeValue(&buffer, &head_number_, sizeof(int));
SerializeValue(&buffer, &head_size_, sizeof(int));
SerializeValue(&buffer, &is_cross_, sizeof(int));
}
REGISTER_TENSORRT_CREATOR(ops::kNameAttention, MhaTensorRT)
} // namespace mindspore::lite

View File

@ -0,0 +1,115 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MHA_TENSORRT_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MHA_TENSORRT_H_
#include <string>
#include <vector>
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h"
#include "src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h"
namespace mindspore::lite {
class MhaTensorRT : public TensorRTOp {
public:
MhaTensorRT(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors, std::string name)
: TensorRTOp(base_operator, in_tensors, out_tensors, name) {}
~MhaTensorRT() override = default;
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) override;
};
constexpr auto MHA_PLUGIN_NAME{"AttentionPlugin"};
class MhaPlugin : public TensorRTPlugin {
public:
MhaPlugin(const std::string name, int compute_type, int head_number, int head_size, int is_cross,
cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, uint32_t device_id)
: TensorRTPlugin(name, std::string(MHA_PLUGIN_NAME), device_id),
compute_type_(compute_type),
head_number_(head_number),
head_size_(head_size),
is_cross_(is_cross),
cublas_handle_(cublas_handle),
cublaslt_handle_(cublaslt_handle) {}
MhaPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
: TensorRTPlugin(std::string(name), std::string(MHA_PLUGIN_NAME)) {
const nvinfer1::PluginField *fields = fc->fields;
compute_type_ = static_cast<const int *>(fields[0].data)[0];
head_number_ = static_cast<const int *>(fields[1].data)[0];
head_size_ = static_cast<const int *>(fields[2].data)[0];
is_cross_ = static_cast<const int *>(fields[3].data)[0];
}
MhaPlugin(const char *name, const void *serialData, size_t serialLength)
: TensorRTPlugin(std::string(name), std::string(MHA_PLUGIN_NAME)) {
DeserializeValue(&serialData, &serialLength, &compute_type_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &head_number_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &head_size_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &is_cross_, sizeof(int));
}
MhaPlugin() = delete;
~MhaPlugin() override {
// std::cout << "~MhaPlugin" << std::endl;
}
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void *buffer) const noexcept override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
void terminate() noexcept override;
private:
bool needResize(const int *current_dims, const int *last_dims);
int RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream);
int RunCudaCrossMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream);
const std::string layer_name_;
std::string name_space_;
int compute_type_;
int head_number_;
int head_size_;
int is_cross_;
cublasHandle_t cublas_handle_;
cublasLtHandle_t cublaslt_handle_;
void *qkv_buf_{nullptr};
void *q_buf_2_{nullptr};
void *qk_buf_{nullptr};
void *qkv_buf_2_{nullptr};
void *qkv_buf_3_{nullptr};
void *output1_{nullptr};
void *output2_{nullptr};
};
class MhaPluginCreater : public TensorRTPluginCreater<MhaPlugin> {
public:
MhaPluginCreater() : TensorRTPluginCreater(std::string(MHA_PLUGIN_NAME)) {}
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MHA_TENSORRT_H_

View File

@ -28,6 +28,7 @@
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
#include "src/extendrt/delegate/tensorrt/op_registration_factory.h"
#include "src/extendrt/delegate/tensorrt/tensor_info.h"
// #include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h"
#include "src/common/log_util.h"
#include "ops/base_operator.h"
#include "ops/op_name.h"
@ -106,6 +107,8 @@ class TensorRTOp {
const std::vector<TensorRTOp *> &out_ops() const;
void SetRuntime(TensorRTRuntime *runtime);
cublasHandle_t GetCublasHandle() { return runtime_ ? runtime_->GetCublasHandle() : nullptr; }
cublasLtHandle_t GetCublasLtHandle() { return runtime_ ? runtime_->GetCublasLtHandle() : nullptr; }
DynamicShapeParams GetDynamicShapeParams() const;

View File

@ -360,6 +360,14 @@ TensorRTExecutor::~TensorRTExecutor() {
if (stream_ != nullptr) {
cudaStreamDestroy(stream_);
}
if (cublas_handle_ != nullptr) {
cublasDestroy(cublas_handle_);
cublas_handle_ = nullptr;
}
if (cublaslt_handle_ != nullptr) {
cublasLtDestroy(cublaslt_handle_);
cublaslt_handle_ = nullptr;
}
}
bool IsHardwareSupport() {
int driver_version = 0;
@ -415,6 +423,19 @@ Status TensorRTExecutor::Init() {
MS_LOG(ERROR) << "Cuda create stream failed";
return mindspore::kLiteError;
}
auto cublas_ret = cublasCreate(&cublas_handle_);
if (cublas_ret != CUBLAS_STATUS_SUCCESS) {
MS_LOG(ERROR) << "Cuda create cublas handle failed";
return mindspore::kLiteError;
}
auto cublaslt_ret = cublasLtCreate(&cublaslt_handle_);
if (cublaslt_ret != CUBLAS_STATUS_SUCCESS) {
MS_LOG(ERROR) << "Cuda create cublaslt handle failed";
return mindspore::kLiteError;
}
return mindspore::kSuccess;
}
@ -569,7 +590,7 @@ std::shared_ptr<TensorRTSubGraph> TensorRTExecutor::CreateTensorRTGraph(const st
FindPreNextOps<TensorRTOp>(ops);
// 2. Init TensorRT SubGraph.
auto ret = tensorrt_graph->Init(stream_);
auto ret = tensorrt_graph->Init(stream_, cublas_handle_, cublaslt_handle_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "TensorRTGraph init failed.";
return nullptr;

View File

@ -83,6 +83,9 @@ class TensorRTExecutor : public LiteGraphExecutor {
size_t device_cache_size_{0};
std::string serialize_path_;
cudaStream_t stream_{nullptr};
cublasHandle_t cublas_handle_{nullptr};
cublasLtHandle_t cublaslt_handle_{nullptr};
std::vector<kernel::Kernel> kernel_list_;
std::vector<TrtGraphContext> tensorrt_graph_list_;

View File

@ -18,6 +18,7 @@
#include <NvInfer.h>
#include "include/errorcode.h"
#include "src/extendrt/delegate/tensorrt/tensorrt_allocator.h"
#include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h"
#include "src/common/log_adapter.h"
#define MAX_BATCH_SIZE 64
@ -55,7 +56,11 @@ class TensorRTRuntime {
void SetBatchSize(int batch_size) { batch_size_ = batch_size; }
void SetCudaStream(cudaStream_t stream) { allocator_->SetCudaStream(stream); }
void SetCudaStream(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle) {
allocator_->SetCudaStream(stream);
cublas_handle_ = cublas_handle;
cublaslt_handle_ = cublaslt_handle;
}
RuntimePrecisionMode GetRuntimePrecisionMode() { return runtime_percision_mode_; }
@ -68,6 +73,8 @@ class TensorRTRuntime {
void SetDeviceID(uint32_t device_id) { device_id_ = device_id; }
uint32_t GetDeviceID() { return device_id_; }
cublasHandle_t GetCublasHandle() { return cublas_handle_; }
cublasLtHandle_t GetCublasLtHandle() { return cublaslt_handle_; }
private:
bool is_init_ = false;
@ -77,6 +84,8 @@ class TensorRTRuntime {
int batch_size_{0};
uint32_t device_id_{0};
RuntimePrecisionMode runtime_percision_mode_{RuntimePrecisionMode::RuntimePrecisionMode_FP32};
cublasHandle_t cublas_handle_{nullptr};
cublasLtHandle_t cublaslt_handle_{nullptr};
};
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_RUNTIME_H_

View File

@ -86,7 +86,7 @@ TensorRTSubGraph::~TensorRTSubGraph() {
}
}
int TensorRTSubGraph::Init(cudaStream_t stream) {
int TensorRTSubGraph::Init(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle) {
auto ret = GetGraphInOutOps(inputs_, outputs_, &in_ops_, &out_ops_, all_ops_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get TensorRT subgraph input and output ops failed.";
@ -107,7 +107,7 @@ int TensorRTSubGraph::Init(cudaStream_t stream) {
MS_LOG(ERROR) << "New TensorRTContext failed.";
return RET_OK;
}
if (SetDeviceConfig(stream) != RET_OK) {
if (SetDeviceConfig(stream, cublas_handle, cublaslt_handle) != RET_OK) {
MS_LOG(WARNING) << "set tensorrt config failed.";
}
serializer_ = std::make_shared<TensorRTSerializer>(serialize_file_path_);
@ -151,7 +151,8 @@ int TensorRTSubGraph::BuildEngine() {
return RET_OK;
}
int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) {
int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream, cublasHandle_t cublas_handle,
cublasLtHandle_t cublaslt_handle) {
if (config_ == nullptr) {
this->config_ = runtime_->GetBuilder()->createBuilderConfig();
if (this->config_ == nullptr) {
@ -176,9 +177,10 @@ int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) {
} else {
MS_LOG(INFO) << "inputs no quant params or platform not support int8.";
}
runtime_->SetCudaStream(stream);
runtime_->SetCudaStream(stream, cublas_handle, cublaslt_handle);
config_->setProfileStream(stream);
stream_ = stream;
MS_LOG(INFO) << GetRankID() << " tensorrt subgraph stream: " << stream_;
// config setMaxWorkspaceSize to 2047 MB for max limit

View File

@ -54,7 +54,7 @@ class TensorRTSubGraph {
int BuildTensorRTGraph();
int Init(cudaStream_t stream);
int Init(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle);
void SetSerializePath(const std::string &path) { serialize_file_path_ = std::move(path); }
@ -65,7 +65,7 @@ class TensorRTSubGraph {
private:
int BuildEngine();
int SetDeviceConfig(cudaStream_t stream);
int SetDeviceConfig(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle);
bool IsInt8Mode();

View File

@ -247,6 +247,9 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared
std::make_shared<opt::TensorDotFusion>(),
std::make_shared<opt::MatMulActivationFusion>(param),
std::make_shared<opt::MulActivationFusion>()};
#ifdef ENABLE_CLOUD_FUSION_INFERENCE
fusions.push_back(std::make_shared<opt::MultiHeadAttentionFusion>());
#endif
for (size_t index = 0; index < fusions.size(); index++) {
auto pass_ptr = fusions.at(index);
auto pass_name = pass_ptr->name();

View File

@ -568,6 +568,7 @@ int AnfExporter::SetMetaGraphInput(const FuncGraphPtr &func_graph,
for (const auto &input : func_graph->get_inputs()) {
auto iter = graph_inputs_map_.find(input);
if (iter == graph_inputs_map_.end()) {
MS_LOG(ERROR) << "input " << input->ToString() << " not found in graph" << std::endl;
return RET_ERROR;
}
meta_graphT->inputIndex.emplace_back(iter->second);

View File

@ -21,39 +21,59 @@
#include <vector>
#include "tools/optimizer/common/gllo_utils.h"
#include "nnacl/op_base.h"
#include "ops/tuple_get_item.h"
#include "tools/common/tensor_util.h"
namespace mindspore::opt {
namespace {
const auto &p1 = std::placeholders::_1;
const size_t kWeightShapeSize = 2;
const int kAttentionOutputs = 3;
} // namespace
namespace {
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, bool test_div = false) {
auto is_matmul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
MS_CHECK_TRUE_RET(is_matmul != nullptr, {});
auto dense = VectorRef({is_matmul, input, weight, bias});
auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
auto var1 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var1 != nullptr, {});
auto reshape = VectorRef({is_reshape, dense, var1});
auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
auto var2 = std::make_shared<Var>();
auto transpose = VectorRef({is_transpose, reshape, var2});
if (test_div) {
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv));
MS_CHECK_TRUE_RET(is_div != nullptr, {});
auto var3 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var3 != nullptr, {});
auto div = VectorRef({is_div, transpose, var3});
return div;
}
return transpose;
bool MultiHeadAttentionFusion::Init() const {
input_q_ = std::make_shared<Var>();
MS_CHECK_TRUE_RET(input_q_ != nullptr, false);
input_k_ = std::make_shared<Var>();
MS_CHECK_TRUE_RET(input_k_ != nullptr, false);
input_v_ = std::make_shared<Var>();
MS_CHECK_TRUE_RET(input_v_ != nullptr, false);
weight_q_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(weight_q_ != nullptr, false);
weight_k_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(weight_k_ != nullptr, false);
weight_v_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(weight_v_ != nullptr, false);
weight_o_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(weight_o_ != nullptr, false);
bias_q_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(bias_q_ != nullptr, false);
bias_k_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(bias_k_ != nullptr, false);
bias_v_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(bias_v_ != nullptr, false);
bias_o_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(bias_o_ != nullptr, false);
mask_ = std::make_shared<Var>();
MS_CHECK_TRUE_RET(mask_ != nullptr, false);
reshape_k_ = std::make_shared<Var>("reshape_k");
MS_CHECK_TRUE_RET(reshape_k_ != nullptr, false);
reshape_v_ = std::make_shared<Var>("reshape_v");
MS_CHECK_TRUE_RET(reshape_v_ != nullptr, false);
reshape_axis_ = std::make_shared<Var>("reshape_axis");
MS_CHECK_TRUE_RET(reshape_axis_ != nullptr, false);
v_transpose_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "v_transpose");
MS_CHECK_TRUE_RET(v_transpose_ != nullptr, false);
k_transpose_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "k_transpose");
MS_CHECK_TRUE_RET(k_transpose_ != nullptr, false);
return true;
}
namespace {
VectorRef DefineMask(const BaseRef &mask_input) {
auto is_expand_dims = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims));
MS_CHECK_TRUE_RET(is_expand_dims != nullptr, {});
@ -71,21 +91,62 @@ VectorRef DefineMask(const BaseRef &mask_input) {
MS_CHECK_TRUE_RET(var3 != nullptr, {});
return VectorRef({is_mul, sub, var3});
}
STATUS GetAxis(const BaseRef &n, std::vector<int> *axes) {
MS_ASSERT(axes != nullptr);
if (utils::isa<ValueNodePtr>(n)) {
auto axes_value_node = utils::cast<ValueNodePtr>(n);
*axes = CastToInt(axes_value_node->value());
return lite::RET_OK;
} else {
MS_LOG(ERROR) << "GetAxis supports only value node";
}
return lite::RET_ERROR;
}
} // namespace
VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias,
const BaseRef &axis, const BaseRef &transpose_var, bool test_div,
bool transpose) const {
auto is_matmul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "e-matmul");
MS_CHECK_TRUE_RET(is_matmul != nullptr, {});
auto dense = VectorRef({is_matmul, input, weight, bias});
auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "e-reshape");
MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
auto reshape = VectorRef({is_reshape, dense, axis});
auto var2 = std::make_shared<Var>();
VectorRef conn;
if (transpose) {
conn = VectorRef({transpose_var, reshape, var2});
} else {
conn = reshape;
}
if (test_div) {
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "e-div");
MS_CHECK_TRUE_RET(is_div != nullptr, {});
auto var3 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var3 != nullptr, {});
auto div = VectorRef({is_div, conn, var3});
return div;
}
return conn;
}
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool cross, bool mask) const {
VectorRef k_embedding, v_embedding;
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, true);
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true);
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
if (!cross) {
k_embedding = DefineEmbedding(input_q_, weight_k_, bias_k_, true);
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_q_, weight_v_, bias_v_);
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
} else {
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, true);
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_k_, weight_v_, bias_v_);
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
}
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
@ -133,19 +194,73 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool cross, bool mas
return matmul3;
}
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const {
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5(bool cross) const {
VectorRef k_embedding, v_embedding;
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, true);
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "q_transpose");
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true, false);
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
if (!cross) {
k_embedding = DefineEmbedding(input_q_, weight_k_, bias_k_, true);
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_q_, weight_v_, bias_v_);
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
} else {
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, true);
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_k_, weight_v_, bias_v_);
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
}
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul1");
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape1");
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
auto var1 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var1 != nullptr, {});
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "add");
MS_CHECK_TRUE_RET(is_add != nullptr, {});
auto mask = DefineMask(mask_);
MS_CHECK_TRUE_RET(!mask.empty(), {});
auto add = VectorRef({is_add, mask, matmul1});
auto reshape1 = VectorRef({is_reshape1, add, var1});
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax), "softmax");
MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
auto softmax = VectorRef({is_softmax, reshape1});
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape2");
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
auto var2 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var2 != nullptr, {});
auto reshape2 = VectorRef({is_reshape2, softmax, var2});
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul2");
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
auto matmul2 = VectorRef({is_matmul2, reshape2, v_embedding});
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape3");
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
auto var4 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var4 != nullptr, {});
auto reshape3 = VectorRef({is_reshape3, matmul2, var4});
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul");
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_, bias_o_});
return matmul3;
}
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const {
VectorRef k_embedding, v_embedding;
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true);
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
if (!cross) {
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
} else {
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
}
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
@ -169,18 +284,36 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const
auto var3 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var3 != nullptr, {});
auto transpose = VectorRef({is_transpose, matmul2, var3});
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape3");
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
auto var4 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var4 != nullptr, {});
auto reshape3 = VectorRef({is_reshape3, transpose, var4});
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul3");
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_, bias_o_});
return matmul3;
}
namespace {
std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<tensor::Tensor>> &tensors) {
STATUS TransposeMatrix(std::shared_ptr<tensor::Tensor> src, std::shared_ptr<tensor::Tensor> dst) {
MS_CHECK_TRUE_RET(src->shape().size() == C2NUM, RET_ERROR);
MS_CHECK_TRUE_RET(dst->shape().size() == C2NUM, RET_ERROR);
int rows = src->shape().at(0);
int cols = src->shape().at(1);
auto src_ptr = reinterpret_cast<float *>(src->data_c());
auto dst_ptr = reinterpret_cast<float *>(dst->data_c());
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
auto val = src_ptr[r * cols + c];
dst_ptr[c * rows + r] = val;
}
}
return RET_OK;
}
std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<tensor::Tensor>> &tensors,
bool transpose = false) {
const std::vector<int64_t> &base_shape = tensors.at(0)->shape();
auto base_shape_size = base_shape.size();
auto base_data_type = tensors.at(0)->data_type();
@ -205,9 +338,10 @@ std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<
auto sum = std::accumulate(tensors.begin(), tensors.end(), 0,
[](int sum, const tensor::TensorPtr &tensor) { return sum + tensor->shape().at(0); });
new_shape.push_back(sum);
for (std::size_t i = 1; i < base_shape_size; i++) {
for (std::size_t i = 1; i < base_shape_size; ++i) {
new_shape.push_back(base_shape.at(i));
}
// calculate data
auto concat_tensor = std::make_shared<tensor::Tensor>(base_data_type, new_shape);
MS_CHECK_TRUE_RET(concat_tensor != nullptr, nullptr);
@ -217,46 +351,17 @@ std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<
memcpy_s(ptr, concat_tensor->Size() - offset, tensor->data_c(), tensor->Size());
offset += tensor->Size();
}
if (transpose) {
std::vector<int64_t> tshape = {new_shape[1], new_shape[0]};
auto transposed_tensor = std::make_shared<tensor::Tensor>(base_data_type, tshape);
auto status = TransposeMatrix(concat_tensor, transposed_tensor);
MS_CHECK_TRUE_RET(status == RET_OK, nullptr);
return transposed_tensor;
}
return concat_tensor;
}
} // namespace
bool MultiHeadAttentionFusion::Init() const {
input_q_ = std::make_shared<Var>();
MS_CHECK_TRUE_RET(input_q_ != nullptr, false);
input_k_ = std::make_shared<Var>();
MS_CHECK_TRUE_RET(input_k_ != nullptr, false);
// input_v_ = std::make_shared<Var>();
// MS_CHECK_TRUE_RET(input_v_ != nullptr, false);
weight_q_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(weight_q_ != nullptr, false);
weight_k_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(weight_k_ != nullptr, false);
weight_v_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(weight_v_ != nullptr, false);
weight_o_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(weight_o_ != nullptr, false);
bias_q_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(bias_q_ != nullptr, false);
bias_k_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(bias_k_ != nullptr, false);
bias_v_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(bias_v_ != nullptr, false);
bias_o_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(bias_o_ != nullptr, false);
mask_ = std::make_shared<Var>();
MS_CHECK_TRUE_RET(mask_ != nullptr, false);
reshape_k_ = std::make_shared<Var>();
MS_CHECK_TRUE_RET(reshape_k_ != nullptr, false);
reshape_v_ = std::make_shared<Var>();
MS_CHECK_TRUE_RET(reshape_v_ != nullptr, false);
return true;
}
std::unordered_map<std::string, VectorRef> MultiHeadAttentionFusion::DefinePatterns() const {
std::unordered_map<std::string, VectorRef> patterns;
if (!Init()) {
@ -269,18 +374,38 @@ std::unordered_map<std::string, VectorRef> MultiHeadAttentionFusion::DefinePatte
patterns[kMPAXPatternName] = DefineMPWithMaskPattern(true, false);
patterns[kMPAWithMaskPatternNamePA] = DefineMPWithMaskPatternPA();
patterns[kMPAXWithMaskPatternNamePA] = DefineMPWithMaskPatternPA(true);
patterns[kMPAWithMaskPatternNameT5] = DefineMPWithMaskPatternT5();
patterns[kMPAXWithMaskPatternNameT5] = DefineMPWithMaskPatternT5(true);
return patterns;
}
bool MultiHeadAttentionFusion::CheckPattern(const EquivPtr &equiv, int *head_num, int *head_size) const {
MS_ASSERT(equiv != nullptr);
MS_ASSERT(head_num != nullptr);
MS_ASSERT(head_size != nullptr);
std::vector<int> reshape_axes;
if (GetAxis((*equiv)[reshape_axis_], &reshape_axes) != lite::RET_OK) {
return false;
}
if (reshape_axes.size() != C4NUM) {
return false;
}
*head_num = reshape_axes.at(C2NUM);
*head_size = reshape_axes.at(C3NUM);
return true;
}
AnfNodePtr MultiHeadAttentionFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
const mindspore::AnfNodePtr &node,
const mindspore::EquivPtr &equiv) const {
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
return nullptr;
}
if ((pattern_name == kMPAWithMaskPatternName) || (pattern_name == kMPAWithMaskPatternNamePA)) {
if ((pattern_name == kMPAWithMaskPatternName) || (pattern_name == kMPAWithMaskPatternNamePA) ||
(pattern_name == kMPAWithMaskPatternNameT5)) {
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope());
} else if ((pattern_name == kMPAXWithMaskPatternName) || (pattern_name == kMPAXWithMaskPatternNamePA)) {
} else if ((pattern_name == kMPAXWithMaskPatternName) || (pattern_name == kMPAXWithMaskPatternNamePA) ||
(pattern_name == kMPAXWithMaskPatternNameT5)) {
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), true);
} else if (pattern_name == kMPAPatternName) {
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), false, false);
@ -328,7 +453,7 @@ std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(con
}
if (!utils::isa<ParameterPtr>((*equiv)[reshape_v_])) {
MS_LOG(ERROR) << "Reshape k is not a parameter";
MS_LOG(ERROR) << "Reshape v is not a parameter";
return nullptr;
}
@ -353,37 +478,158 @@ std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(con
return attention_prim;
}
CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph,
const EquivPtr &equiv, const string &base_name,
bool cross, bool mask) const {
STATUS MultiHeadAttentionFusion::AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &attention,
int index, const AnfNodePtr &node) const {
auto manager = func_graph->manager();
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr";
return RET_ERROR;
}
auto transpose_users = manager->node_users()[node];
auto user_node = transpose_users.front();
if (!CheckPrimitiveType(user_node.first, prim::kPrimTranspose)) {
MS_LOG(ERROR) << " missing transpose node for branch " << index << std::endl;
return RET_ERROR;
}
// connect get item to it
transpose_users = manager->node_users()[user_node.first];
auto get_item = CreateOutputGetItem(func_graph, attention, index);
MS_ASSERT(get_item != nullptr);
if (transpose_users.size() == 1) {
auto &snode = transpose_users.front();
manager->SetEdge(snode.first, snode.second, get_item);
} else {
for (auto &snode : transpose_users) {
if (CheckPrimitiveType(snode.first, prim::kPrimMakeTuple)) {
manager->SetEdge(snode.first, snode.second, get_item);
break;
}
}
}
return RET_OK;
}
CNodePtr MultiHeadAttentionFusion::CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node,
const int item_index) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
MS_ASSERT(node != nullptr);
auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
auto get_item_value = NewValueNode(MakeValue<int>(item_index));
if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
MS_LOG(ERROR) << "NewValueNode is nullptr";
return nullptr;
}
auto tuple_get_item_prim_c = tuple_get_item_prim->GetPrim();
MS_ASSERT(tuple_get_item_prim_c != nullptr);
CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim_c, {node, get_item_value});
MS_CHECK_TRUE_RET(get_item_cnode != nullptr, nullptr);
auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstract failed";
return nullptr;
}
get_item_cnode->set_abstract(abstract);
get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" +
std::to_string(item_index));
return get_item_cnode;
}
STATUS MultiHeadAttentionFusion::SetAbstractTuple(const CNodePtr &cnode, const int output_num) const {
MS_ASSERT(cnode != nullptr);
AbstractBasePtrList abstract_list;
for (int i = 0; i < output_num; ++i) {
auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Create tensor abstract failed";
return RET_ERROR;
}
abstract_list.emplace_back(abstract);
}
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
if (abstract_tuple == nullptr) {
MS_LOG(ERROR) << "create abstract_tuple failed";
return RET_ERROR;
}
cnode->set_abstract(abstract_tuple);
return RET_OK;
}
STATUS MultiHeadAttentionFusion::RemoveRedundantInput(const FuncGraphPtr &func_graph,
const std::vector<AnfNodePtr> &redundant) const {
for (auto &node : redundant) {
func_graph->DropNode(node);
}
return RET_OK;
}
std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::CreatePrim(const EquivPtr &equiv, bool cross) const {
auto attention_prim = std::make_shared<ops::Attention>();
if (attention_prim == nullptr) {
MS_LOG(ERROR) << "Build attention primitive failed.";
return nullptr;
}
int head_num = 0;
int head_size = 0;
if (!CheckPattern(equiv, &head_num, &head_size)) {
return nullptr;
}
attention_prim->Init(head_num, head_size, cross);
return attention_prim;
}
CNodePtr MultiHeadAttentionFusion::MakeGetTuple(const FuncGraphPtr &func_graph, const CNodePtr &new_node,
const AnfNodePtr &knode, const AnfNodePtr &vnode) const {
auto get_item_node = CreateOutputGetItem(func_graph, new_node, 0);
if (get_item_node == nullptr) {
MS_LOG(ERROR) << "create attention output get_item node failed";
return nullptr;
}
if (knode != nullptr) {
auto status = AdjustOtherGetItems(func_graph, new_node, 1, knode);
MS_CHECK_TRUE_RET(status == RET_OK, nullptr);
}
if (vnode != nullptr) {
auto status = AdjustOtherGetItems(func_graph, new_node, 2, vnode);
MS_CHECK_TRUE_RET(status == RET_OK, nullptr);
}
return get_item_node;
}
CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph,
const EquivPtr &equiv, const string &base_name,
bool cross, bool mask) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
std::vector<AnfNodePtr> redundant;
auto attention_prim = CreatePrim(equiv, cross);
MS_CHECK_TRUE_RET(attention_prim != nullptr, nullptr);
auto attention_prim_c = attention_prim->GetPrim();
MS_CHECK_TRUE_RET(attention_prim_c != nullptr, nullptr);
auto value_node = NewValueNode(attention_prim_c);
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
auto input_q = utils::cast<AnfNodePtr>((*equiv)[input_q_]);
AnfNodePtr input_k, input_mask;
if (cross) {
input_k = utils::cast<AnfNodePtr>((*equiv)[input_k_]);
}
// auto input_v = utils::cast<AnfNodePtr>((*equiv)[input_v_]);
auto input_k = utils::cast<AnfNodePtr>((*equiv)[input_k_]);
auto input_v = utils::cast<AnfNodePtr>((*equiv)[input_v_]);
AnfNodePtr input_mask;
auto weight_q = utils::cast<AnfNodePtr>((*equiv)[weight_q_]);
redundant.push_back(weight_q);
auto weight_k = utils::cast<AnfNodePtr>((*equiv)[weight_k_]);
auto weight_v = utils::cast<AnfNodePtr>((*equiv)[weight_v_]);
redundant.push_back(weight_k);
redundant.push_back(weight_v);
auto weight_o = utils::cast<AnfNodePtr>((*equiv)[weight_o_]);
auto bias_q = utils::cast<AnfNodePtr>((*equiv)[bias_q_]);
if (!cross) {
redundant.push_back(bias_q);
}
auto bias_k = utils::cast<AnfNodePtr>((*equiv)[bias_k_]);
auto bias_v = utils::cast<AnfNodePtr>((*equiv)[bias_v_]);
redundant.push_back(bias_k);
redundant.push_back(bias_v);
auto bias_o = utils::cast<AnfNodePtr>((*equiv)[bias_o_]);
auto knode = utils::cast<AnfNodePtr>((*equiv)[k_transpose_]);
auto vnode = utils::cast<AnfNodePtr>((*equiv)[v_transpose_]);
if (mask) {
input_mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
}
@ -394,10 +640,12 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
std::shared_ptr<tensor::Tensor> bias_k_tensor = GetTensorInfo(bias_k);
std::shared_ptr<tensor::Tensor> bias_v_tensor = GetTensorInfo(bias_v);
tensor::TensorPtr c_weights;
tensor::TensorPtr q_weight_t;
if (cross) {
c_weights = ConcatTensors({weight_k_tensor, weight_v_tensor});
c_weights = ConcatTensors({weight_k_tensor, weight_v_tensor}, true);
q_weight_t = ConcatTensors({weight_q_tensor}, true);
} else {
c_weights = ConcatTensors({weight_q_tensor, weight_k_tensor, weight_v_tensor});
c_weights = ConcatTensors({weight_q_tensor, weight_k_tensor, weight_v_tensor}, true);
}
auto c_bias = ConcatTensors({bias_q_tensor, bias_k_tensor, bias_v_tensor});
// convert tensors to params
@ -415,18 +663,34 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
return nullptr;
}
c_bias_param->set_name(base_name + "/bias_qkv");
ParameterPtr q_weight_param;
if (cross) {
q_weight_param = func_graph->add_parameter();
MS_CHECK_TRUE_RET(q_weight_param != nullptr, nullptr);
if (lite::InitParameterFromTensorInfo(q_weight_param, q_weight_t) != lite::RET_OK) {
MS_LOG(ERROR) << "Init parameter from tensor info failed.";
return nullptr;
}
}
std::vector<AnfNodePtr> new_node_inputs;
if (cross) {
new_node_inputs = {value_node, input_q, input_k, input_k, weight_q, c_weight_param, weight_o, c_bias_param, bias_o};
new_node_inputs = {value_node, input_q, input_k, input_v, q_weight_param,
c_weight_param, weight_o, c_bias_param, bias_o};
} else {
new_node_inputs = {value_node, input_q, input_q, input_q, c_weight_param, weight_o, c_bias_param, bias_o};
new_node_inputs = {value_node, input_q, input_k, input_v, c_weight_param, weight_o, c_bias_param, bias_o};
}
if (mask) {
new_node_inputs.push_back(input_mask);
}
auto new_node = func_graph->NewCNode(new_node_inputs);
MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
new_node->set_fullname_with_scope(base_name);
return new_node;
auto status = SetAbstractTuple(new_node, kAttentionOutputs);
if (status != RET_OK) {
return nullptr;
}
new_node->set_fullname_with_scope(base_name + "/attention");
auto get_item_node = MakeGetTuple(func_graph, new_node, knode, vnode);
RemoveRedundantInput(func_graph, redundant);
return get_item_node;
}
} // namespace mindspore::opt

View File

@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
#include "include/common/utils/utils.h"
#include "include/errorcode.h"
@ -48,10 +49,24 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
// define patterns
VectorRef DefineMPWithMaskPattern(bool cross = false, bool mask = true) const;
VectorRef DefineMPWithMaskPatternPA(bool cross = false) const;
VectorRef DefineMPWithMaskPatternT5(bool cross = false) const;
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, const BaseRef &axis,
const BaseRef &transpose_var, bool test_div = false, bool transpose = true) const;
// create masked-multi-head-attention
CNodePtr CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const std::string &base_name, bool cross = false, bool mask = true) const;
// check pattern
bool CheckPattern(const EquivPtr &equiv, int *head_num, int *head_size) const;
CNodePtr CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, const int item_index) const;
lite::STATUS SetAbstractTuple(const CNodePtr &cnode, const int output_num) const;
lite::STATUS AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &attention, int index,
const AnfNodePtr &node) const;
lite::STATUS RemoveRedundantInput(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &redundant) const;
std::shared_ptr<ops::Attention> CreatePrim() const;
CNodePtr MakeGetTuple(const FuncGraphPtr &func_graph, const CNodePtr &new_node, const AnfNodePtr &knode,
const AnfNodePtr &vnode) const;
std::shared_ptr<ops::Attention> CreatePrim(const EquivPtr &equiv, bool cross) const;
protected:
const std::string kMPAWithMaskPatternName = "MPAWithMaskPattern";
@ -60,6 +75,8 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
const std::string kMPAXWithMaskPatternNamePA = "MPAXWithMaskPatternPA";
const std::string kMPAPatternName = "MPAPattern";
const std::string kMPAXPatternName = "MPAXPattern";
const std::string kMPAWithMaskPatternNameT5 = "MPAWithMaskPatternT5";
const std::string kMPAXWithMaskPatternNameT5 = "MPAXWithMaskPatternT5";
mutable VarPtr input_q_{nullptr};
mutable VarPtr input_k_{nullptr};
@ -78,6 +95,10 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
mutable VarPtr reshape_k_{nullptr};
mutable VarPtr reshape_v_{nullptr};
mutable VarPtr reshape_axis_{nullptr};
mutable VarPtr v_transpose_{nullptr};
mutable VarPtr k_transpose_{nullptr};
};
} // namespace opt

View File

@ -128,7 +128,7 @@ index 8707220..aea35e6 100644
target_link_libraries(trt_fused_multi_head_attention PUBLIC -lcublas -lcudart)
set_property(TARGET trt_fused_multi_head_attention PROPERTY POSITION_INDEPENDENT_CODE ON)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ea21014..3098d8a 100644
index ea21014..97b842e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -14,7 +14,9 @@
@ -255,7 +255,7 @@ index ea21014..3098d8a 100644
include(CMakePackageConfigHelpers)
configure_package_config_file(
${CMAKE_CURRENT_LIST_DIR}/cmake/FasterTransformerConfig.cmake.in
@@ -402,52 +392,23 @@ configure_package_config_file(
@@ -402,28 +392,14 @@ configure_package_config_file(
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
)
@ -281,16 +281,12 @@ index ea21014..3098d8a 100644
- FasterTransformerTargets.cmake
- DESTINATION
- ${INSTALL_CONFIGDIR}
+ LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/output/lib
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/output/lib
+ LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/lib
)
file(GLOB_RECURSE HEADER_FILES "*.h" "*.hpp" "*.cuh")
foreach ( file ${HEADER_FILES} )
file( RELATIVE_PATH rfile ${CMAKE_CURRENT_SOURCE_DIR} ${file} )
get_filename_component( dir ${rfile} DIRECTORY )
- install( FILES ${file} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${dir} )
+ install( FILES ${file} DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/output/include )
@@ -434,20 +410,5 @@ foreach ( file ${HEADER_FILES} )
endforeach()