[MSLITE] bug fix for matmul op 0415_10
This commit is contained in:
parent
84a9c68b5c
commit
bb7781037e
|
@ -34,15 +34,15 @@ void CublasMM1Batch(const void *a_addr, const void *b_addr, void *c_addr, const
|
|||
const int k = params[2];
|
||||
cublasOperation_t trans_a = operations[0];
|
||||
cublasOperation_t trans_b = operations[1];
|
||||
const int lda = (trans_a == CUBLAS_OP_N) ? m : k;
|
||||
const int ldb = (trans_b == CUBLAS_OP_N) ? k : n;
|
||||
const int ldc = m;
|
||||
const int lda = (trans_a == CUBLAS_OP_N) ? k : m;
|
||||
const int ldb = (trans_b == CUBLAS_OP_N) ? n : k;
|
||||
const int ldc = n;
|
||||
cudaDataType_t type_a = data_types[0];
|
||||
cudaDataType_t type_b = data_types[1];
|
||||
cudaDataType_t type_c = data_types[2];
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_a, trans_b, m, n, k, &alpha, a_addr, type_a, lda, b_addr, type_b,
|
||||
ldb, &beta, c_addr, type_c, ldc, type_compute, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
CUBLAS_CHECK_VOID(cublasGemmEx(cublas_handle, trans_b, trans_a, n, m, k, &alpha, b_addr, type_b, ldb, a_addr, type_a,
|
||||
lda, &beta, c_addr, type_c, ldc, type_compute, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
} // namespace mindspore::lite
|
||||
|
|
|
@ -39,6 +39,15 @@ class CudaHelper {
|
|||
#define GET_BLOCKS(total_threads) CudaHelper::GetInstance().GetBlocksNum(total_threads)
|
||||
#define GET_THREADS CudaHelper::GetInstance().GetThreadNum()
|
||||
|
||||
#define CUDA_CHECK(ret) \
|
||||
do { \
|
||||
cudaError_t cuda_ret = (ret); \
|
||||
if ((cuda_ret) != cudaSuccess) { \
|
||||
MS_LOG(ERROR) << "cuda func call error: " << cudaGetErrorString(cuda_ret); \
|
||||
return -1; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CUDA_CHECK_VOID(ret) \
|
||||
do { \
|
||||
cudaError_t cuda_ret = (ret); \
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -38,17 +38,16 @@ int MatmulOptPlugin::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, const
|
|||
CUBLAS_CHECK(cublasSetStream(cublas_handle_, stream));
|
||||
const nvinfer1::PluginTensorDesc desc_a = inputDesc[0];
|
||||
const nvinfer1::PluginTensorDesc desc_b = inputDesc[1];
|
||||
const nvinfer1::PluginTensorDesc desc_c = outputDesc[0];
|
||||
// a: m * k, b: k * n, c: m * n
|
||||
int m = a_trans_ ? desc_a.dims.d[1] : desc_a.dims.d[0];
|
||||
int k = a_trans_ ? desc_a.dims.d[0] : desc_a.dims.d[1];
|
||||
int n = b_trans_ ? desc_b.dims.d[0] : desc_b.dims.d[1];
|
||||
int m = desc_c.dims.d[0];
|
||||
int n = desc_c.dims.d[1];
|
||||
int k = b_trans_ ? desc_b.dims.d[1] : desc_b.dims.d[0];
|
||||
const int mm_params[]{m, n, k};
|
||||
const int trans_params[]{n, m};
|
||||
if (desc_a.type == nvinfer1::DataType::kFLOAT && desc_b.type == nvinfer1::DataType::kFLOAT) {
|
||||
CublasMM1Batch(inputs[0], inputs[1], c_addr_trans_, mm_params, operations_, data_types_, type_compute_,
|
||||
CublasMM1Batch(inputs[0], inputs[1], outputs[0], mm_params, operations_, data_types_, type_compute_,
|
||||
cublas_handle_);
|
||||
Cublas2DTranspose(static_cast<const float *>(c_addr_trans_), static_cast<float *>(outputs[0]), trans_params,
|
||||
cublas_handle_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << layer_name_ << " input datatype needs check a: " << static_cast<int>(desc_a.type)
|
||||
<< ", b: " << static_cast<int>(desc_a.type);
|
||||
|
@ -84,17 +83,8 @@ bool MatmulOptPlugin::supportsFormatCombination(int pos, const nvinfer1::PluginT
|
|||
void MatmulOptPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out, int nbOutputs) noexcept {
|
||||
bias_index_ = (nbInputs == INPUT_SIZE3) ? kBiasIndex : -1;
|
||||
if (a_addr_trans_ == nullptr) {
|
||||
CUDA_CHECK_VOID(cudaMalloc(&a_addr_trans_, GetDimsVolume(in[0].max) * sizeof(float)));
|
||||
}
|
||||
if (b_addr_trans_ == nullptr) {
|
||||
CUDA_CHECK_VOID(cudaMalloc(&b_addr_trans_, GetDimsVolume(in[1].max) * sizeof(float)));
|
||||
}
|
||||
if (c_addr_trans_ == nullptr) {
|
||||
CUDA_CHECK_VOID(cudaMalloc(&c_addr_trans_, GetDimsVolume(out[0].max) * sizeof(float)));
|
||||
}
|
||||
operations_[0] = a_trans_ ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
operations_[1] = b_trans_ ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
operations_[0] = a_trans_ ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
operations_[1] = b_trans_ ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
data_types_[0] = ConvertDataType(in[0].desc.type); // input a
|
||||
data_types_[1] = ConvertDataType(in[1].desc.type); // input b
|
||||
data_types_[kBiasIndex] = ConvertDataType(out[0].desc.type); // output c
|
||||
|
@ -125,17 +115,7 @@ int MatmulOptPlugin::initialize() noexcept {
|
|||
}
|
||||
}
|
||||
|
||||
void MatmulOptPlugin::FreeCudaDeviceMemory(void **addr) {
|
||||
if (addr != nullptr) {
|
||||
CUDA_CHECK_VOID(cudaFree(*addr));
|
||||
addr = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void MatmulOptPlugin::terminate() noexcept {
|
||||
FreeCudaDeviceMemory(&a_addr_trans_);
|
||||
FreeCudaDeviceMemory(&b_addr_trans_);
|
||||
FreeCudaDeviceMemory(&c_addr_trans_);
|
||||
if (cublas_handle_ != nullptr) {
|
||||
auto cublas_ret = cublasDestroy(cublas_handle_);
|
||||
if (cublas_ret != CUBLAS_STATUS_SUCCESS) {
|
||||
|
|
|
@ -67,9 +67,6 @@ class MatmulOptPlugin : public nvinfer1::IPluginV2DynamicExt {
|
|||
bool b_trans_{false};
|
||||
int bias_index_{-1}; // -1 means no bias, otherwise should be 2
|
||||
cublasHandle_t cublas_handle_{nullptr};
|
||||
void *a_addr_trans_{nullptr};
|
||||
void *b_addr_trans_{nullptr};
|
||||
void *c_addr_trans_{nullptr};
|
||||
cublasOperation_t operations_[2]{CUBLAS_OP_N, CUBLAS_OP_N};
|
||||
cudaDataType_t data_types_[3]{CUDA_R_32F, CUDA_R_32F, CUDA_R_32F};
|
||||
cublasComputeType_t type_compute_;
|
||||
|
|
Loading…
Reference in New Issue