diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.h index 81174dd81f6..139f0bde7ef 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/combine_momentum_gpu_kernel.h @@ -27,61 +27,54 @@ namespace kernel { template class CombineMomentumGpuKernel : public GpuKernel { public: - CombineMomentumGpuKernel() : element_num_(1), num_(0), max_(0), input_num_(6) {} + CombineMomentumGpuKernel() : element_num_(1), num_(0), input_num_(6) {} ~CombineMomentumGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &workspace, void *stream_ptr) override { - const cudaStream_t stream = reinterpret_cast(stream_ptr); - auto weight_decay = std::make_unique(input_num_ * num_); - auto scale = std::make_unique(input_num_ * num_); - auto variable = std::make_unique(input_num_ * num_); - auto accumulation = std::make_unique(input_num_ * num_); - auto learning_rate = std::make_unique(input_num_ * num_); - auto gradient = std::make_unique(input_num_ * num_); - auto momentum = std::make_unique(input_num_ * num_); - if (input_num_ == 6) { - LaunchCombineMom(inputs, workspace, stream, scale, variable, accumulation, learning_rate, gradient, momentum); - } else { - LaunchCombineMomWeightDecay(inputs, workspace, stream, weight_decay, scale, variable, accumulation, learning_rate, - gradient, momentum); + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + auto stream = reinterpret_cast(stream_ptr); + for (size_t i = 0; i < num_; i++) { + if (input_num_ == 6) { + T *scale = GetDeviceAddress(inputs, i * input_num_); + T *variable = GetDeviceAddress(inputs, i * input_num_ + 1); + T *acc = GetDeviceAddress(inputs, i * input_num_ + 2); + T *lr = GetDeviceAddress(inputs, i * input_num_ + 3); + S *grad = GetDeviceAddress(inputs, i * input_num_ + 4); + T *mom = GetDeviceAddress(inputs, i * input_num_ + 5); + FusedScaleMomentum(elements_[i], scale, variable, acc, lr, grad, mom, stream); + } else { + T *weight_decay = GetDeviceAddress(inputs, i * input_num_); + T *scale = GetDeviceAddress(inputs, i * input_num_ + 1); + T *variable = GetDeviceAddress(inputs, i * input_num_ + 2); + T *acc = GetDeviceAddress(inputs, i * input_num_ + 3); + T *lr = GetDeviceAddress(inputs, i * input_num_ + 4); + S *grad = GetDeviceAddress(inputs, i * input_num_ + 5); + T *mom = GetDeviceAddress(inputs, i * input_num_ + 6); + FusedWeightDecayScaleMomentum(elements_[i], weight_decay, scale, variable, acc, lr, grad, mom, stream); + } } - return true; } bool Init(const CNodePtr &kernel_node) override { num_ = GetAttr(kernel_node, "n"); - elements_ = std::make_unique(num_); auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); if (kernel_name == "CombineMomentum") { input_num_ = 6; } else { input_num_ = 7; - workspace_size_list_.push_back(sizeof(T *) * num_); } - for (size_t i = 0; i < num_; i++) { element_num_ = 1; - auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i * input_num_ + input_num_ - 4); + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i * input_num_ + input_num_ - 5); for (size_t j = 0; j < variable_shape.size(); j++) { element_num_ *= variable_shape[j]; } - if (max_ < element_num_) { - max_ = element_num_; - } - elements_[i] = element_num_; + elements_.push_back(element_num_); InitSizeLists(); } - workspace_size_list_.push_back(sizeof(T *) * num_); - workspace_size_list_.push_back(sizeof(T *) * num_); - workspace_size_list_.push_back(sizeof(T *) * num_); - workspace_size_list_.push_back(sizeof(T *) * num_); - workspace_size_list_.push_back(sizeof(S *) * num_); - workspace_size_list_.push_back(sizeof(T *) * num_); - workspace_size_list_.push_back(sizeof(size_t) * num_); return true; } @@ -100,102 +93,9 @@ class CombineMomentumGpuKernel : public GpuKernel { } private: - void LaunchCombineMom(const std::vector &inputs, const std::vector &workspace, - const cudaStream_t &stream, const std::unique_ptr &scale, - const std::unique_ptr &variable, const std::unique_ptr &accumulation, - const std::unique_ptr &learning_rate, const std::unique_ptr &gradient, - const std::unique_ptr &momentum) { - for (size_t i = 0; i < num_; i++) { - scale[i] = GetDeviceAddress(inputs, i * input_num_); - variable[i] = GetDeviceAddress(inputs, i * input_num_ + 1); - accumulation[i] = GetDeviceAddress(inputs, i * input_num_ + 2); - learning_rate[i] = GetDeviceAddress(inputs, i * input_num_ + 3); - gradient[i] = GetDeviceAddress(inputs, i * input_num_ + 4); - momentum[i] = GetDeviceAddress(inputs, i * input_num_ + 5); - } - T **scale_dev = GetDeviceAddress(workspace, 0); - T **variable_dev = GetDeviceAddress(workspace, 1); - T **accumulation_dev = GetDeviceAddress(workspace, 2); - T **learning_rate_dev = GetDeviceAddress(workspace, 3); - S **gradient_dev = GetDeviceAddress(workspace, 4); - T **momentum_dev = GetDeviceAddress(workspace, 5); - size_t *elements_dev = GetDeviceAddress(workspace, 6); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(scale_dev, scale.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(variable_dev, variable.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(accumulation_dev, accumulation.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(learning_rate_dev, learning_rate.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(gradient_dev, gradient.get(), sizeof(S *) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(momentum_dev, momentum.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(elements_dev, elements_.get(), sizeof(size_t) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CombineFusedScaleMomentum(max_, num_, elements_dev, scale_dev, variable_dev, accumulation_dev, learning_rate_dev, - gradient_dev, momentum_dev, stream); - } - void LaunchCombineMomWeightDecay(const std::vector &inputs, const std::vector &workspace, - const cudaStream_t &stream, const std::unique_ptr &weight_decay, - const std::unique_ptr &scale, const std::unique_ptr &variable, - const std::unique_ptr &accumulation, - const std::unique_ptr &learning_rate, const std::unique_ptr &gradient, - const std::unique_ptr &momentum) { - for (size_t i = 0; i < num_; i++) { - weight_decay[i] = GetDeviceAddress(inputs, i * input_num_); - scale[i] = GetDeviceAddress(inputs, i * input_num_ + 1); - variable[i] = GetDeviceAddress(inputs, i * input_num_ + 2); - accumulation[i] = GetDeviceAddress(inputs, i * input_num_ + 3); - learning_rate[i] = GetDeviceAddress(inputs, i * input_num_ + 4); - gradient[i] = GetDeviceAddress(inputs, i * input_num_ + 5); - momentum[i] = GetDeviceAddress(inputs, i * input_num_ + 6); - } - T **weight_decay_dev = GetDeviceAddress(workspace, 0); - T **scale_dev = GetDeviceAddress(workspace, 1); - T **variable_dev = GetDeviceAddress(workspace, 2); - T **accumulation_dev = GetDeviceAddress(workspace, 3); - T **learning_rate_dev = GetDeviceAddress(workspace, 4); - S **gradient_dev = GetDeviceAddress(workspace, 5); - T **momentum_dev = GetDeviceAddress(workspace, 6); - size_t *elements_dev = GetDeviceAddress(workspace, 7); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(weight_decay_dev, weight_decay.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(scale_dev, scale.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(variable_dev, variable.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(accumulation_dev, accumulation.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(learning_rate_dev, learning_rate.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(gradient_dev, gradient.get(), sizeof(S *) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(momentum_dev, momentum.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpyAsync(elements_dev, elements_.get(), sizeof(size_t) * num_, cudaMemcpyHostToDevice, stream), - "cudaMemCPY failed") - CombineFusedWeightDecayScaleMomentum(max_, num_, elements_dev, weight_decay_dev, scale_dev, variable_dev, - accumulation_dev, learning_rate_dev, gradient_dev, momentum_dev, stream); - } size_t element_num_; - std::unique_ptr elements_; + std::vector elements_; size_t num_; - size_t max_; int input_num_; std::vector input_size_list_; std::vector output_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h index daacab87e0b..710c67aa3d2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h @@ -17,12 +17,13 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2DGPUKERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2DGPUKERNEL_H_ -#include -#include #include +#include +#include + +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" -#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" #include "backend/kernel_compiler/gpu/kernel_constants.h" namespace mindspore { @@ -77,7 +78,7 @@ class Conv2dGpuFwdKernel : public GpuKernel { const float beta = 0; if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { T *padded_addr = GetDeviceAddress(workspace, 1); - if (data_format_ == "NHWC") { + if (data_format_ == kOpFormat_NHWC) { CalPadNHWC(padded_size_ / sizeof(T), input_addr, n_, old_height_, old_width_, c_, old_height_ + pad_height_, old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr, reinterpret_cast(stream_ptr)); @@ -106,6 +107,10 @@ class Conv2dGpuFwdKernel : public GpuKernel { } cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); + auto format_attr = GetAttr(kernel_node, "data_format"); + if (format_attr == kOpFormat_NHWC) { + data_format_ = kOpFormat_NHWC; + } auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); @@ -116,7 +121,7 @@ class Conv2dGpuFwdKernel : public GpuKernel { return true; } SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_); - if (data_format_ == "NHWC") { + if (data_format_ == kOpFormat_NHWC) { compute_format_ = CUDNN_TENSOR_NHWC; } Set4DDesc(in_shape, filter_shape, output_shape); @@ -144,12 +149,12 @@ class Conv2dGpuFwdKernel : public GpuKernel { } int dimA[4]; int strideApadded[4]; - if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { + if (data_format_ == kOpFormat_NCHW || data_format_ == kOpFormat_DEFAULT) { auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_)}; SetDimA(padded_shape, dimA, 4, data_format_); SetStrideA(padded_shape, strideApadded, 4, data_format_); - } else if (data_format_ == "NHWC") { + } else if (data_format_ == kOpFormat_NHWC) { auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), IntToSize(c_)}; SetDimA(padded_shape, dimA, 4, data_format_); @@ -324,7 +329,7 @@ class Conv2dGpuFwdKernel : public GpuKernel { cudnnConvolutionDescriptor_t conv_desc_; cudnnTensorDescriptor_t padded_desc_; std::string pad_mode_; - std::string data_format_ = "NCHW"; + std::string data_format_ = kOpFormat_NCHW; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h index 1d82f2e6dfc..62b2c863670 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_filter_gpu_kernel.h @@ -17,12 +17,13 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_ -#include -#include #include +#include +#include + +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" -#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" #include "backend/kernel_compiler/gpu/kernel_constants.h" namespace mindspore { @@ -79,7 +80,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) { T *padded = GetDeviceAddress(workspace, 1); - if (data_format_ == "NHWC") { + if (data_format_ == kOpFormat_NHWC) { CalPadNHWC(padded_size_ / sizeof(T), x, n_, old_height_, old_width_, c_, old_height_ + pad_height_, old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded, reinterpret_cast(stream_ptr)); @@ -115,9 +116,13 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { return true; } data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); + format_attr_ = GetAttr(kernel_node, "data_format"); + if (format_attr_ == kOpFormat_NHWC) { + data_format_ = kOpFormat_NHWC; + } std::vector filter_shape; GetFilterShape(kernel_node, &filter_shape); - if (data_format_ == "NHWC") { + if (data_format_ == kOpFormat_NHWC) { compute_format_ = CUDNN_TENSOR_NHWC; } SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_); @@ -145,12 +150,12 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { } int dimA[4]; int strideApadded[4]; - if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { + if (data_format_ == kOpFormat_NCHW || data_format_ == kOpFormat_DEFAULT) { auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_)}; SetDimA(padded_shape, dimA, 4, data_format_); SetStrideA(padded_shape, strideApadded, 4, data_format_); - } else if (data_format_ == "NHWC") { + } else if (data_format_ == kOpFormat_NHWC) { auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), IntToSize(c_)}; SetDimA(padded_shape, dimA, 4, data_format_); @@ -292,10 +297,9 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { SetStrideA(in_shape, strideAin, 4, data_format_); SetDimA(dy_shape, dimAdy, 4, data_format_); SetStrideA(dy_shape, strideAdy, 4, data_format_); - // filter shape always keep OIHW. - int filterDimA[4] = {SizeToInt(filter_shape[0]), SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]), - SizeToInt(filter_shape[3])}; - + // filter shape relued by format_attr_. In native mode it's OHWI. In transpose mode it's OIHW. + int filterDimA[4]; + SetDimA(filter_shape, filterDimA, 4, format_attr_); CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, nbDims, dimAdy, strideAdy), "cudnnSetTensorNdDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT( @@ -325,7 +329,8 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel { cudnnTensorDescriptor_t padded_descriptor_; cudnnConvolutionBwdFilterAlgo_t algo_; std::string pad_mode_; - std::string data_format_ = "NCHW"; + std::string data_format_ = kOpFormat_NCHW; + std::string format_attr_ = kOpFormat_NCHW; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h index d9490d84640..f13f9df8626 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_grad_input_gpu_kernel.h @@ -17,12 +17,13 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_ -#include -#include #include +#include +#include + +#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" -#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" #include "backend/kernel_compiler/gpu/kernel_constants.h" namespace mindspore { @@ -83,7 +84,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space, workspace_size_, &beta_, padded_descriptor_, padded), "ConvolutionBackwardData failed"); - if (data_format_ == "NHWC") { + if (data_format_ == kOpFormat_NHWC) { CalPadGradNHWC(output_size_ / sizeof(T), padded, n_, old_height_, old_width_, c_, old_height_ + pad_height_, old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast(stream_ptr)); } else { @@ -105,6 +106,10 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { } cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); + auto format_attr = GetAttr(kernel_node, "data_format"); + if (format_attr == kOpFormat_NHWC) { + data_format_ = kOpFormat_NHWC; + } auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); is_null_input_ = CHECK_NULL_INPUT(dy_shape); @@ -116,9 +121,11 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { std::vector input_shape; GetInputShape(kernel_node, &input_shape); - if (data_format_ == "NHWC") { + if (data_format_ == kOpFormat_NHWC) { compute_format_ = CUDNN_TENSOR_NHWC; - ShapeNCHW2NHWC(&input_shape); + if (format_attr == kOpFormat_NCHW) { + ShapeNCHW2NHWC(&input_shape); + } } SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_); Set4DDesc(dy_shape, input_shape, filter_shape); @@ -146,12 +153,12 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { } int dimA[4]; int strideApadded[4]; - if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") { + if (data_format_ == kOpFormat_NCHW || data_format_ == kOpFormat_DEFAULT) { auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_)}; SetDimA(padded_shape, dimA, 4, data_format_); SetStrideA(padded_shape, strideApadded, 4, data_format_); - } else if (data_format_ == "NHWC") { + } else if (data_format_ == kOpFormat_NHWC) { auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_), IntToSize(c_)}; SetDimA(padded_shape, dimA, 4, data_format_); @@ -326,7 +333,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel { cudnnTensorDescriptor_t padded_descriptor_; cudnnConvolutionBwdDataAlgo_t algo_; std::string pad_mode_; - std::string data_format_ = "NCHW"; + std::string data_format_ = kOpFormat_NCHW; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_ex_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_ex_gpu_kernel.h index ac1b9817484..8262e3423bf 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_ex_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_ex_gpu_kernel.h @@ -17,8 +17,8 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_EX_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_EX_GPU_KERNEL_H_ -#include #include +#include #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/kernel_constants.h" @@ -131,6 +131,10 @@ class FusedBatchNormExGpuKernel : public GpuKernel { return true; } auto format = AnfAlgo::GetInputFormat(kernel_node, 0); + auto format_attr = GetAttr(kernel_node, "data_format"); + if (format_attr == kOpFormat_NHWC) { + format = kOpFormat_NHWC; + } SetTensorDescriptor(format, shape); InitSizeLists(); return true; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h index b029929b022..2d107136f7d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_gpu_kernel.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_ +#include #include #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" @@ -98,11 +99,14 @@ class FusedBatchNormGpuKernel : public GpuKernel { InitSizeLists(); return true; } - batch_ = SizeToInt(shape[0]); - channel_ = SizeToInt(shape[1]); - height_ = SizeToInt(shape[2]); - width_ = SizeToInt(shape[3]); - + cudnnTensorFormat_t cudnn_format = CUDNN_TENSOR_NCHW; + auto format = AnfAlgo::GetInputFormat(kernel_node, 0); + auto format_attr = GetAttr(kernel_node, "data_format"); + if (format_attr == kOpFormat_NHWC) { + format = kOpFormat_NHWC; + cudnn_format = CUDNN_TENSOR_NHWC; + } + SetNCHW(shape, &batch_, &channel_, &height_, &width_, format); mode_ = CUDNN_BATCHNORM_SPATIAL; epsilon_ = GetAttr(kernel_node, "epsilon"); // P.FusedBatchNorm is used for training; P.BatchNorm is used for inference @@ -113,15 +117,15 @@ class FusedBatchNormGpuKernel : public GpuKernel { } CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_), "Set x desc failed"); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_), + cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_), "Set y desc failed"); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), + cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, cudnn_format, CUDNN_DATA_FLOAT, 1, channel_, 1, 1), "Set para desc failed"); InitSizeLists(); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h index 20c18ab81d1..49d05b8729d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h @@ -17,12 +17,13 @@ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_ -#include #include +#include +#include "utils/utils.h" + #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/kernel_constants.h" -#include "utils/utils.h" namespace mindspore { namespace kernel { @@ -140,6 +141,10 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel { return true; } std::string format = AnfAlgo::GetInputFormat(kernel_node, 0); + auto format_attr = GetAttr(kernel_node, "data_format"); + if (format_attr == kOpFormat_NHWC) { + format = kOpFormat_NHWC; + } beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1; SetTensorDescriptor(format, shape); InitSizeLists(); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h index 8e4e506d982..1b673e7c676 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h @@ -78,6 +78,10 @@ class PoolingGpuFwdKernel : public GpuKernel { } cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); + auto format_attr = GetAttr(kernel_node, "data_format"); + if (format_attr == kOpFormat_NHWC) { + data_format_ = kOpFormat_NHWC; + } auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); @@ -200,7 +204,7 @@ class PoolingGpuFwdKernel : public GpuKernel { std::vector stride_; std::string mode_; std::string pad_mode_; - std::string data_format_ = "NCHW"; + std::string data_format_ = kOpFormat_NCHW; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h index b486165c842..d271a23b973 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_grad_gpu_kernel.h @@ -86,6 +86,10 @@ class PoolingGradGpuKernel : public GpuKernel { auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); + auto format_attr = GetAttr(kernel_node, "data_format"); + if (format_attr == kOpFormat_NHWC) { + data_format_ = kOpFormat_NHWC; + } cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask); if (is_null_input_) { @@ -236,7 +240,7 @@ class PoolingGradGpuKernel : public GpuKernel { std::vector workspace_size_list_; std::string mode_; std::string pad_mode_; - std::string data_format_ = "NCHW"; + std::string data_format_ = kOpFormat_NCHW; cudnnDataType_t cudnn_data_type_; cudnnTensorFormat_t compute_format_; int old_height_; diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc index 21d2f5be0e2..466aeb39e6c 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc @@ -46,8 +46,10 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons MS_EXCEPTION_IF_NULL(tuple_get_item); auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast(tuple_get_item), 0); MS_EXCEPTION_IF_NULL(batch_norm_ex); - - if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) { + auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("data_format"); + MS_EXCEPTION_IF_NULL(format_attr); + auto format = GetValue(format_attr); + if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { return nullptr; } diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc index f7d6906f5c2..6e81a262b87 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc @@ -123,8 +123,10 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); - - if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) { + auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format"); + MS_EXCEPTION_IF_NULL(format_attr); + auto format = GetValue(format_attr); + if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { return nullptr; } diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc index 6e98a8ae01c..629dd17714e 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc @@ -43,8 +43,10 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A MS_EXCEPTION_IF_NULL(tuple_get_item); auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast(tuple_get_item), 0); MS_EXCEPTION_IF_NULL(batch_norm_ex); - - if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) { + auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("data_format"); + MS_EXCEPTION_IF_NULL(format_attr); + auto format = GetValue(format_attr); + if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { return nullptr; } diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc index e4332f16202..e8dc5395914 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc @@ -38,8 +38,10 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); - - if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) { + auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format"); + MS_EXCEPTION_IF_NULL(format_attr); + auto format = GetValue(format_attr); + if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { return nullptr; } diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 56a39a357a5..38091cf110a 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -26,6 +26,8 @@ #include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h" #include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h" #include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h" +#include "backend/optimizer/gpu/combine_momentum_fusion.h" +#include "backend/optimizer/gpu/combine_cast_fusion.h" #include "backend/optimizer/gpu/cudnn_inplace_fusion.h" #include "backend/optimizer/gpu/insert_format_transform_op.h" #include "backend/optimizer/gpu/replace_momentum_cast_fusion.h" @@ -85,6 +87,10 @@ void GPUSession::Optimize(const std::shared_ptr &kernel_graph) { auto pm = std::make_shared(); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared("cast_all")); + pm->AddPass(std::make_shared("combine_momentum")); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); @@ -98,6 +104,7 @@ void GPUSession::HardwareOptimize(const std::shared_ptr &kernel_gra pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc index eda93b54012..2bdeff5ae6a 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -341,6 +341,12 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptrGetAttr("data_format") != nullptr && + GetValue(value->GetAttr("data_format")) == kOpFormat_NHWC) { + format_transform_ = false; + return; + } if (kernel_name == prim::kPrimConv2D->name()) { conv_cnt++; } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 02a3279e328..47fdc742396 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -85,6 +85,8 @@ constexpr auto kSplitVOpName = "SplitV"; constexpr auto kSparseApplyAdagradOpName = "SparseApplyAdagrad"; constexpr auto kMomentumOpName = "Momentum"; constexpr auto kApplyMomentumOpName = "ApplyMomentum"; +constexpr auto kCombineMomentumOpName = "CombineMomentum"; +constexpr auto kCombineMomentumWeightOpName = "CombineMomentumWeight"; constexpr auto kApplyAdadeltaOpName = "ApplyAdadelta"; constexpr auto kApplyAdagradOpName = "ApplyAdagrad"; constexpr auto kApplyAdagradDAName = "ApplyAdagradDA"; @@ -374,38 +376,38 @@ const std::set kOpFormatList = { kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM}; const std::set kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; -const std::set kOptOperatorSet = { - kMomentumOpName, - kApplyMomentumOpName, - kApplyAdadeltaOpName, - kApplyAdagradOpName, - kApplyAdagradDAName, - kApplyAdamOpName, - kApplyAdaMaxOpName, - kApplyAddSignOpName, - kApplyCenteredRMSPOpName, - kApplyFtrlOpName, - kApplyFtrlV2OpName, - kApplyGradientDescentOpName, - kApplyPowerSignOpName, - kApplyProximalAdagradOpName, - kApplyProximalGradientDescentOpName, - kApplyRMSPropOpName, - kFusedAdamWeightDecayName, - kFusedAdamName, - kFusedSparseAdamName, - kFusedWeightScaleApplyMomentum, - kFusedScaleApplyMomentum, - kApplyCenteredRMSPropOpName, - kFusedSparseFtrlName, - kFusedSparseProximalAdagradName, - kFusedSparseLazyAdamName, - kSparseApplyFtrlName, - kSparseApplyFtrlV2Name, - kSGDName, - kLARSUpdateName, - kPullOpName, -}; +const std::set kOptOperatorSet = {kMomentumOpName, + kApplyMomentumOpName, + kApplyAdadeltaOpName, + kApplyAdagradOpName, + kApplyAdagradDAName, + kApplyAdamOpName, + kApplyAdaMaxOpName, + kApplyAddSignOpName, + kApplyCenteredRMSPOpName, + kApplyFtrlOpName, + kApplyFtrlV2OpName, + kApplyGradientDescentOpName, + kApplyPowerSignOpName, + kApplyProximalAdagradOpName, + kApplyProximalGradientDescentOpName, + kApplyRMSPropOpName, + kFusedAdamWeightDecayName, + kFusedAdamName, + kFusedSparseAdamName, + kFusedWeightScaleApplyMomentum, + kFusedScaleApplyMomentum, + kApplyCenteredRMSPropOpName, + kFusedSparseFtrlName, + kFusedSparseProximalAdagradName, + kFusedSparseLazyAdamName, + kSparseApplyFtrlName, + kSparseApplyFtrlV2Name, + kSGDName, + kLARSUpdateName, + kPullOpName, + kCombineMomentumWeightOpName, + kCombineMomentumOpName}; const std::set kHWSpecialFormatSet = { kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, diff --git a/mindspore/nn/layer/conv.py b/mindspore/nn/layer/conv.py index 793d338da4e..07ea983b3bf 100644 --- a/mindspore/nn/layer/conv.py +++ b/mindspore/nn/layer/conv.py @@ -45,6 +45,7 @@ class _Conv(Cell): has_bias, weight_init, bias_init, + data_format='NCHW', transposed=False): super(_Conv, self).__init__() self.in_channels = Validator.check_positive_int(in_channels) @@ -54,6 +55,9 @@ class _Conv(Cell): self.pad_mode = pad_mode self.weight_init = weight_init self.bias_init = bias_init + self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") if isinstance(padding, int): Validator.check_non_negative_int(padding, 'padding', self.cls_name) self.padding = padding @@ -89,7 +93,8 @@ class _Conv(Cell): if transposed: shape = [in_channels, out_channels // group, *kernel_size] else: - shape = [out_channels, in_channels // group, *kernel_size] + shape = [out_channels, in_channels // group, *kernel_size] if self.format == "NCHW" else \ + [out_channels, *kernel_size, in_channels // group] self.weight = Parameter(initializer(self.weight_init, shape), name='weight') if Validator.check_bool(has_bias): @@ -181,12 +186,15 @@ class Conv2d(_Conv): bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible Initializer and string are the same as 'weight_init'. Refer to the values of Initializer for more details. Default: 'zeros'. + data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. + Default: 'NCHW'. Inputs: - - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})` \ + or `(N, H_{in}, W_{in}, C_{in})`. Outputs: - Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})` or `(N, H_{out}, W_{out}, C_{out})`. Examples: >>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal') @@ -207,7 +215,8 @@ class Conv2d(_Conv): group=1, has_bias=False, weight_init='normal', - bias_init='zeros'): + bias_init='zeros', + data_format='NCHW'): kernel_size = twice(kernel_size) stride = twice(stride) self._dilation = dilation @@ -223,7 +232,8 @@ class Conv2d(_Conv): group, has_bias, weight_init, - bias_init) + bias_init, + data_format) self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1, @@ -231,7 +241,8 @@ class Conv2d(_Conv): pad=self.padding, stride=self.stride, dilation=self.dilation, - group=self.group) + group=self.group, + data_format=self.format) self._init_depthwise_conv2d() self.bias_add = P.BiasAdd() @@ -263,8 +274,8 @@ class Conv2d(_Conv): def extend_repr(self): s = 'input_channels={}, output_channels={}, kernel_size={},' \ 'stride={}, pad_mode={}, padding={}, dilation={}, ' \ - 'group={}, has_bias={},' \ - 'weight_init={}, bias_init={}'.format( + 'group={}, has_bias={}' \ + 'weight_init={}, bias_init={}, format={}'.format( self.in_channels, self.out_channels, self.kernel_size, @@ -275,7 +286,8 @@ class Conv2d(_Conv): self.group, self.has_bias, self.weight_init, - self.bias_init) + self.bias_init, + self.format) return s diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 8555fc2be2c..7f0987b16b8 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -44,14 +44,17 @@ class _BatchNorm(Cell): moving_var_init='ones', use_batch_statistics=None, device_num_each_group=1, - input_dims='2d'): + input_dims='2d', + data_format='NCHW'): super(_BatchNorm, self).__init__() if num_features < 1: raise ValueError("num_features must be at least 1") if momentum < 0 or momentum > 1: raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum)) - + self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") self.use_batch_statistics = use_batch_statistics self.num_features = num_features self.eps = eps @@ -99,7 +102,8 @@ class _BatchNorm(Cell): elif self.is_gpu: self.bn_train = P.FusedBatchNormEx(mode=1, epsilon=self.eps, - momentum=self.momentum) + momentum=self.momentum, + data_format=self.format) else: self.bn_train = P.FusedBatchNorm(mode=1, epsilon=self.eps, @@ -352,6 +356,8 @@ class BatchNorm2d(_BatchNorm): use the mean value and variance value of specified value. If None, the training process will use the mean and variance of current batch data and track the running mean and variance, the evaluation process will use the running mean and variance. Default: None. + data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. + Default: 'NCHW'. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -374,7 +380,8 @@ class BatchNorm2d(_BatchNorm): beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', - use_batch_statistics=None): + use_batch_statistics=None, + data_format='NCHW'): super(BatchNorm2d, self).__init__(num_features, eps, momentum, @@ -384,7 +391,8 @@ class BatchNorm2d(_BatchNorm): moving_mean_init, moving_var_init, use_batch_statistics, - input_dims='2d') + input_dims='2d', + data_format=data_format) def _check_data_dim(self, x): if x.dim() != 4: diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 21f49101a5d..f8b37978321 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -25,10 +25,12 @@ __all__ = ['AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'MaxPool1d'] class _PoolNd(Cell): """N-D AvgPool""" - def __init__(self, kernel_size, stride, pad_mode): + def __init__(self, kernel_size, stride, pad_mode, data_format="NCHW"): super(_PoolNd, self).__init__() self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name) - + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") def _check_int_or_tuple(arg_name, arg_value): validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name) error_msg = f'For \'{self.cls_name}\' the {arg_name} should be an positive int number or ' \ @@ -93,6 +95,8 @@ class MaxPool2d(_PoolNd): - valid: Adopts the way of discarding. The possible largest height and width of output will be returned without padding. Extra pixels will be discarded. + data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. + Default: 'NCHW'. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -121,11 +125,12 @@ class MaxPool2d(_PoolNd): [8. 8.]]]] """ - def __init__(self, kernel_size=1, stride=1, pad_mode="valid"): - super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode) + def __init__(self, kernel_size=1, stride=1, pad_mode="valid", data_format="NCHW"): + super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode, data_format) self.max_pool = P.MaxPool(ksize=self.kernel_size, strides=self.stride, - padding=self.pad_mode) + padding=self.pad_mode, + data_format=self.format) self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size, strides=self.stride, padding=self.pad_mode) @@ -252,6 +257,8 @@ class AvgPool2d(_PoolNd): - valid: Adopts the way of discarding. The possible largest height and width of output will be returned without padding. Extra pixels will be discarded. + data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. + Default: 'NCHW'. Inputs: @@ -284,11 +291,13 @@ class AvgPool2d(_PoolNd): def __init__(self, kernel_size=1, stride=1, - pad_mode="valid"): - super(AvgPool2d, self).__init__(kernel_size, stride, pad_mode) + pad_mode="valid", + data_format="NCHW"): + super(AvgPool2d, self).__init__(kernel_size, stride, pad_mode, data_format) self.avg_pool = P.AvgPool(ksize=self.kernel_size, strides=self.stride, - padding=self.pad_mode) + padding=self.pad_mode, + data_format=self.format) def construct(self, x): return self.avg_pool(x) diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 68cb6e33abc..3853afeaaa5 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -31,7 +31,7 @@ from ... import context @bprop_getters.register(P.BiasAdd) def get_bprop_bias_add(self): """Grad definition for `BiasAdd` operation.""" - bias_grad = SG.BiasAddGrad() + bias_grad = SG.BiasAddGrad(self.data_format) def bprop(x, w, out, dout): return dout, bias_grad(dout) @@ -44,11 +44,11 @@ def get_bprop_conv2d(self): """Grad definition for `Conv2D` operation.""" input_grad = P.Conv2DBackpropInput( self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode, - dilation=self.dilation, stride=self.stride, group=self.group + dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format ) filter_grad = G.Conv2DBackpropFilter( self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode, - dilation=self.dilation, stride=self.stride, group=self.group + dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format ) get_shape = P.Shape() @@ -224,7 +224,8 @@ def get_bprop_max_pool_grad(self): maxpool_grad = G.MaxPoolGrad( ksize=self.ksize, strides=self.strides, - padding=self.padding) + padding=self.padding, + data_format=self.format) def bprop(x, out, dout): dx = maxpool_grad(x, out, dout) @@ -324,7 +325,8 @@ def get_bprop_avg_pool_grad(self): avgpool_grad_gpu = G.AvgPoolGradGpu( ksize=self.ksize, strides=self.strides, - padding=self.padding) + padding=self.padding, + data_format=self.format) def bprop_gpu(x, out, dout): dx = avgpool_grad_gpu(x, out, dout) @@ -574,7 +576,7 @@ def get_bprop_fused_batch_norm(self): @bprop_getters.register(P.FusedBatchNormEx) def get_bprop_fused_batch_norm_ex(self): """Grad definition for `FusedBatchNormEx` operation.""" - input_grad = G.FusedBatchNormGradEx(self.epsilon, self.momentum) + input_grad = G.FusedBatchNormGradEx(self.epsilon, self.momentum, self.format) def bprop(x, scale, b, mean, variance, out, dout): saved_mean = out[3] @@ -922,11 +924,11 @@ def get_bprop_conv2d_backprop_input(self): """Grad definition for `Conv2DBackpropInput` operation.""" filter_grad = G.Conv2DBackpropFilter( self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode, - dilation=self.dilation, stride=self.stride, group=self.group + dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format ) input_grad = P.Conv2D( self.out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad, - dilation=self.dilation, stride=self.stride, group=self.group + dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format ) def bprop(x, w, f_sizes, out, dout): diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 39ca6581ba5..cc61c9c4c77 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -21,6 +21,7 @@ from ..._checkparam import Validator as validator, Rel from .._utils import get_concat_offset from ...common import dtype as mstype from .. import functional as F +from ... import context class AbsGrad(PrimitiveWithInfer): """Computes gradients for abs operation.""" @@ -199,16 +200,23 @@ class BatchNormGrad(PrimitiveWithInfer): return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) -class BiasAddGrad(Primitive): +class BiasAddGrad(PrimitiveWithInfer): """Computes gradients of BiasAdd.""" @prim_attr_register - def __init__(self): + def __init__(self, data_format="NCHW"): self.init_prim_io_names(inputs=['dout'], outputs=['output']) - self.add_prim_attr('data_format', 'NCHW') + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") + self.add_prim_attr('data_format', self.format) - def __call__(self, d_output): - raise NotImplementedError + def infer_shape(self, d_output): + channel = d_output[1] if self.format == "NCHW" else d_output[-1] + return (channel,) + + def infer_dtype(self, dout_dtype): + return dout_dtype class KLDivLossGrad(PrimitiveWithInfer): @@ -291,6 +299,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1). dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1). group (int): Splits input into groups. Default: 1. + data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\ + default is 'NCHW'. Returns: Tensor, the gradients of convolution. @@ -306,7 +316,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): mode=1, stride=(1, 1), dilation=(1, 1, 1, 1), - group=1): + group=1, + data_format="NCHW"): """Initialize Convolution""" self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output']) self.out_channel = out_channel @@ -321,7 +332,10 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): self.dilation = dilation self.group = group self.add_prim_attr('groups', group) - self.add_prim_attr('data_format', "NCHW") + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") + self.add_prim_attr('data_format', self.format) def __infer__(self, doutput, x, w_size): w_size_v = w_size['value'] @@ -530,10 +544,13 @@ class FusedBatchNormGradEx(PrimitiveWithInfer): """Gradients of FusedBatchNormEx operation.""" @prim_attr_register - def __init__(self, epsilon=0.0, momentum=0.1): + def __init__(self, epsilon=0.0, momentum=0.1, data_format="NCHW"): self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance', 'reserve'], outputs=['dx', 'bn_scale', 'bn_bias']) - self.add_prim_attr('data_format', "NCHW") + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") + self.add_prim_attr('data_format', self.format) def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve_shape): return (x_shape, scale_shape, scale_shape) @@ -604,16 +621,19 @@ class _PoolGrad(PrimitiveWithInfer): """Gradients of the max/avg pool operation.""" @prim_attr_register - def __init__(self, ksize, strides, padding="VALID"): + def __init__(self, ksize, strides, padding="VALID", data_format="NCHW"): self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output']) validator.check_value_type('ksize', ksize, [int, tuple], self.name) validator.check_value_type('strides', strides, [int, tuple], self.name) self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name) self.add_prim_attr("padding", self.padding) + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax") if not self.is_maxpoolgradwithargmax: - self.add_prim_attr('data_format', "NCHW") + self.add_prim_attr('data_format', self.format) def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax): validator.check_value_type(arg_name, arg_val, (int, tuple), self.name) @@ -633,10 +653,12 @@ class _PoolGrad(PrimitiveWithInfer): raise error_msg return ret - self.ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax) + ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax) + self.ksize = ksize if self.format == "NCHW" else [ksize[0], ksize[2], ksize[3], ksize[1]] self.add_prim_attr("ksize", self.ksize) - self.strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax) + strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax) + self.strides = strides if self.format == "NCHW" else [strides[0], strides[2], strides[3], strides[1]] self.add_prim_attr("strides", self.strides) @@ -679,8 +701,8 @@ class AvgPoolGradGpu(_PoolGrad): """Gradients of the avg pool operation for gpu.""" @prim_attr_register - def __init__(self, ksize=1, strides=1, padding="VALID"): - super(AvgPoolGradGpu, self).__init__(ksize, strides, padding) + def __init__(self, ksize=1, strides=1, padding="VALID", data_format="NCHW"): + super(AvgPoolGradGpu, self).__init__(ksize, strides, padding, data_format) def infer_shape(self, x1_shape, x2_shape, grad_shape): return x1_shape @@ -693,8 +715,8 @@ class MaxPoolGrad(_PoolGrad): """Performs gradients of the max pool operation.""" @prim_attr_register - def __init__(self, ksize=1, strides=1, padding="VALID"): - super(MaxPoolGrad, self).__init__(ksize, strides, padding) + def __init__(self, ksize=1, strides=1, padding="VALID", data_format="NCHW"): + super(MaxPoolGrad, self).__init__(ksize, strides, padding, data_format) def infer_shape(self, x1_shape, x2_shape, grad_shape): return x1_shape @@ -763,7 +785,7 @@ class MaxPoolGradWithArgmax(_PoolGrad): """Computes the gradients of MaxPoolWithArgmax.""" @prim_attr_register - def __init__(self, ksize=1, strides=1, padding="VALID",): + def __init__(self, ksize=1, strides=1, padding="VALID"): self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output']) super(MaxPoolGradWithArgmax, self).__init__(ksize, strides, padding) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 399e5ce73f4..3c0d460afaf 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -666,6 +666,8 @@ class FusedBatchNormEx(PrimitiveWithInfer): momentum (float): The hyper parameter to compute moving average for running_mean and running_var (e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`). Momentum value must be [0, 1]. Default: 0.9. + data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. + Default: "NCHW". Inputs: - **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`, @@ -706,20 +708,25 @@ class FusedBatchNormEx(PrimitiveWithInfer): ) @prim_attr_register - def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): + def __init__(self, mode=0, epsilon=1e-5, momentum=0.1, data_format="NCHW"): self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve']) self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name) self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) self._update_parameter = True - self.add_prim_attr('data_format', "NCHW") + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") + self.add_prim_attr('data_format', self.format) def infer_shape(self, input_x, scale, bias, mean, variance): + input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2]) validator.check_equal_int(len(scale), 1, "scale rank", self.name) validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) - validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) + validator.check("scale shape[0]", scale[0], "input channel", input_shape_norm[1], Rel.EQ, self.name) validator.check_equal_int(len(mean), 1, "mean rank", self.name) + validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) return (input_x, scale, scale, scale, scale, scale) @@ -868,6 +875,8 @@ class BatchNorm(PrimitiveWithInfer): is_training (bool): If `is_training` is True, `mean` and `variance` are computed during training. If `is_training` is False, they're loaded from checkpoint during inference. Default: False. epsilon (float): A small value added for numerical stability. Default: 1e-5. + data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. + Default: "NCHW". Inputs: - **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type. @@ -896,17 +905,21 @@ class BatchNorm(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self, is_training=False, epsilon=1e-5): + def __init__(self, is_training=False, epsilon=1e-5, data_format="NCHW"): validator.check_value_type('is_training', is_training, (bool,), self.name) validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) - self.add_prim_attr('data_format', "NCHW") + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") + self.add_prim_attr('data_format', self.format) self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) def infer_shape(self, input_x, scale, bias, mean, variance): + input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2]) validator.check_equal_int(len(scale), 1, "scale rank", self.name) validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) - validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) + validator.check("scale shape[0]", scale[0], "input_x channel", input_shape_norm[1], Rel.EQ, self.name) if not self.is_training: validator.check_equal_int(len(mean), 1, "mean rank", self.name) validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) @@ -970,6 +983,7 @@ class Conv2D(PrimitiveWithInfer): stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1. dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1. group (int): Splits input into groups. Default: 1. + data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: "NCHW". Returns: Tensor, the value that applied 2D convolution. @@ -998,7 +1012,8 @@ class Conv2D(PrimitiveWithInfer): pad=0, stride=1, dilation=1, - group=1): + group=1, + data_format="NCHW"): """Initialize Conv2D""" self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) @@ -1021,54 +1036,63 @@ class Conv2D(PrimitiveWithInfer): validator.check_non_negative_int(item, 'pad item', self.name) self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) - self.add_prim_attr('data_format', "NCHW") + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") + self.add_prim_attr('data_format', self.format) self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) self.group = validator.check_positive_int(group, 'group', self.name) self.add_prim_attr('offset_a', 0) def infer_shape(self, x_shape, w_shape, b_shape=None): - validator.check_equal_int(len(w_shape), 4, "weight rank", self.name) - validator.check_equal_int(len(x_shape), 4, "x rank", self.name) - validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) - validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) - validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) + x_shape_norm = x_shape if self.format == "NCHW" else (x_shape[0], x_shape[3], x_shape[1], x_shape[2]) + w_shape_norm = w_shape if self.format == "NCHW" else (w_shape[0], w_shape[3], w_shape[1], w_shape[2]) + + validator.check_equal_int(len(w_shape_norm), 4, "weight rank", self.name) + validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name) + validator.check(f"x_shape[1] / group", x_shape_norm[1] // self.group, "w_shape[1]", w_shape_norm[1], \ + Rel.EQ, self.name) + validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape_norm[0], Rel.EQ, self.name) + validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape_norm[2:4]), Rel.EQ, self.name) + + kernel_size_h = w_shape_norm[2] + kernel_size_w = w_shape_norm[3] - kernel_size_h = w_shape[2] - kernel_size_w = w_shape[3] stride_h = self.stride[2] stride_w = self.stride[3] dilation_h = self.dilation[2] dilation_w = self.dilation[3] if self.pad_mode == "valid": - h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h) - w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w) + h_out = math.ceil((x_shape_norm[2] - dilation_h * (kernel_size_h - 1)) / stride_h) + w_out = math.ceil((x_shape_norm[3] - dilation_w * (kernel_size_w - 1)) / stride_w) pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0 elif self.pad_mode == "same": - h_out = math.ceil(x_shape[2] / stride_h) - w_out = math.ceil(x_shape[3] / stride_w) + h_out = math.ceil(x_shape_norm[2] / stride_h) + w_out = math.ceil(x_shape_norm[3] / stride_w) - pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]) + pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape_norm[2]) pad_top = math.floor(pad_needed_h / 2) pad_bottom = pad_needed_h - pad_top - pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]) + pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape_norm[3]) pad_left = math.floor(pad_needed_w / 2) pad_right = pad_needed_w - pad_left elif self.pad_mode == 'pad': pad_top, pad_bottom, pad_left, pad_right = self.padding - h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \ - / stride_h - w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \ - / stride_w + h_out = 1 + (x_shape_norm[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) \ + * (dilation_h - 1)) / stride_h + w_out = 1 + (x_shape_norm[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) \ + * (dilation_w - 1)) / stride_w h_out = math.floor(h_out) w_out = math.floor(w_out) self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) out_channel = self.out_channel - out_shape = [x_shape[0], out_channel, h_out, w_out] + out_shape = [x_shape_norm[0], out_channel, h_out, w_out] if self.format == "NCHW" else\ + [x_shape_norm[0], h_out, w_out, out_channel] _check_shape('output', out_shape, self.name) return out_shape @@ -1226,18 +1250,23 @@ class _Pool(PrimitiveWithInfer): a tuple of two `int` for height and width. Default: 1. padding (str): The optional value for pad mode, is "same" or "valid", not case sensitive. Default: "valid". + data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. + Default: "NCHW". """ @prim_attr_register - def __init__(self, ksize=1, strides=1, padding="valid"): + def __init__(self, ksize=1, strides=1, padding="valid", data_format="NCHW"): self.init_prim_io_names(inputs=['x'], outputs=['output']) validator.check_value_type('ksize', ksize, [int, tuple], self.name) validator.check_value_type('strides', strides, [int, tuple], self.name) self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name) self.add_prim_attr("padding", self.padding) self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax") + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") if not self.is_maxpoolwithargmax: - self.add_prim_attr('data_format', "NCHW") + self.add_prim_attr('data_format', self.format) self.ksize = _check_positive_int_or_tuple("ksize", ksize, self.name, allow_four=False, ret_four=True) if self.is_maxpoolwithargmax: @@ -1250,8 +1279,9 @@ class _Pool(PrimitiveWithInfer): self.add_prim_attr("strides", self.strides) def infer_shape(self, x_shape): - validator.check_equal_int(len(x_shape), 4, "x rank", self.name) - batch, channel, input_h, input_w = x_shape + x_shape_norm = x_shape if self.format == "NCHW" else [x_shape[0], x_shape[3], x_shape[1], x_shape[2]] + validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name) + batch, channel, input_h, input_w = x_shape_norm if self.is_maxpoolwithargmax: _, kernel_h, kernel_w, _ = self.ksize _, stride_h, stride_w, _ = self.strides @@ -1265,7 +1295,7 @@ class _Pool(PrimitiveWithInfer): elif self.padding == "SAME": out_h = math.ceil(input_h / stride_h) out_w = math.ceil(input_w / stride_w) - out_shape = [batch, channel, out_h, out_w] + out_shape = [batch, channel, out_h, out_w] if self.format == "NCHW" else [batch, out_h, out_w, channel] for shape_value in out_shape: if shape_value <= 0: @@ -1301,6 +1331,8 @@ class MaxPool(_Pool): represent height and width of movement respectively. Default: 1. padding (str): The optional value for pad mode, is "same" or "valid", not case sensitive. Default: "valid". + format (str) : The optional value for data format, is 'NHWC' or 'NCHW'. + Default: 'NCHW'. - same: Adopts the way of completion. The height and width of the output will be the same as the input. The total number of padding will be calculated in horizontal and vertical @@ -1323,8 +1355,8 @@ class MaxPool(_Pool): """ @prim_attr_register - def __init__(self, ksize=1, strides=1, padding="valid"): - super(MaxPool, self).__init__(ksize, strides, padding) + def __init__(self, ksize=1, strides=1, padding="valid", data_format="NCHW"): + super(MaxPool, self).__init__(ksize, strides, padding, data_format) class MaxPoolWithArgmax(_Pool): @@ -1374,8 +1406,8 @@ class MaxPoolWithArgmax(_Pool): >>> output_tensor, argmax = maxpool_arg_op(input_tensor) """ - def __init__(self, ksize=1, strides=1, padding="valid"): - super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding) + def __init__(self, ksize=1, strides=1, padding="valid", data_format="NCHW"): + super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding, data_format) self.is_tbe = context.get_context("device_target") == "Ascend" self.is_gpu = context.get_context("device_target") == "GPU" @@ -1439,6 +1471,8 @@ class AvgPool(_Pool): - valid: Adopts the way of discarding. The possible largest height and width of output will be returned without padding. Extra pixels will be discarded. + data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\ + default is 'NCHW'. Inputs: - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. @@ -1473,14 +1507,14 @@ class AvgPool(_Pool): """ @prim_attr_register - def __init__(self, ksize=1, strides=1, padding="valid"): + def __init__(self, ksize=1, strides=1, padding="valid", data_format="NCHW"): if context.get_context("device_target") == "GPU": self.target = "GPU" elif context.get_context("enable_ge"): self.target = "GE" else: self.target = "OTHER" - super(AvgPool, self).__init__(ksize, strides, padding) + super(AvgPool, self).__init__(ksize, strides, padding, data_format) class Conv2DBackpropInput(PrimitiveWithInfer): @@ -1500,6 +1534,8 @@ class Conv2DBackpropInput(PrimitiveWithInfer): dilation (Union[int. tuple[int]]): Specifies the dilation rate to be used for the dilated convolution. Default: 1. group (int): Splits input into groups. Default: 1. + data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\ + default is 'NCHW'. Returns: Tensor, the gradients of convolution. @@ -1522,7 +1558,8 @@ class Conv2DBackpropInput(PrimitiveWithInfer): mode=1, stride=1, dilation=1, - group=1): + group=1, + data_format="NCHW"): """Initialize Conv2DBackpropInput""" self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) @@ -1549,7 +1586,10 @@ class Conv2DBackpropInput(PrimitiveWithInfer): self.add_prim_attr('pad_mode', pad_mode) self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) self.group = validator.check_positive_int(group, 'group', self.name) - self.add_prim_attr('data_format', "NCHW") + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") + self.add_prim_attr('data_format', self.format) if pad_list: for x in pad_list: validator.check_non_negative_int(x, 'element of pad_list', self.name) @@ -1566,6 +1606,8 @@ class Conv2DBackpropInput(PrimitiveWithInfer): # infer shape dout_shape = doutput['shape'] + dout_shape_norm = dout_shape if self.format == "NCHW" else\ + [dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]] kernel_h = self.kernel_size[0] kernel_w = self.kernel_size[1] stride_h = self.stride[0] @@ -1577,11 +1619,11 @@ class Conv2DBackpropInput(PrimitiveWithInfer): if self.pad_list: pad_list = tuple(self.pad_list) elif self.pad_mode == "SAME": - pad_needed_h = max(0, (dout_shape[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2]) + pad_needed_h = max(0, (dout_shape_norm[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2]) pad_top = math.floor(pad_needed_h / 2) pad_bottom = pad_needed_h - pad_top - pad_needed_w = max(0, (dout_shape[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3]) + pad_needed_w = max(0, (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3]) pad_left = math.floor(pad_needed_w / 2) pad_right = pad_needed_w - pad_left pad_list = (pad_top, pad_bottom, pad_left, pad_right) @@ -1606,6 +1648,8 @@ class BiasAdd(PrimitiveWithInfer): Inputs: - **input_x** (Tensor) - The input tensor. The shape can be 2-4 dimensions. - **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. + - **data_format** (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\ + default is 'NCHW'. The shape of `bias` must be the same as `input_x` in the second dimension. Outputs: @@ -1619,14 +1663,18 @@ class BiasAdd(PrimitiveWithInfer): """ @prim_attr_register - def __init__(self): + def __init__(self, data_format="NCHW"): self.init_prim_io_names(inputs=['x', 'b'], outputs=['output']) - self.add_prim_attr('data_format', 'NCHW') + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") + self.add_prim_attr('data_format', self.format) def infer_shape(self, x_shape, b_shape): validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) - validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_shape[1], Rel.EQ, self.name) + x_channel = x_shape[1] if self.format == "NCHW" else x_shape[-1] + validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_channel, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, b_type):