From dff877dbd3307617a21ff0cc4390a13842a7e23d Mon Sep 17 00:00:00 2001 From: nizzan Date: Tue, 6 Dec 2022 09:56:12 +0200 Subject: [PATCH] Adding support for FP16, cross & T5 MHA --- cmake/external_libs/fast_transformers.cmake | 5 +- .../common/optimizer/pattern_engine.cc | 3 +- .../cpu/kernel/nnacl/infer/attention_infer.c | 18 +- .../tensorrt/cuda_impl/cublas_utils.cc | 62 +- .../tensorrt/cuda_impl/cublas_utils.h | 8 +- .../delegate/tensorrt/op/mha_tensorrt.cc | 348 +- .../delegate/tensorrt/op/mha_tensorrt.h | 33 +- .../delegate/tensorrt/tensorrt_subgraph.cc | 20 +- .../delegate/tensorrt/tensorrt_utils.cc | 65 + .../delegate/tensorrt/tensorrt_utils.h | 25 +- .../extendrt/mock/lite_runtime/converters.cc | 2 +- .../tools/benchmark/benchmark_unified_api.cc | 1 + .../fusion/multi_head_attention_fusion.cc | 487 ++- .../fusion/multi_head_attention_fusion.h | 30 +- .../001-fast_transformer.patch | 3219 +++++++++++++++-- 15 files changed, 3713 insertions(+), 613 deletions(-) diff --git a/cmake/external_libs/fast_transformers.cmake b/cmake/external_libs/fast_transformers.cmake index bad2f0f2c86..7823783c4c7 100644 --- a/cmake/external_libs/fast_transformers.cmake +++ b/cmake/external_libs/fast_transformers.cmake @@ -2,6 +2,9 @@ set(REQ_URL "https://github.com/NVIDIA/FasterTransformer/archive/refs/tags/relea set(SHA256 "7adffe2d53b3c1544295a6b7d1887e59b044eba25dd3e150bc909168d5e99081") set(ft_libs "transformer-shared") +if(DEFINED ENV{MSLITE_GPU_ARCH}) + set(arch_opt -DSM=$ENV{MSLITE_GPU_ARCH}) +endif() mindspore_add_pkg(fast_transformers VER 0.5.0 @@ -10,7 +13,7 @@ mindspore_add_pkg(fast_transformers LIBS ${ft_libs} LIB_PATH lib PATCHES ${MINDSPORE_PROJECT_DIR}/third_party/patch/fast_transformer/001-fast_transformer.patch - CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DEXAMPLES=off) + CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release ${arch_opt} -DEXAMPLES=off) include_directories(${fast_transformers_INC}) add_library(mindspore::fast_transformers ALIAS fast_transformers::transformer-shared) diff --git a/mindspore/ccsrc/backend/common/optimizer/pattern_engine.cc b/mindspore/ccsrc/backend/common/optimizer/pattern_engine.cc index a2f724da93c..1062fcdd01f 100644 --- a/mindspore/ccsrc/backend/common/optimizer/pattern_engine.cc +++ b/mindspore/ccsrc/backend/common/optimizer/pattern_engine.cc @@ -124,7 +124,8 @@ static BaseRef GetVar(const BaseRef &x) { EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) { MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); - MS_EXCEPTION_IF_NULL(equiv); + + if (equiv == nullptr) MS_EXCEPTION_IF_NULL(equiv); if (utils::isa(pattern)) { VarPtr var = utils::cast(pattern); if (var->matches(expr)) { diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/attention_infer.c b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/attention_infer.c index 2d110d8695f..56e30b3bbb6 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/attention_infer.c +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/attention_infer.c @@ -39,13 +39,19 @@ int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor if (q_weight->shape_size_ != C2NUM) { return NNACL_ERR; } - int batch = (q_input->shape_size_ == 2) ? 1 : q_input->shape_[0]; - int f_seq = (q_input->shape_size_ == 2) ? q_input->shape_[0] : q_input->shape_[1]; + int batch = (q_input->shape_size_ == C2NUM) ? 1 : q_input->shape_[0]; + int f_seq = (q_input->shape_size_ == C2NUM) ? q_input->shape_[0] : q_input->shape_[1]; int t_seq_len = k_input->shape_[1]; - output0->shape_[FIRST_INPUT] = batch; - output0->shape_[SECOND_INPUT] = f_seq; - output0->shape_[THIRD_INPUT] = param->head_num_ * param->head_size_; - output0->shape_size_ = C3NUM; + if (q_input->shape_size_ == C2NUM) { + output0->shape_[FIRST_INPUT] = batch * f_seq; + output0->shape_[SECOND_INPUT] = param->head_num_ * param->head_size_; + output0->shape_size_ = C2NUM; + } else { + output0->shape_[FIRST_INPUT] = batch; + output0->shape_[SECOND_INPUT] = f_seq; + output0->shape_[THIRD_INPUT] = param->head_num_ * param->head_size_; + output0->shape_size_ = C3NUM; + } if (outputs_size >= C3NUM) { TensorC *output1 = outputs[SECOND_INPUT]; SetDataTypeFormat(output1, q_input); diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.cc index d3b4519cd79..9a5a98fd046 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.cc +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.cc @@ -71,7 +71,7 @@ void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int * void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta, - cublasHandle_t cublas_handle) { + cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) { const int m = params[0]; const int n = params[1]; const int k = params[2]; @@ -84,15 +84,17 @@ void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, con cudaDataType type_b = data_types[1]; cudaDataType type_c = data_types[2]; cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; - + if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { + compute_type = CUBLAS_COMPUTE_16F; + } CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda, b_addr, type_b, - ldb, beta, c_addr, type_c, ldc, compute_type, CUBLAS_GEMM_DEFAULT)); + ldb, beta, c_addr, type_c, ldc, compute_type, algo)); } void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, const cublasOperation_t *operations, const int *strides, const cudaDataType *data_types, void *alpha, void *beta, int batch, - cublasHandle_t cublas_handle) { + cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) { const int m = params[0]; const int n = params[1]; const int k = params[2]; @@ -105,12 +107,62 @@ void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, voi cudaDataType type_b = data_types[1]; cudaDataType type_c = data_types[2]; cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { + compute_type = CUBLAS_COMPUTE_16F; + } const int stride_a = strides[0]; const int stride_b = strides[1]; const int stride_c = strides[2]; CUBLAS_CHECK_VOID(cublasGemmStridedBatchedEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda, stride_a, b_addr, type_b, ldb, stride_b, beta, c_addr, type_c, ldc, - stride_c, batch, compute_type, CUBLAS_GEMM_DEFAULT)); + stride_c, batch, compute_type, algo)); +} + +void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, + const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta, + const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle) { + cublasOperation_t trans_a = operations[0]; + cublasOperation_t trans_b = operations[1]; + cudaDataType type_a = data_types[0]; + cudaDataType type_b = data_types[1]; + cudaDataType type_c = data_types[2]; + const int m = params[0]; + const int n = params[1]; + const int k = params[2]; + const int lda = lds[0]; + const int ldb = lds[1]; + const int ldc = lds[2]; + + cublasLtMatrixLayout_t mat_a_desc = NULL; + cublasLtMatrixLayoutCreate(&mat_a_desc, type_a, (trans_a == CUBLAS_OP_N) ? m : k, (trans_a == CUBLAS_OP_N) ? k : m, + lda); + cublasLtMatrixLayout_t mat_b_desc = NULL; + cublasLtMatrixLayoutCreate(&mat_b_desc, type_b, (trans_b == CUBLAS_OP_N) ? k : n, (trans_b == CUBLAS_OP_N) ? n : k, + ldb); + cublasLtMatrixLayout_t mat_c_desc = NULL; + cublasLtMatrixLayoutCreate(&mat_c_desc, type_c, m, n, ldc); + + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { + compute_type = CUBLAS_COMPUTE_16F; + } + + cublasLtMatmulDesc_t mat_operation_desc = NULL; + cublasLtMatmulDescCreate(&mat_operation_desc, compute_type, type_a); + cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(cublasOperation_t)); + cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(cublasOperation_t)); + if (bias != nullptr) { + cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS; + cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t)); + cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void *)); + } + + cublasLtMatmul(cublaslt_handle, mat_operation_desc, alpha, a_addr, mat_a_desc, b_addr, mat_b_desc, beta, c_addr, + mat_c_desc, c_addr, mat_c_desc, NULL, NULL, 0, stream); + cublasLtMatrixLayoutDestroy(mat_a_desc); + cublasLtMatrixLayoutDestroy(mat_b_desc); + cublasLtMatrixLayoutDestroy(mat_c_desc); + cublasLtMatmulDescDestroy(mat_operation_desc); } } // namespace mindspore::lite diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h b/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h index e68dd40d9b1..528daf8fafb 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h @@ -62,10 +62,14 @@ void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int * void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta, - cublasHandle_t cublas_handle); + cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP); void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, const cublasOperation_t *operations, const int *strides, const cudaDataType *data_types, void *alpha, void *beta, int batch, - cublasHandle_t cublas_handle); + cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP); + +void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds, + const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta, + const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle); } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_ diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/mha_tensorrt.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/op/mha_tensorrt.cc index 075505122f7..5ee63e71e14 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/mha_tensorrt.cc +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/mha_tensorrt.cc @@ -34,13 +34,38 @@ namespace mindspore::lite { namespace { constexpr std::size_t kTwo = 2; constexpr std::size_t kThree = 3; +std::ostream &operator<<(std::ostream &s, const nvinfer1::ITensor &t) { + const auto &dims = t.getDimensions(); + s << "ndims=" << dims.nbDims << " ["; + for (int i = 0; i < dims.nbDims; ++i) { + s << dims.d[i] << " "; + } + s << "]"; + return s; +} } // namespace +#define SET_GEMM_PARAMS(gemm_ops_, gemm_lds_, gemm_op1_, gemm_op2_, gemm_ld1_, gemm_ld2_, gemm_ld3_) \ + do { \ + gemm_ops_[0] = gemm_op1_; \ + gemm_ops_[1] = gemm_op2_; \ + gemm_lds_[0] = gemm_ld1_; \ + gemm_lds_[1] = gemm_ld2_; \ + gemm_lds_[2] = gemm_ld3_; \ + } while (0) + +#define SET_GEMM_DIMS(gemm_dims_, gemm_dim1_, gemm_dim2_, gemm_dim3_) \ + do { \ + gemm_dims_[0] = gemm_dim1_; \ + gemm_dims_[1] = gemm_dim2_; \ + gemm_dims_[2] = gemm_dim3_; \ + } while (0) + // Multi Head Attention TensorRT op int MhaTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, const std::vector &out_tensors) { - if (in_tensors.size() != 8 && in_tensors.size() != 6) { - MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size(); + if (in_tensors.size() < 7 || in_tensors.size() > 9) { // T5 has 6 or 7 inputs, other models have 8 or 9 inputs + MS_LOG(ERROR) << "Unsupported number of inputs, size is " << in_tensors.size(); return RET_ERROR; } return RET_OK; @@ -59,13 +84,13 @@ int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) { // get attribute for Attn op - TODO - add attribute in op int head_number = mha_op->get_head_num(); int head_size = mha_op->get_head_size(); - int compute_type = 1; // mha_op->get_compute_type(); - int is_cross = mha_op->get_cross(); - nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_; - - auto plugin = std::make_shared(input_tensor->getName(), compute_type, head_number, head_size, is_cross, - GetCublasHandle(), GetCublasLtHandle(), device_id_); + auto compute_type = runtime_->GetRuntimePrecisionMode(); // mha_op->get_compute_type(); + bool is_cross = mha_op->get_cross(); const int input_number = inputs().size(); + bool is_position_bias = (((input_number == 8) && is_cross) || ((input_number == 7) && !is_cross)) ? true : false; + nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_; + auto plugin = std::make_shared(input_tensor->getName(), compute_type, head_number, head_size, is_cross, + is_position_bias, GetCublasHandle(), GetCublasLtHandle(), device_id_); nvinfer1::ITensor *inputTensors[input_number]; for (int i = 0; i < input_number; i++) { inputTensors[i] = input(ctx, i).trt_tensor_; @@ -75,16 +100,46 @@ int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) { MS_LOG(ERROR) << "add mha op failed for TensorRT."; return RET_ERROR; } - mha_layer->setName(op_name_.c_str()); + mha_layer->setName((op_name_ + "plugin_attention").c_str()); // TODO(haim) one output nvinfer1::ITensor *attn_tensor = mha_layer->getOutput(0); +#ifndef TEST_ ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name()); +#else /* TEST_ */ + ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name() + "attn"); +#endif /* TEST_ */ // nvinfer1::ITensor *key_tensor = mha_layer->getOutput(1); // ctx->RegisterTensor(ITensorHelper{key_tensor, Format::NCHW, true}, out_tensors_[1].Name()); // nvinfer1::ITensor *value_tensor = mha_layer->getOutput(kTwo); // ctx->RegisterTensor(ITensorHelper{value_tensor, Format::NCHW, true}, out_tensors_[kTwo].Name()); this->layer_ = mha_layer; +#ifdef TEST_ + auto weight_projection = input(ctx, 4).trt_tensor_; + auto bias_projection = input(ctx, 6).trt_tensor_; +#endif /* TEST_ */ +#ifdef TEST_ + auto matmul_layer = ctx->network()->addMatrixMultiply(*attn_tensor, nvinfer1::MatrixOperation::kNONE, + *weight_projection, nvinfer1::MatrixOperation::kNONE); + if (matmul_layer == nullptr) { + MS_LOG(ERROR) << "failed to add matmul layer"; + return RET_ERROR; + } + matmul_layer->setName((op_name_ + "_matmul").c_str()); + auto matmul_tensor = matmul_layer->getOutput(0); + auto shuffle_layer = ctx->network()->addShuffle(*bias_projection); + const auto size = bias_projection->getDimensions().d[0]; + shuffle_layer->setReshapeDimensions(nvinfer1::Dims{2, {1, size}}); + auto shuffle_tensor = shuffle_layer->getOutput(0); + auto addbias = ctx->network()->addElementWise(*matmul_tensor, *shuffle_tensor, nvinfer1::ElementWiseOperation::kSUM); + if (addbias == nullptr) { + MS_LOG(ERROR) << "failed to add bias layer"; + return RET_ERROR; + } + addbias->setName((op_name_ + "_bias").c_str()); + auto bias_out = addbias->getOutput(0); + ctx->RegisterTensor(ITensorHelper{bias_out, Format::NCHW, true}, out_tensors_[0].Name()); +#endif /* TEST_ */ return RET_OK; } @@ -98,11 +153,80 @@ std::vector TensorRTPluginCreater::fields_; int MhaPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept { - return RunCudaMha(inputDesc, outputDesc, inputs, outputs, workspace, stream); + if (compute_type_ == RuntimePrecisionMode_FP16) { + return RunCudaMha(inputDesc, outputDesc, inputs, outputs, workspace, stream, fast_algo_gemm); + } else { + return RunCudaMha(inputDesc, outputDesc, inputs, outputs, workspace, stream, fast_algo_gemm); + } } +template +void MhaPlugin::SetInnerAddr(void *workspace, size_t size_q, size_t size_k, size_t qk_buf_len, size_t qkv_buf_2_len, + size_t extra_size) { + size_t qkv_len = size_q + (size_k * 2); // size_v is equal to size_k + size_t q_buf_2_len = size_q; + auto buff_size = + qkv_len + q_buf_2_len + qk_buf_len + (qkv_buf_2_len * 2); // qkv_buf_3_ len is equal to qkv_buf_2_len + qkv_buf_ = workspace; + q_buf_2_ = static_cast(qkv_buf_) + qkv_len; + qk_buf_ = static_cast(q_buf_2_) + q_buf_2_len; + qkv_buf_2_ = static_cast(qk_buf_) + qk_buf_len; + qkv_buf_3_ = static_cast(qkv_buf_2_) + qkv_buf_2_len; + output1_ = static_cast(workspace) + buff_size; + output2_ = static_cast(output1_) + extra_size; +} + +template +void MhaPlugin::RunPhase1GEMM(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, int *gemm_dims, + int *gemm_lds, cublasOperation_t *gemm_ops, cudaDataType *gemm_data_types, void *alpha, + void *beta, cublasGemmAlgo_t algoId, cudaStream_t stream) { + int cross_tensor_offset = 0; + if (is_cross_) cross_tensor_offset = 1; + const int from_tensor_idx = 0, encoder_tensor_idx = 1, weight_qkv_tensor_idx = 3; + const int weight_qkv_tensor_idx_cross = 3 + cross_tensor_offset; + const int bias_qkv_tensor_idx = 5 + cross_tensor_offset; + const int attn_mask_tensor_idx = (is_position_bias_) ? 6 + cross_tensor_offset : 7 + cross_tensor_offset; + + auto from_tensor = static_cast(inputs[from_tensor_idx]); + auto encoder_output_tensor = static_cast(inputs[encoder_tensor_idx]); + auto weight_q = static_cast(inputs[weight_qkv_tensor_idx]); + auto weight_kv = static_cast(inputs[weight_qkv_tensor_idx_cross]); + auto weight_qkv = static_cast(inputs[weight_qkv_tensor_idx_cross]); + auto bias_qkv = (is_position_bias_) ? nullptr : static_cast(inputs[bias_qkv_tensor_idx]); + + auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims; + const int request_batch_size = static_cast(inputDesc[attn_mask_tensor_idx].dims.d[0]); + const int request_src_seq_len = static_cast(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]); + const int request_tgt_seq_len = static_cast(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]); + auto hidden_size = static_cast(head_number_ * head_size_); + if (is_cross_) { + SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, hidden_size, hidden_size, hidden_size); + SET_GEMM_DIMS(gemm_dims, hidden_size, request_batch_size * request_src_seq_len, hidden_size); + CublasGemmWrapper(weight_q, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops, + const_cast(gemm_data_types), alpha, beta, cublas_handle_); + SET_GEMM_DIMS(gemm_dims, C2NUM * hidden_size, request_batch_size * request_tgt_seq_len, hidden_size); + gemm_lds[0] = gemm_lds[THIRD_INPUT] = C2NUM * hidden_size; + + CublasGemmWrapper(weight_kv, encoder_output_tensor, + static_cast(qkv_buf_) + (request_batch_size * request_src_seq_len) * hidden_size, gemm_dims, + gemm_lds, gemm_ops, const_cast(gemm_data_types), alpha, beta, + cublas_handle_); + fastertransformer::invokeCrossAddFusedQKVBiasTranspose( + static_cast(q_buf_2_), static_cast(output1_), static_cast(output2_), static_cast(qkv_buf_), + bias_qkv, request_batch_size, request_src_seq_len, request_tgt_seq_len, head_number_, head_size_, stream); + } else { + CublasGemmWrapper(weight_qkv, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops, + const_cast(gemm_data_types), alpha, beta, cublas_handle_, algoId); + fastertransformer::invokeAddFusedQKVBiasTranspose( + static_cast(q_buf_2_), static_cast(output1_), static_cast(output2_), static_cast(qkv_buf_), + bias_qkv, request_batch_size, request_src_seq_len, head_number_, head_size_, 0, stream); + } +} + +template int MhaPlugin::RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) { + const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream, + cublasGemmAlgo_t *algoId) { // inputs order: // 0] Q // 1] K @@ -112,135 +236,100 @@ int MhaPlugin::RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvi // 5] B // 6] PB // 7] AttnMask - + // inputs order cross: + // 0] Q + // 1] K enco output + // 2] V + // 3] Wq + // 4] Wkv + // 5] PW + // 6] Bqkv + // 7] PB + // 8] AttnMask + int cross_tensor_offset = 0; cublasSetStream(cublas_handle_, stream); + if (is_cross_) cross_tensor_offset = 1; + const int weight_projection_tensor_idx = 4 + cross_tensor_offset; + const int bias_projection_tensor_idx = 6 + cross_tensor_offset; + const int attn_mask_tensor_idx = (is_position_bias_) ? 6 + cross_tensor_offset : 7 + cross_tensor_offset; + const int bias_position_tensor_idx = 5 + cross_tensor_offset; - // TODO(Haim) - Fix tensor ids according to cross flag - const int from_tensor_idx = 0; - // const int encoder_tensor_idx = 1; - const int weight_qkv_tensor_idx = 3; - const int weight_projection_tensor_idx = 4; - const int bias_qkv_tensor_idx = 5; - const int bias_projection_tensor_idx = 6; - const int attn_mask_tensor_idx = 7; - - auto from_tensor = static_cast(inputs[from_tensor_idx]); - auto attention_mask = static_cast(inputs[attn_mask_tensor_idx]); - auto weight_qkv = static_cast(inputs[weight_qkv_tensor_idx]); - auto bias_qkv = static_cast(inputs[bias_qkv_tensor_idx]); - auto weight_projection = static_cast(inputs[weight_projection_tensor_idx]); - auto bias_projection = static_cast(inputs[bias_projection_tensor_idx]); - - auto output0 = static_cast(outputs[0]); - // auto output1 = static_cast(outputs[1]); - // auto output2 = static_cast(outputs[2]); - + auto attention_mask = static_cast(inputs[attn_mask_tensor_idx]); + auto weight_projection = static_cast(inputs[weight_projection_tensor_idx]); + auto bias_projection = (is_position_bias_) ? nullptr : static_cast(inputs[bias_projection_tensor_idx]); + auto bias_position = (is_position_bias_) ? static_cast(inputs[bias_position_tensor_idx]) : nullptr; + auto output0 = static_cast(outputs[0]); auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims; const int request_batch_size = static_cast(inputDesc[attn_mask_tensor_idx].dims.d[0]); const int request_src_seq_len = static_cast(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]); const int request_tgt_seq_len = static_cast(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]); auto hidden_size = static_cast(head_number_ * head_size_); - - // TODO(NIZZAN): fix allocator - size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len; + auto extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len; size_t size_q = request_batch_size * request_src_seq_len * hidden_size; size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size; - size_t size_v = size_k; - - size_t qkv_len = size_q + size_k + size_v; - size_t q_buf_2_len = size_q; size_t qk_buf_len = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len; size_t qkv_buf_2_len = request_batch_size * request_src_seq_len * hidden_size; - size_t qkv_buf_3_len = qkv_buf_2_len; - size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len; - qkv_buf_ = workspace; - q_buf_2_ = static_cast(qkv_buf_) + qkv_len; - qk_buf_ = static_cast(q_buf_2_) + q_buf_2_len; - qkv_buf_2_ = static_cast(qk_buf_) + qk_buf_len; - qkv_buf_3_ = static_cast(qkv_buf_2_) + qkv_buf_2_len; - output1_ = static_cast(workspace) + buff_size; - output2_ = static_cast(output1_) + extra_tmp_size; + SetInnerAddr(workspace, size_q, size_k, qk_buf_len, qkv_buf_2_len, extra_tmp_size); - int gemm_dims[3] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size}; - int gemm_lds[3] = {3 * hidden_size, hidden_size, 3 * hidden_size}; + cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N}; + cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; + if constexpr (std::is_same::value) + std::fill(std::begin(gemm_data_types), std::end(gemm_data_types), CUDA_R_16F); + float alpha = 1.0f, beta = 0.0f; + int gemm_dims[] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size}; + int gemm_lds[] = {3 * hidden_size, hidden_size, 3 * hidden_size}; - cublasOperation_t gemm_ops[2] = {CUBLAS_OP_N, CUBLAS_OP_N}; - const cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; - float alpha = 1.0f; - float beta = 0.0f; - - CublasGemmWrapper(weight_qkv, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha, &beta, - cublas_handle_); - - fastertransformer::invokeAddFusedQKVBiasTranspose(static_cast(q_buf_2_), static_cast(output1_), - static_cast(output2_), static_cast(qkv_buf_), - bias_qkv, request_batch_size, request_src_seq_len, head_number_, - head_size_, 0, stream); - gemm_ops[0] = CUBLAS_OP_T; - gemm_ops[1] = CUBLAS_OP_N; - gemm_dims[0] = request_tgt_seq_len; - gemm_dims[1] = request_src_seq_len; - gemm_dims[THIRD_INPUT] = head_size_; - - gemm_lds[0] = head_size_; - gemm_lds[1] = head_size_; - gemm_lds[THIRD_INPUT] = request_tgt_seq_len; + RunPhase1GEMM(inputDesc, inputs, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha, &beta, algoId[0], stream); + SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_T, CUBLAS_OP_N, head_size_, head_size_, request_tgt_seq_len); + SET_GEMM_DIMS(gemm_dims, request_tgt_seq_len, request_src_seq_len, head_size_); int gemm_strides[] = {request_tgt_seq_len * head_size_, request_src_seq_len * head_size_, request_src_seq_len * request_tgt_seq_len}; + CublasGemmStridedBatchedWrapper(output1_, q_buf_2_, qk_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_strides, - gemm_data_types, &alpha, &beta, request_batch_size * head_number_, cublas_handle_); + const_cast(gemm_data_types), &alpha, &beta, + request_batch_size * head_number_, cublas_handle_, algoId[1]); - float scalar = (1.0f / sqrtf(static_cast(head_size_) * 1.0f)); - fastertransformer::invokeMixMaskedSoftMax(static_cast(qk_buf_), attention_mask, request_batch_size, - request_src_seq_len, request_tgt_seq_len, head_number_, scalar, stream); - gemm_ops[0] = CUBLAS_OP_N; - gemm_ops[1] = CUBLAS_OP_N; - gemm_dims[0] = head_size_; - gemm_dims[1] = request_src_seq_len; - gemm_dims[THIRD_INPUT] = request_tgt_seq_len; - - gemm_lds[0] = head_size_; - gemm_lds[1] = request_tgt_seq_len; - gemm_lds[THIRD_INPUT] = head_size_; - - gemm_strides[0] = request_tgt_seq_len * head_size_; + T scalar = static_cast(1.0f / sqrtf(head_size_ * 1.0f)); + fastertransformer::invokeMixMaskedSoftMax(static_cast(qk_buf_), attention_mask, bias_position, + request_batch_size, request_src_seq_len, request_tgt_seq_len, head_number_, + scalar, stream); + SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, head_size_, request_tgt_seq_len, head_size_); + SET_GEMM_DIMS(gemm_dims, head_size_, request_src_seq_len, request_tgt_seq_len); gemm_strides[1] = request_src_seq_len * request_tgt_seq_len; gemm_strides[THIRD_INPUT] = request_src_seq_len * head_size_; CublasGemmStridedBatchedWrapper(output2_, qk_buf_, qkv_buf_2_, gemm_dims, gemm_lds, gemm_ops, gemm_strides, - gemm_data_types, &alpha, &beta, request_batch_size * head_number_, cublas_handle_); - - fastertransformer::invokeTransposeQKV(static_cast(qkv_buf_3_), static_cast(qkv_buf_2_), - request_batch_size, request_src_seq_len, head_number_, head_size_, stream); - - gemm_ops[0] = CUBLAS_OP_N; - gemm_ops[1] = CUBLAS_OP_N; - gemm_dims[0] = hidden_size; - gemm_dims[1] = request_batch_size * request_src_seq_len; - gemm_dims[THIRD_INPUT] = hidden_size; - - gemm_lds[0] = hidden_size; - gemm_lds[1] = hidden_size; - gemm_lds[THIRD_INPUT] = hidden_size; - - CublasGemmWrapper(weight_projection, qkv_buf_3_, output0, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha, - &beta, cublas_handle_); - int len = request_batch_size * request_src_seq_len; - fastertransformer::invokeAddBias(reinterpret_cast(output0), reinterpret_cast(bias_projection), - len, hidden_size, stream); - + const_cast(gemm_data_types), &alpha, &beta, + request_batch_size * head_number_, cublas_handle_, algoId[2]); + fastertransformer::invokeTransposeQKV(static_cast(qkv_buf_3_), static_cast(qkv_buf_2_), request_batch_size, + request_src_seq_len, head_number_, head_size_, stream); + SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, hidden_size, hidden_size, hidden_size); + SET_GEMM_DIMS(gemm_dims, hidden_size, request_batch_size * request_src_seq_len, hidden_size); + CublasGemmWrapper(weight_projection, qkv_buf_3_, output0, gemm_dims, gemm_lds, gemm_ops, + const_cast(gemm_data_types), &alpha, &beta, cublas_handle_, algoId[3]); + if (!is_position_bias_) { + int len = request_batch_size * request_src_seq_len; + fastertransformer::invokeAddBias(reinterpret_cast(output0), reinterpret_cast(bias_projection), len, + hidden_size, stream); + } return RET_OK; } -int MhaPlugin::RunCudaCrossMha(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) { - // Add Cross Mha Layer here - return RET_OK; +bool MhaPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, + int nbOutputs) noexcept { + auto type = (compute_type_ == RuntimePrecisionMode_FP16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; + for (int i = 0; i < pos; i++) { + if (tensorsDesc[pos].type != tensorsDesc[i].type) return false; + } + bool res = (tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) && (tensorsDesc[pos].type == type); + return res; } +void MhaPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {} + size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept { auto attn_dim_size = inputs[nbInputs - 1].dims.nbDims; @@ -250,9 +339,6 @@ size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int auto hidden_size = static_cast(head_number_ * head_size_); // TODO(NIZZAN) Fix efficient allocator - // size_t buff_size = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len + - // request_batch_size * request_src_seq_len * hidden_size; - size_t size_q = request_batch_size * request_src_seq_len * hidden_size; size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size; size_t size_v = size_k; @@ -265,8 +351,12 @@ size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len; size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len; + int elem_size = sizeof(float); + if (compute_type_ == RuntimePrecisionMode_FP16) { + elem_size = sizeof(half); + } - return (buff_size + extra_tmp_size + extra_tmp_size) * sizeof(float); + return (buff_size + extra_tmp_size + extra_tmp_size) * elem_size; } nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims, @@ -282,14 +372,15 @@ nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1 // value_cache [batch, head_num, tgt_seq_len, size_per_head] nvinfer1::DimsExprs dims; if (index == 0) { - // if (inputs[0].nbDims == 2) { - // dims.nbDims = INPUT_SIZE2; - // dims.d[0] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1]; - // auto hidden_size = exprBuilder.constant(head_size_ * head_number_); - // dims.d[1] = hidden_size; - // } else - { - dims.nbDims = INPUT_SIZE3; +#ifndef TEST_ + int num_dims = inputs[0].nbDims; + dims.nbDims = num_dims; + if (num_dims == INPUT_SIZE2) { + dims.d[0] = exprBuilder.constant(inputs[nbInputDims - 1].d[0]->getConstantValue() * + inputs[nbInputDims - 1].d[1]->getConstantValue()); + auto hidden_size = exprBuilder.constant(head_size_ * head_number_); + dims.d[1] = hidden_size; + } else if (num_dims == INPUT_SIZE3) { dims.d[0] = inputs[nbInputDims - 1].d[0]; // batch dims.d[1] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1]; auto hidden_size = exprBuilder.constant(head_size_ * head_number_); @@ -303,6 +394,13 @@ nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1 dims.d[kTwo] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1]; dims.d[kThree] = exprBuilder.constant(head_size_); } +#else + dims.nbDims = C2NUM; + dims.d[0] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1]; + auto hidden_size = exprBuilder.constant(head_size_ * head_number_); + dims.d[1] = hidden_size; + } +#endif return dims; } @@ -316,6 +414,8 @@ nvinfer1::IPluginV2DynamicExt *MhaPlugin::clone() const noexcept { return plugin; } +int MhaPlugin::initialize() noexcept { return 0; } + void MhaPlugin::terminate() noexcept {} size_t MhaPlugin::getSerializationSize() const noexcept { return INPUT_SIZE4 * sizeof(int); } diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/op/mha_tensorrt.h b/mindspore/lite/src/extendrt/delegate/tensorrt/op/mha_tensorrt.h index a5d0ce68235..c3f2c279179 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/op/mha_tensorrt.h +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/op/mha_tensorrt.h @@ -31,6 +31,8 @@ class MhaTensorRT : public TensorRTOp { : TensorRTOp(base_operator, in_tensors, out_tensors, name) {} ~MhaTensorRT() override = default; + + // bool IsWeightInputHanledInner() const override { return true; } int AddInnerOp(TensorRTContext *ctx) override; int IsSupport(const BaseOperatorPtr &base_operator, const std::vector &in_tensors, @@ -40,13 +42,14 @@ class MhaTensorRT : public TensorRTOp { constexpr auto MHA_PLUGIN_NAME{"AttentionPlugin"}; class MhaPlugin : public TensorRTPlugin { public: - MhaPlugin(const std::string name, int compute_type, int head_number, int head_size, int is_cross, - cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, uint32_t device_id) + MhaPlugin(const std::string name, int compute_type, int head_number, int head_size, bool is_cross, + bool is_position_bias, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, uint32_t device_id) : TensorRTPlugin(name, std::string(MHA_PLUGIN_NAME), device_id), compute_type_(compute_type), head_number_(head_number), head_size_(head_size), is_cross_(is_cross), + is_position_bias_(is_position_bias), cublas_handle_(cublas_handle), cublaslt_handle_(cublaslt_handle) {} @@ -57,6 +60,7 @@ class MhaPlugin : public TensorRTPlugin { head_number_ = static_cast(fields[1].data)[0]; head_size_ = static_cast(fields[2].data)[0]; is_cross_ = static_cast(fields[3].data)[0]; + is_position_bias_ = static_cast(fields[4].data)[0]; } MhaPlugin(const char *name, const void *serialData, size_t serialLength) @@ -65,6 +69,7 @@ class MhaPlugin : public TensorRTPlugin { DeserializeValue(&serialData, &serialLength, &head_number_, sizeof(int)); DeserializeValue(&serialData, &serialLength, &head_size_, sizeof(int)); DeserializeValue(&serialData, &serialLength, &is_cross_, sizeof(int)); + DeserializeValue(&serialData, &serialLength, &is_position_bias_, sizeof(int)); } MhaPlugin() = delete; @@ -82,21 +87,37 @@ class MhaPlugin : public TensorRTPlugin { const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override; nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims, nvinfer1::IExprBuilder &exprBuilder) noexcept override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs, + int nbOutputs) noexcept override; void terminate() noexcept override; + int initialize() noexcept override; private: bool needResize(const int *current_dims, const int *last_dims); + template int RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream); - int RunCudaCrossMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream); + const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream, + cublasGemmAlgo_t *algoId); + template + void SetInnerAddr(void *workspace, size_t size_q, size_t size_k, size_t qk_buf_len, size_t qkv_buf_2_len, + size_t extra_size); + template + void RunPhase1GEMM(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, int *gemm_dims, + int *gemm_lds, cublasOperation_t *gemm_ops, cudaDataType *gemm_data_types, void *alpha, void *beta, + cublasGemmAlgo_t algoId, cudaStream_t stream); const std::string layer_name_; std::string name_space_; int compute_type_; int head_number_; int head_size_; - int is_cross_; + bool is_cross_; + bool is_position_bias_; + cublasGemmAlgo_t fast_algo_gemm[4] = {CUBLAS_GEMM_DEFAULT_TENSOR_OP, CUBLAS_GEMM_DEFAULT_TENSOR_OP, + CUBLAS_GEMM_DEFAULT_TENSOR_OP, CUBLAS_GEMM_DEFAULT_TENSOR_OP}; + cublasHandle_t cublas_handle_; cublasLtHandle_t cublaslt_handle_; void *qkv_buf_{nullptr}; diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc index c1e28e7a37e..a52b3c53595 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_subgraph.cc @@ -27,6 +27,7 @@ #include #include #include "src/extendrt/delegate/delegate_utils.h" +#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" #include "src/common/utils.h" #include "ops/transpose.h" @@ -68,6 +69,11 @@ TensorRTSubGraph::~TensorRTSubGraph() { config_->destroy(); config_ = nullptr; } +#ifdef PROFILER_ + auto profile = dynamic_cast(trt_context_->getProfiler()); + if (profile != nullptr) std::cout << *profile << std::endl; + delete profile; +#endif if (trt_context_ != nullptr) { trt_context_->destroy(); trt_context_ = nullptr; @@ -494,6 +500,15 @@ int TensorRTSubGraph::Prepare() { MS_LOG(ERROR) << "TensorRTSubGraph create context failed."; return RET_ERROR; } + +#ifdef PROFILER_ + auto profiler = new SimpleProfiler("myprofiler"); + if (profiler == nullptr) { + MS_LOG(WARNING) << "Cannot create profiler"; + } + this->trt_context_->setProfiler(profiler); +#endif + int binding_num = this->engine_->getNbBindings(); if (binding_num <= 0) { MS_LOG(ERROR) << "TensorRTSubGraph binding num < 0."; @@ -504,7 +519,6 @@ int TensorRTSubGraph::Prepare() { MS_LOG(ERROR) << "malloc tensor binding array failed."; return RET_ERROR; } - profile_index_ = MaxVolumnProfileIndex(); if (this->trt_context_->setOptimizationProfile(profile_index_)) { MS_LOG(INFO) << "setOptimizationProfile: " << profile_index_; @@ -527,9 +541,6 @@ int TensorRTSubGraph::Prepare() { MS_LOG(INFO) << "device index " << index << " for tensor : " << tensor_name << " attr: " << device_ptr; tensor_bindings_[index] = device_ptr; nvinfer1::Dims input_dims = ConvertCudaDims(profile.inputs[i].max_dims); - for (int od = 0; od < input_dims.nbDims; od++) { - MS_LOG(DEBUG) << "in tensor " << tensor.Name() << " dims at " << od << " is " << input_dims.d[od]; - } if (!this->trt_context_->setBindingDimensions(index, input_dims)) { MS_LOG(ERROR) << "invalid input dims of " << tensor.Name(); return RET_ERROR; @@ -781,7 +792,6 @@ int TensorRTSubGraph::PostExecute(std::vector *outputs) { // actual output tensor dims auto out_dims = this->trt_context_->getBindingDimensions(index); std::vector new_shape = lite::ConvertMSShape(out_dims); - outputs_[i].SetShape(new_shape); for (int od = 0; od < out_dims.nbDims; od++) { MS_LOG(DEBUG) << "out tensor " << trt_out_tensor_name << " dims at " << od << " is " << new_shape[od]; } diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_utils.cc b/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_utils.cc index 2fcace8e43e..07b3e9e3241 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_utils.cc +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_utils.cc @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include "src/extendrt/delegate/tensorrt/op/cast_plugin.h" #include "src/extendrt/delegate/tensorrt/distribution/distribution_collective.h" @@ -176,6 +178,7 @@ nvinfer1::DataType ConvertDataType(DataType type_id) { {DataType::kNumberTypeInt32, nvinfer1::DataType::kINT32}, {DataType::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT}, {DataType::kNumberTypeFloat16, nvinfer1::DataType::kHALF}, + {DataType::kNumberTypeInt64, nvinfer1::DataType::kINT32}, }; auto iter = data_type_map.find(type_id); nvinfer1::DataType data_type; @@ -895,4 +898,66 @@ nvinfer1::DataType GetNvinferDataType() { template nvinfer1::DataType GetNvinferDataType(); template nvinfer1::DataType GetNvinferDataType(); + +#ifdef PROFILER_ +void SimpleProfiler::reportLayerTime(const char *layerName, float ms) noexcept { + mProfile_[layerName].count++; + mProfile_[layerName].time += ms; + if (std::find(mLayerNames_.begin(), mLayerNames_.end(), layerName) == mLayerNames_.end()) { + mLayerNames_.push_back(layerName); + } +} + +SimpleProfiler::SimpleProfiler(const char *name, const std::vector &srcProfilers) : mName_(name) { + for (const auto &srcProfiler : srcProfilers) { + for (const auto &rec : srcProfiler.mProfile_) { + auto it = mProfile_.find(rec.first); + if (it == mProfile_.end()) { + mProfile_.insert(rec); + } else { + it->second.time += rec.second.time; + it->second.count += rec.second.count; + } + } + } +} + +std::ostream &operator<<(std::ostream &out, const SimpleProfiler &value) { + out << "========== " << value.mName_ << " profile ==========" << std::endl; + float totalTime = 0; + std::string layerNameStr = "TensorRT layer name"; + int maxLayerNameLength = std::max(static_cast(layerNameStr.size()), 70); + for (const auto &elem : value.mProfile_) { + totalTime += elem.second.time; + maxLayerNameLength = std::max(maxLayerNameLength, static_cast(elem.first.size())); + } + + auto old_settings = out.flags(); + auto old_precision = out.precision(); + // Output header + { + out << std::setw(maxLayerNameLength) << layerNameStr << " "; + out << std::setw(C12NUM) << "Runtime, " + << "%" + << " "; + out << std::setw(C12NUM) << "Invocations" + << " "; + out << std::setw(C12NUM) << "Runtime, ms" << std::endl; + } + for (size_t i = 0; i < value.mLayerNames_.size(); i++) { + const std::string layerName = value.mLayerNames_[i]; + auto elem = value.mProfile_.at(layerName); + out << std::setw(maxLayerNameLength) << layerName << " "; + out << std::setw(C12NUM) << std::fixed << std::setprecision(1) << (elem.time * 100.0F / totalTime) << "%" + << " "; + out << std::setw(C12NUM) << elem.count << " "; + out << std::setw(C12NUM) << std::fixed << std::setprecision(C2NUM) << elem.time << std::endl; + } + out.flags(old_settings); + out.precision(old_precision); + out << "========== " << value.mName_ << " total runtime = " << totalTime << " ms ==========" << std::endl; + + return out; +} +#endif // PROFILER_ } // namespace mindspore::lite diff --git a/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h b/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h index 9389c793c42..18baf21654f 100644 --- a/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h +++ b/mindspore/lite/src/extendrt/delegate/tensorrt/tensorrt_utils.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "src/extendrt/delegate/tensorrt/tensorrt_context.h" #include "src/extendrt/delegate/tensorrt/tensor_info.h" #include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h" @@ -57,6 +58,28 @@ typedef union float32_bits { float f; } float32_bits; +// #define PROFILER_ +#ifdef PROFILER_ +struct SimpleProfiler : public nvinfer1::IProfiler { + struct Record { + float time{0}; + int count{0}; + }; + + virtual void reportLayerTime(const char *layerName, float ms) noexcept; + + explicit SimpleProfiler(const char *name, + const std::vector &srcProfilers = std::vector()); + + friend std::ostream &operator<<(std::ostream &out, const SimpleProfiler &value); + + private: + std::string mName_; + std::vector mLayerNames_; + std::map mProfile_; +}; +#endif + // Convert Tensor data to Cuda dims. nvinfer1::Dims ConvertCudaDims(const std::vector &data); @@ -194,4 +217,4 @@ void Data2Vector(std::vector *dst, const void *src) { } } } // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_UTILS_H_ +#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_UTILS_H_ diff --git a/mindspore/lite/src/extendrt/mock/lite_runtime/converters.cc b/mindspore/lite/src/extendrt/mock/lite_runtime/converters.cc index 2bc937d3f58..5daa1288b2c 100644 --- a/mindspore/lite/src/extendrt/mock/lite_runtime/converters.cc +++ b/mindspore/lite/src/extendrt/mock/lite_runtime/converters.cc @@ -64,7 +64,7 @@ Status ContextUtils::AddGpuDevice(bool enable_fp16, uint32_t device_id, int rank Status ContextUtils::AddNpuDevice(int frequency, lite::InnerContext *inner_context) { lite::DeviceInfo device_info; - device_info.npu_device_info_ = {false, frequency}; + device_info.npu_device_info_.frequency_ = frequency; inner_context->device_list_.push_back({lite::DT_NPU, device_info}); return kSuccess; } diff --git a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc index 398a476bbbf..e98f9552697 100644 --- a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc +++ b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc @@ -42,6 +42,7 @@ #include #include "src/common/config_file.h" #endif + namespace mindspore { constexpr size_t kDataToStringMaxNum = 40; constexpr int kPrintDataNum = 20; diff --git a/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.cc b/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.cc index 75361725224..c9315442c8f 100644 --- a/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "tools/optimizer/common/gllo_utils.h" #include "nnacl/op_base.h" #include "ops/tuple_get_item.h" @@ -32,32 +33,34 @@ const int kAttentionOutputs = 3; } // namespace bool MultiHeadAttentionFusion::Init() const { - input_q_ = std::make_shared(); + input_q_ = std::make_shared("input_q"); MS_CHECK_TRUE_RET(input_q_ != nullptr, false); - input_k_ = std::make_shared(); + input_k_ = std::make_shared("input_k"); MS_CHECK_TRUE_RET(input_k_ != nullptr, false); - input_v_ = std::make_shared(); + input_v_ = std::make_shared("input_v"); MS_CHECK_TRUE_RET(input_v_ != nullptr, false); + position_bias_ = std::make_shared("position_bias_"); + MS_CHECK_TRUE_RET(position_bias_ != nullptr, false); - weight_q_ = std::make_shared(IsParamNode); + weight_q_ = std::make_shared(IsParamNode, "weight_q"); MS_CHECK_TRUE_RET(weight_q_ != nullptr, false); - weight_k_ = std::make_shared(IsParamNode); + weight_k_ = std::make_shared(IsParamNode, "weight_k"); MS_CHECK_TRUE_RET(weight_k_ != nullptr, false); - weight_v_ = std::make_shared(IsParamNode); + weight_v_ = std::make_shared(IsParamNode, "weight_v"); MS_CHECK_TRUE_RET(weight_v_ != nullptr, false); weight_o_ = std::make_shared(IsParamNode); MS_CHECK_TRUE_RET(weight_o_ != nullptr, false); - bias_q_ = std::make_shared(IsParamNode); + bias_q_ = std::make_shared(IsParamNode, "bias_q"); MS_CHECK_TRUE_RET(bias_q_ != nullptr, false); - bias_k_ = std::make_shared(IsParamNode); + bias_k_ = std::make_shared(IsParamNode, "bias_k"); MS_CHECK_TRUE_RET(bias_k_ != nullptr, false); - bias_v_ = std::make_shared(IsParamNode); + bias_v_ = std::make_shared(IsParamNode, "bias_v"); MS_CHECK_TRUE_RET(bias_v_ != nullptr, false); bias_o_ = std::make_shared(IsParamNode); MS_CHECK_TRUE_RET(bias_o_ != nullptr, false); - mask_ = std::make_shared(); + mask_ = std::make_shared("mask"); MS_CHECK_TRUE_RET(mask_ != nullptr, false); reshape_k_ = std::make_shared("reshape_k"); @@ -77,20 +80,43 @@ namespace { VectorRef DefineMask(const BaseRef &mask_input) { auto is_expand_dims = std::make_shared(std::bind(IsOpType, p1, prim::kPrimExpandDims)); MS_CHECK_TRUE_RET(is_expand_dims != nullptr, {}); - auto var1 = std::make_shared(); + auto var1 = std::make_shared("m-var1"); MS_CHECK_TRUE_RET(var1 != nullptr, {}); auto expand_dims = VectorRef({is_expand_dims, mask_input, var1}); - auto is_sub = std::make_shared(std::bind(IsOpType, p1, prim::kPrimSubFusion)); + auto is_sub = std::make_shared(std::bind(IsOpType, p1, prim::kPrimSubFusion), "m-sub"); MS_CHECK_TRUE_RET(is_sub != nullptr, {}); - auto var2 = std::make_shared(); + auto var2 = std::make_shared("m-var2"); MS_CHECK_TRUE_RET(var2 != nullptr, {}); auto sub = VectorRef({is_sub, var2, expand_dims}); - auto is_mul = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMulFusion)); + auto is_mul = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMulFusion), "m-mul"); MS_CHECK_TRUE_RET(is_mul != nullptr, {}); - auto var3 = std::make_shared(); + auto var3 = std::make_shared("m-var3"); MS_CHECK_TRUE_RET(var3 != nullptr, {}); return VectorRef({is_mul, sub, var3}); } +STATUS GetIntParameterData(const ParameterPtr ¶m_ptr, std::vector *result) { + if (param_ptr == nullptr || !param_ptr->has_default()) { + MS_LOG(DEBUG) << "param not have default"; + return RET_ERROR; + } + auto default_param = param_ptr->default_param(); + if (default_param == nullptr || !utils::isa(default_param)) { + MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr"; + return RET_ERROR; + } + auto default_param_ptr = utils::cast(default_param); + if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) { + MS_LOG(DEBUG) << "default param is not int"; + return RET_ERROR; + } + auto ptr = reinterpret_cast(default_param_ptr->data_c()); + int64_t shape_size = + std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>()); + for (int64_t i = 0; i < shape_size; i++) { + result->emplace_back(ptr[i]); + } + return RET_OK; +} STATUS GetAxis(const BaseRef &n, std::vector *axes) { MS_ASSERT(axes != nullptr); @@ -99,8 +125,14 @@ STATUS GetAxis(const BaseRef &n, std::vector *axes) { *axes = CastToInt(axes_value_node->value()); return lite::RET_OK; } else { - MS_LOG(ERROR) << "GetAxis supports only value node"; + auto reshape = utils::cast(n); + if (reshape != nullptr) { + if (GetIntParameterData(reshape, axes) == lite::RET_OK) { + return lite::RET_OK; + } + } } + MS_LOG(ERROR) << " cannot get axes data"; return lite::RET_ERROR; } } // namespace @@ -132,23 +164,42 @@ VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const return conn; } -VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool cross, bool mask) const { +VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &axis, + const BaseRef &transpose_var, bool test_div, bool transpose) const { + auto is_matmul = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "e-matmul"); + MS_CHECK_TRUE_RET(is_matmul != nullptr, {}); + auto dense = VectorRef({is_matmul, input, weight}); + auto is_reshape = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape), "e-reshape"); + MS_CHECK_TRUE_RET(is_reshape != nullptr, {}); + auto reshape = VectorRef({is_reshape, dense, axis}); + auto var2 = std::make_shared(); + VectorRef conn; + if (transpose) { + conn = VectorRef({transpose_var, reshape, var2}); + } else { + conn = reshape; + } + if (test_div) { + auto is_div = std::make_shared(std::bind(IsOpType, p1, prim::kPrimRealDiv), "e-div"); + MS_CHECK_TRUE_RET(is_div != nullptr, {}); + auto var3 = std::make_shared(); + MS_CHECK_TRUE_RET(var3 != nullptr, {}); + auto div = VectorRef({is_div, conn, var3}); + return div; + } + return conn; +} + +VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool mask) const { VectorRef k_embedding, v_embedding; auto q_transpose = std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose)); MS_CHECK_TRUE_RET(q_transpose != nullptr, {}); auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true); MS_CHECK_TRUE_RET(!q_embedding.empty(), {}); - if (!cross) { - k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true); - MS_CHECK_TRUE_RET(!k_embedding.empty(), {}); - v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_); - MS_CHECK_TRUE_RET(!v_embedding.empty(), {}); - } else { - k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true); - MS_CHECK_TRUE_RET(!k_embedding.empty(), {}); - v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_); - MS_CHECK_TRUE_RET(!v_embedding.empty(), {}); - } + k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, true); + MS_CHECK_TRUE_RET(!k_embedding.empty(), {}); + v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_); + MS_CHECK_TRUE_RET(!v_embedding.empty(), {}); auto is_matmul1 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMulFusion)); MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {}); auto is_reshape1 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape)); @@ -194,23 +245,16 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool cross, bool mas return matmul3; } -VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5(bool cross) const { +VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5() const { VectorRef k_embedding, v_embedding; auto q_transpose = std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose), "q_transpose"); MS_CHECK_TRUE_RET(q_transpose != nullptr, {}); auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true, false); MS_CHECK_TRUE_RET(!q_embedding.empty(), {}); - if (!cross) { - k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true); - MS_CHECK_TRUE_RET(!k_embedding.empty(), {}); - v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_); - MS_CHECK_TRUE_RET(!v_embedding.empty(), {}); - } else { - k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true); - MS_CHECK_TRUE_RET(!k_embedding.empty(), {}); - v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_); - MS_CHECK_TRUE_RET(!v_embedding.empty(), {}); - } + k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, true); + MS_CHECK_TRUE_RET(!k_embedding.empty(), {}); + v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_, false, false); + MS_CHECK_TRUE_RET(!v_embedding.empty(), {}); auto is_matmul1 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul1"); MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {}); auto is_reshape1 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape1"); @@ -246,23 +290,79 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5(bool cross) const return matmul3; } -VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const { +VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose) const { + VectorRef k_embedding, v_embedding; + auto q_transpose = std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose), "q_transpose"); + MS_CHECK_TRUE_RET(q_transpose != nullptr, {}); + VectorRef q_embedding; + if (transpose) { + q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, true); + } else { + q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, false); + } + MS_CHECK_TRUE_RET(!q_embedding.empty(), {}); + k_embedding = DefineEmbedding(input_k_, weight_k_, reshape_k_, k_transpose_, true, true); + MS_CHECK_TRUE_RET(!k_embedding.empty(), {}); + v_embedding = DefineEmbedding(input_v_, weight_v_, reshape_v_, v_transpose_, false, true); + MS_CHECK_TRUE_RET(!v_embedding.empty(), {}); + auto is_matmul1 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul1"); + MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {}); + auto is_reshape1 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape1"); + MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {}); + auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding}); + auto var1 = std::make_shared("var1"); + MS_CHECK_TRUE_RET(var1 != nullptr, {}); + auto is_add1 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimAddFusion), "add1"); + MS_CHECK_TRUE_RET(is_add1 != nullptr, {}); + auto add1 = VectorRef({is_add1, matmul1, position_bias_}); + auto is_add2 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimAddFusion), "add2"); + MS_CHECK_TRUE_RET(is_add2 != nullptr, {}); + auto mask2 = DefineMask(mask_); + MS_CHECK_TRUE_RET(!mask2.empty(), {}); + auto add2 = VectorRef({is_add1, mask2, add1}); + auto reshape1 = VectorRef({is_reshape1, add2, var1}); + auto is_softmax = std::make_shared(std::bind(IsOpType, p1, prim::kPrimSoftmax), "softmax"); + MS_CHECK_TRUE_RET(is_softmax != nullptr, {}); + auto softmax = VectorRef({is_softmax, reshape1}); + auto is_reshape2 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape2"); + MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {}); + auto var2 = std::make_shared("var2"); + MS_CHECK_TRUE_RET(var2 != nullptr, {}); + auto reshape2 = VectorRef({is_reshape2, softmax, var2}); + auto is_matmul2 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul2"); + MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {}); + auto matmul2 = VectorRef({is_matmul2, reshape2, v_embedding}); + auto is_reshape3 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape3"); + MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {}); + auto var4 = std::make_shared("var4"); + MS_CHECK_TRUE_RET(var4 != nullptr, {}); + VectorRef reshape3; + if (transpose) { + auto is_transpose = std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose), "transpose"); + MS_CHECK_TRUE_RET(is_transpose != nullptr, {}); + auto var3 = std::make_shared("var3"); + MS_CHECK_TRUE_RET(var3 != nullptr, {}); + auto transpose1 = VectorRef({is_transpose, matmul2, var3}); + reshape3 = VectorRef({is_reshape3, transpose1, var4}); + } else { + reshape3 = VectorRef({is_reshape3, matmul2, var4}); + } + + auto is_matmul3 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul3"); + MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {}); + auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_}); + return matmul3; +} +VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA() const { VectorRef k_embedding, v_embedding; auto q_transpose = std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose)); MS_CHECK_TRUE_RET(q_transpose != nullptr, {}); auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true); MS_CHECK_TRUE_RET(!q_embedding.empty(), {}); - if (!cross) { - k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true); - MS_CHECK_TRUE_RET(!k_embedding.empty(), {}); - v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_); - MS_CHECK_TRUE_RET(!v_embedding.empty(), {}); - } else { - k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true); - MS_CHECK_TRUE_RET(!k_embedding.empty(), {}); - v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_); - MS_CHECK_TRUE_RET(!v_embedding.empty(), {}); - } + k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, true); + MS_CHECK_TRUE_RET(!k_embedding.empty(), {}); + v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_); + MS_CHECK_TRUE_RET(!v_embedding.empty(), {}); auto is_matmul1 = std::make_shared(std::bind(IsOpType, p1, prim::kPrimMatMulFusion)); MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {}); auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding}); @@ -296,13 +396,14 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const } namespace { +template STATUS TransposeMatrix(std::shared_ptr src, std::shared_ptr dst) { MS_CHECK_TRUE_RET(src->shape().size() == C2NUM, RET_ERROR); MS_CHECK_TRUE_RET(dst->shape().size() == C2NUM, RET_ERROR); int rows = src->shape().at(0); int cols = src->shape().at(1); - auto src_ptr = reinterpret_cast(src->data_c()); - auto dst_ptr = reinterpret_cast(dst->data_c()); + auto src_ptr = reinterpret_cast(src->data_c()); + auto dst_ptr = reinterpret_cast(dst->data_c()); for (int r = 0; r < rows; ++r) { for (int c = 0; c < cols; ++c) { auto val = src_ptr[r * cols + c]; @@ -354,8 +455,20 @@ std::shared_ptr ConcatTensors(const std::vector tshape = {new_shape[1], new_shape[0]}; auto transposed_tensor = std::make_shared(base_data_type, tshape); - auto status = TransposeMatrix(concat_tensor, transposed_tensor); - MS_CHECK_TRUE_RET(status == RET_OK, nullptr); + switch (base_data_type) { + case kNumberTypeFloat32: { + auto status = TransposeMatrix(concat_tensor, transposed_tensor); + MS_CHECK_TRUE_RET(status == RET_OK, nullptr); + break; + } + case kNumberTypeFloat16: { + auto status = TransposeMatrix(concat_tensor, transposed_tensor); + MS_CHECK_TRUE_RET(status == RET_OK, nullptr); + break; + } + default: + MS_LOG(ERROR) << "unsupported data type " << base_data_type << std::endl; + } return transposed_tensor; } return concat_tensor; @@ -368,14 +481,13 @@ std::unordered_map MultiHeadAttentionFusion::DefinePatte MS_LOG(ERROR) << "initial member failed."; return patterns; } + patterns[kMPAWithMaskPatternName] = DefineMPWithMaskPattern(); - patterns[kMPAXWithMaskPatternName] = DefineMPWithMaskPattern(true); - patterns[kMPAPatternName] = DefineMPWithMaskPattern(false, false); - patterns[kMPAXPatternName] = DefineMPWithMaskPattern(true, false); + patterns[kMPAPatternName] = DefineMPWithMaskPattern(false); patterns[kMPAWithMaskPatternNamePA] = DefineMPWithMaskPatternPA(); - patterns[kMPAXWithMaskPatternNamePA] = DefineMPWithMaskPatternPA(true); patterns[kMPAWithMaskPatternNameT5] = DefineMPWithMaskPatternT5(); - patterns[kMPAXWithMaskPatternNameT5] = DefineMPWithMaskPatternT5(true); + patterns[kMPAWithMaskPatternNameT5New] = DefineMPWithMaskPatternT5New(false); + patterns[kMPAWithMaskTransposePatternNameT5New] = DefineMPWithMaskPatternT5New(); return patterns; } @@ -384,14 +496,24 @@ bool MultiHeadAttentionFusion::CheckPattern(const EquivPtr &equiv, int *head_num MS_ASSERT(head_num != nullptr); MS_ASSERT(head_size != nullptr); std::vector reshape_axes; + // UNDO !!!!!!!!! if (GetAxis((*equiv)[reshape_axis_], &reshape_axes) != lite::RET_OK) { + MS_LOG(ERROR) << "cannot figure out reshape"; return false; } if (reshape_axes.size() != C4NUM) { return false; } - *head_num = reshape_axes.at(C2NUM); - *head_size = reshape_axes.at(C3NUM); + std::vector out; + std::for_each(reshape_axes.begin() + 1, reshape_axes.end(), [&out](const auto &x) { + if (x != -1) out.push_back(x); + }); + if (out.size() < C2NUM) { + MS_LOG(ERROR) << "cannot find head_num or head_size"; + return false; + } + *head_num = out.at(0); + *head_size = out.at(1); return true; } @@ -401,44 +523,43 @@ AnfNodePtr MultiHeadAttentionFusion::Process(const std::string &pattern_name, co if (func_graph == nullptr || node == nullptr || equiv == nullptr) { return nullptr; } + ++match_count_; if ((pattern_name == kMPAWithMaskPatternName) || (pattern_name == kMPAWithMaskPatternNamePA) || - (pattern_name == kMPAWithMaskPatternNameT5)) { - return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope()); - } else if ((pattern_name == kMPAXWithMaskPatternName) || (pattern_name == kMPAXWithMaskPatternNamePA) || - (pattern_name == kMPAXWithMaskPatternNameT5)) { + (pattern_name == kMPAWithMaskPatternNameT5) || + (pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New)) { + if (pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New) { + t5_x_ = true; + } return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), true); - } else if (pattern_name == kMPAPatternName) { - return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), false, false); - } else if (pattern_name == kMPAXPatternName) { - return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), true, false); } - - { return nullptr; } + if (pattern_name == kMPAPatternName) + return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), false); + return nullptr; } -STATUS GetIntParameterData(const ParameterPtr ¶m_ptr, std::vector *result) { - if (param_ptr == nullptr || !param_ptr->has_default()) { - MS_LOG(DEBUG) << "param not have default"; - return RET_ERROR; - } - auto default_param = param_ptr->default_param(); - if (default_param == nullptr || !utils::isa(default_param)) { - MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr"; - return RET_ERROR; - } - auto default_param_ptr = utils::cast(default_param); - if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) { - MS_LOG(DEBUG) << "default param is not int"; - return RET_ERROR; - } - auto ptr = reinterpret_cast(default_param_ptr->data_c()); - int64_t shape_size = - std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>()); - for (int64_t i = 0; i < shape_size; i++) { - result->emplace_back(ptr[i]); - } - return RET_OK; -} +// STATUS GetIntParameterData(const ParameterPtr ¶m_ptr, std::vector *result) { +// if (param_ptr == nullptr || !param_ptr->has_default()) { +// MS_LOG(DEBUG) << "param not have default"; +// return RET_ERROR; +// } +// auto default_param = param_ptr->default_param(); +// if (default_param == nullptr || !utils::isa(default_param)) { +// MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr"; +// return RET_ERROR; +// } +// auto default_param_ptr = utils::cast(default_param); +// if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) { +// MS_LOG(DEBUG) << "default param is not int"; +// return RET_ERROR; +// } +// auto ptr = reinterpret_cast(default_param_ptr->data_c()); +// int64_t shape_size = +// std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>()); +// for (int64_t i = 0; i < shape_size; i++) { +// result->emplace_back(ptr[i]); +// } +// return RET_OK; +// } std::shared_ptr MultiHeadAttentionFusion::BuildAttentionPrim(const EquivPtr &equiv) const { MS_ASSERT(equiv != nullptr); @@ -595,50 +716,115 @@ CNodePtr MultiHeadAttentionFusion::MakeGetTuple(const FuncGraphPtr &func_graph, return get_item_node; } -CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph, - const EquivPtr &equiv, const string &base_name, - bool cross, bool mask) const { - MS_ASSERT(func_graph != nullptr); - MS_ASSERT(equiv != nullptr); - std::vector redundant; - auto attention_prim = CreatePrim(equiv, cross); - MS_CHECK_TRUE_RET(attention_prim != nullptr, nullptr); - auto attention_prim_c = attention_prim->GetPrim(); - MS_CHECK_TRUE_RET(attention_prim_c != nullptr, nullptr); - auto value_node = NewValueNode(attention_prim_c); - MS_CHECK_TRUE_RET(value_node != nullptr, nullptr); - +bool MultiHeadAttentionFusion::IsCross(const EquivPtr &equiv) const { auto input_q = utils::cast((*equiv)[input_q_]); - auto input_k = utils::cast((*equiv)[input_k_]); auto input_v = utils::cast((*equiv)[input_v_]); - AnfNodePtr input_mask; + + ShapeVector inputq_shape, inputv_shape; + auto ret = FetchShapeFromAbstract(input_q->abstract(), &inputq_shape); + MS_CHECK_TRUE_RET(ret == RET_OK, false); + ret = FetchShapeFromAbstract(input_v->abstract(), &inputv_shape); + MS_CHECK_TRUE_RET(ret == RET_OK, false); + + if ((inputq_shape != inputv_shape) || ((match_count_ > 1) && (input_q != input_v))) { + return true; + } + return false; +} + +using tup = std::tuple, std::shared_ptr, + std::shared_ptr>; + +tup MultiHeadAttentionFusion::GetAttentionNodeWeights(const EquivPtr &equiv, std::vector *redundant) const { auto weight_q = utils::cast((*equiv)[weight_q_]); - redundant.push_back(weight_q); auto weight_k = utils::cast((*equiv)[weight_k_]); auto weight_v = utils::cast((*equiv)[weight_v_]); - redundant.push_back(weight_k); - redundant.push_back(weight_v); + redundant->push_back(weight_q); + redundant->push_back(weight_k); + redundant->push_back(weight_v); auto weight_o = utils::cast((*equiv)[weight_o_]); - auto bias_q = utils::cast((*equiv)[bias_q_]); - if (!cross) { - redundant.push_back(bias_q); - } - auto bias_k = utils::cast((*equiv)[bias_k_]); - auto bias_v = utils::cast((*equiv)[bias_v_]); - redundant.push_back(bias_k); - redundant.push_back(bias_v); - auto bias_o = utils::cast((*equiv)[bias_o_]); - auto knode = utils::cast((*equiv)[k_transpose_]); - auto vnode = utils::cast((*equiv)[v_transpose_]); - if (mask) { - input_mask = utils::cast((*equiv)[mask_]); - } std::shared_ptr weight_q_tensor = GetTensorInfo(weight_q); std::shared_ptr weight_k_tensor = GetTensorInfo(weight_k); std::shared_ptr weight_v_tensor = GetTensorInfo(weight_v); - std::shared_ptr bias_q_tensor = GetTensorInfo(bias_q); - std::shared_ptr bias_k_tensor = GetTensorInfo(bias_k); - std::shared_ptr bias_v_tensor = GetTensorInfo(bias_v); + return make_tuple(weight_o, weight_q_tensor, weight_k_tensor, weight_v_tensor); +} + +std::vector MultiHeadAttentionFusion::GetNewNodeInputs(const EquivPtr &equiv, ParameterPtr q_weight_param, + ParameterPtr c_weight_param, AnfNodePtr weight_o, + ParameterPtr c_bias_param, AnfNodePtr bias_o, + bool mask, bool cross) const { + auto input_q = utils::cast((*equiv)[input_q_]); + auto input_k = utils::cast((*equiv)[input_k_]); + auto input_v = utils::cast((*equiv)[input_v_]); + AnfNodePtr input_mask = (mask) ? utils::cast((*equiv)[mask_]) : nullptr; + AnfNodePtr position_bias = (t5_x_) ? utils::cast((*equiv)[position_bias_]) : nullptr; + + auto attention_prim = CreatePrim(equiv, cross); + MS_CHECK_TRUE_RET(attention_prim != nullptr, {}); + auto attention_prim_c = attention_prim->GetPrim(); + MS_CHECK_TRUE_RET(attention_prim_c != nullptr, {}); + auto value_node = NewValueNode(attention_prim_c); + MS_CHECK_TRUE_RET(value_node != nullptr, {}); + std::vector new_node_inputs; + if (cross) { + if (t5_x_) { + new_node_inputs = {value_node, input_q, input_k, input_v, + q_weight_param, c_weight_param, weight_o, position_bias}; + } else { + new_node_inputs = {value_node, input_q, input_k, input_v, q_weight_param, + c_weight_param, weight_o, c_bias_param, bias_o}; + } + } else { + if (t5_x_) { + new_node_inputs = {value_node, input_q, input_k, input_v, c_weight_param, weight_o, position_bias}; + } else { + new_node_inputs = {value_node, input_q, input_k, input_v, c_weight_param, weight_o, c_bias_param, bias_o}; + } + } + if (mask) { + new_node_inputs.push_back(input_mask); + } + return new_node_inputs; +} + +CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph, + const EquivPtr &equiv, const string &base_name, + bool mask) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(equiv != nullptr); + bool cross = IsCross(equiv); + std::vector redundant; + auto [weight_o, weight_q_tensor, weight_k_tensor, weight_v_tensor] = GetAttentionNodeWeights(equiv, &redundant); + AnfNodePtr bias_q; + ParameterPtr c_bias_param; + AnfNodePtr bias_o; + if (!t5_x_) { + bias_q = utils::cast((*equiv)[bias_q_]); + auto bias_k = utils::cast((*equiv)[bias_k_]); + auto bias_v = utils::cast((*equiv)[bias_v_]); + redundant.push_back(bias_k); + redundant.push_back(bias_v); + bias_o = utils::cast((*equiv)[bias_o_]); + std::shared_ptr bias_q_tensor = GetTensorInfo(bias_q); + std::shared_ptr bias_k_tensor = GetTensorInfo(bias_k); + std::shared_ptr bias_v_tensor = GetTensorInfo(bias_v); + auto c_bias = ConcatTensors({bias_q_tensor, bias_k_tensor, bias_v_tensor}); + c_bias_param = func_graph->add_parameter(); + MS_CHECK_TRUE_RET(c_bias_param != nullptr, nullptr); + c_bias_param->set_name(base_name + "/bias_qkv"); + if (lite::InitParameterFromTensorInfo(c_bias_param, c_bias) != lite::RET_OK) { + MS_LOG(ERROR) << "Init parameter from tensor info failed."; + return nullptr; + } + } + auto knode = utils::cast((*equiv)[k_transpose_]); + AnfNodePtr vnode; + auto it_vnode = (*equiv).find(v_transpose_); + if (it_vnode != (*equiv).end() && !t5_x_) vnode = utils::cast(it_vnode->second); + if (!cross && !t5_x_) { + redundant.push_back(bias_q); + } + tensor::TensorPtr c_weights; tensor::TensorPtr q_weight_t; if (cross) { @@ -647,7 +833,6 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func } else { c_weights = ConcatTensors({weight_q_tensor, weight_k_tensor, weight_v_tensor}, true); } - auto c_bias = ConcatTensors({bias_q_tensor, bias_k_tensor, bias_v_tensor}); // convert tensors to params auto c_weight_param = func_graph->add_parameter(); MS_CHECK_TRUE_RET(c_weight_param != nullptr, nullptr); @@ -656,13 +841,6 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func return nullptr; } c_weight_param->set_name(base_name + "/weight_qkv"); - auto c_bias_param = func_graph->add_parameter(); - MS_CHECK_TRUE_RET(c_bias_param != nullptr, nullptr); - if (lite::InitParameterFromTensorInfo(c_bias_param, c_bias) != lite::RET_OK) { - MS_LOG(ERROR) << "Init parameter from tensor info failed."; - return nullptr; - } - c_bias_param->set_name(base_name + "/bias_qkv"); ParameterPtr q_weight_param; if (cross) { q_weight_param = func_graph->add_parameter(); @@ -672,25 +850,26 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func return nullptr; } } - std::vector new_node_inputs; - if (cross) { - new_node_inputs = {value_node, input_q, input_k, input_v, q_weight_param, - c_weight_param, weight_o, c_bias_param, bias_o}; - } else { - new_node_inputs = {value_node, input_q, input_k, input_v, c_weight_param, weight_o, c_bias_param, bias_o}; - } - if (mask) { - new_node_inputs.push_back(input_mask); - } + std::vector new_node_inputs = + GetNewNodeInputs(equiv, q_weight_param, c_weight_param, weight_o, c_bias_param, bias_o, mask, cross); + auto new_node = func_graph->NewCNode(new_node_inputs); MS_CHECK_TRUE_RET(new_node != nullptr, nullptr); - auto status = SetAbstractTuple(new_node, kAttentionOutputs); - if (status != RET_OK) { - return nullptr; + if (vnode) { + auto status = SetAbstractTuple(new_node, kAttentionOutputs); + if (status != RET_OK) { + return nullptr; + } } new_node->set_fullname_with_scope(base_name + "/attention"); - auto get_item_node = MakeGetTuple(func_graph, new_node, knode, vnode); + CNodePtr ret_node; + if (vnode) { + auto get_item_node = MakeGetTuple(func_graph, new_node, knode, vnode); + ret_node = get_item_node; + } else { + ret_node = new_node; + } RemoveRedundantInput(func_graph, redundant); - return get_item_node; + return ret_node; } } // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.h b/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.h index f886d3aa3c7..53fc9abc567 100644 --- a/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/multi_head_attention_fusion.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "tools/optimizer/common/multiple_pattern_process_pass.h" #include "include/common/utils/utils.h" #include "include/errorcode.h" @@ -46,15 +47,18 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass { private: // define patterns - VectorRef DefineMPWithMaskPattern(bool cross = false, bool mask = true) const; - VectorRef DefineMPWithMaskPatternPA(bool cross = false) const; - VectorRef DefineMPWithMaskPatternT5(bool cross = false) const; + VectorRef DefineMPWithMaskPattern(bool mask = true) const; + VectorRef DefineMPWithMaskPatternPA() const; + VectorRef DefineMPWithMaskPatternT5() const; + VectorRef DefineMPWithMaskPatternT5New(bool transpose = true) const; VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, const BaseRef &axis, const BaseRef &transpose_var, bool test_div = false, bool transpose = true) const; + VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &axis, + const BaseRef &transpose_var, bool test_div, bool transpose) const; // create masked-multi-head-attention CNodePtr CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, - const std::string &base_name, bool cross = false, bool mask = true) const; + const std::string &base_name, bool mask = true) const; // check pattern bool CheckPattern(const EquivPtr &equiv, int *head_num, int *head_size) const; CNodePtr CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, const int item_index) const; @@ -66,20 +70,27 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass { CNodePtr MakeGetTuple(const FuncGraphPtr &func_graph, const CNodePtr &new_node, const AnfNodePtr &knode, const AnfNodePtr &vnode) const; std::shared_ptr CreatePrim(const EquivPtr &equiv, bool cross) const; + bool IsCross(const EquivPtr &equiv) const; + std::vector GetNewNodeInputs(const EquivPtr &equiv, ParameterPtr q_weight_param, + ParameterPtr c_weight_param, AnfNodePtr weight_o, ParameterPtr c_bias_param, + AnfNodePtr bias_o, bool mask, bool cross) const; + std::tuple, std::shared_ptr, + std::shared_ptr > + GetAttentionNodeWeights(const EquivPtr &equiv, std::vector *redundant) const; + mutable int match_count_ = 0; protected: const std::string kMPAWithMaskPatternName = "MPAWithMaskPattern"; - const std::string kMPAXWithMaskPatternName = "MPAXWithMaskPattern"; const std::string kMPAWithMaskPatternNamePA = "MPAWithMaskPatternPA"; - const std::string kMPAXWithMaskPatternNamePA = "MPAXWithMaskPatternPA"; const std::string kMPAPatternName = "MPAPattern"; - const std::string kMPAXPatternName = "MPAXPattern"; const std::string kMPAWithMaskPatternNameT5 = "MPAWithMaskPatternT5"; - const std::string kMPAXWithMaskPatternNameT5 = "MPAXWithMaskPatternT5"; + const std::string kMPAWithMaskPatternNameT5New = "MPAWithMaskPatternT5New"; + const std::string kMPAWithMaskTransposePatternNameT5New = "MPAWithMaskTransposePatternT5New"; mutable VarPtr input_q_{nullptr}; mutable VarPtr input_k_{nullptr}; mutable VarPtr input_v_{nullptr}; + mutable VarPtr position_bias_{nullptr}; mutable VarPtr weight_q_{nullptr}; mutable VarPtr weight_k_{nullptr}; @@ -98,8 +109,9 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass { mutable VarPtr reshape_axis_{nullptr}; mutable VarPtr v_transpose_{nullptr}; mutable VarPtr k_transpose_{nullptr}; -}; + mutable bool t5_x_{false}; +}; } // namespace opt } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_MULTI_HEAD_ATTENTION_FUSION_H_ diff --git a/third_party/patch/fast_transformer/001-fast_transformer.patch b/third_party/patch/fast_transformer/001-fast_transformer.patch index aad56628365..f53ebccf624 100644 --- a/third_party/patch/fast_transformer/001-fast_transformer.patch +++ b/third_party/patch/fast_transformer/001-fast_transformer.patch @@ -116,19 +116,23 @@ index 6f535da..0000000 -} \ No newline at end of file diff --git a/3rdparty/trt_fused_multihead_attention/CMakeLists.txt b/3rdparty/trt_fused_multihead_attention/CMakeLists.txt -index 8707220..aea35e6 100644 +index 8707220..c9369e0 100644 --- a/3rdparty/trt_fused_multihead_attention/CMakeLists.txt +++ b/3rdparty/trt_fused_multihead_attention/CMakeLists.txt -@@ -21,7 +21,6 @@ set(trt_fused_multi_head_attention_files +@@ -21,7 +21,10 @@ set(trt_fused_multi_head_attention_files ) file(GLOB trt_fused_multi_head_attention_files ${trt_fused_multi_head_attention_files} *.sm*.cpp) - ++if(${CUDA_VERSION_STRING} VERSION_LESS_EQUAL "10.1.105" ) ++#this cuda don't support sm80 ++ list(REMOVE_ITEM trt_fused_multi_head_attention_files fused_mha_with_relPosBias_fp16_64_32_kernel.sm80.cpp) ++endif() add_library(trt_fused_multi_head_attention STATIC ${trt_fused_multi_head_attention_files}) target_link_libraries(trt_fused_multi_head_attention PUBLIC -lcublas -lcudart) set_property(TARGET trt_fused_multi_head_attention PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/CMakeLists.txt b/CMakeLists.txt -index ea21014..97b842e 100644 +index ea21014..cf30782 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,7 +14,9 @@ @@ -151,7 +155,7 @@ index ea21014..97b842e 100644 ") # -rdc=true") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") -@@ -136,7 +136,11 @@ if(NOT (FIND_SM STREQUAL True)) +@@ -136,7 +136,12 @@ if(NOT (FIND_SM STREQUAL True)) set(ENV{TORCH_CUDA_ARCH_LIST} "7.0;7.5;8.0;8.6") endif() set(CMAKE_CUDA_ARCHITECTURES 70 75 80 86) @@ -161,10 +165,11 @@ index ea21014..97b842e 100644 + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES 80 86) +endif() + message("-- Assign GPU architecture (sm=${CMAKE_CUDA_ARCHITECTURES})") ++ set(SM 70) endif() if(BUILD_PYT) -@@ -152,8 +156,9 @@ set(CMAKE_CXX_STANDARD "${CXX_STD}") +@@ -152,8 +157,9 @@ set(CMAKE_CXX_STANDARD "${CXX_STD}") set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") @@ -176,27 +181,31 @@ index ea21014..97b842e 100644 set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") # set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose") set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3") -@@ -230,8 +235,10 @@ link_directories( +@@ -230,9 +236,10 @@ link_directories( add_subdirectory(3rdparty) add_subdirectory(src) -add_subdirectory(examples) -add_subdirectory(tests) +- +if(EXAMPLES) + add_subdirectory(examples) + add_subdirectory(tests) +endif() - ######################################## -@@ -313,6 +320,7 @@ add_library(transformer-static STATIC + if(BUILD_MULTI_GPU) +@@ -313,8 +320,9 @@ add_library(transformer-static STATIC set_property(TARGET transformer-static PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET transformer-static PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(transformer-static PUBLIC -lcudart -lnccl -lmpi -lcublas -lcublasLt -lcurand) +endif() - add_library(transformer-shared SHARED +-add_library(transformer-shared SHARED ++set(transformer_objects $ + $ + $ @@ -324,29 +332,9 @@ add_library(transformer-shared SHARED $ $ @@ -227,19 +236,26 @@ index ea21014..97b842e 100644 $ $ $ -@@ -373,7 +361,6 @@ add_library(transformer-shared SHARED +@@ -373,9 +361,7 @@ add_library(transformer-shared SHARED $ $ $ - $ $ - $ +- $ $ -@@ -387,14 +374,17 @@ add_library(transformer-shared SHARED + $ + $ +@@ -387,14 +373,22 @@ add_library(transformer-shared SHARED $ $ $) + ++if(${SM} GREATER_EQUAL 70) ++ set(transformer_objects ${transformer_objects} $) ++endif() ++ ++add_library(transformer-shared SHARED ${transformer_objects}) set_target_properties(transformer-shared PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(transformer-shared PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) set_target_properties(transformer-shared PROPERTIES LINKER_LANGUAGE CXX) @@ -255,7 +271,7 @@ index ea21014..97b842e 100644 include(CMakePackageConfigHelpers) configure_package_config_file( ${CMAKE_CURRENT_LIST_DIR}/cmake/FasterTransformerConfig.cmake.in -@@ -402,28 +392,14 @@ configure_package_config_file( +@@ -402,52 +396,23 @@ configure_package_config_file( INSTALL_DESTINATION ${INSTALL_CONFIGDIR} ) @@ -286,7 +302,11 @@ index ea21014..97b842e 100644 ) file(GLOB_RECURSE HEADER_FILES "*.h" "*.hpp" "*.cuh") -@@ -434,20 +410,5 @@ foreach ( file ${HEADER_FILES} ) + foreach ( file ${HEADER_FILES} ) + file( RELATIVE_PATH rfile ${CMAKE_CURRENT_SOURCE_DIR} ${file} ) + get_filename_component( dir ${rfile} DIRECTORY ) +- install( FILES ${file} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${dir} ) ++ install( FILES ${file} DESTINATION ${CMAKE_INSTALL_PREFIX}/include/${dir}) endforeach() @@ -322,10 +342,10 @@ index a60983c..45b5374 100644 diff --git a/deploy.sh b/deploy.sh new file mode 100755 -index 0000000..ba7f644 +index 0000000..6a8d4b0 --- /dev/null +++ b/deploy.sh -@@ -0,0 +1,25 @@ +@@ -0,0 +1,26 @@ +#copy cuda folder (once) +base=`git rev-parse --show-toplevel` +server=10.10.10.174 @@ -343,6 +363,7 @@ index 0000000..ba7f644 +rsync -v ${file} ${server}:${file} +echo "file=${file}" +rsync -v ${base}/../mindspore/trc/transformer/*.fp32 ${server}:${base}/build/bin ++rsync -v ${base}/build/lib/*.so ${server}:${base}/build/lib +# echo "cd ${base}/build/bin/" +command=$(cat <<-ENDM + CUDA_VISIBLE_DEVICES=0 \ @@ -350,7 +371,7 @@ index 0000000..ba7f644 +ENDM +) +echo "command=${command}" -+ssh ${server} "cd ${base}/build/bin ;${command}" ++ssh ${server} "cd ${base}/build/bin; LD_LIBRARY_PATH={base}/../FasterTransformer:/usr/local/cuda-11.7/lib64 ${command}" diff --git a/docs/gpt_guide.md b/docs/gpt_guide.md index afcba9a..71c4fab 100644 --- a/docs/gpt_guide.md @@ -404,10 +425,10 @@ index 0000000..09920ff +endif() diff --git a/examples/cpp/ms/initialize.h b/examples/cpp/ms/initialize.h new file mode 100644 -index 0000000..b607656 +index 0000000..9a760a9 --- /dev/null +++ b/examples/cpp/ms/initialize.h -@@ -0,0 +1,275 @@ +@@ -0,0 +1,502 @@ +#pragma once + +#include "src/fastertransformer/layers/attention_layers/AttentionWeight.h" @@ -430,17 +451,21 @@ index 0000000..b607656 +}; +template +struct DecriptorTest{ -+ std::vector input_tensors; -+ std::vector output_tensors; -+ std::vector output_python_tensors; ++ std::vector input_tensors; // GPU ++ std::vector input_python_tensors; // CPU ++ std::vector output_tensors; // GPU ++ std::vector output_python_tensors; //CPU + std::vector w_tensors; + BaseAttentionLayer* Attn; ++ // +}; + +typedef enum { + MHA_X1 = 1, // AttnIn + AttnMask + MHA_X2, // AttnIn + EncOut -- same seq size + AttnMask + MHA_CROSS, // AttnIn + EncOut + AttnMAsk ++ MHA_T5, // AttnIn + AttnMAsk + position_bias ++ MHA_T5_CROSS, // AttnIn + EncOut + AttnMAsk + position_bias +}MODEL_TEST_ID_E; + +int ModelNum(std::string model_name) { @@ -448,8 +473,12 @@ index 0000000..b607656 + return MHA_X1; + } else if (model_name == "mha_x2") { + return MHA_X2; -+ } else if (model_name == "mha_cross") { ++ } else if (model_name == "mha_cross") { + return MHA_CROSS; ++ } else if (model_name == "mha_T5") { ++ return MHA_T5; ++ } else if (model_name == "mha_T5_cross") { ++ return MHA_T5_CROSS; + } else { + return -1; + } @@ -463,6 +492,7 @@ index 0000000..b607656 + Allocator* allocator) { + const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; + ++ // TODO Nizzan - check if need to be + desc.Attn = new MSMHALayer(opt_a->batch_size, + opt_a->seq_len, + opt_a->tgt_seq_len, @@ -471,9 +501,10 @@ index 0000000..b607656 + stream, + cublas_wrapper, + allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false); // sparse ++ false, // free buffer after fwd ++ true, // is_qk_buf_float_ ++ false, // sparse ++ false); // is_position_bias + + desc.input_tensors.push_back(Tensor{MEMORY_GPU, + getTensorType(), @@ -483,6 +514,16 @@ index 0000000..b607656 + getTensorType(), + std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, + 0}); ++ ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size * opt_a->seq_len,hidden_units}, ++ 0}); ++ ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, ++ 0}); + // GPU RESULTS + desc.output_tensors.push_back(Tensor{ + MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); @@ -523,9 +564,10 @@ index 0000000..b607656 + stream, + cublas_wrapper, + allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false); // sparse ++ false, // free buffer after fwd ++ true, // is_qk_buf_float_ ++ false, // sparse ++ false); // is_position_bias + + desc.input_tensors.push_back(Tensor{MEMORY_GPU, + getTensorType(), @@ -541,7 +583,7 @@ index 0000000..b607656 + getTensorType(), + std::vector{opt_a->batch_size, 1, opt_a->seq_len, (size_t)(opt_a->seq_len)}, + 0}); -+ ++ + // GPU RESULTS + desc.output_tensors.push_back(Tensor{ + MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); @@ -584,9 +626,10 @@ index 0000000..b607656 + stream, + cublas_wrapper, + allocator, -+ false, // free buffer after fwd -+ true, // is_qk_buf_float_ -+ false); // sparse ++ false, // free buffer after fwd ++ true, // is_qk_buf_float_ ++ false, // sparse ++ false); // is_position_bias + + desc.input_tensors.push_back(Tensor{MEMORY_GPU, + getTensorType(), @@ -602,22 +645,36 @@ index 0000000..b607656 + getTensorType(), + std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, + 0}); ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size * opt_a->seq_len, hidden_units}, ++ 0}); ++ ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size * opt_a->tgt_seq_len, hidden_units}, ++ 0}); ++ ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, ++ 0}); + + // GPU RESULTS + + desc.output_tensors.push_back(Tensor{ + MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); -+ desc.output_tensors.push_back(Tensor{ -+ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_tensors.push_back(Tensor{ ++ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_tensors.push_back(Tensor{ ++ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); + + desc.output_python_tensors.push_back(Tensor{ + MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); -+ desc.output_python_tensors.push_back(Tensor{ -+ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_python_tensors.push_back(Tensor{ ++ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_python_tensors.push_back(Tensor{ ++ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); + + + desc.w_tensors.push_back( @@ -630,6 +687,172 @@ index 0000000..b607656 + desc.w_tensors.push_back( + Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units}, 0}); +} ++template ++void InitializeAttnT5(opt_arg* opt_a, ++ DecriptorTest &desc, ++ cudaStream_t stream, ++ cublasMMWrapper* cublas_wrapper, ++ Allocator* allocator) { ++ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; ++ ++ desc.Attn = new MSMHALayer(opt_a->batch_size, ++ opt_a->seq_len, ++ opt_a->tgt_seq_len, ++ opt_a->head_num, ++ opt_a->size_per_head, ++ stream, ++ cublas_wrapper, ++ allocator, ++ false, // free buffer after fwd ++ true, // is_qk_buf_float_ ++ false, // sparse ++ true); // is_position_bias ++ ++ desc.input_tensors.push_back(Tensor{MEMORY_GPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size * opt_a->seq_len,hidden_units}, ++ 0}); ++ desc.input_tensors.push_back(Tensor{MEMORY_GPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, ++ 0}); ++ ++ desc.input_tensors.push_back(Tensor{MEMORY_GPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, ++ 0}); ++ ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size * opt_a->seq_len,hidden_units}, ++ 0}); ++ ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->seq_len}, ++ 0}); ++ ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, ++ 0}); ++ ++ ++ // GPU RESULTS ++ ++ desc.output_tensors.push_back(Tensor{ ++ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); ++ // desc.output_tensors.push_back(Tensor{ ++ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_tensors.push_back(Tensor{ ++ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_tensors.push_back(Tensor{ ++ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len},0}); ++ ++ desc.output_python_tensors.push_back(Tensor{ ++ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); ++ // desc.output_python_tensors.push_back(Tensor{ ++ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_python_tensors.push_back(Tensor{ ++ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_python_tensors.push_back(Tensor{ ++ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, 0}); ++ ++ desc.w_tensors.push_back( ++ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 3 * hidden_units}, 0}); ++ desc.w_tensors.push_back( ++ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); ++} ++ ++template ++void InitializeAttnT5Cross(opt_arg* opt_a, ++ DecriptorTest &desc, ++ cudaStream_t stream, ++ cublasMMWrapper* cublas_wrapper, ++ Allocator* allocator) { ++ const size_t hidden_units = opt_a->head_num * opt_a->size_per_head; ++ ++ desc.Attn = new MSMHALayer(opt_a->batch_size, ++ opt_a->seq_len, ++ opt_a->tgt_seq_len, ++ opt_a->head_num, ++ opt_a->size_per_head, ++ stream, ++ cublas_wrapper, ++ allocator, ++ false, // free buffer after fwd ++ true, // is_qk_buf_float_ ++ false, // sparse ++ true); // is_position_bias ++ ++ desc.input_tensors.push_back(Tensor{MEMORY_GPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size * opt_a->seq_len,hidden_units}, ++ 0}); ++ ++ desc.input_tensors.push_back(Tensor{MEMORY_GPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size * opt_a->tgt_seq_len, hidden_units}, ++ 0}); ++ ++ desc.input_tensors.push_back(Tensor{MEMORY_GPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, ++ 0}); ++ ++ desc.input_tensors.push_back(Tensor{MEMORY_GPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, ++ 0}); ++ ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size * opt_a->seq_len,hidden_units}, ++ 0}); ++ ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size * opt_a->tgt_seq_len, hidden_units}, ++ 0}); ++ ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size, 1, opt_a->seq_len, opt_a->tgt_seq_len}, ++ 0}); ++ ++ desc.input_python_tensors.push_back(Tensor{MEMORY_CPU, ++ getTensorType(), ++ std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, ++ 0}); ++ ++ ++ // GPU RESULTS ++ ++ desc.output_tensors.push_back(Tensor{ ++ MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); ++ // desc.output_tensors.push_back(Tensor{ ++ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_tensors.push_back(Tensor{ ++ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_tensors.push_back(Tensor{ ++ // MEMORY_GPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len},0}); ++ ++ desc.output_python_tensors.push_back(Tensor{ ++ MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->seq_len, hidden_units}, 0}); ++ // desc.output_python_tensors.push_back(Tensor{ ++ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_python_tensors.push_back(Tensor{ ++ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->tgt_seq_len, opt_a->size_per_head}, 0}); ++ // desc.output_python_tensors.push_back(Tensor{ ++ // MEMORY_CPU, getTensorType(), std::vector{opt_a->batch_size, opt_a->head_num, opt_a->seq_len, opt_a->tgt_seq_len}, 0}); ++ ++ desc.w_tensors.push_back( ++ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); ++ desc.w_tensors.push_back( ++ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, 2 * hidden_units}, 0}); ++ desc.w_tensors.push_back( ++ Tensor{MEMORY_GPU, getTensorType(), std::vector{hidden_units, hidden_units}, 0}); ++} + +template +void Init(opt_arg* opt_a, @@ -648,19 +871,33 @@ index 0000000..b607656 + allocator); + break; + case MHA_X2: -+ InitializeAttnX2(opt_a, ++ InitializeAttnX2(opt_a, + desc, + stream, + cublas_wrapper, + allocator); + break; + case MHA_CROSS: -+ InitializeAttnCross(opt_a, ++ InitializeAttnCross(opt_a, + desc, + stream, + cublas_wrapper, + allocator); + break; ++ case MHA_T5: ++ InitializeAttnT5(opt_a, ++ desc, ++ stream, ++ cublas_wrapper, ++ allocator); ++ break; ++ case MHA_T5_CROSS: ++ InitializeAttnT5Cross(opt_a, ++ desc, ++ stream, ++ cublas_wrapper, ++ allocator); ++ break; + default: + break; + } @@ -679,16 +916,27 @@ index 0000000..b607656 + attn_weights.key_weight.kernel = (const T*)w_tensors[2].data; + attn_weights.attention_output_weight.kernel = (const T*)w_tensors[3].data; + attn_weights.attention_output_weight.bias = (const T*)w_tensors[4].data; ++ } else if (modelId==MHA_T5) { ++ attn_weights.query_weight.kernel = (const T*)w_tensors[0].data; ++ attn_weights.query_weight.bias = nullptr; ++ attn_weights.attention_output_weight.kernel = (const T*)w_tensors[1].data; ++ attn_weights.attention_output_weight.bias = nullptr; ++ } else if (modelId==MHA_T5_CROSS) { ++ attn_weights.query_weight.kernel = (const T*)w_tensors[0].data; ++ attn_weights.query_weight.bias = nullptr; ++ attn_weights.key_weight.kernel = (const T*)w_tensors[1].data; ++ attn_weights.attention_output_weight.kernel = (const T*)w_tensors[2].data; ++ attn_weights.attention_output_weight.bias = nullptr; + } else { + // return ERROR illegal model ! + } +} diff --git a/examples/cpp/ms/ms.cc b/examples/cpp/ms/ms.cc new file mode 100644 -index 0000000..3121200 +index 0000000..5117f9f --- /dev/null +++ b/examples/cpp/ms/ms.cc -@@ -0,0 +1,434 @@ +@@ -0,0 +1,458 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * @@ -904,14 +1152,26 @@ index 0000000..3121200 + float meanError = 0; + std::cout << "Out tensor size is: " << size << std::endl; + std::cout << "Data of model output: "; -+ for (int j = 0; j < std::min(50, size); j++) { -+ std::cout << static_cast(msTensorData[j]) << " "; -+ } -+ std::cout << std::endl; -+ std::cout << "Data of Ref output : "; -+ for (int j = 0; j < std::min(50, size); j++) { -+ std::cout << static_cast(refOutput[j]) << " "; ++ static int x = 0; ++ ++ if (x == 0) { ++ // for (int j = 0; j < size; j++) { //std::min(50, size) ++ // std::cout << static_cast(msTensorData[j]) << " "; ++ // } ++ // std::cout << std::endl; ++ // std::cout << "Data of Ref output : "; ++ // for (int j = 0; j < size; j++) { //std::min(50, size) ++ // std::cout << static_cast(refOutput[j]) << " "; ++ // } ++ // for (int j = 0; j < size; j++) { //std::min(50, size) ++ // std::cout << "idx=" << j << ++ // " reference " << static_cast(refOutput[j]) << ++ // " model " << static_cast(msTensorData[j]) << ++ // " diff " << std::fabs(static_cast(msTensorData[j]) - static_cast(refOutput[j])) << ++ // std::endl; ++ // } + } ++ x++; + std::cout << std::endl; + for (int j = 0; j < size; j++) { + if (std::isnan(msTensorData[j]) || std::isinf(msTensorData[j])) { @@ -1015,6 +1275,16 @@ index 0000000..3121200 + } +} + ++uint64_t GetTimeUs() { ++ const int USEC = 1000000; ++ const int MSEC = 1000; ++ struct timespec ts = {0, 0}; ++ if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) { ++ return 0; ++ } ++ uint64_t retval = (uint64_t)((ts.tv_sec * USEC) + (ts.tv_nsec / MSEC)); ++ return retval; ++} + +template +int MsExample(opt_arg* opt_a) { @@ -1086,7 +1356,7 @@ index 0000000..3121200 + desc.Attn->forward(&desc.output_tensors, &desc.input_tensors, &attn_weights); + CompareOutput(desc.output_python_tensors, desc.output_tensors); + -+#define DO_TIME ++// #define DO_TIME +#ifdef DO_TIME + // warmup + for (int i = 0; i < 10; i++) { @@ -1095,13 +1365,15 @@ index 0000000..3121200 + + // profile time + const int ite = 1000; -+ CudaTimer cuda_timer(stream); -+ cuda_timer.start(); ++ size_t s = GetTimeUs(); ++ // CudaTimer cuda_timer(stream); ++ // cuda_timer.start(); + for (int i = 0; i < ite; i++) { + desc.Attn->forward(&desc.output_tensors, &desc.input_tensors, &attn_weights); + } -+ float total_time = cuda_timer.stop(); -+ ++ // float total_time = cuda_timer.stop(); ++ size_t e = GetTimeUs(); ++ float total_time = (e - s) / 1000.0; + printf("batch_size %ld seq_len %ld layer %ld " + "AVG FT-CPP-time %.2f ms (%d iterations) " + "Total Time %.2f ms\n", @@ -1144,6 +1416,26 @@ index 0000000..53f5ca6 +++ b/path.sh @@ -0,0 +1 @@ +export PATH=/usr/local/cuda-11/bin:/home/yoni/.vscode-server/bin/4af164ea3a06f701fe3e89a2bcbb421d2026b68f/bin/remote-cli:/home/yoni/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin +diff --git a/src/fastertransformer/kernels/CMakeLists.txt b/src/fastertransformer/kernels/CMakeLists.txt +index 3db0830..3dd4210 100644 +--- a/src/fastertransformer/kernels/CMakeLists.txt ++++ b/src/fastertransformer/kernels/CMakeLists.txt +@@ -159,9 +159,12 @@ add_library(matrix_vector_multiplication STATIC matrix_vector_multiplication.cu) + set_property(TARGET matrix_vector_multiplication PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET matrix_vector_multiplication PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +-add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) +-set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +-set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) ++if(${SM} GREATER_EQUAL 70) ++ message("-- Making custom kernels") ++ add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) ++ set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) ++ set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) ++endif() + + add_library(vit_kernels STATIC vit_kernels.cu) + set_property(TARGET vit_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/fastertransformer/kernels/activation_kernels.cu b/src/fastertransformer/kernels/activation_kernels.cu index 7ff8e0f..e1be64c 100644 --- a/src/fastertransformer/kernels/activation_kernels.cu @@ -1210,10 +1502,51 @@ index 7ff8e0f..e1be64c 100644 template void invokeAddBias(float* out, const float* bias, const int m, const int n, cudaStream_t stream); diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu -index f951e71..4455879 100644 +index f951e71..597a266 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu -@@ -243,6 +243,116 @@ __global__ void softmax_kernel_v4(T* qk_buf_, +@@ -15,6 +15,14 @@ + * limitations under the License. + */ + ++#ifndef CUDART_VERSION ++#error CUDART_VERSION Undefined! ++#elif (CUDART_VERSION >= 11050) ++#include ++#else ++#include "3rdparty/cub/cub.cuh" ++#endif ++ + #include "src/fastertransformer/kernels/bfloat16_fallback_kenrels.cuh" + #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" + #include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" +@@ -23,6 +31,25 @@ + + namespace fastertransformer { + ++ ++const int WARP_SIZE = 32; ++const bool ATTENION_OPT = true; ++const int ATTENTION_BLOCK_SIZE = 256; ++ ++/////////////////////////////////////////////////////////////////////////////////////////////////// ++ ++template ++using Copy_half_t = typename std::conditional< ++ HALF_ELEMENTS_PER_WARP_LOAD == 32, ++ half, ++ typename std::conditional::type>::type>:: ++ type; ++ ++template ++using Copy_t = Copy_half_t; ++ + __inline__ __device__ int target_index(int id1, int id2, int id3, int id4, int dim_1, int dim_2, int dim_3, int dim_4) + { + return id1 * (dim_2 * dim_3 * dim_4) + id3 * (dim_2 * dim_4) + id2 * dim_4 + id4; +@@ -243,6 +270,172 @@ __global__ void softmax_kernel_v4(T* qk_buf_, } } @@ -1291,7 +1624,6 @@ index f951e71..4455879 100644 + qk_offset = + ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * trgt_seq_len + blockDim.x * i + threadIdx.x; + int mask_offset = (blockIdx.y * seq_len + seq_id) * trgt_seq_len + blockDim.x * i + threadIdx.x; -+ + float qk = static_cast(qk_buf_src[qk_offset]); + float mask_val = static_cast(ldg(&attr_mask[mask_offset])); + @@ -1326,14 +1658,103 @@ index f951e71..4455879 100644 + } + } +} ++ ++template ++__global__ void softmax_mix_kernel_bias_v4(T* qk_buf_, ++ const T_M* attr_mask, ++ const T* position_bias, ++ const int batch_size, ++ const int head_num, ++ const int seq_len, ++ const int trgt_seq_len, ++ const T scalar) ++{ ++ T* qk_buf_src = qk_buf_; ++ for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { ++ float data[ITEMS_PER_THREAD]; ++ int qk_offset; ++ __shared__ float s_mean, s_max; ++ float local_max = -1e20f; ++ for (int i = 0; blockDim.x * i + threadIdx.x < trgt_seq_len; i++) { ++ qk_offset = ++ ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * trgt_seq_len + blockDim.x * i + threadIdx.x; ++ int mask_offset = (blockIdx.y * seq_len + seq_id) * trgt_seq_len + blockDim.x * i + threadIdx.x; ++ int bias_offset = qk_offset; ++ float qk = static_cast(qk_buf_src[qk_offset]); ++ float mask_val = static_cast(ldg(&attr_mask[mask_offset])); ++ float bias_val = static_cast(ldg(&position_bias[bias_offset])); ++ ++ mask_val = (1.0f - mask_val) * -10000.0f; ++ ++ data[i] = qk * static_cast(scalar) + mask_val + bias_val; ++ local_max = fmax(local_max, data[i]); ++ } ++ ++ float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max); ++ if (threadIdx.x == 0) { ++ s_max = max_val; ++ } ++ __syncthreads(); ++ ++ float local_sum = 0; ++ for (int i = 0; blockDim.x * i + threadIdx.x < trgt_seq_len; i++) { ++ data[i] = __expf(data[i] - s_max); ++ local_sum += data[i]; ++ } ++ float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum(local_sum); ++ if (threadIdx.x == 0) { ++ s_mean = sum_val + 1e-6f; ++ s_mean = __fdividef(1.0f, s_mean); ++ } ++ __syncthreads(); ++ ++ for (int i = 0; blockDim.x * i + threadIdx.x < trgt_seq_len; i++) { ++ qk_offset = ++ ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * trgt_seq_len + blockDim.x * i + threadIdx.x; ++ qk_buf_[qk_offset] = (T)(data[i] * s_mean); ++ } ++ } ++} + template __global__ void softmax_kernel_v4_half2( T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) -@@ -298,6 +408,61 @@ __global__ void softmax_kernel_v4_half2( - } - } +@@ -267,49 +460,398 @@ __global__ void softmax_kernel_v4_half2( + data[i] = hadd2(hmul2(qk, type2type2(scalar)), mask_val); + +- local_max = fmax(local_max, fmax((float)data[i].x, (float)data[i].y)); ++ local_max = fmax(local_max, fmax((float)data[i].x, (float)data[i].y)); ++ } ++ ++ float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max); ++ if (threadIdx.x == 0) { ++ s_max = max_val; ++ } ++ __syncthreads(); ++ ++ float local_sum = 0; ++ for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { ++ data[i] = hexp2(hsub2(data[i], float2type2(s_max))); ++ local_sum += (float)(data[i].x + data[i].y); ++ } ++ ++ float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum(local_sum); ++ ++ if (threadIdx.x == 0) { ++ s_mean = sum_val + 1e-6f; ++ s_mean = __fdividef(1.0f, s_mean); ++ } ++ __syncthreads(); ++ ++ for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { ++ qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * (seq_len / 2) + blockDim.x * i ++ + threadIdx.x; ++ qk_buf_half2[qk_offset] = hmul2(data[i], float2type2(s_mean)); ++ } ++ } ++} ++ +template +__global__ void softmax_cross_kernel_v4_half2( + T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const int trgt_seq_len, const T scalar) @@ -1389,13 +1810,181 @@ index f951e71..4455879 100644 + } +} + - template - __global__ void softmax_kernel_v5_half2( - T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) -@@ -415,6 +580,123 @@ __global__ void softmax_kernel_v5_half2( - } - } - ++template ++__global__ void softmax_cross_kernel_bias_v4_half2( ++ T* qk_buf_, const T* attr_mask, const T* position_bias, const int batch_size, const int head_num, const int seq_len, const int trgt_seq_len, const T scalar) ++{ ++ using T2 = typename TypeConverter::Type; ++ T2* qk_buf_half2 = (T2*)qk_buf_; ++ const T2* attr_mask_half2 = (const T2*)attr_mask; ++ const T2* position_bias_half2 = (const T2*)position_bias; ++ ++ for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x) { ++ T2 data[ITEMS_PER_THREAD]; ++ int qk_offset; ++ __shared__ float s_mean, s_max; ++ float local_max = -1e20f; ++ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { ++ qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * (trgt_seq_len / 2) + blockDim.x * i ++ + threadIdx.x; ++ int mask_offset = (blockIdx.y * seq_len + seq_id) * (trgt_seq_len / 2) + blockDim.x * i + threadIdx.x; ++ int bias_offset = qk_offset; ++ ++ T2 qk = qk_buf_half2[qk_offset]; ++ T2 mask_val = ldg(&attr_mask_half2[mask_offset]); ++ mask_val = hmul2(hsub2(float2type2(1.0f), mask_val), float2type2(-10000.0f)); ++ T2 bias_val = (ldg(&position_bias_half2[bias_offset])); ++ ++ data[i] = hadd2(hadd2(hmul2(qk, type2type2(scalar)), mask_val), bias_val); ++ ++ local_max = fmax(local_max, fmax((float)data[i].x, (float)data[i].y)); ++ } ++ ++ float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max); ++ if (threadIdx.x == 0) { ++ s_max = max_val; ++ } ++ __syncthreads(); ++ ++ float local_sum = 0; ++ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { ++ data[i] = hexp2(hsub2(data[i], float2type2(s_max))); ++ local_sum += (float)(data[i].x + data[i].y); ++ } ++ ++ float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum(local_sum); ++ ++ if (threadIdx.x == 0) { ++ s_mean = sum_val + 1e-6f; ++ s_mean = __fdividef(1.0f, s_mean); ++ } ++ __syncthreads(); ++ ++ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { ++ qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * (trgt_seq_len / 2) + blockDim.x * i ++ + threadIdx.x; ++ qk_buf_half2[qk_offset] = hmul2(data[i], float2type2(s_mean)); ++ } ++ } ++} ++ ++template ++__global__ void softmax_kernel_v5_half2( ++ T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) ++{ ++ using T2 = typename TypeConverter::Type; ++ T2* qk_buf_half2 = (T2*)qk_buf_; ++ const T2* attr_mask_half2 = (const T2*)attr_mask; ++ ++ for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x * NUM) { ++ T2 data[NUM][ITEMS_PER_THREAD]; ++ ++ int qk_offset[NUM]; ++ ++ __shared__ float s_sum[NUM], s_max[NUM]; ++ float local_max[NUM]; ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ local_max[j] = -1e20f; ++ } ++ ++ for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { ++ int mask_offset[NUM]; ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) ++ + blockDim.x * i + threadIdx.x; ++ mask_offset[j] = ++ (blockIdx.y * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) + blockDim.x * i + threadIdx.x; ++ } ++ ++ T2 mask_val[NUM]; ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ mask_val[j] = ldg(&attr_mask_half2[mask_offset[j]]); ++ } ++ ++ T2 qk[NUM]; ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ qk[j] = qk_buf_half2[qk_offset[j]]; ++ } ++ ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ mask_val[j] = hmul2(hsub2(float2type2(1.0f), mask_val[j]), float2type2(-10000.0f)); ++ } ++ ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ data[j][i] = hadd2(hmul2(qk[j], type2type2(scalar)), mask_val[j]); ++ local_max[j] = fmax(local_max[j], fmax((float)data[j][i].x, (float)data[j][i].y)); ++ } ++ } ++ ++ if (blockDim.x <= 32) { ++ warpReduceMaxV2(local_max); ++ } ++ else { ++ blockReduceMaxV2(local_max); ++ } ++ ++ if (threadIdx.x == 0) { ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ s_max[j] = local_max[j]; ++ } ++ } ++ __syncthreads(); ++ ++ float local_sum[NUM]; ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ local_sum[j] = {0.f}; ++ } ++ ++ for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ data[j][i] = hexp2(hsub2(data[j][i], float2type2(s_max[j]))); ++ } ++ ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ local_sum[j] += (float)(data[j][i].x + data[j][i].y); ++ } ++ } ++ ++ if (blockDim.x <= 32) { ++ warpReduceSumV2(local_sum); ++ } ++ else { ++ blockReduceSumV2(local_sum); ++ } ++ ++ if (threadIdx.x == 0) { ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f); ++ } ++ } ++ __syncthreads(); ++ ++ for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) ++ + blockDim.x * i + threadIdx.x; ++ } ++ ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ qk_buf_half2[qk_offset[j]] = hmul2(data[j][i], float2type2(s_sum[j])); ++ } ++ } ++ } ++} ++ +template +__global__ void softmax_cross_kernel_v5_half2( + T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const int trgt_seq_len, const T scalar) @@ -1455,22 +2044,29 @@ index f951e71..4455879 100644 + } + else { + blockReduceMaxV2(local_max); -+ } -+ -+ if (threadIdx.x == 0) { + } + +- float max_val = blockDim.x <= 32 ? warpReduceMax(local_max) : blockReduceMax(local_max); + if (threadIdx.x == 0) { +- s_max = max_val; +#pragma unroll + for (int j = 0; j < NUM; j++) { + s_max[j] = local_max[j]; + } -+ } -+ __syncthreads(); -+ + } + __syncthreads(); + +- float local_sum = 0; +- for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { +- data[i] = hexp2(hsub2(data[i], float2type2(s_max))); +- local_sum += (float)(data[i].x + data[i].y); + float local_sum[NUM]; +#pragma unroll + for (int j = 0; j < NUM; j++) { + local_sum[j] = {0.f}; -+ } -+ + } + +- float sum_val = blockDim.x <= 32 ? warpReduceSum(local_sum) : blockReduceSum(local_sum); + for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { +#pragma unroll + for (int j = 0; j < NUM; j++) { @@ -1489,15 +2085,21 @@ index f951e71..4455879 100644 + else { + blockReduceSumV2(local_sum); + } -+ -+ if (threadIdx.x == 0) { + + if (threadIdx.x == 0) { +- s_mean = sum_val + 1e-6f; +- s_mean = __fdividef(1.0f, s_mean); +#pragma unroll + for (int j = 0; j < NUM; j++) { + s_sum[j] = __fdividef(1.0f, local_sum[j] + 1e-6f); + } -+ } -+ __syncthreads(); -+ + } + __syncthreads(); + +- for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { +- qk_offset = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id) * (seq_len / 2) + blockDim.x * i +- + threadIdx.x; +- qk_buf_half2[qk_offset] = hmul2(data[i], float2type2(s_mean)); + for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { +#pragma unroll + for (int j = 0; j < NUM; j++) { @@ -1509,14 +2111,101 @@ index f951e71..4455879 100644 + for (int j = 0; j < NUM; j++) { + qk_buf_half2[qk_offset[j]] = hmul2(data[j][i], float2type2(s_sum[j])); + } -+ } -+ } -+} + } + } + } + + template +-__global__ void softmax_kernel_v5_half2( +- T* qk_buf_, const T* attr_mask, const int batch_size, const int head_num, const int seq_len, const T scalar) ++__global__ void softmax_cross_kernel_bias_v5_half2( ++ T* qk_buf_, const T* attr_mask, const T* position_bias, const int batch_size, const int head_num, const int seq_len, const int trgt_seq_len, const T scalar) + { + using T2 = typename TypeConverter::Type; + T2* qk_buf_half2 = (T2*)qk_buf_; + const T2* attr_mask_half2 = (const T2*)attr_mask; ++ const T2* position_bias_half2 = (const T2*)position_bias; + + for (int seq_id = blockIdx.x; seq_id < seq_len; seq_id += gridDim.x * NUM) { + T2 data[NUM][ITEMS_PER_THREAD]; + + int qk_offset[NUM]; ++ int pos_bias_offset[NUM]; + + __shared__ float s_sum[NUM], s_max[NUM]; + float local_max[NUM]; +@@ -318,14 +860,15 @@ __global__ void softmax_kernel_v5_half2( + local_max[j] = -1e20f; + } + +- for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { ++ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { + int mask_offset[NUM]; + #pragma unroll + for (int j = 0; j < NUM; j++) { +- qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) ++ qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (trgt_seq_len / 2) + + blockDim.x * i + threadIdx.x; + mask_offset[j] = +- (blockIdx.y * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) + blockDim.x * i + threadIdx.x; ++ (blockIdx.y * seq_len + seq_id + j * gridDim.x) * (trgt_seq_len / 2) + blockDim.x * i + threadIdx.x; ++ pos_bias_offset[j] = qk_offset[j]; + } + + T2 mask_val[NUM]; +@@ -339,6 +882,12 @@ __global__ void softmax_kernel_v5_half2( + for (int j = 0; j < NUM; j++) { + qk[j] = qk_buf_half2[qk_offset[j]]; + } ++ ++ T2 pos_bias_val[NUM]; ++#pragma unroll ++ for (int j = 0; j < NUM; j++) { ++ pos_bias_val[j] = ldg(&position_bias_half2[pos_bias_offset[j]]); ++ } + + #pragma unroll + for (int j = 0; j < NUM; j++) { +@@ -347,7 +896,7 @@ __global__ void softmax_kernel_v5_half2( + + #pragma unroll + for (int j = 0; j < NUM; j++) { +- data[j][i] = hadd2(hmul2(qk[j], type2type2(scalar)), mask_val[j]); ++ data[j][i] = hadd2(hadd2(hmul2(qk[j], type2type2(scalar)), mask_val[j]), pos_bias_val[j]); + local_max[j] = fmax(local_max[j], fmax((float)data[j][i].x, (float)data[j][i].y)); + } + } +@@ -373,7 +922,7 @@ __global__ void softmax_kernel_v5_half2( + local_sum[j] = {0.f}; + } + +- for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { ++ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { + #pragma unroll + for (int j = 0; j < NUM; j++) { + data[j][i] = hexp2(hsub2(data[j][i], float2type2(s_max[j]))); +@@ -400,10 +949,10 @@ __global__ void softmax_kernel_v5_half2( + } + __syncthreads(); + +- for (int i = 0; blockDim.x * i + threadIdx.x < (seq_len / 2) && i < ITEMS_PER_THREAD; i++) { ++ for (int i = 0; blockDim.x * i + threadIdx.x < (trgt_seq_len / 2) && i < ITEMS_PER_THREAD; i++) { + #pragma unroll + for (int j = 0; j < NUM; j++) { +- qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (seq_len / 2) ++ qk_offset[j] = ((blockIdx.y * head_num + blockIdx.z) * seq_len + seq_id + j * gridDim.x) * (trgt_seq_len / 2) + + blockDim.x * i + threadIdx.x; + } + +@@ -415,6 +964,7 @@ __global__ void softmax_kernel_v5_half2( + } + } + + #define SOFTMAX_KERNEL(ITEMS_PER_THREAD) \ block.x /= ITEMS_PER_THREAD; \ assert(block.x <= 1024); \ -@@ -434,6 +716,50 @@ __global__ void softmax_kernel_v5_half2( +@@ -434,6 +984,72 @@ __global__ void softmax_kernel_v5_half2( <<>>(buffer, buffer_src, attr_mask, batch_size, head_num, seq_len, scalar); \ } @@ -1563,11 +2252,33 @@ index f951e71..4455879 100644 + <<>>(io_buffer, attr_mask, batch_size, head_num, seq_len, \ + trgt_seq_len, scalar); \ + } ++ ++#define SOFTMAX_MIX_KERNEL_BIAS(ITEMS_PER_THREAD) \ ++ block.x /= ITEMS_PER_THREAD; \ ++ assert(block.x <= 1024); \ ++ if (is_half2) { \ ++ if (grid.x % 4 == 0) { \ ++ grid.x /= 4; \ ++ softmax_cross_kernel_bias_v5_half2<<>>( \ ++ (half*)io_buffer, (const half*)attr_mask, (const half*)position_bias, batch_size, head_num, seq_len, trgt_seq_len, \ ++ (const half)scalar); \ ++ } \ ++ else { \ ++ softmax_cross_kernel_bias_v4_half2<<>>( \ ++ (half*)io_buffer, (const half*)attr_mask, (const half*)position_bias, batch_size, head_num, seq_len, trgt_seq_len, \ ++ (const half)scalar); \ ++ } \ ++ } \ ++ else { \ ++ softmax_mix_kernel_bias_v4 \ ++ <<>>(io_buffer, attr_mask, position_bias, batch_size, head_num, seq_len, \ ++ trgt_seq_len, scalar); \ ++ } + #ifdef ENABLE_BF16 #define SOFTMAX_KERNEL_BF16(ITEMS_PER_THREAD) \ block.x /= ITEMS_PER_THREAD; \ -@@ -501,6 +827,80 @@ void invokeMaskedSoftMax(T* buffer, +@@ -501,6 +1117,120 @@ void invokeMaskedSoftMax(T* buffer, } } @@ -1644,11 +2355,51 @@ index f951e71..4455879 100644 + FT_CHECK(trgt_seq_len <= 4096 || seq_len <= 4096); + } +} ++ ++template ++void invokeMixMaskedSoftMax(T* io_buffer, ++ const T_M* attr_mask, ++ const T* position_bias, ++ const int batch_size, ++ const int seq_len, ++ const int trgt_seq_len, ++ const int head_num, ++ const T scalar, ++ cudaStream_t stream) ++{ ++ if (position_bias == nullptr) { ++ invokeMixMaskedSoftMax(io_buffer, attr_mask, batch_size, seq_len, trgt_seq_len, head_num, scalar, stream); ++ } else { ++ dim3 grid(seq_len, batch_size, head_num); ++ if (batch_size * head_num > 360) { ++ grid.x = ceil(float(seq_len) / 32.0f); ++ } ++ ++ bool is_half2 = sizeof(T) == 2 && sizeof(T_M) == 2 && trgt_seq_len % 2 == 0; ++ dim3 block((trgt_seq_len / (is_half2 ? 2 : 1) + 31) / 32 * 32); ++ ++ if (block.x > 3072 && block.x <= 4096) { ++ SOFTMAX_MIX_KERNEL_BIAS(4) ++ } ++ if (block.x > 2048) { ++ SOFTMAX_MIX_KERNEL_BIAS(3) ++ } ++ else if (block.x > 1024) { ++ SOFTMAX_MIX_KERNEL_BIAS(2) ++ } ++ else if (block.x > 0) { ++ SOFTMAX_MIX_KERNEL_BIAS(1) ++ } ++ else { ++ FT_CHECK(trgt_seq_len <= 4096 || seq_len <= 4096); ++ } ++ } ++} + #ifdef ENABLE_BF16 template<> void invokeMaskedSoftMax(__nv_bfloat16* buffer, -@@ -574,13 +974,78 @@ void invokeMaskedSoftMax(__nv_bfloat16* buffer, +@@ -574,13 +1304,118 @@ void invokeMaskedSoftMax(__nv_bfloat16* buffer, FT_CHECK(seq_len <= 4096); } } @@ -1674,7 +2425,6 @@ index f951e71..4455879 100644 + cudaStream_t stream) {;} #endif // ENABLE_BF16 --template void invokeMaskedSoftMax(float* buffer, +template void invokeMixMaskedSoftMax(float* io_buffer, + const float* attr_mask, + const int batch_size, @@ -1711,7 +2461,47 @@ index f951e71..4455879 100644 + const half scalar, + cudaStream_t stream); + -+ template void invokeMaskedSoftMax(float* buffer, ++template void invokeMixMaskedSoftMax(float* io_buffer, ++ const float* attr_mask, ++ const float* position_bias, ++ const int batch_size, ++ const int seq_len, ++ const int tgt_seq_len, ++ const int head_num, ++ const float scalar, ++ cudaStream_t stream); ++ ++template void invokeMixMaskedSoftMax(half* io_buffer, ++ const half* attr_mask, ++ const half* position_bias, ++ const int batch_size, ++ const int seq_len, ++ const int tgt_seq_len, ++ const int head_num, ++ const half scalar, ++ cudaStream_t stream); ++ ++template void invokeMixMaskedSoftMax(float* io_buffer, ++ const half* attr_mask, ++ const float* position_bias, ++ const int batch_size, ++ const int seq_len, ++ const int tgt_seq_len, ++ const int head_num, ++ const float scalar, ++ cudaStream_t stream); ++ ++template void invokeMixMaskedSoftMax(half* io_buffer, ++ const float* attr_mask, ++ const half* position_bias, ++ const int batch_size, ++ const int seq_len, ++ const int tgt_seq_len, ++ const int head_num, ++ const half scalar, ++ cudaStream_t stream); ++ + template void invokeMaskedSoftMax(float* buffer, + const float* buffer_src, + const float* attr_mask, + const int batch_size, @@ -1729,7 +2519,7 @@ index f951e71..4455879 100644 const int head_num, const float scalar, cudaStream_t stream); -@@ -594,6 +1059,15 @@ template void invokeMaskedSoftMax(half* buffer, +@@ -594,6 +1429,15 @@ template void invokeMaskedSoftMax(half* buffer, const half scalar, cudaStream_t stream); @@ -1745,7 +2535,7 @@ index f951e71..4455879 100644 template void invokeMaskedSoftMax(half* buffer, const half* buffer_src, const half* attr_mask, -@@ -603,6 +1077,15 @@ template void invokeMaskedSoftMax(half* buffer, +@@ -603,6 +1447,15 @@ template void invokeMaskedSoftMax(half* buffer, const half scalar, cudaStream_t stream); @@ -1761,7 +2551,7 @@ index f951e71..4455879 100644 #ifdef ENABLE_BF16 template void invokeMaskedSoftMax(__nv_bfloat16* buffer, const __nv_bfloat16* buffer_src, -@@ -621,6 +1104,25 @@ template void invokeMaskedSoftMax(__nv_bfloat16* buffer, +@@ -621,6 +1474,25 @@ template void invokeMaskedSoftMax(__nv_bfloat16* buffer, const int head_num, const __nv_bfloat16 scalar, cudaStream_t stream); @@ -1787,7 +2577,7 @@ index f951e71..4455879 100644 #endif // ENABLE_BF16 template -@@ -726,9 +1228,10 @@ void invokeTransposeQKV(T* dst, +@@ -726,9 +1598,10 @@ void invokeTransposeQKV(T* dst, seq_per_block *= 2; } @@ -1800,7 +2590,7 @@ index f951e71..4455879 100644 block.x = seq_per_block * size_per_head / 2; if (std::is_same::value) { transpose<<>>( -@@ -1061,12 +1564,12 @@ template void invokeTransposeAttentionOutRemovePadding(half* src, +@@ -1061,12 +1934,12 @@ template void invokeTransposeAttentionOutRemovePadding(half* src, const int* mask_offset, cudaStream_t stream); @@ -1815,16 +2605,62 @@ index f951e71..4455879 100644 const int batch_size, const int seq_len, const int head_num, -@@ -1081,7 +1584,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, +@@ -1079,10 +1952,9 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, + T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; + const int n = head_num * size_per_head; for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 3 * n; - index += gridDim.x * blockDim.x) { +- index += gridDim.x * blockDim.x) { ++ index += gridDim.x * blockDim.x) { int bias_id = index % (3 * n); - T val = ldg(&QKV[index]) + ldg(&qkv_bias[bias_id]); +- + T val = ldg(&QKV[index]) + (T)ldg(&qkv_bias[bias_id]); - int tmp_index = index; const int target_batch_id = tmp_index / (seq_len * 3 * n); -@@ -1116,12 +1619,12 @@ struct Vec_t<__nv_bfloat16> { + tmp_index -= target_batch_id * seq_len * 3 * n; +@@ -1098,6 +1970,41 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, + } + } + ++template ++__global__ void transposeQKV_kernel(T* q_buf, ++ T* k_buf, ++ T* v_buf, ++ const T* __restrict QKV, ++ const int batch_size, ++ const int seq_len, ++ const int head_num, ++ const int size_per_head) ++{ ++ // QKV: [m, 3, n] ++ // qkv_bias: [3, n] ++ // q_buf, k_buf, v_buf: [batch, head_num, seq_len, size_per_head] ++ ++ T* qkv_ptr[3] = {q_buf, k_buf, v_buf}; ++ const int n = head_num * size_per_head; ++ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 3 * n; ++ index += gridDim.x * blockDim.x) { ++ T val = ldg(&QKV[index]); ++ int tmp_index = index; ++ const int target_batch_id = tmp_index / (seq_len * 3 * n); ++ tmp_index -= target_batch_id * seq_len * 3 * n; ++ const int seq_id = tmp_index / (3 * n); ++ tmp_index -= seq_id * 3 * n; ++ const int qkv_id = tmp_index / n; ++ tmp_index -= qkv_id * n; ++ const int head_id = tmp_index / size_per_head; ++ const int size_id = tmp_index - head_id * size_per_head; ++ const int dst_id = target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head ++ + seq_id * size_per_head + size_id; ++ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head ++ + seq_id * size_per_head + size_id] = val; ++ } ++} ++ + template + struct Vec_t {}; + template<> +@@ -1116,12 +2023,12 @@ struct Vec_t<__nv_bfloat16> { }; #endif @@ -1839,7 +2675,7 @@ index f951e71..4455879 100644 const int batch_size, const int seq_len, const int head_num, -@@ -1170,12 +1673,12 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, +@@ -1170,12 +2077,12 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, *reinterpret_cast(&v_buf[dest_idx]) = v; } @@ -1854,10 +2690,46 @@ index f951e71..4455879 100644 const int batch_size, const int seq_len, const int head_num, -@@ -1200,6 +1703,155 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, +@@ -1183,23 +2090,262 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, + const int rotary_embedding_dim, + cudaStream_t stream) + { +- if (rotary_embedding_dim == 0) { ++ if (qkv_bias != nullptr) { ++ if (rotary_embedding_dim == 0) { ++ const int m = batch_size * seq_len; ++ const int n = head_num * size_per_head; ++ dim3 block(384); ++ dim3 grid((int)(ceil(1.0 * m * n / 384))); ++ add_fusedQKV_bias_transpose_kernel<<>>( ++ q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); ++ } ++ else { ++ std::cout << "2"<< std::endl; ++ // To implement rotary embeddings, each thread processes two QKV elems: ++ dim3 block((size_per_head / 2 + 31) / 32 * 32); ++ dim3 grid(seq_len, head_num, batch_size); ++ add_fusedQKV_bias_transpose_kernel<<>>( ++ q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, rotary_embedding_dim); ++ } ++ } else { + const int m = batch_size * seq_len; + const int n = head_num * size_per_head; + dim3 block(384); + dim3 grid((int)(ceil(1.0 * m * n / 384))); +- add_fusedQKV_bias_transpose_kernel<<>>( +- q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); ++ transposeQKV_kernel<<>>( ++ q_buf, k_buf, v_buf, QKV, batch_size, seq_len, head_num, size_per_head); } - } - +- else { +- // To implement rotary embeddings, each thread processes two QKV elems: +- dim3 block((size_per_head / 2 + 31) / 32 * 32); +- dim3 grid(seq_len, head_num, batch_size); +- add_fusedQKV_bias_transpose_kernel<<>>( +- q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, rotary_embedding_dim); ++} ++ + +template +__global__ void invokeCrossAddFusedQKVBiasTransposeQ(T* q_buf, @@ -1895,6 +2767,38 @@ index f951e71..4455879 100644 +} + +template ++__global__ void invokeCrossTransposeQ(T* q_buf, ++ const T* __restrict QKV, ++ const int batch_size, ++ const int seq_len, ++ const int head_num, ++ const int size_per_head) ++{ ++ // QKV: [m, 1, n] ++ // q_buf: [batch, head_num, seq_len, size_per_head] ++ ++ T* qkv_ptr[1] = {q_buf}; ++ const int n = head_num * size_per_head; ++ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 1 * n; ++ index += gridDim.x * blockDim.x) { ++ T val = ldg(&QKV[index]); ++ ++ int tmp_index = index; ++ const int target_batch_id = tmp_index / (seq_len * 1 * n); ++ tmp_index -= target_batch_id * seq_len * 1 * n; ++ const int seq_id = tmp_index / (1 * n); ++ tmp_index -= seq_id * 1 * n; ++ const int qkv_id = tmp_index / n; ++ tmp_index -= qkv_id * n; ++ const int head_id = tmp_index / size_per_head; ++ const int size_id = tmp_index - head_id * size_per_head; ++ ++ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head ++ + seq_id * size_per_head + size_id] = val; ++ } ++} ++ ++template +__global__ void invokeCrossAddFusedQKVBiasTransposeKV(T* k_buf, T* v_buf, + const T* __restrict QKV, + const U* __restrict qkv_bias, @@ -1929,6 +2833,39 @@ index f951e71..4455879 100644 + } +} + ++ ++template ++__global__ void invokeCrossTransposeKV(T* k_buf, T* v_buf, ++ const T* __restrict QKV, ++ const int batch_size, ++ const int seq_len, ++ const int head_num, ++ const int size_per_head) ++{ ++ // QKV: [m, 2, n] ++ // k_buf, v_buf: [batch, head_num, seq_len, size_per_head] ++ ++ T* qkv_ptr[2] = {k_buf, v_buf}; ++ const int n = head_num * size_per_head; ++ for (int index = blockDim.x * blockIdx.x + threadIdx.x; index < batch_size * seq_len * 2 * n; ++ index += gridDim.x * blockDim.x) { ++ T val = ldg(&QKV[index]); ++ ++ int tmp_index = index; ++ const int target_batch_id = tmp_index / (seq_len * 2 * n); ++ tmp_index -= target_batch_id * seq_len * 2 * n; ++ const int seq_id = tmp_index / (2 * n); ++ tmp_index -= seq_id * 2 * n; ++ const int qkv_id = tmp_index / n; ++ tmp_index -= qkv_id * n; ++ const int head_id = tmp_index / size_per_head; ++ const int size_id = tmp_index - head_id * size_per_head; ++ //printf("%d %d\n", head_id, size_id); ++ qkv_ptr[qkv_id][target_batch_id * head_num * seq_len * size_per_head + head_id * seq_len * size_per_head ++ + seq_id * size_per_head + size_id] = val; ++ } ++} ++ +template +void invokeCrossAddFusedQKVBiasTranspose(T* q_buf, + T* k_buf, @@ -1942,23 +2879,38 @@ index f951e71..4455879 100644 + const int size_per_head, + cudaStream_t stream) +{ ++ if (qkv_bias != nullptr) { ++ const int m = batch_size * seq_len; ++ const int n = head_num * size_per_head; ++ dim3 block(384); ++ dim3 grid((int)(ceil(1.0 * m * n / 384))); ++ invokeCrossAddFusedQKVBiasTransposeQ<<>>( ++ q_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); ++ ++ const int m2 = batch_size * tgt_seq_len; ++ const int n2 = head_num * size_per_head; ++ dim3 block2(384); ++ dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); ++ invokeCrossAddFusedQKVBiasTransposeKV<<>>( ++ k_buf, v_buf, QKV + m * n, qkv_bias + n2, batch_size, tgt_seq_len, head_num, size_per_head); ++ } else { ++ const int m = batch_size * seq_len; ++ const int n = head_num * size_per_head; ++ dim3 block(384); ++ dim3 grid((int)(ceil(1.0 * m * n / 384))); ++ invokeCrossTransposeQ<<>>( ++ q_buf, QKV, batch_size, seq_len, head_num, size_per_head); ++ ++ const int m2 = batch_size * tgt_seq_len; ++ const int n2 = head_num * size_per_head; ++ dim3 block2(384); ++ dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); ++ invokeCrossTransposeKV<<>>( ++ k_buf, v_buf, QKV + m * n, batch_size, tgt_seq_len, head_num, size_per_head); + } + -+ const int m = batch_size * seq_len; -+ const int n = head_num * size_per_head; -+ dim3 block(384); -+ dim3 grid((int)(ceil(1.0 * m * n / 384))); -+ invokeCrossAddFusedQKVBiasTransposeQ<<>>( -+ q_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head); -+ -+ const int m2 = batch_size * tgt_seq_len; -+ const int n2 = head_num * size_per_head; -+ dim3 block2(384); -+ dim3 grid2((int)(ceil(1.0 * m2 * n2 / 384))); -+ invokeCrossAddFusedQKVBiasTransposeKV<<>>( -+ k_buf, v_buf, QKV + m * n, qkv_bias + n2, batch_size, tgt_seq_len, head_num, size_per_head); -+ -+} -+ + } + +template void invokeCrossAddFusedQKVBiasTranspose(float* q_buf, + float* k_buf, + float* v_buf, @@ -2010,7 +2962,7 @@ index f951e71..4455879 100644 template void invokeAddFusedQKVBiasTranspose(float* q_buf, float* k_buf, float* v_buf, -@@ -1224,6 +1876,30 @@ template void invokeAddFusedQKVBiasTranspose(half* q_buf, +@@ -1224,6 +2370,30 @@ template void invokeAddFusedQKVBiasTranspose(half* q_buf, const int rotary_embedding_dim, cudaStream_t stream); @@ -2041,7 +2993,7 @@ index f951e71..4455879 100644 #ifdef ENABLE_BF16 template void invokeAddFusedQKVBiasTranspose(__nv_bfloat16* q_buf, __nv_bfloat16* k_buf, -@@ -1236,6 +1912,19 @@ template void invokeAddFusedQKVBiasTranspose(__nv_bfloat16* q_buf, +@@ -1236,6 +2406,19 @@ template void invokeAddFusedQKVBiasTranspose(__nv_bfloat16* q_buf, const int size_per_head, const int rotary_embedding_dim, cudaStream_t stream); @@ -2061,11 +3013,435 @@ index f951e71..4455879 100644 #endif template +@@ -1860,4 +3043,423 @@ template void invokeMaskedSoftMaxWithRelPosBias(half* qk_buf, + const float qk_scale, + cudaStream_t stream); + ++ ++ ++template ++__global__ void attention_kernel(T* query_buf, ++ const T* Q_bias, ++ T* key_cache, ++ const T* K_bias, ++ T* value_cache, ++ const T* V_bias, ++ const int* length_per_sample, ++ T* context_buf, ++ const bool* finished, ++ int batch_size, ++ int head_num, ++ int size_per_head, ++ int step, ++ const int seq_len, ++ const T scalar) ++{ ++ if (finished != nullptr && finished[blockIdx.x / head_num] == true) { ++ return; ++ } ++ int tid = threadIdx.x; ++ int bid = blockIdx.x / head_num; ++ int head_id = blockIdx.x % head_num; ++ ++ extern __shared__ __align__(sizeof(T)) unsigned s_buf[]; ++ T* sq = reinterpret_cast(s_buf); ++ T* logits = reinterpret_cast(&sq[size_per_head]); ++ ++ int length = __ldg(&length_per_sample[bid]); ++ ++ int qkv_id = bid * head_num * size_per_head + head_id * size_per_head + tid; ++ int qkv_bias_id = head_id * size_per_head + tid; ++ ++ if (tid < size_per_head) { ++ sq[tid] = query_buf[qkv_id] + Q_bias[qkv_bias_id]; ++ } ++ __syncthreads(); ++ ++ for (int ite = 0; ite < length; ++ite) { ++ int key_id = bid * (seq_len * head_num * size_per_head) + ite * (head_num * size_per_head) ++ + head_id * size_per_head + tid; ++ ++ T key = tid < size_per_head ? key_cache[key_id] : (T)(0.0f); ++ ++ // For the first step, we should add bias to key memory cache. ++ // The KV memory cache only need to be updated at the first step. ++ if (step == 1 && tid < size_per_head) { ++ key += K_bias[head_id * size_per_head + tid]; ++ key_cache[key_id] = key; ++ } ++ ++ T val = (tid < size_per_head) ? key * sq[tid] * scalar : (T)(0.0f); ++ T qk = blockReduceSum(val); ++ if (threadIdx.x == 0) { ++ logits[ite] = qk; ++ } ++ __syncthreads(); // try to remove ++ } ++ __syncthreads(); ++ ++ __shared__ float s_max_val, s_sum; ++ ++ float local_i = tid < length ? (float)logits[tid] : -1e20f; ++ float max_val = blockReduceMax(local_i); ++ if (tid == 0) { ++ s_max_val = max_val; ++ } ++ __syncthreads(); ++ ++ local_i -= s_max_val; ++ float local_o = tid < length ? __expf(local_i) : 0.0f; ++ float val = blockReduceSum(local_o); ++ ++ if (tid == 0) { ++ s_sum = val + 1e-6; ++ } ++ __syncthreads(); ++ if (tid < length) { ++ logits[tid] = local_o / s_sum; ++ } ++ __syncthreads(); ++ ++ if (tid < size_per_head) { ++ T sum = (T)0.0f; ++ for (int ite = 0; ite < length; ++ite) { ++ int value_id = bid * seq_len * head_num * size_per_head + ite * head_num * size_per_head ++ + head_id * size_per_head + tid; ++ ++ T value = value_cache[value_id]; ++ ++ // for the first step, we should add bias to key memory cache ++ if (step == 1) { ++ value += V_bias[head_id * size_per_head + tid]; ++ value_cache[value_id] = value; ++ } ++ sum += value * logits[ite]; ++ } ++ context_buf[bid * head_num * size_per_head + head_id * size_per_head + tid] = sum; ++ } ++} ++ ++template ++__global__ void attention_kernel_opt( ++ const T* __restrict qkv_buf, ++ const T* __restrict qkv_bias, ++ const T* __restrict attr_mask, ++ T* __restrict out_buf, ++ T* __restrict key_cache_output, ++ T* __restrict value_cache_output, ++ int batch_size, ++ int head_num, ++ const int seq_len, ++ const float scalar) ++{ ++ typedef Copy_t copy_t; ++ const int elems_per_thread = size_per_head / WARP_SIZE; ++ union Access_t { ++ copy_t v; ++ T x[elems_per_thread]; // supported size 1,2,4 ++ }; ++ typedef struct Float_n_t { ++ float x[elems_per_thread]; // supported size 1,2,4 ++ } float_n_t; ++ ++ __shared__ float_n_t sq[block_sz]; ++ extern __shared__ float logits[]; // use to store the logits from [0~step] ++ ++ const int warp_id = threadIdx.x / WARP_SIZE; ++ const int warp_num = block_sz / WARP_SIZE; ++ ++ typedef cub::BlockReduce MaxValBlockReduce; ++ typedef cub::BlockReduce BlockReduce; ++ __shared__ typename MaxValBlockReduce::TempStorage max_val_block_temp_storage; ++ __shared__ typename BlockReduce::TempStorage block_temp_storage; ++ ++ __shared__ typename cub::WarpReduce::TempStorage temp_storage[warp_num]; ++ ++ const int tid = threadIdx.x; ++ const int bid = blockIdx.x / head_num; ++ const int head_id = blockIdx.x % head_num; ++ int seq_id = blockIdx.y; ++ ++ ++ int length = seq_len; ++ const int lane_id = tid % WARP_SIZE; ++ ++ // QKV [m 3 n] shape ++ int qkv_id = bid * (3 * seq_len * head_num * size_per_head) + seq_id * (3 * head_num * size_per_head) ++ + head_id * size_per_head; ++ int q_id = bid * (seq_len * head_num * size_per_head) + seq_id * (head_num * size_per_head) ++ + head_id * size_per_head; ++ int qkv_bias_id = head_id * size_per_head; ++ int key_id = bid * (3 * seq_len * head_num * size_per_head) + head_num * size_per_head + head_id * size_per_head; ++ int value_id = bid * (3 * seq_len * head_num * size_per_head) + 2 * head_num * size_per_head + head_id * size_per_head; ++ ++ int key_trn_id = bid * (seq_len * head_num * size_per_head) + head_id * (size_per_head * seq_len); ++ int value_trn_id = bid * (seq_len * head_num * size_per_head) + head_id * (size_per_head * seq_len); ++ int mask_offset = bid * (seq_len * seq_len) + seq_id * seq_len; ++ ++ // get pointers ++ const T* query_buf = qkv_buf + qkv_id; ++ const T* Q_bias = qkv_bias + qkv_bias_id; ++ T* context_buf = out_buf + q_id; ++ ++ const T* key_cache = qkv_buf + key_id; ++ const T* K_bias = qkv_bias + head_num * size_per_head + qkv_bias_id; ++ T* key_cache_out = key_cache_output + key_trn_id; ++ ++ const T* value_cache = qkv_buf + value_id; ++ const T* V_bias = qkv_bias + 2 * head_num * size_per_head + qkv_bias_id; ++ T* value_cache_out = value_cache_output + value_trn_id; ++ ++ Access_t bias_r, key_val_r, query_buf_r; ++ // offset inside head ++ int minor_offset = lane_id; // offset in copy_t elements ++ // each warp will have its own copy of sq ++ query_buf_r.v = *((copy_t*)query_buf + minor_offset); ++ ++ bias_r.v = *((copy_t*)Q_bias + minor_offset); ++ float qb_r[elems_per_thread]; ++#pragma unroll ++ for (int i = 0; i < elems_per_thread; ++i) { ++ qb_r[i] = (float)query_buf_r.x[i] + (float)bias_r.x[i]; ++ } ++ ++ // offset for each step ++ int offset = 3 * head_num * size_per_head; ++ bias_r.v = *((copy_t*)K_bias + minor_offset); ++ for (int ite = warp_id; ite < length; ite += warp_num) { ++ key_val_r.v = *((copy_t*)&key_cache[ite * offset] + minor_offset); ++ ++ if (seq_id == 0) { ++ for (int i = 0; i < elems_per_thread; i++) { ++ key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; ++ key_cache_out[ite + seq_len * (minor_offset * elems_per_thread + i)] = key_val_r.x[i]; ++ } ++ } else { ++ for (int i = 0; i < elems_per_thread; i++) { ++ key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; ++ } ++ } ++ float val = 0; ++ for (int i = 0; i < elems_per_thread; i++) { ++ val = val + (float)key_val_r.x[i] * qb_r[i]; ++ } ++ float qk = cub::WarpReduce(temp_storage[warp_id]).Sum(val); ++ ++ if (lane_id == 0) { ++ T mask_val = attr_mask[mask_offset + ite]; ++ mask_val = (1.0f - mask_val) * -10000.0f; ++ logits[ite] = qk * scalar + mask_val; ++ } ++ } ++ ++ __syncthreads(); ++ ++ __shared__ float s_max_val, s_sum; ++ float local_i = -1e20f; ++ for (int i = tid; i < length; i += blockDim.x) { ++ local_i = max(local_i, logits[i]); ++ } ++ ++ float max_val = MaxValBlockReduce(max_val_block_temp_storage).Reduce(local_i, cub::Max()); ++ if (tid == 0) { ++ s_max_val = max_val; ++ } ++ __syncthreads(); ++ ++ float local_o = 0.0f; ++ for (int i = tid; i < length; i += blockDim.x) { ++ logits[i] = __expf(logits[i] - s_max_val); ++ local_o += logits[i]; ++ } ++ float val = BlockReduce(block_temp_storage).Sum(local_o); ++ ++ if (tid == 0) { ++ s_sum = val + 1e-6; ++ } ++ __syncthreads(); ++ ++ float s_sum_inverse = __fdividef(1.0f, s_sum); ++ for (int i = tid; i < length; i += blockDim.x) { ++ logits[i] = logits[i] * s_sum_inverse; ++ } ++ __syncthreads(); ++ ++ // This optimization introduces discrepancy because of different order in FP32 summation ++ float sum_r[elems_per_thread] = {0.f}; ++ bias_r.v = *((copy_t*)V_bias + minor_offset); ++ for (int ite = warp_id; ite < length; ite += warp_num) { ++ key_val_r.v = *((copy_t*)&value_cache[ite * offset] + minor_offset); ++#pragma unroll ++ for (int i = 0; i < elems_per_thread; i++) { ++ key_val_r.x[i] = (float)key_val_r.x[i] + (float)bias_r.x[i]; ++ } ++ if(seq_id == 0) ++ *((copy_t*)&value_cache_out[ite * size_per_head] + minor_offset) = key_val_r.v; ++#pragma unroll ++ for (int i = 0; i < elems_per_thread; ++i) { ++ sum_r[i] += (float)key_val_r.x[i] * logits[ite]; ++ } ++ } ++ for (int i = 0; i < elems_per_thread; i++) { ++ sq[warp_id * WARP_SIZE + lane_id].x[i] = sum_r[i]; ++ } ++ __syncthreads(); ++ if (threadIdx.x < WARP_SIZE) { ++#pragma unroll ++ for (int j = 1; j < warp_num; j++) { ++ for (int i = 0; i < elems_per_thread; ++i) { ++ sum_r[i] = sum_r[i] + (float)sq[j * WARP_SIZE + threadIdx.x].x[i]; ++ } ++ } ++ } ++ __syncthreads(); ++#pragma unroll ++ for (int i = 0; i < elems_per_thread; i++) { ++ key_val_r.x[i] = sum_r[i]; ++ } ++ if (threadIdx.x < WARP_SIZE) { ++ *((copy_t*)context_buf + minor_offset) = key_val_r.v; ++ } ++} ++ ++template ++void myAttnention( ++ const T* qkv_buf, ++ const T* qkv_bias, ++ const T* attr_mask, ++ T* context_buf, ++ T* key_cache_out, ++ T* value_cache_out, ++ const int inference_batch_size, ++ const int head_num, ++ const int size_per_head, ++ const int seq_len, ++ const float q_scaling, ++ cudaStream_t stream) ++{ ++ const int block_sz = ATTENTION_BLOCK_SIZE; // blockDim.x ++ float scalar = 1.f / (sqrtf(size_per_head * 1.0f) * q_scaling); ++ ++ dim3 grid(inference_batch_size * head_num, seq_len); // gridDim.x gridDim.y ++ int cond = size_per_head * ((ATTENION_OPT) ? 1 : 0); ++ switch (cond) { ++ case 32: ++ attention_kernel_opt ++ <<>>(qkv_buf, ++ qkv_bias, ++ attr_mask, ++ context_buf, ++ key_cache_out, ++ value_cache_out, ++ inference_batch_size, ++ head_num, ++ seq_len, ++ scalar); ++ break; ++ case 64: ++ attention_kernel_opt ++ <<>>(qkv_buf, ++ qkv_bias, ++ attr_mask, ++ context_buf, ++ key_cache_out, ++ value_cache_out, ++ inference_batch_size, ++ head_num, ++ seq_len, ++ scalar); ++ break; ++ case 128: ++ attention_kernel_opt ++ <<>>(qkv_buf, ++ qkv_bias, ++ attr_mask, ++ context_buf, ++ key_cache_out, ++ value_cache_out, ++ inference_batch_size, ++ head_num, ++ seq_len, ++ scalar); ++ break; ++ default: ++ ; ++ // default path ++ ++ // int block_size = 128; ++ ++ // if (seq_len <= 64) { ++ // block_size = 64; ++ // } ++ // else if (seq_len <= 128 && seq_len > size_per_head) { ++ // block_size = 128; ++ // } ++ // else if (seq_len > 128 && seq_len <= 256) { ++ // block_size = 256; ++ // } ++ // else if (seq_len > 256 && seq_len <= 512) { ++ // block_size = 512; ++ // } ++ // else { ++ // block_size = 1024; ++ // } ++ ++ // if (block_size < size_per_head) { ++ // block_size = size_per_head; ++ // } ++ ++ // assert(block_size <= 1024); ++ // dim3 block(block_size); ++ ++ // int shared_size = sizeof(T) * (size_per_head + seq_len); ++ // attention_kernel<<>>(query_buf, ++ // Q_bias, ++ // key_cache, ++ // K_bias, ++ // value_cache, ++ // V_bias, ++ // length, ++ // context_buf, ++ // finished, ++ // max_batch_size, ++ // head_num, ++ // size_per_head, ++ // step, ++ // seq_len, ++ // scalar); ++ } ++} ++ ++template void myAttnention(const float* qkv_buf, ++ const float* qkv_bias, ++ const float *attr_mask, ++ float* context_buf, ++ float* key_cache_out, ++ float* value_cache_out, ++ const int inference_batch_size, ++ const int head_num, ++ const int size_per_head, ++ const int seq_len, ++ const float q_scaling, ++ cudaStream_t stream); ++ ++// template void myAttnention(const half* qkv_buf, ++// const half* qkv_bias, ++// const half *attr_mask, ++// half* context_buf, ++// half* key_cache_out, ++// half* value_cache_out, ++// const int inference_batch_size, ++// const int head_num, ++// const int size_per_head, ++// const int seq_len, ++// const float q_scaling, ++// cudaStream_t stream); + } // namespace fastertransformer diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.h b/src/fastertransformer/kernels/unfused_attention_kernels.h -index be8b178..21d9a62 100644 +index be8b178..9e8be09 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.h +++ b/src/fastertransformer/kernels/unfused_attention_kernels.h -@@ -42,6 +42,26 @@ void invokeMaskedSoftMax(T* buffer, +@@ -42,6 +42,37 @@ void invokeMaskedSoftMax(T* buffer, const int head_num, const T scalar, cudaStream_t stream); @@ -2088,11 +3464,22 @@ index be8b178..21d9a62 100644 + const int tgt_seq_len, + const int head_num, + const T scalar, ++ cudaStream_t stream); ++ ++template ++void invokeMixMaskedSoftMax(T* io_buffer, ++ const T_M* attr_mask, ++ const T* position_bias, ++ const int batch_size, ++ const int seq_len, ++ const int tgt_seq_len, ++ const int head_num, ++ const T scalar, + cudaStream_t stream); template void invokeTransposeQKV(T* dst, -@@ -81,12 +101,12 @@ void invokeTransposeAttentionOutRemovePadding(T* src, +@@ -81,12 +112,12 @@ void invokeTransposeAttentionOutRemovePadding(T* src, const int* mask_offset, cudaStream_t stream); @@ -2107,7 +3494,7 @@ index be8b178..21d9a62 100644 const int batch_size, const int seq_len, const int head_num, -@@ -97,12 +117,29 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, +@@ -97,12 +128,29 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, q_buf, k_buf, v_buf, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, 0, stream); } @@ -2139,11 +3526,41 @@ index be8b178..21d9a62 100644 const int batch_size, const int seq_len, const int head_num, +@@ -166,4 +214,21 @@ void invokeMaskedSoftMaxWithRelPosBias(T* qk_buf, + const float qk_scale, + cudaStream_t stream); + ++ ++template ++void myAttnention(const T* qkv_buf, ++ const T* qkv_bias, ++ const T *attr_mask, ++ T* context_buf, ++ T* key_cache_out, ++ T* value_cache_out, ++ const int inference_batch_size, ++ const int head_num, ++ const int size_per_head, ++ const int seq_len, ++ const float q_scaling, ++ cudaStream_t stream); ++ ++ ++ + } // namespace fastertransformer diff --git a/src/fastertransformer/layers/CMakeLists.txt b/src/fastertransformer/layers/CMakeLists.txt -index cbaf4fa..2ab5320 100644 +index cbaf4fa..00a46d4 100644 --- a/src/fastertransformer/layers/CMakeLists.txt +++ b/src/fastertransformer/layers/CMakeLists.txt -@@ -30,15 +30,18 @@ set_property(TARGET FfnLayerINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) +@@ -14,6 +14,7 @@ + + cmake_minimum_required(VERSION 3.8) + ++add_subdirectory(encoder_layers) + add_subdirectory(attention_layers) + add_subdirectory(attention_layers_int8) + add_subdirectory(xlnet_attention_layers) +@@ -30,15 +31,18 @@ set_property(TARGET FfnLayerINT8 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET FfnLayerINT8 PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(FfnLayerINT8 PUBLIC -lcublasLt -lcublas -lcudart cublasMMWrapper cublasINT8MMWrapper activation_int8_kernels memory_utils) @@ -2206,7 +3623,7 @@ index 9cef315..7170af4 100644 add_library(WindowAttention STATIC WindowAttention.cc) set_property(TARGET WindowAttention PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc -index bada640..e606bc2 100644 +index bada640..3ed150b 100644 --- a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.cc @@ -17,6 +17,7 @@ @@ -2226,7 +3643,7 @@ index bada640..e606bc2 100644 const int m = input_tensors->at(0).shape[0]; -@@ -428,4 +429,503 @@ template class GptContextAttentionLayer; +@@ -428,4 +429,520 @@ template class GptContextAttentionLayer; template class GptContextAttentionLayer<__nv_bfloat16>; #endif @@ -2244,6 +3661,8 @@ index bada640..e606bc2 100644 + + + ++// HAIM Playground MS-MHA ++ +template +MSMHALayer::MSMHALayer(size_t max_batch_size, + size_t max_src_seq_len, @@ -2255,7 +3674,8 @@ index bada640..e606bc2 100644 + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, -+ bool sparse): ++ bool sparse, ++ bool is_position_bias): + BaseAttentionLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, sparse), + max_batch_size_(max_batch_size), + max_src_seq_len_(max_src_seq_len), @@ -2264,42 +3684,31 @@ index bada640..e606bc2 100644 + size_per_head_(size_per_head), + hidden_size_(head_num * size_per_head), + is_qk_buf_float_(false), // for now set to false -+ optimized_offset_(0) ++ optimized_offset_(0), ++ is_position_bias_(is_position_bias) +{ +} + ++ +template -+void MSMHALayer::forward(std::vector* output_tensors, ++void MSMHALayer::forward_1(std::vector* output_tensors, + const std::vector* input_tensors, -+ const AttentionWeight* attention_weights) -+{ -+ // input_tensors: use 1 gemm -- multi head attention -+ // input_query [batch_size * seq_len, hidden_dimension] -+ // attention_mask [batch_size, 1, seq_len, seq_len] -+ -+ // input_tensors: use 2 gemm -- cross attention -+ // input_query [batch_size * seq_len, hidden_dimension] -+ // enc_output [batch_size * tgt_len, hidden_dimension] -+ // attention_mask [batch_size, 1, seq_len, seq_len] -+ -+ // output_tensors: -+ // attention_out [batch_size * seq_len, hidden_dimension] -+ // key_cache [batch, local_head_num, size_per_head // x, max_seq_len, x] -+ // value_cache [batch, local_head_num, max_seq_len, size_per_head] -+ -+ -+ // validate in / out tensors ++ const AttentionWeight* attention_weights) { + int in_tensor_number = input_tensors->size(); -+ FT_CHECK(in_tensor_number == 2 || in_tensor_number == 3); -+ FT_CHECK(output_tensors->size() == 3); ++// FT_CHECK(in_tensor_number == 2 || in_tensor_number == 3); ++ // FT_CHECK(output_tensors->size() == 3); + // FT_CHECK(isValidBatchSize(request_batch_size); + // FT_CHECK(isValidSrcSeqLen(request_src_seq_len); + // FT_CHECK(isValidTgtSeqLen(request_tgt_seq_len); + + // setup in tensors id + int attn_input_tensor_id = 0; -+ int encoder_out_tensor_id = (in_tensor_number == 3) ? 1 : 0; ++ int encoder_out_tensor_id = ((in_tensor_number == 3) || ((in_tensor_number == 4) && is_position_bias_)) ? 1 : 0; + int attn_mask_tensor_id = in_tensor_number - 1; // always last tensor ++ int position_bias_tensor_id = in_tensor_number - 1; ++ if(is_position_bias_) { // position bias is last ++ attn_mask_tensor_id -= 1; ++ } + + const int request_batch_size = input_tensors->at(attn_mask_tensor_id).shape[0]; + const int request_src_seq_len = input_tensors->at(attn_mask_tensor_id).shape[2]; @@ -2312,12 +3721,16 @@ index bada640..e606bc2 100644 + const T* attention_input = (const T*)input_tensors->at(attn_input_tensor_id).data; + const T* attention_mask = (const T*)input_tensors->at(attn_mask_tensor_id).data; + const bool is_final = false; -+ ++ const S* position_bias = (is_position_bias_) ? (const S*)input_tensors->at(position_bias_tensor_id).data : nullptr; ++ // std::cout << "attn_input_tensor_id: " << attn_input_tensor_id << std::endl; ++ // std::cout << "attn_mask_tensor_id: " << attn_mask_tensor_id << std::endl; ++ // std::cout << "encoder_out_tensor_id: " << encoder_out_tensor_id << std::endl; ++ // std::cout << "position_bias_tensor_id: " << position_bias_tensor_id << std::endl; + const cudaDataType_t gemm_data_type = (std::is_same::value) ? CUDA_R_32F : CUDA_R_16F; + const cudaDataType_t softmax_data_type = (std::is_same::value) ? CUDA_R_32F : CUDA_R_16F; + + const int m = input_tensors->at(attn_input_tensor_id).shape[0]; -+ if (input_tensors->size() == 3) { ++ if (((input_tensors->size() == 3) && !is_position_bias_) || ((input_tensors->size() == 4) && is_position_bias_)) { + // cross attention + Tensor encoder_output_tensor = input_tensors->at(encoder_out_tensor_id); + if (request_src_seq_len == request_tgt_seq_len) { @@ -2371,6 +3784,7 @@ index bada640..e606bc2 100644 + 2 * hidden_size_ /* n */); + } + } else { ++ std::cout << "--------------------------------------------------------------------"<< std::endl; + cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + 3 * hidden_size_, // n @@ -2382,14 +3796,14 @@ index bada640..e606bc2 100644 + hidden_size_, // k + qkv_buf_, + 3 * hidden_size_ /* n */); ++ std::cout << "--------------------------------------------------------------------"<< std::endl; + } + sync_check_cuda_error(); -+ + if (request_src_seq_len == request_tgt_seq_len) { + invokeAddFusedQKVBiasTranspose( + (T*)q_buf_2_, -+ (T*)output_tensors->at(1).data, //k_buf_2_, -+ (T*)output_tensors->at(2).data, //v_buf_2_, ++ (T*)output1_, //(T*)output_tensors->at(1).data, //(T*)output1_, //k_buf_2_, ++ (T*)output2_, //(T*)output_tensors->at(2).data, //(T*)output2_, //v_buf_2_, + (T*)qkv_buf_, + (U*)attention_weights->query_weight.bias, + request_batch_size, @@ -2401,8 +3815,8 @@ index bada640..e606bc2 100644 + } else { + invokeCrossAddFusedQKVBiasTranspose( + (T*)q_buf_2_, -+ (T*)output_tensors->at(1).data, //k_buf_2_, -+ (T*)output_tensors->at(2).data, //v_buf_2_, ++ (T*)output1_, //(T*)output_tensors->at(1).data, //(T*)output1_, //k_buf_2_, ++ (T*)output2_, //(T*)output_tensors->at(2).data, //(T*)output2_, //v_buf_2_, + qkv_buf_, + attention_weights->query_weight.bias, + request_batch_size, @@ -2413,118 +3827,49 @@ index bada640..e606bc2 100644 + stream_); + } + sync_check_cuda_error(); -+ // Use batch major -+ // put k/v_buf from shape [B, H, L, Dh] -+ // to cache [B, H, Dh/x, L, x] and [B, H, L, Dh/x, x] -+ // invokeTranspose4dBatchMajor((T*)output_tensors->at(1).data, //k_buf_2_, -+ // (T*)output_tensors->at(2).data, //v_buf_2_, -+ // (T*)output_tensors->at(1).data, //k_buf_2_, -+ // (T*)output_tensors->at(2).data, //v_buf_2_, -+ // request_batch_size, -+ // request_tgt_seq_len, //request_seq_len, -+ // request_tgt_seq_len, //max_seq_len, -+ // size_per_head_, -+ // head_num_, -+ // stream_); -+ // sync_check_cuda_error(); -+ -+ if (is_final == false) { ++ { + const cudaDataType_t gemm_data_type = getCudaDataType(); -+ if (is_qk_buf_float_ == true && gemm_data_type != CUDA_R_32F) { -+ // cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, -+ // CUBLAS_OP_N, -+ // request_seq_len, -+ // request_seq_len, -+ // size_per_head_, -+ // 1.0f, -+ // (T*)output_tensors->at(1).data, // k_buf_2_, -+ // gemm_data_type, -+ // size_per_head_, -+ // request_seq_len * size_per_head_, -+ // q_buf_2_, -+ // gemm_data_type, -+ // size_per_head_, -+ // request_seq_len * size_per_head_, -+ // 0.0f, -+ // qk_buf_float_, -+ // CUDA_R_32F, -+ // request_seq_len, -+ // request_seq_len * request_seq_len, -+ // request_batch_size * head_num_, -+ // CUDA_R_32F); -+ // sync_check_cuda_error(); -+ // T scalar = 1 / sqrtf(size_per_head_ * 1.0f); -+ // invokeMaskedSoftMax(qk_buf_, -+ // qk_buf_float_, -+ // attention_mask, -+ // request_batch_size, -+ // request_seq_len, -+ // head_num_, -+ // scalar, -+ // stream_); -+ // sync_check_cuda_error(); -+ } -+ else { -+ cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, -+ CUBLAS_OP_N, -+ request_tgt_seq_len, -+ request_src_seq_len, -+ size_per_head_, -+ 1.0f, -+ (T*)output_tensors->at(1).data, // k_buf_2_, -+ gemm_data_type, -+ size_per_head_, -+ request_tgt_seq_len * size_per_head_, -+ q_buf_2_, -+ gemm_data_type, -+ size_per_head_, -+ request_src_seq_len * size_per_head_, -+ 0.0f, -+ qk_buf_, -+ softmax_data_type, -+ request_tgt_seq_len, -+ request_src_seq_len * request_tgt_seq_len, -+ request_batch_size * head_num_, -+ CUDA_R_32F); ++ cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_T, ++ CUBLAS_OP_N, ++ request_tgt_seq_len, ++ request_src_seq_len, ++ size_per_head_, ++ 1.0f, ++ (T*)output1_, //(T*)output_tensors->at(1).data, //(T*)output1_, // k_buf_2_, ++ gemm_data_type, ++ size_per_head_, ++ request_tgt_seq_len * size_per_head_, ++ q_buf_2_, ++ gemm_data_type, ++ size_per_head_, ++ request_src_seq_len * size_per_head_, ++ 0.0f, ++ qk_buf_, ++ softmax_data_type, ++ request_tgt_seq_len, ++ request_src_seq_len * request_tgt_seq_len, ++ request_batch_size * head_num_, ++ CUDA_R_32F); + -+ S scalar = (S) (1.0f / sqrtf(size_per_head_ * 1.0f)); -+ invokeMixMaskedSoftMax(qk_buf_, -+ attention_mask, -+ request_batch_size, -+ request_src_seq_len, -+ request_tgt_seq_len, -+ head_num_, -+ scalar, -+ stream_); -+ -+ // if (request_src_seq_len == request_tgt_seq_len) { -+ // invokeMaskedSoftMax(qk_buf_, -+ // qk_buf_, -+ // attention_mask, -+ // request_batch_size, -+ // request_tgt_seq_len, -+ // head_num_, -+ // scalar, -+ // stream_); -+ // } else { -+ // invokeCrossMaskedSoftMax(qk_buf_, -+ // qk_buf_, -+ // attention_mask, -+ // request_batch_size, -+ // request_src_seq_len, request_tgt_seq_len, -+ // head_num_, -+ // scalar, -+ // stream_); -+ // } -+ sync_check_cuda_error(); -+ } ++ S scalar = (S) (1.0f / sqrtf(size_per_head_ * 1.0f)); ++ sync_check_cuda_error(); ++ invokeMixMaskedSoftMax(qk_buf_, ++ attention_mask, ++ position_bias, ++ request_batch_size, ++ request_src_seq_len, ++ request_tgt_seq_len, ++ head_num_, ++ scalar, ++ stream_); ++ sync_check_cuda_error(); + cublas_wrapper_->stridedBatchedGemm(CUBLAS_OP_N, + CUBLAS_OP_N, + size_per_head_, + request_src_seq_len, + request_tgt_seq_len, + 1.0f, -+ (T*)output_tensors->at(2).data, // v_buf_2_, ++ (T*)output2_, //(T*)output_tensors->at(2).data, //(T*)output2_, // v_buf_2_, + gemm_data_type, + size_per_head_, + request_tgt_seq_len * size_per_head_, @@ -2539,9 +3884,7 @@ index bada640..e606bc2 100644 + request_src_seq_len * size_per_head_, + request_batch_size * head_num_, + CUDA_R_32F); -+ sync_check_cuda_error(); -+ -+ ++ sync_check_cuda_error(); + invokeTransposeQKV( + qkv_buf_3_, + qkv_buf_2_, @@ -2551,20 +3894,86 @@ index bada640..e606bc2 100644 + size_per_head_, + stream_); + sync_check_cuda_error(); -+// #ifdef SPARSITY_ENABLED -+// if (sparse_ && cublas_wrapper_->isUseSparse(1, hidden_size_, m_padded, local_hidden_size_)) { -+// cublas_wrapper_->SpGemm(CUBLAS_OP_N, -+// CUBLAS_OP_N, -+// hidden_size_, -+// m_padded, -+// local_hidden_size_, -+// attention_weights->attention_output_weight.sp_kernel, -+// qkv_buf_3_, -+// attention_out); -+// } -+// else { -+// #endif -+ cublas_wrapper_->Gemm(CUBLAS_OP_N, ++ cublas_wrapper_->Gemm(CUBLAS_OP_N, ++ CUBLAS_OP_N, ++ hidden_size_, ++ m, ++ hidden_size_, ++ attention_weights->attention_output_weight.kernel, ++ hidden_size_, ++ qkv_buf_3_, ++ hidden_size_, ++ (T*)output_tensors->at(0).data, ++ hidden_size_); ++ sync_check_cuda_error(); ++ if (!is_position_bias_) { ++ int len = request_batch_size * request_src_seq_len; ++ invokeAddBias((T*)output_tensors->at(0).data, (const T*)attention_weights->attention_output_weight.bias, len, hidden_size_, stream_); ++ } ++ sync_check_cuda_error(); ++ } ++ if (is_free_buffer_after_forward_ == true) { ++ freeBuffer(); ++ } ++ sync_check_cuda_error(); ++} ++ ++template ++void MSMHALayer::forward_2(std::vector* output_tensors, ++ const std::vector* input_tensors, ++ const AttentionWeight* attention_weights) { ++ T* attention_out = (T*)(output_tensors->at(0).data); ++ T* key_mem_cache = (T*)(output_tensors->at(1).data); ++ T* value_mem_cache = (T*)(output_tensors->at(2).data); ++ Tensor encoder_output_tensor = input_tensors->at(0); ++ ++ int in_tensor_number = input_tensors->size(); ++ int attn_input_tensor_id = 0; ++ int encoder_out_tensor_id = (in_tensor_number == 3) ? 1 : 0; ++ int attn_mask_tensor_id = in_tensor_number - 1; // always last tensor ++ ++ const int request_batch_size = input_tensors->at(attn_mask_tensor_id).shape[0]; ++ const int request_src_seq_len = input_tensors->at(attn_mask_tensor_id).shape[2]; ++ const int request_tgt_seq_len = input_tensors->at(attn_mask_tensor_id).shape[3]; ++ ++ // alloc buffer according to curent in size ++ allocateBuffer(request_batch_size, request_src_seq_len, request_tgt_seq_len); ++ sync_check_cuda_error(); ++ ++ const T* attention_input = (const T*)input_tensors->at(attn_input_tensor_id).data; ++ const T* attention_mask = (const T*)input_tensors->at(attn_mask_tensor_id).data; ++ const int m = input_tensors->at(attn_input_tensor_id).shape[0]; ++ ++ cublas_wrapper_->Gemm(CUBLAS_OP_N, ++ CUBLAS_OP_N, ++ 3 * hidden_size_, // n ++ m, ++ hidden_size_, // k ++ attention_weights->query_weight.kernel, ++ 3 * hidden_size_, // n ++ attention_input, ++ hidden_size_, // k ++ qkv_buf_, ++ 3 * hidden_size_ /* n */); ++ ++ sync_check_cuda_error(); ++ myAttnention( ++ (const float*) qkv_buf_, ++ (const float*) attention_weights->query_weight.bias, ++ (const float*) attention_mask, ++ (float*) qkv_buf_3_, ++ (float*) key_mem_cache, ++ (float*) value_mem_cache, ++ request_batch_size, ++ head_num_, ++ size_per_head_, ++ request_src_seq_len, ++ 1.0f, ++ stream_); ++ sync_check_cuda_error(); ++ ++ ++ cublas_wrapper_->Gemm(CUBLAS_OP_N, + CUBLAS_OP_N, + hidden_size_, + m, @@ -2579,14 +3988,29 @@ index bada640..e606bc2 100644 + sync_check_cuda_error(); + invokeAddBias((T*)output_tensors->at(0).data, (const T*)attention_weights->attention_output_weight.bias, len, hidden_size_, stream_); + sync_check_cuda_error(); -+// #ifdef SPARSITY_ENABLED -+// } -+// #endif -+ } -+ if (is_free_buffer_after_forward_ == true) { -+ freeBuffer(); -+ } -+ sync_check_cuda_error(); ++} ++template ++void MSMHALayer::forward(std::vector* output_tensors, ++ const std::vector* input_tensors, ++ const AttentionWeight* attention_weights) ++{ ++ // input_tensors: use 1 gemm -- multi head attention ++ // input_query [batch_size * seq_len, hidden_dimension] ++ // attention_mask [batch_size, 1, seq_len, seq_len] ++ ++ // input_tensors: use 2 gemm -- cross attention ++ // input_query [batch_size * seq_len, hidden_dimension] ++ // enc_output [batch_size * tgt_len, hidden_dimension] ++ // attention_mask [batch_size, 1, seq_len, seq_len] ++ ++ // output_tensors: ++ // attention_out [batch_size * seq_len, hidden_dimension] ++ // key_cache [batch, local_head_num, size_per_head // x, max_seq_len, x] ++ // value_cache [batch, local_head_num, max_seq_len, size_per_head] ++ ++ ++ forward_1(output_tensors, input_tensors, attention_weights); ++ +} + +template @@ -2602,14 +4026,18 @@ index bada640..e606bc2 100644 + FT_CHECK(false); + // allocate according to max parameters + if (is_allocate_buffer_ == false) { ++ // TODO nizzan fix allocator + #if 1 + size_t qkv_len = getQElemNum() + getKElemNum() + getVElemNum(); ++ // auto buff_size = qkv_len + q_buf_2_len + qk_buf_len + (qkv_buf_2_len * 2); ++ auto buff_size = qkv_len + getQElemNum() + getQKElemNum() + (getQKVElemNum() * 2); + qkv_buf_ = reinterpret_cast(allocator_->malloc(sizeof(T) * qkv_len, true)); + q_buf_2_ = reinterpret_cast(allocator_->malloc(sizeof(T) * getQElemNum(), true)); + qk_buf_ = reinterpret_cast(allocator_->malloc(sizeof(S) *getQKElemNum(), true)); + qkv_buf_2_ = reinterpret_cast(allocator_->malloc(sizeof(T) * getQKVElemNum(), true)); + qkv_buf_3_ = reinterpret_cast(allocator_->malloc(sizeof(T) * getQKVElemNum(), true)); -+ ++ output1_ = reinterpret_cast(allocator_->malloc(sizeof(T) * buff_size, true)); //static_cast(workspace) + buff_size; ++ output2_ = reinterpret_cast(allocator_->malloc(sizeof(T) * buff_size, true));//static_cast(output1_) + extra_size; + // if (is_qk_buf_float_ == true) { + // qk_buf_float_ = (float*)allocator_->malloc( + // sizeof(float) * getQKElemNum(), true); @@ -2632,16 +4060,20 @@ index bada640..e606bc2 100644 +template +void MSMHALayer::allocateBuffer(size_t batch_size, size_t src_seq_len, size_t tgt_seq_len) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); ++ // TODO nizzan fix allocator + #if 1 ++ + size_t qkv_len = getQElemNum(batch_size, src_seq_len) + getKElemNum(batch_size, tgt_seq_len) + + getVElemNum(batch_size, tgt_seq_len); ++ auto buff_size = qkv_len + getQElemNum() + getQKElemNum() + (getQKVElemNum() * 2); + qkv_buf_ = reinterpret_cast(allocator_->reMalloc(qkv_buf_, sizeof(T) * qkv_len, true)); + q_buf_2_ = reinterpret_cast(allocator_->reMalloc(q_buf_2_, sizeof(T) * getQElemNum(batch_size, src_seq_len), true)); + qk_buf_ = reinterpret_cast( + allocator_->reMalloc(qk_buf_, sizeof(S) * getQKElemNum(batch_size, src_seq_len, tgt_seq_len), true)); + qkv_buf_2_ = reinterpret_cast(allocator_->reMalloc(qkv_buf_2_, sizeof(T) * getQKVElemNum(batch_size, src_seq_len), true)); + qkv_buf_3_ = reinterpret_cast(allocator_->reMalloc(qkv_buf_3_, sizeof(T) * getQKVElemNum(batch_size, src_seq_len), true)); -+ ++ output1_ = reinterpret_cast(allocator_->malloc(sizeof(T) * buff_size, true)); //static_cast(workspace) + buff_size; ++ output2_ = reinterpret_cast(allocator_->malloc(sizeof(T) * buff_size, true));//static_cast(output1_) + extra_size; + // if (is_qk_buf_float_ == true) { + // qk_buf_float_ = (float*)allocator_->reMalloc( + // qk_buf_float_, sizeof(float) * getQKElemNum(batch_size, size_t src_seq_len, size_t tgt_seq_len), true); @@ -2663,6 +4095,8 @@ index bada640..e606bc2 100644 +void MSMHALayer::freeBuffer() { + if (is_allocate_buffer_) { + FT_LOG_DEBUG(__PRETTY_FUNCTION__); ++ // TODO nizzan fix allocator ++ //allocator_->free(buf_); + allocator_->free(qkv_buf_); + allocator_->free(q_buf_2_); + allocator_->free(qk_buf_); @@ -2731,14 +4165,15 @@ index bada640..e606bc2 100644 + } // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h -index 92e2175..df67c9a 100644 +index 92e2175..557a4d1 100644 --- a/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/GptContextAttentionLayer.h -@@ -107,4 +107,205 @@ public: +@@ -107,4 +107,215 @@ public: const AttentionWeight* attention_weights) override; }; + ++// TODO(haim): Add template according to "mix" compute type (fp32, fp16) +template +class MSMHALayer: public BaseAttentionLayer { +private: @@ -2802,7 +4237,13 @@ index 92e2175..df67c9a 100644 + + bool is_qk_buf_float_; + int optimized_offset_; -+ ++ bool is_position_bias_; ++ void forward_1(std::vector* output_tensors, ++ const std::vector* input_tensors, ++ const AttentionWeight* attention_weights); ++ void forward_2(std::vector* output_tensors, ++ const std::vector* input_tensors, ++ const AttentionWeight* attention_weights); + +protected: + using BaseAttentionLayer::stream_; @@ -2816,6 +4257,8 @@ index 92e2175..df67c9a 100644 + T* qkv_buf_2_ = nullptr; + T* qkv_buf_3_ = nullptr; + T* buf_ = nullptr; ++ void *output1_{nullptr}; ++ void *output2_{nullptr}; + +public: + MSMHALayer(size_t batch_size, @@ -2828,7 +4271,8 @@ index 92e2175..df67c9a 100644 + IAllocator* allocator, + bool is_free_buffer_after_forward, + bool is_qk_buf_float, -+ bool sparse = false); ++ bool sparse = false, ++ bool is_position_bias=false); + + MSMHALayer(MSMHALayer const& attention_layer); + @@ -2940,6 +4384,531 @@ index 92e2175..df67c9a 100644 +// }; + } // namespace fastertransformer +diff --git a/src/fastertransformer/layers/encoder_layers/CMakeLists.txt b/src/fastertransformer/layers/encoder_layers/CMakeLists.txt +new file mode 100644 +index 0000000..b211a45 +--- /dev/null ++++ b/src/fastertransformer/layers/encoder_layers/CMakeLists.txt +@@ -0,0 +1,20 @@ ++# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++cmake_minimum_required(VERSION 3.8) ++ ++add_library(EncoderLayer STATIC encoder.cc) ++set_property(TARGET EncoderLayer PROPERTY POSITION_INDEPENDENT_CODE ON) ++set_property(TARGET EncoderLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) ++target_link_libraries(EncoderLayer PUBLIC -lcublas -lcudart unfused_attention_kernels activation_kernels) +diff --git a/src/fastertransformer/layers/encoder_layers/encoder.cc b/src/fastertransformer/layers/encoder_layers/encoder.cc +new file mode 100644 +index 0000000..0ad16ae +--- /dev/null ++++ b/src/fastertransformer/layers/encoder_layers/encoder.cc +@@ -0,0 +1,455 @@ ++ ++#include "src/fastertransformer/layers/encoder_layers/encoder.h" ++#include "src/fastertransformer/kernels/activation_kernels.h" ++#include "src/fastertransformer/kernels/layernorm_kernels.h" ++#include "src/fastertransformer/kernels/unfused_attention_kernels.h" ++ ++ ++namespace fastertransformer { ++ ++void CublasGemmWrapper(const void* a_addr, ++ const void* b_addr, ++ void* c_addr, ++ const int* params, ++ const int* lds, ++ const cublasOperation_t* operations, ++ const cudaDataType* data_types, ++ void* alpha, ++ void* beta, ++ cublasHandle_t cublas_handle, ++ cublasGemmAlgo_t algo) ++{ ++ const int m = params[0]; ++ const int n = params[1]; ++ const int k = params[2]; ++ cublasOperation_t trans_a = operations[0]; ++ cublasOperation_t trans_b = operations[1]; ++ const int lda = lds[0]; ++ const int ldb = lds[1]; ++ const int ldc = lds[2]; ++ cudaDataType type_a = data_types[0]; ++ cudaDataType type_b = data_types[1]; ++ cudaDataType type_c = data_types[2]; ++ cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; ++ if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { ++ compute_type = CUBLAS_COMPUTE_16F; ++ } ++ cublasGemmEx(cublas_handle, ++ trans_a, ++ trans_b, ++ m, ++ n, ++ k, ++ alpha, ++ a_addr, ++ type_a, ++ lda, ++ b_addr, ++ type_b, ++ ldb, ++ beta, ++ c_addr, ++ type_c, ++ ldc, ++ compute_type, ++ algo); ++} ++ ++void CublasGemmStridedBatchedWrapper(const void* a_addr, ++ const void* b_addr, ++ void* c_addr, ++ const int* params, ++ const int* lds, ++ const cublasOperation_t* operations, ++ const int* strides, ++ const cudaDataType* data_types, ++ void* alpha, ++ void* beta, ++ int batch, ++ cublasHandle_t cublas_handle, ++ cublasGemmAlgo_t algo) ++{ ++ const int m = params[0]; ++ const int n = params[1]; ++ const int k = params[2]; ++ cublasOperation_t trans_a = operations[0]; ++ cublasOperation_t trans_b = operations[1]; ++ const int lda = lds[0]; ++ const int ldb = lds[1]; ++ const int ldc = lds[2]; ++ cudaDataType type_a = data_types[0]; ++ cudaDataType type_b = data_types[1]; ++ cudaDataType type_c = data_types[2]; ++ cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; ++ if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) { ++ compute_type = CUBLAS_COMPUTE_16F; ++ } ++ const int stride_a = strides[0]; ++ const int stride_b = strides[1]; ++ const int stride_c = strides[2]; ++ ++ cublasGemmStridedBatchedEx(cublas_handle, ++ trans_a, ++ trans_b, ++ m, ++ n, ++ k, ++ alpha, ++ a_addr, ++ type_a, ++ lda, ++ stride_a, ++ b_addr, ++ type_b, ++ ldb, ++ stride_b, ++ beta, ++ c_addr, ++ type_c, ++ ldc, ++ stride_c, ++ batch, ++ compute_type, ++ algo); ++} ++ ++ ++template ++size_t GetAttnWorkspaceSize(encoderParamT* param) ++{ ++ size_t size_q = param->batch_size * param->src_seq_len * param->hidden_size; ++ size_t size_k = param->batch_size * param->tgt_seq_len * param->hidden_size; ++ size_t size_v = size_k; ++ size_t qkv_len = size_q + size_k + size_v; ++ size_t q_buf_2_len = size_q; ++ size_t qk_buf_len = param->batch_size * param->head_num * param->src_seq_len * param->tgt_seq_len; ++ size_t qkv_buf_2_len = param->batch_size * param->src_seq_len * param->hidden_size; ++ size_t qkv_buf_3_len = qkv_buf_2_len; ++ size_t attn_out_size = param->batch_size * param->head_num * param->head_size * param->tgt_seq_len; ++ return (qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len + 2 * attn_out_size) * sizeof(T); ++} ++ ++template ++size_t GetEncoderLayerWorkspaceSize(encoderParamT* param) ++{ ++ size_t norm1 = param->batch_size * param->src_seq_len * param->hidden_size; ++ size_t attn_out = param->batch_size * param->src_seq_len * param->hidden_size; ++ return (GetAttnWorkspaceSize(param) + norm1 + attn_out) * sizeof(T); ++} ++ ++template ++void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, encoderParamT* param, void* ws) ++{ ++ param->in_idx = 0; ++ size_t h_token_num = param->batch_size * param->src_seq_len; ++ T* normed_from_tensor = ws; ++ T* from_tensor = reinterpret_cast(inputs[param->in_idx++]); ++ invokeGeneralLayerNorm(normed_from_tensor, ++ from_tensor, // from tensor ++ reinterpret_cast(inputs[param->in_idx++]), // Gamma ++ reinterpret_cast(inputs[param->in_idx++]), // Beta ++ h_token_num, ++ param->hidden_size, ++ param->stream); ++ // simulate attention inputs ++ inputs[--param->in_idx] = normed_from_tensor; ++ T* attn_out = ws + param->batch_size * param->src_seq_len * param->hidden_size; ++ T* attn_ws = attn_out + param->batch_size * param->src_seq_len * param->hidden_size; ++ forward_attn(inputs, in_len, attn_out, 1, param, ws); ++ ++ T* normed_attn_out = attn_ws; ++ if (param->projection_bias) { ++ T* projection_bias = reinterpret_cast(inputs[param->in_idx++]); ++ invokeGeneralAddBiasResidualPreLayerNorm(attn_out, ++ normed_attn_out, ++ from_tensor, ++ reinterpret_cast(inputs[param->in_idx++]), // gamma ++ reinterpret_cast(inputs[param->in_idx++]), // beta ++ projection_bias, ++ h_token_num, ++ param->hidden_size, ++ param->stream); ++ } ++ else { ++ // without projection bias ++ } ++ // forward ffn ++ T* ffn_ws = attn_ws + param->batch_size * param->src_seq_len * param->hidden_size; ++ // simulate attention inputs ++ inputs[--param->in_idx] = normed_attn_out; ++ forward_ffn(inputs, in_len, output, out_len, param, ffn_ws); ++ invokeAddBiasResidual(reinterpret_cast(output[0]), ++ attn_out, ++ reinterpret_cast(inputs[param->in_idx++]), // FFN bias ++ param->hidden_size, ++ param->hidden_size, ++ param->stream); ++ return; ++} ++ ++template ++void forward_ffn(void* inputs[], int in_len, void* output[], int out_len, encoderParamT* param, void* ws) ++{ ++ size_t inter_size = param->ffn_hidden_size; // 4 * param->hidden_size; ++ size_t h_token_num = param->batch_size * param->src_seq_len; ++ cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N}; ++ cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; ++ if (std::is_same::value) { ++ gemm_data_types[0] = CUDA_R_16F; ++ gemm_data_types[1] = CUDA_R_16F; ++ gemm_data_types[2] = CUDA_R_16F; ++ } ++ float alpha = 1.0f; ++ float beta = 0.0f; ++ ++ int gemm_dims[] = {inter_size, h_token_num, param->hidden_size}; ++ int gemm_lds[] = {inter_size, param->hidden_size, inter_size}; ++ T* normed_attn_out = reinterpret_cast(inputs[param->in_idx++]); ++ CublasGemmWrapper(reinterpret_cast(inputs[param->in_idx++]), ++ normed_attn_out, ++ ws, ++ gemm_dims, ++ gemm_lds, ++ gemm_ops, ++ gemm_data_types, ++ &alpha, ++ &beta, ++ param->cublas_handle); ++ ++ invokeAddBiasGelu(ws, reinterpret_cast(inputs[param->in_idx++]), h_token_num, inter_size, param->stream); ++ gemm_dims[0] = param->hidden_size; ++ gemm_dims[1] = h_token_num; ++ gemm_dims[2] = inter_size; ++ gemm_lds[0] = param->hidden_size; ++ gemm_lds[1] = inter_size; ++ gemm_lds[2] = param->hidden_size; ++ CublasGemmWrapper(reinterpret_cast(inputs[param->in_idx++]), ++ ws, ++ output[0], ++ gemm_dims, ++ gemm_lds, ++ gemm_ops, ++ gemm_data_types, ++ &alpha, ++ &beta, ++ param->cublas_handle); ++} ++ ++template ++void forward_attn(void* inputs[], int in_len, void* output[], int out_len, encoderParamT* param, void* ws) ++{ ++ auto extra_tmp_size = param->batch_size * param->head_num * param->head_size * param->tgt_seq_len; ++ size_t size_q = param->batch_size * param->src_seq_len * param->hidden_size; ++ size_t size_k = param->batch_size * param->tgt_seq_len * param->hidden_size; ++ size_t size_v = size_k; ++ ++ size_t qkv_len = size_q + size_k + size_v; ++ size_t q_buf_2_len = size_q; ++ size_t qk_buf_len = param->batch_size * param->head_num * param->src_seq_len * param->tgt_seq_len; ++ size_t qkv_buf_2_len = param->batch_size * param->src_seq_len * param->hidden_size; ++ size_t qkv_buf_3_len = param->batch_size * param->src_seq_len * param->hidden_size; ++ auto buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len; ++ T* qkv_buf = ws; ++ T* q_buf_2 = static_cast(qkv_buf) + qkv_len; ++ T* qk_buf = static_cast(q_buf_2) + q_buf_2_len; ++ T* qkv_buf_2 = static_cast(qk_buf) + qk_buf_len; ++ T* qkv_buf_3 = static_cast(qkv_buf_2) + qkv_buf_2_len; ++ T* output1 = static_cast(ws) + buff_size; ++ T* output2 = static_cast(output1) + extra_tmp_size; ++ ++ int gemm_dims[] = {3 * param->hidden_size, param->batch_size * param->src_seq_len, param->hidden_size}; ++ int gemm_lds[] = {3 * param->hidden_size, param->hidden_size, 3 * param->hidden_size}; ++ T* from_tensor = reinterpret_cast(inputs[param->in_idx++]); ++ ++ cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N}; ++ cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F}; ++ if (std::is_same::value) { ++ gemm_data_types[0] = CUDA_R_16F; ++ gemm_data_types[1] = CUDA_R_16F; ++ gemm_data_types[2] = CUDA_R_16F; ++ } ++ float alpha = 1.0f; ++ float beta = 0.0f; ++ if (param->is_cross) { ++ gemm_dims[0] = param->hidden_size; ++ gemm_dims[1] = param->batch_size * param->src_seq_len; ++ gemm_dims[2] = param->hidden_size; ++ gemm_lds[0] = param->hidden_size; ++ gemm_lds[1] = param->hidden_size; ++ gemm_lds[2] = param->hidden_size; ++ T* encoder_output = reinterpret_cast(inputs[param->in_idx++]); ++ T* weight_q = reinterpret_cast(inputs[param->in_idx++]); ++ ++ CublasGemmWrapper(weight_q, ++ from_tensor, ++ qkv_buf, ++ gemm_dims, ++ gemm_lds, ++ gemm_ops, ++ gemm_data_types, ++ &alpha, ++ &beta, ++ param->cublas_handle); ++ gemm_dims[0] = 2 * param->hidden_size; ++ gemm_lds[0] = 2 * param->hidden_size; ++ gemm_lds[2] = 2 * param->hidden_size; ++ T* weight_kv = reinterpret_cast(inputs[param->in_idx++]); ++ CublasGemmWrapper(weight_kv, ++ encoder_output, ++ qkv_buf + (param->batch_size * param->src_seq_len) * param->hidden_size, ++ gemm_dims, ++ gemm_lds, ++ gemm_ops, ++ gemm_data_types, ++ &alpha, ++ &beta, ++ param->cublas_handle); ++ if (param->qkv_bias) { ++ T* bias_qkv = reinterpret_cast(inputs[param->in_idx++]); ++ invokeCrossAddFusedQKVBiasTranspose(q_buf_2, ++ output1, ++ output2, ++ qkv_buf, ++ bias_qkv, ++ param->batch_size, ++ param->src_seq_len, ++ param->tgt_seq_len, ++ param->head_num, ++ param->head_size, ++ param->stream); ++ } ++ else { ++ } ++ } ++ else { ++ T* weight_qkv = reinterpret_cast(inputs[param->in_idx++]); ++ CublasGemmWrapper(weight_qkv, ++ from_tensor, ++ qkv_buf, ++ gemm_dims, ++ gemm_lds, ++ gemm_ops, ++ const_cast(gemm_data_types), ++ &alpha, ++ &beta, ++ param->cublas_handle); ++ if (param->qkv_bias) { ++ T* bias_qkv = reinterpret_cast(inputs[param->in_idx++]); ++ fastertransformer::invokeAddFusedQKVBiasTranspose(static_cast(q_buf_2), ++ static_cast(output1), ++ static_cast(output2), ++ static_cast(qkv_buf), ++ bias_qkv, ++ param->batch_size, ++ param->src_seq_len, ++ param->head_num, ++ param->head_size, ++ 0, ++ param->stream); ++ } ++ else { ++ ; ++ } ++ } ++ gemm_ops[0] = CUBLAS_OP_T; ++ gemm_ops[1] = CUBLAS_OP_N; ++ gemm_dims[0] = param->tgt_seq_len; ++ gemm_dims[1] = param->src_seq_len; ++ gemm_dims[2] = param->head_size; ++ ++ gemm_lds[0] = param->head_size; ++ gemm_lds[1] = param->head_size; ++ gemm_lds[2] = param->tgt_seq_len; ++ ++ int gemm_strides[] = {param->tgt_seq_len * param->head_size, ++ param->src_seq_len * param->head_size, ++ param->src_seq_len * param->tgt_seq_len}; ++ ++ CublasGemmStridedBatchedWrapper(output1, ++ q_buf_2, ++ qk_buf, ++ gemm_dims, ++ gemm_lds, ++ gemm_ops, ++ gemm_strides, ++ const_cast(gemm_data_types), ++ &alpha, ++ &beta, ++ param->batch_size * param->head_num, ++ param->cublas_handle); ++ ++ T* position_bias = nullptr; ++ if (param->position_bias) { ++ position_bias = reinterpret_cast(inputs[param->in_idx++]); ++ } ++ T* attention_mask = reinterpret_cast(inputs[param->in_idx++]); ++ T scalar = static_cast(1.0f / sqrtf(param->head_size * 1.0f)); ++ fastertransformer::invokeMixMaskedSoftMax(static_cast(qk_buf), ++ attention_mask, ++ param->batch_size, ++ param->src_seq_len, ++ param->tgt_seq_len, ++ param->head_num, ++ scalar, ++ param->stream); ++ ++ gemm_ops[0] = CUBLAS_OP_N; ++ gemm_ops[1] = CUBLAS_OP_N; ++ gemm_dims[0] = param->head_size; ++ gemm_dims[1] = param->src_seq_len; ++ gemm_dims[2] = param->tgt_seq_len; ++ ++ gemm_lds[0] = param->head_size; ++ gemm_lds[1] = param->tgt_seq_len; ++ gemm_lds[2] = param->head_size; ++ ++ gemm_strides[0] = param->tgt_seq_len * attention_mask->head_size; ++ gemm_strides[1] = param->src_seq_len * attention_mask->tgt_seq_len; ++ gemm_strides[2] = param->src_seq_len * attention_mask->head_size; ++ ++ CublasGemmStridedBatchedWrapper(output2, ++ qk_buf, ++ qkv_buf_2, ++ gemm_dims, ++ gemm_lds, ++ gemm_ops, ++ gemm_strides, ++ const_cast(gemm_data_types), ++ &alpha, ++ &beta, ++ param->batch_size * param->head_num, ++ param->cublas_handle); ++ invokeTransposeQKV(static_cast(qkv_buf_3), ++ static_cast(qkv_buf_2), ++ param->batch_size, ++ param->src_seq_len, ++ param->head_num, ++ param->head_size, ++ param->stream); ++ ++ gemm_ops[0] = CUBLAS_OP_N; ++ gemm_ops[1] = CUBLAS_OP_N; ++ gemm_dims[0] = param->hidden_size; ++ gemm_dims[1] = param->batch_size * param->src_seq_len; ++ gemm_dims[2] = param->hidden_size; ++ ++ gemm_lds[0] = param->hidden_size; ++ gemm_lds[1] = param->hidden_size; ++ gemm_lds[2] = param->hidden_size; ++ ++ CublasGemmWrapper(reinterpret_cast(inputs[param->in_idx++]), ++ qkv_buf_3, ++ static_cast(output[0]), ++ gemm_dims, ++ gemm_lds, ++ gemm_ops, ++ const_cast(gemm_data_types), ++ &alpha, ++ &beta, ++ param->cublas_handle); ++ int len = param->batch_size * param->src_seq_len; ++ ++ return; ++} ++ ++} // namespace fastertransformer +diff --git a/src/fastertransformer/layers/encoder_layers/encoder.h b/src/fastertransformer/layers/encoder_layers/encoder.h +new file mode 100644 +index 0000000..758fba5 +--- /dev/null ++++ b/src/fastertransformer/layers/encoder_layers/encoder.h +@@ -0,0 +1,31 @@ ++#pragma once ++ ++#include ++#include ++ ++namespace fastertransformer { ++ ++typedef struct { ++ size_t batch_size; ++ size_t src_seq_len; ++ size_t tgt_seq_len; ++ size_t head_num; ++ size_t head_size; ++ size_t hidden_size; ++ size_t ffn_hidden_size; ++ // handle ++ cublasHandle_t cublas_handle; ++ cudaStream_t stream; ++ // ctrls ++ int in_idx; ++ bool qkv_bias; ++ bool projection_bias; ++ bool is_cross; ++ bool position_bias; ++ bool layernorm_post; ++} encoderParamT; ++ ++size_t GetEncoderLayerWorkspaceSize(encoderParamT* param); ++void forwardEncoder(void* inputs[], int in_len, void* output[], int out_len, encoderParamT* param, void* ws); ++ ++} +\ No newline at end of file diff --git a/src/fastertransformer/models/CMakeLists.txt b/src/fastertransformer/models/CMakeLists.txt index af33e76..21efb6d 100644 --- a/src/fastertransformer/models/CMakeLists.txt @@ -3010,6 +4979,192 @@ index 0829e0d..cfd72b8 100644 pipeline_para_.world_size_ - 1, pipeline_para_, stream_); +diff --git a/src/fastertransformer/models/ms/main.cc b/src/fastertransformer/models/ms/main.cc +new file mode 100644 +index 0000000..cd5844f +--- /dev/null ++++ b/src/fastertransformer/models/ms/main.cc +@@ -0,0 +1,179 @@ ++/* ++ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++#include "src/fastertransformer/utils/gemm_test/ms_gemm_func.h" ++#include "src/fastertransformer/utils/memory_utils.h" ++ ++namespace ft = fastertransformer; ++ ++struct ms_opt_arg { ++ size_t batch_size; ++ size_t num_layers; ++ size_t seq_len; // source seq len ++ size_t tgt_seq_len; ++ size_t head_num; ++ size_t hidden_size; ++ size_t size_per_head; ++ bool is_remove_padding; ++ int m; ++ int n; ++ int k; ++ std::string model_name; ++ std::string compute_type; ++ std::string w_compute_type; ++ std::string s_compute_type; ++}; ++ ++void usage() { ++ std::cout << "Usage: ms_benchmark -b -l -t " ++ << "-s -H -S -p " ++ << "-T -W -F " ++ << "-m -c -M -N -K \n"; ++} ++ ++bool read_args(int argc, char* argv[], ms_opt_arg* opt_a) { ++ int opt; ++ while ((opt = getopt(argc, argv, "b:l:s:t:H:S:p:m:T:W:F:i:w:M:N:K:")) != -1) { ++ switch (opt) { ++ case 'b': ++ opt_a->batch_size = atoi(optarg); ++ break; ++ case 'l': ++ opt_a->num_layers = atoi(optarg); ++ break; ++ case 's': ++ opt_a->seq_len = atoi(optarg); ++ break; ++ case 't': ++ opt_a->tgt_seq_len = atoi(optarg); ++ break; ++ case 'H': ++ opt_a->head_num = atoi(optarg); ++ break; ++ case 'S': ++ opt_a->hidden_size = atoi(optarg); ++ break; ++ case 'p': ++ opt_a->is_remove_padding = static_cast(atoi(optarg)); ++ break; ++ case 'm': ++ opt_a->model_name = std::string(optarg); ++ break; ++ case 'T': ++ opt_a->compute_type = std::string(optarg); ++ break; ++ case 'W': ++ opt_a->w_compute_type = std::string(optarg); ++ break; ++ case 'F': ++ opt_a->s_compute_type = std::string(optarg); ++ break; ++ case 'M': ++ opt_a->m = atoi(optarg); ++ break; ++ case 'N': ++ opt_a->n = atoi(optarg); ++ break; ++ case 'K': ++ opt_a->k = atoi(optarg); ++ break; ++ case 'i': ++ case 'w': ++ break; ++ case 'h': ++ default: ++ usage(); ++ return false; ++ } ++ } ++ opt_a->size_per_head = opt_a->hidden_size / opt_a->head_num; ++ opt_a->tgt_seq_len = (opt_a->tgt_seq_len == -1) ? opt_a->seq_len : opt_a->tgt_seq_len; ++ return true; ++} ++ ++int main(int argc, char* argv[]) ++{ ++ ms_opt_arg opt_a; ++ opt_a.batch_size = 1; ++ opt_a.num_layers = 1; ++ opt_a.seq_len = 1; ++ opt_a.tgt_seq_len = -1; ++ opt_a.head_num = 1; ++ opt_a.hidden_size = 1; ++ opt_a.size_per_head = 1; ++ opt_a.is_remove_padding = false; ++ opt_a.m = 1; ++ opt_a.n = 1; ++ opt_a.k = 1; ++ opt_a.model_name = ""; ++ opt_a.compute_type = "fp32"; ++ opt_a.w_compute_type = "fp32"; ++ opt_a.s_compute_type = "fp32"; ++ ++ if (!read_args(argc, argv, &opt_a)) { ++ printf("[ERROR] Failed to read arguments. \n"); ++ usage(); ++ return 0; ++ } ++ ++ bool c_type_fp32 = (opt_a.compute_type.compare("fp32") == 0); ++ std::cout << "[INFO] arguments: " << std::endl; ++ std::cout << " batch_size: " << opt_a.batch_size << std::endl; ++ std::cout << " num of layers: " << opt_a.num_layers << std::endl; ++ std::cout << " seq len:" << opt_a.seq_len << std::endl; ++ std::cout << " target seq len: " << opt_a.tgt_seq_len << std::endl; ++ std::cout << " head_num: " << opt_a.head_num << std::endl; ++ std::cout << " size_per_head: " << opt_a.size_per_head << std::endl; ++ // std::cout << " compute_type: " << c_type_fp32 << std::endl; ++ ++ std::cout << std::endl; ++ ++ const int inter_size = 4 * opt_a.head_num * opt_a.size_per_head; ++ const ft::CublasDataType data_type = static_cast(0); // 0 FP32, 1 FP16, 2 BF 16 ++ void* gemm_test_buf; ++ size_t buf_size_in_byte = ft::calGemmTestBufSizeInByte(opt_a.batch_size, ++ opt_a.seq_len, ++ opt_a.head_num, ++ opt_a.size_per_head, ++ inter_size, ++ 0, // default ++ 0, // default ++ data_type); ++ ++ size_t total, free; ++ ft::check_cuda_error(cudaMemGetInfo(&free, &total)); ++ if (free < buf_size_in_byte + 10 * 1024 * 1024) { ++ printf("[ERROR] There is no enough device memory for gemm test!\n" ++ " %ld Bytes is needed, but only %ld Bytes is free.\n", ++ buf_size_in_byte, ++ free); ++ gemm_test_buf = NULL; ++ return -1; ++ } else { ++ ft::deviceMalloc(reinterpret_cast(&gemm_test_buf), buf_size_in_byte, false); ++ } ++ // int fast_algo = 0; ++ if (data_type == ft::FLOAT_DATATYPE) { ++ ft::generate_ms_gemm_config(opt_a.batch_size, opt_a.seq_len, opt_a.tgt_seq_len, opt_a.head_num, opt_a.size_per_head, gemm_test_buf, ++ false); ++ } else { ++ printf("[ERROR] data type only supports fp32(0). \n"); ++ return -1; ++ } ++ // std::cout << "main fast algo: " << fast_algo << std::endl; ++ ft::check_cuda_error(cudaFree(gemm_test_buf)); ++ return 0; ++} +\ No newline at end of file diff --git a/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt b/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt index 10b9e0b..86d733f 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/CMakeLists.txt @@ -3114,10 +5269,37 @@ index 3d0f28a..3d2efbd 100644 add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc) set_property(TARGET cublasINT8MMWrapper PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/fastertransformer/utils/cublasMMWrapper.cc b/src/fastertransformer/utils/cublasMMWrapper.cc -index e291151..6ddd6bd 100644 +index e291151..8294a4d 100644 --- a/src/fastertransformer/utils/cublasMMWrapper.cc +++ b/src/fastertransformer/utils/cublasMMWrapper.cc -@@ -313,6 +313,22 @@ void cublasMMWrapper::setFP16GemmConfig() +@@ -103,6 +103,8 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, + cublasGemmAlgo_t algo) + { + mu_->lock(); ++ std::cout << "m: " << m << " n: " << n << " k: "<< k << std::endl; ++ std::cout << "lda: " << lda << " ldb: " << ldb << " ldc: "<< ldc << std::endl; + check_cuda_error(cublasGemmEx(cublas_handle_, + transa, + transb, +@@ -181,6 +183,7 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, + } + + if (using_cublasLt) { ++ std::cout << "using_cublasLt" << std::endl; + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + cudaDataType_t scaleType; +@@ -272,6 +275,9 @@ void cublasMMWrapper::Gemm(cublasOperation_t transa, + sync_check_cuda_error(); + } + else { ++ std::cout << "----- else -----" << std::endl; ++ std::cout << "m: " << m << " n: " << n << " k: "<< k << std::endl; ++ std::cout << "lda: " << lda << " ldb: " << ldb << " ldc: "<< ldc << std::endl; + int cublasAlgo = info.algoId; + check_cuda_error(cublasGemmEx(cublas_handle_, + transa, +@@ -313,6 +319,22 @@ void cublasMMWrapper::setFP16GemmConfig() computeType_ = CUDA_R_32F; } @@ -3153,6 +5335,447 @@ index 6f410ab..a2159e0 100644 #ifdef ENABLE_BF16 void setBF16GemmConfig(); #endif +diff --git a/src/fastertransformer/utils/cuda_utils.h b/src/fastertransformer/utils/cuda_utils.h +index 5d73c87..aef6ab9 100644 +--- a/src/fastertransformer/utils/cuda_utils.h ++++ b/src/fastertransformer/utils/cuda_utils.h +@@ -382,7 +382,7 @@ public: + + static double diffTime(timeval start, timeval end) + { +- return (end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001; ++ return (end.tv_sec - start.tv_sec) * 1000000 + (end.tv_usec - start.tv_usec); + } + + /* ***************************** common utils ****************************** */ +diff --git a/src/fastertransformer/utils/custom_ar_comm.cc b/src/fastertransformer/utils/custom_ar_comm.cc +index ded1e58..159faaf 100644 +--- a/src/fastertransformer/utils/custom_ar_comm.cc ++++ b/src/fastertransformer/utils/custom_ar_comm.cc +@@ -54,6 +54,7 @@ void CustomAllReduceComm::customAllReduce(size_t elts, cudaStream_t stream) + output_tensor_->at(0).data = (const void*)tmp_tensor_data_; + } + ++ + template + void CustomAllReduceComm::allocateAndExchangePeerAccessPointer( + std::vector>* custom_all_reduce_comms) +diff --git a/src/fastertransformer/utils/gemm_test/ms_gemm_func.cc b/src/fastertransformer/utils/gemm_test/ms_gemm_func.cc +new file mode 100644 +index 0000000..e8f88fe +--- /dev/null ++++ b/src/fastertransformer/utils/gemm_test/ms_gemm_func.cc +@@ -0,0 +1,364 @@ ++/* ++ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++#include "src/fastertransformer/utils/gemm_test/ms_gemm_func.h" ++ ++namespace fastertransformer { ++ ++template ++void generate_ms_gemm_config( ++ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer_in, bool isAppend) ++{ ++ void* cublas_workspace; ++ void* buffer; ++ int workSpaceSize; ++ ++#ifdef ENABLE_BF16 ++ if (std::is_same::value || std::is_same::value) { ++#else ++ if (std::is_same::value) { ++#endif // ENABLE_BF16 ++ // cublas_workspace_ should be the start pointer of cudaMalloc() ++ // to ensure 16B alignemnet ++ cublas_workspace = buffer_in; ++ buffer = (void*)((char*)cublas_workspace + CUBLAS_WORKSPACE_SIZE); ++ workSpaceSize = CUBLAS_WORKSPACE_SIZE; ++ } ++ else { ++ cublas_workspace = nullptr; ++ buffer = buffer_in; ++ workSpaceSize = 0; ++ } ++ ++ struct cudaDeviceProp prop; ++ check_cuda_error(cudaGetDeviceProperties(&prop, 0)); ++ printf("Device %s\n", prop.name); ++ ++ // check config ++ FILE* fd; ++ int line_count = 0; ++ if (!isAppend) { ++ fd = fopen(GEMM_CONFIG, "w+"); ++ } ++ else { ++ fd = fopen(GEMM_CONFIG, "a+"); ++ std::vector config; ++ char line[1024]; ++ while (fgets(line, 1024, fd) != NULL) { ++ config.push_back(std::string(line)); ++ } ++ line_count = config.size(); ++ if (config.size() >= (MAX_CONFIG_NUM * GEMM_NUM + 1)) // 6 cublas/cublasLt, first row is not included ++ { ++ int startIdx = config.size() - ((MAX_CONFIG_NUM - 1) * GEMM_NUM); ++ fclose(fd); ++ fd = fopen(GEMM_CONFIG, "w+"); ++ fprintf(fd, "%s", config[0].c_str()); ++ for (uint i = startIdx; i < config.size(); i++) { ++ fprintf(fd, "%s", config[i].c_str()); ++ } ++ line_count = config.size() - (GEMM_NUM + 3); ++ } ++ } ++ ++ const int gemm_num = 4; ++ int M[gemm_num]; ++ int N[gemm_num]; ++ int K[gemm_num]; ++ int batchCount[gemm_num] = {1, 1, 1, 1}; ++ char mess[gemm_num][256]; ++ float exec_times[gemm_num]; ++ int gemm_lds[gemm_num][3]; // = {3 * hidden_size, hidden_size, 3 * hidden_size}; ++ cublasOperation_t gemm_ops[gemm_num][2]; // = {CUBLAS_OP_N, CUBLAS_OP_N}; ++ int gemm_strides[2][3]; ++ ++ // gemm1 ++ // int gemm_dims[] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size}; ++ int hidden_size = head_num * size_per_head; ++ M[0] = 3 * hidden_size; ++ N[0] = batch_size * seq_len; ++ K[0] = hidden_size; ++ gemm_lds[0][0] = 3 * hidden_size; ++ gemm_lds[0][1] = hidden_size; ++ gemm_lds[0][2] = 3 * hidden_size; ++ gemm_ops[0][0] = CUBLAS_OP_N; ++ gemm_ops[0][1] = CUBLAS_OP_N; ++ strcpy(mess[0], "cublasGemmEx "); ++ ++ // gemm2 ++ M[1] = tgt_seq_len; ++ N[1] = seq_len; ++ K[1] = size_per_head; ++ gemm_ops[1][0] = CUBLAS_OP_T; ++ gemm_ops[1][1] = CUBLAS_OP_N; ++ ++ gemm_lds[1][0] = size_per_head; ++ gemm_lds[1][1] = size_per_head; ++ gemm_lds[1][2] = tgt_seq_len; ++ ++ gemm_strides[0][0] = tgt_seq_len * size_per_head; ++ gemm_strides[0][1] = seq_len * size_per_head; ++ gemm_strides[0][2] = seq_len * tgt_seq_len; ++ strcpy(mess[1], "cublasGemmStridedBatchedEx"); ++ ++ // gemm3 ++ M[2] = size_per_head; ++ N[2] = seq_len; ++ K[2] = tgt_seq_len; ++ gemm_ops[2][0] = CUBLAS_OP_N; ++ gemm_ops[2][1] = CUBLAS_OP_N; ++ ++ gemm_lds[2][0] = size_per_head; ++ gemm_lds[2][1] = tgt_seq_len; ++ gemm_lds[2][2] = size_per_head; ++ ++ gemm_strides[1][0] = tgt_seq_len * size_per_head; ++ gemm_strides[1][1] = seq_len * tgt_seq_len; ++ gemm_strides[1][2] = seq_len * size_per_head; ++ strcpy(mess[2], "cublasGemmStridedBatchedEx"); ++ ++ // gemm4 ++ M[3] = hidden_size; ++ N[3] = batch_size * seq_len; ++ K[3] = hidden_size; ++ gemm_ops[3][0] = CUBLAS_OP_N; ++ gemm_ops[3][1] = CUBLAS_OP_N; ++ ++ gemm_lds[3][0] = hidden_size; ++ gemm_lds[3][1] = hidden_size; ++ gemm_lds[3][2] = hidden_size; ++ strcpy(mess[3], "cublasGemmEx"); ++ ++ cublasHandle_t cublas_handle; ++ check_cuda_error(cublasCreate(&cublas_handle)); ++ cublasLtHandle_t ltHandle; ++ check_cuda_error(cublasLtCreate(<Handle)); ++ ++ cudaDataType_t AType; ++ cudaDataType_t BType; ++ cudaDataType_t CType; ++ cublasComputeType_t computeType; ++ int startAlgo, endAlgo; ++ const int ites = 10000; ++ const int warmup_ites = 10000; ++ struct timeval start, end; ++ ++ CublasDataType data_type; ++ if (std::is_same::value) { ++ data_type = FLOAT_DATATYPE; ++ AType = CUDA_R_32F; ++ BType = CUDA_R_32F; ++ CType = CUDA_R_32F; ++ computeType = CUBLAS_COMPUTE_32F_FAST_TF32; ++ startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; ++ endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; ++ } ++ else if (std::is_same::value) { ++ data_type = HALF_DATATYPE; ++ AType = CUDA_R_16F; ++ BType = CUDA_R_16F; ++ CType = CUDA_R_16F; ++ computeType = CUBLAS_COMPUTE_16F; ++ startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; ++ endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; ++ } ++#ifdef ENABLE_BF16 ++ else if (std::is_same::value) { ++ data_type = BFLOAT16_DATATYPE; ++ AType = CUDA_R_16BF; ++ BType = CUDA_R_16BF; ++ CType = CUDA_R_16BF; ++ computeType = CUBLAS_COMPUTE_32F; ++ startAlgo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; ++ endAlgo = (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; ++ } ++#endif ++ using scaleT = typename ScaleTypeConverter::Type; ++ ++ scaleT alpha = (scaleT)1.0f; ++ scaleT beta = (scaleT)0.0f; ++ ++ printf("***Encoder Gemm Testing Begin***\n"); ++ printf("***Cublas Gemm Testing Begin***\n"); ++ if (line_count == 0) { ++ fprintf(fd, ++ "batch_size, seq_len, head_num, size_per_head dataType ### batchCount, n, m, k, algoId, " ++ "customOption, tile, numSplitsK, swizzle, reductionScheme, workspaceSize, stages, exec_time\n"); ++ } ++ for (int i = 0; i < gemm_num; ++i) { ++ // if(i != 0 && i != 5) continue; ++ ++ int m = M[i], n = N[i], k = K[i]; ++ printf("\n-----------------------------\n"); ++ printf("GEMM test %d: [M: %d, K: %d, N: %d] %s\n", i, m, k, n, mess[i]); ++ // printf("GEMM test %d: [M: %d, K: %d, N: %d] \n", i, m, k, n); ++ T* d_A = (T*)buffer; ++ T* d_B = d_A + m * k * batchCount[i]; ++ T* d_C = d_B + k * n * batchCount[i]; ++ ++ // array of pointer for batchedGemm ++ T* harray[12]; ++ harray[0] = (T*)buffer; ++ harray[1] = (T*)((char*)buffer + sizeof(T) * m * k); ++ harray[2] = (T*)((char*)buffer + 2 * sizeof(T) * m * k); ++ harray[4] = (T*)((char*)buffer + 3 * sizeof(T) * m * k); ++ harray[5] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + sizeof(T) * k * n); ++ harray[6] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 2 * sizeof(T) * k * n); ++ harray[8] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n); ++ harray[9] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + sizeof(T) * m * n); ++ harray[10] = (T*)((char*)buffer + 3 * sizeof(T) * m * k + 3 * sizeof(T) * k * n + 2 * sizeof(T) * m * n); ++ ++ T** darray = 0; ++ check_cuda_error(cudaMalloc((void**)&darray, sizeof(T*) * 12)); ++ cudaMemcpy((void*)darray, (void*)harray, sizeof(T*) * 12, cudaMemcpyHostToDevice); ++ T** dAarray = darray; ++ T** dBarray = darray + 4; ++ T** dCarray = darray + 8; ++ ++ float exec_time = 99999.0f; ++ int fast_algo = 0; ++ ++ // warmup ++ // for (int j = 0; j < ites*10; j++) { ++ // cublasGemmEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], d_B, BType, ++ // gemm_lds[i][1], &beta, d_C, CType, gemm_lds[i][2], computeType, static_cast(0)); ++ // } ++ ++ for (int algo = startAlgo; algo <= endAlgo; algo++) { ++ cublasStatus_t status; ++ //warmup ++ for (int ite = 0; ite < warmup_ites; ++ite) { ++ if ((i == 0) || (i == 3)) { ++ status = cublasGemmEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], d_B, BType, ++ gemm_lds[i][1], &beta, d_C, CType, gemm_lds[i][2], computeType, static_cast(algo)); ++ } else { ++ status = cublasGemmStridedBatchedEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], ++ gemm_strides[i-1][0], d_B, BType, gemm_lds[i][1], gemm_strides[i-1][1], &beta, d_C, CType, ++ gemm_lds[i][2], gemm_strides[i-1][2], batch_size, computeType, static_cast(algo)); ++ } ++ } ++ cudaDeviceSynchronize(); ++ gettimeofday(&start, NULL); ++ if ((i == 0) || (i == 3)) { ++ for (int ite = 0; ite < ites; ++ite) { ++ status = cublasGemmEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], d_B, BType, ++ gemm_lds[i][1], &beta, d_C, CType, gemm_lds[i][2], computeType, static_cast(algo)); ++ } ++ } else { ++ for (int ite = 0; ite < ites; ++ite) { ++ status = cublasGemmStridedBatchedEx(cublas_handle, gemm_ops[i][0], gemm_ops[i][1], m, n, k, &alpha, d_A, AType, gemm_lds[i][0], ++ gemm_strides[i-1][0], d_B, BType, gemm_lds[i][1], gemm_strides[i-1][1], &beta, d_C, CType, ++ gemm_lds[i][2], gemm_strides[i-1][2], batch_size, computeType, static_cast(algo)); ++ } ++ } ++ ++ if (status != CUBLAS_STATUS_SUCCESS) { ++ break; ++ } ++ // } ++ cudaDeviceSynchronize(); ++ gettimeofday(&end, NULL); ++ if (status == CUBLAS_STATUS_SUCCESS) { ++ printf("algo_%d costs %.6fms \n", algo, diffTime(start, end) / ites); ++ if (diffTime(start, end) / ites < exec_time) { ++ exec_time = diffTime(start, end) / ites; ++ fast_algo = algo; ++ } ++ } ++ } ++ printf("fast_algo %d costs %.6f ms \n", fast_algo, exec_time); ++ ++ // for fp16 and bf16, we compare cublasLt ++ if (i < 3 && data_type != FLOAT_DATATYPE) { ++ printf("***cublasLt Gemm Testing Beign***\n"); ++ // Let try a fixed number of combinations ++ int ALGO_COMBINATIONS = 5000; ++ customMatmulPerf_t perfResults[ALGO_COMBINATIONS]; ++ LtHgemmCustomFind(ltHandle, ++ batch_size, ++ seq_len, ++ head_num, ++ size_per_head, ++ n, ++ m, ++ k, ++ &alpha, ++ d_B, ++ d_A, ++ &beta, ++ d_C, ++ cublas_workspace, ++ workSpaceSize, ++ fd, ++ perfResults, ++ ALGO_COMBINATIONS); ++ if (perfResults[0].time < exec_time) { ++ printPerfStructure( ++ batch_size, seq_len, head_num, size_per_head, n, m, k, perfResults[0], fd, data_type, 0); ++ exec_time = perfResults[0].time; ++ } ++ else { ++ fprintf(fd, ++ "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 %f\n", ++ batch_size, ++ seq_len, ++ head_num, ++ size_per_head, ++ data_type, ++ batchCount[i], ++ n, ++ m, ++ k, ++ fast_algo, ++ exec_time); ++ } ++ printf("***cublasLt Gemm Testing End***\n"); ++ } ++ else { ++ fprintf(fd, ++ "%d %d %d %d %d ### %d %d %d %d %d -1 -1 -1 -1 -1 -1 -1 %f\n", ++ batch_size, ++ seq_len, ++ head_num, ++ size_per_head, ++ data_type, ++ batchCount[i], ++ n, ++ m, ++ k, ++ fast_algo, ++ exec_time); ++ } ++ exec_times[i] = exec_time; ++ cudaFree(darray); ++ } ++ printf("***cublas Gemm Testing End***\n\n"); ++ fclose(fd); ++ printf("***Encoder Gemm Testing End***\n"); ++ ++ return; ++} ++ ++template void generate_ms_gemm_config( ++ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); ++template void generate_ms_gemm_config( ++ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); ++#ifdef ENABLE_BF16 ++template void generate_ms_gemm_config<__nv_bfloat16>( ++ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend); ++#endif ++ ++} // namespace fastertransformer +diff --git a/src/fastertransformer/utils/gemm_test/ms_gemm_func.h b/src/fastertransformer/utils/gemm_test/ms_gemm_func.h +new file mode 100644 +index 0000000..c6f68ca +--- /dev/null ++++ b/src/fastertransformer/utils/gemm_test/ms_gemm_func.h +@@ -0,0 +1,40 @@ ++/* ++ * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. ++ * ++ * Licensed under the Apache License, Version 2.0 (the "License"); ++ * you may not use this file except in compliance with the License. ++ * You may obtain a copy of the License at ++ * ++ * http://www.apache.org/licenses/LICENSE-2.0 ++ * ++ * Unless required by applicable law or agreed to in writing, software ++ * distributed under the License is distributed on an "AS IS" BASIS, ++ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++ * See the License for the specific language governing permissions and ++ * limitations under the License. ++ */ ++ ++#pragma once ++ ++#include "src/fastertransformer/utils/cublasAlgoMap.h" ++#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" ++#include "src/fastertransformer/utils/cuda_utils.h" ++#include "src/fastertransformer/utils/gemm_test/gemm_func.h" ++ ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++#include ++ ++namespace fastertransformer { ++ ++template ++void generate_ms_gemm_config( ++ int batch_size, int seq_len, int tgt_seq_len, int head_num, int size_per_head, void* buffer, bool isAppend = true); ++ ++} // namespace fastertransformer diff --git a/src/fastertransformer/utils/logger.h b/src/fastertransformer/utils/logger.h index bcdf8fa..e3e7007 100644 --- a/src/fastertransformer/utils/logger.h