[MSLITE] bug fix for matmul op 0415_10

This commit is contained in:
Liu_Xuu 2022-03-15 09:39:59 +08:00
parent 84a9c68b5c
commit bb7781037e
5 changed files with 22 additions and 36 deletions

View File

@ -34,15 +34,15 @@ void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const
const int k = params[2];
cublasOperation_t trans_a = operations[0];
cublasOperation_t trans_b = operations[1];
const int lda = (trans_a == CUBLAS_OP_N) ? m : k;
const int ldb = (trans_b == CUBLAS_OP_N) ? k : n;
const int ldc = m;
const int lda = (trans_a == CUBLAS_OP_N) ? k : m;
const int ldb = (trans_b == CUBLAS_OP_N) ? n : k;
const int ldc = n;
cudaDataType_t type_a = data_types[0];
cudaDataType_t type_b = data_types[1];
cudaDataType_t type_c = data_types[2];
const float alpha = 1.0f;
const float beta = 0.0f;
CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_a, trans_b, m, n, k, &alpha, a_addr, type_a, lda, b_addr, type_b,
ldb, &beta, c_addr, type_c, ldc, type_compute, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_b, trans_a, n, m, k, &alpha, b_addr, type_b, ldb, a_addr, type_a,
lda, &beta, c_addr, type_c, ldc, type_compute, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
} // namespace mindspore::lite

View File

@ -39,6 +39,15 @@ class CudaHelper {
#define GET_BLOCKS(total_threads) CudaHelper::GetInstance().GetBlocksNum(total_threads)
#define GET_THREADS CudaHelper::GetInstance().GetThreadNum()
#define CUDA_CHECK(ret) \
do { \
cudaError_t cuda_ret = (ret); \
if ((cuda_ret) != cudaSuccess) { \
MS_LOG(ERROR) << "cuda func call error: " << cudaGetErrorString(cuda_ret); \
return -1; \
} \
} while (0)
#define CUDA_CHECK_VOID(ret) \
do { \
cudaError_t cuda_ret = (ret); \

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -38,17 +38,16 @@ int MatmulOptPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const
CUBLAS_CHECK(cublasSetStream(cublas_handle_, stream));
const nvinfer1::PluginTensorDesc desc_a = inputDesc[0];
const nvinfer1::PluginTensorDesc desc_b = inputDesc[1];
const nvinfer1::PluginTensorDesc desc_c = outputDesc[0];
// a: m * k, b: k * n, c: m * n
int m = a_trans_ ? desc_a.dims.d[1] : desc_a.dims.d[0];
int k = a_trans_ ? desc_a.dims.d[0] : desc_a.dims.d[1];
int n = b_trans_ ? desc_b.dims.d[0] : desc_b.dims.d[1];
int m = desc_c.dims.d[0];
int n = desc_c.dims.d[1];
int k = b_trans_ ? desc_b.dims.d[1] : desc_b.dims.d[0];
const int mm_params[]{m, n, k};
const int trans_params[]{n, m};
if (desc_a.type == nvinfer1::DataType::kFLOAT && desc_b.type == nvinfer1::DataType::kFLOAT) {
CublasMM1Batch(inputs[0], inputs[1], c_addr_trans_, mm_params, operations_, data_types_, type_compute_,
CublasMM1Batch(inputs[0], inputs[1], outputs[0], mm_params, operations_, data_types_, type_compute_,
cublas_handle_);
Cublas2DTranspose(static_cast<const float *>(c_addr_trans_), static_cast<float *>(outputs[0]), trans_params,
cublas_handle_);
} else {
MS_LOG(ERROR) << layer_name_ << " input datatype needs check a: " << static_cast<int>(desc_a.type)
<< ", b: " << static_cast<int>(desc_a.type);
@ -84,17 +83,8 @@ bool MatmulOptPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginT
void MatmulOptPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {
bias_index_ = (nbInputs == INPUT_SIZE3) ? kBiasIndex : -1;
if (a_addr_trans_ == nullptr) {
CUDA_CHECK_VOID(cudaMalloc(&a_addr_trans_, GetDimsVolume(in[0].max) * sizeof(float)));
}
if (b_addr_trans_ == nullptr) {
CUDA_CHECK_VOID(cudaMalloc(&b_addr_trans_, GetDimsVolume(in[1].max) * sizeof(float)));
}
if (c_addr_trans_ == nullptr) {
CUDA_CHECK_VOID(cudaMalloc(&c_addr_trans_, GetDimsVolume(out[0].max) * sizeof(float)));
}
operations_[0] = a_trans_ ? CUBLAS_OP_N : CUBLAS_OP_T;
operations_[1] = b_trans_ ? CUBLAS_OP_N : CUBLAS_OP_T;
operations_[0] = a_trans_ ? CUBLAS_OP_T : CUBLAS_OP_N;
operations_[1] = b_trans_ ? CUBLAS_OP_T : CUBLAS_OP_N;
data_types_[0] = ConvertDataType(in[0].desc.type); // input a
data_types_[1] = ConvertDataType(in[1].desc.type); // input b
data_types_[kBiasIndex] = ConvertDataType(out[0].desc.type); // output c
@ -125,17 +115,7 @@ int MatmulOptPlugin::initialize() noexcept {
}
}
void MatmulOptPlugin::FreeCudaDeviceMemory(void **addr) {
if (addr != nullptr) {
CUDA_CHECK_VOID(cudaFree(*addr));
addr = nullptr;
}
}
void MatmulOptPlugin::terminate() noexcept {
FreeCudaDeviceMemory(&a_addr_trans_);
FreeCudaDeviceMemory(&b_addr_trans_);
FreeCudaDeviceMemory(&c_addr_trans_);
if (cublas_handle_ != nullptr) {
auto cublas_ret = cublasDestroy(cublas_handle_);
if (cublas_ret != CUBLAS_STATUS_SUCCESS) {

View File

@ -67,9 +67,6 @@ class MatmulOptPlugin : public nvinfer1::IPluginV2DynamicExt {
bool b_trans_{false};
int bias_index_{-1}; // -1 means no bias, otherwise should be 2
cublasHandle_t cublas_handle_{nullptr};
void *a_addr_trans_{nullptr};
void *b_addr_trans_{nullptr};
void *c_addr_trans_{nullptr};
cublasOperation_t operations_[2]{CUBLAS_OP_N, CUBLAS_OP_N};
cudaDataType_t data_types_[3]{CUDA_R_32F, CUDA_R_32F, CUDA_R_32F};
cublasComputeType_t type_compute_;