forked from mindspore-Ecosystem/mindspore
!18441 Fix conv3d cudnn algorithm error
Merge pull request !18441 from tom_chen/conv3d
This commit is contained in:
commit
26c7d274c9
|
@ -320,9 +320,6 @@ class Conv3dGpuKernel : public GpuKernel {
|
|||
output_desc_, requested_algo_count, &returned_algo_count, &perf_results),
|
||||
"cudnnGetConvolutionForwardAlgorithm_v7 failed");
|
||||
conv_algorithm_ = perf_results.algo;
|
||||
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
|
||||
conv_algorithm_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
|
||||
}
|
||||
}
|
||||
|
||||
void SetStrideAndDilation(const CNodePtr &kernel_node) {
|
||||
|
|
|
@ -297,9 +297,6 @@ class Conv3dGradInputGpuKernel : public GpuKernel {
|
|||
requested_algo_count, &returned_algo_count, &perf_results),
|
||||
"cudnnGetConvolutionBackwardDataAlgorithm_v7 failed");
|
||||
algo_ = perf_results.algo;
|
||||
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
|
||||
algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
||||
}
|
||||
}
|
||||
|
||||
void GetInputShape(const CNodePtr &kernel_node, std::vector<size_t> *input_shape) {
|
||||
|
|
Loading…
Reference in New Issue