Adding support for FP16, cross & T5 MHA
This commit is contained in:
parent
d870c9090c
commit
dff877dbd3
|
@ -2,6 +2,9 @@ set(REQ_URL "https://github.com/NVIDIA/FasterTransformer/archive/refs/tags/relea
|
|||
set(SHA256 "7adffe2d53b3c1544295a6b7d1887e59b044eba25dd3e150bc909168d5e99081")
|
||||
set(ft_libs "transformer-shared")
|
||||
|
||||
if(DEFINED ENV{MSLITE_GPU_ARCH})
|
||||
set(arch_opt -DSM=$ENV{MSLITE_GPU_ARCH})
|
||||
endif()
|
||||
|
||||
mindspore_add_pkg(fast_transformers
|
||||
VER 0.5.0
|
||||
|
@ -10,7 +13,7 @@ mindspore_add_pkg(fast_transformers
|
|||
LIBS ${ft_libs}
|
||||
LIB_PATH lib
|
||||
PATCHES ${MINDSPORE_PROJECT_DIR}/third_party/patch/fast_transformer/001-fast_transformer.patch
|
||||
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release -DEXAMPLES=off)
|
||||
CMAKE_OPTION -DCMAKE_BUILD_TYPE=Release ${arch_opt} -DEXAMPLES=off)
|
||||
include_directories(${fast_transformers_INC})
|
||||
|
||||
add_library(mindspore::fast_transformers ALIAS fast_transformers::transformer-shared)
|
||||
|
|
|
@ -124,7 +124,8 @@ static BaseRef GetVar(const BaseRef &x) {
|
|||
|
||||
EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) {
|
||||
MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString();
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
|
||||
if (equiv == nullptr) MS_EXCEPTION_IF_NULL(equiv);
|
||||
if (utils::isa<VarPtr>(pattern)) {
|
||||
VarPtr var = utils::cast<VarPtr>(pattern);
|
||||
if (var->matches(expr)) {
|
||||
|
|
|
@ -39,13 +39,19 @@ int AttentionInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor
|
|||
if (q_weight->shape_size_ != C2NUM) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
int batch = (q_input->shape_size_ == 2) ? 1 : q_input->shape_[0];
|
||||
int f_seq = (q_input->shape_size_ == 2) ? q_input->shape_[0] : q_input->shape_[1];
|
||||
int batch = (q_input->shape_size_ == C2NUM) ? 1 : q_input->shape_[0];
|
||||
int f_seq = (q_input->shape_size_ == C2NUM) ? q_input->shape_[0] : q_input->shape_[1];
|
||||
int t_seq_len = k_input->shape_[1];
|
||||
if (q_input->shape_size_ == C2NUM) {
|
||||
output0->shape_[FIRST_INPUT] = batch * f_seq;
|
||||
output0->shape_[SECOND_INPUT] = param->head_num_ * param->head_size_;
|
||||
output0->shape_size_ = C2NUM;
|
||||
} else {
|
||||
output0->shape_[FIRST_INPUT] = batch;
|
||||
output0->shape_[SECOND_INPUT] = f_seq;
|
||||
output0->shape_[THIRD_INPUT] = param->head_num_ * param->head_size_;
|
||||
output0->shape_size_ = C3NUM;
|
||||
}
|
||||
if (outputs_size >= C3NUM) {
|
||||
TensorC *output1 = outputs[SECOND_INPUT];
|
||||
SetDataTypeFormat(output1, q_input);
|
||||
|
|
|
@ -71,7 +71,7 @@ void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *
|
|||
|
||||
void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
|
||||
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
|
||||
cublasHandle_t cublas_handle) {
|
||||
cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) {
|
||||
const int m = params[0];
|
||||
const int n = params[1];
|
||||
const int k = params[2];
|
||||
|
@ -84,15 +84,17 @@ void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, con
|
|||
cudaDataType type_b = data_types[1];
|
||||
cudaDataType type_c = data_types[2];
|
||||
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
|
||||
if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
|
||||
compute_type = CUBLAS_COMPUTE_16F;
|
||||
}
|
||||
CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda, b_addr, type_b,
|
||||
ldb, beta, c_addr, type_c, ldc, compute_type, CUBLAS_GEMM_DEFAULT));
|
||||
ldb, beta, c_addr, type_c, ldc, compute_type, algo));
|
||||
}
|
||||
|
||||
void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
|
||||
const int *lds, const cublasOperation_t *operations, const int *strides,
|
||||
const cudaDataType *data_types, void *alpha, void *beta, int batch,
|
||||
cublasHandle_t cublas_handle) {
|
||||
cublasHandle_t cublas_handle, cublasGemmAlgo_t algo) {
|
||||
const int m = params[0];
|
||||
const int n = params[1];
|
||||
const int k = params[2];
|
||||
|
@ -105,12 +107,62 @@ void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, voi
|
|||
cudaDataType type_b = data_types[1];
|
||||
cudaDataType type_c = data_types[2];
|
||||
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
|
||||
compute_type = CUBLAS_COMPUTE_16F;
|
||||
}
|
||||
const int stride_a = strides[0];
|
||||
const int stride_b = strides[1];
|
||||
const int stride_c = strides[2];
|
||||
|
||||
CUBLAS_CHECK_VOID(cublasGemmStridedBatchedEx(cublas_handle, trans_a, trans_b, m, n, k, alpha, a_addr, type_a, lda,
|
||||
stride_a, b_addr, type_b, ldb, stride_b, beta, c_addr, type_c, ldc,
|
||||
stride_c, batch, compute_type, CUBLAS_GEMM_DEFAULT));
|
||||
stride_c, batch, compute_type, algo));
|
||||
}
|
||||
|
||||
void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
|
||||
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
|
||||
const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle) {
|
||||
cublasOperation_t trans_a = operations[0];
|
||||
cublasOperation_t trans_b = operations[1];
|
||||
cudaDataType type_a = data_types[0];
|
||||
cudaDataType type_b = data_types[1];
|
||||
cudaDataType type_c = data_types[2];
|
||||
const int m = params[0];
|
||||
const int n = params[1];
|
||||
const int k = params[2];
|
||||
const int lda = lds[0];
|
||||
const int ldb = lds[1];
|
||||
const int ldc = lds[2];
|
||||
|
||||
cublasLtMatrixLayout_t mat_a_desc = NULL;
|
||||
cublasLtMatrixLayoutCreate(&mat_a_desc, type_a, (trans_a == CUBLAS_OP_N) ? m : k, (trans_a == CUBLAS_OP_N) ? k : m,
|
||||
lda);
|
||||
cublasLtMatrixLayout_t mat_b_desc = NULL;
|
||||
cublasLtMatrixLayoutCreate(&mat_b_desc, type_b, (trans_b == CUBLAS_OP_N) ? k : n, (trans_b == CUBLAS_OP_N) ? n : k,
|
||||
ldb);
|
||||
cublasLtMatrixLayout_t mat_c_desc = NULL;
|
||||
cublasLtMatrixLayoutCreate(&mat_c_desc, type_c, m, n, ldc);
|
||||
|
||||
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
if ((type_a == CUDA_R_16F) && (type_b == CUDA_R_16F) && (type_c == CUDA_R_16F)) {
|
||||
compute_type = CUBLAS_COMPUTE_16F;
|
||||
}
|
||||
|
||||
cublasLtMatmulDesc_t mat_operation_desc = NULL;
|
||||
cublasLtMatmulDescCreate(&mat_operation_desc, compute_type, type_a);
|
||||
cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(cublasOperation_t));
|
||||
cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(cublasOperation_t));
|
||||
if (bias != nullptr) {
|
||||
cublasLtEpilogue_t epi = CUBLASLT_EPILOGUE_BIAS;
|
||||
cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epi, sizeof(cublasLtEpilogue_t));
|
||||
cublasLtMatmulDescSetAttribute(mat_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(const void *));
|
||||
}
|
||||
|
||||
cublasLtMatmul(cublaslt_handle, mat_operation_desc, alpha, a_addr, mat_a_desc, b_addr, mat_b_desc, beta, c_addr,
|
||||
mat_c_desc, c_addr, mat_c_desc, NULL, NULL, 0, stream);
|
||||
cublasLtMatrixLayoutDestroy(mat_a_desc);
|
||||
cublasLtMatrixLayoutDestroy(mat_b_desc);
|
||||
cublasLtMatrixLayoutDestroy(mat_c_desc);
|
||||
cublasLtMatmulDescDestroy(mat_operation_desc);
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -62,10 +62,14 @@ void CublasMMBatched(void **a_addrs, void **b_addrs, void **c_addrs, const int *
|
|||
|
||||
void CublasGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
|
||||
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
|
||||
cublasHandle_t cublas_handle);
|
||||
cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
void CublasGemmStridedBatchedWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params,
|
||||
const int *lds, const cublasOperation_t *operations, const int *strides,
|
||||
const cudaDataType *data_types, void *alpha, void *beta, int batch,
|
||||
cublasHandle_t cublas_handle);
|
||||
cublasHandle_t cublas_handle, cublasGemmAlgo_t = CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
||||
void CublasLtGemmWrapper(const void *a_addr, const void *b_addr, void *c_addr, const int *params, const int *lds,
|
||||
const cublasOperation_t *operations, const cudaDataType *data_types, void *alpha, void *beta,
|
||||
const void *bias, cudaStream_t stream, cublasLtHandle_t cublaslt_handle);
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_CUDA_IMPL_CUBLAS_UTILS_H_
|
||||
|
|
|
@ -34,13 +34,38 @@ namespace mindspore::lite {
|
|||
namespace {
|
||||
constexpr std::size_t kTwo = 2;
|
||||
constexpr std::size_t kThree = 3;
|
||||
std::ostream &operator<<(std::ostream &s, const nvinfer1::ITensor &t) {
|
||||
const auto &dims = t.getDimensions();
|
||||
s << "ndims=" << dims.nbDims << " [";
|
||||
for (int i = 0; i < dims.nbDims; ++i) {
|
||||
s << dims.d[i] << " ";
|
||||
}
|
||||
s << "]";
|
||||
return s;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#define SET_GEMM_PARAMS(gemm_ops_, gemm_lds_, gemm_op1_, gemm_op2_, gemm_ld1_, gemm_ld2_, gemm_ld3_) \
|
||||
do { \
|
||||
gemm_ops_[0] = gemm_op1_; \
|
||||
gemm_ops_[1] = gemm_op2_; \
|
||||
gemm_lds_[0] = gemm_ld1_; \
|
||||
gemm_lds_[1] = gemm_ld2_; \
|
||||
gemm_lds_[2] = gemm_ld3_; \
|
||||
} while (0)
|
||||
|
||||
#define SET_GEMM_DIMS(gemm_dims_, gemm_dim1_, gemm_dim2_, gemm_dim3_) \
|
||||
do { \
|
||||
gemm_dims_[0] = gemm_dim1_; \
|
||||
gemm_dims_[1] = gemm_dim2_; \
|
||||
gemm_dims_[2] = gemm_dim3_; \
|
||||
} while (0)
|
||||
|
||||
// Multi Head Attention TensorRT op
|
||||
int MhaTensorRT::IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
||||
const std::vector<TensorInfo> &out_tensors) {
|
||||
if (in_tensors.size() != 8 && in_tensors.size() != 6) {
|
||||
MS_LOG(ERROR) << "Unsupported input tensor size, size is " << in_tensors.size();
|
||||
if (in_tensors.size() < 7 || in_tensors.size() > 9) { // T5 has 6 or 7 inputs, other models have 8 or 9 inputs
|
||||
MS_LOG(ERROR) << "Unsupported number of inputs, size is " << in_tensors.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -59,13 +84,13 @@ int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
|||
// get attribute for Attn op - TODO - add attribute in op
|
||||
int head_number = mha_op->get_head_num();
|
||||
int head_size = mha_op->get_head_size();
|
||||
int compute_type = 1; // mha_op->get_compute_type();
|
||||
int is_cross = mha_op->get_cross();
|
||||
nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_;
|
||||
|
||||
auto plugin = std::make_shared<MhaPlugin>(input_tensor->getName(), compute_type, head_number, head_size, is_cross,
|
||||
GetCublasHandle(), GetCublasLtHandle(), device_id_);
|
||||
auto compute_type = runtime_->GetRuntimePrecisionMode(); // mha_op->get_compute_type();
|
||||
bool is_cross = mha_op->get_cross();
|
||||
const int input_number = inputs().size();
|
||||
bool is_position_bias = (((input_number == 8) && is_cross) || ((input_number == 7) && !is_cross)) ? true : false;
|
||||
nvinfer1::ITensor *input_tensor = input(ctx, 0).trt_tensor_;
|
||||
auto plugin = std::make_shared<MhaPlugin>(input_tensor->getName(), compute_type, head_number, head_size, is_cross,
|
||||
is_position_bias, GetCublasHandle(), GetCublasLtHandle(), device_id_);
|
||||
nvinfer1::ITensor *inputTensors[input_number];
|
||||
for (int i = 0; i < input_number; i++) {
|
||||
inputTensors[i] = input(ctx, i).trt_tensor_;
|
||||
|
@ -75,16 +100,46 @@ int MhaTensorRT::AddInnerOp(TensorRTContext *ctx) {
|
|||
MS_LOG(ERROR) << "add mha op failed for TensorRT.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
mha_layer->setName(op_name_.c_str());
|
||||
mha_layer->setName((op_name_ + "plugin_attention").c_str());
|
||||
// TODO(haim) one output
|
||||
nvinfer1::ITensor *attn_tensor = mha_layer->getOutput(0);
|
||||
#ifndef TEST_
|
||||
ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name());
|
||||
#else /* TEST_ */
|
||||
ctx->RegisterTensor(ITensorHelper{attn_tensor, Format::NCHW, true}, out_tensors_[0].Name() + "attn");
|
||||
#endif /* TEST_ */
|
||||
// nvinfer1::ITensor *key_tensor = mha_layer->getOutput(1);
|
||||
// ctx->RegisterTensor(ITensorHelper{key_tensor, Format::NCHW, true}, out_tensors_[1].Name());
|
||||
// nvinfer1::ITensor *value_tensor = mha_layer->getOutput(kTwo);
|
||||
// ctx->RegisterTensor(ITensorHelper{value_tensor, Format::NCHW, true}, out_tensors_[kTwo].Name());
|
||||
this->layer_ = mha_layer;
|
||||
#ifdef TEST_
|
||||
auto weight_projection = input(ctx, 4).trt_tensor_;
|
||||
auto bias_projection = input(ctx, 6).trt_tensor_;
|
||||
#endif /* TEST_ */
|
||||
|
||||
#ifdef TEST_
|
||||
auto matmul_layer = ctx->network()->addMatrixMultiply(*attn_tensor, nvinfer1::MatrixOperation::kNONE,
|
||||
*weight_projection, nvinfer1::MatrixOperation::kNONE);
|
||||
if (matmul_layer == nullptr) {
|
||||
MS_LOG(ERROR) << "failed to add matmul layer";
|
||||
return RET_ERROR;
|
||||
}
|
||||
matmul_layer->setName((op_name_ + "_matmul").c_str());
|
||||
auto matmul_tensor = matmul_layer->getOutput(0);
|
||||
auto shuffle_layer = ctx->network()->addShuffle(*bias_projection);
|
||||
const auto size = bias_projection->getDimensions().d[0];
|
||||
shuffle_layer->setReshapeDimensions(nvinfer1::Dims{2, {1, size}});
|
||||
auto shuffle_tensor = shuffle_layer->getOutput(0);
|
||||
auto addbias = ctx->network()->addElementWise(*matmul_tensor, *shuffle_tensor, nvinfer1::ElementWiseOperation::kSUM);
|
||||
if (addbias == nullptr) {
|
||||
MS_LOG(ERROR) << "failed to add bias layer";
|
||||
return RET_ERROR;
|
||||
}
|
||||
addbias->setName((op_name_ + "_bias").c_str());
|
||||
auto bias_out = addbias->getOutput(0);
|
||||
ctx->RegisterTensor(ITensorHelper{bias_out, Format::NCHW, true}, out_tensors_[0].Name());
|
||||
#endif /* TEST_ */
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -98,11 +153,80 @@ std::vector<nvinfer1::PluginField> TensorRTPluginCreater<T>::fields_;
|
|||
|
||||
int MhaPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) noexcept {
|
||||
return RunCudaMha(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
if (compute_type_ == RuntimePrecisionMode_FP16) {
|
||||
return RunCudaMha<half>(inputDesc, outputDesc, inputs, outputs, workspace, stream, fast_algo_gemm);
|
||||
} else {
|
||||
return RunCudaMha<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream, fast_algo_gemm);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MhaPlugin::SetInnerAddr(void *workspace, size_t size_q, size_t size_k, size_t qk_buf_len, size_t qkv_buf_2_len,
|
||||
size_t extra_size) {
|
||||
size_t qkv_len = size_q + (size_k * 2); // size_v is equal to size_k
|
||||
size_t q_buf_2_len = size_q;
|
||||
auto buff_size =
|
||||
qkv_len + q_buf_2_len + qk_buf_len + (qkv_buf_2_len * 2); // qkv_buf_3_ len is equal to qkv_buf_2_len
|
||||
qkv_buf_ = workspace;
|
||||
q_buf_2_ = static_cast<T *>(qkv_buf_) + qkv_len;
|
||||
qk_buf_ = static_cast<T *>(q_buf_2_) + q_buf_2_len;
|
||||
qkv_buf_2_ = static_cast<T *>(qk_buf_) + qk_buf_len;
|
||||
qkv_buf_3_ = static_cast<T *>(qkv_buf_2_) + qkv_buf_2_len;
|
||||
output1_ = static_cast<T *>(workspace) + buff_size;
|
||||
output2_ = static_cast<T *>(output1_) + extra_size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MhaPlugin::RunPhase1GEMM(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, int *gemm_dims,
|
||||
int *gemm_lds, cublasOperation_t *gemm_ops, cudaDataType *gemm_data_types, void *alpha,
|
||||
void *beta, cublasGemmAlgo_t algoId, cudaStream_t stream) {
|
||||
int cross_tensor_offset = 0;
|
||||
if (is_cross_) cross_tensor_offset = 1;
|
||||
const int from_tensor_idx = 0, encoder_tensor_idx = 1, weight_qkv_tensor_idx = 3;
|
||||
const int weight_qkv_tensor_idx_cross = 3 + cross_tensor_offset;
|
||||
const int bias_qkv_tensor_idx = 5 + cross_tensor_offset;
|
||||
const int attn_mask_tensor_idx = (is_position_bias_) ? 6 + cross_tensor_offset : 7 + cross_tensor_offset;
|
||||
|
||||
auto from_tensor = static_cast<const T *>(inputs[from_tensor_idx]);
|
||||
auto encoder_output_tensor = static_cast<const T *>(inputs[encoder_tensor_idx]);
|
||||
auto weight_q = static_cast<const T *>(inputs[weight_qkv_tensor_idx]);
|
||||
auto weight_kv = static_cast<const T *>(inputs[weight_qkv_tensor_idx_cross]);
|
||||
auto weight_qkv = static_cast<const T *>(inputs[weight_qkv_tensor_idx_cross]);
|
||||
auto bias_qkv = (is_position_bias_) ? nullptr : static_cast<const T *>(inputs[bias_qkv_tensor_idx]);
|
||||
|
||||
auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims;
|
||||
const int request_batch_size = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[0]);
|
||||
const int request_src_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]);
|
||||
const int request_tgt_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]);
|
||||
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
|
||||
if (is_cross_) {
|
||||
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, hidden_size, hidden_size, hidden_size);
|
||||
SET_GEMM_DIMS(gemm_dims, hidden_size, request_batch_size * request_src_seq_len, hidden_size);
|
||||
CublasGemmWrapper(weight_q, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops,
|
||||
const_cast<const cudaDataType *>(gemm_data_types), alpha, beta, cublas_handle_);
|
||||
SET_GEMM_DIMS(gemm_dims, C2NUM * hidden_size, request_batch_size * request_tgt_seq_len, hidden_size);
|
||||
gemm_lds[0] = gemm_lds[THIRD_INPUT] = C2NUM * hidden_size;
|
||||
|
||||
CublasGemmWrapper(weight_kv, encoder_output_tensor,
|
||||
static_cast<T *>(qkv_buf_) + (request_batch_size * request_src_seq_len) * hidden_size, gemm_dims,
|
||||
gemm_lds, gemm_ops, const_cast<const cudaDataType *>(gemm_data_types), alpha, beta,
|
||||
cublas_handle_);
|
||||
fastertransformer::invokeCrossAddFusedQKVBiasTranspose(
|
||||
static_cast<T *>(q_buf_2_), static_cast<T *>(output1_), static_cast<T *>(output2_), static_cast<T *>(qkv_buf_),
|
||||
bias_qkv, request_batch_size, request_src_seq_len, request_tgt_seq_len, head_number_, head_size_, stream);
|
||||
} else {
|
||||
CublasGemmWrapper(weight_qkv, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops,
|
||||
const_cast<const cudaDataType *>(gemm_data_types), alpha, beta, cublas_handle_, algoId);
|
||||
fastertransformer::invokeAddFusedQKVBiasTranspose(
|
||||
static_cast<T *>(q_buf_2_), static_cast<T *>(output1_), static_cast<T *>(output2_), static_cast<T *>(qkv_buf_),
|
||||
bias_qkv, request_batch_size, request_src_seq_len, head_number_, head_size_, 0, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int MhaPlugin::RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream) {
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream,
|
||||
cublasGemmAlgo_t *algoId) {
|
||||
// inputs order:
|
||||
// 0] Q
|
||||
// 1] K
|
||||
|
@ -112,134 +236,99 @@ int MhaPlugin::RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvi
|
|||
// 5] B
|
||||
// 6] PB
|
||||
// 7] AttnMask
|
||||
|
||||
// inputs order cross:
|
||||
// 0] Q
|
||||
// 1] K enco output
|
||||
// 2] V
|
||||
// 3] Wq
|
||||
// 4] Wkv
|
||||
// 5] PW
|
||||
// 6] Bqkv
|
||||
// 7] PB
|
||||
// 8] AttnMask
|
||||
int cross_tensor_offset = 0;
|
||||
cublasSetStream(cublas_handle_, stream);
|
||||
if (is_cross_) cross_tensor_offset = 1;
|
||||
const int weight_projection_tensor_idx = 4 + cross_tensor_offset;
|
||||
const int bias_projection_tensor_idx = 6 + cross_tensor_offset;
|
||||
const int attn_mask_tensor_idx = (is_position_bias_) ? 6 + cross_tensor_offset : 7 + cross_tensor_offset;
|
||||
const int bias_position_tensor_idx = 5 + cross_tensor_offset;
|
||||
|
||||
// TODO(Haim) - Fix tensor ids according to cross flag
|
||||
const int from_tensor_idx = 0;
|
||||
// const int encoder_tensor_idx = 1;
|
||||
const int weight_qkv_tensor_idx = 3;
|
||||
const int weight_projection_tensor_idx = 4;
|
||||
const int bias_qkv_tensor_idx = 5;
|
||||
const int bias_projection_tensor_idx = 6;
|
||||
const int attn_mask_tensor_idx = 7;
|
||||
|
||||
auto from_tensor = static_cast<const float *>(inputs[from_tensor_idx]);
|
||||
auto attention_mask = static_cast<const float *>(inputs[attn_mask_tensor_idx]);
|
||||
auto weight_qkv = static_cast<const float *>(inputs[weight_qkv_tensor_idx]);
|
||||
auto bias_qkv = static_cast<const float *>(inputs[bias_qkv_tensor_idx]);
|
||||
auto weight_projection = static_cast<const float *>(inputs[weight_projection_tensor_idx]);
|
||||
auto bias_projection = static_cast<const float *>(inputs[bias_projection_tensor_idx]);
|
||||
|
||||
auto output0 = static_cast<float *>(outputs[0]);
|
||||
// auto output1 = static_cast<float *>(outputs[1]);
|
||||
// auto output2 = static_cast<float *>(outputs[2]);
|
||||
|
||||
auto attention_mask = static_cast<const T *>(inputs[attn_mask_tensor_idx]);
|
||||
auto weight_projection = static_cast<const T *>(inputs[weight_projection_tensor_idx]);
|
||||
auto bias_projection = (is_position_bias_) ? nullptr : static_cast<const T *>(inputs[bias_projection_tensor_idx]);
|
||||
auto bias_position = (is_position_bias_) ? static_cast<const T *>(inputs[bias_position_tensor_idx]) : nullptr;
|
||||
auto output0 = static_cast<T *>(outputs[0]);
|
||||
auto attn_dim_size = inputDesc[attn_mask_tensor_idx].dims.nbDims;
|
||||
const int request_batch_size = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[0]);
|
||||
const int request_src_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 2]);
|
||||
const int request_tgt_seq_len = static_cast<const int>(inputDesc[attn_mask_tensor_idx].dims.d[attn_dim_size - 1]);
|
||||
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
|
||||
|
||||
// TODO(NIZZAN): fix allocator
|
||||
size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
|
||||
auto extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
|
||||
|
||||
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
|
||||
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
|
||||
size_t size_v = size_k;
|
||||
|
||||
size_t qkv_len = size_q + size_k + size_v;
|
||||
size_t q_buf_2_len = size_q;
|
||||
size_t qk_buf_len = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len;
|
||||
size_t qkv_buf_2_len = request_batch_size * request_src_seq_len * hidden_size;
|
||||
size_t qkv_buf_3_len = qkv_buf_2_len;
|
||||
size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len;
|
||||
qkv_buf_ = workspace;
|
||||
q_buf_2_ = static_cast<float *>(qkv_buf_) + qkv_len;
|
||||
qk_buf_ = static_cast<float *>(q_buf_2_) + q_buf_2_len;
|
||||
qkv_buf_2_ = static_cast<float *>(qk_buf_) + qk_buf_len;
|
||||
qkv_buf_3_ = static_cast<float *>(qkv_buf_2_) + qkv_buf_2_len;
|
||||
output1_ = static_cast<float *>(workspace) + buff_size;
|
||||
output2_ = static_cast<float *>(output1_) + extra_tmp_size;
|
||||
SetInnerAddr<T>(workspace, size_q, size_k, qk_buf_len, qkv_buf_2_len, extra_tmp_size);
|
||||
|
||||
int gemm_dims[3] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size};
|
||||
int gemm_lds[3] = {3 * hidden_size, hidden_size, 3 * hidden_size};
|
||||
cublasOperation_t gemm_ops[] = {CUBLAS_OP_N, CUBLAS_OP_N};
|
||||
cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F};
|
||||
if constexpr (std::is_same<T, half>::value)
|
||||
std::fill(std::begin(gemm_data_types), std::end(gemm_data_types), CUDA_R_16F);
|
||||
float alpha = 1.0f, beta = 0.0f;
|
||||
int gemm_dims[] = {3 * hidden_size, request_batch_size * request_src_seq_len, hidden_size};
|
||||
int gemm_lds[] = {3 * hidden_size, hidden_size, 3 * hidden_size};
|
||||
|
||||
cublasOperation_t gemm_ops[2] = {CUBLAS_OP_N, CUBLAS_OP_N};
|
||||
const cudaDataType gemm_data_types[] = {CUDA_R_32F, CUDA_R_32F, CUDA_R_32F};
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
|
||||
CublasGemmWrapper(weight_qkv, from_tensor, qkv_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha, &beta,
|
||||
cublas_handle_);
|
||||
|
||||
fastertransformer::invokeAddFusedQKVBiasTranspose(static_cast<float *>(q_buf_2_), static_cast<float *>(output1_),
|
||||
static_cast<float *>(output2_), static_cast<float *>(qkv_buf_),
|
||||
bias_qkv, request_batch_size, request_src_seq_len, head_number_,
|
||||
head_size_, 0, stream);
|
||||
gemm_ops[0] = CUBLAS_OP_T;
|
||||
gemm_ops[1] = CUBLAS_OP_N;
|
||||
gemm_dims[0] = request_tgt_seq_len;
|
||||
gemm_dims[1] = request_src_seq_len;
|
||||
gemm_dims[THIRD_INPUT] = head_size_;
|
||||
|
||||
gemm_lds[0] = head_size_;
|
||||
gemm_lds[1] = head_size_;
|
||||
gemm_lds[THIRD_INPUT] = request_tgt_seq_len;
|
||||
RunPhase1GEMM<T>(inputDesc, inputs, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha, &beta, algoId[0], stream);
|
||||
|
||||
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_T, CUBLAS_OP_N, head_size_, head_size_, request_tgt_seq_len);
|
||||
SET_GEMM_DIMS(gemm_dims, request_tgt_seq_len, request_src_seq_len, head_size_);
|
||||
int gemm_strides[] = {request_tgt_seq_len * head_size_, request_src_seq_len * head_size_,
|
||||
request_src_seq_len * request_tgt_seq_len};
|
||||
|
||||
CublasGemmStridedBatchedWrapper(output1_, q_buf_2_, qk_buf_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
|
||||
gemm_data_types, &alpha, &beta, request_batch_size * head_number_, cublas_handle_);
|
||||
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta,
|
||||
request_batch_size * head_number_, cublas_handle_, algoId[1]);
|
||||
|
||||
float scalar = (1.0f / sqrtf(static_cast<float>(head_size_) * 1.0f));
|
||||
fastertransformer::invokeMixMaskedSoftMax(static_cast<float *>(qk_buf_), attention_mask, request_batch_size,
|
||||
request_src_seq_len, request_tgt_seq_len, head_number_, scalar, stream);
|
||||
gemm_ops[0] = CUBLAS_OP_N;
|
||||
gemm_ops[1] = CUBLAS_OP_N;
|
||||
gemm_dims[0] = head_size_;
|
||||
gemm_dims[1] = request_src_seq_len;
|
||||
gemm_dims[THIRD_INPUT] = request_tgt_seq_len;
|
||||
|
||||
gemm_lds[0] = head_size_;
|
||||
gemm_lds[1] = request_tgt_seq_len;
|
||||
gemm_lds[THIRD_INPUT] = head_size_;
|
||||
|
||||
gemm_strides[0] = request_tgt_seq_len * head_size_;
|
||||
T scalar = static_cast<T>(1.0f / sqrtf(head_size_ * 1.0f));
|
||||
fastertransformer::invokeMixMaskedSoftMax(static_cast<T *>(qk_buf_), attention_mask, bias_position,
|
||||
request_batch_size, request_src_seq_len, request_tgt_seq_len, head_number_,
|
||||
scalar, stream);
|
||||
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, head_size_, request_tgt_seq_len, head_size_);
|
||||
SET_GEMM_DIMS(gemm_dims, head_size_, request_src_seq_len, request_tgt_seq_len);
|
||||
gemm_strides[1] = request_src_seq_len * request_tgt_seq_len;
|
||||
gemm_strides[THIRD_INPUT] = request_src_seq_len * head_size_;
|
||||
|
||||
CublasGemmStridedBatchedWrapper(output2_, qk_buf_, qkv_buf_2_, gemm_dims, gemm_lds, gemm_ops, gemm_strides,
|
||||
gemm_data_types, &alpha, &beta, request_batch_size * head_number_, cublas_handle_);
|
||||
|
||||
fastertransformer::invokeTransposeQKV(static_cast<float *>(qkv_buf_3_), static_cast<float *>(qkv_buf_2_),
|
||||
request_batch_size, request_src_seq_len, head_number_, head_size_, stream);
|
||||
|
||||
gemm_ops[0] = CUBLAS_OP_N;
|
||||
gemm_ops[1] = CUBLAS_OP_N;
|
||||
gemm_dims[0] = hidden_size;
|
||||
gemm_dims[1] = request_batch_size * request_src_seq_len;
|
||||
gemm_dims[THIRD_INPUT] = hidden_size;
|
||||
|
||||
gemm_lds[0] = hidden_size;
|
||||
gemm_lds[1] = hidden_size;
|
||||
gemm_lds[THIRD_INPUT] = hidden_size;
|
||||
|
||||
CublasGemmWrapper(weight_projection, qkv_buf_3_, output0, gemm_dims, gemm_lds, gemm_ops, gemm_data_types, &alpha,
|
||||
&beta, cublas_handle_);
|
||||
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta,
|
||||
request_batch_size * head_number_, cublas_handle_, algoId[2]);
|
||||
fastertransformer::invokeTransposeQKV(static_cast<T *>(qkv_buf_3_), static_cast<T *>(qkv_buf_2_), request_batch_size,
|
||||
request_src_seq_len, head_number_, head_size_, stream);
|
||||
SET_GEMM_PARAMS(gemm_ops, gemm_lds, CUBLAS_OP_N, CUBLAS_OP_N, hidden_size, hidden_size, hidden_size);
|
||||
SET_GEMM_DIMS(gemm_dims, hidden_size, request_batch_size * request_src_seq_len, hidden_size);
|
||||
CublasGemmWrapper(weight_projection, qkv_buf_3_, output0, gemm_dims, gemm_lds, gemm_ops,
|
||||
const_cast<const cudaDataType *>(gemm_data_types), &alpha, &beta, cublas_handle_, algoId[3]);
|
||||
if (!is_position_bias_) {
|
||||
int len = request_batch_size * request_src_seq_len;
|
||||
fastertransformer::invokeAddBias(reinterpret_cast<float *>(output0), reinterpret_cast<const float *>(bias_projection),
|
||||
len, hidden_size, stream);
|
||||
|
||||
fastertransformer::invokeAddBias(reinterpret_cast<T *>(output0), reinterpret_cast<const T *>(bias_projection), len,
|
||||
hidden_size, stream);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int MhaPlugin::RunCudaCrossMha(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||
void *const *outputs, void *workspace, cudaStream_t stream) {
|
||||
// Add Cross Mha Layer here
|
||||
return RET_OK;
|
||||
bool MhaPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
|
||||
int nbOutputs) noexcept {
|
||||
auto type = (compute_type_ == RuntimePrecisionMode_FP16) ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;
|
||||
for (int i = 0; i < pos; i++) {
|
||||
if (tensorsDesc[pos].type != tensorsDesc[i].type) return false;
|
||||
}
|
||||
bool res = (tensorsDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) && (tensorsDesc[pos].type == type);
|
||||
return res;
|
||||
}
|
||||
|
||||
void MhaPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {}
|
||||
|
||||
size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept {
|
||||
|
@ -250,9 +339,6 @@ size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int
|
|||
auto hidden_size = static_cast<const int>(head_number_ * head_size_);
|
||||
|
||||
// TODO(NIZZAN) Fix efficient allocator
|
||||
// size_t buff_size = request_batch_size * head_number_ * request_src_seq_len * request_tgt_seq_len +
|
||||
// request_batch_size * request_src_seq_len * hidden_size;
|
||||
|
||||
size_t size_q = request_batch_size * request_src_seq_len * hidden_size;
|
||||
size_t size_k = request_batch_size * request_tgt_seq_len * hidden_size;
|
||||
size_t size_v = size_k;
|
||||
|
@ -265,8 +351,12 @@ size_t MhaPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int
|
|||
size_t buff_size = qkv_len + q_buf_2_len + qk_buf_len + qkv_buf_2_len + qkv_buf_3_len;
|
||||
|
||||
size_t extra_tmp_size = request_batch_size * head_number_ * head_size_ * request_tgt_seq_len;
|
||||
int elem_size = sizeof(float);
|
||||
if (compute_type_ == RuntimePrecisionMode_FP16) {
|
||||
elem_size = sizeof(half);
|
||||
}
|
||||
|
||||
return (buff_size + extra_tmp_size + extra_tmp_size) * sizeof(float);
|
||||
return (buff_size + extra_tmp_size + extra_tmp_size) * elem_size;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
|
||||
|
@ -282,14 +372,15 @@ nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1
|
|||
// value_cache [batch, head_num, tgt_seq_len, size_per_head]
|
||||
nvinfer1::DimsExprs dims;
|
||||
if (index == 0) {
|
||||
// if (inputs[0].nbDims == 2) {
|
||||
// dims.nbDims = INPUT_SIZE2;
|
||||
// dims.d[0] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
|
||||
// auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
|
||||
// dims.d[1] = hidden_size;
|
||||
// } else
|
||||
{
|
||||
dims.nbDims = INPUT_SIZE3;
|
||||
#ifndef TEST_
|
||||
int num_dims = inputs[0].nbDims;
|
||||
dims.nbDims = num_dims;
|
||||
if (num_dims == INPUT_SIZE2) {
|
||||
dims.d[0] = exprBuilder.constant(inputs[nbInputDims - 1].d[0]->getConstantValue() *
|
||||
inputs[nbInputDims - 1].d[1]->getConstantValue());
|
||||
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
|
||||
dims.d[1] = hidden_size;
|
||||
} else if (num_dims == INPUT_SIZE3) {
|
||||
dims.d[0] = inputs[nbInputDims - 1].d[0]; // batch
|
||||
dims.d[1] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
|
||||
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
|
||||
|
@ -303,6 +394,13 @@ nvinfer1::DimsExprs MhaPlugin::getOutputDimensions(int32_t index, const nvinfer1
|
|||
dims.d[kTwo] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
|
||||
dims.d[kThree] = exprBuilder.constant(head_size_);
|
||||
}
|
||||
#else
|
||||
dims.nbDims = C2NUM;
|
||||
dims.d[0] = inputs[nbInputDims - 1].d[(inputs[nbInputDims - 1].nbDims) - 1];
|
||||
auto hidden_size = exprBuilder.constant(head_size_ * head_number_);
|
||||
dims.d[1] = hidden_size;
|
||||
}
|
||||
#endif
|
||||
return dims;
|
||||
}
|
||||
|
||||
|
@ -316,6 +414,8 @@ nvinfer1::IPluginV2DynamicExt *MhaPlugin::clone() const noexcept {
|
|||
return plugin;
|
||||
}
|
||||
|
||||
int MhaPlugin::initialize() noexcept { return 0; }
|
||||
|
||||
void MhaPlugin::terminate() noexcept {}
|
||||
|
||||
size_t MhaPlugin::getSerializationSize() const noexcept { return INPUT_SIZE4 * sizeof(int); }
|
||||
|
|
|
@ -31,6 +31,8 @@ class MhaTensorRT : public TensorRTOp {
|
|||
: TensorRTOp(base_operator, in_tensors, out_tensors, name) {}
|
||||
|
||||
~MhaTensorRT() override = default;
|
||||
|
||||
// bool IsWeightInputHanledInner() const override { return true; }
|
||||
int AddInnerOp(TensorRTContext *ctx) override;
|
||||
|
||||
int IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
|
||||
|
@ -40,13 +42,14 @@ class MhaTensorRT : public TensorRTOp {
|
|||
constexpr auto MHA_PLUGIN_NAME{"AttentionPlugin"};
|
||||
class MhaPlugin : public TensorRTPlugin {
|
||||
public:
|
||||
MhaPlugin(const std::string name, int compute_type, int head_number, int head_size, int is_cross,
|
||||
cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, uint32_t device_id)
|
||||
MhaPlugin(const std::string name, int compute_type, int head_number, int head_size, bool is_cross,
|
||||
bool is_position_bias, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle, uint32_t device_id)
|
||||
: TensorRTPlugin(name, std::string(MHA_PLUGIN_NAME), device_id),
|
||||
compute_type_(compute_type),
|
||||
head_number_(head_number),
|
||||
head_size_(head_size),
|
||||
is_cross_(is_cross),
|
||||
is_position_bias_(is_position_bias),
|
||||
cublas_handle_(cublas_handle),
|
||||
cublaslt_handle_(cublaslt_handle) {}
|
||||
|
||||
|
@ -57,6 +60,7 @@ class MhaPlugin : public TensorRTPlugin {
|
|||
head_number_ = static_cast<const int *>(fields[1].data)[0];
|
||||
head_size_ = static_cast<const int *>(fields[2].data)[0];
|
||||
is_cross_ = static_cast<const int *>(fields[3].data)[0];
|
||||
is_position_bias_ = static_cast<const int *>(fields[4].data)[0];
|
||||
}
|
||||
|
||||
MhaPlugin(const char *name, const void *serialData, size_t serialLength)
|
||||
|
@ -65,6 +69,7 @@ class MhaPlugin : public TensorRTPlugin {
|
|||
DeserializeValue(&serialData, &serialLength, &head_number_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &head_size_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &is_cross_, sizeof(int));
|
||||
DeserializeValue(&serialData, &serialLength, &is_position_bias_, sizeof(int));
|
||||
}
|
||||
|
||||
MhaPlugin() = delete;
|
||||
|
@ -82,21 +87,37 @@ class MhaPlugin : public TensorRTPlugin {
|
|||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int index, const nvinfer1::DimsExprs *inputs, int nbInputDims,
|
||||
nvinfer1::IExprBuilder &exprBuilder) noexcept override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept override;
|
||||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *tensorsDesc, int nbInputs,
|
||||
int nbOutputs) noexcept override;
|
||||
void terminate() noexcept override;
|
||||
int initialize() noexcept override;
|
||||
|
||||
private:
|
||||
bool needResize(const int *current_dims, const int *last_dims);
|
||||
template <typename T>
|
||||
int RunCudaMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream);
|
||||
int RunCudaCrossMha(const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream);
|
||||
const void *const *inputs, void *const *outputs, void *workspace, cudaStream_t stream,
|
||||
cublasGemmAlgo_t *algoId);
|
||||
template <typename T>
|
||||
void SetInnerAddr(void *workspace, size_t size_q, size_t size_k, size_t qk_buf_len, size_t qkv_buf_2_len,
|
||||
size_t extra_size);
|
||||
template <typename T>
|
||||
void RunPhase1GEMM(const nvinfer1::PluginTensorDesc *inputDesc, const void *const *inputs, int *gemm_dims,
|
||||
int *gemm_lds, cublasOperation_t *gemm_ops, cudaDataType *gemm_data_types, void *alpha, void *beta,
|
||||
cublasGemmAlgo_t algoId, cudaStream_t stream);
|
||||
|
||||
const std::string layer_name_;
|
||||
std::string name_space_;
|
||||
int compute_type_;
|
||||
int head_number_;
|
||||
int head_size_;
|
||||
int is_cross_;
|
||||
bool is_cross_;
|
||||
bool is_position_bias_;
|
||||
cublasGemmAlgo_t fast_algo_gemm[4] = {CUBLAS_GEMM_DEFAULT_TENSOR_OP, CUBLAS_GEMM_DEFAULT_TENSOR_OP,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP, CUBLAS_GEMM_DEFAULT_TENSOR_OP};
|
||||
|
||||
cublasHandle_t cublas_handle_;
|
||||
cublasLtHandle_t cublaslt_handle_;
|
||||
void *qkv_buf_{nullptr};
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include <limits>
|
||||
#include <unordered_map>
|
||||
#include "src/extendrt/delegate/delegate_utils.h"
|
||||
#include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
|
||||
#include "src/common/utils.h"
|
||||
|
||||
#include "ops/transpose.h"
|
||||
|
@ -68,6 +69,11 @@ TensorRTSubGraph::~TensorRTSubGraph() {
|
|||
config_->destroy();
|
||||
config_ = nullptr;
|
||||
}
|
||||
#ifdef PROFILER_
|
||||
auto profile = dynamic_cast<SimpleProfiler *>(trt_context_->getProfiler());
|
||||
if (profile != nullptr) std::cout << *profile << std::endl;
|
||||
delete profile;
|
||||
#endif
|
||||
if (trt_context_ != nullptr) {
|
||||
trt_context_->destroy();
|
||||
trt_context_ = nullptr;
|
||||
|
@ -494,6 +500,15 @@ int TensorRTSubGraph::Prepare() {
|
|||
MS_LOG(ERROR) << "TensorRTSubGraph create context failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
#ifdef PROFILER_
|
||||
auto profiler = new SimpleProfiler("myprofiler");
|
||||
if (profiler == nullptr) {
|
||||
MS_LOG(WARNING) << "Cannot create profiler";
|
||||
}
|
||||
this->trt_context_->setProfiler(profiler);
|
||||
#endif
|
||||
|
||||
int binding_num = this->engine_->getNbBindings();
|
||||
if (binding_num <= 0) {
|
||||
MS_LOG(ERROR) << "TensorRTSubGraph binding num < 0.";
|
||||
|
@ -504,7 +519,6 @@ int TensorRTSubGraph::Prepare() {
|
|||
MS_LOG(ERROR) << "malloc tensor binding array failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
profile_index_ = MaxVolumnProfileIndex();
|
||||
if (this->trt_context_->setOptimizationProfile(profile_index_)) {
|
||||
MS_LOG(INFO) << "setOptimizationProfile: " << profile_index_;
|
||||
|
@ -527,9 +541,6 @@ int TensorRTSubGraph::Prepare() {
|
|||
MS_LOG(INFO) << "device index " << index << " for tensor : " << tensor_name << " attr: " << device_ptr;
|
||||
tensor_bindings_[index] = device_ptr;
|
||||
nvinfer1::Dims input_dims = ConvertCudaDims(profile.inputs[i].max_dims);
|
||||
for (int od = 0; od < input_dims.nbDims; od++) {
|
||||
MS_LOG(DEBUG) << "in tensor " << tensor.Name() << " dims at " << od << " is " << input_dims.d[od];
|
||||
}
|
||||
if (!this->trt_context_->setBindingDimensions(index, input_dims)) {
|
||||
MS_LOG(ERROR) << "invalid input dims of " << tensor.Name();
|
||||
return RET_ERROR;
|
||||
|
@ -781,7 +792,6 @@ int TensorRTSubGraph::PostExecute(std::vector<tensor::Tensor> *outputs) {
|
|||
// actual output tensor dims
|
||||
auto out_dims = this->trt_context_->getBindingDimensions(index);
|
||||
std::vector<int64_t> new_shape = lite::ConvertMSShape(out_dims);
|
||||
outputs_[i].SetShape(new_shape);
|
||||
for (int od = 0; od < out_dims.nbDims; od++) {
|
||||
MS_LOG(DEBUG) << "out tensor " << trt_out_tensor_name << " dims at " << od << " is " << new_shape[od];
|
||||
}
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include <unordered_set>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include "src/extendrt/delegate/tensorrt/op/cast_plugin.h"
|
||||
#include "src/extendrt/delegate/tensorrt/distribution/distribution_collective.h"
|
||||
|
||||
|
@ -176,6 +178,7 @@ nvinfer1::DataType ConvertDataType(DataType type_id) {
|
|||
{DataType::kNumberTypeInt32, nvinfer1::DataType::kINT32},
|
||||
{DataType::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT},
|
||||
{DataType::kNumberTypeFloat16, nvinfer1::DataType::kHALF},
|
||||
{DataType::kNumberTypeInt64, nvinfer1::DataType::kINT32},
|
||||
};
|
||||
auto iter = data_type_map.find(type_id);
|
||||
nvinfer1::DataType data_type;
|
||||
|
@ -895,4 +898,66 @@ nvinfer1::DataType GetNvinferDataType<int>() {
|
|||
|
||||
template nvinfer1::DataType GetNvinferDataType<float>();
|
||||
template nvinfer1::DataType GetNvinferDataType<int>();
|
||||
|
||||
#ifdef PROFILER_
|
||||
void SimpleProfiler::reportLayerTime(const char *layerName, float ms) noexcept {
|
||||
mProfile_[layerName].count++;
|
||||
mProfile_[layerName].time += ms;
|
||||
if (std::find(mLayerNames_.begin(), mLayerNames_.end(), layerName) == mLayerNames_.end()) {
|
||||
mLayerNames_.push_back(layerName);
|
||||
}
|
||||
}
|
||||
|
||||
SimpleProfiler::SimpleProfiler(const char *name, const std::vector<SimpleProfiler> &srcProfilers) : mName_(name) {
|
||||
for (const auto &srcProfiler : srcProfilers) {
|
||||
for (const auto &rec : srcProfiler.mProfile_) {
|
||||
auto it = mProfile_.find(rec.first);
|
||||
if (it == mProfile_.end()) {
|
||||
mProfile_.insert(rec);
|
||||
} else {
|
||||
it->second.time += rec.second.time;
|
||||
it->second.count += rec.second.count;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &out, const SimpleProfiler &value) {
|
||||
out << "========== " << value.mName_ << " profile ==========" << std::endl;
|
||||
float totalTime = 0;
|
||||
std::string layerNameStr = "TensorRT layer name";
|
||||
int maxLayerNameLength = std::max(static_cast<int>(layerNameStr.size()), 70);
|
||||
for (const auto &elem : value.mProfile_) {
|
||||
totalTime += elem.second.time;
|
||||
maxLayerNameLength = std::max(maxLayerNameLength, static_cast<int>(elem.first.size()));
|
||||
}
|
||||
|
||||
auto old_settings = out.flags();
|
||||
auto old_precision = out.precision();
|
||||
// Output header
|
||||
{
|
||||
out << std::setw(maxLayerNameLength) << layerNameStr << " ";
|
||||
out << std::setw(C12NUM) << "Runtime, "
|
||||
<< "%"
|
||||
<< " ";
|
||||
out << std::setw(C12NUM) << "Invocations"
|
||||
<< " ";
|
||||
out << std::setw(C12NUM) << "Runtime, ms" << std::endl;
|
||||
}
|
||||
for (size_t i = 0; i < value.mLayerNames_.size(); i++) {
|
||||
const std::string layerName = value.mLayerNames_[i];
|
||||
auto elem = value.mProfile_.at(layerName);
|
||||
out << std::setw(maxLayerNameLength) << layerName << " ";
|
||||
out << std::setw(C12NUM) << std::fixed << std::setprecision(1) << (elem.time * 100.0F / totalTime) << "%"
|
||||
<< " ";
|
||||
out << std::setw(C12NUM) << elem.count << " ";
|
||||
out << std::setw(C12NUM) << std::fixed << std::setprecision(C2NUM) << elem.time << std::endl;
|
||||
}
|
||||
out.flags(old_settings);
|
||||
out.precision(old_precision);
|
||||
out << "========== " << value.mName_ << " total runtime = " << totalTime << " ms ==========" << std::endl;
|
||||
|
||||
return out;
|
||||
}
|
||||
#endif // PROFILER_
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <NvInferVersion.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "src/extendrt/delegate/tensorrt/tensorrt_context.h"
|
||||
#include "src/extendrt/delegate/tensorrt/tensor_info.h"
|
||||
#include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h"
|
||||
|
@ -57,6 +58,28 @@ typedef union float32_bits {
|
|||
float f;
|
||||
} float32_bits;
|
||||
|
||||
// #define PROFILER_
|
||||
#ifdef PROFILER_
|
||||
struct SimpleProfiler : public nvinfer1::IProfiler {
|
||||
struct Record {
|
||||
float time{0};
|
||||
int count{0};
|
||||
};
|
||||
|
||||
virtual void reportLayerTime(const char *layerName, float ms) noexcept;
|
||||
|
||||
explicit SimpleProfiler(const char *name,
|
||||
const std::vector<SimpleProfiler> &srcProfilers = std::vector<SimpleProfiler>());
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &out, const SimpleProfiler &value);
|
||||
|
||||
private:
|
||||
std::string mName_;
|
||||
std::vector<std::string> mLayerNames_;
|
||||
std::map<std::string, Record> mProfile_;
|
||||
};
|
||||
#endif
|
||||
|
||||
// Convert Tensor data to Cuda dims.
|
||||
nvinfer1::Dims ConvertCudaDims(const std::vector<int> &data);
|
||||
|
||||
|
@ -194,4 +217,4 @@ void Data2Vector(std::vector<float> *dst, const void *src) {
|
|||
}
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_UTILS_H_
|
||||
#endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_UTILS_H_
|
||||
|
|
|
@ -64,7 +64,7 @@ Status ContextUtils::AddGpuDevice(bool enable_fp16, uint32_t device_id, int rank
|
|||
|
||||
Status ContextUtils::AddNpuDevice(int frequency, lite::InnerContext *inner_context) {
|
||||
lite::DeviceInfo device_info;
|
||||
device_info.npu_device_info_ = {false, frequency};
|
||||
device_info.npu_device_info_.frequency_ = frequency;
|
||||
inner_context->device_list_.push_back({lite::DT_NPU, device_info});
|
||||
return kSuccess;
|
||||
}
|
||||
|
|
|
@ -42,6 +42,7 @@
|
|||
#include <thread>
|
||||
#include "src/common/config_file.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
constexpr size_t kDataToStringMaxNum = 40;
|
||||
constexpr int kPrintDataNum = 20;
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <functional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "ops/tuple_get_item.h"
|
||||
|
@ -32,32 +33,34 @@ const int kAttentionOutputs = 3;
|
|||
} // namespace
|
||||
|
||||
bool MultiHeadAttentionFusion::Init() const {
|
||||
input_q_ = std::make_shared<Var>();
|
||||
input_q_ = std::make_shared<Var>("input_q");
|
||||
MS_CHECK_TRUE_RET(input_q_ != nullptr, false);
|
||||
input_k_ = std::make_shared<Var>();
|
||||
input_k_ = std::make_shared<Var>("input_k");
|
||||
MS_CHECK_TRUE_RET(input_k_ != nullptr, false);
|
||||
input_v_ = std::make_shared<Var>();
|
||||
input_v_ = std::make_shared<Var>("input_v");
|
||||
MS_CHECK_TRUE_RET(input_v_ != nullptr, false);
|
||||
position_bias_ = std::make_shared<Var>("position_bias_");
|
||||
MS_CHECK_TRUE_RET(position_bias_ != nullptr, false);
|
||||
|
||||
weight_q_ = std::make_shared<CondVar>(IsParamNode);
|
||||
weight_q_ = std::make_shared<CondVar>(IsParamNode, "weight_q");
|
||||
MS_CHECK_TRUE_RET(weight_q_ != nullptr, false);
|
||||
weight_k_ = std::make_shared<CondVar>(IsParamNode);
|
||||
weight_k_ = std::make_shared<CondVar>(IsParamNode, "weight_k");
|
||||
MS_CHECK_TRUE_RET(weight_k_ != nullptr, false);
|
||||
weight_v_ = std::make_shared<CondVar>(IsParamNode);
|
||||
weight_v_ = std::make_shared<CondVar>(IsParamNode, "weight_v");
|
||||
MS_CHECK_TRUE_RET(weight_v_ != nullptr, false);
|
||||
weight_o_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(weight_o_ != nullptr, false);
|
||||
|
||||
bias_q_ = std::make_shared<CondVar>(IsParamNode);
|
||||
bias_q_ = std::make_shared<CondVar>(IsParamNode, "bias_q");
|
||||
MS_CHECK_TRUE_RET(bias_q_ != nullptr, false);
|
||||
bias_k_ = std::make_shared<CondVar>(IsParamNode);
|
||||
bias_k_ = std::make_shared<CondVar>(IsParamNode, "bias_k");
|
||||
MS_CHECK_TRUE_RET(bias_k_ != nullptr, false);
|
||||
bias_v_ = std::make_shared<CondVar>(IsParamNode);
|
||||
bias_v_ = std::make_shared<CondVar>(IsParamNode, "bias_v");
|
||||
MS_CHECK_TRUE_RET(bias_v_ != nullptr, false);
|
||||
bias_o_ = std::make_shared<CondVar>(IsParamNode);
|
||||
MS_CHECK_TRUE_RET(bias_o_ != nullptr, false);
|
||||
|
||||
mask_ = std::make_shared<Var>();
|
||||
mask_ = std::make_shared<Var>("mask");
|
||||
MS_CHECK_TRUE_RET(mask_ != nullptr, false);
|
||||
|
||||
reshape_k_ = std::make_shared<Var>("reshape_k");
|
||||
|
@ -77,20 +80,43 @@ namespace {
|
|||
VectorRef DefineMask(const BaseRef &mask_input) {
|
||||
auto is_expand_dims = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims));
|
||||
MS_CHECK_TRUE_RET(is_expand_dims != nullptr, {});
|
||||
auto var1 = std::make_shared<Var>();
|
||||
auto var1 = std::make_shared<Var>("m-var1");
|
||||
MS_CHECK_TRUE_RET(var1 != nullptr, {});
|
||||
auto expand_dims = VectorRef({is_expand_dims, mask_input, var1});
|
||||
auto is_sub = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSubFusion));
|
||||
auto is_sub = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSubFusion), "m-sub");
|
||||
MS_CHECK_TRUE_RET(is_sub != nullptr, {});
|
||||
auto var2 = std::make_shared<Var>();
|
||||
auto var2 = std::make_shared<Var>("m-var2");
|
||||
MS_CHECK_TRUE_RET(var2 != nullptr, {});
|
||||
auto sub = VectorRef({is_sub, var2, expand_dims});
|
||||
auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion));
|
||||
auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "m-mul");
|
||||
MS_CHECK_TRUE_RET(is_mul != nullptr, {});
|
||||
auto var3 = std::make_shared<Var>();
|
||||
auto var3 = std::make_shared<Var>("m-var3");
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
return VectorRef({is_mul, sub, var3});
|
||||
}
|
||||
STATUS GetIntParameterData(const ParameterPtr ¶m_ptr, std::vector<int> *result) {
|
||||
if (param_ptr == nullptr || !param_ptr->has_default()) {
|
||||
MS_LOG(DEBUG) << "param not have default";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto default_param = param_ptr->default_param();
|
||||
if (default_param == nullptr || !utils::isa<tensor::TensorPtr>(default_param)) {
|
||||
MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
|
||||
if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
|
||||
MS_LOG(DEBUG) << "default param is not int";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
|
||||
int64_t shape_size =
|
||||
std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>());
|
||||
for (int64_t i = 0; i < shape_size; i++) {
|
||||
result->emplace_back(ptr[i]);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS GetAxis(const BaseRef &n, std::vector<int> *axes) {
|
||||
MS_ASSERT(axes != nullptr);
|
||||
|
@ -99,8 +125,14 @@ STATUS GetAxis(const BaseRef &n, std::vector<int> *axes) {
|
|||
*axes = CastToInt(axes_value_node->value());
|
||||
return lite::RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "GetAxis supports only value node";
|
||||
auto reshape = utils::cast<ParameterPtr>(n);
|
||||
if (reshape != nullptr) {
|
||||
if (GetIntParameterData(reshape, axes) == lite::RET_OK) {
|
||||
return lite::RET_OK;
|
||||
}
|
||||
}
|
||||
}
|
||||
MS_LOG(ERROR) << " cannot get axes data";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -132,23 +164,42 @@ VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const
|
|||
return conn;
|
||||
}
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool cross, bool mask) const {
|
||||
VectorRef MultiHeadAttentionFusion::DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &axis,
|
||||
const BaseRef &transpose_var, bool test_div, bool transpose) const {
|
||||
auto is_matmul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "e-matmul");
|
||||
MS_CHECK_TRUE_RET(is_matmul != nullptr, {});
|
||||
auto dense = VectorRef({is_matmul, input, weight});
|
||||
auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "e-reshape");
|
||||
MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
|
||||
auto reshape = VectorRef({is_reshape, dense, axis});
|
||||
auto var2 = std::make_shared<Var>();
|
||||
VectorRef conn;
|
||||
if (transpose) {
|
||||
conn = VectorRef({transpose_var, reshape, var2});
|
||||
} else {
|
||||
conn = reshape;
|
||||
}
|
||||
if (test_div) {
|
||||
auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "e-div");
|
||||
MS_CHECK_TRUE_RET(is_div != nullptr, {});
|
||||
auto var3 = std::make_shared<Var>();
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
auto div = VectorRef({is_div, conn, var3});
|
||||
return div;
|
||||
}
|
||||
return conn;
|
||||
}
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool mask) const {
|
||||
VectorRef k_embedding, v_embedding;
|
||||
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
|
||||
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
|
||||
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true);
|
||||
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
|
||||
if (!cross) {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
} else {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
}
|
||||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
|
||||
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
|
||||
|
@ -194,23 +245,16 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPattern(bool cross, bool mas
|
|||
return matmul3;
|
||||
}
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5(bool cross) const {
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5() const {
|
||||
VectorRef k_embedding, v_embedding;
|
||||
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "q_transpose");
|
||||
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
|
||||
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true, false);
|
||||
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
|
||||
if (!cross) {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_, false, false);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
} else {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
}
|
||||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul1");
|
||||
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
|
||||
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape1");
|
||||
|
@ -246,23 +290,79 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5(bool cross) const
|
|||
return matmul3;
|
||||
}
|
||||
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const {
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternT5New(bool transpose) const {
|
||||
VectorRef k_embedding, v_embedding;
|
||||
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "q_transpose");
|
||||
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
|
||||
VectorRef q_embedding;
|
||||
if (transpose) {
|
||||
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, true);
|
||||
} else {
|
||||
q_embedding = DefineEmbedding(input_q_, weight_q_, reshape_axis_, q_transpose, true, false);
|
||||
}
|
||||
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, reshape_k_, k_transpose_, true, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, reshape_v_, v_transpose_, false, true);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul1");
|
||||
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
|
||||
auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape1");
|
||||
MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
|
||||
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
|
||||
auto var1 = std::make_shared<Var>("var1");
|
||||
MS_CHECK_TRUE_RET(var1 != nullptr, {});
|
||||
auto is_add1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "add1");
|
||||
MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
|
||||
auto add1 = VectorRef({is_add1, matmul1, position_bias_});
|
||||
auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "add2");
|
||||
MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
|
||||
auto mask2 = DefineMask(mask_);
|
||||
MS_CHECK_TRUE_RET(!mask2.empty(), {});
|
||||
auto add2 = VectorRef({is_add1, mask2, add1});
|
||||
auto reshape1 = VectorRef({is_reshape1, add2, var1});
|
||||
auto is_softmax = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSoftmax), "softmax");
|
||||
MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
|
||||
auto softmax = VectorRef({is_softmax, reshape1});
|
||||
auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape2");
|
||||
MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
|
||||
auto var2 = std::make_shared<Var>("var2");
|
||||
MS_CHECK_TRUE_RET(var2 != nullptr, {});
|
||||
auto reshape2 = VectorRef({is_reshape2, softmax, var2});
|
||||
auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul2");
|
||||
MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
|
||||
auto matmul2 = VectorRef({is_matmul2, reshape2, v_embedding});
|
||||
auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape3");
|
||||
MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
|
||||
auto var4 = std::make_shared<Var>("var4");
|
||||
MS_CHECK_TRUE_RET(var4 != nullptr, {});
|
||||
VectorRef reshape3;
|
||||
if (transpose) {
|
||||
auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "transpose");
|
||||
MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
|
||||
auto var3 = std::make_shared<Var>("var3");
|
||||
MS_CHECK_TRUE_RET(var3 != nullptr, {});
|
||||
auto transpose1 = VectorRef({is_transpose, matmul2, var3});
|
||||
reshape3 = VectorRef({is_reshape3, transpose1, var4});
|
||||
} else {
|
||||
reshape3 = VectorRef({is_reshape3, matmul2, var4});
|
||||
}
|
||||
|
||||
auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul3");
|
||||
MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
|
||||
auto matmul3 = VectorRef({is_matmul3, reshape3, weight_o_});
|
||||
return matmul3;
|
||||
}
|
||||
VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA() const {
|
||||
VectorRef k_embedding, v_embedding;
|
||||
auto q_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
|
||||
MS_CHECK_TRUE_RET(q_transpose != nullptr, {});
|
||||
auto q_embedding = DefineEmbedding(input_q_, weight_q_, bias_q_, reshape_axis_, q_transpose, true);
|
||||
MS_CHECK_TRUE_RET(!q_embedding.empty(), {});
|
||||
if (!cross) {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_k_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_v_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
} else {
|
||||
k_embedding = DefineEmbedding(input_k_, weight_k_, bias_k_, reshape_axis_, k_transpose_, true);
|
||||
MS_CHECK_TRUE_RET(!k_embedding.empty(), {});
|
||||
v_embedding = DefineEmbedding(input_v_, weight_v_, bias_v_, reshape_axis_, v_transpose_);
|
||||
MS_CHECK_TRUE_RET(!v_embedding.empty(), {});
|
||||
}
|
||||
auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion));
|
||||
MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
|
||||
auto matmul1 = VectorRef({is_matmul1, q_embedding, k_embedding});
|
||||
|
@ -296,13 +396,14 @@ VectorRef MultiHeadAttentionFusion::DefineMPWithMaskPatternPA(bool cross) const
|
|||
}
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
STATUS TransposeMatrix(std::shared_ptr<tensor::Tensor> src, std::shared_ptr<tensor::Tensor> dst) {
|
||||
MS_CHECK_TRUE_RET(src->shape().size() == C2NUM, RET_ERROR);
|
||||
MS_CHECK_TRUE_RET(dst->shape().size() == C2NUM, RET_ERROR);
|
||||
int rows = src->shape().at(0);
|
||||
int cols = src->shape().at(1);
|
||||
auto src_ptr = reinterpret_cast<float *>(src->data_c());
|
||||
auto dst_ptr = reinterpret_cast<float *>(dst->data_c());
|
||||
auto src_ptr = reinterpret_cast<T *>(src->data_c());
|
||||
auto dst_ptr = reinterpret_cast<T *>(dst->data_c());
|
||||
for (int r = 0; r < rows; ++r) {
|
||||
for (int c = 0; c < cols; ++c) {
|
||||
auto val = src_ptr[r * cols + c];
|
||||
|
@ -354,8 +455,20 @@ std::shared_ptr<tensor::Tensor> ConcatTensors(const std::vector<std::shared_ptr<
|
|||
if (transpose) {
|
||||
std::vector<int64_t> tshape = {new_shape[1], new_shape[0]};
|
||||
auto transposed_tensor = std::make_shared<tensor::Tensor>(base_data_type, tshape);
|
||||
auto status = TransposeMatrix(concat_tensor, transposed_tensor);
|
||||
switch (base_data_type) {
|
||||
case kNumberTypeFloat32: {
|
||||
auto status = TransposeMatrix<float>(concat_tensor, transposed_tensor);
|
||||
MS_CHECK_TRUE_RET(status == RET_OK, nullptr);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeFloat16: {
|
||||
auto status = TransposeMatrix<float16>(concat_tensor, transposed_tensor);
|
||||
MS_CHECK_TRUE_RET(status == RET_OK, nullptr);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "unsupported data type " << base_data_type << std::endl;
|
||||
}
|
||||
return transposed_tensor;
|
||||
}
|
||||
return concat_tensor;
|
||||
|
@ -368,14 +481,13 @@ std::unordered_map<std::string, VectorRef> MultiHeadAttentionFusion::DefinePatte
|
|||
MS_LOG(ERROR) << "initial member failed.";
|
||||
return patterns;
|
||||
}
|
||||
|
||||
patterns[kMPAWithMaskPatternName] = DefineMPWithMaskPattern();
|
||||
patterns[kMPAXWithMaskPatternName] = DefineMPWithMaskPattern(true);
|
||||
patterns[kMPAPatternName] = DefineMPWithMaskPattern(false, false);
|
||||
patterns[kMPAXPatternName] = DefineMPWithMaskPattern(true, false);
|
||||
patterns[kMPAPatternName] = DefineMPWithMaskPattern(false);
|
||||
patterns[kMPAWithMaskPatternNamePA] = DefineMPWithMaskPatternPA();
|
||||
patterns[kMPAXWithMaskPatternNamePA] = DefineMPWithMaskPatternPA(true);
|
||||
patterns[kMPAWithMaskPatternNameT5] = DefineMPWithMaskPatternT5();
|
||||
patterns[kMPAXWithMaskPatternNameT5] = DefineMPWithMaskPatternT5(true);
|
||||
patterns[kMPAWithMaskPatternNameT5New] = DefineMPWithMaskPatternT5New(false);
|
||||
patterns[kMPAWithMaskTransposePatternNameT5New] = DefineMPWithMaskPatternT5New();
|
||||
return patterns;
|
||||
}
|
||||
|
||||
|
@ -384,14 +496,24 @@ bool MultiHeadAttentionFusion::CheckPattern(const EquivPtr &equiv, int *head_num
|
|||
MS_ASSERT(head_num != nullptr);
|
||||
MS_ASSERT(head_size != nullptr);
|
||||
std::vector<int> reshape_axes;
|
||||
// UNDO !!!!!!!!!
|
||||
if (GetAxis((*equiv)[reshape_axis_], &reshape_axes) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "cannot figure out reshape";
|
||||
return false;
|
||||
}
|
||||
if (reshape_axes.size() != C4NUM) {
|
||||
return false;
|
||||
}
|
||||
*head_num = reshape_axes.at(C2NUM);
|
||||
*head_size = reshape_axes.at(C3NUM);
|
||||
std::vector<int> out;
|
||||
std::for_each(reshape_axes.begin() + 1, reshape_axes.end(), [&out](const auto &x) {
|
||||
if (x != -1) out.push_back(x);
|
||||
});
|
||||
if (out.size() < C2NUM) {
|
||||
MS_LOG(ERROR) << "cannot find head_num or head_size";
|
||||
return false;
|
||||
}
|
||||
*head_num = out.at(0);
|
||||
*head_size = out.at(1);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -401,44 +523,43 @@ AnfNodePtr MultiHeadAttentionFusion::Process(const std::string &pattern_name, co
|
|||
if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
++match_count_;
|
||||
if ((pattern_name == kMPAWithMaskPatternName) || (pattern_name == kMPAWithMaskPatternNamePA) ||
|
||||
(pattern_name == kMPAWithMaskPatternNameT5)) {
|
||||
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope());
|
||||
} else if ((pattern_name == kMPAXWithMaskPatternName) || (pattern_name == kMPAXWithMaskPatternNamePA) ||
|
||||
(pattern_name == kMPAXWithMaskPatternNameT5)) {
|
||||
(pattern_name == kMPAWithMaskPatternNameT5) ||
|
||||
(pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New)) {
|
||||
if (pattern_name == kMPAWithMaskPatternNameT5New || pattern_name == kMPAWithMaskTransposePatternNameT5New) {
|
||||
t5_x_ = true;
|
||||
}
|
||||
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), true);
|
||||
} else if (pattern_name == kMPAPatternName) {
|
||||
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), false, false);
|
||||
} else if (pattern_name == kMPAXPatternName) {
|
||||
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), true, false);
|
||||
}
|
||||
if (pattern_name == kMPAPatternName)
|
||||
return CreateMaskedMultiHeadAttentionNode(func_graph, equiv, node->fullname_with_scope(), false);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
{ return nullptr; }
|
||||
}
|
||||
|
||||
STATUS GetIntParameterData(const ParameterPtr ¶m_ptr, std::vector<int> *result) {
|
||||
if (param_ptr == nullptr || !param_ptr->has_default()) {
|
||||
MS_LOG(DEBUG) << "param not have default";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto default_param = param_ptr->default_param();
|
||||
if (default_param == nullptr || !utils::isa<tensor::TensorPtr>(default_param)) {
|
||||
MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
|
||||
if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
|
||||
MS_LOG(DEBUG) << "default param is not int";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
|
||||
int64_t shape_size =
|
||||
std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>());
|
||||
for (int64_t i = 0; i < shape_size; i++) {
|
||||
result->emplace_back(ptr[i]);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
// STATUS GetIntParameterData(const ParameterPtr ¶m_ptr, std::vector<int> *result) {
|
||||
// if (param_ptr == nullptr || !param_ptr->has_default()) {
|
||||
// MS_LOG(DEBUG) << "param not have default";
|
||||
// return RET_ERROR;
|
||||
// }
|
||||
// auto default_param = param_ptr->default_param();
|
||||
// if (default_param == nullptr || !utils::isa<tensor::TensorPtr>(default_param)) {
|
||||
// MS_LOG(DEBUG) << "tensor_info is not tensor::TensorPtr";
|
||||
// return RET_ERROR;
|
||||
// }
|
||||
// auto default_param_ptr = utils::cast<tensor::TensorPtr>(default_param);
|
||||
// if (default_param_ptr->data_type() != kNumberTypeInt32 && default_param_ptr->data_type() != kNumberTypeInt) {
|
||||
// MS_LOG(DEBUG) << "default param is not int";
|
||||
// return RET_ERROR;
|
||||
// }
|
||||
// auto ptr = reinterpret_cast<int *>(default_param_ptr->data_c());
|
||||
// int64_t shape_size =
|
||||
// std::accumulate(default_param_ptr->shape().begin(), default_param_ptr->shape().end(), 1, std::multiplies<>());
|
||||
// for (int64_t i = 0; i < shape_size; i++) {
|
||||
// result->emplace_back(ptr[i]);
|
||||
// }
|
||||
// return RET_OK;
|
||||
// }
|
||||
|
||||
std::shared_ptr<ops::Attention> MultiHeadAttentionFusion::BuildAttentionPrim(const EquivPtr &equiv) const {
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
|
@ -595,50 +716,115 @@ CNodePtr MultiHeadAttentionFusion::MakeGetTuple(const FuncGraphPtr &func_graph,
|
|||
return get_item_node;
|
||||
}
|
||||
|
||||
CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph,
|
||||
const EquivPtr &equiv, const string &base_name,
|
||||
bool cross, bool mask) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
std::vector<AnfNodePtr> redundant;
|
||||
auto attention_prim = CreatePrim(equiv, cross);
|
||||
MS_CHECK_TRUE_RET(attention_prim != nullptr, nullptr);
|
||||
auto attention_prim_c = attention_prim->GetPrim();
|
||||
MS_CHECK_TRUE_RET(attention_prim_c != nullptr, nullptr);
|
||||
auto value_node = NewValueNode(attention_prim_c);
|
||||
MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
|
||||
bool MultiHeadAttentionFusion::IsCross(const EquivPtr &equiv) const {
|
||||
auto input_q = utils::cast<AnfNodePtr>((*equiv)[input_q_]);
|
||||
auto input_v = utils::cast<AnfNodePtr>((*equiv)[input_v_]);
|
||||
|
||||
ShapeVector inputq_shape, inputv_shape;
|
||||
auto ret = FetchShapeFromAbstract(input_q->abstract(), &inputq_shape);
|
||||
MS_CHECK_TRUE_RET(ret == RET_OK, false);
|
||||
ret = FetchShapeFromAbstract(input_v->abstract(), &inputv_shape);
|
||||
MS_CHECK_TRUE_RET(ret == RET_OK, false);
|
||||
|
||||
if ((inputq_shape != inputv_shape) || ((match_count_ > 1) && (input_q != input_v))) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
using tup = std::tuple<AnfNodePtr, std::shared_ptr<tensor::Tensor>, std::shared_ptr<tensor::Tensor>,
|
||||
std::shared_ptr<tensor::Tensor>>;
|
||||
|
||||
tup MultiHeadAttentionFusion::GetAttentionNodeWeights(const EquivPtr &equiv, std::vector<AnfNodePtr> *redundant) const {
|
||||
auto weight_q = utils::cast<AnfNodePtr>((*equiv)[weight_q_]);
|
||||
auto weight_k = utils::cast<AnfNodePtr>((*equiv)[weight_k_]);
|
||||
auto weight_v = utils::cast<AnfNodePtr>((*equiv)[weight_v_]);
|
||||
redundant->push_back(weight_q);
|
||||
redundant->push_back(weight_k);
|
||||
redundant->push_back(weight_v);
|
||||
auto weight_o = utils::cast<AnfNodePtr>((*equiv)[weight_o_]);
|
||||
std::shared_ptr<tensor::Tensor> weight_q_tensor = GetTensorInfo(weight_q);
|
||||
std::shared_ptr<tensor::Tensor> weight_k_tensor = GetTensorInfo(weight_k);
|
||||
std::shared_ptr<tensor::Tensor> weight_v_tensor = GetTensorInfo(weight_v);
|
||||
return make_tuple(weight_o, weight_q_tensor, weight_k_tensor, weight_v_tensor);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> MultiHeadAttentionFusion::GetNewNodeInputs(const EquivPtr &equiv, ParameterPtr q_weight_param,
|
||||
ParameterPtr c_weight_param, AnfNodePtr weight_o,
|
||||
ParameterPtr c_bias_param, AnfNodePtr bias_o,
|
||||
bool mask, bool cross) const {
|
||||
auto input_q = utils::cast<AnfNodePtr>((*equiv)[input_q_]);
|
||||
auto input_k = utils::cast<AnfNodePtr>((*equiv)[input_k_]);
|
||||
auto input_v = utils::cast<AnfNodePtr>((*equiv)[input_v_]);
|
||||
AnfNodePtr input_mask;
|
||||
auto weight_q = utils::cast<AnfNodePtr>((*equiv)[weight_q_]);
|
||||
redundant.push_back(weight_q);
|
||||
auto weight_k = utils::cast<AnfNodePtr>((*equiv)[weight_k_]);
|
||||
auto weight_v = utils::cast<AnfNodePtr>((*equiv)[weight_v_]);
|
||||
redundant.push_back(weight_k);
|
||||
redundant.push_back(weight_v);
|
||||
auto weight_o = utils::cast<AnfNodePtr>((*equiv)[weight_o_]);
|
||||
auto bias_q = utils::cast<AnfNodePtr>((*equiv)[bias_q_]);
|
||||
if (!cross) {
|
||||
redundant.push_back(bias_q);
|
||||
AnfNodePtr input_mask = (mask) ? utils::cast<AnfNodePtr>((*equiv)[mask_]) : nullptr;
|
||||
AnfNodePtr position_bias = (t5_x_) ? utils::cast<AnfNodePtr>((*equiv)[position_bias_]) : nullptr;
|
||||
|
||||
auto attention_prim = CreatePrim(equiv, cross);
|
||||
MS_CHECK_TRUE_RET(attention_prim != nullptr, {});
|
||||
auto attention_prim_c = attention_prim->GetPrim();
|
||||
MS_CHECK_TRUE_RET(attention_prim_c != nullptr, {});
|
||||
auto value_node = NewValueNode(attention_prim_c);
|
||||
MS_CHECK_TRUE_RET(value_node != nullptr, {});
|
||||
std::vector<AnfNodePtr> new_node_inputs;
|
||||
if (cross) {
|
||||
if (t5_x_) {
|
||||
new_node_inputs = {value_node, input_q, input_k, input_v,
|
||||
q_weight_param, c_weight_param, weight_o, position_bias};
|
||||
} else {
|
||||
new_node_inputs = {value_node, input_q, input_k, input_v, q_weight_param,
|
||||
c_weight_param, weight_o, c_bias_param, bias_o};
|
||||
}
|
||||
} else {
|
||||
if (t5_x_) {
|
||||
new_node_inputs = {value_node, input_q, input_k, input_v, c_weight_param, weight_o, position_bias};
|
||||
} else {
|
||||
new_node_inputs = {value_node, input_q, input_k, input_v, c_weight_param, weight_o, c_bias_param, bias_o};
|
||||
}
|
||||
}
|
||||
if (mask) {
|
||||
new_node_inputs.push_back(input_mask);
|
||||
}
|
||||
return new_node_inputs;
|
||||
}
|
||||
|
||||
CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph,
|
||||
const EquivPtr &equiv, const string &base_name,
|
||||
bool mask) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
bool cross = IsCross(equiv);
|
||||
std::vector<AnfNodePtr> redundant;
|
||||
auto [weight_o, weight_q_tensor, weight_k_tensor, weight_v_tensor] = GetAttentionNodeWeights(equiv, &redundant);
|
||||
AnfNodePtr bias_q;
|
||||
ParameterPtr c_bias_param;
|
||||
AnfNodePtr bias_o;
|
||||
if (!t5_x_) {
|
||||
bias_q = utils::cast<AnfNodePtr>((*equiv)[bias_q_]);
|
||||
auto bias_k = utils::cast<AnfNodePtr>((*equiv)[bias_k_]);
|
||||
auto bias_v = utils::cast<AnfNodePtr>((*equiv)[bias_v_]);
|
||||
redundant.push_back(bias_k);
|
||||
redundant.push_back(bias_v);
|
||||
auto bias_o = utils::cast<AnfNodePtr>((*equiv)[bias_o_]);
|
||||
auto knode = utils::cast<AnfNodePtr>((*equiv)[k_transpose_]);
|
||||
auto vnode = utils::cast<AnfNodePtr>((*equiv)[v_transpose_]);
|
||||
if (mask) {
|
||||
input_mask = utils::cast<AnfNodePtr>((*equiv)[mask_]);
|
||||
}
|
||||
std::shared_ptr<tensor::Tensor> weight_q_tensor = GetTensorInfo(weight_q);
|
||||
std::shared_ptr<tensor::Tensor> weight_k_tensor = GetTensorInfo(weight_k);
|
||||
std::shared_ptr<tensor::Tensor> weight_v_tensor = GetTensorInfo(weight_v);
|
||||
bias_o = utils::cast<AnfNodePtr>((*equiv)[bias_o_]);
|
||||
std::shared_ptr<tensor::Tensor> bias_q_tensor = GetTensorInfo(bias_q);
|
||||
std::shared_ptr<tensor::Tensor> bias_k_tensor = GetTensorInfo(bias_k);
|
||||
std::shared_ptr<tensor::Tensor> bias_v_tensor = GetTensorInfo(bias_v);
|
||||
auto c_bias = ConcatTensors({bias_q_tensor, bias_k_tensor, bias_v_tensor});
|
||||
c_bias_param = func_graph->add_parameter();
|
||||
MS_CHECK_TRUE_RET(c_bias_param != nullptr, nullptr);
|
||||
c_bias_param->set_name(base_name + "/bias_qkv");
|
||||
if (lite::InitParameterFromTensorInfo(c_bias_param, c_bias) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Init parameter from tensor info failed.";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
auto knode = utils::cast<AnfNodePtr>((*equiv)[k_transpose_]);
|
||||
AnfNodePtr vnode;
|
||||
auto it_vnode = (*equiv).find(v_transpose_);
|
||||
if (it_vnode != (*equiv).end() && !t5_x_) vnode = utils::cast<AnfNodePtr>(it_vnode->second);
|
||||
if (!cross && !t5_x_) {
|
||||
redundant.push_back(bias_q);
|
||||
}
|
||||
|
||||
tensor::TensorPtr c_weights;
|
||||
tensor::TensorPtr q_weight_t;
|
||||
if (cross) {
|
||||
|
@ -647,7 +833,6 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
|
|||
} else {
|
||||
c_weights = ConcatTensors({weight_q_tensor, weight_k_tensor, weight_v_tensor}, true);
|
||||
}
|
||||
auto c_bias = ConcatTensors({bias_q_tensor, bias_k_tensor, bias_v_tensor});
|
||||
// convert tensors to params
|
||||
auto c_weight_param = func_graph->add_parameter();
|
||||
MS_CHECK_TRUE_RET(c_weight_param != nullptr, nullptr);
|
||||
|
@ -656,13 +841,6 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
|
|||
return nullptr;
|
||||
}
|
||||
c_weight_param->set_name(base_name + "/weight_qkv");
|
||||
auto c_bias_param = func_graph->add_parameter();
|
||||
MS_CHECK_TRUE_RET(c_bias_param != nullptr, nullptr);
|
||||
if (lite::InitParameterFromTensorInfo(c_bias_param, c_bias) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Init parameter from tensor info failed.";
|
||||
return nullptr;
|
||||
}
|
||||
c_bias_param->set_name(base_name + "/bias_qkv");
|
||||
ParameterPtr q_weight_param;
|
||||
if (cross) {
|
||||
q_weight_param = func_graph->add_parameter();
|
||||
|
@ -672,25 +850,26 @@ CNodePtr MultiHeadAttentionFusion::CreateMaskedMultiHeadAttentionNode(const Func
|
|||
return nullptr;
|
||||
}
|
||||
}
|
||||
std::vector<AnfNodePtr> new_node_inputs;
|
||||
if (cross) {
|
||||
new_node_inputs = {value_node, input_q, input_k, input_v, q_weight_param,
|
||||
c_weight_param, weight_o, c_bias_param, bias_o};
|
||||
} else {
|
||||
new_node_inputs = {value_node, input_q, input_k, input_v, c_weight_param, weight_o, c_bias_param, bias_o};
|
||||
}
|
||||
if (mask) {
|
||||
new_node_inputs.push_back(input_mask);
|
||||
}
|
||||
std::vector<AnfNodePtr> new_node_inputs =
|
||||
GetNewNodeInputs(equiv, q_weight_param, c_weight_param, weight_o, c_bias_param, bias_o, mask, cross);
|
||||
|
||||
auto new_node = func_graph->NewCNode(new_node_inputs);
|
||||
MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
|
||||
if (vnode) {
|
||||
auto status = SetAbstractTuple(new_node, kAttentionOutputs);
|
||||
if (status != RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
new_node->set_fullname_with_scope(base_name + "/attention");
|
||||
CNodePtr ret_node;
|
||||
if (vnode) {
|
||||
auto get_item_node = MakeGetTuple(func_graph, new_node, knode, vnode);
|
||||
ret_node = get_item_node;
|
||||
} else {
|
||||
ret_node = new_node;
|
||||
}
|
||||
RemoveRedundantInput(func_graph, redundant);
|
||||
return get_item_node;
|
||||
return ret_node;
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include "tools/optimizer/common/multiple_pattern_process_pass.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -46,15 +47,18 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
|
|||
|
||||
private:
|
||||
// define patterns
|
||||
VectorRef DefineMPWithMaskPattern(bool cross = false, bool mask = true) const;
|
||||
VectorRef DefineMPWithMaskPatternPA(bool cross = false) const;
|
||||
VectorRef DefineMPWithMaskPatternT5(bool cross = false) const;
|
||||
VectorRef DefineMPWithMaskPattern(bool mask = true) const;
|
||||
VectorRef DefineMPWithMaskPatternPA() const;
|
||||
VectorRef DefineMPWithMaskPatternT5() const;
|
||||
VectorRef DefineMPWithMaskPatternT5New(bool transpose = true) const;
|
||||
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &bias, const BaseRef &axis,
|
||||
const BaseRef &transpose_var, bool test_div = false, bool transpose = true) const;
|
||||
VectorRef DefineEmbedding(const BaseRef &input, const BaseRef &weight, const BaseRef &axis,
|
||||
const BaseRef &transpose_var, bool test_div, bool transpose) const;
|
||||
|
||||
// create masked-multi-head-attention
|
||||
CNodePtr CreateMaskedMultiHeadAttentionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
|
||||
const std::string &base_name, bool cross = false, bool mask = true) const;
|
||||
const std::string &base_name, bool mask = true) const;
|
||||
// check pattern
|
||||
bool CheckPattern(const EquivPtr &equiv, int *head_num, int *head_size) const;
|
||||
CNodePtr CreateOutputGetItem(const FuncGraphPtr &func_graph, const CNodePtr &node, const int item_index) const;
|
||||
|
@ -66,20 +70,27 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
|
|||
CNodePtr MakeGetTuple(const FuncGraphPtr &func_graph, const CNodePtr &new_node, const AnfNodePtr &knode,
|
||||
const AnfNodePtr &vnode) const;
|
||||
std::shared_ptr<ops::Attention> CreatePrim(const EquivPtr &equiv, bool cross) const;
|
||||
bool IsCross(const EquivPtr &equiv) const;
|
||||
std::vector<AnfNodePtr> GetNewNodeInputs(const EquivPtr &equiv, ParameterPtr q_weight_param,
|
||||
ParameterPtr c_weight_param, AnfNodePtr weight_o, ParameterPtr c_bias_param,
|
||||
AnfNodePtr bias_o, bool mask, bool cross) const;
|
||||
std::tuple<AnfNodePtr, std::shared_ptr<tensor::Tensor>, std::shared_ptr<tensor::Tensor>,
|
||||
std::shared_ptr<tensor::Tensor> >
|
||||
GetAttentionNodeWeights(const EquivPtr &equiv, std::vector<AnfNodePtr> *redundant) const;
|
||||
mutable int match_count_ = 0;
|
||||
|
||||
protected:
|
||||
const std::string kMPAWithMaskPatternName = "MPAWithMaskPattern";
|
||||
const std::string kMPAXWithMaskPatternName = "MPAXWithMaskPattern";
|
||||
const std::string kMPAWithMaskPatternNamePA = "MPAWithMaskPatternPA";
|
||||
const std::string kMPAXWithMaskPatternNamePA = "MPAXWithMaskPatternPA";
|
||||
const std::string kMPAPatternName = "MPAPattern";
|
||||
const std::string kMPAXPatternName = "MPAXPattern";
|
||||
const std::string kMPAWithMaskPatternNameT5 = "MPAWithMaskPatternT5";
|
||||
const std::string kMPAXWithMaskPatternNameT5 = "MPAXWithMaskPatternT5";
|
||||
const std::string kMPAWithMaskPatternNameT5New = "MPAWithMaskPatternT5New";
|
||||
const std::string kMPAWithMaskTransposePatternNameT5New = "MPAWithMaskTransposePatternT5New";
|
||||
|
||||
mutable VarPtr input_q_{nullptr};
|
||||
mutable VarPtr input_k_{nullptr};
|
||||
mutable VarPtr input_v_{nullptr};
|
||||
mutable VarPtr position_bias_{nullptr};
|
||||
|
||||
mutable VarPtr weight_q_{nullptr};
|
||||
mutable VarPtr weight_k_{nullptr};
|
||||
|
@ -98,8 +109,9 @@ class MultiHeadAttentionFusion : public MultiplePatternProcessPass {
|
|||
mutable VarPtr reshape_axis_{nullptr};
|
||||
mutable VarPtr v_transpose_{nullptr};
|
||||
mutable VarPtr k_transpose_{nullptr};
|
||||
};
|
||||
|
||||
mutable bool t5_x_{false};
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_MULTI_HEAD_ATTENTION_FUSION_H_
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue