gpu conv support oldly gpu
This commit is contained in:
parent
3e0ada91ff
commit
c32ff48e70
2
build.sh
2
build.sh
|
@ -24,7 +24,7 @@ usage()
|
|||
echo "Usage:"
|
||||
echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t ut|st] [-g on|off] [-h] [-b ge] [-m infer|train] \\"
|
||||
echo " [-a on|off] [-p on|off] [-i] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|cpu|npu] \\"
|
||||
echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1|310|910] [-I arm64|arm32|x86_64] [-K] \\"
|
||||
echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 10.1|11.1|310|910] [-I arm64|arm32|x86_64] [-K] \\"
|
||||
echo " [-B on|off] [-E] [-l on|off] [-n full|lite|off] [-T on|off] [-H on|off] \\"
|
||||
echo " [-A [cpp|java|object-c] [-C on|off] [-o on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|off] \\"
|
||||
echo " [-L Tensor-RT path] \\"
|
||||
|
|
|
@ -320,6 +320,15 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|||
output_desc_, requested_algo_count, &returned_algo_count, &perf_results),
|
||||
"cudnnGetConvolutionForwardAlgorithm_v7 failed");
|
||||
conv_algorithm_ = perf_results.algo;
|
||||
#if CUDNN_VERSION < 8000
|
||||
if (group_ > 1) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudnnGetConvolutionForwardAlgorithm(
|
||||
cudnn_handle_, input_descriptor_real, filter_desc_, conv_desc_, output_desc_,
|
||||
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, 0, &conv_algorithm_),
|
||||
"cudnnGetConvolutionForwardAlgorithm failed");
|
||||
}
|
||||
#endif
|
||||
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
|
||||
conv_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
|
||||
}
|
||||
|
|
|
@ -290,6 +290,15 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
|
|||
requested_algo_count, &returned_algo_count, &perf_results),
|
||||
"GetConvolutionBackwardFilterAlgorithm failed");
|
||||
algo_ = perf_results.algo;
|
||||
#if CUDNN_VERSION < 8000
|
||||
if (group_ > 1) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle_, x_desc_real, dy_desc_, conv_desc_, dw_desc_,
|
||||
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, 0, &algo_),
|
||||
"GetConvolutionBackwardFilterAlgorithm failed");
|
||||
}
|
||||
#endif
|
||||
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
|
||||
algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
||||
}
|
||||
|
|
|
@ -299,6 +299,15 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|||
requested_algo_count, &returned_algo_count, &perf_results),
|
||||
"cudnnGetConvolutionBackwardDataAlgorithm_v7 failed");
|
||||
algo_ = perf_results.algo;
|
||||
#if CUDNN_VERSION < 8000
|
||||
if (group_ > 1) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle_, w_desc_, dy_desc_, conv_desc_, dx_desc_real,
|
||||
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, 0, &algo_),
|
||||
"cudnnGetConvolutionBackwardDataAlgorithm failed");
|
||||
}
|
||||
#endif
|
||||
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
|
||||
algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue