forked from mindspore-Ecosystem/mindspore
!4758 GPU kernel support NHWC
Merge pull request !4758 from VectorSL/nhwc-support
This commit is contained in:
commit
a3dae6344b
|
@ -37,6 +37,39 @@ __global__ void Pad(const size_t size, const T* input, const int num, const int
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void PadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width,
|
||||
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
||||
const int pad_left, float pad_value, T* output) {
|
||||
T pad_value_ = static_cast<T>(pad_value);
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
int block_num = pos / channels / padded_width / padded_height;
|
||||
const int padded_w = pos / channels % padded_width;
|
||||
const int padded_h = pos / channels / padded_width % padded_height;
|
||||
if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height ||
|
||||
padded_w - pad_left >= old_width) {
|
||||
output[pos] = pad_value_;
|
||||
} else {
|
||||
output[pos] = input[((block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left)
|
||||
*channels + pos % channels];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void PadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width,
|
||||
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
||||
const int pad_left, T* dx) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
|
||||
int block_num = pos / channels / old_width / old_height;
|
||||
const int padded_w = pos / channels % old_width + pad_left;
|
||||
const int padded_h = pos / channels / old_width % old_height + pad_top;
|
||||
dx[pos] = dy[((block_num * padded_height + padded_h) * padded_width + padded_w)*channels+pos%channels];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void PadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height,
|
||||
const int old_width, const int padded_height, const int padded_width, const int pad_top,
|
||||
|
@ -60,6 +93,24 @@ void CalPad(const size_t size, const T* input, const int num, const int channels
|
|||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalPadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width,
|
||||
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
||||
const int pad_left, const float pad_value, T* output, cudaStream_t cuda_stream) {
|
||||
PadNHWC<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, num, old_height, old_width, channels,
|
||||
padded_height, padded_width, pad_top, pad_left, pad_value, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalPadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width,
|
||||
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
||||
const int pad_left, T* dx, cudaStream_t cuda_stream) {
|
||||
PadGradNHWC<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dy, num, old_height, old_width, channels,
|
||||
padded_height, padded_width, pad_top, pad_left, dx);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height,
|
||||
const int old_width, const int padded_height, const int padded_width, const int pad_top,
|
||||
|
@ -85,3 +136,19 @@ template void CalPadGrad<half>(const size_t size, const half* dy, const int num,
|
|||
const int old_height, const int old_width, const int padded_height,
|
||||
const int padded_width, const int pad_top, const int pad_left, half* dx,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalPadNHWC<float>(const size_t size, const float* input, const int num, const int old_height,
|
||||
const int old_width, const int channels, const int padded_height,
|
||||
const int padded_width, const int pad_top, const int pad_left, float pad_value,
|
||||
float* output, cudaStream_t cuda_stream);
|
||||
template void CalPadNHWC<half>(const size_t size, const half* input, const int num, const int old_height,
|
||||
const int old_width, const int channels, const int padded_height,
|
||||
const int padded_width, const int pad_top, const int pad_left, float pad_value,
|
||||
half* output, cudaStream_t cuda_stream);
|
||||
template void CalPadGradNHWC<float>(const size_t size, const float* dy, const int num, const int old_height,
|
||||
const int old_width, const int channels, const int padded_height,
|
||||
const int padded_width, const int pad_top, const int pad_left, float* dx,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalPadGradNHWC<half>(const size_t size, const half* dy, const int num, const int old_height,
|
||||
const int old_width, const int channels, const int padded_height,
|
||||
const int padded_width, const int pad_top, const int pad_left, half* dx,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -27,5 +27,13 @@ template <typename T>
|
|||
void CalPadGrad(const size_t size, const T* dy, const int num, const int channels, const int old_height,
|
||||
const int old_width, const int padded_height, const int padded_width, const int pad_top,
|
||||
const int pad_left, T* dx, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void CalPadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width,
|
||||
const int channels, const int padded_height, const int padded_width, const int pad_top, const int pad_left,
|
||||
float pad_value, T* output, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalPadGradNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width,
|
||||
const int channels, const int padded_height, const int padded_width, const int pad_top,
|
||||
const int pad_left, T* output, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <cudnn.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
#include "runtime/device/gpu/gpu_device_manager.h"
|
||||
|
@ -73,6 +74,59 @@ class GpuKernel : public KernelMod {
|
|||
dst->push_back(src.size() == 0 ? 1 : SizeToInt(src[src.size() - 1]));
|
||||
}
|
||||
|
||||
// transpose shape: NCHW To NHWC
|
||||
void ShapeNCHW2NHWC(std::vector<size_t> *shape) {
|
||||
std::swap((*shape)[1], (*shape)[3]);
|
||||
std::swap((*shape)[2], (*shape)[1]);
|
||||
}
|
||||
|
||||
void SetDimA(const std::vector<size_t> &shape, int *dimA, const std::string &format) {
|
||||
if (format == "NCHW" || format == "DefaultFormat") {
|
||||
dimA[0] = SizeToInt(shape[0]);
|
||||
dimA[1] = SizeToInt(shape[1]);
|
||||
dimA[2] = SizeToInt(shape[2]);
|
||||
dimA[3] = SizeToInt(shape[3]);
|
||||
} else if (format == "NHWC") {
|
||||
dimA[0] = SizeToInt(shape[0]);
|
||||
dimA[1] = SizeToInt(shape[3]);
|
||||
dimA[2] = SizeToInt(shape[1]);
|
||||
dimA[3] = SizeToInt(shape[2]);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data format " << format;
|
||||
}
|
||||
}
|
||||
void SetStrideA(const std::vector<size_t> &shape, int *strideA, const std::string &format) {
|
||||
if (format == "NCHW" || format == "DefaultFormat") {
|
||||
strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3]);
|
||||
strideA[1] = SizeToInt(shape[2] * shape[3]);
|
||||
strideA[2] = SizeToInt(shape[3]);
|
||||
strideA[3] = 1;
|
||||
} else if (format == "NHWC") {
|
||||
strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3]);
|
||||
strideA[1] = 1;
|
||||
strideA[2] = SizeToInt(shape[2] * shape[3]);
|
||||
strideA[3] = SizeToInt(shape[3]);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data format " << format;
|
||||
}
|
||||
}
|
||||
|
||||
void SetNCHW(const std::vector<size_t> &shape, int *n, int *c, int *h, int *w, const std::string &format) {
|
||||
if (format == "NCHW" || format == "DefaultFormat") {
|
||||
*n = SizeToInt(shape[0]);
|
||||
*c = SizeToInt(shape[1]);
|
||||
*h = SizeToInt(shape[2]);
|
||||
*w = SizeToInt(shape[3]);
|
||||
} else if (format == "NHWC") {
|
||||
*n = SizeToInt(shape[0]);
|
||||
*c = SizeToInt(shape[3]);
|
||||
*h = SizeToInt(shape[1]);
|
||||
*w = SizeToInt(shape[2]);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported data format " << format;
|
||||
}
|
||||
}
|
||||
|
||||
inline void CheckBroadcast4TensorOp(const std::vector<int> &A, const std::vector<int> &B,
|
||||
const std::vector<int> &Out) {
|
||||
if (A != Out && B != Out) {
|
||||
|
|
|
@ -38,6 +38,7 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|||
conv_desc_(nullptr),
|
||||
padded_desc_(nullptr),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
||||
compute_format_(CUDNN_TENSOR_NCHW),
|
||||
old_height_(0),
|
||||
old_width_(0),
|
||||
pad_height_(0),
|
||||
|
@ -76,9 +77,15 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|||
const float beta = 0;
|
||||
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) {
|
||||
T *padded_addr = GetDeviceAddress<T>(workspace, 1);
|
||||
if (data_format_ == "NHWC") {
|
||||
CalPadNHWC(padded_size_ / sizeof(T), input_addr, n_, old_height_, old_width_, c_, old_height_ + pad_height_,
|
||||
old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_,
|
||||
old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnConvolutionForward(cudnn_handle_, &alpha, padded_desc_, padded_addr, filter_desc_, filter_addr, conv_desc_,
|
||||
conv_algorithm_, workspace_addr, workspace_size_, &beta, output_desc_, output_addr),
|
||||
|
@ -97,15 +104,21 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|||
if (!CheckParam(kernel_node)) {
|
||||
return false;
|
||||
}
|
||||
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(in_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "Conv2dGpuFwdKernel input is null.";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
|
||||
if (data_format_ == "NHWC") {
|
||||
compute_format_ = CUDNN_TENSOR_NHWC;
|
||||
}
|
||||
Set4DDesc(in_shape, filter_shape, output_shape);
|
||||
group_ = GetAttr<int>(kernel_node, "group");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed");
|
||||
|
@ -116,17 +129,55 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|||
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
|
||||
SetStrideAndDilation(kernel_node);
|
||||
cudnnTensorDescriptor_t input_descriptor_real = nullptr;
|
||||
int padA[2];
|
||||
int strideA[2] = {stride_[2], stride_[3]};
|
||||
int dilaA[2] = {dilation_[2], dilation_[3]};
|
||||
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) {
|
||||
SetPad(in_shape, kernel_node);
|
||||
pad_height_ = pad_list[0] + pad_list[1];
|
||||
pad_width_ = pad_list[2] + pad_list[3];
|
||||
pad_top_ = pad_list[0];
|
||||
pad_left_ = pad_list[2];
|
||||
|
||||
// if use_pad_ == true, using zero padding in advance, else using the default cudnn pad.
|
||||
if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) {
|
||||
use_pad_ = false;
|
||||
}
|
||||
int dimA[4];
|
||||
int strideApadded[4];
|
||||
if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") {
|
||||
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
|
||||
IntToSize(old_width_ + pad_width_)};
|
||||
SetDimA(padded_shape, dimA, data_format_);
|
||||
SetStrideA(padded_shape, strideApadded, data_format_);
|
||||
} else if (data_format_ == "NHWC") {
|
||||
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
|
||||
IntToSize(c_)};
|
||||
SetDimA(padded_shape, dimA, data_format_);
|
||||
SetStrideA(padded_shape, strideApadded, data_format_);
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(padded_desc_, cudnn_data_type_, 4, dimA, strideApadded),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
|
||||
if (use_pad_) {
|
||||
padA[0] = 0;
|
||||
padA[1] = 0;
|
||||
} else {
|
||||
padA[0] = pad_top_;
|
||||
padA[1] = pad_left_;
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"cudnnSetConvolutionNdDescriptor failed");
|
||||
input_descriptor_real = use_pad_ ? padded_desc_ : input_desc_;
|
||||
} else {
|
||||
if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) {
|
||||
pad_height_ = 0;
|
||||
pad_width_ = 0;
|
||||
}
|
||||
padA[0] = pad_height_;
|
||||
padA[1] = pad_width_;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[2], stride_[3], dilation_[2],
|
||||
dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"cudnnSetConvolution2dDescriptor failed");
|
||||
input_descriptor_real = input_desc_;
|
||||
}
|
||||
|
@ -193,13 +244,11 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_desc_), "cudnnDestroyTensorDescriptor failed");
|
||||
}
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 2) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but conv2d needs 2 inputs.";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but conv2d needs 1 output.";
|
||||
|
@ -207,45 +256,28 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
void SetPad(const std::vector<size_t> &in_shape, const CNodePtr &kernel_node) {
|
||||
auto pad_list = GetAttr<std::vector<int>>(kernel_node, "pad_list");
|
||||
|
||||
n_ = SizeToInt(in_shape[0]);
|
||||
c_ = SizeToInt(in_shape[1]);
|
||||
old_height_ = SizeToInt(in_shape[2]);
|
||||
old_width_ = SizeToInt(in_shape[3]);
|
||||
pad_height_ = pad_list[0] + pad_list[1];
|
||||
pad_width_ = pad_list[2] + pad_list[3];
|
||||
pad_top_ = pad_list[0];
|
||||
pad_left_ = pad_list[2];
|
||||
|
||||
// if use_pad_ == true, using zero padding in advance, else using the default cudnn pad.
|
||||
if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) {
|
||||
use_pad_ = false;
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_, c_,
|
||||
old_height_ + pad_height_, old_width_ + pad_width_),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[2], stride_[3],
|
||||
dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"cudnnSetConvolution2dDescriptor failed");
|
||||
}
|
||||
|
||||
void Set4DDesc(const std::vector<size_t> &in_shape, const std::vector<size_t> &filter_shape,
|
||||
const std::vector<size_t> &output_shape) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(input_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]),
|
||||
SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])),
|
||||
int nbDims = 4;
|
||||
int dimA[4];
|
||||
int strideAin[4];
|
||||
int dimAout[4];
|
||||
int strideAout[4];
|
||||
SetDimA(in_shape, dimA, data_format_);
|
||||
SetStrideA(in_shape, strideAin, data_format_);
|
||||
SetDimA(output_shape, dimAout, data_format_);
|
||||
SetStrideA(output_shape, strideAout, data_format_);
|
||||
int filterDimA[4];
|
||||
// OHWI for NHWC; OIHW for NCHW
|
||||
SetDimA(filter_shape, filterDimA, data_format_);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(input_desc_, cudnn_data_type_, nbDims, dimA, strideAin),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetFilter4dDescriptor(filter_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(filter_shape[0]),
|
||||
SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])),
|
||||
cudnnSetFilterNdDescriptor(filter_desc_, cudnn_data_type_, compute_format_, nbDims, filterDimA),
|
||||
"cudnnSetFilter4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(output_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]),
|
||||
SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])),
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(output_desc_, cudnn_data_type_, nbDims, dimAout, strideAout),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
}
|
||||
void SelectAlgorithm(cudnnTensorDescriptor_t input_descriptor_real) {
|
||||
|
@ -292,11 +324,13 @@ class Conv2dGpuFwdKernel : public GpuKernel {
|
|||
cudnnConvolutionDescriptor_t conv_desc_;
|
||||
cudnnTensorDescriptor_t padded_desc_;
|
||||
std::string pad_mode_;
|
||||
std::string data_format_ = "NCHW";
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
const float pad_value_ = 0.0;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
cudnnTensorFormat_t compute_format_;
|
||||
int old_height_;
|
||||
int old_width_;
|
||||
int pad_height_;
|
||||
|
|
|
@ -38,6 +38,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
|
|||
x_desc_(nullptr),
|
||||
padded_descriptor_(nullptr),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
||||
compute_format_(CUDNN_TENSOR_NCHW),
|
||||
old_height_(0),
|
||||
old_width_(0),
|
||||
pad_height_(0),
|
||||
|
@ -75,12 +76,18 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
|
|||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
|
||||
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) {
|
||||
T *padded = GetDeviceAddress<T>(workspace, 1);
|
||||
if (data_format_ == "NHWC") {
|
||||
CalPadNHWC(padded_size_ / sizeof(T), x, n_, old_height_, old_width_, c_, old_height_ + pad_height_,
|
||||
old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CalPad(padded_size_ / sizeof(T), x, n_, c_, old_height_, old_width_, old_height_ + pad_height_,
|
||||
old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnConvolutionBackwardFilter(cudnn_handle_, &alpha, padded_descriptor_, padded, dy_desc_, dy, conv_desc_,
|
||||
algo_, work_space, workspace_size_, &beta, dw_desc_, dw),
|
||||
|
@ -99,16 +106,21 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
|
|||
return false;
|
||||
}
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto in_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
is_null_input_ = CHECK_NULL_INPUT(dy_shape) || CHECK_NULL_INPUT(in_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "ConvGradFilterGpuBkwKernel input is null.";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
std::vector<int> filter_shape;
|
||||
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
std::vector<size_t> filter_shape;
|
||||
GetFilterShape(kernel_node, &filter_shape);
|
||||
if (data_format_ == "NHWC") {
|
||||
compute_format_ = CUDNN_TENSOR_NHWC;
|
||||
}
|
||||
SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
|
||||
Set4DDesc(dy_shape, filter_shape, in_shape);
|
||||
group_ = GetAttr<int>(kernel_node, "group");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolutionGroupCount(conv_desc_, group_), "cudnnSetConvGroupCount failed");
|
||||
|
@ -120,18 +132,54 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
|
|||
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
|
||||
SetStrideAndDilation(kernel_node);
|
||||
cudnnTensorDescriptor_t x_desc_real = nullptr;
|
||||
int padA[2];
|
||||
int strideA[2] = {stride_[0], stride_[1]};
|
||||
int dilaA[2] = {dilation_[0], dilation_[1]};
|
||||
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) {
|
||||
SetPad(in_shape, kernel_node);
|
||||
pad_height_ = pad_list[0] + pad_list[1];
|
||||
pad_width_ = pad_list[2] + pad_list[3];
|
||||
pad_top_ = pad_list[0];
|
||||
pad_left_ = pad_list[2];
|
||||
if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) {
|
||||
use_pad_ = false;
|
||||
}
|
||||
int dimA[4];
|
||||
int strideApadded[4];
|
||||
if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") {
|
||||
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
|
||||
IntToSize(old_width_ + pad_width_)};
|
||||
SetDimA(padded_shape, dimA, data_format_);
|
||||
SetStrideA(padded_shape, strideApadded, data_format_);
|
||||
} else if (data_format_ == "NHWC") {
|
||||
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
|
||||
IntToSize(c_)};
|
||||
SetDimA(padded_shape, dimA, data_format_);
|
||||
SetStrideA(padded_shape, strideApadded, data_format_);
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, 4, dimA, strideApadded),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
if (use_pad_) {
|
||||
padA[0] = 0;
|
||||
padA[1] = 0;
|
||||
} else {
|
||||
padA[0] = pad_top_;
|
||||
padA[1] = pad_left_;
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"cudnnSetConvolutionNdDescriptor failed");
|
||||
x_desc_real = use_pad_ ? padded_descriptor_ : x_desc_;
|
||||
} else {
|
||||
if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) {
|
||||
pad_height_ = 0;
|
||||
pad_width_ = 0;
|
||||
}
|
||||
padA[0] = pad_height_;
|
||||
padA[1] = pad_width_;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2],
|
||||
dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"GetConvolution2dDescriptor failed");
|
||||
cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"cudnnSetConvolution2dDescriptor failed");
|
||||
x_desc_real = x_desc_;
|
||||
}
|
||||
if (cudnn_data_type_ == CUDNN_DATA_HALF) {
|
||||
|
@ -208,27 +256,6 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
void SetPad(const std::vector<size_t> &in_shape, const CNodePtr &kernel_node) {
|
||||
auto pad_list = GetAttr<std::vector<int>>(kernel_node, "pad_list");
|
||||
n_ = SizeToInt(in_shape[0]);
|
||||
c_ = SizeToInt(in_shape[1]);
|
||||
old_height_ = SizeToInt(in_shape[2]);
|
||||
old_width_ = SizeToInt(in_shape[3]);
|
||||
pad_height_ = pad_list[0] + pad_list[1];
|
||||
pad_width_ = pad_list[2] + pad_list[3];
|
||||
pad_top_ = pad_list[0];
|
||||
pad_left_ = pad_list[2];
|
||||
if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) {
|
||||
use_pad_ = false;
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_,
|
||||
c_, old_height_ + pad_height_, old_width_ + pad_width_),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1],
|
||||
dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"cudnnSetConvolution2dDescriptor failed");
|
||||
}
|
||||
void SelectAlgorithm(cudnnTensorDescriptor_t x_desc_real) {
|
||||
if (group_ > 1 || CUDNN_MAJOR < 7) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
|
@ -249,27 +276,33 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
|
|||
algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
||||
}
|
||||
}
|
||||
void GetFilterShape(const CNodePtr &kernel_node, std::vector<int> *filter_shape) {
|
||||
void GetFilterShape(const CNodePtr &kernel_node, std::vector<size_t> *filter_shape) {
|
||||
auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("filter_sizes")->cast<ValueTuplePtr>()->value();
|
||||
(void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*filter_shape),
|
||||
[](const ValuePtr &e) -> int { return e->cast<Int32ImmPtr>()->value(); });
|
||||
[](const ValuePtr &e) -> size_t { return e->cast<Int32ImmPtr>()->value(); });
|
||||
}
|
||||
void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<int> &filter_shape,
|
||||
void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &filter_shape,
|
||||
const std::vector<size_t> &in_shape) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]),
|
||||
SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])),
|
||||
"SetTensor4dDescriptor failed");
|
||||
int nbDims = 4;
|
||||
int dimA[4];
|
||||
int strideAin[4];
|
||||
int dimAdy[4];
|
||||
int strideAdy[4];
|
||||
SetDimA(in_shape, dimA, data_format_);
|
||||
SetStrideA(in_shape, strideAin, data_format_);
|
||||
SetDimA(dy_shape, dimAdy, data_format_);
|
||||
SetStrideA(dy_shape, strideAdy, data_format_);
|
||||
// filter shape always keep OIHW.
|
||||
int filterDimA[4] = {SizeToInt(filter_shape[0]), SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]),
|
||||
SizeToInt(filter_shape[3])};
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, nbDims, dimAdy, strideAdy),
|
||||
"cudnnSetTensorNdDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetFilter4dDescriptor(dw_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]), filter_shape[1],
|
||||
filter_shape[2], filter_shape[3]),
|
||||
"SetFilter4dDescriptor failed");
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(in_shape[0]),
|
||||
SizeToInt(in_shape[1]), SizeToInt(in_shape[2]), SizeToInt(in_shape[3])),
|
||||
"SetTensor4dDescriptor failed");
|
||||
cudnnSetFilterNdDescriptor(dw_desc_, cudnn_data_type_, compute_format_, nbDims, filterDimA),
|
||||
"cudnnSetFilterNdDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(x_desc_, cudnn_data_type_, nbDims, dimA, strideAin),
|
||||
"cudnnSetTensorNdDescriptor failed");
|
||||
}
|
||||
void SetStrideAndDilation(const CNodePtr &kernel_node) {
|
||||
stride_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
|
||||
|
@ -292,11 +325,13 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
|
|||
cudnnTensorDescriptor_t padded_descriptor_;
|
||||
cudnnConvolutionBwdFilterAlgo_t algo_;
|
||||
std::string pad_mode_;
|
||||
std::string data_format_ = "NCHW";
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
const float pad_value_ = 0.0;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
cudnnTensorFormat_t compute_format_;
|
||||
int old_height_;
|
||||
int old_width_;
|
||||
int pad_height_;
|
||||
|
@ -319,4 +354,4 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_
|
||||
#endif // MINDePORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_
|
||||
|
|
|
@ -38,6 +38,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|||
dx_desc_(nullptr),
|
||||
padded_descriptor_(nullptr),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
||||
compute_format_(CUDNN_TENSOR_NCHW),
|
||||
old_height_(0),
|
||||
old_width_(0),
|
||||
pad_height_(0),
|
||||
|
@ -75,7 +76,6 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
|
||||
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) {
|
||||
T *padded = GetDeviceAddress<T>(workspace, 1);
|
||||
|
||||
|
@ -83,8 +83,13 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|||
cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space,
|
||||
workspace_size_, &beta, padded_descriptor_, padded),
|
||||
"ConvolutionBackwardData failed");
|
||||
if (data_format_ == "NHWC") {
|
||||
CalPadGradNHWC(output_size_ / sizeof(T), padded, n_, old_height_, old_width_, c_, old_height_ + pad_height_,
|
||||
old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CalPadGrad(output_size_ / sizeof(T), padded, n_, c_, old_height_, old_width_, old_height_ + pad_height_,
|
||||
old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space,
|
||||
|
@ -99,16 +104,23 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|||
return false;
|
||||
}
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto filter_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
is_null_input_ = CHECK_NULL_INPUT(dy_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "ConvGradInputGpuBkwKernel input is null.";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
std::vector<int> input_shape;
|
||||
|
||||
std::vector<size_t> input_shape;
|
||||
GetInputShape(kernel_node, &input_shape);
|
||||
if (data_format_ == "NHWC") {
|
||||
compute_format_ = CUDNN_TENSOR_NHWC;
|
||||
ShapeNCHW2NHWC(&input_shape);
|
||||
}
|
||||
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
|
||||
Set4DDesc(dy_shape, input_shape, filter_shape);
|
||||
|
||||
group_ = GetAttr<int>(kernel_node, "group");
|
||||
|
@ -121,17 +133,53 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|||
pad_mode_ = GetAttr<std::string>(kernel_node, "pad_mode");
|
||||
SetStrideAndDilation(kernel_node);
|
||||
cudnnTensorDescriptor_t dx_desc_real = nullptr;
|
||||
int padA[2];
|
||||
int strideA[2] = {stride_[0], stride_[1]};
|
||||
int dilaA[2] = {dilation_[0], dilation_[1]};
|
||||
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase || !symmetry_pad) {
|
||||
SetPad(input_shape, kernel_node);
|
||||
pad_height_ = pad_list[0] + pad_list[1];
|
||||
pad_width_ = pad_list[2] + pad_list[3];
|
||||
pad_top_ = pad_list[0];
|
||||
pad_left_ = pad_list[2];
|
||||
if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) {
|
||||
use_pad_ = false;
|
||||
}
|
||||
int dimA[4];
|
||||
int strideApadded[4];
|
||||
if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") {
|
||||
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
|
||||
IntToSize(old_width_ + pad_width_)};
|
||||
SetDimA(padded_shape, dimA, data_format_);
|
||||
SetStrideA(padded_shape, strideApadded, data_format_);
|
||||
} else if (data_format_ == "NHWC") {
|
||||
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
|
||||
IntToSize(c_)};
|
||||
SetDimA(padded_shape, dimA, data_format_);
|
||||
SetStrideA(padded_shape, strideApadded, data_format_);
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensorNdDescriptor(padded_descriptor_, cudnn_data_type_, 4, dimA, strideApadded),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
if (use_pad_) {
|
||||
padA[0] = 0;
|
||||
padA[1] = 0;
|
||||
} else {
|
||||
padA[0] = pad_top_;
|
||||
padA[1] = pad_left_;
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"cudnnSetConvolutionNdDescriptor failed");
|
||||
dx_desc_real = use_pad_ ? padded_descriptor_ : dx_desc_;
|
||||
} else {
|
||||
if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) {
|
||||
pad_height_ = 0;
|
||||
pad_width_ = 0;
|
||||
}
|
||||
padA[0] = pad_height_;
|
||||
padA[1] = pad_width_;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetConvolution2dDescriptor(conv_desc_, pad_height_, pad_width_, stride_[0], stride_[1], dilation_[2],
|
||||
dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
cudnnSetConvolutionNdDescriptor(conv_desc_, 2, padA, strideA, dilaA, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"cudnnSetConvolution2dDescriptor failed");
|
||||
dx_desc_real = dx_desc_;
|
||||
}
|
||||
|
@ -208,24 +256,6 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|||
}
|
||||
void SetPad(const std::vector<int> &input_shape, const CNodePtr &kernel_node) {
|
||||
auto pad_list = GetAttr<std::vector<int>>(kernel_node, "pad_list");
|
||||
n_ = input_shape[0];
|
||||
c_ = input_shape[1];
|
||||
old_height_ = input_shape[2];
|
||||
old_width_ = input_shape[3];
|
||||
pad_height_ = pad_list[0] + pad_list[1];
|
||||
pad_width_ = pad_list[2] + pad_list[3];
|
||||
pad_top_ = pad_list[0];
|
||||
pad_left_ = pad_list[2];
|
||||
if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) {
|
||||
use_pad_ = false;
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_,
|
||||
c_, old_height_ + pad_height_, old_width_ + pad_width_),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetConvolution2dDescriptor(
|
||||
conv_desc_, use_pad_ ? 0 : pad_top_, use_pad_ ? 0 : pad_left_, stride_[0], stride_[1],
|
||||
dilation_[2], dilation_[3], CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT),
|
||||
"cudnnSetConvolution2dDescriptor failed");
|
||||
}
|
||||
void SelectAlgorithm(cudnnTensorDescriptor_t dx_desc_real) {
|
||||
if (group_ > 1 || CUDNN_MAJOR < 7) {
|
||||
|
@ -247,25 +277,32 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|||
algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
||||
}
|
||||
}
|
||||
void GetInputShape(const CNodePtr &kernel_node, std::vector<int> *input_shape) {
|
||||
void GetInputShape(const CNodePtr &kernel_node, std::vector<size_t> *input_shape) {
|
||||
auto shp_tuple_x = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("input_sizes")->cast<ValueTuplePtr>()->value();
|
||||
(void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(*input_shape),
|
||||
[](const ValuePtr &e) -> int { return e->cast<Int32ImmPtr>()->value(); });
|
||||
[](const ValuePtr &e) -> size_t { return e->cast<Int32ImmPtr>()->value(); });
|
||||
}
|
||||
void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<int> &input_shape,
|
||||
void Set4DDesc(const std::vector<size_t> &dy_shape, const std::vector<size_t> &input_shape,
|
||||
const std::vector<size_t> &filter_shape) {
|
||||
int nbDims = 4;
|
||||
int dimA[4];
|
||||
int strideAin[4];
|
||||
int dimAdy[4];
|
||||
int strideAdy[4];
|
||||
int filterDimA[4];
|
||||
SetDimA(input_shape, dimA, data_format_);
|
||||
SetStrideA(input_shape, strideAin, data_format_);
|
||||
SetDimA(dy_shape, dimAdy, data_format_);
|
||||
SetStrideA(dy_shape, strideAdy, data_format_);
|
||||
SetDimA(filter_shape, filterDimA, data_format_);
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, nbDims, dimAdy, strideAdy),
|
||||
"cudnnSetTensorNdDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetFilter4dDescriptor(w_desc_, cudnn_data_type_, CUDNN_TENSOR_NCHW, SizeToInt(dy_shape[1]),
|
||||
SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), SizeToInt(filter_shape[3])),
|
||||
"SetFilter4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(dy_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dy_shape[0]),
|
||||
SizeToInt(dy_shape[1]), SizeToInt(dy_shape[2]), SizeToInt(dy_shape[3])),
|
||||
"SetTensor4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(dx_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, input_shape[0], input_shape[1],
|
||||
input_shape[2], input_shape[3]),
|
||||
"SetTensor4dDescriptor failed");
|
||||
cudnnSetFilterNdDescriptor(w_desc_, cudnn_data_type_, compute_format_, nbDims, filterDimA),
|
||||
"cudnnSetFilterNdDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dx_desc_, cudnn_data_type_, nbDims, dimA, strideAin),
|
||||
"cudnnSetTensorNdDescriptor failed");
|
||||
}
|
||||
void SetStrideAndDilation(const CNodePtr &kernel_node) {
|
||||
stride_ = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, "stride");
|
||||
|
@ -288,10 +325,12 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
|
|||
cudnnTensorDescriptor_t padded_descriptor_;
|
||||
cudnnConvolutionBwdDataAlgo_t algo_;
|
||||
std::string pad_mode_;
|
||||
std::string data_format_ = "NCHW";
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
cudnnTensorFormat_t compute_format_;
|
||||
int old_height_;
|
||||
int old_width_;
|
||||
int pad_height_;
|
||||
|
|
|
@ -35,9 +35,9 @@ class PoolingGpuFwdKernel : public GpuKernel {
|
|||
input_descriptor_(nullptr),
|
||||
output_descriptor_(nullptr),
|
||||
pooling_descriptor_(nullptr),
|
||||
padded_descriptor_(nullptr),
|
||||
pooling_mode_(CUDNN_POOLING_MAX),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
||||
compute_format_(CUDNN_TENSOR_NCHW),
|
||||
old_height_(0),
|
||||
old_width_(0),
|
||||
pad_height_(0),
|
||||
|
@ -50,9 +50,7 @@ class PoolingGpuFwdKernel : public GpuKernel {
|
|||
is_null_input_(false),
|
||||
input_size_(0),
|
||||
output_size_(0),
|
||||
padded_size_(0),
|
||||
workspace_size_(0),
|
||||
use_pad_(true) {}
|
||||
workspace_size_(0) {}
|
||||
~PoolingGpuFwdKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -67,20 +65,10 @@ class PoolingGpuFwdKernel : public GpuKernel {
|
|||
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) {
|
||||
T *padded_addr = reinterpret_cast<T *>(workspace[0]->addr);
|
||||
CalPad(padded_size_ / sizeof(T), input_addr, n_, c_, old_height_, old_width_, old_height_ + pad_height_,
|
||||
old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, padded_descriptor_,
|
||||
padded_addr, &beta, output_descriptor_, output_addr),
|
||||
"cudnnPoolingForward failed");
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_,
|
||||
input_addr, &beta, output_descriptor_, output_addr),
|
||||
"cudnnPoolingForward failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) {
|
||||
|
@ -89,89 +77,41 @@ class PoolingGpuFwdKernel : public GpuKernel {
|
|||
return false;
|
||||
}
|
||||
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
|
||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "PoolingGpuFwdKernel input is null.";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
|
||||
int nbDims = 4;
|
||||
int dimA[4];
|
||||
int strideAin[4];
|
||||
int dimAout[4];
|
||||
int strideAout[4];
|
||||
SetDimA(input_shape, dimA, data_format_);
|
||||
SetStrideA(input_shape, strideAin, data_format_);
|
||||
SetDimA(output_shape, dimAout, data_format_);
|
||||
SetStrideA(output_shape, strideAout, data_format_);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]),
|
||||
SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])),
|
||||
cudnnSetTensorNdDescriptor(input_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(output_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]),
|
||||
SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])),
|
||||
cudnnSetTensorNdDescriptor(output_descriptor_, cudnn_data_type_, nbDims, dimAout, strideAout),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
auto window = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize"));
|
||||
int window_height = window[2];
|
||||
int window_width = window[3];
|
||||
stride_ = GetValue<std::vector<int>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"));
|
||||
SetPoolingMode(kernel_node);
|
||||
int windowDimA[2] = {window_height, window_width};
|
||||
int paddingA[2] = {0, 0};
|
||||
int strideA[2] = {stride_[2], stride_[3]};
|
||||
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
|
||||
SetPad(input_shape, window_height, window_width);
|
||||
} else {
|
||||
pad_height_ = 0;
|
||||
pad_width_ = 0;
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height,
|
||||
window_width, pad_height_, pad_width_, stride_[2], stride_[3]),
|
||||
"cudnnSetPooling2dDescriptor failed");
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() {
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_),
|
||||
"cudnnCreatePoolingDescriptor failed");
|
||||
}
|
||||
void InitSizeLists() {
|
||||
if (!is_null_input_) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnGetTensorSizeInBytes(input_descriptor_, reinterpret_cast<size_t *>(&input_size_)),
|
||||
"cudnnGetTensorSizeInBytes failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnGetTensorSizeInBytes(output_descriptor_, reinterpret_cast<size_t *>(&output_size_)),
|
||||
"cudnnGetTensorSizeInBytes failed");
|
||||
}
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnGetTensorSizeInBytes(padded_descriptor_, reinterpret_cast<size_t *>(&padded_size_)),
|
||||
"cudnnGetTensorSizeInBytes failed");
|
||||
workspace_size_list_.push_back(padded_size_);
|
||||
if (padded_size_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "Padded size is 0.";
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but pooling needs 1 inputs.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void SetPad(const std::vector<size_t> &input_shape, const int &window_height, const int &window_width) {
|
||||
n_ = SizeToInt(input_shape[0]);
|
||||
c_ = SizeToInt(input_shape[1]);
|
||||
old_height_ = SizeToInt(input_shape[2]);
|
||||
old_width_ = SizeToInt(input_shape[3]);
|
||||
pad_height_ =
|
||||
std::max<int>(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2])
|
||||
: (old_height_ / stride_[2]) + 1) -
|
||||
|
@ -186,17 +126,51 @@ class PoolingGpuFwdKernel : public GpuKernel {
|
|||
window_width - old_width_);
|
||||
pad_top_ = pad_height_ / 2;
|
||||
pad_left_ = pad_width_ / 2;
|
||||
if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) {
|
||||
use_pad_ = false;
|
||||
paddingA[0] = pad_top_;
|
||||
paddingA[1] = pad_left_;
|
||||
} else {
|
||||
pad_height_ = 0;
|
||||
pad_width_ = 0;
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_,
|
||||
c_, old_height_ + pad_height_, old_width_ + pad_width_),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN,
|
||||
window_height, window_width, use_pad_ ? 0 : pad_top_,
|
||||
use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]),
|
||||
"cudnnSetPooling2dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN,
|
||||
2, windowDimA, paddingA, strideA),
|
||||
"cudnnSetPoolingNdDescriptor failed");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitResource() {
|
||||
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&input_descriptor_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&output_descriptor_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_),
|
||||
"cudnnCreatePoolingDescriptor failed");
|
||||
}
|
||||
void InitSizeLists() {
|
||||
if (!is_null_input_) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnGetTensorSizeInBytes(input_descriptor_, reinterpret_cast<size_t *>(&input_size_)),
|
||||
"cudnnGetTensorSizeInBytes failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnGetTensorSizeInBytes(output_descriptor_, reinterpret_cast<size_t *>(&output_size_)),
|
||||
"cudnnGetTensorSizeInBytes failed");
|
||||
}
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but pooling needs 1 inputs.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SetPoolingMode(const CNodePtr &kernel_node) {
|
||||
pad_mode_ = GetValue<std::string>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding"));
|
||||
mode_ = AnfAlgo::GetCNodeName(kernel_node);
|
||||
|
@ -211,7 +185,6 @@ class PoolingGpuFwdKernel : public GpuKernel {
|
|||
void DestroyResource() noexcept {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_),
|
||||
"cudnnDestroyPoolingDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(output_descriptor_), "cudnnDestroyTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(input_descriptor_), "cudnnDestroyTensorDescriptor failed");
|
||||
}
|
||||
|
@ -220,16 +193,16 @@ class PoolingGpuFwdKernel : public GpuKernel {
|
|||
cudnnTensorDescriptor_t input_descriptor_;
|
||||
cudnnTensorDescriptor_t output_descriptor_;
|
||||
cudnnPoolingDescriptor_t pooling_descriptor_;
|
||||
cudnnTensorDescriptor_t padded_descriptor_;
|
||||
cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX;
|
||||
std::vector<int> stride_;
|
||||
std::string mode_;
|
||||
std::string pad_mode_;
|
||||
std::string data_format_ = "NCHW";
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
|
||||
cudnnTensorFormat_t compute_format_;
|
||||
int old_height_;
|
||||
int old_width_;
|
||||
int pad_height_;
|
||||
|
@ -242,9 +215,7 @@ class PoolingGpuFwdKernel : public GpuKernel {
|
|||
bool is_null_input_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t padded_size_;
|
||||
size_t workspace_size_;
|
||||
bool use_pad_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -37,9 +37,9 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
dy_descriptor_(nullptr),
|
||||
x_descriptor_(nullptr),
|
||||
dx_descriptor_(nullptr),
|
||||
padded_descriptor_(nullptr),
|
||||
pooling_mode_(CUDNN_POOLING_MAX),
|
||||
cudnn_data_type_(CUDNN_DATA_FLOAT),
|
||||
compute_format_(CUDNN_TENSOR_NCHW),
|
||||
old_height_(0),
|
||||
old_width_(0),
|
||||
pad_height_(0),
|
||||
|
@ -52,9 +52,7 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
is_null_input_(false),
|
||||
input_size_(0),
|
||||
output_size_(0),
|
||||
padded_size_(0),
|
||||
workspace_size_(0),
|
||||
use_pad_(true) {}
|
||||
workspace_size_(0) {}
|
||||
~PoolingGradGpuKernel() override { DestroyResource(); }
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -72,27 +70,10 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) {
|
||||
T *padded = GetDeviceAddress<T>(workspace, 0);
|
||||
T *padded_dx = GetDeviceAddress<T>(workspace, 1);
|
||||
|
||||
CalPad(padded_size_ / sizeof(T), x_data, n_, c_, old_height_, old_width_, old_height_ + pad_height_,
|
||||
old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy,
|
||||
padded_descriptor_, padded, &beta, padded_descriptor_, padded_dx),
|
||||
"cudnnPoolingBackward failed");
|
||||
|
||||
CalPadGrad(output_size_ / sizeof(T), padded_dx, n_, c_, old_height_, old_width_, old_height_ + pad_height_,
|
||||
old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy,
|
||||
x_descriptor_, x_data, &beta, dx_descriptor_, dx),
|
||||
"cudnnPoolingBackward failed");
|
||||
}
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
|
@ -104,46 +85,73 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
int window_height = window[2];
|
||||
int window_width = window[3];
|
||||
SetPoolingMode(kernel_node);
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto input_mask = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
auto input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "PoolingGradGpuKernel input is null.";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
|
||||
int windowDimA[2] = {window_height, window_width};
|
||||
int paddingA[2] = {0, 0};
|
||||
int strideA[2] = {stride_[2], stride_[3]};
|
||||
int nbDims = 4;
|
||||
int dimA[4];
|
||||
int strideAin[4];
|
||||
int dimAy[4];
|
||||
int strideAiny[4];
|
||||
int dimAdy[4];
|
||||
int strideAdy[4];
|
||||
int dimAout[4];
|
||||
int strideAout[4];
|
||||
SetDimA(input_shape, dimA, data_format_);
|
||||
SetStrideA(input_shape, strideAin, data_format_);
|
||||
SetDimA(input_mask, dimAy, data_format_);
|
||||
SetStrideA(input_mask, strideAiny, data_format_);
|
||||
SetDimA(dout_shape, dimAdy, data_format_);
|
||||
SetStrideA(dout_shape, strideAdy, data_format_);
|
||||
SetDimA(output_shape, dimAout, data_format_);
|
||||
SetStrideA(output_shape, strideAout, data_format_);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(y_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_mask[0]),
|
||||
SizeToInt(input_mask[1]), SizeToInt(input_mask[2]), SizeToInt(input_mask[3])),
|
||||
"cudnnSetTensor4dDescriptor");
|
||||
|
||||
auto dout_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(dy_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(dout_shape[0]),
|
||||
SizeToInt(dout_shape[1]), SizeToInt(dout_shape[2]), SizeToInt(dout_shape[3])),
|
||||
"cudnnSetTensor4dDescriptor");
|
||||
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(dx_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(output_shape[0]),
|
||||
SizeToInt(output_shape[1]), SizeToInt(output_shape[2]), SizeToInt(output_shape[3])),
|
||||
cudnnSetTensorNdDescriptor(dx_descriptor_, cudnn_data_type_, nbDims, dimAout, strideAout),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(x_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) {
|
||||
SetPad(input_shape, window_height, window_width);
|
||||
pad_height_ =
|
||||
std::max<int>(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2])
|
||||
: (old_height_ / stride_[2]) + 1) -
|
||||
1) *
|
||||
stride_[2] +
|
||||
window_height - old_height_);
|
||||
pad_width_ =
|
||||
std::max<int>(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3])
|
||||
: (old_width_ / stride_[3]) + 1) -
|
||||
1) *
|
||||
stride_[3] +
|
||||
window_width - old_width_);
|
||||
pad_top_ = pad_height_ / 2;
|
||||
pad_left_ = pad_width_ / 2;
|
||||
paddingA[0] = pad_top_;
|
||||
paddingA[1] = pad_left_;
|
||||
} else {
|
||||
if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) {
|
||||
pad_height_ = 0;
|
||||
pad_width_ = 0;
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, window_height,
|
||||
window_width, pad_height_, pad_width_, stride_[2], stride_[3]),
|
||||
"cudnnSetPooling2dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]),
|
||||
SizeToInt(input_shape[1]), SizeToInt(input_shape[2]), SizeToInt(input_shape[3])),
|
||||
"cudnnSetTensor4dDescriptor");
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN,
|
||||
2, windowDimA, paddingA, strideA),
|
||||
"cudnnSetPoolingNdDescriptor failed");
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
@ -155,7 +163,6 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_descriptor_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_descriptor_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_descriptor_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&padded_descriptor_), "cudnnCreateTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreatePoolingDescriptor(&pooling_descriptor_),
|
||||
"cudnnCreatePoolingDescriptor failed");
|
||||
}
|
||||
|
@ -179,16 +186,6 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
"cudnnGetTensorSizeInBytes failed");
|
||||
}
|
||||
input_size_list_.push_back(input_size_);
|
||||
|
||||
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_ && !is_null_input_) {
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(padded_descriptor_, &padded_size_),
|
||||
"cudnnGetTensorSizeInBytes failed");
|
||||
if (padded_size_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "Padded size is 0.";
|
||||
}
|
||||
workspace_size_list_.push_back(padded_size_);
|
||||
workspace_size_list_.push_back(padded_size_);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -206,35 +203,6 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
c_ = SizeToInt(input_shape[1]);
|
||||
old_height_ = SizeToInt(input_shape[2]);
|
||||
old_width_ = SizeToInt(input_shape[3]);
|
||||
pad_height_ =
|
||||
std::max<int>(0, (((old_height_ / stride_[2]) * stride_[2] == old_height_ ? (old_height_ / stride_[2])
|
||||
: (old_height_ / stride_[2]) + 1) -
|
||||
1) *
|
||||
stride_[2] +
|
||||
window_height - old_height_);
|
||||
pad_width_ =
|
||||
std::max<int>(0, (((old_width_ / stride_[3]) * stride_[3] == old_width_ ? (old_width_ / stride_[3])
|
||||
: (old_width_ / stride_[3]) + 1) -
|
||||
1) *
|
||||
stride_[3] +
|
||||
window_width - old_width_);
|
||||
pad_top_ = pad_height_ / 2;
|
||||
pad_left_ = pad_width_ / 2;
|
||||
if (pad_height_ % 2 == 0 && pad_width_ % 2 == 0) {
|
||||
use_pad_ = false;
|
||||
}
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(padded_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, n_,
|
||||
c_, old_height_ + pad_height_, old_width_ + pad_width_),
|
||||
"cudnnSetTensor4dDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(
|
||||
cudnnSetTensor4dDescriptor(x_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(input_shape[0]),
|
||||
SizeToInt(input_shape[1]), SizeToInt(input_shape[2]) + (use_pad_ ? pad_height_ : 0),
|
||||
SizeToInt(input_shape[3]) + (use_pad_ ? pad_width_ : 0)),
|
||||
"cudnnSetTensor4dDescriptor");
|
||||
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetPooling2dDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN,
|
||||
window_height, window_width, use_pad_ ? 0 : pad_top_,
|
||||
use_pad_ ? 0 : pad_left_, stride_[2], stride_[3]),
|
||||
"cudnnSetPooling2dDescriptor failed");
|
||||
}
|
||||
void SetPoolingMode(const CNodePtr &kernel_node) {
|
||||
pad_mode_ = GetAttr<std::string>(kernel_node, "padding");
|
||||
|
@ -252,7 +220,6 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
void DestroyResource() noexcept {
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyPoolingDescriptor(pooling_descriptor_),
|
||||
"cudnnDestroyPoolingDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(padded_descriptor_), "cudnnDestroyTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_descriptor_), "cudnnDestroyTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_descriptor_), "cudnnDestroyTensorDescriptor failed");
|
||||
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_descriptor_), "cudnnDestroyTensorDescriptor failed");
|
||||
|
@ -265,7 +232,6 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
cudnnTensorDescriptor_t dy_descriptor_;
|
||||
cudnnTensorDescriptor_t x_descriptor_;
|
||||
cudnnTensorDescriptor_t dx_descriptor_;
|
||||
cudnnTensorDescriptor_t padded_descriptor_;
|
||||
cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX;
|
||||
std::vector<int> stride_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
|
@ -273,7 +239,9 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
std::vector<size_t> workspace_size_list_;
|
||||
std::string mode_;
|
||||
std::string pad_mode_;
|
||||
std::string data_format_ = "NCHW";
|
||||
cudnnDataType_t cudnn_data_type_;
|
||||
cudnnTensorFormat_t compute_format_;
|
||||
int old_height_;
|
||||
int old_width_;
|
||||
int pad_height_;
|
||||
|
@ -286,9 +254,7 @@ class PoolingGradGpuKernel : public GpuKernel {
|
|||
bool is_null_input_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t padded_size_;
|
||||
size_t workspace_size_;
|
||||
bool use_pad_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -45,23 +45,23 @@ def test_maxpool2d_grad():
|
|||
[24, 25, 26, 27, 28, 29],
|
||||
[30, 31, 32, 33, 34, 35]
|
||||
]]]).astype(np.float32))
|
||||
a = Tensor(np.array([[[
|
||||
d = Tensor(np.array([[[
|
||||
[3, 3, 3],
|
||||
[3, 3, 3],
|
||||
[3, 3, 3]
|
||||
]]]).astype(np.float32))
|
||||
d = Tensor(np.array([[[
|
||||
a = Tensor(np.array([[[
|
||||
[7, 9, 11],
|
||||
[19, 21, 23],
|
||||
[31, 33, 35]
|
||||
]]]).astype(np.float32))
|
||||
expect_result = (np.array([[[
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 7, 0, 9, 0, 11],
|
||||
[0, 3, 0, 3, 0, 3],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 19, 0, 21, 0, 23],
|
||||
[0, 3, 0, 3, 0, 3],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 31, 0, 33, 0, 35]
|
||||
[0, 3, 0, 3, 0, 3]
|
||||
]]]))
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
|
Loading…
Reference in New Issue