forked from mindspore-Ecosystem/mindspore
attention kernel bit exact
This commit is contained in:
parent
ca371e1531
commit
3ae59185b2
|
@ -8,6 +8,9 @@ mindspore_add_pkg(fast_transformers
|
|||
URL ${REQ_URL}
|
||||
MD5 ${MD5}
|
||||
LIBS ${ft_libs}
|
||||
LIB_PATH output/lib
|
||||
LIB_PATH lib
|
||||
PATCHES ${MINDSPORE_PROJECT_DIR}/third_party/patch/fast_transformer/001-fast_transformer.patch
|
||||
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DEXAMPLES=off)
|
||||
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DEXAMPLES=off)
|
||||
include_directories(${fast_transformers_INC})
|
||||
|
||||
add_library(mindspore::fast_transformers ALIAS fast_transformers::transformer-shared)
|
||||
|
|
|
@ -58,9 +58,6 @@ endif()
|
|||
|
||||
if(ENABLE_GPU)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/cub.cmake)
|
||||
if(NOT MSVC)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/fast_transformers.cmake)
|
||||
endif()
|
||||
if(ENABLE_MPI)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/nccl.cmake)
|
||||
endif()
|
||||
|
|
|
@ -852,6 +852,8 @@ else()
|
|||
if(MSLITE_GPU_BACKEND STREQUAL tensorrt)
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/extendrt/delegate/tensorrt/libtensorrt_plugin.so
|
||||
DESTINATION ${RUNTIME_LIB_DIR} COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
install(FILES ${fast_transformers_LIBPATH}/libtransformer-shared.so DESTINATION ${RUNTIME_LIB_DIR}
|
||||
COMPONENT ${RUNTIME_COMPONENT_NAME})
|
||||
endif()
|
||||
else()
|
||||
install(FILES ${TOP_DIR}/mindspore/lite/build/src/${MINDSPORE_LITE_LIB_NAME}.so DESTINATION ${RUNTIME_LIB_DIR}
|
||||
|
|
|
@ -133,9 +133,10 @@ EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv)
|
|||
(*equiv)[var] = expr;
|
||||
MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString();
|
||||
return equiv;
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "pattern not match: " + pattern.ToString() + ", " + expr.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -222,6 +223,7 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorR
|
|||
return nullptr;
|
||||
}
|
||||
}
|
||||
if ((values_expr_len == 0) && (values_pattern_len == 0)) return equiv;
|
||||
if (values_expr_len < values_pattern_len - 1) {
|
||||
MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len;
|
||||
return nullptr;
|
||||
|
|
|
@ -109,6 +109,7 @@ class VarHasher {
|
|||
class CondVar : public Var {
|
||||
public:
|
||||
explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {}
|
||||
explicit CondVar(const ConditionFunc &cond, std::string tag) : Var(tag), cond_fn_(cond) {}
|
||||
~CondVar() override = default;
|
||||
MS_DECLARE_PARENT(CondVar, Var);
|
||||
bool matches(const BaseRef &value) override {
|
||||
|
|
|
@ -18,6 +18,13 @@
|
|||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct AttentionParameter {
|
||||
OpParameter op_parameter_;
|
||||
int head_num_;
|
||||
int head_size_;
|
||||
bool cross_;
|
||||
} AttentionParameter;
|
||||
|
||||
typedef struct RelativePositionAttentionParameter {
|
||||
// Primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "nnacl/infer/attention_infer.h"
|
||||
#include "nnacl/infer/infer_register.h"
|
||||
#include "nnacl/attention_parameter.h"
|
||||
|
||||
int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
|
||||
OpParameter *parameter) {
|
||||
|
@ -23,27 +24,44 @@ int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
|||
if (check_ret != NNACL_OK) {
|
||||
return check_ret;
|
||||
}
|
||||
const TensorC *q_input = inputs[0];
|
||||
TensorC *output = outputs[0];
|
||||
SetDataTypeFormat(output, q_input);
|
||||
AttentionParameter *param = (AttentionParameter *)parameter;
|
||||
const TensorC *q_input = inputs[FIRST_INPUT];
|
||||
const TensorC *k_input = inputs[SECOND_INPUT];
|
||||
TensorC *output0 = outputs[FIRST_INPUT];
|
||||
SetDataTypeFormat(output0, q_input);
|
||||
if (!InferFlag(inputs, inputs_size)) {
|
||||
return NNACL_INFER_INVALID;
|
||||
}
|
||||
const TensorC *q_weight = inputs[3];
|
||||
if (q_input->shape_size_ != 2 && q_input->shape_size_ != 3) {
|
||||
const TensorC *q_weight = inputs[FOURTH_INPUT];
|
||||
if (q_input->shape_size_ != C2NUM && q_input->shape_size_ != C3NUM) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
if (q_weight->shape_size_ != 2) {
|
||||
if (q_weight->shape_size_ != C2NUM) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
int batch = (q_input->shape_size_ == 2) ? 1 : q_input->shape_[0];
|
||||
int f_seq = (q_input->shape_size_ == 2) ? q_input->shape_[0] : q_input->shape_[1];
|
||||
int d_model = q_weight->shape_[1];
|
||||
|
||||
output->shape_[0] = batch;
|
||||
output->shape_[1] = f_seq;
|
||||
output->shape_[2] = d_model;
|
||||
output->shape_size_ = 3;
|
||||
int t_seq_len = k_input->shape_[1];
|
||||
output0->shape_[FIRST_INPUT] = batch;
|
||||
output0->shape_[SECOND_INPUT] = f_seq;
|
||||
output0->shape_[THIRD_INPUT] = param->head_num_ * param->head_size_;
|
||||
output0->shape_size_ = C3NUM;
|
||||
if (outputs_size >= C3NUM) {
|
||||
TensorC *output1 = outputs[SECOND_INPUT];
|
||||
SetDataTypeFormat(output1, q_input);
|
||||
output1->shape_[FIRST_INPUT] = batch;
|
||||
output1->shape_[SECOND_INPUT] = param->head_num_;
|
||||
output1->shape_[THIRD_INPUT] = param->head_size_;
|
||||
output1->shape_[FOURTH_INPUT] = t_seq_len;
|
||||
output1->shape_size_ = C4NUM;
|
||||
TensorC *output2 = outputs[THIRD_INPUT];
|
||||
SetDataTypeFormat(output2, q_input);
|
||||
output2->shape_[FIRST_INPUT] = batch;
|
||||
output2->shape_[SECOND_INPUT] = param->head_num_;
|
||||
output2->shape_[THIRD_INPUT] = t_seq_len;
|
||||
output2->shape_[FOURTH_INPUT] = param->head_size_;
|
||||
output2->shape_size_ = C4NUM;
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,9 +17,39 @@
|
|||
|
||||
#include "ops/attention.h"
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore::ops {
|
||||
MIND_API_OPERATOR_IMPL(Attention, BaseOperator);
|
||||
|
||||
void Attention::set_head_num(int64_t head_num) { (void)this->AddAttr(kAttentionNumHeads, api::MakeValue(head_num)); }
|
||||
|
||||
void Attention::set_head_size(int64_t head_size) {
|
||||
(void)this->AddAttr(kAttentionSizePerHead, api::MakeValue(head_size));
|
||||
}
|
||||
|
||||
void Attention::set_cross(bool cross) { (void)this->AddAttr(kCross, api::MakeValue(cross)); }
|
||||
|
||||
int64_t Attention::get_head_num() const {
|
||||
auto value_ptr = this->GetAttr(kAttentionNumHeads);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Attention::get_head_size() const {
|
||||
auto value_ptr = this->GetAttr(kAttentionSizePerHead);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
bool Attention::get_cross() const {
|
||||
auto value_ptr = this->GetAttr(kCross);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
void Attention::Init(int64_t head_num, int64_t head_size, bool cross) {
|
||||
this->set_head_num(head_num);
|
||||
this->set_head_size(head_size);
|
||||
this->set_cross(cross);
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAttention, Attention);
|
||||
} // namespace mindspore::ops
|
||||
|
|
|
@ -37,7 +37,16 @@ class MIND_API Attention : public BaseOperator {
|
|||
{"output"});
|
||||
}
|
||||
/// \brief Initialize Attention op.
|
||||
void Init() const {}
|
||||
/// \param[in] head_num Define head number.
|
||||
/// \param[in] head_size Define size per head.
|
||||
/// \param[in] cross Define is cross attention. Default false.
|
||||
void Init(int64_t head_num, int64_t head_size, bool cross = false);
|
||||
void set_head_num(int64_t head_num);
|
||||
void set_head_size(int64_t head_size);
|
||||
void set_cross(bool cross);
|
||||
int64_t get_head_num() const;
|
||||
int64_t get_head_size() const;
|
||||
bool get_cross() const;
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -686,6 +686,7 @@ GVAR_DEF(PrimitivePtr, kPrimSoftmaxV2WithDropoutDoMaskV3, std::make_shared<Primi
|
|||
GVAR_DEF(PrimitivePtr, kPrimLogSoftmax, std::make_shared<Primitive>("LogSoftmax"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimLogSoftmaxGrad, std::make_shared<Primitive>("LogSoftmaxGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimLstm, std::make_shared<Primitive>("LSTM"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimAttention, std::make_shared<Primitive>("Attention"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimLstmGradData, std::make_shared<Primitive>("LSTMGradData"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimLstmGradWeight, std::make_shared<Primitive>("LSTMGradWeight"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimTan, std::make_shared<Primitive>("Tan"));
|
||||
|
|
|
@ -145,10 +145,10 @@ constexpr auto kNumElements = "num_elements";
|
|||
constexpr auto kNumBits = "num_bits";
|
||||
constexpr auto kNumDirections = "num_directions";
|
||||
constexpr auto kNumProj = "num_proj";
|
||||
constexpr auto kAttentionNumHeads = "attention_num_heads";
|
||||
constexpr auto kAttentionSizePerHead = "attention_size_per_head";
|
||||
constexpr auto kAttentionFromSeqLen = "attention_from_seq_len";
|
||||
constexpr auto kAttentionToSeqLen = "attention_to_seq_len";
|
||||
constexpr auto kAttentionNumHeads = "head_num";
|
||||
constexpr auto kAttentionSizePerHead = "head_size";
|
||||
constexpr auto kAttentionFromSeqLen = "from_seq_len";
|
||||
constexpr auto kAttentionToSeqLen = "to_seq_len";
|
||||
constexpr auto kOffset = "offset";
|
||||
constexpr auto kNmsIouThreshold = "nms_iou_threshold";
|
||||
constexpr auto kNmsScoreThreshold = "nms_score_threshold";
|
||||
|
@ -345,6 +345,7 @@ constexpr auto kSeed0 = "seed0";
|
|||
constexpr auto kSeed1 = "seed1";
|
||||
constexpr auto kHandle = "handle";
|
||||
constexpr auto kBatchSize = "batch_size";
|
||||
constexpr auto kCross = "cross";
|
||||
constexpr auto kDeviceNum = "device_num";
|
||||
|
||||
constexpr size_t kInputIndex0 = 0;
|
||||
|
|
|
@ -749,6 +749,9 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
|
|||
set(MSLITE_DEPS_LIBEVENT on)
|
||||
set(MSLITE_DEPS_PYBIND11 on)
|
||||
set(MSLITE_DEPS_OPENSSL on)
|
||||
if(SUPPORT_TENSORRT)
|
||||
set(MSLITE_DEPS_FAST_TRANSFORMERS on)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(MSLITE_ENABLE_MODEL_ENCRYPTION)
|
||||
|
|
|
@ -34,6 +34,10 @@ if(MSLITE_DEPS_OPENCV)
|
|||
include(${TOP_DIR}/cmake/external_libs/opencv.cmake)
|
||||
endif()
|
||||
|
||||
if(MSLITE_DEPS_FAST_TRANSFORMERS)
|
||||
include(${TOP_DIR}/cmake/external_libs/fast_transformers.cmake)
|
||||
endif()
|
||||
|
||||
if(MSLITE_DEPS_MKLDNN)
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
set(USE_MS_THREADPOOL_FOR_DNNL ON)
|
||||
|
|
|
@ -390,6 +390,9 @@ table Concat {
|
|||
}
|
||||
|
||||
table Attention {
|
||||
head_num: long;
|
||||
head_size: long;
|
||||
cross: bool;
|
||||
}
|
||||
|
||||
table Conv2DBackpropFilterFusion {
|
||||
|
|
|
@ -390,6 +390,9 @@ OP_ATTR(axis, long)
|
|||
OP_SCHEMA_DEF_END(Concat)
|
||||
|
||||
OP_SCHEMA_DEF(Attention)
|
||||
OP_ATTR(head_num, long)
|
||||
OP_ATTR(head_size, long);
|
||||
OP_ATTR(cross, bool)
|
||||
OP_SCHEMA_DEF_END(Attention)
|
||||
|
||||
OP_SCHEMA_DEF(Conv2DBackpropFilterFusion)
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/common/ops/populate/populate_register.h"
|
||||
#include "nnacl/attention_parameter.h"
|
||||
|
||||
using mindspore::schema::PrimitiveType_Attention;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateAttentionParameter(const void *prim) {
|
||||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
|
||||
auto value = primitive->value_as_Attention();
|
||||
MS_CHECK_TRUE_MSG(value != nullptr, nullptr, "value is nullptr.");
|
||||
auto *param = reinterpret_cast<AttentionParameter *>(malloc(sizeof(AttentionParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc AttentionParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(AttentionParameter));
|
||||
param->op_parameter_.type_ = primitive->value_type();
|
||||
param->head_num_ = value->head_num();
|
||||
param->head_size_ = value->head_size();
|
||||
param->cross_ = value->cross();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
|
||||
REG_POPULATE(PrimitiveType_Attention, PopulateAttentionParameter, SCHEMA_CUR)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -41,7 +41,6 @@ OpParameter *PopulateCommonParameter(const void *prim) {
|
|||
REG_POPULATE(PrimitiveType_AddN, PopulateCommonParameter, SCHEMA_CUR)
|
||||
REG_POPULATE(PrimitiveType_ZerosLike, PopulateCommonParameter, SCHEMA_CUR)
|
||||
REG_POPULATE(PrimitiveType_Depend, PopulateCommonParameter, SCHEMA_CUR)
|
||||
REG_POPULATE(PrimitiveType_Attention, PopulateCommonParameter, SCHEMA_CUR)
|
||||
REG_POPULATE(PrimitiveType_SwitchLayer, PopulateCommonParameter, SCHEMA_CUR)
|
||||
REG_POPULATE(PrimitiveType_Log1p, PopulateCommonParameter, SCHEMA_CUR)
|
||||
} // namespace lite
|
||||
|
|
|
@ -97,4 +97,4 @@ target_link_libraries(
|
|||
add_subdirectory(cuda_impl)
|
||||
|
||||
target_link_libraries(tensorrt_plugin cuda_kernel_mid gpu_distribution_collective)
|
||||
target_link_libraries(tensorrt_plugin mindspore-extendrt mindspore_core)
|
||||
target_link_libraries(tensorrt_plugin mindspore-extendrt mindspore_core mindspore::fast_transformers)
|
||||
|
|
|
@ -45,6 +45,7 @@ void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const
|
|||
CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_b, trans_a, n, m, k, &alpha, b_addr, type_b, ldb, a_addr, type_a,
|
||||
lda, &beta, c_addr, type_c, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *params,
|
||||
const cublasOperation_t *operations, const cudaDataType *data_types,
|
||||
cublasHandle_t cublas_handle) {
|
||||
|
@ -67,4 +68,49 @@ void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *
|
|||
type_a, lda, &beta, c_addrs, type_c, ldc, batch, compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
|
||||
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
|
||||
cublasHandle_t cublas_handle) {
|
||||
const int m = params[0];
|
||||
const int n = params[1];
|
||||
const int k = params[2];
|
||||
cublasOperation_t trans_a = operations[0];
|
||||
cublasOperation_t trans_b = operations[1];
|
||||
const int lda = lds[0];
|
||||
const int ldb = lds[1];
|
||||
const int ldc = lds[2];
|
||||
cudaDataType type_a = data_types[0];
|
||||
cudaDataType type_b = data_types[1];
|
||||
cudaDataType type_c = data_types[2];
|
||||
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
|
||||
CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda, b_addr, type_b,
|
||||
ldb, beta, c_addr, type_c, ldc, compute_type, CUBLAS_GEMM_DEFAULT));
|
||||
}
|
||||
|
||||
void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
|
||||
const int *lds, const cublasOperation_t *operations, const int *strides,
|
||||
const cudaDataType *data_types, void *alpha, void *beta, int batch,
|
||||
cublasHandle_t cublas_handle) {
|
||||
const int m = params[0];
|
||||
const int n = params[1];
|
||||
const int k = params[2];
|
||||
cublasOperation_t trans_a = operations[0];
|
||||
cublasOperation_t trans_b = operations[1];
|
||||
const int lda = lds[0];
|
||||
const int ldb = lds[1];
|
||||
const int ldc = lds[2];
|
||||
cudaDataType type_a = data_types[0];
|
||||
cudaDataType type_b = data_types[1];
|
||||
cudaDataType type_c = data_types[2];
|
||||
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
const int stride_a = strides[0];
|
||||
const int stride_b = strides[1];
|
||||
const int stride_c = strides[2];
|
||||
|
||||
CUBLAS_CHECK_VOID(cublasGemmStridedBatchedEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda,
|
||||
stride_a, b_addr, type_b, ldb, stride_b, beta, c_addr, type_c, ldc,
|
||||
stride_c, batch, compute_type, CUBLAS_GEMM_DEFAULT));
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <cublas_v2.h>
|
||||
#include "src/extendrt/delegate/tensorrt/cuda_impl/cuda_helper.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
@ -58,5 +59,13 @@ void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const
|
|||
// data_types: type_a, type_b, type_c, compute type
|
||||
void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *params,
|
||||
const cublasOperation_t *operations, const cudaDataType *data_types, cublasHandle_t cublas_handle);
|
||||
|
||||
void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
|
||||
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
|
||||
cublasHandle_t cublas_handle);
|
||||
void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
|
||||
const int *lds, const cublasOperation_t *operations, const int *strides,
|
||||
const cudaDataType *data_types, void *alpha, void *beta, int batch,
|
||||
cublasHandle_t cublas_handle);
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,330 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <numeric>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <algorithm>
|
||||
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
|
||||
#include "NvInferRuntimeCommon.h"
|
||||
#include "src/extendrt/delegate/tensorrt/op/mha_tensorrt.h"
|
||||
#include "ops/attention.h"
|
||||
#include "src/fastertransformer/kernels/unfused_attention_kernels.h"
|
||||
#include "src/fastertransformer/kernels/activation_kernels.h"
|
||||
#include "src/fastertransformer/utils/cuda_utils.h"
|
||||
#include "src/fastertransformer/utils/allocator.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
constexpr std::size_t kTwo = 2;
|
||||
constexpr std::size_t kThree = 3;
|
||||
} // namespace
|
||||
|
||||
// Multi Head Attention TensorRT op
|
||||
int MhaTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
||||
const std::vector<TensorInfo> &out_tensors) {
|
||||
if (in_tensors.size() != 8 && in_tensors.size() != 6) {
|
||||
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
||||
if (ctx == nullptr || ctx->network() == nullptr) {
|
||||
MS_LOG(ERROR) << "context or network is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto mha_op = AsOps<ops::Attention>();
|
||||
if (mha_op == nullptr) {
|
||||
MS_LOG(ERROR) << "op action convert failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// get attribute for Attn op - TODO - add attribute in op
|
||||
int head_number = mha_op->get_head_num();
|
||||
int head_size = mha_op->get_head_size();
|
||||
int compute_type = 1; // mha_op->get_compute_type();
|
||||
int is_cross = mha_op->get_cross();
|
||||
nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_;
|
||||
|
||||
auto plugin = std::make_shared<MhaPlugin>(input_tensor->getName(), compute_type, head_number, head_size, is_cross,
|
||||
GetCublasHandle(), GetCublasLtHandle(), device_id_);
|
||||
const int input_number = inputs().size();
|
||||
nvinfer1::ITensor *inputTensors[input_number];
|
||||
for (int i = 0; i < input_number; i++) {
|
||||
inputTensors[i] = input(ctx, i).trt_tensor_;
|
||||
}
|
||||
nvinfer1::IPluginV2Layer *mha_layer = ctx->network()->addPluginV2(inputTensors, input_number, *plugin);
|
||||
if (mha_layer == nullptr) {
|
||||
MS_LOG(ERROR) << "add mha op failed for TensorRT.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
mha_layer->setName(op_name_.c_str());
|
||||
// TODO(haim) one output
|
||||
nvinfer1::ITensor *attn_tensor = mha_layer->getOutput(0);
|
||||
ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name());
|
||||
// nvinfer1::ITensor *key_tensor = mha_layer->getOutput(1);
|
||||
// ctx->RegisterTensor(ITensorHelper{key_tensor, Format::NCHW, true}, out_tensors_[1].Name());
|
||||
// nvinfer1::ITensor *value_tensor = mha_layer->getOutput(kTwo);
|
||||
// ctx->RegisterTensor(ITensorHelper{value_tensor, Format::NCHW, true}, out_tensors_[kTwo].Name());
|
||||
this->layer_ = mha_layer;
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
// PLUGIN of Multi Head Attention Layer
|
||||
REGISTER_TENSORRT_PLUGIN(MhaPluginCreater);
|
||||
template class TensorRTPluginCreater<MhaPlugin>;
|
||||
template <class T>
|
||||
nvinfer1::PluginFieldCollection TensorRTPluginCreater<T>::field_collection_{};
|
||||
template <class T>
|
||||
std::vector<nvinfer1::PluginField> TensorRTPluginCreater<T>::fields_;
|
||||
|
||||
int MhaPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept {
|
||||
return RunCudaMha(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
}
|
||||
|
||||
int MhaPlugin::RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) {
|
||||
// inputs order:
|
||||
// 0] Q
|
||||
// 1] K
|
||||
// 2] V
|
||||
// 3] W
|
||||
// 4] PW
|
||||
// 5] B
|
||||
// 6] PB
|
||||
// 7] AttnMask
|
||||
|
||||
cublasSetStream(cublas_handle_, stream);
|
||||
|
||||
// TODO(Haim) - Fix tensor ids according to cross flag
|
||||
const int from_tensor_idx = 0;
|
||||
// const int encoder_tensor_idx = 1;
|
||||
const int weight_qkv_tensor_idx = 3;
|
||||
const int weight_projection_tensor_idx = 4;
|
||||
const int bias_qkv_tensor_idx = 5;
|
||||
const int bias_projection_tensor_idx = 6;
|
||||
const int attn_mask_tensor_idx = 7;
|
||||
|
||||
auto from_tensor = static_cast<const float *>(inputs[from_tensor_idx]);
|
||||
auto attention_mask = static_cast<const float *>(inputs[attn_mask_tensor_idx]);
|
||||
auto weight_qkv = static_cast<const float *>(inputs[weight_qkv_tensor_idx]);
|
||||
auto bias_qkv = static_cast<const float *>(inputs[bias_qkv_tensor_idx]);
|
||||
auto weight_projection = static_cast<const float *>(inputs[weight_projection_tensor_idx]);
|
||||
auto bias_projection = static_cast<const float *>(inputs[bias_projection_tensor_idx]);
|
||||
|
||||
auto output0 = static_cast<float *>(outputs[0]);
|
||||
// auto output1 = static_cast<float *>(outputs[1]);
|
||||
// auto output2 = static_cast<float *>(outputs[2]);
|
||||
|
||||
auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims;
|
||||
const int request_batch_size = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[0]);
|
||||
const int request_src_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]);
|
||||
const int request_tgt_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]);
|
||||
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
|
||||
|
||||
// TODO(NIZZAN): fix allocator
|
||||
size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
|
||||
|
||||
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
|
||||
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
|
||||
size_t size_v = size_k;
|
||||
|
||||
size_t qkv_len = size_q + size_k + size_v;
|
||||
size_t q_buf_2_len = size_q;
|
||||
size_t qk_buf_len = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len;
|
||||
size_t qkv_buf_2_len = request_batch_size * request_src_seq_len * hidden_size;
|
||||
size_t qkv_buf_3_len = qkv_buf_2_len;
|
||||
size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len;
|
||||
qkv_buf_ = workspace;
|
||||
q_buf_2_ = static_cast<float *>(qkv_buf_) + qkv_len;
|
||||
qk_buf_ = static_cast<float *>(q_buf_2_) + q_buf_2_len;
|
||||
qkv_buf_2_ = static_cast<float *>(qk_buf_) + qk_buf_len;
|
||||
qkv_buf_3_ = static_cast<float *>(qkv_buf_2_) + qkv_buf_2_len;
|
||||
output1_ = static_cast<float *>(workspace) + buff_size;
|
||||
output2_ = static_cast<float *>(output1_) + extra_tmp_size;
|
||||
|
||||
int gemm_dims[3] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size};
|
||||
int gemm_lds[3] = {3 * hidden_size, hidden_size, 3 * hidden_size};
|
||||
|
||||
cublasOperation_t gemm_ops[2] = {CUBLAS_OP_N, CUBLAS_OP_N};
|
||||
const cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F};
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
|
||||
CublasGemmWrapper(weight_qkv, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha, &beta,
|
||||
cublas_handle_);
|
||||
|
||||
fastertransformer::invokeAddFusedQKVBiasTranspose(static_cast<float *>(q_buf_2_), static_cast<float *>(output1_),
|
||||
static_cast<float *>(output2_), static_cast<float *>(qkv_buf_),
|
||||
bias_qkv, request_batch_size, request_src_seq_len, head_number_,
|
||||
head_size_, 0, stream);
|
||||
gemm_ops[0] = CUBLAS_OP_T;
|
||||
gemm_ops[1] = CUBLAS_OP_N;
|
||||
gemm_dims[0] = request_tgt_seq_len;
|
||||
gemm_dims[1] = request_src_seq_len;
|
||||
gemm_dims[THIRD_INPUT] = head_size_;
|
||||
|
||||
gemm_lds[0] = head_size_;
|
||||
gemm_lds[1] = head_size_;
|
||||
gemm_lds[THIRD_INPUT] = request_tgt_seq_len;
|
||||
|
||||
int gemm_strides[] = {request_tgt_seq_len * head_size_, request_src_seq_len * head_size_,
|
||||
request_src_seq_len * request_tgt_seq_len};
|
||||
CublasGemmStridedBatchedWrapper(output1_, q_buf_2_, qk_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
|
||||
gemm_data_types, &alpha, &beta, request_batch_size * head_number_, cublas_handle_);
|
||||
|
||||
float scalar = (1.0f / sqrtf(static_cast<float>(head_size_) * 1.0f));
|
||||
fastertransformer::invokeMixMaskedSoftMax(static_cast<float *>(qk_buf_), attention_mask, request_batch_size,
|
||||
request_src_seq_len, request_tgt_seq_len, head_number_, scalar, stream);
|
||||
gemm_ops[0] = CUBLAS_OP_N;
|
||||
gemm_ops[1] = CUBLAS_OP_N;
|
||||
gemm_dims[0] = head_size_;
|
||||
gemm_dims[1] = request_src_seq_len;
|
||||
gemm_dims[THIRD_INPUT] = request_tgt_seq_len;
|
||||
|
||||
gemm_lds[0] = head_size_;
|
||||
gemm_lds[1] = request_tgt_seq_len;
|
||||
gemm_lds[THIRD_INPUT] = head_size_;
|
||||
|
||||
gemm_strides[0] = request_tgt_seq_len * head_size_;
|
||||
gemm_strides[1] = request_src_seq_len * request_tgt_seq_len;
|
||||
gemm_strides[THIRD_INPUT] = request_src_seq_len * head_size_;
|
||||
|
||||
CublasGemmStridedBatchedWrapper(output2_, qk_buf_, qkv_buf_2_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
|
||||
gemm_data_types, &alpha, &beta, request_batch_size * head_number_, cublas_handle_);
|
||||
|
||||
fastertransformer::invokeTransposeQKV(static_cast<float *>(qkv_buf_3_), static_cast<float *>(qkv_buf_2_),
|
||||
request_batch_size, request_src_seq_len, head_number_, head_size_, stream);
|
||||
|
||||
gemm_ops[0] = CUBLAS_OP_N;
|
||||
gemm_ops[1] = CUBLAS_OP_N;
|
||||
gemm_dims[0] = hidden_size;
|
||||
gemm_dims[1] = request_batch_size * request_src_seq_len;
|
||||
gemm_dims[THIRD_INPUT] = hidden_size;
|
||||
|
||||
gemm_lds[0] = hidden_size;
|
||||
gemm_lds[1] = hidden_size;
|
||||
gemm_lds[THIRD_INPUT] = hidden_size;
|
||||
|
||||
CublasGemmWrapper(weight_projection, qkv_buf_3_, output0, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha,
|
||||
&beta, cublas_handle_);
|
||||
int len = request_batch_size * request_src_seq_len;
|
||||
fastertransformer::invokeAddBias(reinterpret_cast<float *>(output0), reinterpret_cast<const float *>(bias_projection),
|
||||
len, hidden_size, stream);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MhaPlugin::RunCudaCrossMha(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||
void *const *outputs, void *workspace, cudaStream_t stream) {
|
||||
// Add Cross Mha Layer here
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
|
||||
auto attn_dim_size = inputs[nbInputs - 1].dims.nbDims;
|
||||
const int request_batch_size = static_cast<const int>(inputs[nbInputs - 1].dims.d[0]);
|
||||
const int request_src_seq_len = static_cast<const int>(inputs[nbInputs - 1].dims.d[attn_dim_size - 2]);
|
||||
const int request_tgt_seq_len = static_cast<const int>(inputs[nbInputs - 1].dims.d[attn_dim_size - 1]);
|
||||
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
|
||||
|
||||
// TODO(NIZZAN) Fix efficient allocator
|
||||
// size_t buff_size = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len +
|
||||
// request_batch_size * request_src_seq_len * hidden_size;
|
||||
|
||||
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
|
||||
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
|
||||
size_t size_v = size_k;
|
||||
|
||||
size_t qkv_len = size_q + size_k + size_v;
|
||||
size_t q_buf_2_len = size_q;
|
||||
size_t qk_buf_len = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len;
|
||||
size_t qkv_buf_2_len = request_batch_size * request_src_seq_len * hidden_size;
|
||||
size_t qkv_buf_3_len = qkv_buf_2_len;
|
||||
size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len;
|
||||
|
||||
size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
|
||||
|
||||
return (buff_size + extra_tmp_size + extra_tmp_size) * sizeof(float);
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
|
||||
nvinfer1::IExprBuilder &exprBuilder) noexcept {
|
||||
// MHA inputs:
|
||||
// from_tensor [batch_size, src_seq_len, hidden_size_] or [batch_size * src_seq_len, hidden_size_]
|
||||
// encoder_output [batch_size, tgt_seq_len, hidden_size_] or [batch_size * tgt_seq_len, hidden_size_]--> only in
|
||||
// cross MHA attention_mask [batch_size, 1, src_seq_len, tgt_seq_len] or [batch_size, src_seq_len, tgt_seq_len]
|
||||
|
||||
// MHA output_tensors:
|
||||
// attention_out [batch_size, src_seq_len, hidden_size_]
|
||||
// key_cache [batch, head_num, size_per_head]
|
||||
// value_cache [batch, head_num, tgt_seq_len, size_per_head]
|
||||
nvinfer1::DimsExprs dims;
|
||||
if (index == 0) {
|
||||
// if (inputs[0].nbDims == 2) {
|
||||
// dims.nbDims = INPUT_SIZE2;
|
||||
// dims.d[0] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
|
||||
// auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
|
||||
// dims.d[1] = hidden_size;
|
||||
// } else
|
||||
{
|
||||
dims.nbDims = INPUT_SIZE3;
|
||||
dims.d[0] = inputs[nbInputDims - 1].d[0]; // batch
|
||||
dims.d[1] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
|
||||
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
|
||||
dims.d[kTwo] = hidden_size;
|
||||
}
|
||||
} else {
|
||||
// TODO(Haim) - Fix size in case of 2d input
|
||||
dims.nbDims = INPUT_SIZE4;
|
||||
dims.d[0] = inputs[nbInputDims - 1].d[0]; // batch
|
||||
dims.d[1] = exprBuilder.constant(head_number_);
|
||||
dims.d[kTwo] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
|
||||
dims.d[kThree] = exprBuilder.constant(head_size_);
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *MhaPlugin::clone() const noexcept {
|
||||
auto *plugin = new MhaPlugin(*this); // TODO(haim) CopyConstructor
|
||||
if (plugin == nullptr) {
|
||||
MS_LOG(ERROR) << "plugin is null";
|
||||
return nullptr;
|
||||
}
|
||||
plugin->setPluginNamespace(name_space_.c_str());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
void MhaPlugin::terminate() noexcept {}
|
||||
|
||||
size_t MhaPlugin::getSerializationSize() const noexcept { return INPUT_SIZE4 * sizeof(int); }
|
||||
|
||||
void MhaPlugin::serialize(void *buffer) const noexcept {
|
||||
SerializeValue(&buffer, &compute_type_, sizeof(int));
|
||||
SerializeValue(&buffer, &head_number_, sizeof(int));
|
||||
SerializeValue(&buffer, &head_size_, sizeof(int));
|
||||
SerializeValue(&buffer, &is_cross_, sizeof(int));
|
||||
}
|
||||
REGISTER_TENSORRT_CREATOR(ops::kNameAttention, MhaTensorRT)
|
||||
} // namespace mindspore::lite
|
|
@ -0,0 +1,115 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MHA_TENSORRT_H_
|
||||
#define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MHA_TENSORRT_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
|
||||
#include "src/extendrt/delegate/tensorrt/op/tensorrt_plugin.h"
|
||||
#include "src/extendrt/delegate/tensorrt/cuda_impl/cudnn_utils.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
class MhaTensorRT : public TensorRTOp {
|
||||
public:
|
||||
MhaTensorRT(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
||||
const std::vector<TensorInfo> &out_tensors, std::string name)
|
||||
: TensorRTOp(base_operator, in_tensors, out_tensors, name) {}
|
||||
|
||||
~MhaTensorRT() override = default;
|
||||
int AddInnerOp(TensorRTContext *ctx) override;
|
||||
|
||||
int IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
||||
const std::vector<TensorInfo> &out_tensors) override;
|
||||
};
|
||||
|
||||
constexpr auto MHA_PLUGIN_NAME{"AttentionPlugin"};
|
||||
class MhaPlugin : public TensorRTPlugin {
|
||||
public:
|
||||
MhaPlugin(const std::string name, int compute_type, int head_number, int head_size, int is_cross,
|
||||
cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, uint32_t device_id)
|
||||
: TensorRTPlugin(name, std::string(MHA_PLUGIN_NAME), device_id),
|
||||
compute_type_(compute_type),
|
||||
head_number_(head_number),
|
||||
head_size_(head_size),
|
||||
is_cross_(is_cross),
|
||||
cublas_handle_(cublas_handle),
|
||||
cublaslt_handle_(cublaslt_handle) {}
|
||||
|
||||
MhaPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
|
||||
: TensorRTPlugin(std::string(name), std::string(MHA_PLUGIN_NAME)) {
|
||||
const nvinfer1::PluginField *fields = fc->fields;
|
||||
compute_type_ = static_cast<const int *>(fields[0].data)[0];
|
||||
head_number_ = static_cast<const int *>(fields[1].data)[0];
|
||||
head_size_ = static_cast<const int *>(fields[2].data)[0];
|
||||
is_cross_ = static_cast<const int *>(fields[3].data)[0];
|
||||
}
|
||||
|
||||
MhaPlugin(const char *name, const void *serialData, size_t serialLength)
|
||||
: TensorRTPlugin(std::string(name), std::string(MHA_PLUGIN_NAME)) {
|
||||
DeserializeValue(&serialData, &serialLength, &compute_type_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &head_number_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &head_size_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &is_cross_, sizeof(int));
|
||||
}
|
||||
|
||||
MhaPlugin() = delete;
|
||||
|
||||
~MhaPlugin() override {
|
||||
// std::cout << "~MhaPlugin" << std::endl;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const noexcept override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept override;
|
||||
size_t getSerializationSize() const noexcept override;
|
||||
void serialize(void *buffer) const noexcept override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
|
||||
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
|
||||
void terminate() noexcept override;
|
||||
|
||||
private:
|
||||
bool needResize(const int *current_dims, const int *last_dims);
|
||||
int RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream);
|
||||
int RunCudaCrossMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream);
|
||||
|
||||
const std::string layer_name_;
|
||||
std::string name_space_;
|
||||
int compute_type_;
|
||||
int head_number_;
|
||||
int head_size_;
|
||||
int is_cross_;
|
||||
cublasHandle_t cublas_handle_;
|
||||
cublasLtHandle_t cublaslt_handle_;
|
||||
void *qkv_buf_{nullptr};
|
||||
void *q_buf_2_{nullptr};
|
||||
void *qk_buf_{nullptr};
|
||||
void *qkv_buf_2_{nullptr};
|
||||
void *qkv_buf_3_{nullptr};
|
||||
void *output1_{nullptr};
|
||||
void *output2_{nullptr};
|
||||
};
|
||||
class MhaPluginCreater : public TensorRTPluginCreater<MhaPlugin> {
|
||||
public:
|
||||
MhaPluginCreater() : TensorRTPluginCreater(std::string(MHA_PLUGIN_NAME)) {}
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_MHA_TENSORRT_H_
|
|
@ -28,6 +28,7 @@
|
|||
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
|
||||
#include "src/extendrt/delegate/tensorrt/op_registration_factory.h"
|
||||
#include "src/extendrt/delegate/tensorrt/tensor_info.h"
|
||||
// #include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "ops/op_name.h"
|
||||
|
@ -106,6 +107,8 @@ class TensorRTOp {
|
|||
const std::vector<TensorRTOp *> &out_ops() const;
|
||||
|
||||
void SetRuntime(TensorRTRuntime *runtime);
|
||||
cublasHandle_t GetCublasHandle() { return runtime_ ? runtime_->GetCublasHandle() : nullptr; }
|
||||
cublasLtHandle_t GetCublasLtHandle() { return runtime_ ? runtime_->GetCublasLtHandle() : nullptr; }
|
||||
|
||||
DynamicShapeParams GetDynamicShapeParams() const;
|
||||
|
||||
|
|
|
@ -360,6 +360,14 @@ TensorRTExecutor::~TensorRTExecutor() {
|
|||
if (stream_ != nullptr) {
|
||||
cudaStreamDestroy(stream_);
|
||||
}
|
||||
if (cublas_handle_ != nullptr) {
|
||||
cublasDestroy(cublas_handle_);
|
||||
cublas_handle_ = nullptr;
|
||||
}
|
||||
if (cublaslt_handle_ != nullptr) {
|
||||
cublasLtDestroy(cublaslt_handle_);
|
||||
cublaslt_handle_ = nullptr;
|
||||
}
|
||||
}
|
||||
bool IsHardwareSupport() {
|
||||
int driver_version = 0;
|
||||
|
@ -415,6 +423,19 @@ Status TensorRTExecutor::Init() {
|
|||
MS_LOG(ERROR) << "Cuda create stream failed";
|
||||
return mindspore::kLiteError;
|
||||
}
|
||||
|
||||
auto cublas_ret = cublasCreate(&cublas_handle_);
|
||||
if (cublas_ret != CUBLAS_STATUS_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Cuda create cublas handle failed";
|
||||
return mindspore::kLiteError;
|
||||
}
|
||||
|
||||
auto cublaslt_ret = cublasLtCreate(&cublaslt_handle_);
|
||||
if (cublaslt_ret != CUBLAS_STATUS_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Cuda create cublaslt handle failed";
|
||||
return mindspore::kLiteError;
|
||||
}
|
||||
|
||||
return mindspore::kSuccess;
|
||||
}
|
||||
|
||||
|
@ -569,7 +590,7 @@ std::shared_ptr<TensorRTSubGraph> TensorRTExecutor::CreateTensorRTGraph(const st
|
|||
FindPreNextOps<TensorRTOp>(ops);
|
||||
|
||||
// 2. Init TensorRT SubGraph.
|
||||
auto ret = tensorrt_graph->Init(stream_);
|
||||
auto ret = tensorrt_graph->Init(stream_, cublas_handle_, cublaslt_handle_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "TensorRTGraph init failed.";
|
||||
return nullptr;
|
||||
|
|
|
@ -83,6 +83,9 @@ class TensorRTExecutor : public LiteGraphExecutor {
|
|||
size_t device_cache_size_{0};
|
||||
std::string serialize_path_;
|
||||
cudaStream_t stream_{nullptr};
|
||||
cublasHandle_t cublas_handle_{nullptr};
|
||||
cublasLtHandle_t cublaslt_handle_{nullptr};
|
||||
|
||||
std::vector<kernel::Kernel> kernel_list_;
|
||||
|
||||
std::vector<TrtGraphContext> tensorrt_graph_list_;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <NvInfer.h>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/extendrt/delegate/tensorrt/tensorrt_allocator.h"
|
||||
#include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#define MAX_BATCH_SIZE 64
|
||||
|
||||
|
@ -55,7 +56,11 @@ class TensorRTRuntime {
|
|||
|
||||
void SetBatchSize(int batch_size) { batch_size_ = batch_size; }
|
||||
|
||||
void SetCudaStream(cudaStream_t stream) { allocator_->SetCudaStream(stream); }
|
||||
void SetCudaStream(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle) {
|
||||
allocator_->SetCudaStream(stream);
|
||||
cublas_handle_ = cublas_handle;
|
||||
cublaslt_handle_ = cublaslt_handle;
|
||||
}
|
||||
|
||||
RuntimePrecisionMode GetRuntimePrecisionMode() { return runtime_percision_mode_; }
|
||||
|
||||
|
@ -68,6 +73,8 @@ class TensorRTRuntime {
|
|||
void SetDeviceID(uint32_t device_id) { device_id_ = device_id; }
|
||||
|
||||
uint32_t GetDeviceID() { return device_id_; }
|
||||
cublasHandle_t GetCublasHandle() { return cublas_handle_; }
|
||||
cublasLtHandle_t GetCublasLtHandle() { return cublaslt_handle_; }
|
||||
|
||||
private:
|
||||
bool is_init_ = false;
|
||||
|
@ -77,6 +84,8 @@ class TensorRTRuntime {
|
|||
int batch_size_{0};
|
||||
uint32_t device_id_{0};
|
||||
RuntimePrecisionMode runtime_percision_mode_{RuntimePrecisionMode::RuntimePrecisionMode_FP32};
|
||||
cublasHandle_t cublas_handle_{nullptr};
|
||||
cublasLtHandle_t cublaslt_handle_{nullptr};
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_RUNTIME_H_
|
||||
|
|
|
@ -86,7 +86,7 @@ TensorRTSubGraph::~TensorRTSubGraph() {
|
|||
}
|
||||
}
|
||||
|
||||
int TensorRTSubGraph::Init(cudaStream_t stream) {
|
||||
int TensorRTSubGraph::Init(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle) {
|
||||
auto ret = GetGraphInOutOps(inputs_, outputs_, &in_ops_, &out_ops_, all_ops_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get TensorRT subgraph input and output ops failed.";
|
||||
|
@ -107,7 +107,7 @@ int TensorRTSubGraph::Init(cudaStream_t stream) {
|
|||
MS_LOG(ERROR) << "New TensorRTContext failed.";
|
||||
return RET_OK;
|
||||
}
|
||||
if (SetDeviceConfig(stream) != RET_OK) {
|
||||
if (SetDeviceConfig(stream, cublas_handle, cublaslt_handle) != RET_OK) {
|
||||
MS_LOG(WARNING) << "set tensorrt config failed.";
|
||||
}
|
||||
serializer_ = std::make_shared<TensorRTSerializer>(serialize_file_path_);
|
||||
|
@ -151,7 +151,8 @@ int TensorRTSubGraph::BuildEngine() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) {
|
||||
int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream, cublasHandle_t cublas_handle,
|
||||
cublasLtHandle_t cublaslt_handle) {
|
||||
if (config_ == nullptr) {
|
||||
this->config_ = runtime_->GetBuilder()->createBuilderConfig();
|
||||
if (this->config_ == nullptr) {
|
||||
|
@ -176,9 +177,10 @@ int TensorRTSubGraph::SetDeviceConfig(cudaStream_t stream) {
|
|||
} else {
|
||||
MS_LOG(INFO) << "inputs no quant params or platform not support int8.";
|
||||
}
|
||||
runtime_->SetCudaStream(stream);
|
||||
runtime_->SetCudaStream(stream, cublas_handle, cublaslt_handle);
|
||||
config_->setProfileStream(stream);
|
||||
stream_ = stream;
|
||||
|
||||
MS_LOG(INFO) << GetRankID() << " tensorrt subgraph stream: " << stream_;
|
||||
|
||||
// config setMaxWorkspaceSize to 2047 MB for max limit
|
||||
|
|
|
@ -54,7 +54,7 @@ class TensorRTSubGraph {
|
|||
|
||||
int BuildTensorRTGraph();
|
||||
|
||||
int Init(cudaStream_t stream);
|
||||
int Init(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle);
|
||||
|
||||
void SetSerializePath(const std::string &path) { serialize_file_path_ = std::move(path); }
|
||||
|
||||
|
@ -65,7 +65,7 @@ class TensorRTSubGraph {
|
|||
private:
|
||||
int BuildEngine();
|
||||
|
||||
int SetDeviceConfig(cudaStream_t stream);
|
||||
int SetDeviceConfig(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle);
|
||||
|
||||
bool IsInt8Mode();
|
||||
|
||||
|
|
|
@ -247,6 +247,9 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared
|
|||
std::make_shared<opt::TensorDotFusion>(),
|
||||
std::make_shared<opt::MatMulActivationFusion>(param),
|
||||
std::make_shared<opt::MulActivationFusion>()};
|
||||
#ifdef ENABLE_CLOUD_FUSION_INFERENCE
|
||||
fusions.push_back(std::make_shared<opt::MultiHeadAttentionFusion>());
|
||||
#endif
|
||||
for (size_t index = 0; index < fusions.size(); index++) {
|
||||
auto pass_ptr = fusions.at(index);
|
||||
auto pass_name = pass_ptr->name();
|
||||
|
|
|
@ -568,6 +568,7 @@ int AnfExporter::SetMetaGraphInput(const FuncGraphPtr &func_graph,
|
|||
for (const auto &input : func_graph->get_inputs()) {
|
||||
auto iter = graph_inputs_map_.find(input);
|
||||
if (iter == graph_inputs_map_.end()) {
|
||||
MS_LOG(ERROR) << "input " << input->ToString() << " not found in graph" << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
meta_graphT->inputIndex.emplace_back(iter->second);
|
||||
|
|
|
@ -21,39 +21,59 @@
|
|||
#include <vector>
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "ops/tuple_get_item.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
const auto &p1 = std::placeholders::_1;
|
||||
const size_t kWeightShapeSize = 2;
|
||||
const int kAttentionOutputs = 3;
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, bool test_div = false) {
|
||||
auto is_matmul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
MS_CHECK_TRUE_RET(is_matmul != nullptr, {});
|
||||
auto dense = VectorRef({is_matmul, input, weight, bias});
|
||||
auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
|
||||
MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
|
||||
auto var1 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var1 != nullptr, {});
|
||||
auto reshape = VectorRef({is_reshape, dense, var1});
|
||||
auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
|
||||
MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
|
||||
auto var2 = std::make_shared<Var>();
|
||||
auto transpose = VectorRef({is_transpose, reshape, var2});
|
||||
if (test_div) {
|
||||
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv));
|
||||
MS_CHECK_TRUE_RET(is_div != nullptr, {});
|
||||
auto var3 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
auto div = VectorRef({is_div, transpose, var3});
|
||||
return div;
|
||||
}
|
||||
return transpose;
|
||||
bool MultiHeadAttentionFusion::Init() const {
|
||||
input_q_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(input_q_ != nullptr, false);
|
||||
input_k_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(input_k_ != nullptr, false);
|
||||
input_v_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(input_v_ != nullptr, false);
|
||||
|
||||
weight_q_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(weight_q_ != nullptr, false);
|
||||
weight_k_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(weight_k_ != nullptr, false);
|
||||
weight_v_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(weight_v_ != nullptr, false);
|
||||
weight_o_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(weight_o_ != nullptr, false);
|
||||
|
||||
bias_q_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(bias_q_ != nullptr, false);
|
||||
bias_k_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(bias_k_ != nullptr, false);
|
||||
bias_v_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(bias_v_ != nullptr, false);
|
||||
bias_o_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(bias_o_ != nullptr, false);
|
||||
|
||||
mask_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(mask_ != nullptr, false);
|
||||
|
||||
reshape_k_ = std::make_shared<Var>("reshape_k");
|
||||
MS_CHECK_TRUE_RET(reshape_k_ != nullptr, false);
|
||||
reshape_v_ = std::make_shared<Var>("reshape_v");
|
||||
MS_CHECK_TRUE_RET(reshape_v_ != nullptr, false);
|
||||
reshape_axis_ = std::make_shared<Var>("reshape_axis");
|
||||
MS_CHECK_TRUE_RET(reshape_axis_ != nullptr, false);
|
||||
v_transpose_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "v_transpose");
|
||||
MS_CHECK_TRUE_RET(v_transpose_ != nullptr, false);
|
||||
k_transpose_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "k_transpose");
|
||||
MS_CHECK_TRUE_RET(k_transpose_ != nullptr, false);
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
VectorRef DefineMask(const BaseRef &mask_input) {
|
||||
auto is_expand_dims = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims));
|
||||
MS_CHECK_TRUE_RET(is_expand_dims != nullptr, {});
|
||||
|
@ -71,21 +91,62 @@ VectorRef DefineMask(const BaseRef &mask_input) {
|
|||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
return VectorRef({is_mul, sub, var3});
|
||||
}
|
||||
|
||||
STATUS GetAxis(const BaseRef &n, std::vector<int> *axes) {
|
||||
MS_ASSERT(axes != nullptr);
|
||||
if (utils::isa<ValueNodePtr>(n)) {
|
||||
auto axes_value_node = utils::cast<ValueNodePtr>(n);
|
||||
*axes = CastToInt(axes_value_node->value());
|
||||
return lite::RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "GetAxis supports only value node";
|
||||
}
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias,
|
||||
const BaseRef &axis, const BaseRef &transpose_var, bool test_div,
|
||||
bool transpose) const {
|
||||
auto is_matmul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "e-matmul");
|
||||
MS_CHECK_TRUE_RET(is_matmul != nullptr, {});
|
||||
auto dense = VectorRef({is_matmul, input, weight, bias});
|
||||
auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "e-reshape");
|
||||
MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
|
||||
auto reshape = VectorRef({is_reshape, dense, axis});
|
||||
auto var2 = std::make_shared<Var>();
|
||||
VectorRef conn;
|
||||
if (transpose) {
|
||||
conn = VectorRef({transpose_var, reshape, var2});
|
||||
} else {
|
||||
conn = reshape;
|
||||
}
|
||||
if (test_div) {
|
||||
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "e-div");
|
||||
MS_CHECK_TRUE_RET(is_div != nullptr, {});
|
||||
auto var3 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
auto div = VectorRef({is_div, conn, var3});
|
||||
return div;
|
||||
}
|
||||
return conn;
|
||||
}
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool cross, bool mask) const {
|
||||
VectorRef k_embedding, v_embedding;
|
||||
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, true);
|
||||
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
|
||||
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
|
||||
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true);
|
||||
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
|
||||
if (!cross) {
|
||||
k_embedding = DefineEmbedding(input_q_, weight_k_, bias_k_, true);
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_q_, weight_v_, bias_v_);
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
} else {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, true);
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_k_, weight_v_, bias_v_);
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
}
|
||||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
|
@ -133,19 +194,73 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool cross, bool mas
|
|||
return matmul3;
|
||||
}
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const {
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5(bool cross) const {
|
||||
VectorRef k_embedding, v_embedding;
|
||||
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, true);
|
||||
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "q_transpose");
|
||||
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
|
||||
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true, false);
|
||||
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
|
||||
if (!cross) {
|
||||
k_embedding = DefineEmbedding(input_q_, weight_k_, bias_k_, true);
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_q_, weight_v_, bias_v_);
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
} else {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, true);
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_k_, weight_v_, bias_v_);
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
}
|
||||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul1");
|
||||
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
|
||||
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape1");
|
||||
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
|
||||
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
|
||||
auto var1 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var1 != nullptr, {});
|
||||
auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "add");
|
||||
MS_CHECK_TRUE_RET(is_add != nullptr, {});
|
||||
auto mask = DefineMask(mask_);
|
||||
MS_CHECK_TRUE_RET(!mask.empty(), {});
|
||||
auto add = VectorRef({is_add, mask, matmul1});
|
||||
auto reshape1 = VectorRef({is_reshape1, add, var1});
|
||||
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax), "softmax");
|
||||
MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
|
||||
auto softmax = VectorRef({is_softmax, reshape1});
|
||||
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape2");
|
||||
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
|
||||
auto var2 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var2 != nullptr, {});
|
||||
auto reshape2 = VectorRef({is_reshape2, softmax, var2});
|
||||
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul2");
|
||||
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
|
||||
auto matmul2 = VectorRef({is_matmul2, reshape2, v_embedding});
|
||||
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape3");
|
||||
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
|
||||
auto var4 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var4 != nullptr, {});
|
||||
auto reshape3 = VectorRef({is_reshape3, matmul2, var4});
|
||||
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul");
|
||||
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
|
||||
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_, bias_o_});
|
||||
return matmul3;
|
||||
}
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const {
|
||||
VectorRef k_embedding, v_embedding;
|
||||
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
|
||||
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
|
||||
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true);
|
||||
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
|
||||
if (!cross) {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
} else {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
}
|
||||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
|
@ -169,18 +284,36 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const
|
|||
auto var3 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
auto transpose = VectorRef({is_transpose, matmul2, var3});
|
||||
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
|
||||
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape3");
|
||||
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
|
||||
auto var4 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var4 != nullptr, {});
|
||||
auto reshape3 = VectorRef({is_reshape3, transpose, var4});
|
||||
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul3");
|
||||
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
|
||||
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_, bias_o_});
|
||||
return matmul3;
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<tensor::Tensor>> &tensors) {
|
||||
STATUS TransposeMatrix(std::shared_ptr<tensor::Tensor> src, std::shared_ptr<tensor::Tensor> dst) {
|
||||
MS_CHECK_TRUE_RET(src->shape().size() == C2NUM, RET_ERROR);
|
||||
MS_CHECK_TRUE_RET(dst->shape().size() == C2NUM, RET_ERROR);
|
||||
int rows = src->shape().at(0);
|
||||
int cols = src->shape().at(1);
|
||||
auto src_ptr = reinterpret_cast<float *>(src->data_c());
|
||||
auto dst_ptr = reinterpret_cast<float *>(dst->data_c());
|
||||
for (int r = 0; r < rows; ++r) {
|
||||
for (int c = 0; c < cols; ++c) {
|
||||
auto val = src_ptr[r * cols + c];
|
||||
dst_ptr[c * rows + r] = val;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<tensor::Tensor>> &tensors,
|
||||
bool transpose = false) {
|
||||
const std::vector<int64_t> &base_shape = tensors.at(0)->shape();
|
||||
auto base_shape_size = base_shape.size();
|
||||
auto base_data_type = tensors.at(0)->data_type();
|
||||
|
@ -205,9 +338,10 @@ std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<
|
|||
auto sum = std::accumulate(tensors.begin(), tensors.end(), 0,
|
||||
[](int sum, const tensor::TensorPtr &tensor) { return sum + tensor->shape().at(0); });
|
||||
new_shape.push_back(sum);
|
||||
for (std::size_t i = 1; i < base_shape_size; i++) {
|
||||
for (std::size_t i = 1; i < base_shape_size; ++i) {
|
||||
new_shape.push_back(base_shape.at(i));
|
||||
}
|
||||
|
||||
// calculate data
|
||||
auto concat_tensor = std::make_shared<tensor::Tensor>(base_data_type, new_shape);
|
||||
MS_CHECK_TRUE_RET(concat_tensor != nullptr, nullptr);
|
||||
|
@ -217,46 +351,17 @@ std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<
|
|||
memcpy_s(ptr, concat_tensor->Size() - offset, tensor->data_c(), tensor->Size());
|
||||
offset += tensor->Size();
|
||||
}
|
||||
if (transpose) {
|
||||
std::vector<int64_t> tshape = {new_shape[1], new_shape[0]};
|
||||
auto transposed_tensor = std::make_shared<tensor::Tensor>(base_data_type, tshape);
|
||||
auto status = TransposeMatrix(concat_tensor, transposed_tensor);
|
||||
MS_CHECK_TRUE_RET(status == RET_OK, nullptr);
|
||||
return transposed_tensor;
|
||||
}
|
||||
return concat_tensor;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool MultiHeadAttentionFusion::Init() const {
|
||||
input_q_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(input_q_ != nullptr, false);
|
||||
input_k_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(input_k_ != nullptr, false);
|
||||
// input_v_ = std::make_shared<Var>();
|
||||
// MS_CHECK_TRUE_RET(input_v_ != nullptr, false);
|
||||
|
||||
weight_q_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(weight_q_ != nullptr, false);
|
||||
weight_k_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(weight_k_ != nullptr, false);
|
||||
weight_v_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(weight_v_ != nullptr, false);
|
||||
weight_o_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(weight_o_ != nullptr, false);
|
||||
|
||||
bias_q_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(bias_q_ != nullptr, false);
|
||||
bias_k_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(bias_k_ != nullptr, false);
|
||||
bias_v_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(bias_v_ != nullptr, false);
|
||||
bias_o_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(bias_o_ != nullptr, false);
|
||||
|
||||
mask_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(mask_ != nullptr, false);
|
||||
|
||||
reshape_k_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(reshape_k_ != nullptr, false);
|
||||
reshape_v_ = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(reshape_v_ != nullptr, false);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, VectorRef> MultiHeadAttentionFusion::DefinePatterns() const {
|
||||
std::unordered_map<std::string, VectorRef> patterns;
|
||||
if (!Init()) {
|
||||
|
@ -269,18 +374,38 @@ std::unordered_map<std::string, VectorRef> MultiHeadAttentionFusion::DefinePatte
|
|||
patterns[kMPAXPatternName] = DefineMPWithMaskPattern(true, false);
|
||||
patterns[kMPAWithMaskPatternNamePA] = DefineMPWithMaskPatternPA();
|
||||
patterns[kMPAXWithMaskPatternNamePA] = DefineMPWithMaskPatternPA(true);
|
||||
patterns[kMPAWithMaskPatternNameT5] = DefineMPWithMaskPatternT5();
|
||||
patterns[kMPAXWithMaskPatternNameT5] = DefineMPWithMaskPatternT5(true);
|
||||
return patterns;
|
||||
}
|
||||
|
||||
bool MultiHeadAttentionFusion::CheckPattern(const EquivPtr &equiv, int *head_num, int *head_size) const {
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
MS_ASSERT(head_num != nullptr);
|
||||
MS_ASSERT(head_size != nullptr);
|
||||
std::vector<int> reshape_axes;
|
||||
if (GetAxis((*equiv)[reshape_axis_], &reshape_axes) != lite::RET_OK) {
|
||||
return false;
|
||||
}
|
||||
if (reshape_axes.size() != C4NUM) {
|
||||
return false;
|
||||
}
|
||||
*head_num = reshape_axes.at(C2NUM);
|
||||
*head_size = reshape_axes.at(C3NUM);
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr MultiHeadAttentionFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
|
||||
const mindspore::AnfNodePtr &node,
|
||||
const mindspore::EquivPtr &equiv) const {
|
||||
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if ((pattern_name == kMPAWithMaskPatternName) || (pattern_name == kMPAWithMaskPatternNamePA)) {
|
||||
if ((pattern_name == kMPAWithMaskPatternName) || (pattern_name == kMPAWithMaskPatternNamePA) ||
|
||||
(pattern_name == kMPAWithMaskPatternNameT5)) {
|
||||
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope());
|
||||
} else if ((pattern_name == kMPAXWithMaskPatternName) || (pattern_name == kMPAXWithMaskPatternNamePA)) {
|
||||
} else if ((pattern_name == kMPAXWithMaskPatternName) || (pattern_name == kMPAXWithMaskPatternNamePA) ||
|
||||
(pattern_name == kMPAXWithMaskPatternNameT5)) {
|
||||
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), true);
|
||||
} else if (pattern_name == kMPAPatternName) {
|
||||
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), false, false);
|
||||
|
@ -328,7 +453,7 @@ std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(con
|
|||
}
|
||||
|
||||
if (!utils::isa<ParameterPtr>((*equiv)[reshape_v_])) {
|
||||
MS_LOG(ERROR) << "Reshape k is not a parameter";
|
||||
MS_LOG(ERROR) << "Reshape v is not a parameter";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -353,37 +478,158 @@ std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(con
|
|||
return attention_prim;
|
||||
}
|
||||
|
||||
CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph,
|
||||
const EquivPtr &equiv, const string &base_name,
|
||||
bool cross, bool mask) const {
|
||||
STATUS MultiHeadAttentionFusion::AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &attention,
|
||||
int index, const AnfNodePtr &node) const {
|
||||
auto manager = func_graph->manager();
|
||||
if (manager == nullptr) {
|
||||
MS_LOG(ERROR) << "manager is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto transpose_users = manager->node_users()[node];
|
||||
auto user_node = transpose_users.front();
|
||||
if (!CheckPrimitiveType(user_node.first, prim::kPrimTranspose)) {
|
||||
MS_LOG(ERROR) << " missing transpose node for branch " << index << std::endl;
|
||||
return RET_ERROR;
|
||||
}
|
||||
// connect get item to it
|
||||
transpose_users = manager->node_users()[user_node.first];
|
||||
auto get_item = CreateOutputGetItem(func_graph, attention, index);
|
||||
MS_ASSERT(get_item != nullptr);
|
||||
if (transpose_users.size() == 1) {
|
||||
auto &snode = transpose_users.front();
|
||||
manager->SetEdge(snode.first, snode.second, get_item);
|
||||
} else {
|
||||
for (auto &snode : transpose_users) {
|
||||
if (CheckPrimitiveType(snode.first, prim::kPrimMakeTuple)) {
|
||||
manager->SetEdge(snode.first, snode.second, get_item);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
CNodePtr MultiHeadAttentionFusion::CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node,
|
||||
const int item_index) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
auto tuple_get_item_prim = std::make_shared<ops::TupleGetItem>();
|
||||
auto get_item_value = NewValueNode(MakeValue<int>(item_index));
|
||||
if (tuple_get_item_prim == nullptr || get_item_value == nullptr) {
|
||||
MS_LOG(ERROR) << "NewValueNode is nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
auto tuple_get_item_prim_c = tuple_get_item_prim->GetPrim();
|
||||
MS_ASSERT(tuple_get_item_prim_c != nullptr);
|
||||
CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim_c, {node, get_item_value});
|
||||
MS_CHECK_TRUE_RET(get_item_cnode != nullptr, nullptr);
|
||||
auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstract failed";
|
||||
return nullptr;
|
||||
}
|
||||
get_item_cnode->set_abstract(abstract);
|
||||
get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" +
|
||||
std::to_string(item_index));
|
||||
return get_item_cnode;
|
||||
}
|
||||
|
||||
STATUS MultiHeadAttentionFusion::SetAbstractTuple(const CNodePtr &cnode, const int output_num) const {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
AbstractBasePtrList abstract_list;
|
||||
for (int i = 0; i < output_num; ++i) {
|
||||
auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||
if (abstract == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor abstract failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
abstract_list.emplace_back(abstract);
|
||||
}
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
if (abstract_tuple == nullptr) {
|
||||
MS_LOG(ERROR) << "create abstract_tuple failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
cnode->set_abstract(abstract_tuple);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MultiHeadAttentionFusion::RemoveRedundantInput(const FuncGraphPtr &func_graph,
|
||||
const std::vector<AnfNodePtr> &redundant) const {
|
||||
for (auto &node : redundant) {
|
||||
func_graph->DropNode(node);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::CreatePrim(const EquivPtr &equiv, bool cross) const {
|
||||
auto attention_prim = std::make_shared<ops::Attention>();
|
||||
if (attention_prim == nullptr) {
|
||||
MS_LOG(ERROR) << "Build attention primitive failed.";
|
||||
return nullptr;
|
||||
}
|
||||
int head_num = 0;
|
||||
int head_size = 0;
|
||||
if (!CheckPattern(equiv, &head_num, &head_size)) {
|
||||
return nullptr;
|
||||
}
|
||||
attention_prim->Init(head_num, head_size, cross);
|
||||
return attention_prim;
|
||||
}
|
||||
|
||||
CNodePtr MultiHeadAttentionFusion::MakeGetTuple(const FuncGraphPtr &func_graph, const CNodePtr &new_node,
|
||||
const AnfNodePtr &knode, const AnfNodePtr &vnode) const {
|
||||
auto get_item_node = CreateOutputGetItem(func_graph, new_node, 0);
|
||||
if (get_item_node == nullptr) {
|
||||
MS_LOG(ERROR) << "create attention output get_item node failed";
|
||||
return nullptr;
|
||||
}
|
||||
if (knode != nullptr) {
|
||||
auto status = AdjustOtherGetItems(func_graph, new_node, 1, knode);
|
||||
MS_CHECK_TRUE_RET(status == RET_OK, nullptr);
|
||||
}
|
||||
if (vnode != nullptr) {
|
||||
auto status = AdjustOtherGetItems(func_graph, new_node, 2, vnode);
|
||||
MS_CHECK_TRUE_RET(status == RET_OK, nullptr);
|
||||
}
|
||||
return get_item_node;
|
||||
}
|
||||
|
||||
CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph,
|
||||
const EquivPtr &equiv, const string &base_name,
|
||||
bool cross, bool mask) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
std::vector<AnfNodePtr> redundant;
|
||||
auto attention_prim = CreatePrim(equiv, cross);
|
||||
MS_CHECK_TRUE_RET(attention_prim != nullptr, nullptr);
|
||||
auto attention_prim_c = attention_prim->GetPrim();
|
||||
MS_CHECK_TRUE_RET(attention_prim_c != nullptr, nullptr);
|
||||
auto value_node = NewValueNode(attention_prim_c);
|
||||
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
|
||||
|
||||
auto input_q = utils::cast<AnfNodePtr>((*equiv)[input_q_]);
|
||||
AnfNodePtr input_k, input_mask;
|
||||
|
||||
if (cross) {
|
||||
input_k = utils::cast<AnfNodePtr>((*equiv)[input_k_]);
|
||||
}
|
||||
// auto input_v = utils::cast<AnfNodePtr>((*equiv)[input_v_]);
|
||||
|
||||
auto input_k = utils::cast<AnfNodePtr>((*equiv)[input_k_]);
|
||||
auto input_v = utils::cast<AnfNodePtr>((*equiv)[input_v_]);
|
||||
AnfNodePtr input_mask;
|
||||
auto weight_q = utils::cast<AnfNodePtr>((*equiv)[weight_q_]);
|
||||
redundant.push_back(weight_q);
|
||||
auto weight_k = utils::cast<AnfNodePtr>((*equiv)[weight_k_]);
|
||||
auto weight_v = utils::cast<AnfNodePtr>((*equiv)[weight_v_]);
|
||||
redundant.push_back(weight_k);
|
||||
redundant.push_back(weight_v);
|
||||
auto weight_o = utils::cast<AnfNodePtr>((*equiv)[weight_o_]);
|
||||
|
||||
auto bias_q = utils::cast<AnfNodePtr>((*equiv)[bias_q_]);
|
||||
if (!cross) {
|
||||
redundant.push_back(bias_q);
|
||||
}
|
||||
auto bias_k = utils::cast<AnfNodePtr>((*equiv)[bias_k_]);
|
||||
auto bias_v = utils::cast<AnfNodePtr>((*equiv)[bias_v_]);
|
||||
redundant.push_back(bias_k);
|
||||
redundant.push_back(bias_v);
|
||||
auto bias_o = utils::cast<AnfNodePtr>((*equiv)[bias_o_]);
|
||||
auto knode = utils::cast<AnfNodePtr>((*equiv)[k_transpose_]);
|
||||
auto vnode = utils::cast<AnfNodePtr>((*equiv)[v_transpose_]);
|
||||
if (mask) {
|
||||
input_mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
|
||||
}
|
||||
|
@ -394,10 +640,12 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
|
|||
std::shared_ptr<tensor::Tensor> bias_k_tensor = GetTensorInfo(bias_k);
|
||||
std::shared_ptr<tensor::Tensor> bias_v_tensor = GetTensorInfo(bias_v);
|
||||
tensor::TensorPtr c_weights;
|
||||
tensor::TensorPtr q_weight_t;
|
||||
if (cross) {
|
||||
c_weights = ConcatTensors({weight_k_tensor, weight_v_tensor});
|
||||
c_weights = ConcatTensors({weight_k_tensor, weight_v_tensor}, true);
|
||||
q_weight_t = ConcatTensors({weight_q_tensor}, true);
|
||||
} else {
|
||||
c_weights = ConcatTensors({weight_q_tensor, weight_k_tensor, weight_v_tensor});
|
||||
c_weights = ConcatTensors({weight_q_tensor, weight_k_tensor, weight_v_tensor}, true);
|
||||
}
|
||||
auto c_bias = ConcatTensors({bias_q_tensor, bias_k_tensor, bias_v_tensor});
|
||||
// convert tensors to params
|
||||
|
@ -415,18 +663,34 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
|
|||
return nullptr;
|
||||
}
|
||||
c_bias_param->set_name(base_name + "/bias_qkv");
|
||||
ParameterPtr q_weight_param;
|
||||
if (cross) {
|
||||
q_weight_param = func_graph->add_parameter();
|
||||
MS_CHECK_TRUE_RET(q_weight_param != nullptr, nullptr);
|
||||
if (lite::InitParameterFromTensorInfo(q_weight_param, q_weight_t) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Init parameter from tensor info failed.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
std::vector<AnfNodePtr> new_node_inputs;
|
||||
if (cross) {
|
||||
new_node_inputs = {value_node, input_q, input_k, input_k, weight_q, c_weight_param, weight_o, c_bias_param, bias_o};
|
||||
new_node_inputs = {value_node, input_q, input_k, input_v, q_weight_param,
|
||||
c_weight_param, weight_o, c_bias_param, bias_o};
|
||||
} else {
|
||||
new_node_inputs = {value_node, input_q, input_q, input_q, c_weight_param, weight_o, c_bias_param, bias_o};
|
||||
new_node_inputs = {value_node, input_q, input_k, input_v, c_weight_param, weight_o, c_bias_param, bias_o};
|
||||
}
|
||||
if (mask) {
|
||||
new_node_inputs.push_back(input_mask);
|
||||
}
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
|
||||
new_node->set_fullname_with_scope(base_name);
|
||||
return new_node;
|
||||
auto status = SetAbstractTuple(new_node, kAttentionOutputs);
|
||||
if (status != RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
new_node->set_fullname_with_scope(base_name + "/attention");
|
||||
auto get_item_node = MakeGetTuple(func_graph, new_node, knode, vnode);
|
||||
RemoveRedundantInput(func_graph, redundant);
|
||||
return get_item_node;
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -48,10 +49,24 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
|
|||
// define patterns
|
||||
VectorRef DefineMPWithMaskPattern(bool cross = false, bool mask = true) const;
|
||||
VectorRef DefineMPWithMaskPatternPA(bool cross = false) const;
|
||||
VectorRef DefineMPWithMaskPatternT5(bool cross = false) const;
|
||||
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, const BaseRef &axis,
|
||||
const BaseRef &transpose_var, bool test_div = false, bool transpose = true) const;
|
||||
|
||||
// create masked-multi-head-attention
|
||||
CNodePtr CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const std::string &base_name, bool cross = false, bool mask = true) const;
|
||||
// check pattern
|
||||
bool CheckPattern(const EquivPtr &equiv, int *head_num, int *head_size) const;
|
||||
CNodePtr CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, const int item_index) const;
|
||||
lite::STATUS SetAbstractTuple(const CNodePtr &cnode, const int output_num) const;
|
||||
lite::STATUS AdjustOtherGetItems(const FuncGraphPtr &func_graph, const CNodePtr &attention, int index,
|
||||
const AnfNodePtr &node) const;
|
||||
lite::STATUS RemoveRedundantInput(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &redundant) const;
|
||||
std::shared_ptr<ops::Attention> CreatePrim() const;
|
||||
CNodePtr MakeGetTuple(const FuncGraphPtr &func_graph, const CNodePtr &new_node, const AnfNodePtr &knode,
|
||||
const AnfNodePtr &vnode) const;
|
||||
std::shared_ptr<ops::Attention> CreatePrim(const EquivPtr &equiv, bool cross) const;
|
||||
|
||||
protected:
|
||||
const std::string kMPAWithMaskPatternName = "MPAWithMaskPattern";
|
||||
|
@ -60,6 +75,8 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
|
|||
const std::string kMPAXWithMaskPatternNamePA = "MPAXWithMaskPatternPA";
|
||||
const std::string kMPAPatternName = "MPAPattern";
|
||||
const std::string kMPAXPatternName = "MPAXPattern";
|
||||
const std::string kMPAWithMaskPatternNameT5 = "MPAWithMaskPatternT5";
|
||||
const std::string kMPAXWithMaskPatternNameT5 = "MPAXWithMaskPatternT5";
|
||||
|
||||
mutable VarPtr input_q_{nullptr};
|
||||
mutable VarPtr input_k_{nullptr};
|
||||
|
@ -78,6 +95,10 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
|
|||
|
||||
mutable VarPtr reshape_k_{nullptr};
|
||||
mutable VarPtr reshape_v_{nullptr};
|
||||
|
||||
mutable VarPtr reshape_axis_{nullptr};
|
||||
mutable VarPtr v_transpose_{nullptr};
|
||||
mutable VarPtr k_transpose_{nullptr};
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
|
|
|
@ -128,7 +128,7 @@ index 8707220..aea35e6 100644
|
|||
target_link_libraries(trt_fused_multi_head_attention PUBLIC -lcublas -lcudart)
|
||||
set_property(TARGET trt_fused_multi_head_attention PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index ea21014..3098d8a 100644
|
||||
index ea21014..97b842e 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -14,7 +14,9 @@
|
||||
|
@ -255,7 +255,7 @@ index ea21014..3098d8a 100644
|
|||
include(CMakePackageConfigHelpers)
|
||||
configure_package_config_file(
|
||||
${CMAKE_CURRENT_LIST_DIR}/cmake/FasterTransformerConfig.cmake.in
|
||||
@@ -402,52 +392,23 @@ configure_package_config_file(
|
||||
@@ -402,28 +392,14 @@ configure_package_config_file(
|
||||
INSTALL_DESTINATION ${INSTALL_CONFIGDIR}
|
||||
)
|
||||
|
||||
|
@ -281,16 +281,12 @@ index ea21014..3098d8a 100644
|
|||
- FasterTransformerTargets.cmake
|
||||
- DESTINATION
|
||||
- ${INSTALL_CONFIGDIR}
|
||||
+ LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/output/lib
|
||||
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/output/lib
|
||||
+ LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib
|
||||
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/lib
|
||||
)
|
||||
|
||||
file(GLOB_RECURSE HEADER_FILES "*.h" "*.hpp" "*.cuh")
|
||||
foreach ( file ${HEADER_FILES} )
|
||||
file( RELATIVE_PATH rfile ${CMAKE_CURRENT_SOURCE_DIR} ${file} )
|
||||
get_filename_component( dir ${rfile} DIRECTORY )
|
||||
- install( FILES ${file} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${dir} )
|
||||
+ install( FILES ${file} DESTINATION ${CMAKE_CURRENT_SOURCE_DIR}/output/include )
|
||||
@@ -434,20 +410,5 @@ foreach ( file ${HEADER_FILES} )
|
||||
endforeach()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue