Adding support for FP16, cross & T5 MHA

This commit is contained in:
nizzan 2022-12-06 09:56:12 +02:00
parent d870c9090c
commit dff877dbd3
15 changed files with 3713 additions and 613 deletions

View File

@ -2,6 +2,9 @@ set(REQ_URL "https://github.com/NVIDIA/FasterTransformer/archive/refs/tags/relea
set(SHA256 "7adffe2d53b3c1544295a6b7d1887e59b044eba25dd3e150bc909168d5e99081")
set(ft_libs "transformer-shared")
if(DEFINED ENV{MSLITE_GPU_ARCH})
set(arch_opt -DSM=$ENV{MSLITE_GPU_ARCH})
endif()
mindspore_add_pkg(fast_transformers
VER 0.5.0
@ -10,7 +13,7 @@ mindspore_add_pkg(fast_transformers
LIBS ${ft_libs}
LIB_PATH lib
PATCHES ${MINDSPORE_PROJECT_DIR}/third_party/patch/fast_transformer/001-fast_transformer.patch
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DEXAMPLES=off)
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release ${arch_opt} -DEXAMPLES=off)
include_directories(${fast_transformers_INC})
add_library(mindspore::fast_transformers ALIAS fast_transformers::transformer-shared)

View File

@ -124,7 +124,8 @@ static BaseRef GetVar(const BaseRef &x) {
EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
MS_EXCEPTION_IF_NULL(equiv);
if (equiv == nullptr) MS_EXCEPTION_IF_NULL(equiv);
if (utils::isa<VarPtr>(pattern)) {
VarPtr var = utils::cast<VarPtr>(pattern);
if (var->matches(expr)) {

View File

@ -39,13 +39,19 @@ int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
if (q_weight->shape_size_ != C2NUM) {
return NNACL_ERR;
}
int batch = (q_input->shape_size_ == 2) ? 1 : q_input->shape_[0];
int f_seq = (q_input->shape_size_ == 2) ? q_input->shape_[0] : q_input->shape_[1];
int batch = (q_input->shape_size_ == C2NUM) ? 1 : q_input->shape_[0];
int f_seq = (q_input->shape_size_ == C2NUM) ? q_input->shape_[0] : q_input->shape_[1];
int t_seq_len = k_input->shape_[1];
if (q_input->shape_size_ == C2NUM) {
output0->shape_[FIRST_INPUT] = batch * f_seq;
output0->shape_[SECOND_INPUT] = param->head_num_ * param->head_size_;
output0->shape_size_ = C2NUM;
} else {
output0->shape_[FIRST_INPUT] = batch;
output0->shape_[SECOND_INPUT] = f_seq;
output0->shape_[THIRD_INPUT] = param->head_num_ * param->head_size_;
output0->shape_size_ = C3NUM;
}
if (outputs_size >= C3NUM) {
TensorC *output1 = outputs[SECOND_INPUT];
SetDataTypeFormat(output1, q_input);

View File

@ -71,7 +71,7 @@ void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *
void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
cublasHandle_t cublas_handle) {
cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) {
const int m = params[0];
const int n = params[1];
const int k = params[2];
@ -84,15 +84,17 @@ void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, con
cudaDataType type_b = data_types[1];
cudaDataType type_c = data_types[2];
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
compute_type = CUBLAS_COMPUTE_16F;
}
CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda, b_addr, type_b,
ldb, beta, c_addr, type_c, ldc, compute_type, CUBLAS_GEMM_DEFAULT));
ldb, beta, c_addr, type_c, ldc, compute_type, algo));
}
void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
const int *lds, const cublasOperation_t *operations, const int *strides,
const cudaDataType *data_types, void *alpha, void *beta, int batch,
cublasHandle_t cublas_handle) {
cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) {
const int m = params[0];
const int n = params[1];
const int k = params[2];
@ -105,12 +107,62 @@ void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, voi
cudaDataType type_b = data_types[1];
cudaDataType type_c = data_types[2];
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
compute_type = CUBLAS_COMPUTE_16F;
}
const int stride_a = strides[0];
const int stride_b = strides[1];
const int stride_c = strides[2];
CUBLAS_CHECK_VOID(cublasGemmStridedBatchedEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda,
stride_a, b_addr, type_b, ldb, stride_b, beta, c_addr, type_c, ldc,
stride_c, batch, compute_type, CUBLAS_GEMM_DEFAULT));
stride_c, batch, compute_type, algo));
}
void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle) {
cublasOperation_t trans_a = operations[0];
cublasOperation_t trans_b = operations[1];
cudaDataType type_a = data_types[0];
cudaDataType type_b = data_types[1];
cudaDataType type_c = data_types[2];
const int m = params[0];
const int n = params[1];
const int k = params[2];
const int lda = lds[0];
const int ldb = lds[1];
const int ldc = lds[2];
cublasLtMatrixLayout_t mat_a_desc = NULL;
cublasLtMatrixLayoutCreate(&mat_a_desc, type_a, (trans_a == CUBLAS_OP_N) ? m : k, (trans_a == CUBLAS_OP_N) ? k : m,
lda);
cublasLtMatrixLayout_t mat_b_desc = NULL;
cublasLtMatrixLayoutCreate(&mat_b_desc, type_b, (trans_b == CUBLAS_OP_N) ? k : n, (trans_b == CUBLAS_OP_N) ? n : k,
ldb);
cublasLtMatrixLayout_t mat_c_desc = NULL;
cublasLtMatrixLayoutCreate(&mat_c_desc, type_c, m, n, ldc);
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
compute_type = CUBLAS_COMPUTE_16F;
}
cublasLtMatmulDesc_t mat_operation_desc = NULL;
cublasLtMatmulDescCreate(&mat_operation_desc, compute_type, type_a);
cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(cublasOperation_t));
cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(cublasOperation_t));
if (bias != nullptr) {
cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void *));
}
cublasLtMatmul(cublaslt_handle, mat_operation_desc, alpha, a_addr, mat_a_desc, b_addr, mat_b_desc, beta, c_addr,
mat_c_desc, c_addr, mat_c_desc, NULL, NULL, 0, stream);
cublasLtMatrixLayoutDestroy(mat_a_desc);
cublasLtMatrixLayoutDestroy(mat_b_desc);
cublasLtMatrixLayoutDestroy(mat_c_desc);
cublasLtMatmulDescDestroy(mat_operation_desc);
}
} // namespace mindspore::lite

View File

@ -62,10 +62,14 @@ void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *
void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
cublasHandle_t cublas_handle);
cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
const int *lds, const cublasOperation_t *operations, const int *strides,
const cudaDataType *data_types, void *alpha, void *beta, int batch,
cublasHandle_t cublas_handle);
cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_

View File

@ -34,13 +34,38 @@ namespace mindspore::lite {
namespace {
constexpr std::size_t kTwo = 2;
constexpr std::size_t kThree = 3;
std::ostream &operator<<(std::ostream &s, const nvinfer1::ITensor &t) {
const auto &dims = t.getDimensions();
s << "ndims=" << dims.nbDims << " [";
for (int i = 0; i < dims.nbDims; ++i) {
s << dims.d[i] << " ";
}
s << "]";
return s;
}
} // namespace
#define SET_GEMM_PARAMS(gemm_ops_, gemm_lds_, gemm_op1_, gemm_op2_, gemm_ld1_, gemm_ld2_, gemm_ld3_) \
do { \
gemm_ops_[0] = gemm_op1_; \
gemm_ops_[1] = gemm_op2_; \
gemm_lds_[0] = gemm_ld1_; \
gemm_lds_[1] = gemm_ld2_; \
gemm_lds_[2] = gemm_ld3_; \
} while (0)
#define SET_GEMM_DIMS(gemm_dims_, gemm_dim1_, gemm_dim2_, gemm_dim3_) \
do { \
gemm_dims_[0] = gemm_dim1_; \
gemm_dims_[1] = gemm_dim2_; \
gemm_dims_[2] = gemm_dim3_; \
} while (0)
// Multi Head Attention TensorRT op
int MhaTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
const std::vector<TensorInfo> &out_tensors) {
if (in_tensors.size() != 8 && in_tensors.size() != 6) {
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
if (in_tensors.size() < 7 || in_tensors.size() > 9) { // T5 has 6 or 7 inputs, other models have 8 or 9 inputs
MS_LOG(ERROR) << "Unsupported number of inputs, size is " << in_tensors.size();
return RET_ERROR;
}
return RET_OK;
@ -59,13 +84,13 @@ int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) {
// get attribute for Attn op - TODO - add attribute in op
int head_number = mha_op->get_head_num();
int head_size = mha_op->get_head_size();
int compute_type = 1; // mha_op->get_compute_type();
int is_cross = mha_op->get_cross();
nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_;
auto plugin = std::make_shared<MhaPlugin>(input_tensor->getName(), compute_type, head_number, head_size, is_cross,
GetCublasHandle(), GetCublasLtHandle(), device_id_);
auto compute_type = runtime_->GetRuntimePrecisionMode(); // mha_op->get_compute_type();
bool is_cross = mha_op->get_cross();
const int input_number = inputs().size();
bool is_position_bias = (((input_number == 8) && is_cross) || ((input_number == 7) && !is_cross)) ? true : false;
nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_;
auto plugin = std::make_shared<MhaPlugin>(input_tensor->getName(), compute_type, head_number, head_size, is_cross,
is_position_bias, GetCublasHandle(), GetCublasLtHandle(), device_id_);
nvinfer1::ITensor *inputTensors[input_number];
for (int i = 0; i < input_number; i++) {
inputTensors[i] = input(ctx, i).trt_tensor_;
@ -75,16 +100,46 @@ int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) {
MS_LOG(ERROR) << "add mha op failed for TensorRT.";
return RET_ERROR;
}
mha_layer->setName(op_name_.c_str());
mha_layer->setName((op_name_ + "plugin_attention").c_str());
// TODO(haim) one output
nvinfer1::ITensor *attn_tensor = mha_layer->getOutput(0);
#ifndef TEST_
ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name());
#else /* TEST_ */
ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name() + "attn");
#endif /* TEST_ */
// nvinfer1::ITensor *key_tensor = mha_layer->getOutput(1);
// ctx->RegisterTensor(ITensorHelper{key_tensor, Format::NCHW, true}, out_tensors_[1].Name());
// nvinfer1::ITensor *value_tensor = mha_layer->getOutput(kTwo);
// ctx->RegisterTensor(ITensorHelper{value_tensor, Format::NCHW, true}, out_tensors_[kTwo].Name());
this->layer_ = mha_layer;
#ifdef TEST_
auto weight_projection = input(ctx, 4).trt_tensor_;
auto bias_projection = input(ctx, 6).trt_tensor_;
#endif /* TEST_ */
#ifdef TEST_
auto matmul_layer = ctx->network()->addMatrixMultiply(*attn_tensor, nvinfer1::MatrixOperation::kNONE,
*weight_projection, nvinfer1::MatrixOperation::kNONE);
if (matmul_layer == nullptr) {
MS_LOG(ERROR) << "failed to add matmul layer";
return RET_ERROR;
}
matmul_layer->setName((op_name_ + "_matmul").c_str());
auto matmul_tensor = matmul_layer->getOutput(0);
auto shuffle_layer = ctx->network()->addShuffle(*bias_projection);
const auto size = bias_projection->getDimensions().d[0];
shuffle_layer->setReshapeDimensions(nvinfer1::Dims{2, {1, size}});
auto shuffle_tensor = shuffle_layer->getOutput(0);
auto addbias = ctx->network()->addElementWise(*matmul_tensor, *shuffle_tensor, nvinfer1::ElementWiseOperation::kSUM);
if (addbias == nullptr) {
MS_LOG(ERROR) << "failed to add bias layer";
return RET_ERROR;
}
addbias->setName((op_name_ + "_bias").c_str());
auto bias_out = addbias->getOutput(0);
ctx->RegisterTensor(ITensorHelper{bias_out, Format::NCHW, true}, out_tensors_[0].Name());
#endif /* TEST_ */
return RET_OK;
}
@ -98,11 +153,80 @@ std::vector<nvinfer1::PluginField> TensorRTPluginCreater<T>::fields_;
int MhaPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept {
return RunCudaMha(inputDesc, outputDesc, inputs, outputs, workspace, stream);
if (compute_type_ == RuntimePrecisionMode_FP16) {
return RunCudaMha<half>(inputDesc, outputDesc, inputs, outputs, workspace, stream, fast_algo_gemm);
} else {
return RunCudaMha<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream, fast_algo_gemm);
}
}
template <typename T>
void MhaPlugin::SetInnerAddr(void *workspace, size_t size_q, size_t size_k, size_t qk_buf_len, size_t qkv_buf_2_len,
size_t extra_size) {
size_t qkv_len = size_q + (size_k * 2); // size_v is equal to size_k
size_t q_buf_2_len = size_q;
auto buff_size =
qkv_len + q_buf_2_len + qk_buf_len + (qkv_buf_2_len * 2); // qkv_buf_3_ len is equal to qkv_buf_2_len
qkv_buf_ = workspace;
q_buf_2_ = static_cast<T *>(qkv_buf_) + qkv_len;
qk_buf_ = static_cast<T *>(q_buf_2_) + q_buf_2_len;
qkv_buf_2_ = static_cast<T *>(qk_buf_) + qk_buf_len;
qkv_buf_3_ = static_cast<T *>(qkv_buf_2_) + qkv_buf_2_len;
output1_ = static_cast<T *>(workspace) + buff_size;
output2_ = static_cast<T *>(output1_) + extra_size;
}
template <typename T>
void MhaPlugin::RunPhase1GEMM(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, int *gemm_dims,
int *gemm_lds, cublasOperation_t *gemm_ops, cudaDataType *gemm_data_types, void *alpha,
void *beta, cublasGemmAlgo_t algoId, cudaStream_t stream) {
int cross_tensor_offset = 0;
if (is_cross_) cross_tensor_offset = 1;
const int from_tensor_idx = 0, encoder_tensor_idx = 1, weight_qkv_tensor_idx = 3;
const int weight_qkv_tensor_idx_cross = 3 + cross_tensor_offset;
const int bias_qkv_tensor_idx = 5 + cross_tensor_offset;
const int attn_mask_tensor_idx = (is_position_bias_) ? 6 + cross_tensor_offset : 7 + cross_tensor_offset;
auto from_tensor = static_cast<const T *>(inputs[from_tensor_idx]);
auto encoder_output_tensor = static_cast<const T *>(inputs[encoder_tensor_idx]);
auto weight_q = static_cast<const T *>(inputs[weight_qkv_tensor_idx]);
auto weight_kv = static_cast<const T *>(inputs[weight_qkv_tensor_idx_cross]);
auto weight_qkv = static_cast<const T *>(inputs[weight_qkv_tensor_idx_cross]);
auto bias_qkv = (is_position_bias_) ? nullptr : static_cast<const T *>(inputs[bias_qkv_tensor_idx]);
auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims;
const int request_batch_size = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[0]);
const int request_src_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]);
const int request_tgt_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]);
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
if (is_cross_) {
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, hidden_size, hidden_size, hidden_size);
SET_GEMM_DIMS(gemm_dims, hidden_size, request_batch_size * request_src_seq_len, hidden_size);
CublasGemmWrapper(weight_q, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops,
const_cast<const cudaDataType *>(gemm_data_types), alpha, beta, cublas_handle_);
SET_GEMM_DIMS(gemm_dims, C2NUM * hidden_size, request_batch_size * request_tgt_seq_len, hidden_size);
gemm_lds[0] = gemm_lds[THIRD_INPUT] = C2NUM * hidden_size;
CublasGemmWrapper(weight_kv, encoder_output_tensor,
static_cast<T *>(qkv_buf_) + (request_batch_size * request_src_seq_len) * hidden_size, gemm_dims,
gemm_lds, gemm_ops, const_cast<const cudaDataType *>(gemm_data_types), alpha, beta,
cublas_handle_);
fastertransformer::invokeCrossAddFusedQKVBiasTranspose(
static_cast<T *>(q_buf_2_), static_cast<T *>(output1_), static_cast<T *>(output2_), static_cast<T *>(qkv_buf_),
bias_qkv, request_batch_size, request_src_seq_len, request_tgt_seq_len, head_number_, head_size_, stream);
} else {
CublasGemmWrapper(weight_qkv, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops,
const_cast<const cudaDataType *>(gemm_data_types), alpha, beta, cublas_handle_, algoId);
fastertransformer::invokeAddFusedQKVBiasTranspose(
static_cast<T *>(q_buf_2_), static_cast<T *>(output1_), static_cast<T *>(output2_), static_cast<T *>(qkv_buf_),
bias_qkv, request_batch_size, request_src_seq_len, head_number_, head_size_, 0, stream);
}
}
template <typename T>
int MhaPlugin::RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) {
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream,
cublasGemmAlgo_t *algoId) {
// inputs order:
// 0] Q
// 1] K
@ -112,134 +236,99 @@ int MhaPlugin::RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvi
// 5] B
// 6] PB
// 7] AttnMask
// inputs order cross:
// 0] Q
// 1] K enco output
// 2] V
// 3] Wq
// 4] Wkv
// 5] PW
// 6] Bqkv
// 7] PB
// 8] AttnMask
int cross_tensor_offset = 0;
cublasSetStream(cublas_handle_, stream);
if (is_cross_) cross_tensor_offset = 1;
const int weight_projection_tensor_idx = 4 + cross_tensor_offset;
const int bias_projection_tensor_idx = 6 + cross_tensor_offset;
const int attn_mask_tensor_idx = (is_position_bias_) ? 6 + cross_tensor_offset : 7 + cross_tensor_offset;
const int bias_position_tensor_idx = 5 + cross_tensor_offset;
// TODO(Haim) - Fix tensor ids according to cross flag
const int from_tensor_idx = 0;
// const int encoder_tensor_idx = 1;
const int weight_qkv_tensor_idx = 3;
const int weight_projection_tensor_idx = 4;
const int bias_qkv_tensor_idx = 5;
const int bias_projection_tensor_idx = 6;
const int attn_mask_tensor_idx = 7;
auto from_tensor = static_cast<const float *>(inputs[from_tensor_idx]);
auto attention_mask = static_cast<const float *>(inputs[attn_mask_tensor_idx]);
auto weight_qkv = static_cast<const float *>(inputs[weight_qkv_tensor_idx]);
auto bias_qkv = static_cast<const float *>(inputs[bias_qkv_tensor_idx]);
auto weight_projection = static_cast<const float *>(inputs[weight_projection_tensor_idx]);
auto bias_projection = static_cast<const float *>(inputs[bias_projection_tensor_idx]);
auto output0 = static_cast<float *>(outputs[0]);
// auto output1 = static_cast<float *>(outputs[1]);
// auto output2 = static_cast<float *>(outputs[2]);
auto attention_mask = static_cast<const T *>(inputs[attn_mask_tensor_idx]);
auto weight_projection = static_cast<const T *>(inputs[weight_projection_tensor_idx]);
auto bias_projection = (is_position_bias_) ? nullptr : static_cast<const T *>(inputs[bias_projection_tensor_idx]);
auto bias_position = (is_position_bias_) ? static_cast<const T *>(inputs[bias_position_tensor_idx]) : nullptr;
auto output0 = static_cast<T *>(outputs[0]);
auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims;
const int request_batch_size = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[0]);
const int request_src_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]);
const int request_tgt_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]);
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
// TODO(NIZZAN): fix allocator
size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
auto extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
size_t size_v = size_k;
size_t qkv_len = size_q + size_k + size_v;
size_t q_buf_2_len = size_q;
size_t qk_buf_len = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len;
size_t qkv_buf_2_len = request_batch_size * request_src_seq_len * hidden_size;
size_t qkv_buf_3_len = qkv_buf_2_len;
size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len;
qkv_buf_ = workspace;
q_buf_2_ = static_cast<float *>(qkv_buf_) + qkv_len;
qk_buf_ = static_cast<float *>(q_buf_2_) + q_buf_2_len;
qkv_buf_2_ = static_cast<float *>(qk_buf_) + qk_buf_len;
qkv_buf_3_ = static_cast<float *>(qkv_buf_2_) + qkv_buf_2_len;
output1_ = static_cast<float *>(workspace) + buff_size;
output2_ = static_cast<float *>(output1_) + extra_tmp_size;
SetInnerAddr<T>(workspace, size_q, size_k, qk_buf_len, qkv_buf_2_len, extra_tmp_size);
int gemm_dims[3] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size};
int gemm_lds[3] = {3 * hidden_size, hidden_size, 3 * hidden_size};
cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N};
cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F};
if constexpr (std::is_same<T, half>::value)
std::fill(std::begin(gemm_data_types), std::end(gemm_data_types), CUDA_R_16F);
float alpha = 1.0f, beta = 0.0f;
int gemm_dims[] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size};
int gemm_lds[] = {3 * hidden_size, hidden_size, 3 * hidden_size};
cublasOperation_t gemm_ops[2] = {CUBLAS_OP_N, CUBLAS_OP_N};
const cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F};
float alpha = 1.0f;
float beta = 0.0f;
CublasGemmWrapper(weight_qkv, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha, &beta,
cublas_handle_);
fastertransformer::invokeAddFusedQKVBiasTranspose(static_cast<float *>(q_buf_2_), static_cast<float *>(output1_),
static_cast<float *>(output2_), static_cast<float *>(qkv_buf_),
bias_qkv, request_batch_size, request_src_seq_len, head_number_,
head_size_, 0, stream);
gemm_ops[0] = CUBLAS_OP_T;
gemm_ops[1] = CUBLAS_OP_N;
gemm_dims[0] = request_tgt_seq_len;
gemm_dims[1] = request_src_seq_len;
gemm_dims[THIRD_INPUT] = head_size_;
gemm_lds[0] = head_size_;
gemm_lds[1] = head_size_;
gemm_lds[THIRD_INPUT] = request_tgt_seq_len;
RunPhase1GEMM<T>(inputDesc, inputs, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha, &beta, algoId[0], stream);
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_T, CUBLAS_OP_N, head_size_, head_size_, request_tgt_seq_len);
SET_GEMM_DIMS(gemm_dims, request_tgt_seq_len, request_src_seq_len, head_size_);
int gemm_strides[] = {request_tgt_seq_len * head_size_, request_src_seq_len * head_size_,
request_src_seq_len * request_tgt_seq_len};
CublasGemmStridedBatchedWrapper(output1_, q_buf_2_, qk_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
gemm_data_types, &alpha, &beta, request_batch_size * head_number_, cublas_handle_);
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta,
request_batch_size * head_number_, cublas_handle_, algoId[1]);
float scalar = (1.0f / sqrtf(static_cast<float>(head_size_) * 1.0f));
fastertransformer::invokeMixMaskedSoftMax(static_cast<float *>(qk_buf_), attention_mask, request_batch_size,
request_src_seq_len, request_tgt_seq_len, head_number_, scalar, stream);
gemm_ops[0] = CUBLAS_OP_N;
gemm_ops[1] = CUBLAS_OP_N;
gemm_dims[0] = head_size_;
gemm_dims[1] = request_src_seq_len;
gemm_dims[THIRD_INPUT] = request_tgt_seq_len;
gemm_lds[0] = head_size_;
gemm_lds[1] = request_tgt_seq_len;
gemm_lds[THIRD_INPUT] = head_size_;
gemm_strides[0] = request_tgt_seq_len * head_size_;
T scalar = static_cast<T>(1.0f / sqrtf(head_size_ * 1.0f));
fastertransformer::invokeMixMaskedSoftMax(static_cast<T *>(qk_buf_), attention_mask, bias_position,
request_batch_size, request_src_seq_len, request_tgt_seq_len, head_number_,
scalar, stream);
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, head_size_, request_tgt_seq_len, head_size_);
SET_GEMM_DIMS(gemm_dims, head_size_, request_src_seq_len, request_tgt_seq_len);
gemm_strides[1] = request_src_seq_len * request_tgt_seq_len;
gemm_strides[THIRD_INPUT] = request_src_seq_len * head_size_;
CublasGemmStridedBatchedWrapper(output2_, qk_buf_, qkv_buf_2_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
gemm_data_types, &alpha, &beta, request_batch_size * head_number_, cublas_handle_);
fastertransformer::invokeTransposeQKV(static_cast<float *>(qkv_buf_3_), static_cast<float *>(qkv_buf_2_),
request_batch_size, request_src_seq_len, head_number_, head_size_, stream);
gemm_ops[0] = CUBLAS_OP_N;
gemm_ops[1] = CUBLAS_OP_N;
gemm_dims[0] = hidden_size;
gemm_dims[1] = request_batch_size * request_src_seq_len;
gemm_dims[THIRD_INPUT] = hidden_size;
gemm_lds[0] = hidden_size;
gemm_lds[1] = hidden_size;
gemm_lds[THIRD_INPUT] = hidden_size;
CublasGemmWrapper(weight_projection, qkv_buf_3_, output0, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha,
&beta, cublas_handle_);
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta,
request_batch_size * head_number_, cublas_handle_, algoId[2]);
fastertransformer::invokeTransposeQKV(static_cast<T *>(qkv_buf_3_), static_cast<T *>(qkv_buf_2_), request_batch_size,
request_src_seq_len, head_number_, head_size_, stream);
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, hidden_size, hidden_size, hidden_size);
SET_GEMM_DIMS(gemm_dims, hidden_size, request_batch_size * request_src_seq_len, hidden_size);
CublasGemmWrapper(weight_projection, qkv_buf_3_, output0, gemm_dims, gemm_lds, gemm_ops,
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta, cublas_handle_, algoId[3]);
if (!is_position_bias_) {
int len = request_batch_size * request_src_seq_len;
fastertransformer::invokeAddBias(reinterpret_cast<float *>(output0), reinterpret_cast<const float *>(bias_projection),
len, hidden_size, stream);
fastertransformer::invokeAddBias(reinterpret_cast<T *>(output0), reinterpret_cast<const T *>(bias_projection), len,
hidden_size, stream);
}
return RET_OK;
}
int MhaPlugin::RunCudaCrossMha(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) {
// Add Cross Mha Layer here
return RET_OK;
bool MhaPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
int nbOutputs) noexcept {
auto type = (compute_type_ == RuntimePrecisionMode_FP16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;
for (int i = 0; i < pos; i++) {
if (tensorsDesc[pos].type != tensorsDesc[i].type) return false;
}
bool res = (tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) && (tensorsDesc[pos].type == type);
return res;
}
void MhaPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {}
size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
@ -250,9 +339,6 @@ size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
// TODO(NIZZAN) Fix efficient allocator
// size_t buff_size = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len +
// request_batch_size * request_src_seq_len * hidden_size;
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
size_t size_v = size_k;
@ -265,8 +351,12 @@ size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int
size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len;
size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
int elem_size = sizeof(float);
if (compute_type_ == RuntimePrecisionMode_FP16) {
elem_size = sizeof(half);
}
return (buff_size + extra_tmp_size + extra_tmp_size) * sizeof(float);
return (buff_size + extra_tmp_size + extra_tmp_size) * elem_size;
}
nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
@ -282,14 +372,15 @@ nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1
// value_cache [batch, head_num, tgt_seq_len, size_per_head]
nvinfer1::DimsExprs dims;
if (index == 0) {
// if (inputs[0].nbDims == 2) {
// dims.nbDims = INPUT_SIZE2;
// dims.d[0] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
// auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
// dims.d[1] = hidden_size;
// } else
{
dims.nbDims = INPUT_SIZE3;
#ifndef TEST_
int num_dims = inputs[0].nbDims;
dims.nbDims = num_dims;
if (num_dims == INPUT_SIZE2) {
dims.d[0] = exprBuilder.constant(inputs[nbInputDims - 1].d[0]->getConstantValue() *
inputs[nbInputDims - 1].d[1]->getConstantValue());
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
dims.d[1] = hidden_size;
} else if (num_dims == INPUT_SIZE3) {
dims.d[0] = inputs[nbInputDims - 1].d[0]; // batch
dims.d[1] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
@ -303,6 +394,13 @@ nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1
dims.d[kTwo] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
dims.d[kThree] = exprBuilder.constant(head_size_);
}
#else
dims.nbDims = C2NUM;
dims.d[0] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
dims.d[1] = hidden_size;
}
#endif
return dims;
}
@ -316,6 +414,8 @@ nvinfer1::IPluginV2DynamicExt *MhaPlugin::clone() const noexcept {
return plugin;
}
int MhaPlugin::initialize() noexcept { return 0; }
void MhaPlugin::terminate() noexcept {}
size_t MhaPlugin::getSerializationSize() const noexcept { return INPUT_SIZE4 * sizeof(int); }

View File

@ -31,6 +31,8 @@ class MhaTensorRT : public TensorRTOp {
: TensorRTOp(base_operator, in_tensors, out_tensors, name) {}
~MhaTensorRT() override = default;
// bool IsWeightInputHanledInner() const override { return true; }
int AddInnerOp(TensorRTContext *ctx) override;
int IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
@ -40,13 +42,14 @@ class MhaTensorRT : public TensorRTOp {
constexpr auto MHA_PLUGIN_NAME{"AttentionPlugin"};
class MhaPlugin : public TensorRTPlugin {
public:
MhaPlugin(const std::string name, int compute_type, int head_number, int head_size, int is_cross,
cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, uint32_t device_id)
MhaPlugin(const std::string name, int compute_type, int head_number, int head_size, bool is_cross,
bool is_position_bias, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, uint32_t device_id)
: TensorRTPlugin(name, std::string(MHA_PLUGIN_NAME), device_id),
compute_type_(compute_type),
head_number_(head_number),
head_size_(head_size),
is_cross_(is_cross),
is_position_bias_(is_position_bias),
cublas_handle_(cublas_handle),
cublaslt_handle_(cublaslt_handle) {}
@ -57,6 +60,7 @@ class MhaPlugin : public TensorRTPlugin {
head_number_ = static_cast<const int *>(fields[1].data)[0];
head_size_ = static_cast<const int *>(fields[2].data)[0];
is_cross_ = static_cast<const int *>(fields[3].data)[0];
is_position_bias_ = static_cast<const int *>(fields[4].data)[0];
}
MhaPlugin(const char *name, const void *serialData, size_t serialLength)
@ -65,6 +69,7 @@ class MhaPlugin : public TensorRTPlugin {
DeserializeValue(&serialData, &serialLength, &head_number_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &head_size_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &is_cross_, sizeof(int));
DeserializeValue(&serialData, &serialLength, &is_position_bias_, sizeof(int));
}
MhaPlugin() = delete;
@ -82,21 +87,37 @@ class MhaPlugin : public TensorRTPlugin {
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
int nbOutputs) noexcept override;
void terminate() noexcept override;
int initialize() noexcept override;
private:
bool needResize(const int *current_dims, const int *last_dims);
template <typename T>
int RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream);
int RunCudaCrossMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream);
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream,
cublasGemmAlgo_t *algoId);
template <typename T>
void SetInnerAddr(void *workspace, size_t size_q, size_t size_k, size_t qk_buf_len, size_t qkv_buf_2_len,
size_t extra_size);
template <typename T>
void RunPhase1GEMM(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, int *gemm_dims,
int *gemm_lds, cublasOperation_t *gemm_ops, cudaDataType *gemm_data_types, void *alpha, void *beta,
cublasGemmAlgo_t algoId, cudaStream_t stream);
const std::string layer_name_;
std::string name_space_;
int compute_type_;
int head_number_;
int head_size_;
int is_cross_;
bool is_cross_;
bool is_position_bias_;
cublasGemmAlgo_t fast_algo_gemm[4] = {CUBLAS_GEMM_DEFAULT_TENSOR_OP, CUBLAS_GEMM_DEFAULT_TENSOR_OP,
CUBLAS_GEMM_DEFAULT_TENSOR_OP, CUBLAS_GEMM_DEFAULT_TENSOR_OP};
cublasHandle_t cublas_handle_;
cublasLtHandle_t cublaslt_handle_;
void *qkv_buf_{nullptr};

View File

@ -27,6 +27,7 @@
#include <limits>
#include <unordered_map>
#include "src/extendrt/delegate/delegate_utils.h"
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
#include "src/common/utils.h"
#include "ops/transpose.h"
@ -68,6 +69,11 @@ TensorRTSubGraph::~TensorRTSubGraph() {
config_->destroy();
config_ = nullptr;
}
#ifdef PROFILER_
auto profile = dynamic_cast<SimpleProfiler *>(trt_context_->getProfiler());
if (profile != nullptr) std::cout << *profile << std::endl;
delete profile;
#endif
if (trt_context_ != nullptr) {
trt_context_->destroy();
trt_context_ = nullptr;
@ -494,6 +500,15 @@ int TensorRTSubGraph::Prepare() {
MS_LOG(ERROR) << "TensorRTSubGraph create context failed.";
return RET_ERROR;
}
#ifdef PROFILER_
auto profiler = new SimpleProfiler("myprofiler");
if (profiler == nullptr) {
MS_LOG(WARNING) << "Cannot create profiler";
}
this->trt_context_->setProfiler(profiler);
#endif
int binding_num = this->engine_->getNbBindings();
if (binding_num <= 0) {
MS_LOG(ERROR) << "TensorRTSubGraph binding num < 0.";
@ -504,7 +519,6 @@ int TensorRTSubGraph::Prepare() {
MS_LOG(ERROR) << "malloc tensor binding array failed.";
return RET_ERROR;
}
profile_index_ = MaxVolumnProfileIndex();
if (this->trt_context_->setOptimizationProfile(profile_index_)) {
MS_LOG(INFO) << "setOptimizationProfile: " << profile_index_;
@ -527,9 +541,6 @@ int TensorRTSubGraph::Prepare() {
MS_LOG(INFO) << "device index " << index << " for tensor : " << tensor_name << " attr: " << device_ptr;
tensor_bindings_[index] = device_ptr;
nvinfer1::Dims input_dims = ConvertCudaDims(profile.inputs[i].max_dims);
for (int od = 0; od < input_dims.nbDims; od++) {
MS_LOG(DEBUG) << "in tensor " << tensor.Name() << " dims at " << od << " is " << input_dims.d[od];
}
if (!this->trt_context_->setBindingDimensions(index, input_dims)) {
MS_LOG(ERROR) << "invalid input dims of " << tensor.Name();
return RET_ERROR;
@ -781,7 +792,6 @@ int TensorRTSubGraph::PostExecute(std::vector<tensor::Tensor> *outputs) {
// actual output tensor dims
auto out_dims = this->trt_context_->getBindingDimensions(index);
std::vector<int64_t> new_shape = lite::ConvertMSShape(out_dims);
outputs_[i].SetShape(new_shape);
for (int od = 0; od < out_dims.nbDims; od++) {
MS_LOG(DEBUG) << "out tensor " << trt_out_tensor_name << " dims at " << od << " is " << new_shape[od];
}

View File

@ -20,6 +20,8 @@
#include <unordered_set>
#include <numeric>
#include <functional>
#include <iomanip>
#include <algorithm>
#include "src/extendrt/delegate/tensorrt/op/cast_plugin.h"
#include "src/extendrt/delegate/tensorrt/distribution/distribution_collective.h"
@ -176,6 +178,7 @@ nvinfer1::DataType ConvertDataType(DataType type_id) {
{DataType::kNumberTypeInt32, nvinfer1::DataType::kINT32},
{DataType::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT},
{DataType::kNumberTypeFloat16, nvinfer1::DataType::kHALF},
{DataType::kNumberTypeInt64, nvinfer1::DataType::kINT32},
};
auto iter = data_type_map.find(type_id);
nvinfer1::DataType data_type;
@ -895,4 +898,66 @@ nvinfer1::DataType GetNvinferDataType<int>() {
template nvinfer1::DataType GetNvinferDataType<float>();
template nvinfer1::DataType GetNvinferDataType<int>();
#ifdef PROFILER_
void SimpleProfiler::reportLayerTime(const char *layerName, float ms) noexcept {
mProfile_[layerName].count++;
mProfile_[layerName].time += ms;
if (std::find(mLayerNames_.begin(), mLayerNames_.end(), layerName) == mLayerNames_.end()) {
mLayerNames_.push_back(layerName);
}
}
SimpleProfiler::SimpleProfiler(const char *name, const std::vector<SimpleProfiler> &srcProfilers) : mName_(name) {
for (const auto &srcProfiler : srcProfilers) {
for (const auto &rec : srcProfiler.mProfile_) {
auto it = mProfile_.find(rec.first);
if (it == mProfile_.end()) {
mProfile_.insert(rec);
} else {
it->second.time += rec.second.time;
it->second.count += rec.second.count;
}
}
}
}
std::ostream &operator<<(std::ostream &out, const SimpleProfiler &value) {
out << "========== " << value.mName_ << " profile ==========" << std::endl;
float totalTime = 0;
std::string layerNameStr = "TensorRT layer name";
int maxLayerNameLength = std::max(static_cast<int>(layerNameStr.size()), 70);
for (const auto &elem : value.mProfile_) {
totalTime += elem.second.time;
maxLayerNameLength = std::max(maxLayerNameLength, static_cast<int>(elem.first.size()));
}
auto old_settings = out.flags();
auto old_precision = out.precision();
// Output header
{
out << std::setw(maxLayerNameLength) << layerNameStr << " ";
out << std::setw(C12NUM) << "Runtime, "
<< "%"
<< " ";
out << std::setw(C12NUM) << "Invocations"
<< " ";
out << std::setw(C12NUM) << "Runtime, ms" << std::endl;
}
for (size_t i = 0; i < value.mLayerNames_.size(); i++) {
const std::string layerName = value.mLayerNames_[i];
auto elem = value.mProfile_.at(layerName);
out << std::setw(maxLayerNameLength) << layerName << " ";
out << std::setw(C12NUM) << std::fixed << std::setprecision(1) << (elem.time * 100.0F / totalTime) << "%"
<< " ";
out << std::setw(C12NUM) << elem.count << " ";
out << std::setw(C12NUM) << std::fixed << std::setprecision(C2NUM) << elem.time << std::endl;
}
out.flags(old_settings);
out.precision(old_precision);
out << "========== " << value.mName_ << " total runtime = " << totalTime << " ms ==========" << std::endl;
return out;
}
#endif // PROFILER_
} // namespace mindspore::lite

View File

@ -21,6 +21,7 @@
#include <NvInferVersion.h>
#include <memory>
#include <string>
#include <map>
#include "src/extendrt/delegate/tensorrt/tensorrt_context.h"
#include "src/extendrt/delegate/tensorrt/tensor_info.h"
#include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h"
@ -57,6 +58,28 @@ typedef union float32_bits {
float f;
} float32_bits;
// #define PROFILER_
#ifdef PROFILER_
struct SimpleProfiler : public nvinfer1::IProfiler {
struct Record {
float time{0};
int count{0};
};
virtual void reportLayerTime(const char *layerName, float ms) noexcept;
explicit SimpleProfiler(const char *name,
const std::vector<SimpleProfiler> &srcProfilers = std::vector<SimpleProfiler>());
friend std::ostream &operator<<(std::ostream &out, const SimpleProfiler &value);
private:
std::string mName_;
std::vector<std::string> mLayerNames_;
std::map<std::string, Record> mProfile_;
};
#endif
// Convert Tensor data to Cuda dims.
nvinfer1::Dims ConvertCudaDims(const std::vector<int> &data);
@ -194,4 +217,4 @@ void Data2Vector(std::vector<float> *dst, const void *src) {
}
}
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_UTILS_H_
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_UTILS_H_

View File

@ -64,7 +64,7 @@ Status ContextUtils::AddGpuDevice(bool enable_fp16, uint32_t device_id, int rank
Status ContextUtils::AddNpuDevice(int frequency, lite::InnerContext *inner_context) {
lite::DeviceInfo device_info;
device_info.npu_device_info_ = {false, frequency};
device_info.npu_device_info_.frequency_ = frequency;
inner_context->device_list_.push_back({lite::DT_NPU, device_info});
return kSuccess;
}

View File

@ -42,6 +42,7 @@
#include <thread>
#include "src/common/config_file.h"
#endif
namespace mindspore {
constexpr size_t kDataToStringMaxNum = 40;
constexpr int kPrintDataNum = 20;

View File

@ -19,6 +19,7 @@
#include <functional>
#include <utility>
#include <vector>
#include <algorithm>
#include "tools/optimizer/common/gllo_utils.h"
#include "nnacl/op_base.h"
#include "ops/tuple_get_item.h"
@ -32,32 +33,34 @@ const int kAttentionOutputs = 3;
} // namespace
bool MultiHeadAttentionFusion::Init() const {
input_q_ = std::make_shared<Var>();
input_q_ = std::make_shared<Var>("input_q");
MS_CHECK_TRUE_RET(input_q_ != nullptr, false);
input_k_ = std::make_shared<Var>();
input_k_ = std::make_shared<Var>("input_k");
MS_CHECK_TRUE_RET(input_k_ != nullptr, false);
input_v_ = std::make_shared<Var>();
input_v_ = std::make_shared<Var>("input_v");
MS_CHECK_TRUE_RET(input_v_ != nullptr, false);
position_bias_ = std::make_shared<Var>("position_bias_");
MS_CHECK_TRUE_RET(position_bias_ != nullptr, false);
weight_q_ = std::make_shared<CondVar>(IsParamNode);
weight_q_ = std::make_shared<CondVar>(IsParamNode, "weight_q");
MS_CHECK_TRUE_RET(weight_q_ != nullptr, false);
weight_k_ = std::make_shared<CondVar>(IsParamNode);
weight_k_ = std::make_shared<CondVar>(IsParamNode, "weight_k");
MS_CHECK_TRUE_RET(weight_k_ != nullptr, false);
weight_v_ = std::make_shared<CondVar>(IsParamNode);
weight_v_ = std::make_shared<CondVar>(IsParamNode, "weight_v");
MS_CHECK_TRUE_RET(weight_v_ != nullptr, false);
weight_o_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(weight_o_ != nullptr, false);
bias_q_ = std::make_shared<CondVar>(IsParamNode);
bias_q_ = std::make_shared<CondVar>(IsParamNode, "bias_q");
MS_CHECK_TRUE_RET(bias_q_ != nullptr, false);
bias_k_ = std::make_shared<CondVar>(IsParamNode);
bias_k_ = std::make_shared<CondVar>(IsParamNode, "bias_k");
MS_CHECK_TRUE_RET(bias_k_ != nullptr, false);
bias_v_ = std::make_shared<CondVar>(IsParamNode);
bias_v_ = std::make_shared<CondVar>(IsParamNode, "bias_v");
MS_CHECK_TRUE_RET(bias_v_ != nullptr, false);
bias_o_ = std::make_shared<CondVar>(IsParamNode);
MS_CHECK_TRUE_RET(bias_o_ != nullptr, false);
mask_ = std::make_shared<Var>();
mask_ = std::make_shared<Var>("mask");
MS_CHECK_TRUE_RET(mask_ != nullptr, false);
reshape_k_ = std::make_shared<Var>("reshape_k");
@ -77,20 +80,43 @@ namespace {
VectorRef DefineMask(const BaseRef &mask_input) {
auto is_expand_dims = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims));
MS_CHECK_TRUE_RET(is_expand_dims != nullptr, {});
auto var1 = std::make_shared<Var>();
auto var1 = std::make_shared<Var>("m-var1");
MS_CHECK_TRUE_RET(var1 != nullptr, {});
auto expand_dims = VectorRef({is_expand_dims, mask_input, var1});
auto is_sub = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSubFusion));
auto is_sub = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSubFusion), "m-sub");
MS_CHECK_TRUE_RET(is_sub != nullptr, {});
auto var2 = std::make_shared<Var>();
auto var2 = std::make_shared<Var>("m-var2");
MS_CHECK_TRUE_RET(var2 != nullptr, {});
auto sub = VectorRef({is_sub, var2, expand_dims});
auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion));
auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "m-mul");
MS_CHECK_TRUE_RET(is_mul != nullptr, {});
auto var3 = std::make_shared<Var>();
auto var3 = std::make_shared<Var>("m-var3");
MS_CHECK_TRUE_RET(var3 != nullptr, {});
return VectorRef({is_mul, sub, var3});
}
STATUS GetIntParameterData(const ParameterPtr &param_ptr, std::vector<int> *result) {
if (param_ptr == nullptr || !param_ptr->has_default()) {
MS_LOG(DEBUG) << "param not have default";
return RET_ERROR;
}
auto default_param = param_ptr->default_param();
if (default_param == nullptr || !utils::isa<tensor::TensorPtr>(default_param)) {
MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
return RET_ERROR;
}
auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
MS_LOG(DEBUG) << "default param is not int";
return RET_ERROR;
}
auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
int64_t shape_size =
std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>());
for (int64_t i = 0; i < shape_size; i++) {
result->emplace_back(ptr[i]);
}
return RET_OK;
}
STATUS GetAxis(const BaseRef &n, std::vector<int> *axes) {
MS_ASSERT(axes != nullptr);
@ -99,8 +125,14 @@ STATUS GetAxis(const BaseRef &n, std::vector<int> *axes) {
*axes = CastToInt(axes_value_node->value());
return lite::RET_OK;
} else {
MS_LOG(ERROR) << "GetAxis supports only value node";
auto reshape = utils::cast<ParameterPtr>(n);
if (reshape != nullptr) {
if (GetIntParameterData(reshape, axes) == lite::RET_OK) {
return lite::RET_OK;
}
}
}
MS_LOG(ERROR) << " cannot get axes data";
return lite::RET_ERROR;
}
} // namespace
@ -132,23 +164,42 @@ VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const
return conn;
}
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool cross, bool mask) const {
VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &axis,
const BaseRef &transpose_var, bool test_div, bool transpose) const {
auto is_matmul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "e-matmul");
MS_CHECK_TRUE_RET(is_matmul != nullptr, {});
auto dense = VectorRef({is_matmul, input, weight});
auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "e-reshape");
MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
auto reshape = VectorRef({is_reshape, dense, axis});
auto var2 = std::make_shared<Var>();
VectorRef conn;
if (transpose) {
conn = VectorRef({transpose_var, reshape, var2});
} else {
conn = reshape;
}
if (test_div) {
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "e-div");
MS_CHECK_TRUE_RET(is_div != nullptr, {});
auto var3 = std::make_shared<Var>();
MS_CHECK_TRUE_RET(var3 != nullptr, {});
auto div = VectorRef({is_div, conn, var3});
return div;
}
return conn;
}
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool mask) const {
VectorRef k_embedding, v_embedding;
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true);
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
if (!cross) {
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
} else {
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
}
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
@ -194,23 +245,16 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool cross, bool mas
return matmul3;
}
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5(bool cross) const {
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5() const {
VectorRef k_embedding, v_embedding;
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "q_transpose");
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true, false);
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
if (!cross) {
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_, false, false);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
} else {
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
}
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul1");
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape1");
@ -246,23 +290,79 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5(bool cross) const
return matmul3;
}
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const {
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose) const {
VectorRef k_embedding, v_embedding;
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "q_transpose");
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
VectorRef q_embedding;
if (transpose) {
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, true);
} else {
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, false);
}
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
k_embedding = DefineEmbedding(input_k_, weight_k_, reshape_k_, k_transpose_, true, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, reshape_v_, v_transpose_, false, true);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul1");
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape1");
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
auto var1 = std::make_shared<Var>("var1");
MS_CHECK_TRUE_RET(var1 != nullptr, {});
auto is_add1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "add1");
MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
auto add1 = VectorRef({is_add1, matmul1, position_bias_});
auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "add2");
MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
auto mask2 = DefineMask(mask_);
MS_CHECK_TRUE_RET(!mask2.empty(), {});
auto add2 = VectorRef({is_add1, mask2, add1});
auto reshape1 = VectorRef({is_reshape1, add2, var1});
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax), "softmax");
MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
auto softmax = VectorRef({is_softmax, reshape1});
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape2");
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
auto var2 = std::make_shared<Var>("var2");
MS_CHECK_TRUE_RET(var2 != nullptr, {});
auto reshape2 = VectorRef({is_reshape2, softmax, var2});
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul2");
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
auto matmul2 = VectorRef({is_matmul2, reshape2, v_embedding});
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape3");
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
auto var4 = std::make_shared<Var>("var4");
MS_CHECK_TRUE_RET(var4 != nullptr, {});
VectorRef reshape3;
if (transpose) {
auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "transpose");
MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
auto var3 = std::make_shared<Var>("var3");
MS_CHECK_TRUE_RET(var3 != nullptr, {});
auto transpose1 = VectorRef({is_transpose, matmul2, var3});
reshape3 = VectorRef({is_reshape3, transpose1, var4});
} else {
reshape3 = VectorRef({is_reshape3, matmul2, var4});
}
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul3");
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_});
return matmul3;
}
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA() const {
VectorRef k_embedding, v_embedding;
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true);
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
if (!cross) {
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
} else {
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
}
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
@ -296,13 +396,14 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const
}
namespace {
template <typename T>
STATUS TransposeMatrix(std::shared_ptr<tensor::Tensor> src, std::shared_ptr<tensor::Tensor> dst) {
MS_CHECK_TRUE_RET(src->shape().size() == C2NUM, RET_ERROR);
MS_CHECK_TRUE_RET(dst->shape().size() == C2NUM, RET_ERROR);
int rows = src->shape().at(0);
int cols = src->shape().at(1);
auto src_ptr = reinterpret_cast<float *>(src->data_c());
auto dst_ptr = reinterpret_cast<float *>(dst->data_c());
auto src_ptr = reinterpret_cast<T *>(src->data_c());
auto dst_ptr = reinterpret_cast<T *>(dst->data_c());
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
auto val = src_ptr[r * cols + c];
@ -354,8 +455,20 @@ std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<
if (transpose) {
std::vector<int64_t> tshape = {new_shape[1], new_shape[0]};
auto transposed_tensor = std::make_shared<tensor::Tensor>(base_data_type, tshape);
auto status = TransposeMatrix(concat_tensor, transposed_tensor);
switch (base_data_type) {
case kNumberTypeFloat32: {
auto status = TransposeMatrix<float>(concat_tensor, transposed_tensor);
MS_CHECK_TRUE_RET(status == RET_OK, nullptr);
break;
}
case kNumberTypeFloat16: {
auto status = TransposeMatrix<float16>(concat_tensor, transposed_tensor);
MS_CHECK_TRUE_RET(status == RET_OK, nullptr);
break;
}
default:
MS_LOG(ERROR) << "unsupported data type " << base_data_type << std::endl;
}
return transposed_tensor;
}
return concat_tensor;
@ -368,14 +481,13 @@ std::unordered_map<std::string, VectorRef> MultiHeadAttentionFusion::DefinePatte
MS_LOG(ERROR) << "initial member failed.";
return patterns;
}
patterns[kMPAWithMaskPatternName] = DefineMPWithMaskPattern();
patterns[kMPAXWithMaskPatternName] = DefineMPWithMaskPattern(true);
patterns[kMPAPatternName] = DefineMPWithMaskPattern(false, false);
patterns[kMPAXPatternName] = DefineMPWithMaskPattern(true, false);
patterns[kMPAPatternName] = DefineMPWithMaskPattern(false);
patterns[kMPAWithMaskPatternNamePA] = DefineMPWithMaskPatternPA();
patterns[kMPAXWithMaskPatternNamePA] = DefineMPWithMaskPatternPA(true);
patterns[kMPAWithMaskPatternNameT5] = DefineMPWithMaskPatternT5();
patterns[kMPAXWithMaskPatternNameT5] = DefineMPWithMaskPatternT5(true);
patterns[kMPAWithMaskPatternNameT5New] = DefineMPWithMaskPatternT5New(false);
patterns[kMPAWithMaskTransposePatternNameT5New] = DefineMPWithMaskPatternT5New();
return patterns;
}
@ -384,14 +496,24 @@ bool MultiHeadAttentionFusion::CheckPattern(const EquivPtr &equiv, int *head_num
MS_ASSERT(head_num != nullptr);
MS_ASSERT(head_size != nullptr);
std::vector<int> reshape_axes;
// UNDO !!!!!!!!!
if (GetAxis((*equiv)[reshape_axis_], &reshape_axes) != lite::RET_OK) {
MS_LOG(ERROR) << "cannot figure out reshape";
return false;
}
if (reshape_axes.size() != C4NUM) {
return false;
}
*head_num = reshape_axes.at(C2NUM);
*head_size = reshape_axes.at(C3NUM);
std::vector<int> out;
std::for_each(reshape_axes.begin() + 1, reshape_axes.end(), [&out](const auto &x) {
if (x != -1) out.push_back(x);
});
if (out.size() < C2NUM) {
MS_LOG(ERROR) << "cannot find head_num or head_size";
return false;
}
*head_num = out.at(0);
*head_size = out.at(1);
return true;
}
@ -401,44 +523,43 @@ AnfNodePtr MultiHeadAttentionFusion::Process(const std::string &pattern_name, co
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
return nullptr;
}
++match_count_;
if ((pattern_name == kMPAWithMaskPatternName) || (pattern_name == kMPAWithMaskPatternNamePA) ||
(pattern_name == kMPAWithMaskPatternNameT5)) {
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope());
} else if ((pattern_name == kMPAXWithMaskPatternName) || (pattern_name == kMPAXWithMaskPatternNamePA) ||
(pattern_name == kMPAXWithMaskPatternNameT5)) {
(pattern_name == kMPAWithMaskPatternNameT5) ||
(pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New)) {
if (pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New) {
t5_x_ = true;
}
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), true);
} else if (pattern_name == kMPAPatternName) {
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), false, false);
} else if (pattern_name == kMPAXPatternName) {
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), true, false);
}
if (pattern_name == kMPAPatternName)
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), false);
return nullptr;
}
{ return nullptr; }
}
STATUS GetIntParameterData(const ParameterPtr &param_ptr, std::vector<int> *result) {
if (param_ptr == nullptr || !param_ptr->has_default()) {
MS_LOG(DEBUG) << "param not have default";
return RET_ERROR;
}
auto default_param = param_ptr->default_param();
if (default_param == nullptr || !utils::isa<tensor::TensorPtr>(default_param)) {
MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
return RET_ERROR;
}
auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
MS_LOG(DEBUG) << "default param is not int";
return RET_ERROR;
}
auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
int64_t shape_size =
std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>());
for (int64_t i = 0; i < shape_size; i++) {
result->emplace_back(ptr[i]);
}
return RET_OK;
}
// STATUS GetIntParameterData(const ParameterPtr &param_ptr, std::vector<int> *result) {
// if (param_ptr == nullptr || !param_ptr->has_default()) {
// MS_LOG(DEBUG) << "param not have default";
// return RET_ERROR;
// }
// auto default_param = param_ptr->default_param();
// if (default_param == nullptr || !utils::isa<tensor::TensorPtr>(default_param)) {
// MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
// return RET_ERROR;
// }
// auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
// if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
// MS_LOG(DEBUG) << "default param is not int";
// return RET_ERROR;
// }
// auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
// int64_t shape_size =
// std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>());
// for (int64_t i = 0; i < shape_size; i++) {
// result->emplace_back(ptr[i]);
// }
// return RET_OK;
// }
std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(const EquivPtr &equiv) const {
MS_ASSERT(equiv != nullptr);
@ -595,50 +716,115 @@ CNodePtr MultiHeadAttentionFusion::MakeGetTuple(const FuncGraphPtr &func_graph,
return get_item_node;
}
CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph,
const EquivPtr &equiv, const string &base_name,
bool cross, bool mask) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
std::vector<AnfNodePtr> redundant;
auto attention_prim = CreatePrim(equiv, cross);
MS_CHECK_TRUE_RET(attention_prim != nullptr, nullptr);
auto attention_prim_c = attention_prim->GetPrim();
MS_CHECK_TRUE_RET(attention_prim_c != nullptr, nullptr);
auto value_node = NewValueNode(attention_prim_c);
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
bool MultiHeadAttentionFusion::IsCross(const EquivPtr &equiv) const {
auto input_q = utils::cast<AnfNodePtr>((*equiv)[input_q_]);
auto input_v = utils::cast<AnfNodePtr>((*equiv)[input_v_]);
ShapeVector inputq_shape, inputv_shape;
auto ret = FetchShapeFromAbstract(input_q->abstract(), &inputq_shape);
MS_CHECK_TRUE_RET(ret == RET_OK, false);
ret = FetchShapeFromAbstract(input_v->abstract(), &inputv_shape);
MS_CHECK_TRUE_RET(ret == RET_OK, false);
if ((inputq_shape != inputv_shape) || ((match_count_ > 1) && (input_q != input_v))) {
return true;
}
return false;
}
using tup = std::tuple<AnfNodePtr, std::shared_ptr<tensor::Tensor>, std::shared_ptr<tensor::Tensor>,
std::shared_ptr<tensor::Tensor>>;
tup MultiHeadAttentionFusion::GetAttentionNodeWeights(const EquivPtr &equiv, std::vector<AnfNodePtr> *redundant) const {
auto weight_q = utils::cast<AnfNodePtr>((*equiv)[weight_q_]);
auto weight_k = utils::cast<AnfNodePtr>((*equiv)[weight_k_]);
auto weight_v = utils::cast<AnfNodePtr>((*equiv)[weight_v_]);
redundant->push_back(weight_q);
redundant->push_back(weight_k);
redundant->push_back(weight_v);
auto weight_o = utils::cast<AnfNodePtr>((*equiv)[weight_o_]);
std::shared_ptr<tensor::Tensor> weight_q_tensor = GetTensorInfo(weight_q);
std::shared_ptr<tensor::Tensor> weight_k_tensor = GetTensorInfo(weight_k);
std::shared_ptr<tensor::Tensor> weight_v_tensor = GetTensorInfo(weight_v);
return make_tuple(weight_o, weight_q_tensor, weight_k_tensor, weight_v_tensor);
}
std::vector<AnfNodePtr> MultiHeadAttentionFusion::GetNewNodeInputs(const EquivPtr &equiv, ParameterPtr q_weight_param,
ParameterPtr c_weight_param, AnfNodePtr weight_o,
ParameterPtr c_bias_param, AnfNodePtr bias_o,
bool mask, bool cross) const {
auto input_q = utils::cast<AnfNodePtr>((*equiv)[input_q_]);
auto input_k = utils::cast<AnfNodePtr>((*equiv)[input_k_]);
auto input_v = utils::cast<AnfNodePtr>((*equiv)[input_v_]);
AnfNodePtr input_mask;
auto weight_q = utils::cast<AnfNodePtr>((*equiv)[weight_q_]);
redundant.push_back(weight_q);
auto weight_k = utils::cast<AnfNodePtr>((*equiv)[weight_k_]);
auto weight_v = utils::cast<AnfNodePtr>((*equiv)[weight_v_]);
redundant.push_back(weight_k);
redundant.push_back(weight_v);
auto weight_o = utils::cast<AnfNodePtr>((*equiv)[weight_o_]);
auto bias_q = utils::cast<AnfNodePtr>((*equiv)[bias_q_]);
if (!cross) {
redundant.push_back(bias_q);
AnfNodePtr input_mask = (mask) ? utils::cast<AnfNodePtr>((*equiv)[mask_]) : nullptr;
AnfNodePtr position_bias = (t5_x_) ? utils::cast<AnfNodePtr>((*equiv)[position_bias_]) : nullptr;
auto attention_prim = CreatePrim(equiv, cross);
MS_CHECK_TRUE_RET(attention_prim != nullptr, {});
auto attention_prim_c = attention_prim->GetPrim();
MS_CHECK_TRUE_RET(attention_prim_c != nullptr, {});
auto value_node = NewValueNode(attention_prim_c);
MS_CHECK_TRUE_RET(value_node != nullptr, {});
std::vector<AnfNodePtr> new_node_inputs;
if (cross) {
if (t5_x_) {
new_node_inputs = {value_node, input_q, input_k, input_v,
q_weight_param, c_weight_param, weight_o, position_bias};
} else {
new_node_inputs = {value_node, input_q, input_k, input_v, q_weight_param,
c_weight_param, weight_o, c_bias_param, bias_o};
}
} else {
if (t5_x_) {
new_node_inputs = {value_node, input_q, input_k, input_v, c_weight_param, weight_o, position_bias};
} else {
new_node_inputs = {value_node, input_q, input_k, input_v, c_weight_param, weight_o, c_bias_param, bias_o};
}
}
if (mask) {
new_node_inputs.push_back(input_mask);
}
return new_node_inputs;
}
CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph,
const EquivPtr &equiv, const string &base_name,
bool mask) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(equiv != nullptr);
bool cross = IsCross(equiv);
std::vector<AnfNodePtr> redundant;
auto [weight_o, weight_q_tensor, weight_k_tensor, weight_v_tensor] = GetAttentionNodeWeights(equiv, &redundant);
AnfNodePtr bias_q;
ParameterPtr c_bias_param;
AnfNodePtr bias_o;
if (!t5_x_) {
bias_q = utils::cast<AnfNodePtr>((*equiv)[bias_q_]);
auto bias_k = utils::cast<AnfNodePtr>((*equiv)[bias_k_]);
auto bias_v = utils::cast<AnfNodePtr>((*equiv)[bias_v_]);
redundant.push_back(bias_k);
redundant.push_back(bias_v);
auto bias_o = utils::cast<AnfNodePtr>((*equiv)[bias_o_]);
auto knode = utils::cast<AnfNodePtr>((*equiv)[k_transpose_]);
auto vnode = utils::cast<AnfNodePtr>((*equiv)[v_transpose_]);
if (mask) {
input_mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
}
std::shared_ptr<tensor::Tensor> weight_q_tensor = GetTensorInfo(weight_q);
std::shared_ptr<tensor::Tensor> weight_k_tensor = GetTensorInfo(weight_k);
std::shared_ptr<tensor::Tensor> weight_v_tensor = GetTensorInfo(weight_v);
bias_o = utils::cast<AnfNodePtr>((*equiv)[bias_o_]);
std::shared_ptr<tensor::Tensor> bias_q_tensor = GetTensorInfo(bias_q);
std::shared_ptr<tensor::Tensor> bias_k_tensor = GetTensorInfo(bias_k);
std::shared_ptr<tensor::Tensor> bias_v_tensor = GetTensorInfo(bias_v);
auto c_bias = ConcatTensors({bias_q_tensor, bias_k_tensor, bias_v_tensor});
c_bias_param = func_graph->add_parameter();
MS_CHECK_TRUE_RET(c_bias_param != nullptr, nullptr);
c_bias_param->set_name(base_name + "/bias_qkv");
if (lite::InitParameterFromTensorInfo(c_bias_param, c_bias) != lite::RET_OK) {
MS_LOG(ERROR) << "Init parameter from tensor info failed.";
return nullptr;
}
}
auto knode = utils::cast<AnfNodePtr>((*equiv)[k_transpose_]);
AnfNodePtr vnode;
auto it_vnode = (*equiv).find(v_transpose_);
if (it_vnode != (*equiv).end() && !t5_x_) vnode = utils::cast<AnfNodePtr>(it_vnode->second);
if (!cross && !t5_x_) {
redundant.push_back(bias_q);
}
tensor::TensorPtr c_weights;
tensor::TensorPtr q_weight_t;
if (cross) {
@ -647,7 +833,6 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
} else {
c_weights = ConcatTensors({weight_q_tensor, weight_k_tensor, weight_v_tensor}, true);
}
auto c_bias = ConcatTensors({bias_q_tensor, bias_k_tensor, bias_v_tensor});
// convert tensors to params
auto c_weight_param = func_graph->add_parameter();
MS_CHECK_TRUE_RET(c_weight_param != nullptr, nullptr);
@ -656,13 +841,6 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
return nullptr;
}
c_weight_param->set_name(base_name + "/weight_qkv");
auto c_bias_param = func_graph->add_parameter();
MS_CHECK_TRUE_RET(c_bias_param != nullptr, nullptr);
if (lite::InitParameterFromTensorInfo(c_bias_param, c_bias) != lite::RET_OK) {
MS_LOG(ERROR) << "Init parameter from tensor info failed.";
return nullptr;
}
c_bias_param->set_name(base_name + "/bias_qkv");
ParameterPtr q_weight_param;
if (cross) {
q_weight_param = func_graph->add_parameter();
@ -672,25 +850,26 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
return nullptr;
}
}
std::vector<AnfNodePtr> new_node_inputs;
if (cross) {
new_node_inputs = {value_node, input_q, input_k, input_v, q_weight_param,
c_weight_param, weight_o, c_bias_param, bias_o};
} else {
new_node_inputs = {value_node, input_q, input_k, input_v, c_weight_param, weight_o, c_bias_param, bias_o};
}
if (mask) {
new_node_inputs.push_back(input_mask);
}
std::vector<AnfNodePtr> new_node_inputs =
GetNewNodeInputs(equiv, q_weight_param, c_weight_param, weight_o, c_bias_param, bias_o, mask, cross);
auto new_node = func_graph->NewCNode(new_node_inputs);
MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
if (vnode) {
auto status = SetAbstractTuple(new_node, kAttentionOutputs);
if (status != RET_OK) {
return nullptr;
}
}
new_node->set_fullname_with_scope(base_name + "/attention");
CNodePtr ret_node;
if (vnode) {
auto get_item_node = MakeGetTuple(func_graph, new_node, knode, vnode);
ret_node = get_item_node;
} else {
ret_node = new_node;
}
RemoveRedundantInput(func_graph, redundant);
return get_item_node;
return ret_node;
}
} // namespace mindspore::opt

View File

@ -20,6 +20,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <tuple>
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
#include "include/common/utils/utils.h"
#include "include/errorcode.h"
@ -46,15 +47,18 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
private:
// define patterns
VectorRef DefineMPWithMaskPattern(bool cross = false, bool mask = true) const;
VectorRef DefineMPWithMaskPatternPA(bool cross = false) const;
VectorRef DefineMPWithMaskPatternT5(bool cross = false) const;
VectorRef DefineMPWithMaskPattern(bool mask = true) const;
VectorRef DefineMPWithMaskPatternPA() const;
VectorRef DefineMPWithMaskPatternT5() const;
VectorRef DefineMPWithMaskPatternT5New(bool transpose = true) const;
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, const BaseRef &axis,
const BaseRef &transpose_var, bool test_div = false, bool transpose = true) const;
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &axis,
const BaseRef &transpose_var, bool test_div, bool transpose) const;
// create masked-multi-head-attention
CNodePtr CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
const std::string &base_name, bool cross = false, bool mask = true) const;
const std::string &base_name, bool mask = true) const;
// check pattern
bool CheckPattern(const EquivPtr &equiv, int *head_num, int *head_size) const;
CNodePtr CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, const int item_index) const;
@ -66,20 +70,27 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
CNodePtr MakeGetTuple(const FuncGraphPtr &func_graph, const CNodePtr &new_node, const AnfNodePtr &knode,
const AnfNodePtr &vnode) const;
std::shared_ptr<ops::Attention> CreatePrim(const EquivPtr &equiv, bool cross) const;
bool IsCross(const EquivPtr &equiv) const;
std::vector<AnfNodePtr> GetNewNodeInputs(const EquivPtr &equiv, ParameterPtr q_weight_param,
ParameterPtr c_weight_param, AnfNodePtr weight_o, ParameterPtr c_bias_param,
AnfNodePtr bias_o, bool mask, bool cross) const;
std::tuple<AnfNodePtr, std::shared_ptr<tensor::Tensor>, std::shared_ptr<tensor::Tensor>,
std::shared_ptr<tensor::Tensor> >
GetAttentionNodeWeights(const EquivPtr &equiv, std::vector<AnfNodePtr> *redundant) const;
mutable int match_count_ = 0;
protected:
const std::string kMPAWithMaskPatternName = "MPAWithMaskPattern";
const std::string kMPAXWithMaskPatternName = "MPAXWithMaskPattern";
const std::string kMPAWithMaskPatternNamePA = "MPAWithMaskPatternPA";
const std::string kMPAXWithMaskPatternNamePA = "MPAXWithMaskPatternPA";
const std::string kMPAPatternName = "MPAPattern";
const std::string kMPAXPatternName = "MPAXPattern";
const std::string kMPAWithMaskPatternNameT5 = "MPAWithMaskPatternT5";
const std::string kMPAXWithMaskPatternNameT5 = "MPAXWithMaskPatternT5";
const std::string kMPAWithMaskPatternNameT5New = "MPAWithMaskPatternT5New";
const std::string kMPAWithMaskTransposePatternNameT5New = "MPAWithMaskTransposePatternT5New";
mutable VarPtr input_q_{nullptr};
mutable VarPtr input_k_{nullptr};
mutable VarPtr input_v_{nullptr};
mutable VarPtr position_bias_{nullptr};
mutable VarPtr weight_q_{nullptr};
mutable VarPtr weight_k_{nullptr};
@ -98,8 +109,9 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
mutable VarPtr reshape_axis_{nullptr};
mutable VarPtr v_transpose_{nullptr};
mutable VarPtr k_transpose_{nullptr};
};
mutable bool t5_x_{false};
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_MULTI_HEAD_ATTENTION_FUSION_H_

File diff suppressed because it is too large Load Diff