From 8fd1bf6f624055c8a89b7bf4c6eff3c893e8bece Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 16 Mar 2021 17:22:19 +0800 Subject: [PATCH] gpu support cuda-11.0 and cudnn-8.0 --- build.sh | 6 +---- mindspore/_check_version.py | 2 +- mindspore/ccsrc/CMakeLists.txt | 6 ++++- .../gpu/nn/conv2d_gpu_kernel.h | 26 +++++++------------ .../gpu/nn/conv2d_grad_filter_gpu_kernel.h | 26 +++++++------------ .../gpu/nn/conv2d_grad_input_gpu_kernel.h | 26 +++++++------------ .../kernel_compiler/gpu/nn/lstm_gpu_kernel.h | 17 +++++++++--- .../gpu/nn/lstm_grad_data_gpu_kernel.h | 17 +++++++++--- .../gpu/nn/lstm_grad_weight_gpu_kernel.h | 18 +++++++++---- 9 files changed, 73 insertions(+), 71 deletions(-) diff --git a/build.sh b/build.sh index 7f29dab7f81..388dbafbea9 100755 --- a/build.sh +++ b/build.sh @@ -357,15 +357,11 @@ checkopts() if [[ "X$DEVICE_VERSION" == "X" ]]; then DEVICE_VERSION=10.1 fi - if [[ "X$DEVICE_VERSION" != "X9.2" && "X$DEVICE_VERSION" != "X10.1" ]]; then + if [[ "X$DEVICE_VERSION" != "X11.1" && "X$DEVICE_VERSION" != "X10.1" ]]; then echo "Invalid value ${DEVICE_VERSION} for option -V" usage exit 1 fi - if [[ "X$DEVICE_VERSION" == "X9.2" ]]; then - echo "Unsupported CUDA version 9.2" - exit 1 - fi CUDA_VERSION="$DEVICE_VERSION" elif [[ "X$DEVICE" == "Xd" || "X$DEVICE" == "Xascend" ]]; then # version default 910 diff --git a/mindspore/_check_version.py b/mindspore/_check_version.py index 6d14757fbcb..7b51aaa1c71 100644 --- a/mindspore/_check_version.py +++ b/mindspore/_check_version.py @@ -45,7 +45,7 @@ class GPUEnvChecker(EnvChecker): """GPU environment check.""" def __init__(self): - self.version = ["10.1"] + self.version = ["10.1", "11.1"] self.lib_key_to_lib_name = {'libcu': 'libcuda.so'} # env self.path = os.getenv("PATH") diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 2b280741fcb..2c92b059f14 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -127,7 +127,11 @@ if(ENABLE_GPU) endif() set(NVCC_TMP_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) - string(REPLACE "-std=c++17" "-std=c++11" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + if(${CUDA_VERSION} VERSION_LESS 11.0) + string(REPLACE "-std=c++17" "-std=c++11" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + else() + string(REPLACE "-std=c++17" "-std=c++14" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") + endif() set_property(SOURCE ${GPU_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE) cuda_add_library(gpu_cuda_lib STATIC ${GPU_SRC_LIST}) set(CMAKE_CXX_FLAGS ${NVCC_TMP_CMAKE_CXX_FLAGS}) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h index 668ee801f5a..f3db6d76584 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h @@ -310,23 +310,15 @@ class Conv2dGpuFwdKernel : public GpuKernel { "cudnnSetTensor4dDescriptor failed"); } void SelectAlgorithm(cudnnTensorDescriptor_t input_descriptor_real) { - if (group_ > 1 || CUDNN_MAJOR < 7) { - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, - cudnnGetConvolutionForwardAlgorithm( - cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, output_desc_, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, 0, &conv_algorithm_), - "cudnnGetConvolutionForwardAlgorithm failed"); - } else { - constexpr int requested_algo_count = 1; - int returned_algo_count; - cudnnConvolutionFwdAlgoPerf_t perf_results; - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnGetConvolutionForwardAlgorithm_v7(cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, - output_desc_, requested_algo_count, &returned_algo_count, &perf_results), - "cudnnGetConvolutionForwardAlgorithm_v7 failed"); - conv_algorithm_ = perf_results.algo; - } + constexpr int requested_algo_count = 1; + int returned_algo_count = 0; + cudnnConvolutionFwdAlgoPerf_t perf_results; + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnGetConvolutionForwardAlgorithm_v7(cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, + output_desc_, requested_algo_count, &returned_algo_count, &perf_results), + "cudnnGetConvolutionForwardAlgorithm_v7 failed"); + conv_algorithm_ = perf_results.algo; if (cudnn_data_type_ == CUDNN_DATA_HALF) { conv_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h index 24c3f4a75b3..caa692f80c8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h @@ -280,23 +280,15 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { return true; } void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) { - if (group_ > 1 || CUDNN_MAJOR < 7) { - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, 0, &algo_), - "GetConvolutionBackwardFilterAlgorithm failed"); - } else { - constexpr int requested_algo_count = 1; - int returned_algo_count; - cudnnConvolutionBwdFilterAlgoPerf_t perf_results; - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, - requested_algo_count, &returned_algo_count, &perf_results), - "GetConvolutionBackwardFilterAlgorithm failed"); - algo_ = perf_results.algo; - } + constexpr int requested_algo_count = 1; + int returned_algo_count = 0; + cudnnConvolutionBwdFilterAlgoPerf_t perf_results; + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnGetConvolutionBackwardFilterAlgorithm_v7(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_, + requested_algo_count, &returned_algo_count, &perf_results), + "GetConvolutionBackwardFilterAlgorithm failed"); + algo_ = perf_results.algo; if (cudnn_data_type_ == CUDNN_DATA_HALF) { algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h index 960315dd6f1..795154c1744 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h @@ -289,23 +289,15 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { [](const int64_t &value) { return static_cast(value); }); } void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) { - if (group_ > 1 || CUDNN_MAJOR < 7) { - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, 0, &algo_), - "cudnnGetConvolutionBackwardDataAlgorithm failed"); - } else { - constexpr int requested_algo_count = 1; - int returned_algo_count; - cudnnConvolutionBwdDataAlgoPerf_t perf_results; - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, - requested_algo_count, &returned_algo_count, &perf_results), - "cudnnGetConvolutionBackwardDataAlgorithm_v7 failed"); - algo_ = perf_results.algo; - } + constexpr int requested_algo_count = 1; + int returned_algo_count = 0; + cudnnConvolutionBwdDataAlgoPerf_t perf_results; + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnGetConvolutionBackwardDataAlgorithm_v7(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real, + requested_algo_count, &returned_algo_count, &perf_results), + "cudnnGetConvolutionBackwardDataAlgorithm_v7 failed"); + algo_ = perf_results.algo; if (cudnn_data_type_ == CUDNN_DATA_HALF) { algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h index 425802de085..afbf9715bb1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_gpu_kernel.h @@ -125,12 +125,21 @@ class LstmGpuKernel : public GpuKernel { CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), "set dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, - cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, - input_mode, direction, rnn_mode, algo, cudnn_data_type_), - "set rnn_desc failed"); cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; +#if CUDNN_VERSION < 8000 + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnSetRNNDescriptor_v6(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + input_mode, direction, rnn_mode, algo, cudnn_data_type_), + "set rnn_desc failed"); CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); +#else + cudnnMathType_t math_type = (cudnn_data_type_ == CUDNN_DATA_HALF) ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH; + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnSetRNNDescriptor_v8(rnn_desc_, algo, rnn_mode, bias_mode, direction, input_mode, + cudnn_data_type_, cudnn_data_type_, math_type, input_size_, + hidden_size_, hidden_size_, num_layers_, dropout_desc_, 0), + "set rnn_desc failed"); +#endif auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h index f5a007c2cd3..c0b468e4b2d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_data_gpu_kernel.h @@ -140,12 +140,21 @@ class LstmGradDataGpuKernel : public GpuKernel { CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), "set dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, - cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, - input_mode, direction, rnn_mode, algo, cudnn_data_type_), - "set rnn_desc failed"); cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; +#if CUDNN_VERSION < 8000 + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnSetRNNDescriptor_v6(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + input_mode, direction, rnn_mode, algo, cudnn_data_type_), + "set rnn_desc failed"); CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); +#else + cudnnMathType_t math_type = (cudnn_data_type_ == CUDNN_DATA_HALF) ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH; + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnSetRNNDescriptor_v8(rnn_desc_, algo, rnn_mode, bias_mode, direction, input_mode, + cudnn_data_type_, cudnn_data_type_, math_type, input_size_, + hidden_size_, hidden_size_, num_layers_, dropout_desc_, 0), + "set rnn_desc failed"); +#endif auto weight_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 4); size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T); CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h index bcb3b2365b1..b771c1c28fe 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/lstm_grad_weight_gpu_kernel.h @@ -114,13 +114,21 @@ class LstmGradWeightGpuKernel : public GpuKernel { CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetDropoutDescriptor(dropout_desc_, handle_, dropout_, nullptr, 0, 0), "set dropout_desc failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, - cudnnSetRNNDescriptor(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, - input_mode, direction, rnn_mode, algo, cudnn_data_type_), - "set rnn_desc failed"); cudnnRNNBiasMode_t bias_mode = has_bias_ ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS; +#if CUDNN_VERSION < 8000 + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnSetRNNDescriptor_v6(handle_, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, + input_mode, direction, rnn_mode, algo, cudnn_data_type_), + "set rnn_desc failed"); CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnSetRNNBiasMode(rnn_desc_, bias_mode), "set bias_mode failed"); - +#else + cudnnMathType_t math_type = (cudnn_data_type_ == CUDNN_DATA_HALF) ? CUDNN_TENSOR_OP_MATH : CUDNN_FMA_MATH; + CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, + cudnnSetRNNDescriptor_v8(rnn_desc_, algo, rnn_mode, bias_mode, direction, input_mode, + cudnn_data_type_, cudnn_data_type_, math_type, input_size_, + hidden_size_, hidden_size_, num_layers_, dropout_desc_, 0), + "set rnn_desc failed"); +#endif auto weight_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); size_t weight_size = weight_shape[0] * weight_shape[1] * weight_shape[2] * sizeof(T);