From e939d61a2ca91579f6c99e740cdf19f16a694512 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Wed, 19 Aug 2020 20:03:39 +0800 Subject: [PATCH] conv pooling pad support NHWC --- .../kernel_compiler/gpu/cuda_impl/pad_impl.cu | 67 ++++++++ .../gpu/cuda_impl/pad_impl.cuh | 8 + .../backend/kernel_compiler/gpu/gpu_kernel.h | 54 ++++++ .../gpu/nn/conv2d_gpu_kernel.h | 124 +++++++++----- .../gpu/nn/conv2d_grad_filter_gpu_kernel.h | 131 +++++++++------ .../gpu/nn/conv2d_grad_input_gpu_kernel.h | 121 +++++++++----- .../gpu/nn/pooling_gpu_kernel.h | 121 ++++++-------- .../gpu/nn/pooling_grad_gpu_kernel.h | 154 +++++++----------- tests/st/ops/gpu/test_maxpool_grad_gpu_op.py | 10 +- 9 files changed, 482 insertions(+), 308 deletions(-) mode change 100755 => 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu mode change 100755 => 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu old mode 100755 new mode 100644 index 3bb4d04a011..99f776a2c81 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu @@ -37,6 +37,39 @@ __global__ void Pad(const size_t size, const T* input, const int num, const int return; } +template +__global__ void PadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, + const int channels, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, float pad_value, T* output) { + T pad_value_ = static_cast(pad_value); + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int block_num = pos / channels / padded_width / padded_height; + const int padded_w = pos / channels % padded_width; + const int padded_h = pos / channels / padded_width % padded_height; + if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height || + padded_w - pad_left >= old_width) { + output[pos] = pad_value_; + } else { + output[pos] = input[((block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left) + *channels + pos % channels]; + } + } + return; +} + +template +__global__ void PadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width, + const int channels, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, T* dx) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int block_num = pos / channels / old_width / old_height; + const int padded_w = pos / channels % old_width + pad_left; + const int padded_h = pos / channels / old_width % old_height + pad_top; + dx[pos] = dy[((block_num * padded_height + padded_h) * padded_width + padded_w)*channels+pos%channels]; + } + return; +} + template __global__ void PadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, @@ -60,6 +93,24 @@ void CalPad(const size_t size, const T* input, const int num, const int channels return; } +template +void CalPadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, + const int channels, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, const float pad_value, T* output, cudaStream_t cuda_stream) { + PadNHWC<<>>(size, input, num, old_height, old_width, channels, + padded_height, padded_width, pad_top, pad_left, pad_value, output); + return; +} + +template +void CalPadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width, + const int channels, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, T* dx, cudaStream_t cuda_stream) { + PadGradNHWC<<>>(size, dy, num, old_height, old_width, channels, + padded_height, padded_width, pad_top, pad_left, dx); + return; +} + template void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, @@ -85,3 +136,19 @@ template void CalPadGrad(const size_t size, const half* dy, const int num, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, half* dx, cudaStream_t cuda_stream); +template void CalPadNHWC(const size_t size, const float* input, const int num, const int old_height, + const int old_width, const int channels, const int padded_height, + const int padded_width, const int pad_top, const int pad_left, float pad_value, + float* output, cudaStream_t cuda_stream); +template void CalPadNHWC(const size_t size, const half* input, const int num, const int old_height, + const int old_width, const int channels, const int padded_height, + const int padded_width, const int pad_top, const int pad_left, float pad_value, + half* output, cudaStream_t cuda_stream); +template void CalPadGradNHWC(const size_t size, const float* dy, const int num, const int old_height, + const int old_width, const int channels, const int padded_height, + const int padded_width, const int pad_top, const int pad_left, float* dx, + cudaStream_t cuda_stream); +template void CalPadGradNHWC(const size_t size, const half* dy, const int num, const int old_height, + const int old_width, const int channels, const int padded_height, + const int padded_width, const int pad_top, const int pad_left, half* dx, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh old mode 100755 new mode 100644 index b10804fdab8..b1c94b8dab6 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh @@ -27,5 +27,13 @@ template void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, const int pad_left, T* dx, cudaStream_t cuda_stream); +template +void CalPadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, + const int channels, const int padded_height, const int padded_width, const int pad_top, const int pad_left, + float pad_value, T* output, cudaStream_t cuda_stream); +template +void CalPadGradNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, + const int channels, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, T* output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h index 9ee6ead1cba..1456e9bda66 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/gpu/kernel_constants.h" #include "runtime/device/gpu/gpu_device_manager.h" @@ -73,6 +74,59 @@ class GpuKernel : public KernelMod { dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1])); } + // transpose shape: NCHW To NHWC + void ShapeNCHW2NHWC(std::vector *shape) { + std::swap((*shape)[1], (*shape)[3]); + std::swap((*shape)[2], (*shape)[1]); + } + + void SetDimA(const std::vector &shape, int *dimA, const std::string &format) { + if (format == "NCHW" || format == "DefaultFormat") { + dimA[0] = SizeToInt(shape[0]); + dimA[1] = SizeToInt(shape[1]); + dimA[2] = SizeToInt(shape[2]); + dimA[3] = SizeToInt(shape[3]); + } else if (format == "NHWC") { + dimA[0] = SizeToInt(shape[0]); + dimA[1] = SizeToInt(shape[3]); + dimA[2] = SizeToInt(shape[1]); + dimA[3] = SizeToInt(shape[2]); + } else { + MS_LOG(ERROR) << "Unsupported data format " << format; + } + } + void SetStrideA(const std::vector &shape, int *strideA, const std::string &format) { + if (format == "NCHW" || format == "DefaultFormat") { + strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3]); + strideA[1] = SizeToInt(shape[2] * shape[3]); + strideA[2] = SizeToInt(shape[3]); + strideA[3] = 1; + } else if (format == "NHWC") { + strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3]); + strideA[1] = 1; + strideA[2] = SizeToInt(shape[2] * shape[3]); + strideA[3] = SizeToInt(shape[3]); + } else { + MS_LOG(ERROR) << "Unsupported data format " << format; + } + } + + void SetNCHW(const std::vector &shape, int *n, int *c, int *h, int *w, const std::string &format) { + if (format == "NCHW" || format == "DefaultFormat") { + *n = SizeToInt(shape[0]); + *c = SizeToInt(shape[1]); + *h = SizeToInt(shape[2]); + *w = SizeToInt(shape[3]); + } else if (format == "NHWC") { + *n = SizeToInt(shape[0]); + *c = SizeToInt(shape[3]); + *h = SizeToInt(shape[1]); + *w = SizeToInt(shape[2]); + } else { + MS_LOG(ERROR) << "Unsupported data format " << format; + } + } + inline void CheckBroadcast4TensorOp(const std::vector &A, const std::vector &B, const std::vector &Out) { if (A != Out && B != Out) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h index c5e8a26801e..e1b7aecef8b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h @@ -38,6 +38,7 @@ class Conv2dGpuFwdKernel : public GpuKernel { conv_desc_(nullptr), padded_desc_(nullptr), cudnn_data_type_(CUDNN_DATA_FLOAT), + compute_format_(CUDNN_TENSOR_NCHW), old_height_(0), old_width_(0), pad_height_(0), @@ -76,9 +77,15 @@ class Conv2dGpuFwdKernel : public GpuKernel { const float beta = 0; if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { T *padded_addr = GetDeviceAddress(workspace, 1); - CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, - reinterpret_cast(stream_ptr)); + if (data_format_ == "NHWC") { + CalPadNHWC(padded_size_ / sizeof(T), input_addr, n_, old_height_, old_width_, c_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, + reinterpret_cast(stream_ptr)); + } else { + CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, + reinterpret_cast(stream_ptr)); + } CHECK_CUDNN_RET_WITH_EXCEPT( cudnnConvolutionForward(cudnn_handle_, &alpha, padded_desc_, padded_addr, filter_desc_, filter_addr, conv_desc_, conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr), @@ -97,15 +104,21 @@ class Conv2dGpuFwdKernel : public GpuKernel { if (!CheckParam(kernel_node)) { return false; } - auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); + data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); + auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); is_null_input_ = CHECK_NULL_INPUT(in_shape); if (is_null_input_) { MS_LOG(WARNING) << "Conv2dGpuFwdKernel input is null."; InitSizeLists(); return true; } + SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_); + if (data_format_ == "NHWC") { + compute_format_ = CUDNN_TENSOR_NHWC; + } Set4DDesc(in_shape, filter_shape, output_shape); group_ = GetAttr(kernel_node, "group"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); @@ -116,17 +129,55 @@ class Conv2dGpuFwdKernel : public GpuKernel { pad_mode_ = GetAttr(kernel_node, "pad_mode"); SetStrideAndDilation(kernel_node); cudnnTensorDescriptor_t input_descriptor_real = nullptr; + int padA[2]; + int strideA[2] = {stride_[2], stride_[3]}; + int dilaA[2] = {dilation_[2], dilation_[3]}; if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) { - SetPad(in_shape, kernel_node); + pad_height_ = pad_list[0] + pad_list[1]; + pad_width_ = pad_list[2] + pad_list[3]; + pad_top_ = pad_list[0]; + pad_left_ = pad_list[2]; + + // if use_pad_ == true, using zero padding in advance, else using the default cudnn pad. + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + int dimA[4]; + int strideApadded[4]; + if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { + auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), + IntToSize(old_width_ + pad_width_)}; + SetDimA(padded_shape, dimA, data_format_); + SetStrideA(padded_shape, strideApadded, data_format_); + } else if (data_format_ == "NHWC") { + auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), + IntToSize(c_)}; + SetDimA(padded_shape, dimA, data_format_); + SetStrideA(padded_shape, strideApadded, data_format_); + } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(padded_desc_, cudnn_data_type_, 4, dimA, strideApadded), + "cudnnSetTensor4dDescriptor failed"); + + if (use_pad_) { + padA[0] = 0; + padA[1] = 0; + } else { + padA[0] = pad_top_; + padA[1] = pad_left_; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolutionNdDescriptor failed"); input_descriptor_real = use_pad_ ? padded_desc_ : input_desc_; } else { if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { pad_height_ = 0; pad_width_ = 0; } + padA[0] = pad_height_; + padA[1] = pad_width_; CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2], - dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), "cudnnSetConvolution2dDescriptor failed"); input_descriptor_real = input_desc_; } @@ -193,13 +244,11 @@ class Conv2dGpuFwdKernel : public GpuKernel { CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed"); } bool CheckParam(const CNodePtr &kernel_node) { - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 2) { MS_LOG(ERROR) << "Input number is " << input_num << ", but conv2d needs 2 inputs."; return false; } - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { MS_LOG(ERROR) << "Output number is " << output_num << ", but conv2d needs 1 output."; @@ -207,46 +256,29 @@ class Conv2dGpuFwdKernel : public GpuKernel { } return true; } - void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { - auto pad_list = GetAttr>(kernel_node, "pad_list"); - - n_ = SizeToInt(in_shape[0]); - c_ = SizeToInt(in_shape[1]); - old_height_ = SizeToInt(in_shape[2]); - old_width_ = SizeToInt(in_shape[3]); - pad_height_ = pad_list[0] + pad_list[1]; - pad_width_ = pad_list[2] + pad_list[3]; - pad_top_ = pad_list[0]; - pad_left_ = pad_list[2]; - - // if use_pad_ == true, using zero padding in advance, else using the default cudnn pad. - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_, - old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( - conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3], - dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - } void Set4DDesc(const std::vector &in_shape, const std::vector &filter_shape, const std::vector &output_shape) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), - SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), - "cudnnSetTensor4dDescriptor failed"); + int nbDims = 4; + int dimA[4]; + int strideAin[4]; + int dimAout[4]; + int strideAout[4]; + SetDimA(in_shape, dimA, data_format_); + SetStrideA(in_shape, strideAin, data_format_); + SetDimA(output_shape, dimAout, data_format_); + SetStrideA(output_shape, strideAout, data_format_); + int filterDimA[4]; + // OHWI for NHWC; OIHW for NCHW + SetDimA(filter_shape, filterDimA, data_format_); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(input_desc_, cudnn_data_type_, nbDims, dimA, strideAin), + "cudnnSetTensor4dDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(filter_shape[0]), - SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), + cudnnSetFilterNdDescriptor(filter_desc_, cudnn_data_type_, compute_format_, nbDims, filterDimA), "cudnnSetFilter4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), - SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), - "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(output_desc_, cudnn_data_type_, nbDims, dimAout, strideAout), + "cudnnSetTensor4dDescriptor failed"); } void SelectAlgorithm(cudnnTensorDescriptor_t input_descriptor_real) { if (group_ > 1 || CUDNN_MAJOR < 7) { @@ -292,11 +324,13 @@ class Conv2dGpuFwdKernel : public GpuKernel { cudnnConvolutionDescriptor_t conv_desc_; cudnnTensorDescriptor_t padded_desc_; std::string pad_mode_; + std::string data_format_ = "NCHW"; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; const float pad_value_ = 0.0; cudnnDataType_t cudnn_data_type_; + cudnnTensorFormat_t compute_format_; int old_height_; int old_width_; int pad_height_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h index ac4d127e437..56836291edc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h @@ -38,6 +38,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { x_desc_(nullptr), padded_descriptor_(nullptr), cudnn_data_type_(CUDNN_DATA_FLOAT), + compute_format_(CUDNN_TENSOR_NCHW), old_height_(0), old_width_(0), pad_height_(0), @@ -75,12 +76,18 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { const float alpha = 1; const float beta = 0; + if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { T *padded = GetDeviceAddress(workspace, 1); - CalPad(padded_size_ / sizeof(T), x, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, - reinterpret_cast(stream_ptr)); - + if (data_format_ == "NHWC") { + CalPadNHWC(padded_size_ / sizeof(T), x, n_, old_height_, old_width_, c_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, + reinterpret_cast(stream_ptr)); + } else { + CalPad(padded_size_ / sizeof(T), x, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, + reinterpret_cast(stream_ptr)); + } CHECK_CUDNN_RET_WITH_EXCEPT( cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, padded_descriptor_, padded, dy_desc_, dy, conv_desc_, algo_, work_space, workspace_size_, &beta, dw_desc_, dw), @@ -99,16 +106,21 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { return false; } cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape); if (is_null_input_) { MS_LOG(WARNING) << "ConvGradFilterGpuBkwKernel input is null."; InitSizeLists(); return true; } - std::vector filter_shape; + data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); + std::vector filter_shape; GetFilterShape(kernel_node, &filter_shape); + if (data_format_ == "NHWC") { + compute_format_ = CUDNN_TENSOR_NHWC; + } + SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_); Set4DDesc(dy_shape, filter_shape, in_shape); group_ = GetAttr(kernel_node, "group"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed"); @@ -120,18 +132,54 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { pad_mode_ = GetAttr(kernel_node, "pad_mode"); SetStrideAndDilation(kernel_node); cudnnTensorDescriptor_t x_desc_real = nullptr; + int padA[2]; + int strideA[2] = {stride_[0], stride_[1]}; + int dilaA[2] = {dilation_[0], dilation_[1]}; if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) { - SetPad(in_shape, kernel_node); + pad_height_ = pad_list[0] + pad_list[1]; + pad_width_ = pad_list[2] + pad_list[3]; + pad_top_ = pad_list[0]; + pad_left_ = pad_list[2]; + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + int dimA[4]; + int strideApadded[4]; + if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { + auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), + IntToSize(old_width_ + pad_width_)}; + SetDimA(padded_shape, dimA, data_format_); + SetStrideA(padded_shape, strideApadded, data_format_); + } else if (data_format_ == "NHWC") { + auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), + IntToSize(c_)}; + SetDimA(padded_shape, dimA, data_format_); + SetStrideA(padded_shape, strideApadded, data_format_); + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, 4, dimA, strideApadded), + "cudnnSetTensor4dDescriptor failed"); + if (use_pad_) { + padA[0] = 0; + padA[1] = 0; + } else { + padA[0] = pad_top_; + padA[1] = pad_left_; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolutionNdDescriptor failed"); x_desc_real = use_pad_ ? padded_descriptor_ : x_desc_; } else { if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { pad_height_ = 0; pad_width_ = 0; } + padA[0] = pad_height_; + padA[1] = pad_width_; CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], - dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "GetConvolution2dDescriptor failed"); + cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolution2dDescriptor failed"); x_desc_real = x_desc_; } if (cudnn_data_type_ == CUDNN_DATA_HALF) { @@ -208,27 +256,6 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { } return true; } - void SetPad(const std::vector &in_shape, const CNodePtr &kernel_node) { - auto pad_list = GetAttr>(kernel_node, "pad_list"); - n_ = SizeToInt(in_shape[0]); - c_ = SizeToInt(in_shape[1]); - old_height_ = SizeToInt(in_shape[2]); - old_width_ = SizeToInt(in_shape[3]); - pad_height_ = pad_list[0] + pad_list[1]; - pad_width_ = pad_list[2] + pad_list[3]; - pad_top_ = pad_list[0]; - pad_left_ = pad_list[2]; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( - conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], - dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); - } void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) { if (group_ > 1 || CUDNN_MAJOR < 7) { CHECK_CUDNN_RET_WITH_EXCEPT( @@ -249,27 +276,33 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; } } - void GetFilterShape(const CNodePtr &kernel_node, std::vector *filter_shape) { + void GetFilterShape(const CNodePtr &kernel_node, std::vector *filter_shape) { auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_sizes")->cast()->value(); (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*filter_shape), - [](const ValuePtr &e) -> int { return e->cast()->value(); }); + [](const ValuePtr &e) -> size_t { return e->cast()->value(); }); } - void Set4DDesc(const std::vector &dy_shape, const std::vector &filter_shape, + void Set4DDesc(const std::vector &dy_shape, const std::vector &filter_shape, const std::vector &in_shape) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), - SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), - "SetTensor4dDescriptor failed"); + int nbDims = 4; + int dimA[4]; + int strideAin[4]; + int dimAdy[4]; + int strideAdy[4]; + SetDimA(in_shape, dimA, data_format_); + SetStrideA(in_shape, strideAin, data_format_); + SetDimA(dy_shape, dimAdy, data_format_); + SetStrideA(dy_shape, strideAdy, data_format_); + // filter shape always keep OIHW. + int filterDimA[4] = {SizeToInt(filter_shape[0]), SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), + SizeToInt(filter_shape[3])}; + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, nbDims, dimAdy, strideAdy), + "cudnnSetTensorNdDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetFilter4dDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), filter_shape[1], - filter_shape[2], filter_shape[3]), - "SetFilter4dDescriptor failed"); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]), - SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])), - "SetTensor4dDescriptor failed"); + cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, compute_format_, nbDims, filterDimA), + "cudnnSetFilterNdDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(x_desc_, cudnn_data_type_, nbDims, dimA, strideAin), + "cudnnSetTensorNdDescriptor failed"); } void SetStrideAndDilation(const CNodePtr &kernel_node) { stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); @@ -292,11 +325,13 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { cudnnTensorDescriptor_t padded_descriptor_; cudnnConvolutionBwdFilterAlgo_t algo_; std::string pad_mode_; + std::string data_format_ = "NCHW"; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; const float pad_value_ = 0.0; cudnnDataType_t cudnn_data_type_; + cudnnTensorFormat_t compute_format_; int old_height_; int old_width_; int pad_height_; @@ -319,4 +354,4 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ +#endif // MINDePORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h index e40bd6898fe..cc123b912f0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h @@ -38,6 +38,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { dx_desc_(nullptr), padded_descriptor_(nullptr), cudnn_data_type_(CUDNN_DATA_FLOAT), + compute_format_(CUDNN_TENSOR_NCHW), old_height_(0), old_width_(0), pad_height_(0), @@ -75,7 +76,6 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { const float alpha = 1; const float beta = 0; - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { T *padded = GetDeviceAddress(workspace, 1); @@ -83,8 +83,13 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, workspace_size_, &beta, padded_descriptor_, padded), "ConvolutionBackwardData failed"); - CalPadGrad(output_size_ / sizeof(T), padded, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); + if (data_format_ == "NHWC") { + CalPadGradNHWC(output_size_ / sizeof(T), padded, n_, old_height_, old_width_, c_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); + } else { + CalPadGrad(output_size_ / sizeof(T), padded, n_, c_, old_height_, old_width_, old_height_ + pad_height_, + old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); + } } else { CHECK_CUDNN_RET_WITH_EXCEPT( cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, @@ -99,16 +104,23 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { return false; } cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); + auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); is_null_input_ = CHECK_NULL_INPUT(dy_shape); if (is_null_input_) { MS_LOG(WARNING) << "ConvGradInputGpuBkwKernel input is null."; InitSizeLists(); return true; } - std::vector input_shape; + + std::vector input_shape; GetInputShape(kernel_node, &input_shape); + if (data_format_ == "NHWC") { + compute_format_ = CUDNN_TENSOR_NHWC; + ShapeNCHW2NHWC(&input_shape); + } + SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); Set4DDesc(dy_shape, input_shape, filter_shape); group_ = GetAttr(kernel_node, "group"); @@ -121,17 +133,53 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { pad_mode_ = GetAttr(kernel_node, "pad_mode"); SetStrideAndDilation(kernel_node); cudnnTensorDescriptor_t dx_desc_real = nullptr; + int padA[2]; + int strideA[2] = {stride_[0], stride_[1]}; + int dilaA[2] = {dilation_[0], dilation_[1]}; if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) { - SetPad(input_shape, kernel_node); + pad_height_ = pad_list[0] + pad_list[1]; + pad_width_ = pad_list[2] + pad_list[3]; + pad_top_ = pad_list[0]; + pad_left_ = pad_list[2]; + if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { + use_pad_ = false; + } + int dimA[4]; + int strideApadded[4]; + if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { + auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), + IntToSize(old_width_ + pad_width_)}; + SetDimA(padded_shape, dimA, data_format_); + SetStrideA(padded_shape, strideApadded, data_format_); + } else if (data_format_ == "NHWC") { + auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), + IntToSize(c_)}; + SetDimA(padded_shape, dimA, data_format_); + SetStrideA(padded_shape, strideApadded, data_format_); + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, 4, dimA, strideApadded), + "cudnnSetTensor4dDescriptor failed"); + if (use_pad_) { + padA[0] = 0; + padA[1] = 0; + } else { + padA[0] = pad_top_; + padA[1] = pad_left_; + } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + "cudnnSetConvolutionNdDescriptor failed"); dx_desc_real = use_pad_ ? padded_descriptor_ : dx_desc_; } else { if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { pad_height_ = 0; pad_width_ = 0; } + padA[0] = pad_height_; + padA[1] = pad_width_; CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2], - dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), + cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), "cudnnSetConvolution2dDescriptor failed"); dx_desc_real = dx_desc_; } @@ -208,24 +256,6 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { } void SetPad(const std::vector &input_shape, const CNodePtr &kernel_node) { auto pad_list = GetAttr>(kernel_node, "pad_list"); - n_ = input_shape[0]; - c_ = input_shape[1]; - old_height_ = input_shape[2]; - old_width_ = input_shape[3]; - pad_height_ = pad_list[0] + pad_list[1]; - pad_width_ = pad_list[2] + pad_list[3]; - pad_top_ = pad_list[0]; - pad_left_ = pad_list[2]; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor( - conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1], - dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT), - "cudnnSetConvolution2dDescriptor failed"); } void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) { if (group_ > 1 || CUDNN_MAJOR < 7) { @@ -247,25 +277,32 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; } } - void GetInputShape(const CNodePtr &kernel_node, std::vector *input_shape) { + void GetInputShape(const CNodePtr &kernel_node, std::vector *input_shape) { auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast()->value(); (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*input_shape), - [](const ValuePtr &e) -> int { return e->cast()->value(); }); + [](const ValuePtr &e) -> size_t { return e->cast()->value(); }); } - void Set4DDesc(const std::vector &dy_shape, const std::vector &input_shape, + void Set4DDesc(const std::vector &dy_shape, const std::vector &input_shape, const std::vector &filter_shape) { + int nbDims = 4; + int dimA[4]; + int strideAin[4]; + int dimAdy[4]; + int strideAdy[4]; + int filterDimA[4]; + SetDimA(input_shape, dimA, data_format_); + SetStrideA(input_shape, strideAin, data_format_); + SetDimA(dy_shape, dimAdy, data_format_); + SetStrideA(dy_shape, strideAdy, data_format_); + SetDimA(filter_shape, filterDimA, data_format_); + + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, nbDims, dimAdy, strideAdy), + "cudnnSetTensorNdDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetFilter4dDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), - SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])), - "SetFilter4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]), - SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])), - "SetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, input_shape[0], input_shape[1], - input_shape[2], input_shape[3]), - "SetTensor4dDescriptor failed"); + cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, compute_format_, nbDims, filterDimA), + "cudnnSetFilterNdDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dx_desc_, cudnn_data_type_, nbDims, dimA, strideAin), + "cudnnSetTensorNdDescriptor failed"); } void SetStrideAndDilation(const CNodePtr &kernel_node) { stride_ = AnfAlgo::GetNodeAttr>(kernel_node, "stride"); @@ -288,10 +325,12 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { cudnnTensorDescriptor_t padded_descriptor_; cudnnConvolutionBwdDataAlgo_t algo_; std::string pad_mode_; + std::string data_format_ = "NCHW"; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; cudnnDataType_t cudnn_data_type_; + cudnnTensorFormat_t compute_format_; int old_height_; int old_width_; int pad_height_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h index e9cf05d0dde..e16caf7fe6b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h @@ -35,9 +35,9 @@ class PoolingGpuFwdKernel : public GpuKernel { input_descriptor_(nullptr), output_descriptor_(nullptr), pooling_descriptor_(nullptr), - padded_descriptor_(nullptr), pooling_mode_(CUDNN_POOLING_MAX), cudnn_data_type_(CUDNN_DATA_FLOAT), + compute_format_(CUDNN_TENSOR_NCHW), old_height_(0), old_width_(0), pad_height_(0), @@ -50,9 +50,7 @@ class PoolingGpuFwdKernel : public GpuKernel { is_null_input_(false), input_size_(0), output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} + workspace_size_(0) {} ~PoolingGpuFwdKernel() override { DestroyResource(); } const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -67,20 +65,10 @@ class PoolingGpuFwdKernel : public GpuKernel { T *output_addr = reinterpret_cast(outputs[0]->addr); const float alpha = 1; const float beta = 0; - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded_addr = reinterpret_cast(workspace[0]->addr); - CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, - reinterpret_cast(stream_ptr)); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, padded_descriptor_, - padded_addr, &beta, output_descriptor_, output_addr), - "cudnnPoolingForward failed"); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_, - input_addr, &beta, output_descriptor_, output_addr), - "cudnnPoolingForward failed"); - } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_, + input_addr, &beta, output_descriptor_, output_addr), + "cudnnPoolingForward failed"); return true; } bool Init(const CNodePtr &kernel_node) { @@ -89,39 +77,64 @@ class PoolingGpuFwdKernel : public GpuKernel { return false; } cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); + auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + + auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); is_null_input_ = CHECK_NULL_INPUT(input_shape); if (is_null_input_) { MS_LOG(WARNING) << "PoolingGpuFwdKernel input is null."; InitSizeLists(); return true; } + SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); + int nbDims = 4; + int dimA[4]; + int strideAin[4]; + int dimAout[4]; + int strideAout[4]; + SetDimA(input_shape, dimA, data_format_); + SetStrideA(input_shape, strideAin, data_format_); + SetDimA(output_shape, dimAout, data_format_); + SetStrideA(output_shape, strideAout, data_format_); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), - SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), + cudnnSetTensorNdDescriptor(input_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin), "cudnnSetTensor4dDescriptor failed"); - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), - SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), + cudnnSetTensorNdDescriptor(output_descriptor_, cudnn_data_type_, nbDims, dimAout, strideAout), "cudnnSetTensor4dDescriptor failed"); auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); int window_height = window[2]; int window_width = window[3]; stride_ = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); SetPoolingMode(kernel_node); + int windowDimA[2] = {window_height, window_width}; + int paddingA[2] = {0, 0}; + int strideA[2] = {stride_[2], stride_[3]}; if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { - SetPad(input_shape, window_height, window_width); + pad_height_ = + std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) + : (old_height_ / stride_[2]) + 1) - + 1) * + stride_[2] + + window_height - old_height_); + pad_width_ = + std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) + : (old_width_ / stride_[3]) + 1) - + 1) * + stride_[3] + + window_width - old_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + paddingA[0] = pad_top_; + paddingA[1] = pad_left_; } else { pad_height_ = 0; pad_width_ = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, - window_width, pad_height_, pad_width_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); } - + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, + 2, windowDimA, paddingA, strideA), + "cudnnSetPoolingNdDescriptor failed"); InitSizeLists(); return true; } @@ -131,7 +144,6 @@ class PoolingGpuFwdKernel : public GpuKernel { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), "cudnnCreatePoolingDescriptor failed"); } @@ -146,15 +158,6 @@ class PoolingGpuFwdKernel : public GpuKernel { } input_size_list_.push_back(input_size_); output_size_list_.push_back(output_size_); - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast(&padded_size_)), - "cudnnGetTensorSizeInBytes failed"); - workspace_size_list_.push_back(padded_size_); - if (padded_size_ == 0) { - MS_LOG(EXCEPTION) << "Padded size is 0."; - } - } return; } @@ -167,36 +170,7 @@ class PoolingGpuFwdKernel : public GpuKernel { } return true; } - void SetPad(const std::vector &input_shape, const int &window_height, const int &window_width) { - n_ = SizeToInt(input_shape[0]); - c_ = SizeToInt(input_shape[1]); - old_height_ = SizeToInt(input_shape[2]); - old_width_ = SizeToInt(input_shape[3]); - pad_height_ = - std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) - : (old_height_ / stride_[2]) + 1) - - 1) * - stride_[2] + - window_height - old_height_); - pad_width_ = - std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) - : (old_width_ / stride_[3]) + 1) - - 1) * - stride_[3] + - window_width - old_width_); - pad_top_ = pad_height_ / 2; - pad_left_ = pad_width_ / 2; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, - window_height, window_width, use_pad_ ? 0 : pad_top_, - use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); - } + void SetPoolingMode(const CNodePtr &kernel_node) { pad_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); mode_ = AnfAlgo::GetCNodeName(kernel_node); @@ -211,7 +185,6 @@ class PoolingGpuFwdKernel : public GpuKernel { void DestroyResource() noexcept { CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), "cudnnDestroyPoolingDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "cudnnDestroyTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed"); } @@ -220,16 +193,16 @@ class PoolingGpuFwdKernel : public GpuKernel { cudnnTensorDescriptor_t input_descriptor_; cudnnTensorDescriptor_t output_descriptor_; cudnnPoolingDescriptor_t pooling_descriptor_; - cudnnTensorDescriptor_t padded_descriptor_; cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; std::vector stride_; std::string mode_; std::string pad_mode_; + std::string data_format_ = "NCHW"; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; cudnnDataType_t cudnn_data_type_; - + cudnnTensorFormat_t compute_format_; int old_height_; int old_width_; int pad_height_; @@ -242,9 +215,7 @@ class PoolingGpuFwdKernel : public GpuKernel { bool is_null_input_; size_t input_size_; size_t output_size_; - size_t padded_size_; size_t workspace_size_; - bool use_pad_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h index 0d16fc48a2c..120399f3289 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h @@ -37,9 +37,9 @@ class PoolingGradGpuKernel : public GpuKernel { dy_descriptor_(nullptr), x_descriptor_(nullptr), dx_descriptor_(nullptr), - padded_descriptor_(nullptr), pooling_mode_(CUDNN_POOLING_MAX), cudnn_data_type_(CUDNN_DATA_FLOAT), + compute_format_(CUDNN_TENSOR_NCHW), old_height_(0), old_width_(0), pad_height_(0), @@ -52,9 +52,7 @@ class PoolingGradGpuKernel : public GpuKernel { is_null_input_(false), input_size_(0), output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} + workspace_size_(0) {} ~PoolingGradGpuKernel() override { DestroyResource(); } const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -72,27 +70,10 @@ class PoolingGradGpuKernel : public GpuKernel { const float alpha = 1; const float beta = 0; - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { - T *padded = GetDeviceAddress(workspace, 0); - T *padded_dx = GetDeviceAddress(workspace, 1); - - CalPad(padded_size_ / sizeof(T), x_data, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, - reinterpret_cast(stream_ptr)); - - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, - padded_descriptor_, padded, &beta, padded_descriptor_, padded_dx), - "cudnnPoolingBackward failed"); - - CalPadGrad(output_size_ / sizeof(T), padded_dx, n_, c_, old_height_, old_width_, old_height_ + pad_height_, - old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); - } else { - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, - x_descriptor_, x_data, &beta, dx_descriptor_, dx), - "cudnnPoolingBackward failed"); - } + CHECK_CUDNN_RET_WITH_EXCEPT( + cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, + x_descriptor_, x_data, &beta, dx_descriptor_, dx), + "cudnnPoolingBackward failed"); return true; } bool Init(const CNodePtr &kernel_node) override { @@ -104,46 +85,73 @@ class PoolingGradGpuKernel : public GpuKernel { int window_height = window[2]; int window_width = window[3]; SetPoolingMode(kernel_node); - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto input_mask = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); + auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); if (is_null_input_) { MS_LOG(WARNING) << "PoolingGradGpuKernel input is null."; InitSizeLists(); return true; } + SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); + int windowDimA[2] = {window_height, window_width}; + int paddingA[2] = {0, 0}; + int strideA[2] = {stride_[2], stride_[3]}; + int nbDims = 4; + int dimA[4]; + int strideAin[4]; + int dimAy[4]; + int strideAiny[4]; + int dimAdy[4]; + int strideAdy[4]; + int dimAout[4]; + int strideAout[4]; + SetDimA(input_shape, dimA, data_format_); + SetStrideA(input_shape, strideAin, data_format_); + SetDimA(input_mask, dimAy, data_format_); + SetStrideA(input_mask, strideAiny, data_format_); + SetDimA(dout_shape, dimAdy, data_format_); + SetStrideA(dout_shape, strideAdy, data_format_); + SetDimA(output_shape, dimAout, data_format_); + SetStrideA(output_shape, strideAout, data_format_); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny), + "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy), + "cudnnSetTensor4dDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(y_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_mask[0]), - SizeToInt(input_mask[1]), SizeToInt(input_mask[2]), SizeToInt(input_mask[3])), - "cudnnSetTensor4dDescriptor"); - - auto dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dy_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dout_shape[0]), - SizeToInt(dout_shape[1]), SizeToInt(dout_shape[2]), SizeToInt(dout_shape[3])), - "cudnnSetTensor4dDescriptor"); - - auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(dx_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]), - SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])), + cudnnSetTensorNdDescriptor(dx_descriptor_, cudnn_data_type_, nbDims, dimAout, strideAout), "cudnnSetTensor4dDescriptor failed"); + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(x_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin), + "cudnnSetTensor4dDescriptor failed"); if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { - SetPad(input_shape, window_height, window_width); + pad_height_ = + std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) + : (old_height_ / stride_[2]) + 1) - + 1) * + stride_[2] + + window_height - old_height_); + pad_width_ = + std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) + : (old_width_ / stride_[3]) + 1) - + 1) * + stride_[3] + + window_width - old_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + paddingA[0] = pad_top_; + paddingA[1] = pad_left_; } else { if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { pad_height_ = 0; pad_width_ = 0; } - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height, - window_width, pad_height_, pad_width_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), - SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])), - "cudnnSetTensor4dDescriptor"); } + CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, + 2, windowDimA, paddingA, strideA), + "cudnnSetPoolingNdDescriptor failed"); InitSizeLists(); return true; } @@ -155,7 +163,6 @@ class PoolingGradGpuKernel : public GpuKernel { CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_descriptor_), "cudnnCreateTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_descriptor_), "cudnnCreateTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_descriptor_), "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_), "cudnnCreatePoolingDescriptor failed"); } @@ -179,16 +186,6 @@ class PoolingGradGpuKernel : public GpuKernel { "cudnnGetTensorSizeInBytes failed"); } input_size_list_.push_back(input_size_); - - if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_), - "cudnnGetTensorSizeInBytes failed"); - if (padded_size_ == 0) { - MS_LOG(EXCEPTION) << "Padded size is 0."; - } - workspace_size_list_.push_back(padded_size_); - workspace_size_list_.push_back(padded_size_); - } return; } @@ -206,35 +203,6 @@ class PoolingGradGpuKernel : public GpuKernel { c_ = SizeToInt(input_shape[1]); old_height_ = SizeToInt(input_shape[2]); old_width_ = SizeToInt(input_shape[3]); - pad_height_ = - std::max(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2]) - : (old_height_ / stride_[2]) + 1) - - 1) * - stride_[2] + - window_height - old_height_); - pad_width_ = - std::max(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3]) - : (old_width_ / stride_[3]) + 1) - - 1) * - stride_[3] + - window_width - old_width_); - pad_top_ = pad_height_ / 2; - pad_left_ = pad_width_ / 2; - if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) { - use_pad_ = false; - } - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, - c_, old_height_ + pad_height_, old_width_ + pad_width_), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]), - SizeToInt(input_shape[1]), SizeToInt(input_shape[2]) + (use_pad_ ? pad_height_ : 0), - SizeToInt(input_shape[3]) + (use_pad_ ? pad_width_ : 0)), - "cudnnSetTensor4dDescriptor"); - CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, - window_height, window_width, use_pad_ ? 0 : pad_top_, - use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]), - "cudnnSetPooling2dDescriptor failed"); } void SetPoolingMode(const CNodePtr &kernel_node) { pad_mode_ = GetAttr(kernel_node, "padding"); @@ -252,7 +220,6 @@ class PoolingGradGpuKernel : public GpuKernel { void DestroyResource() noexcept { CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_), "cudnnDestroyPoolingDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_descriptor_), "cudnnDestroyTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_descriptor_), "cudnnDestroyTensorDescriptor failed"); CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_descriptor_), "cudnnDestroyTensorDescriptor failed"); @@ -265,7 +232,6 @@ class PoolingGradGpuKernel : public GpuKernel { cudnnTensorDescriptor_t dy_descriptor_; cudnnTensorDescriptor_t x_descriptor_; cudnnTensorDescriptor_t dx_descriptor_; - cudnnTensorDescriptor_t padded_descriptor_; cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; std::vector stride_; std::vector input_size_list_; @@ -273,7 +239,9 @@ class PoolingGradGpuKernel : public GpuKernel { std::vector workspace_size_list_; std::string mode_; std::string pad_mode_; + std::string data_format_ = "NCHW"; cudnnDataType_t cudnn_data_type_; + cudnnTensorFormat_t compute_format_; int old_height_; int old_width_; int pad_height_; @@ -286,9 +254,7 @@ class PoolingGradGpuKernel : public GpuKernel { bool is_null_input_; size_t input_size_; size_t output_size_; - size_t padded_size_; size_t workspace_size_; - bool use_pad_; }; } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/gpu/test_maxpool_grad_gpu_op.py b/tests/st/ops/gpu/test_maxpool_grad_gpu_op.py index 5b65f609644..a8762b0376b 100644 --- a/tests/st/ops/gpu/test_maxpool_grad_gpu_op.py +++ b/tests/st/ops/gpu/test_maxpool_grad_gpu_op.py @@ -45,23 +45,23 @@ def test_maxpool2d_grad(): [24, 25, 26, 27, 28, 29], [30, 31, 32, 33, 34, 35] ]]]).astype(np.float32)) - a = Tensor(np.array([[[ + d = Tensor(np.array([[[ [3, 3, 3], [3, 3, 3], [3, 3, 3] ]]]).astype(np.float32)) - d = Tensor(np.array([[[ + a = Tensor(np.array([[[ [7, 9, 11], [19, 21, 23], [31, 33, 35] ]]]).astype(np.float32)) expect_result = (np.array([[[ [0, 0, 0, 0, 0, 0], - [0, 7, 0, 9, 0, 11], + [0, 3, 0, 3, 0, 3], [0, 0, 0, 0, 0, 0], - [0, 19, 0, 21, 0, 23], + [0, 3, 0, 3, 0, 3], [0, 0, 0, 0, 0, 0], - [0, 31, 0, 33, 0, 35] + [0, 3, 0, 3, 0, 3] ]]])) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")