support cusolverDn

fix clang format
This commit is contained in:
zongha 2020-07-17 15:54:53 +08:00
parent 946dcfa0ce
commit 82412429cf
4 changed files with 33 additions and 1 deletions

View File

@ -277,7 +277,8 @@ if (ENABLE_GPU)
${CUDA_PATH}/lib64/libcurand.so
${CUDNN_PATH}/lib64/libcudnn.so
${CUDA_PATH}/lib64/libcudart.so
${CUDA_PATH}/lib64/stubs/libcuda.so)
${CUDA_PATH}/lib64/stubs/libcuda.so
${CUDA_PATH}/lib64/libcusolver.so)
endif ()
if (ENABLE_CPU)

View File

@ -93,6 +93,22 @@ namespace gpu {
} \
}
#define CHECK_CUSOLVER_RET_WITH_EXCEPT(expression, message) \
{ \
cusolverStatus_t status = (expression); \
if (status != CUSOLVER_STATUS_SUCCESS) { \
MS_LOG(EXCEPTION) << "cusolver Error: " << message << " | Error Number: " << status; \
} \
}
#define CHECK_CUSOLVER_RET_WITH_ERROR(expression, message) \
{ \
cusolverStatus_t status = (expression); \
if (status != CUSOLVER_STATUS_SUCCESS) { \
MS_LOG(ERROR) << "cusolver Error: " << message << " | Error Number: " << status; \
} \
}
#define CHECK_NCCL_RET_WITH_EXCEPT(expression, message) \
{ \
int result = (expression); \

View File

@ -32,6 +32,11 @@ void GPUDeviceManager::InitDevice() {
CHECK_CUBLAS_RET_WITH_EXCEPT(cublasCreate(&cublas_handle_), "Failed to create cuBLAS handle.");
CHECK_CUBLAS_RET_WITH_EXCEPT(cublasSetStream(cublas_handle_, reinterpret_cast<cudaStream_t>(default_stream())),
"Failed to set stream for cuBLAS handle.");
CHECK_CUSOLVER_RET_WITH_EXCEPT(cusolverDnCreate(&cusolver_dn_handle_), "Failed to create cusolver dn handle.");
CHECK_CUSOLVER_RET_WITH_EXCEPT(
cusolverDnSetStream(cusolver_dn_handle_, reinterpret_cast<cudaStream_t>(default_stream())),
"Failed to set stream for cusolver dn handle");
CHECK_OP_RET_WITH_EXCEPT(GPUMemoryAllocator::GetInstance().Init(), "Failed to Init gpu memory allocator")
}
@ -47,6 +52,9 @@ void GPUDeviceManager::ReleaseDevice() {
if (cublas_handle_ != nullptr) {
CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle.");
}
if (cusolver_dn_handle_ != nullptr) {
CHECK_CUSOLVER_RET_WITH_ERROR(cusolverDnDestroy(cusolver_dn_handle_), "Failed to destroy cusolver dn handle.");
}
CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator");
}
@ -80,6 +88,8 @@ const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_han
const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; }
const cusolverDnHandle_t &GPUDeviceManager::GetCusolverDnHandle() const { return cusolver_dn_handle_; }
bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); }
bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const {

View File

@ -19,6 +19,7 @@
#include <cudnn.h>
#include <cublas_v2.h>
#include <cusolverDn.h>
#include <vector>
#include <memory>
#include "runtime/device/gpu/cuda_driver.h"
@ -43,6 +44,7 @@ class GPUDeviceManager {
const cudnnHandle_t &GetCudnnHandle() const;
const cublasHandle_t &GetCublasHandle() const;
const cusolverDnHandle_t &GetCusolverDnHandle() const;
bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const;
bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const;
@ -73,6 +75,9 @@ class GPUDeviceManager {
// handle used for cuBLAS kernels.
cublasHandle_t cublas_handle_{nullptr};
// handle used for cusolver dn kernels;
cusolverDnHandle_t cusolver_dn_handle_{nullptr};
bool dev_id_init_;
uint32_t cur_dev_id_;
};