!7458 GPU kernel support NHWC layout

Merge pull request !7458 from VectorSL/nhwc
This commit is contained in:
mindspore-ci-bot 2020-10-26 14:38:03 +08:00 committed by Gitee
commit b4ce0aa933
22 changed files with 362 additions and 300 deletions

View File

@ -27,61 +27,54 @@ namespace kernel {
template <typename T, typename S>
class CombineMomentumGpuKernel : public GpuKernel {
public:
CombineMomentumGpuKernel() : element_num_(1), num_(0), max_(0), input_num_(6) {}
CombineMomentumGpuKernel() : element_num_(1), num_(0), input_num_(6) {}
~CombineMomentumGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &workspace, void *stream_ptr) override {
const cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
auto weight_decay = std::make_unique<T *[]>(input_num_ * num_);
auto scale = std::make_unique<T *[]>(input_num_ * num_);
auto variable = std::make_unique<T *[]>(input_num_ * num_);
auto accumulation = std::make_unique<T *[]>(input_num_ * num_);
auto learning_rate = std::make_unique<T *[]>(input_num_ * num_);
auto gradient = std::make_unique<S *[]>(input_num_ * num_);
auto momentum = std::make_unique<T *[]>(input_num_ * num_);
if (input_num_ == 6) {
LaunchCombineMom(inputs, workspace, stream, scale, variable, accumulation, learning_rate, gradient, momentum);
} else {
LaunchCombineMomWeightDecay(inputs, workspace, stream, weight_decay, scale, variable, accumulation, learning_rate,
gradient, momentum);
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
void *stream_ptr) override {
auto stream = reinterpret_cast<cudaStream_t>(stream_ptr);
for (size_t i = 0; i < num_; i++) {
if (input_num_ == 6) {
T *scale = GetDeviceAddress<T>(inputs, i * input_num_);
T *variable = GetDeviceAddress<T>(inputs, i * input_num_ + 1);
T *acc = GetDeviceAddress<T>(inputs, i * input_num_ + 2);
T *lr = GetDeviceAddress<T>(inputs, i * input_num_ + 3);
S *grad = GetDeviceAddress<S>(inputs, i * input_num_ + 4);
T *mom = GetDeviceAddress<T>(inputs, i * input_num_ + 5);
FusedScaleMomentum(elements_[i], scale, variable, acc, lr, grad, mom, stream);
} else {
T *weight_decay = GetDeviceAddress<T>(inputs, i * input_num_);
T *scale = GetDeviceAddress<T>(inputs, i * input_num_ + 1);
T *variable = GetDeviceAddress<T>(inputs, i * input_num_ + 2);
T *acc = GetDeviceAddress<T>(inputs, i * input_num_ + 3);
T *lr = GetDeviceAddress<T>(inputs, i * input_num_ + 4);
S *grad = GetDeviceAddress<S>(inputs, i * input_num_ + 5);
T *mom = GetDeviceAddress<T>(inputs, i * input_num_ + 6);
FusedWeightDecayScaleMomentum(elements_[i], weight_decay, scale, variable, acc, lr, grad, mom, stream);
}
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
num_ = GetAttr<size_t>(kernel_node, "n");
elements_ = std::make_unique<size_t[]>(num_);
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == "CombineMomentum") {
input_num_ = 6;
} else {
input_num_ = 7;
workspace_size_list_.push_back(sizeof(T *) * num_);
}
for (size_t i = 0; i < num_; i++) {
element_num_ = 1;
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i * input_num_ + input_num_ - 4);
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i * input_num_ + input_num_ - 5);
for (size_t j = 0; j < variable_shape.size(); j++) {
element_num_ *= variable_shape[j];
}
if (max_ < element_num_) {
max_ = element_num_;
}
elements_[i] = element_num_;
elements_.push_back(element_num_);
InitSizeLists();
}
workspace_size_list_.push_back(sizeof(T *) * num_);
workspace_size_list_.push_back(sizeof(T *) * num_);
workspace_size_list_.push_back(sizeof(T *) * num_);
workspace_size_list_.push_back(sizeof(T *) * num_);
workspace_size_list_.push_back(sizeof(S *) * num_);
workspace_size_list_.push_back(sizeof(T *) * num_);
workspace_size_list_.push_back(sizeof(size_t) * num_);
return true;
}
@ -100,102 +93,9 @@ class CombineMomentumGpuKernel : public GpuKernel {
}
private:
void LaunchCombineMom(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const cudaStream_t &stream, const std::unique_ptr<T *[]> &scale,
const std::unique_ptr<T *[]> &variable, const std::unique_ptr<T *[]> &accumulation,
const std::unique_ptr<T *[]> &learning_rate, const std::unique_ptr<S *[]> &gradient,
const std::unique_ptr<T *[]> &momentum) {
for (size_t i = 0; i < num_; i++) {
scale[i] = GetDeviceAddress<T>(inputs, i * input_num_);
variable[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 1);
accumulation[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 2);
learning_rate[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 3);
gradient[i] = GetDeviceAddress<S>(inputs, i * input_num_ + 4);
momentum[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 5);
}
T **scale_dev = GetDeviceAddress<T *>(workspace, 0);
T **variable_dev = GetDeviceAddress<T *>(workspace, 1);
T **accumulation_dev = GetDeviceAddress<T *>(workspace, 2);
T **learning_rate_dev = GetDeviceAddress<T *>(workspace, 3);
S **gradient_dev = GetDeviceAddress<S *>(workspace, 4);
T **momentum_dev = GetDeviceAddress<T *>(workspace, 5);
size_t *elements_dev = GetDeviceAddress<size_t>(workspace, 6);
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(scale_dev, scale.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), "cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(variable_dev, variable.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(accumulation_dev, accumulation.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(learning_rate_dev, learning_rate.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(gradient_dev, gradient.get(), sizeof(S *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(momentum_dev, momentum.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(elements_dev, elements_.get(), sizeof(size_t) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CombineFusedScaleMomentum(max_, num_, elements_dev, scale_dev, variable_dev, accumulation_dev, learning_rate_dev,
gradient_dev, momentum_dev, stream);
}
void LaunchCombineMomWeightDecay(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const cudaStream_t &stream, const std::unique_ptr<T *[]> &weight_decay,
const std::unique_ptr<T *[]> &scale, const std::unique_ptr<T *[]> &variable,
const std::unique_ptr<T *[]> &accumulation,
const std::unique_ptr<T *[]> &learning_rate, const std::unique_ptr<S *[]> &gradient,
const std::unique_ptr<T *[]> &momentum) {
for (size_t i = 0; i < num_; i++) {
weight_decay[i] = GetDeviceAddress<T>(inputs, i * input_num_);
scale[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 1);
variable[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 2);
accumulation[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 3);
learning_rate[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 4);
gradient[i] = GetDeviceAddress<S>(inputs, i * input_num_ + 5);
momentum[i] = GetDeviceAddress<T>(inputs, i * input_num_ + 6);
}
T **weight_decay_dev = GetDeviceAddress<T *>(workspace, 0);
T **scale_dev = GetDeviceAddress<T *>(workspace, 1);
T **variable_dev = GetDeviceAddress<T *>(workspace, 2);
T **accumulation_dev = GetDeviceAddress<T *>(workspace, 3);
T **learning_rate_dev = GetDeviceAddress<T *>(workspace, 4);
S **gradient_dev = GetDeviceAddress<S *>(workspace, 5);
T **momentum_dev = GetDeviceAddress<T *>(workspace, 6);
size_t *elements_dev = GetDeviceAddress<size_t>(workspace, 7);
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(weight_decay_dev, weight_decay.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(scale_dev, scale.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream), "cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(variable_dev, variable.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(accumulation_dev, accumulation.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(learning_rate_dev, learning_rate.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(gradient_dev, gradient.get(), sizeof(S *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(momentum_dev, momentum.get(), sizeof(T *) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(elements_dev, elements_.get(), sizeof(size_t) * num_, cudaMemcpyHostToDevice, stream),
"cudaMemCPY failed")
CombineFusedWeightDecayScaleMomentum(max_, num_, elements_dev, weight_decay_dev, scale_dev, variable_dev,
accumulation_dev, learning_rate_dev, gradient_dev, momentum_dev, stream);
}
size_t element_num_;
std::unique_ptr<size_t[]> elements_;
std::vector<size_t> elements_;
size_t num_;
size_t max_;
int input_num_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;

View File

@ -17,12 +17,13 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2DGPUKERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2DGPUKERNEL_H_
#include <vector>
#include <string>
#include <algorithm>
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore {
@ -77,7 +78,7 @@ class Conv2dGpuFwdKernel : public GpuKernel {
const float beta = 0;
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) {
T *padded_addr = GetDeviceAddress<T>(workspace, 1);
if (data_format_ == "NHWC") {
if (data_format_ == kOpFormat_NHWC) {
CalPadNHWC(padded_size_ / sizeof(T), input_addr, n_, old_height_, old_width_, c_, old_height_ + pad_height_,
old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
@ -106,6 +107,10 @@ class Conv2dGpuFwdKernel : public GpuKernel {
}
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "data_format");
if (format_attr == kOpFormat_NHWC) {
data_format_ = kOpFormat_NHWC;
}
auto in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
@ -116,7 +121,7 @@ class Conv2dGpuFwdKernel : public GpuKernel {
return true;
}
SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
if (data_format_ == "NHWC") {
if (data_format_ == kOpFormat_NHWC) {
compute_format_ = CUDNN_TENSOR_NHWC;
}
Set4DDesc(in_shape, filter_shape, output_shape);
@ -144,12 +149,12 @@ class Conv2dGpuFwdKernel : public GpuKernel {
}
int dimA[4];
int strideApadded[4];
if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") {
if (data_format_ == kOpFormat_NCHW || data_format_ == kOpFormat_DEFAULT) {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
} else if (data_format_ == "NHWC") {
} else if (data_format_ == kOpFormat_NHWC) {
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
IntToSize(c_)};
SetDimA(padded_shape, dimA, 4, data_format_);
@ -324,7 +329,7 @@ class Conv2dGpuFwdKernel : public GpuKernel {
cudnnConvolutionDescriptor_t conv_desc_;
cudnnTensorDescriptor_t padded_desc_;
std::string pad_mode_;
std::string data_format_ = "NCHW";
std::string data_format_ = kOpFormat_NCHW;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

View File

@ -17,12 +17,13 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_FILTER_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <algorithm>
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore {
@ -79,7 +80,7 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
if ((pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) && use_pad_) {
T *padded = GetDeviceAddress<T>(workspace, 1);
if (data_format_ == "NHWC") {
if (data_format_ == kOpFormat_NHWC) {
CalPadNHWC(padded_size_ / sizeof(T), x, n_, old_height_, old_width_, c_, old_height_ + pad_height_,
old_width_ + pad_width_, pad_top_, pad_left_, pad_value_, padded,
reinterpret_cast<cudaStream_t>(stream_ptr));
@ -115,9 +116,13 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
return true;
}
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
format_attr_ = GetAttr<std::string>(kernel_node, "data_format");
if (format_attr_ == kOpFormat_NHWC) {
data_format_ = kOpFormat_NHWC;
}
std::vector<size_t> filter_shape;
GetFilterShape(kernel_node, &filter_shape);
if (data_format_ == "NHWC") {
if (data_format_ == kOpFormat_NHWC) {
compute_format_ = CUDNN_TENSOR_NHWC;
}
SetNCHW(in_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
@ -145,12 +150,12 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
}
int dimA[4];
int strideApadded[4];
if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") {
if (data_format_ == kOpFormat_NCHW || data_format_ == kOpFormat_DEFAULT) {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
} else if (data_format_ == "NHWC") {
} else if (data_format_ == kOpFormat_NHWC) {
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
IntToSize(c_)};
SetDimA(padded_shape, dimA, 4, data_format_);
@ -292,10 +297,9 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
SetStrideA(in_shape, strideAin, 4, data_format_);
SetDimA(dy_shape, dimAdy, 4, data_format_);
SetStrideA(dy_shape, strideAdy, 4, data_format_);
// filter shape always keep OIHW.
int filterDimA[4] = {SizeToInt(filter_shape[0]), SizeToInt(filter_shape[1]), SizeToInt(filter_shape[2]),
SizeToInt(filter_shape[3])};
// filter shape relued by format_attr_. In native mode it's OHWI. In transpose mode it's OIHW.
int filterDimA[4];
SetDimA(filter_shape, filterDimA, 4, format_attr_);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(dy_desc_, cudnn_data_type_, nbDims, dimAdy, strideAdy),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
@ -325,7 +329,8 @@ class ConvGradFilterGpuBkwKernel : public GpuKernel {
cudnnTensorDescriptor_t padded_descriptor_;
cudnnConvolutionBwdFilterAlgo_t algo_;
std::string pad_mode_;
std::string data_format_ = "NCHW";
std::string data_format_ = kOpFormat_NCHW;
std::string format_attr_ = kOpFormat_NCHW;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

View File

@ -17,12 +17,13 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CONV2D_GRAD_INPUT_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <algorithm>
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore {
@ -83,7 +84,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
cudnnConvolutionBackwardData(cudnn_handle_, &alpha, w_desc_, w, dy_desc_, dy, conv_desc_, algo_, work_space,
workspace_size_, &beta_, padded_descriptor_, padded),
"ConvolutionBackwardData failed");
if (data_format_ == "NHWC") {
if (data_format_ == kOpFormat_NHWC) {
CalPadGradNHWC(output_size_ / sizeof(T), padded, n_, old_height_, old_width_, c_, old_height_ + pad_height_,
old_width_ + pad_width_, pad_top_, pad_left_, dx, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
@ -105,6 +106,10 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
}
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "data_format");
if (format_attr == kOpFormat_NHWC) {
data_format_ = kOpFormat_NHWC;
}
auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto filter_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
is_null_input_ = CHECK_NULL_INPUT(dy_shape);
@ -116,9 +121,11 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
std::vector<size_t> input_shape;
GetInputShape(kernel_node, &input_shape);
if (data_format_ == "NHWC") {
if (data_format_ == kOpFormat_NHWC) {
compute_format_ = CUDNN_TENSOR_NHWC;
ShapeNCHW2NHWC(&input_shape);
if (format_attr == kOpFormat_NCHW) {
ShapeNCHW2NHWC(&input_shape);
}
}
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
Set4DDesc(dy_shape, input_shape, filter_shape);
@ -146,12 +153,12 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
}
int dimA[4];
int strideApadded[4];
if (data_format_ == "NCHW" || data_format_ == "DefaultFormat") {
if (data_format_ == kOpFormat_NCHW || data_format_ == kOpFormat_DEFAULT) {
auto padded_shape = {IntToSize(n_), IntToSize(c_), IntToSize(old_height_ + pad_height_),
IntToSize(old_width_ + pad_width_)};
SetDimA(padded_shape, dimA, 4, data_format_);
SetStrideA(padded_shape, strideApadded, 4, data_format_);
} else if (data_format_ == "NHWC") {
} else if (data_format_ == kOpFormat_NHWC) {
auto padded_shape = {IntToSize(n_), IntToSize(old_height_ + pad_height_), IntToSize(old_width_ + pad_width_),
IntToSize(c_)};
SetDimA(padded_shape, dimA, 4, data_format_);
@ -326,7 +333,7 @@ class ConvGradInputGpuBkwKernel : public GpuKernel {
cudnnTensorDescriptor_t padded_descriptor_;
cudnnConvolutionBwdDataAlgo_t algo_;
std::string pad_mode_;
std::string data_format_ = "NCHW";
std::string data_format_ = kOpFormat_NCHW;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

View File

@ -17,8 +17,8 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_EX_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_EX_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
@ -131,6 +131,10 @@ class FusedBatchNormExGpuKernel : public GpuKernel {
return true;
}
auto format = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "data_format");
if (format_attr == kOpFormat_NHWC) {
format = kOpFormat_NHWC;
}
SetTensorDescriptor(format, shape);
InitSizeLists();
return true;

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GPU_KERNEL_H_
#include <string>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
@ -98,11 +99,14 @@ class FusedBatchNormGpuKernel : public GpuKernel {
InitSizeLists();
return true;
}
batch_ = SizeToInt(shape[0]);
channel_ = SizeToInt(shape[1]);
height_ = SizeToInt(shape[2]);
width_ = SizeToInt(shape[3]);
cudnnTensorFormat_t cudnn_format = CUDNN_TENSOR_NCHW;
auto format = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "data_format");
if (format_attr == kOpFormat_NHWC) {
format = kOpFormat_NHWC;
cudnn_format = CUDNN_TENSOR_NHWC;
}
SetNCHW(shape, &batch_, &channel_, &height_, &width_, format);
mode_ = CUDNN_BATCHNORM_SPATIAL;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
// P.FusedBatchNorm is used for training; P.BatchNorm is used for inference
@ -113,15 +117,15 @@ class FusedBatchNormGpuKernel : public GpuKernel {
}
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
"Set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, batch_, channel_, height_, width_),
cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch_, channel_, height_, width_),
"Set y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, cudnn_format, CUDNN_DATA_FLOAT, 1, channel_, 1, 1),
"Set para desc failed");
InitSizeLists();

View File

@ -17,12 +17,13 @@
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <vector>
#include "utils/utils.h"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "utils/utils.h"
namespace mindspore {
namespace kernel {
@ -140,6 +141,10 @@ class FusedBatchNormGradExGpuKernel : public GpuKernel {
return true;
}
std::string format = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "data_format");
if (format_attr == kOpFormat_NHWC) {
format = kOpFormat_NHWC;
}
beta_data_diff_ = GetAttrWithDefault(kernel_node, "inplace_algo", std::string("cover")) == "cover" ? 0 : 1;
SetTensorDescriptor(format, shape);
InitSizeLists();

View File

@ -78,6 +78,10 @@ class PoolingGpuFwdKernel : public GpuKernel {
}
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "data_format");
if (format_attr == kOpFormat_NHWC) {
data_format_ = kOpFormat_NHWC;
}
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
@ -200,7 +204,7 @@ class PoolingGpuFwdKernel : public GpuKernel {
std::vector<int> stride_;
std::string mode_;
std::string pad_mode_;
std::string data_format_ = "NCHW";
std::string data_format_ = kOpFormat_NCHW;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

View File

@ -86,6 +86,10 @@ class PoolingGradGpuKernel : public GpuKernel {
auto dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0);
auto format_attr = GetAttr<std::string>(kernel_node, "data_format");
if (format_attr == kOpFormat_NHWC) {
data_format_ = kOpFormat_NHWC;
}
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_mask);
if (is_null_input_) {
@ -236,7 +240,7 @@ class PoolingGradGpuKernel : public GpuKernel {
std::vector<size_t> workspace_size_list_;
std::string mode_;
std::string pad_mode_;
std::string data_format_ = "NCHW";
std::string data_format_ = kOpFormat_NCHW;
cudnnDataType_t cudnn_data_type_;
cudnnTensorFormat_t compute_format_;
int old_height_;

View File

@ -46,8 +46,10 @@ const AnfNodePtr BatchNormAddReluFusion::Process(const FuncGraphPtr &graph, cons
MS_EXCEPTION_IF_NULL(tuple_get_item);
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
MS_EXCEPTION_IF_NULL(batch_norm_ex);
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) {
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("data_format");
MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}

View File

@ -123,8 +123,10 @@ const AnfNodePtr BatchNormAddReluGradFusion::Process(const FuncGraphPtr &graph,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) {
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}

View File

@ -43,8 +43,10 @@ const AnfNodePtr BatchNormReluFusion::Process(const FuncGraphPtr &graph, const A
MS_EXCEPTION_IF_NULL(tuple_get_item);
auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(tuple_get_item), 0);
MS_EXCEPTION_IF_NULL(batch_norm_ex);
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC) {
auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("data_format");
MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}

View File

@ -38,8 +38,10 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC) {
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr);
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}

View File

@ -26,6 +26,8 @@
#include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h"
#include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h"
#include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h"
#include "backend/optimizer/gpu/combine_momentum_fusion.h"
#include "backend/optimizer/gpu/combine_cast_fusion.h"
#include "backend/optimizer/gpu/cudnn_inplace_fusion.h"
#include "backend/optimizer/gpu/insert_format_transform_op.h"
#include "backend/optimizer/gpu/replace_momentum_cast_fusion.h"
@ -85,6 +87,10 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::AdamWeightDecayFusion>());
pm->AddPass(std::make_shared<opt::AdamFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumWeightDecayScaleFusion>());
pm->AddPass(std::make_shared<opt::ApplyMomentumScaleFusion>());
pm->AddPass(std::make_shared<opt::CastAllFusion>("cast_all"));
pm->AddPass(std::make_shared<opt::CombineMomentumFusion>("combine_momentum"));
pm->AddPass(std::make_shared<opt::ReplaceMomentumCastFusion>());
pm->AddPass(std::make_shared<opt::ReplaceAddNFusion>());
optimizer->AddPassManager(pm);
@ -98,6 +104,7 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
pm->AddPass(std::make_shared<opt::BatchNormReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormReluGradFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluFusion>());
pm->AddPass(std::make_shared<opt::BatchNormAddReluGradFusion>());
pm->AddPass(std::make_shared<opt::InsertFormatTransformOp>());
pm->AddPass(std::make_shared<opt::RemoveFormatTransformPair>());
pm->AddPass(std::make_shared<opt::RemoveRedundantFormatTransform>());

View File

@ -341,6 +341,12 @@ void FormatTransformChecker::CheckSupportFormatTransform(const std::shared_ptr<s
format_transform_ = false;
return;
}
auto value = AnfAlgo::GetCNodePrimitive(kernel);
if (value != nullptr && value->GetAttr("data_format") != nullptr &&
GetValue<std::string>(value->GetAttr("data_format")) == kOpFormat_NHWC) {
format_transform_ = false;
return;
}
if (kernel_name == prim::kPrimConv2D->name()) {
conv_cnt++;
}

View File

@ -85,6 +85,8 @@ constexpr auto kSplitVOpName = "SplitV";
constexpr auto kSparseApplyAdagradOpName = "SparseApplyAdagrad";
constexpr auto kMomentumOpName = "Momentum";
constexpr auto kApplyMomentumOpName = "ApplyMomentum";
constexpr auto kCombineMomentumOpName = "CombineMomentum";
constexpr auto kCombineMomentumWeightOpName = "CombineMomentumWeight";
constexpr auto kApplyAdadeltaOpName = "ApplyAdadelta";
constexpr auto kApplyAdagradOpName = "ApplyAdagrad";
constexpr auto kApplyAdagradDAName = "ApplyAdagradDA";
@ -374,38 +376,38 @@ const std::set<std::string> kOpFormatList = {
kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM};
const std::set<std::string> kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN};
const std::set<std::string> kOptOperatorSet = {
kMomentumOpName,
kApplyMomentumOpName,
kApplyAdadeltaOpName,
kApplyAdagradOpName,
kApplyAdagradDAName,
kApplyAdamOpName,
kApplyAdaMaxOpName,
kApplyAddSignOpName,
kApplyCenteredRMSPOpName,
kApplyFtrlOpName,
kApplyFtrlV2OpName,
kApplyGradientDescentOpName,
kApplyPowerSignOpName,
kApplyProximalAdagradOpName,
kApplyProximalGradientDescentOpName,
kApplyRMSPropOpName,
kFusedAdamWeightDecayName,
kFusedAdamName,
kFusedSparseAdamName,
kFusedWeightScaleApplyMomentum,
kFusedScaleApplyMomentum,
kApplyCenteredRMSPropOpName,
kFusedSparseFtrlName,
kFusedSparseProximalAdagradName,
kFusedSparseLazyAdamName,
kSparseApplyFtrlName,
kSparseApplyFtrlV2Name,
kSGDName,
kLARSUpdateName,
kPullOpName,
};
const std::set<std::string> kOptOperatorSet = {kMomentumOpName,
kApplyMomentumOpName,
kApplyAdadeltaOpName,
kApplyAdagradOpName,
kApplyAdagradDAName,
kApplyAdamOpName,
kApplyAdaMaxOpName,
kApplyAddSignOpName,
kApplyCenteredRMSPOpName,
kApplyFtrlOpName,
kApplyFtrlV2OpName,
kApplyGradientDescentOpName,
kApplyPowerSignOpName,
kApplyProximalAdagradOpName,
kApplyProximalGradientDescentOpName,
kApplyRMSPropOpName,
kFusedAdamWeightDecayName,
kFusedAdamName,
kFusedSparseAdamName,
kFusedWeightScaleApplyMomentum,
kFusedScaleApplyMomentum,
kApplyCenteredRMSPropOpName,
kFusedSparseFtrlName,
kFusedSparseProximalAdagradName,
kFusedSparseLazyAdamName,
kSparseApplyFtrlName,
kSparseApplyFtrlV2Name,
kSGDName,
kLARSUpdateName,
kPullOpName,
kCombineMomentumWeightOpName,
kCombineMomentumOpName};
const std::set<std::string> kHWSpecialFormatSet = {
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ,

View File

@ -45,6 +45,7 @@ class _Conv(Cell):
has_bias,
weight_init,
bias_init,
data_format='NCHW',
transposed=False):
super(_Conv, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
@ -54,6 +55,9 @@ class _Conv(Cell):
self.pad_mode = pad_mode
self.weight_init = weight_init
self.bias_init = bias_init
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
if isinstance(padding, int):
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
self.padding = padding
@ -89,7 +93,8 @@ class _Conv(Cell):
if transposed:
shape = [in_channels, out_channels // group, *kernel_size]
else:
shape = [out_channels, in_channels // group, *kernel_size]
shape = [out_channels, in_channels // group, *kernel_size] if self.format == "NCHW" else \
[out_channels, *kernel_size, in_channels // group]
self.weight = Parameter(initializer(self.weight_init, shape), name='weight')
if Validator.check_bool(has_bias):
@ -181,12 +186,15 @@ class Conv2d(_Conv):
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: 'NCHW'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})` \
or `(N, H_{in}, W_{in}, C_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})` or `(N, H_{out}, W_{out}, C_{out})`.
Examples:
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
@ -207,7 +215,8 @@ class Conv2d(_Conv):
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros'):
bias_init='zeros',
data_format='NCHW'):
kernel_size = twice(kernel_size)
stride = twice(stride)
self._dilation = dilation
@ -223,7 +232,8 @@ class Conv2d(_Conv):
group,
has_bias,
weight_init,
bias_init)
bias_init,
data_format)
self.conv2d = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
@ -231,7 +241,8 @@ class Conv2d(_Conv):
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
group=self.group,
data_format=self.format)
self._init_depthwise_conv2d()
self.bias_add = P.BiasAdd()
@ -263,8 +274,8 @@ class Conv2d(_Conv):
def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={},' \
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
'group={}, has_bias={},' \
'weight_init={}, bias_init={}'.format(
'group={}, has_bias={}' \
'weight_init={}, bias_init={}, format={}'.format(
self.in_channels,
self.out_channels,
self.kernel_size,
@ -275,7 +286,8 @@ class Conv2d(_Conv):
self.group,
self.has_bias,
self.weight_init,
self.bias_init)
self.bias_init,
self.format)
return s

View File

@ -44,14 +44,17 @@ class _BatchNorm(Cell):
moving_var_init='ones',
use_batch_statistics=None,
device_num_each_group=1,
input_dims='2d'):
input_dims='2d',
data_format='NCHW'):
super(_BatchNorm, self).__init__()
if num_features < 1:
raise ValueError("num_features must be at least 1")
if momentum < 0 or momentum > 1:
raise ValueError("momentum should be a number in range [0, 1], but got {}".format(momentum))
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.use_batch_statistics = use_batch_statistics
self.num_features = num_features
self.eps = eps
@ -99,7 +102,8 @@ class _BatchNorm(Cell):
elif self.is_gpu:
self.bn_train = P.FusedBatchNormEx(mode=1,
epsilon=self.eps,
momentum=self.momentum)
momentum=self.momentum,
data_format=self.format)
else:
self.bn_train = P.FusedBatchNorm(mode=1,
epsilon=self.eps,
@ -352,6 +356,8 @@ class BatchNorm2d(_BatchNorm):
use the mean value and variance value of specified value. If None, the training process will use the mean
and variance of current batch data and track the running mean and variance, the evaluation process will use
the running mean and variance. Default: None.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: 'NCHW'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -374,7 +380,8 @@ class BatchNorm2d(_BatchNorm):
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=None):
use_batch_statistics=None,
data_format='NCHW'):
super(BatchNorm2d, self).__init__(num_features,
eps,
momentum,
@ -384,7 +391,8 @@ class BatchNorm2d(_BatchNorm):
moving_mean_init,
moving_var_init,
use_batch_statistics,
input_dims='2d')
input_dims='2d',
data_format=data_format)
def _check_data_dim(self, x):
if x.dim() != 4:

View File

@ -25,10 +25,12 @@ __all__ = ['AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'MaxPool1d']
class _PoolNd(Cell):
"""N-D AvgPool"""
def __init__(self, kernel_size, stride, pad_mode):
def __init__(self, kernel_size, stride, pad_mode, data_format="NCHW"):
super(_PoolNd, self).__init__()
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
def _check_int_or_tuple(arg_name, arg_value):
validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name)
error_msg = f'For \'{self.cls_name}\' the {arg_name} should be an positive int number or ' \
@ -93,6 +95,8 @@ class MaxPool2d(_PoolNd):
- valid: Adopts the way of discarding. The possible largest height and width of output
will be returned without padding. Extra pixels will be discarded.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: 'NCHW'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -121,11 +125,12 @@ class MaxPool2d(_PoolNd):
[8. 8.]]]]
"""
def __init__(self, kernel_size=1, stride=1, pad_mode="valid"):
super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode)
def __init__(self, kernel_size=1, stride=1, pad_mode="valid", data_format="NCHW"):
super(MaxPool2d, self).__init__(kernel_size, stride, pad_mode, data_format)
self.max_pool = P.MaxPool(ksize=self.kernel_size,
strides=self.stride,
padding=self.pad_mode)
padding=self.pad_mode,
data_format=self.format)
self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size,
strides=self.stride,
padding=self.pad_mode)
@ -252,6 +257,8 @@ class AvgPool2d(_PoolNd):
- valid: Adopts the way of discarding. The possible largest height and width of output
will be returned without padding. Extra pixels will be discarded.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: 'NCHW'.
Inputs:
@ -284,11 +291,13 @@ class AvgPool2d(_PoolNd):
def __init__(self,
kernel_size=1,
stride=1,
pad_mode="valid"):
super(AvgPool2d, self).__init__(kernel_size, stride, pad_mode)
pad_mode="valid",
data_format="NCHW"):
super(AvgPool2d, self).__init__(kernel_size, stride, pad_mode, data_format)
self.avg_pool = P.AvgPool(ksize=self.kernel_size,
strides=self.stride,
padding=self.pad_mode)
padding=self.pad_mode,
data_format=self.format)
def construct(self, x):
return self.avg_pool(x)

View File

@ -31,7 +31,7 @@ from ... import context
@bprop_getters.register(P.BiasAdd)
def get_bprop_bias_add(self):
"""Grad definition for `BiasAdd` operation."""
bias_grad = SG.BiasAddGrad()
bias_grad = SG.BiasAddGrad(self.data_format)
def bprop(x, w, out, dout):
return dout, bias_grad(dout)
@ -44,11 +44,11 @@ def get_bprop_conv2d(self):
"""Grad definition for `Conv2D` operation."""
input_grad = P.Conv2DBackpropInput(
self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
dilation=self.dilation, stride=self.stride, group=self.group
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
)
filter_grad = G.Conv2DBackpropFilter(
self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
dilation=self.dilation, stride=self.stride, group=self.group
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
)
get_shape = P.Shape()
@ -224,7 +224,8 @@ def get_bprop_max_pool_grad(self):
maxpool_grad = G.MaxPoolGrad(
ksize=self.ksize,
strides=self.strides,
padding=self.padding)
padding=self.padding,
data_format=self.format)
def bprop(x, out, dout):
dx = maxpool_grad(x, out, dout)
@ -324,7 +325,8 @@ def get_bprop_avg_pool_grad(self):
avgpool_grad_gpu = G.AvgPoolGradGpu(
ksize=self.ksize,
strides=self.strides,
padding=self.padding)
padding=self.padding,
data_format=self.format)
def bprop_gpu(x, out, dout):
dx = avgpool_grad_gpu(x, out, dout)
@ -574,7 +576,7 @@ def get_bprop_fused_batch_norm(self):
@bprop_getters.register(P.FusedBatchNormEx)
def get_bprop_fused_batch_norm_ex(self):
"""Grad definition for `FusedBatchNormEx` operation."""
input_grad = G.FusedBatchNormGradEx(self.epsilon, self.momentum)
input_grad = G.FusedBatchNormGradEx(self.epsilon, self.momentum, self.format)
def bprop(x, scale, b, mean, variance, out, dout):
saved_mean = out[3]
@ -922,11 +924,11 @@ def get_bprop_conv2d_backprop_input(self):
"""Grad definition for `Conv2DBackpropInput` operation."""
filter_grad = G.Conv2DBackpropFilter(
self.out_channel, self.kernel_size, self.pad_mode, self.pad, self.pad_list, mode=self.mode,
dilation=self.dilation, stride=self.stride, group=self.group
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
)
input_grad = P.Conv2D(
self.out_channel, self.kernel_size, pad_mode=self.pad_mode.lower(), pad=self.pad,
dilation=self.dilation, stride=self.stride, group=self.group
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
)
def bprop(x, w, f_sizes, out, dout):

View File

@ -21,6 +21,7 @@ from ..._checkparam import Validator as validator, Rel
from .._utils import get_concat_offset
from ...common import dtype as mstype
from .. import functional as F
from ... import context
class AbsGrad(PrimitiveWithInfer):
"""Computes gradients for abs operation."""
@ -199,16 +200,23 @@ class BatchNormGrad(PrimitiveWithInfer):
return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)
class BiasAddGrad(Primitive):
class BiasAddGrad(PrimitiveWithInfer):
"""Computes gradients of BiasAdd."""
@prim_attr_register
def __init__(self):
def __init__(self, data_format="NCHW"):
self.init_prim_io_names(inputs=['dout'], outputs=['output'])
self.add_prim_attr('data_format', 'NCHW')
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
def __call__(self, d_output):
raise NotImplementedError
def infer_shape(self, d_output):
channel = d_output[1] if self.format == "NCHW" else d_output[-1]
return (channel,)
def infer_dtype(self, dout_dtype):
return dout_dtype
class KLDivLossGrad(PrimitiveWithInfer):
@ -291,6 +299,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
stride (tuple): The stride to be applied to the convolution filter. Default: (1, 1).
dilation (tuple): Specifies the dilation rate to be used for the dilated convolution. Default: (1, 1, 1, 1).
group (int): Splits input into groups. Default: 1.
data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW'\
default is 'NCHW'.
Returns:
Tensor, the gradients of convolution.
@ -306,7 +316,8 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
mode=1,
stride=(1, 1),
dilation=(1, 1, 1, 1),
group=1):
group=1,
data_format="NCHW"):
"""Initialize Convolution"""
self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output'])
self.out_channel = out_channel
@ -321,7 +332,10 @@ class Conv2DBackpropFilter(PrimitiveWithInfer):
self.dilation = dilation
self.group = group
self.add_prim_attr('groups', group)
self.add_prim_attr('data_format', "NCHW")
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
def __infer__(self, doutput, x, w_size):
w_size_v = w_size['value']
@ -530,10 +544,13 @@ class FusedBatchNormGradEx(PrimitiveWithInfer):
"""Gradients of FusedBatchNormEx operation."""
@prim_attr_register
def __init__(self, epsilon=0.0, momentum=0.1):
def __init__(self, epsilon=0.0, momentum=0.1, data_format="NCHW"):
self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance', 'reserve'],
outputs=['dx', 'bn_scale', 'bn_bias'])
self.add_prim_attr('data_format', "NCHW")
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, save_mean_shape, save_variance_shape, reserve_shape):
return (x_shape, scale_shape, scale_shape)
@ -604,16 +621,19 @@ class _PoolGrad(PrimitiveWithInfer):
"""Gradients of the max/avg pool operation."""
@prim_attr_register
def __init__(self, ksize, strides, padding="VALID"):
def __init__(self, ksize, strides, padding="VALID", data_format="NCHW"):
self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
validator.check_value_type('ksize', ksize, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
self.add_prim_attr("padding", self.padding)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
if not self.is_maxpoolgradwithargmax:
self.add_prim_attr('data_format', "NCHW")
self.add_prim_attr('data_format', self.format)
def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax):
validator.check_value_type(arg_name, arg_val, (int, tuple), self.name)
@ -633,10 +653,12 @@ class _PoolGrad(PrimitiveWithInfer):
raise error_msg
return ret
self.ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax)
ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax)
self.ksize = ksize if self.format == "NCHW" else [ksize[0], ksize[2], ksize[3], ksize[1]]
self.add_prim_attr("ksize", self.ksize)
self.strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax)
strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax)
self.strides = strides if self.format == "NCHW" else [strides[0], strides[2], strides[3], strides[1]]
self.add_prim_attr("strides", self.strides)
@ -679,8 +701,8 @@ class AvgPoolGradGpu(_PoolGrad):
"""Gradients of the avg pool operation for gpu."""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID"):
super(AvgPoolGradGpu, self).__init__(ksize, strides, padding)
def __init__(self, ksize=1, strides=1, padding="VALID", data_format="NCHW"):
super(AvgPoolGradGpu, self).__init__(ksize, strides, padding, data_format)
def infer_shape(self, x1_shape, x2_shape, grad_shape):
return x1_shape
@ -693,8 +715,8 @@ class MaxPoolGrad(_PoolGrad):
"""Performs gradients of the max pool operation."""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID"):
super(MaxPoolGrad, self).__init__(ksize, strides, padding)
def __init__(self, ksize=1, strides=1, padding="VALID", data_format="NCHW"):
super(MaxPoolGrad, self).__init__(ksize, strides, padding, data_format)
def infer_shape(self, x1_shape, x2_shape, grad_shape):
return x1_shape
@ -763,7 +785,7 @@ class MaxPoolGradWithArgmax(_PoolGrad):
"""Computes the gradients of MaxPoolWithArgmax."""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID",):
def __init__(self, ksize=1, strides=1, padding="VALID"):
self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
super(MaxPoolGradWithArgmax, self).__init__(ksize, strides, padding)

View File

@ -666,6 +666,8 @@ class FusedBatchNormEx(PrimitiveWithInfer):
momentum (float): The hyper parameter to compute moving average for running_mean and running_var
(e.g. :math:`new\_running\_mean = momentum * running\_mean + (1 - momentum) * current\_mean`).
Momentum value must be [0, 1]. Default: 0.9.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: "NCHW".
Inputs:
- **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`,
@ -706,20 +708,25 @@ class FusedBatchNormEx(PrimitiveWithInfer):
)
@prim_attr_register
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1, data_format="NCHW"):
self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'save_scale', 'save_bias', 'save_mean', 'save_inv_variance', 'reserve'])
self.mode = validator.check_int(mode, [0, 1], Rel.IN, 'mode', self.name)
self.epsilon = validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.momentum = validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
self._update_parameter = True
self.add_prim_attr('data_format', "NCHW")
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
def infer_shape(self, input_x, scale, bias, mean, variance):
input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2])
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name)
validator.check("scale shape[0]", scale[0], "input channel", input_shape_norm[1], Rel.EQ, self.name)
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name)
return (input_x, scale, scale, scale, scale, scale)
@ -868,6 +875,8 @@ class BatchNorm(PrimitiveWithInfer):
is_training (bool): If `is_training` is True, `mean` and `variance` are computed during training.
If `is_training` is False, they're loaded from checkpoint during inference. Default: False.
epsilon (float): A small value added for numerical stability. Default: 1e-5.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: "NCHW".
Inputs:
- **input_x** (Tensor) - Tensor of shape :math:`(N, C)`, with float16 or float32 data type.
@ -896,17 +905,21 @@ class BatchNorm(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, is_training=False, epsilon=1e-5):
def __init__(self, is_training=False, epsilon=1e-5, data_format="NCHW"):
validator.check_value_type('is_training', is_training, (bool,), self.name)
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
self.add_prim_attr('data_format', "NCHW")
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'],
outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2'])
def infer_shape(self, input_x, scale, bias, mean, variance):
input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2])
validator.check_equal_int(len(scale), 1, "scale rank", self.name)
validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name)
validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name)
validator.check("scale shape[0]", scale[0], "input_x channel", input_shape_norm[1], Rel.EQ, self.name)
if not self.is_training:
validator.check_equal_int(len(mean), 1, "mean rank", self.name)
validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name)
@ -970,6 +983,7 @@ class Conv2D(PrimitiveWithInfer):
stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
group (int): Splits input into groups. Default: 1.
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'. Default: "NCHW".
Returns:
Tensor, the value that applied 2D convolution.
@ -998,7 +1012,8 @@ class Conv2D(PrimitiveWithInfer):
pad=0,
stride=1,
dilation=1,
group=1):
group=1,
data_format="NCHW"):
"""Initialize Conv2D"""
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
@ -1021,54 +1036,63 @@ class Conv2D(PrimitiveWithInfer):
validator.check_non_negative_int(item, 'pad item', self.name)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.add_prim_attr('data_format', "NCHW")
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('offset_a', 0)
def infer_shape(self, x_shape, w_shape, b_shape=None):
validator.check_equal_int(len(w_shape), 4, "weight rank", self.name)
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
x_shape_norm = x_shape if self.format == "NCHW" else (x_shape[0], x_shape[3], x_shape[1], x_shape[2])
w_shape_norm = w_shape if self.format == "NCHW" else (w_shape[0], w_shape[3], w_shape[1], w_shape[2])
validator.check_equal_int(len(w_shape_norm), 4, "weight rank", self.name)
validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name)
validator.check(f"x_shape[1] / group", x_shape_norm[1] // self.group, "w_shape[1]", w_shape_norm[1], \
Rel.EQ, self.name)
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape_norm[0], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape_norm[2:4]), Rel.EQ, self.name)
kernel_size_h = w_shape_norm[2]
kernel_size_w = w_shape_norm[3]
kernel_size_h = w_shape[2]
kernel_size_w = w_shape[3]
stride_h = self.stride[2]
stride_w = self.stride[3]
dilation_h = self.dilation[2]
dilation_w = self.dilation[3]
if self.pad_mode == "valid":
h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
h_out = math.ceil((x_shape_norm[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
w_out = math.ceil((x_shape_norm[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
elif self.pad_mode == "same":
h_out = math.ceil(x_shape[2] / stride_h)
w_out = math.ceil(x_shape[3] / stride_w)
h_out = math.ceil(x_shape_norm[2] / stride_h)
w_out = math.ceil(x_shape_norm[3] / stride_w)
pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2])
pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape_norm[2])
pad_top = math.floor(pad_needed_h / 2)
pad_bottom = pad_needed_h - pad_top
pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3])
pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape_norm[3])
pad_left = math.floor(pad_needed_w / 2)
pad_right = pad_needed_w - pad_left
elif self.pad_mode == 'pad':
pad_top, pad_bottom, pad_left, pad_right = self.padding
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
/ stride_h
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
/ stride_w
h_out = 1 + (x_shape_norm[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) \
* (dilation_h - 1)) / stride_h
w_out = 1 + (x_shape_norm[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) \
* (dilation_w - 1)) / stride_w
h_out = math.floor(h_out)
w_out = math.floor(w_out)
self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
out_channel = self.out_channel
out_shape = [x_shape[0], out_channel, h_out, w_out]
out_shape = [x_shape_norm[0], out_channel, h_out, w_out] if self.format == "NCHW" else\
[x_shape_norm[0], h_out, w_out, out_channel]
_check_shape('output', out_shape, self.name)
return out_shape
@ -1226,18 +1250,23 @@ class _Pool(PrimitiveWithInfer):
a tuple of two `int` for height and width. Default: 1.
padding (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".
data_format (str): The optional value for data format, is 'NHWC' or 'NCHW'.
Default: "NCHW".
"""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="valid"):
def __init__(self, ksize=1, strides=1, padding="valid", data_format="NCHW"):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type('ksize', ksize, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name)
self.add_prim_attr("padding", self.padding)
self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax")
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
if not self.is_maxpoolwithargmax:
self.add_prim_attr('data_format', "NCHW")
self.add_prim_attr('data_format', self.format)
self.ksize = _check_positive_int_or_tuple("ksize", ksize, self.name, allow_four=False, ret_four=True)
if self.is_maxpoolwithargmax:
@ -1250,8 +1279,9 @@ class _Pool(PrimitiveWithInfer):
self.add_prim_attr("strides", self.strides)
def infer_shape(self, x_shape):
validator.check_equal_int(len(x_shape), 4, "x rank", self.name)
batch, channel, input_h, input_w = x_shape
x_shape_norm = x_shape if self.format == "NCHW" else [x_shape[0], x_shape[3], x_shape[1], x_shape[2]]
validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name)
batch, channel, input_h, input_w = x_shape_norm
if self.is_maxpoolwithargmax:
_, kernel_h, kernel_w, _ = self.ksize
_, stride_h, stride_w, _ = self.strides
@ -1265,7 +1295,7 @@ class _Pool(PrimitiveWithInfer):
elif self.padding == "SAME":
out_h = math.ceil(input_h / stride_h)
out_w = math.ceil(input_w / stride_w)
out_shape = [batch, channel, out_h, out_w]
out_shape = [batch, channel, out_h, out_w] if self.format == "NCHW" else [batch, out_h, out_w, channel]
for shape_value in out_shape:
if shape_value <= 0:
@ -1301,6 +1331,8 @@ class MaxPool(_Pool):
represent height and width of movement respectively. Default: 1.
padding (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".
format (str) : The optional value for data format, is 'NHWC' or 'NCHW'.
Default: 'NCHW'.
- same: Adopts the way of completion. The height and width of the output will be the same as
the input. The total number of padding will be calculated in horizontal and vertical
@ -1323,8 +1355,8 @@ class MaxPool(_Pool):
"""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="valid"):
super(MaxPool, self).__init__(ksize, strides, padding)
def __init__(self, ksize=1, strides=1, padding="valid", data_format="NCHW"):
super(MaxPool, self).__init__(ksize, strides, padding, data_format)
class MaxPoolWithArgmax(_Pool):
@ -1374,8 +1406,8 @@ class MaxPoolWithArgmax(_Pool):
>>> output_tensor, argmax = maxpool_arg_op(input_tensor)
"""
def __init__(self, ksize=1, strides=1, padding="valid"):
super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding)
def __init__(self, ksize=1, strides=1, padding="valid", data_format="NCHW"):
super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding, data_format)
self.is_tbe = context.get_context("device_target") == "Ascend"
self.is_gpu = context.get_context("device_target") == "GPU"
@ -1439,6 +1471,8 @@ class AvgPool(_Pool):
- valid: Adopts the way of discarding. The possible largest height and width of output
will be returned without padding. Extra pixels will be discarded.
data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW'\
default is 'NCHW'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -1473,14 +1507,14 @@ class AvgPool(_Pool):
"""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="valid"):
def __init__(self, ksize=1, strides=1, padding="valid", data_format="NCHW"):
if context.get_context("device_target") == "GPU":
self.target = "GPU"
elif context.get_context("enable_ge"):
self.target = "GE"
else:
self.target = "OTHER"
super(AvgPool, self).__init__(ksize, strides, padding)
super(AvgPool, self).__init__(ksize, strides, padding, data_format)
class Conv2DBackpropInput(PrimitiveWithInfer):
@ -1500,6 +1534,8 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
dilation (Union[int. tuple[int]]): Specifies the dilation rate to be used for the dilated convolution.
Default: 1.
group (int): Splits input into groups. Default: 1.
data_format (str) - The format of input and output data. It should be 'NHWC' or 'NCHW'\
default is 'NCHW'.
Returns:
Tensor, the gradients of convolution.
@ -1522,7 +1558,8 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
mode=1,
stride=1,
dilation=1,
group=1):
group=1,
data_format="NCHW"):
"""Initialize Conv2DBackpropInput"""
self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output'])
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
@ -1549,7 +1586,10 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
self.add_prim_attr('pad_mode', pad_mode)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('data_format', "NCHW")
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
if pad_list:
for x in pad_list:
validator.check_non_negative_int(x, 'element of pad_list', self.name)
@ -1566,6 +1606,8 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
# infer shape
dout_shape = doutput['shape']
dout_shape_norm = dout_shape if self.format == "NCHW" else\
[dout_shape[0], dout_shape[2], dout_shape[3], dout_shape[1]]
kernel_h = self.kernel_size[0]
kernel_w = self.kernel_size[1]
stride_h = self.stride[0]
@ -1577,11 +1619,11 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
if self.pad_list:
pad_list = tuple(self.pad_list)
elif self.pad_mode == "SAME":
pad_needed_h = max(0, (dout_shape[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2])
pad_needed_h = max(0, (dout_shape_norm[2] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_size_v[2])
pad_top = math.floor(pad_needed_h / 2)
pad_bottom = pad_needed_h - pad_top
pad_needed_w = max(0, (dout_shape[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3])
pad_needed_w = max(0, (dout_shape_norm[3] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_size_v[3])
pad_left = math.floor(pad_needed_w / 2)
pad_right = pad_needed_w - pad_left
pad_list = (pad_top, pad_bottom, pad_left, pad_right)
@ -1606,6 +1648,8 @@ class BiasAdd(PrimitiveWithInfer):
Inputs:
- **input_x** (Tensor) - The input tensor. The shape can be 2-4 dimensions.
- **bias** (Tensor) - The bias tensor, with shape :math:`(C)`.
- **data_format** (str) - The format of input and output data. It should be 'NHWC' or 'NCHW'\
default is 'NCHW'.
The shape of `bias` must be the same as `input_x` in the second dimension.
Outputs:
@ -1619,14 +1663,18 @@ class BiasAdd(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self):
def __init__(self, data_format="NCHW"):
self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])
self.add_prim_attr('data_format', 'NCHW')
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)
def infer_shape(self, x_shape, b_shape):
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
validator.check_equal_int(len(b_shape), 1, "bias rank", self.name)
validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_shape[1], Rel.EQ, self.name)
x_channel = x_shape[1] if self.format == "NCHW" else x_shape[-1]
validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_channel, Rel.EQ, self.name)
return x_shape
def infer_dtype(self, x_type, b_type):