diff --git a/mindspore/ccsrc/include/common/utils/utils.h b/mindspore/ccsrc/include/common/utils/utils.h index 35ac6d1023f..05de6d3ca15 100644 --- a/mindspore/ccsrc/include/common/utils/utils.h +++ b/mindspore/ccsrc/include/common/utils/utils.h @@ -867,7 +867,7 @@ const std::set DynamicShapeConstInputToAttrCPU = { const std::set DynamicShapeConstInputToAttrGPU = { kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName, kReduceMinOpName, kReduceMeanOpName, kReduceMaxOpName, kReduceAllOpName, - kReduceAnyOpName, kConcatOpName, kScatterNdOpName}; + kReduceAnyOpName, kConcatOpName, kScatterNdOpName, kGatherV2OpName, kAvgPool3DGradOpName}; // The map between kernel's output and input ref relationship. // Key is the output index while the value is input index which will be used as the reference of output. diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/avg_pool3d_helper_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/avg_pool3d_helper_impl.cu new file mode 100644 index 00000000000..909261f480d --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/avg_pool3d_helper_impl.cu @@ -0,0 +1,69 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/avg_pool3d_helper_impl.cuh" +#include "include/cuda_fp16.h" + +template +__global__ void RealKernelSize(const size_t size, T *kernel, const int64_t kernel_size, const int64_t shape_d, + const int64_t shape_h, const int64_t shape_w, const int64_t kernel_d, + const int64_t kernel_h, const int64_t kernel_w, const int64_t edge_kernel_d, + const int64_t edge_kernel_h, const int64_t edge_kernel_w) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + const int64_t d_max = shape_d - 1; + const int64_t h_max = shape_h - 1; + const int64_t w_max = shape_w - 1; + for (int64_t d = 0; d < shape_d; ++d) { + for (int64_t h = 0; h < shape_h; ++h) { + for (int64_t w = 0; w < shape_w; ++w) { + const int64_t valid_d = ((d == d_max) ? edge_kernel_d : kernel_d); + const int64_t valid_h = ((h == h_max) ? edge_kernel_h : kernel_h); + const int64_t valid_w = ((w == w_max) ? edge_kernel_w : kernel_w); + const int64_t cur_kernel_size = valid_d * valid_h * valid_w; + if (cur_kernel_size != kernel_size) { + const int64_t index = pos * shape_d * shape_h * shape_w + d * shape_h * shape_w + h * shape_w + w; + kernel[index] = + kernel[index] * static_cast(static_cast(cur_kernel_size) / static_cast(kernel_size)); + } + } + } + } + } +} + +template +void CalRealKernelSize(const std::vector &input_shape, const std::vector &kernel_size, + const std::vector &edge_kernel_size, T *kernel, const uint32_t &device_id, + cudaStream_t cuda_stream) { + const int64_t kernel_prod = kernel_size[2] * kernel_size[1] * kernel_size[2]; + const int64_t nc_size = input_shape[0] * input_shape[1]; + RealKernelSize<<>>( + nc_size, kernel, kernel_prod, input_shape[2], input_shape[3], input_shape[4], kernel_size[0], kernel_size[1], + kernel_size[2], edge_kernel_size[0], edge_kernel_size[1], edge_kernel_size[2]); +} + +template CUDA_LIB_EXPORT void CalRealKernelSize(const std::vector &input_shape, + const std::vector &kernel_size, + const std::vector &edge_kernel_size, double *kernel, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalRealKernelSize(const std::vector &input_shape, + const std::vector &kernel_size, + const std::vector &edge_kernel_size, float *kernel, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalRealKernelSize(const std::vector &input_shape, + const std::vector &kernel_size, + const std::vector &edge_kernel_size, half *kernel, + const uint32_t &device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/avg_pool3d_helper_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/avg_pool3d_helper_impl.cuh new file mode 100644 index 00000000000..b3db752949b --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/avg_pool3d_helper_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_AVG_POOL3D_HELPER_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_AVG_POOL3D_HELPER_IMPL_CUH_ +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h" + +template +CUDA_LIB_EXPORT void CalRealKernelSize(const std::vector &input_shape, const std::vector &kernel_size, + const std::vector &edge_kernel_size, T *kernel, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_AVG_POOL3D_HELPER_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/gpu_kernel.h index a0d56016b37..55ec5e4614c 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/gpu_kernel.h @@ -251,6 +251,7 @@ class DeprecatedNativeGpuKernelMod : public NativeGpuKernelMod { } return GetValue(attr); } + template inline T GetAttrWithDefault(const CNodePtr &kernel_node, const std::string &key, const T &value) const { const PrimitivePtr &prim = common::AnfAlgo::GetCNodePrimitive(kernel_node); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_gpu_kernel.cc index 8aabfab211e..c72e6195066 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_gpu_kernel.cc @@ -30,6 +30,8 @@ MS_REG_GPU_KERNEL_ONE(MaxPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat32).A PoolingFwdGpuKernelMod, float) MS_REG_GPU_KERNEL_ONE(MaxPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), PoolingFwdGpuKernelMod, half) +MS_REG_GPU_KERNEL_ONE(AvgPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + PoolingFwdGpuKernelMod, double) MS_REG_GPU_KERNEL_ONE(AvgPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), PoolingFwdGpuKernelMod, float) MS_REG_GPU_KERNEL_ONE(AvgPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_gpu_kernel.h index 888d8e55081..70b273cbbca 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_gpu_kernel.h @@ -20,13 +20,18 @@ #include #include #include +#include #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pad_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/avg_pool3d_helper_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh" #include "plugin/device/gpu/kernel/kernel_constants.h" namespace mindspore { namespace kernel { +constexpr auto kAvgPool = "AvgPool"; +constexpr auto kAvgPool3D = "AvgPool3D"; + template class PoolingFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { public: @@ -49,8 +54,10 @@ class PoolingFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { pad_left_(0), n_(0), c_(0), + divisor_override_(0), pad_value_(0), is_null_input_(false), + ceil_mode_(false), kernel_name_("Pooling"), input_size_(0), output_size_(0), @@ -71,13 +78,38 @@ class PoolingFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { cudnnPoolingForward(cudnn_handle_, pooling_descriptor_, &alpha, input_descriptor_, input_addr, &beta, output_descriptor_, output_addr), "cudnnPoolingForward failed"); + + if (divisor_override_ != 0) { + T *work_addr = GetDeviceAddress(workspace, 0); + size_t output_num = output_size_ / sizeof(T); + int64_t size = std::accumulate(kernel_size_.begin(), kernel_size_.end(), 1, std::multiplies()); + T divisor = static_cast(LongToFloat(size) / LongToFloat(divisor_override_)); + std::vector divisor_value(output_num, divisor); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(work_addr, divisor_value.data(), output_size_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed."); + if (ceil_mode_) { + CalRealKernelSize(output_shape_exclude_nc_, kernel_size_, edge_kernel_, work_addr, 0, + reinterpret_cast(stream_ptr)); + } + ElewiseArith(output_num, BROADCAST_TYPE_MUL, output_addr, work_addr, output_addr, + reinterpret_cast(stream_ptr)); + } return true; } + bool Init(const CNodePtr &kernel_node) { kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); kernel_node_ = kernel_node; InitResource(); (void)CheckParam(kernel_node); + auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node); + MS_EXCEPTION_IF_NULL(prim); + if (kernel_name_ == kAvgPool3D) { + divisor_override_ = GetAttr(kernel_node, "divisor_override"); + ceil_mode_ = GetAttr(kernel_node, "ceil_mode"); + } cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); data_format_ = AnfAlgo::GetInputFormat(kernel_node, 0); auto format_attr = GetAttr(kernel_node, "format"); @@ -123,6 +155,7 @@ class PoolingFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { } else if (dim == kDim3DShapeSize) { SetPad3D(kernel_node); } + edge_kernel_ = GetEdgeKernelSize(kernel_node); InitSizeLists(); return true; } @@ -157,6 +190,7 @@ class PoolingFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { } input_size_list_.push_back(input_size_); output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(output_size_); } private: @@ -169,14 +203,22 @@ class PoolingFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { void SetPoolingMode(const CNodePtr &kernel_node) { mode_ = common::AnfAlgo::GetCNodeName(kernel_node); - if (mode_ == "AvgPool" || mode_ == "AvgPool3D") { - pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node); + MS_EXCEPTION_IF_NULL(prim); + bool include = false; + if (prim->HasAttr("count_include_pad")) { + include = GetAttr(kernel_node, "count_include_pad"); + } + if (mode_ == kAvgPool || mode_ == kAvgPool3D) { + pooling_mode_ = + include ? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; pad_value_ = 0.0; } else { pooling_mode_ = CUDNN_POOLING_MAX; pad_value_ = kSignedMinFloat; } } + void SetPad(const CNodePtr &kernel_node) { auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node); MS_EXCEPTION_IF_NULL(prim); @@ -281,12 +323,66 @@ class PoolingFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { "cudnnSetPoolingNdDescriptor failed"); } + std::vector GetEdgeKernelSize(const CNodePtr &kernel_node) { + if (!ceil_mode_ && divisor_override_ == 0) { + return {}; + } + + const size_t k3dSizeLowerLimit = 5; + const size_t kIdxD = 2; + const size_t kIdxH = 3; + const size_t kIdxW = 4; + const size_t kScale = 2; + std::vector edge_kernel; + std::vector kernel_size = GetAttr>(kernel_node, "kernel_size"); + std::vector strides = GetAttr>(kernel_node, "strides"); + std::vector pad = GetAttr>(kernel_node, "pad_list"); + if (kernel_size.size() != k3dSizeLowerLimit) { + MS_LOG(EXCEPTION) << "kernel_size must be " << k3dSizeLowerLimit << "D, but got " << kernel_size.size(); + } + if (strides.size() != k3dSizeLowerLimit) { + MS_LOG(EXCEPTION) << "strides must be " << k3dSizeLowerLimit << "D, but got " << strides.size(); + } + auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + kernel_size_ = {kernel_size[kIdxD], kernel_size[kIdxH], kernel_size[kIdxW]}; + std::vector stride = {strides[kIdxD], strides[kIdxH], strides[kIdxW]}; + std::vector shape_exclude_nc = {SizeToLong(input_shape[kIdxD]), SizeToLong(input_shape[kIdxH]), + SizeToLong(input_shape[kIdxW])}; + (void)std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(output_shape_exclude_nc_), + SizeToLong); + + const size_t dim = shape_exclude_nc.size(); + if (pad.size() != dim * kScale) { + MS_LOG(EXCEPTION) << "pad_list must be " << (dim * kScale) << "D, but got " << pad.size() << "D!"; + } + + for (size_t i = 0; i < dim; ++i) { + size_t l_index = kScale * i; + size_t r_index = kScale * i + 1; + + int64_t len = shape_exclude_nc[i] + pad[l_index] + pad[r_index] - kernel_size_[i]; + int64_t padding_iv = FloatToLong(std::ceil(LongToFloat(len) / LongToFloat(stride[i]))) * stride[i] - len; + int64_t padding_r = pad[r_index] + padding_iv; + if (padding_r > pad[r_index] && padding_r < kernel_size_[i]) { + edge_kernel.push_back(kernel_size_[i] - padding_iv); + } else { + edge_kernel.push_back(kernel_size_[i]); + } + } + return edge_kernel; + } + cudnnHandle_t cudnn_handle_; cudnnTensorDescriptor_t input_descriptor_; cudnnTensorDescriptor_t output_descriptor_; cudnnPoolingDescriptor_t pooling_descriptor_; cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; std::vector stride_; + std::vector kernel_size_; + std::vector shape_exclude_nc_; + std::vector edge_kernel_; + std::vector output_shape_exclude_nc_; std::string mode_; std::string pad_mode_; std::string data_format_ = kOpFormat_NCHW; @@ -304,8 +400,10 @@ class PoolingFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { int pad_left_; int n_; int c_; + int64_t divisor_override_; float pad_value_; bool is_null_input_; + bool ceil_mode_; std::string kernel_name_; size_t input_size_; size_t output_size_; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_grad_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_grad_gpu_kernel.cc index ca7a62450d3..17fc679079d 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_grad_gpu_kernel.cc @@ -15,54 +15,511 @@ */ #include "plugin/device/gpu/kernel/nn/pooling_grad_gpu_kernel.h" +#include +#include +#include "mindspore/core/ops/grad/pool_grad.h" +#include "mindspore/core/ops/grad/avg_pool_3d_grad.h" +#include "mindspore/core/ops/grad/max_pool_3d_grad.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/avg_pool3d_helper_impl.cuh" namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - PoolingGradGpuKernelMod, float) -MS_REG_GPU_KERNEL_ONE(MaxPoolGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - PoolingGradGpuKernelMod, half) -MS_REG_GPU_KERNEL_ONE(MaxPool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - PoolingGradGpuKernelMod, float) -MS_REG_GPU_KERNEL_ONE(MaxPool3DGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - PoolingGradGpuKernelMod, half) -MS_REG_GPU_KERNEL_ONE(AvgPoolGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddInputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32), - PoolingGradGpuKernelMod, float) -MS_REG_GPU_KERNEL_ONE(AvgPoolGrad, - KernelAttr() - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddInputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16), - PoolingGradGpuKernelMod, half) -MS_REG_GPU_KERNEL_ONE(AvgPool3DGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - PoolingGradGpuKernelMod, float) -MS_REG_GPU_KERNEL_ONE(AvgPool3DGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - PoolingGradGpuKernelMod, half) +constexpr auto kMaxPoolGrad = "MaxPoolGrad"; +constexpr auto kMaxPool3DGrad = "MaxPool3DGrad"; +constexpr auto kAvgPoolGrad = "AvgPoolGrad"; +constexpr auto kAvgPool3DGrad = "AvgPool3DGrad"; +constexpr size_t kInputNum = 3; +constexpr size_t kAvgPool3DGradInputNum = 1; + +bool PoolingGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + kernel_name_ = base_operator->name(); + auto pool_grad_ptr = std::make_shared(base_operator->GetPrim()); + format_attr_ = pool_grad_ptr->get_format(); + pad_mode_ = pool_grad_ptr->get_pad_mode(); + stride_me_ = pool_grad_ptr->get_strides(); + window_me_ = pool_grad_ptr->get_kernel_size(); + if (kernel_name_ == kMaxPool3DGrad) { + auto kernel_ptr = std::make_shared(base_operator->GetPrim()); + pad_list_ = kernel_ptr->get_pad_list(); + } else if (kernel_name_ == kAvgPool3DGrad) { + auto kernel_ptr = std::make_shared(base_operator->GetPrim()); + pad_list_ = kernel_ptr->get_pad_list(); + divisor_override_ = kernel_ptr->get_divisor_override(); + ceil_mode_ = kernel_ptr->get_ceil_mode(); + include_ = kernel_ptr->get_count_include_pad(); + } + cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(inputs.at(kIndex0)->GetDtype())); + SetPoolingMode(); + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = kernel_attr_map_.at(kernel_name_)[index].second; + return true; +} + +int PoolingGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + int ret = KernelMod::Resize(base_operator, inputs, outputs); + if (ret != KRET_OK) { + return ret; + } + size_t input_num = inputs.size(); + size_t expect_input_num = (kernel_name_ == kAvgPool3DGrad) ? kAvgPool3DGradInputNum : kInputNum; + if (input_num != expect_input_num) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << expect_input_num + << ", but got " << input_num; + } + + input_shape_ = inputs.at(kIndex0)->GetShapeVector(); + output_shape_ = outputs.at(kIndex0)->GetShapeVector(); + int nbDims = SizeToInt(input_shape_.size()); + int dimA[kPoolingNbDims]; + int strideAin[kPoolingNbDims]; + int dimAy[kPoolingNbDims]; + int strideAiny[kPoolingNbDims]; + int dimAdy[kPoolingNbDims]; + int strideAdy[kPoolingNbDims]; + int dimAout[kPoolingNbDims]; + int strideAout[kPoolingNbDims]; + if (!InitShape(inputs, outputs, dimA, strideAin, dimAy, strideAiny, dimAdy, strideAdy, dimAout, strideAout, nbDims)) { + return ret; + } + if (nbDims == kDim2DShapeSize) { + SetPad(); + } else if (nbDims == kDim3DShapeSize) { + SetPad3D(); + } + std::string err_msg = "For '" + kernel_name_ + "', cudnnSetTensor4dDescriptor failed"; + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnSetTensorNdDescriptor(dx_descriptor_, cudnn_data_type_, nbDims, dimAout, strideAout), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnSetTensorNdDescriptor(x_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin), err_msg); + edge_kernel_ = GetEdgeKernelSize(); + InitSizeLists(); + return ret; +} + +template +bool PoolingGradGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *x_data = nullptr; + T *y = nullptr; + T *dy = nullptr; + T *dx = nullptr; + if (kernel_name_ == kAvgPool3DGrad) { + dy = GetDeviceAddress(inputs, kIndex0); + dx = GetDeviceAddress(outputs, kIndex0); + x_data = GetDeviceAddress(workspace, kIndex0); + y = GetDeviceAddress(workspace, kIndex1); + } else { + x_data = GetDeviceAddress(inputs, kIndex0); + y = GetDeviceAddress(inputs, kIndex1); + dy = GetDeviceAddress(inputs, kIndex2); + dx = GetDeviceAddress(outputs, kIndex0); + } + + const float alpha = 1; + const float beta = 0; + + if (divisor_override_ != 0) { + T *work_addr = GetDeviceAddress(workspace, kIndex2); + T *dy_work_addr = GetDeviceAddress(workspace, kIndex3); + size_t output_num = input_size_ / sizeof(T); + int64_t size = std::accumulate(kernel_size_.begin(), kernel_size_.end(), 1, std::multiplies()); + T divisor = static_cast(LongToFloat(size) / LongToFloat(divisor_override_)); + std::vector divisor_value(output_num, divisor); + std::string err_msg = "For '" + kernel_name_ + "', cudaMemcpyAsync failed."; + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(work_addr, divisor_value.data(), input_size_, cudaMemcpyHostToDevice, + reinterpret_cast(cuda_stream_)), + err_msg); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(dy_work_addr, dy, input_size_, cudaMemcpyDeviceToDevice, + reinterpret_cast(cuda_stream_)), + err_msg); + if (ceil_mode_) { + CalRealKernelSize(input_shape_, kernel_size_, edge_kernel_, work_addr, device_id_, + reinterpret_cast(cuda_stream_)); + } + ElewiseArith(output_num, BROADCAST_TYPE_MUL, dy_work_addr, work_addr, dy_work_addr, + reinterpret_cast(cuda_stream_)); + + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy_work_addr, + x_descriptor_, x_data, &beta, dx_descriptor_, dx), + "For '" + kernel_name_ + "', cudnnPoolingBackward failed"); + return true; + } + + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, + x_descriptor_, x_data, &beta, dx_descriptor_, dx), + "For '" + kernel_name_ + "', cudnnPoolingBackward failed"); + return true; +} + +bool PoolingGradGpuKernelMod::InitShape(const std::vector &inputs, + const std::vector &outputs, int *dimA, int *strideAin, + int *dimAy, int *strideAiny, int *dimAdy, int *strideAdy, int *dimAout, + int *strideAout, int nbDims) { + ShapeVector dout_shape, input_mask, output_shape, input_shape; + if (kernel_name_ == kAvgPool3DGrad) { + dout_shape = inputs.at(kIndex0)->GetShapeVector(); + output_shape = outputs.at(kIndex0)->GetShapeVector(); + input_mask = dout_shape; + input_shape = output_shape; + } else { + input_shape = inputs.at(kIndex0)->GetShapeVector(); + input_mask = inputs.at(kIndex1)->GetShapeVector(); + dout_shape = inputs.at(kIndex2)->GetShapeVector(); + output_shape = outputs.at(kIndex0)->GetShapeVector(); + } + is_null_input_ = + CHECK_SHAPE_NULL(input_shape, kernel_name_, "input") || CHECK_SHAPE_NULL(input_mask, kernel_name_, "mask") || + CHECK_SHAPE_NULL(dout_shape, kernel_name_, "dout") || CHECK_SHAPE_NULL(output_shape, kernel_name_, "output"); + if (is_null_input_) { + InitSizeLists(); + return false; + } + auto data_format = GetFormatFromEnumToStr(inputs.at(kIndex0)->GetFormat()); + if (Anyone(format_attr_, Format::NHWC, Format::NDHWC)) { + data_format = GetFormatFromEnumToStr(format_attr_); + } + + CheckTensorSize({input_shape, input_mask, dout_shape, output_shape}); + if (nbDims == kDim2DShapeSize) { + SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format); + } else if (nbDims == kDim3DShapeSize) { + SetNCDHW(input_shape, &n_, &c_, &old_depth_, &old_height_, &old_width_, data_format); + } + SetDimA(input_shape, dimA, nbDims, data_format); + SetStrideA(input_shape, strideAin, nbDims, data_format); + SetDimA(input_mask, dimAy, nbDims, data_format); + SetStrideA(input_mask, strideAiny, nbDims, data_format); + SetDimA(dout_shape, dimAdy, nbDims, data_format); + SetStrideA(dout_shape, strideAdy, nbDims, data_format); + SetDimA(output_shape, dimAout, nbDims, data_format); + SetStrideA(output_shape, strideAout, nbDims, data_format); + return true; +} + +void PoolingGradGpuKernelMod::DestroyResource() noexcept { + std::string err_msg = "For '" + kernel_name_ + "', cudnnDestroyPoolingDescriptor failed"; + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyPoolingDescriptor(pooling_descriptor_), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyTensorDescriptor(dx_descriptor_), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyTensorDescriptor(x_descriptor_), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyTensorDescriptor(dy_descriptor_), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnDestroyTensorDescriptor(y_descriptor_), err_msg); +} + +void PoolingGradGpuKernelMod::InitResource() { + pooling_mode_ = CUDNN_POOLING_MAX; + cudnn_data_type_ = CUDNN_DATA_FLOAT; + compute_format_ = CUDNN_TENSOR_NCHW; + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + std::string err_msg = "For '" + kernel_name_ + "', cudnnCreateTensorDescriptor failed"; + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&y_descriptor_), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&dy_descriptor_), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&x_descriptor_), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&dx_descriptor_), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreatePoolingDescriptor(&pooling_descriptor_), err_msg); +} + +void PoolingGradGpuKernelMod::InitSizeLists() { + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + std::string err_msg = "For '" + kernel_name_ + "', cudnnGetTensorSizeInBytes failed"; + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(x_descriptor_, &input_size_), err_msg); + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(dx_descriptor_, &output_size_), err_msg); + } + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + if (kernel_name_ == kAvgPool3DGrad) { + workspace_size_list_.push_back(output_size_); + workspace_size_list_.push_back(input_size_); + if (divisor_override_ != 0) { + workspace_size_list_.push_back(input_size_); + workspace_size_list_.push_back(input_size_); + } + } + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(y_descriptor_, &input_size_), err_msg); + } + input_size_list_.push_back(input_size_); + + if (!is_null_input_) { + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnGetTensorSizeInBytes(dy_descriptor_, &input_size_), err_msg); + } + input_size_list_.push_back(input_size_); +} + +void PoolingGradGpuKernelMod::SetPad() { + std::vector window; + std::vector stride; + (void)std::transform(stride_me_.begin(), stride_me_.end(), std::back_inserter(stride), + [](const int64_t &value) { return static_cast(value); }); + (void)std::transform(window_me_.begin(), window_me_.end(), std::back_inserter(window), + [](const int64_t &value) { return static_cast(value); }); + const size_t kSizeLowerLimit = 4; + const size_t kIdxH = 2; + const size_t kIdxW = 3; + if (window.size() < kSizeLowerLimit) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'kernel_size' cannot be less than 4, but got " + << window.size(); + } + int window_height = window[kIdxH]; + int window_width = window[kIdxW]; + if (stride.size() < kSizeLowerLimit) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'strides' cannot be less than 4, but got " + << stride.size(); + } + int stride_h = stride[kIdxH]; + int stride_w = stride[kIdxW]; + if (format_attr_ == Format::NHWC) { + const size_t kNHWCIdxH = 1; + const size_t kNHWCIdxW = 2; + window_height = window[kNHWCIdxH]; + window_width = window[kNHWCIdxW]; + stride_h = stride[kNHWCIdxH]; + stride_w = stride[kNHWCIdxW]; + } + const size_t k2dDim = 2; + int windowDimA[k2dDim] = {window_height, window_width}; + int paddingA[k2dDim] = {0, 0}; + int strideA[k2dDim] = {stride_h, stride_w}; + if (pad_mode_ == PadMode::SAME) { + pad_height_ = GetPad(old_height_, window_height, stride_h); + pad_width_ = GetPad(old_width_, window_width, stride_w); + const int kSymCoef = 2; + pad_top_ = pad_height_ / kSymCoef; + pad_left_ = pad_width_ / kSymCoef; + paddingA[kIndex0] = pad_top_; + paddingA[kIndex1] = pad_left_; + } else { + if (pad_mode_ == PadMode::VALID) { + pad_height_ = 0; + pad_width_ = 0; + } + } + std::string err_msg = "For '" + kernel_name_ + "', cudnnSetPoolingNdDescriptor failed"; + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, k2dDim, windowDimA, + paddingA, strideA), + err_msg); +} + +void PoolingGradGpuKernelMod::SetPad3D() { + const int kPadListSize = 6; + const int kPadScale = 2; + std::vector window; + std::vector stride; + (void)std::transform(stride_me_.begin(), stride_me_.end(), std::back_inserter(stride), + [](const int64_t &value) { return static_cast(value); }); + (void)std::transform(window_me_.begin(), window_me_.end(), std::back_inserter(window), + [](const int64_t &value) { return static_cast(value); }); + const size_t k3dSizeLowerLimit = 5; + const size_t kIdxD = 2; + const size_t kIdxH = 3; + const size_t kIdxW = 4; + if (window.size() < k3dSizeLowerLimit) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'kernel_size' cannot be less than 5, but got " + << window.size(); + } + int window_depth = window[kIdxD]; + int window_height = window[kIdxH]; + int window_width = window[kIdxW]; + if (stride.size() < k3dSizeLowerLimit) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'strides' cannot be less than 5, but got " + << stride.size(); + } + int stride_d = stride[kIdxD]; + int stride_h = stride[kIdxH]; + int stride_w = stride[kIdxW]; + if (format_attr_ == Format::NDHWC) { + const size_t kNDHWCIdxD = 1; + const size_t kNDHWCIdxH = 2; + const size_t kNDHWCIdxW = 3; + window_depth = window[kNDHWCIdxD]; + window_height = window[kNDHWCIdxH]; + window_width = window[kNDHWCIdxW]; + stride_d = stride[kNDHWCIdxD]; + stride_h = stride[kNDHWCIdxH]; + stride_w = stride[kNDHWCIdxW]; + } + const size_t k3dDimSize = 3; + int windowDimA[k3dDimSize] = {window_depth, window_height, window_width}; + int paddingA[k3dDimSize] = {0, 0, 0}; + int strideA[k3dDimSize] = {stride_d, stride_h, stride_w}; + if (pad_mode_ == PadMode::SAME) { + pad_depth_ = GetPad(old_depth_, window_depth, stride_d); + pad_height_ = GetPad(old_height_, window_height, stride_h); + pad_width_ = GetPad(old_width_, window_width, stride_w); + const int kSymCoef = 2; + pad_front_ = pad_depth_ / kSymCoef; + pad_top_ = pad_height_ / kSymCoef; + pad_left_ = pad_width_ / kSymCoef; + paddingA[kIndex0] = pad_front_; + paddingA[kIndex1] = pad_top_; + paddingA[kIndex2] = pad_left_; + } else if (pad_mode_ == PadMode::VALID) { + pad_depth_ = 0; + pad_height_ = 0; + pad_width_ = 0; + } else { + if (pad_list_.size() != kPadListSize) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'pad_list' must be 6, but got " + << pad_list_.size(); + } + for (size_t idx = 0; idx < k3dDimSize; idx++) { + paddingA[idx] = pad_list_[idx * kPadScale]; + } + } + std::string err_msg = "For '" + kernel_name_ + "', cudnnSetPoolingNdDescriptor failed"; + CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE( + cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, k3dDimSize, windowDimA, + paddingA, strideA), + err_msg); +} + +void PoolingGradGpuKernelMod::SetPoolingMode() { + if (kernel_name_ == kAvgPool3DGrad) { + pooling_mode_ = + include_ ? CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING : CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + pad_value_ = 0.0; + } else if (kernel_name_ == kAvgPoolGrad) { + pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + pad_value_ = 0.0; + } else { + pooling_mode_ = CUDNN_POOLING_MAX; + pad_value_ = kSignedMinFloat; + } +} + +std::vector PoolingGradGpuKernelMod::GetEdgeKernelSize() { + if (!ceil_mode_ && divisor_override_ == 0) { + return {}; + } + + const size_t k3dSizeLowerLimit = 5; + const size_t kIdxD = 2; + const size_t kIdxH = 3; + const size_t kIdxW = 4; + const size_t kScale = 2; + std::vector edge_kernel; + if (window_me_.size() != k3dSizeLowerLimit) { + MS_LOG(EXCEPTION) << "kernel_size must be " << k3dSizeLowerLimit << "D, but got " << window_me_.size(); + } + if (stride_me_.size() != k3dSizeLowerLimit) { + MS_LOG(EXCEPTION) << "strides must be " << k3dSizeLowerLimit << "D, but got " << stride_me_.size(); + } + + kernel_size_ = {window_me_[kIdxD], window_me_[kIdxH], window_me_[kIdxW]}; + std::vector stride = {stride_me_[kIdxD], stride_me_[kIdxH], stride_me_[kIdxW]}; + std::vector shape_exclude_nc = {output_shape_[kIdxD], output_shape_[kIdxH], output_shape_[kIdxW]}; + + const size_t kDim = shape_exclude_nc.size(); + if (pad_list_.size() != kDim * kScale) { + MS_LOG(EXCEPTION) << "pad_list must be " << (kDim * kScale) << "D, but got " << pad_list_.size() << "D!"; + } + + for (size_t i = 0; i < kDim; ++i) { + size_t l_index = kScale * i; + size_t r_index = kScale * i + 1; + + int64_t len = shape_exclude_nc[i] + pad_list_[l_index] + pad_list_[r_index] - kernel_size_[i]; + int64_t padding_iv = FloatToLong(std::ceil(LongToFloat(len) / LongToFloat(stride[i]))) * stride[i] - len; + int64_t padding_r = pad_list_[r_index] + padding_iv; + if (padding_r > pad_list_[r_index] && padding_r < kernel_size_[i]) { + edge_kernel.push_back(kernel_size_[i] - padding_iv); + } else { + edge_kernel.push_back(kernel_size_[i]); + } + } + return edge_kernel; +} + +std::map>> + PoolingGradGpuKernelMod::kernel_attr_map_ = { + {kMaxPoolGrad, + {{KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &PoolingGradGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &PoolingGradGpuKernelMod::LaunchKernel}}}, + {kMaxPool3DGrad, + {{KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &PoolingGradGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &PoolingGradGpuKernelMod::LaunchKernel}}}, + {kAvgPoolGrad, + {{KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + &PoolingGradGpuKernelMod::LaunchKernel}, + {KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + &PoolingGradGpuKernelMod::LaunchKernel}}}, + {kAvgPool3DGrad, + {{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &PoolingGradGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &PoolingGradGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &PoolingGradGpuKernelMod::LaunchKernel}}}}; + +std::vector PoolingGradGpuKernelMod::GetOpSupport() { + auto iter = kernel_attr_map_.find(kernel_name_); + if (iter == kernel_attr_map_.end()) { + MS_LOG(ERROR) << "For 'PoolingGradGpuKernelMod', the kernel name must be in " << kernel::Map2Str(kernel_attr_map_) + << ", but got " << kernel_name_; + return std::vector{}; + } + std::vector support_list; + (void)std::transform(iter->second.begin(), iter->second.end(), std::back_inserter(support_list), + [](const std::pair &item) { return item.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, MaxPoolGrad, + []() { return std::make_shared(kMaxPoolGrad); }); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, MaxPool3DGrad, + []() { return std::make_shared(kMaxPool3DGrad); }); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, AvgPoolGrad, + []() { return std::make_shared(kAvgPoolGrad); }); +MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, AvgPool3DGrad, + []() { return std::make_shared(kAvgPool3DGrad); }); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_grad_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_grad_gpu_kernel.h index 0d99a7de310..f78621b9ae6 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_grad_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/nn/pooling_grad_gpu_kernel.h @@ -18,368 +18,62 @@ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_POOLING_GRAD_GPU_KERNEL_H_ #include +#include +#include #include #include #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pad_impl.cuh" #include "plugin/device/gpu/kernel/kernel_constants.h" +#include "utils/check_convert_utils.h" namespace mindspore { namespace kernel { -constexpr size_t kInputNum = 3; -constexpr size_t kAvgPool3DGradInputNum = 1; -template -class PoolingGradGpuKernelMod : public DeprecatedNativeGpuKernelMod { +class PoolingGradGpuKernelMod : public NativeGpuKernelMod { public: - PoolingGradGpuKernelMod() - : cudnn_handle_(nullptr), - pooling_descriptor_(nullptr), - y_descriptor_(nullptr), - dy_descriptor_(nullptr), - x_descriptor_(nullptr), - dx_descriptor_(nullptr), - pooling_mode_(CUDNN_POOLING_MAX), - cudnn_data_type_(CUDNN_DATA_FLOAT), - compute_format_(CUDNN_TENSOR_NCHW), - old_depth_(0), - old_height_(0), - old_width_(0), - pad_depth_(0), - pad_height_(0), - pad_width_(0), - pad_front_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - pad_value_(0), - is_null_input_(false), - kernel_name_("PoolingGrad"), - input_size_(0), - output_size_(0), - workspace_size_(0) {} + explicit PoolingGradGpuKernelMod(const std::string &kernel_name) : kernel_name_(kernel_name) { InitResource(); } ~PoolingGradGpuKernelMod() override { DestroyResource(); } + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; bool Launch(const std::vector &inputs, const std::vector &workspace, - const std::vector &outputs, void *stream_ptr) override { + const std::vector &outputs, void *cuda_stream) override { if (is_null_input_) { return true; } - T *x_data = nullptr; - T *y = nullptr; - T *dy = nullptr; - T *dx = nullptr; - if (kernel_name_ == kAvgPool3DGradOpName) { - dy = GetDeviceAddress(inputs, 0); - dx = GetDeviceAddress(outputs, 0); - CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast(&x_data), outputs[0]->size), - "cudaMalloc failed."); - CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaMalloc(reinterpret_cast(&y), inputs[0]->size), - "cudaMalloc failed."); - } else { - x_data = GetDeviceAddress(inputs, 0); - y = GetDeviceAddress(inputs, 1); - dy = GetDeviceAddress(inputs, 2); - dx = GetDeviceAddress(outputs, 0); - } - - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy, - x_descriptor_, x_data, &beta, dx_descriptor_, dx), - "cudnnPoolingBackward failed"); - return true; + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); } - bool InitShape(const CNodePtr &kernel_node, int *dimA, int *strideAin, int *dimAy, int *strideAiny, int *dimAdy, - int *strideAdy, int *dimAout, int *strideAout, int nbDims) { - ShapeVector dout_shape, input_mask, output_shape, input_shape; - if (kernel_name_ == kAvgPool3DGradOpName) { - dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - input_mask = dout_shape; - input_shape = output_shape; - } else { - const size_t kDoutIdx = 2; - input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - input_mask = AnfAlgo::GetInputDeviceShape(kernel_node, 1); - dout_shape = AnfAlgo::GetInputDeviceShape(kernel_node, kDoutIdx); - output_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - } + int Resize( + const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost = std::map()) override; - auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0); - format_attr_ = GetAttr(kernel_node, "format"); - if (Anyone(format_attr_, kOpFormat_NHWC, kOpFormat_NDHWC)) { - data_format = format_attr_; - } - cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); - is_null_input_ = - CHECK_SHAPE_NULL(input_shape, kernel_name_, "input") || CHECK_SHAPE_NULL(input_mask, kernel_name_, "mask") || - CHECK_SHAPE_NULL(dout_shape, kernel_name_, "dout") || CHECK_SHAPE_NULL(output_shape, kernel_name_, "output"); - if (is_null_input_ || AnfAlgo::IsShapesDynamic({input_shape, output_shape, input_mask, dout_shape})) { - InitSizeLists(); - return true; - } - CheckTensorSize({input_shape, input_mask, dout_shape, output_shape}); - if (nbDims == kDim2DShapeSize) { - SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format); - } else if (nbDims == kDim3DShapeSize) { - SetNCDHW(input_shape, &n_, &c_, &old_depth_, &old_height_, &old_width_, data_format); - } - SetDimA(input_shape, dimA, nbDims, data_format); - SetStrideA(input_shape, strideAin, nbDims, data_format); - SetDimA(input_mask, dimAy, nbDims, data_format); - SetStrideA(input_mask, strideAiny, nbDims, data_format); - SetDimA(dout_shape, dimAdy, nbDims, data_format); - SetStrideA(dout_shape, strideAdy, nbDims, data_format); - SetDimA(output_shape, dimAout, nbDims, data_format); - SetStrideA(output_shape, strideAout, nbDims, data_format); - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); - kernel_node_ = kernel_node; - InitResource(); - (void)CheckParam(kernel_node); - auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - int nbDims = SizeToInt(input_shape.size()); - - int dimA[kPoolingNbDims]; - int strideAin[kPoolingNbDims]; - int dimAy[kPoolingNbDims]; - int strideAiny[kPoolingNbDims]; - int dimAdy[kPoolingNbDims]; - int strideAdy[kPoolingNbDims]; - int dimAout[kPoolingNbDims]; - int strideAout[kPoolingNbDims]; - if (!InitShape(kernel_node, dimA, strideAin, dimAy, strideAiny, dimAdy, strideAdy, dimAout, strideAout, nbDims)) { - return true; - } - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, - cudnnSetTensorNdDescriptor(y_descriptor_, cudnn_data_type_, nbDims, dimAy, strideAiny), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, - cudnnSetTensorNdDescriptor(dy_descriptor_, cudnn_data_type_, nbDims, dimAdy, strideAdy), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, cudnnSetTensorNdDescriptor(dx_descriptor_, cudnn_data_type_, nbDims, dimAout, strideAout), - "cudnnSetTensor4dDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, - cudnnSetTensorNdDescriptor(x_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin), - "cudnnSetTensor4dDescriptor failed"); - SetPoolingMode(kernel_node); - if (nbDims == kDim2DShapeSize) { - SetPad(kernel_node); - } else if (nbDims == kDim3DShapeSize) { - SetPad3D(kernel_node); - } - InitSizeLists(); - return true; - } - - void DestroyResource() noexcept override { - CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyPoolingDescriptor(pooling_descriptor_), - "cudnnDestroyPoolingDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dx_descriptor_), - "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_descriptor_), - "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(dy_descriptor_), - "cudnnDestroyTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_descriptor_), - "cudnnDestroyTensorDescriptor failed"); - } + void DestroyResource() noexcept override; protected: - void InitResource() override { - cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&y_descriptor_), - "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dy_descriptor_), - "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_descriptor_), - "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&dx_descriptor_), - "cudnnCreateTensorDescriptor failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreatePoolingDescriptor(&pooling_descriptor_), - "cudnnCreatePoolingDescriptor failed"); - } - void InitSizeLists() override { - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dx_descriptor_, &output_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(y_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - - if (!is_null_input_) { - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(dy_descriptor_, &input_size_), - "cudnnGetTensorSizeInBytes failed"); - } - input_size_list_.push_back(input_size_); - } + std::vector GetOpSupport() override; + void InitResource() override; + void InitSizeLists(); private: - void CheckParam(const CNodePtr &kernel_node) { - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - size_t expect_input_num = (kernel_name_ == kAvgPool3DGradOpName) ? kAvgPool3DGradInputNum : kInputNum; - if (input_num != expect_input_num) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be " << expect_input_num - << ", but got " << input_num; - } - } - void SetPad(const CNodePtr &kernel_node) { - pad_mode_ = GetAttr(kernel_node, "pad_mode"); - std::vector stride_me = GetAttr>(kernel_node, "strides"); - std::vector window; - std::vector window_me = GetAttr>(kernel_node, "kernel_size"); - (void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_), - [](const int64_t &value) { return static_cast(value); }); - (void)std::transform(window_me.begin(), window_me.end(), std::back_inserter(window), - [](const int64_t &value) { return static_cast(value); }); - const size_t kSizeLowerLimit = 4; - const size_t kIdxH = 2; - const size_t kIdxW = 3; - if (window.size() < kSizeLowerLimit) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'kernel_size' cannot be less than 4, but got " - << window.size(); - } - int window_height = window[kIdxH]; - int window_width = window[kIdxW]; - if (stride_.size() < kSizeLowerLimit) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'strides' cannot be less than 4, but got " - << stride_.size(); - } - int stride_h = stride_[kIdxH]; - int stride_w = stride_[kIdxW]; - if (format_attr_ == kOpFormat_NHWC) { - const size_t kNHWCIdxH = 1; - const size_t kNHWCIdxW = 2; - window_height = window[kNHWCIdxH]; - window_width = window[kNHWCIdxW]; - stride_h = stride_[kNHWCIdxH]; - stride_w = stride_[kNHWCIdxW]; - } - const size_t k2dDim = 2; - int windowDimA[k2dDim] = {window_height, window_width}; - int paddingA[k2dDim] = {0, 0}; - int strideA[k2dDim] = {stride_h, stride_w}; - if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { - pad_height_ = GetPad(old_height_, window_height, stride_h); - pad_width_ = GetPad(old_width_, window_width, stride_w); - const int kSymCoef = 2; - pad_top_ = pad_height_ / kSymCoef; - pad_left_ = pad_width_ / kSymCoef; - paddingA[0] = pad_top_; - paddingA[1] = pad_left_; - } else { - if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { - pad_height_ = 0; - pad_width_ = 0; - } - } - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, - cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, - 2, windowDimA, paddingA, strideA), - "cudnnSetPoolingNdDescriptor failed"); - } - - void SetPad3D(const CNodePtr &kernel_node) { - const int kPadListSize = 6; - const int kPadScale = 2; - pad_mode_ = GetAttr(kernel_node, "pad_mode"); - std::vector stride_me = GetAttr>(kernel_node, "strides"); - std::vector window; - std::vector window_me = GetAttr>(kernel_node, "kernel_size"); - (void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_), - [](const int64_t &value) { return static_cast(value); }); - (void)std::transform(window_me.begin(), window_me.end(), std::back_inserter(window), - [](const int64_t &value) { return static_cast(value); }); - const size_t k3dSizeLowerLimit = 5; - const size_t kIdxD = 2; - const size_t kIdxH = 3; - const size_t kIdxW = 4; - if (window.size() < k3dSizeLowerLimit) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'kernel_size' cannot be less than 5, but got " - << window.size(); - } - int window_depth = window[kIdxD]; - int window_height = window[kIdxH]; - int window_width = window[kIdxW]; - if (stride_.size() < k3dSizeLowerLimit) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'strides' cannot be less than 5, but got " - << stride_.size(); - } - int stride_d = stride_[kIdxD]; - int stride_h = stride_[kIdxH]; - int stride_w = stride_[kIdxW]; - if (format_attr_ == kOpFormat_NDHWC) { - const size_t kNDHWCIdxD = 1; - const size_t kNDHWCIdxH = 2; - const size_t kNDHWCIdxW = 3; - window_depth = window[kNDHWCIdxD]; - window_height = window[kNDHWCIdxH]; - window_width = window[kNDHWCIdxW]; - stride_d = stride_[kNDHWCIdxD]; - stride_h = stride_[kNDHWCIdxH]; - stride_w = stride_[kNDHWCIdxW]; - } - const size_t k3dDimSize = 3; - int windowDimA[k3dDimSize] = {window_depth, window_height, window_width}; - int paddingA[k3dDimSize] = {0, 0, 0}; - int strideA[k3dDimSize] = {stride_d, stride_h, stride_w}; - if (kSamePadModeUpperCase == pad_mode_ || kSamePadModeLowerCase == pad_mode_) { - pad_depth_ = GetPad(old_depth_, window_depth, stride_d); - pad_height_ = GetPad(old_height_, window_height, stride_h); - pad_width_ = GetPad(old_width_, window_width, stride_w); - const int kSymCoef = 2; - pad_front_ = pad_depth_ / kSymCoef; - pad_top_ = pad_height_ / kSymCoef; - pad_left_ = pad_width_ / kSymCoef; - paddingA[0] = pad_front_; - paddingA[1] = pad_top_; - const size_t kPadLeftIdx = 2; - paddingA[kPadLeftIdx] = pad_left_; - } else if (pad_mode_ == kValidPadModeUpperCase || pad_mode_ == kValidPadModeLowerCase) { - pad_depth_ = 0; - pad_height_ = 0; - pad_width_ = 0; - } else { - const std::vector &pad_list = GetAttr>(kernel_node, "pad_list"); - if (pad_list.size() != kPadListSize) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the length of 'pad_list' must be 6, but got " - << pad_list.size(); - } - for (size_t idx = 0; idx < k3dDimSize; idx++) { - paddingA[idx] = pad_list[idx * kPadScale]; - } - } - CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, - cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN, - k3dDimSize, windowDimA, paddingA, strideA), - "cudnnSetPoolingNdDescriptor failed"); - } - - void SetPoolingMode(const CNodePtr &kernel_node) { - if (kernel_name_ == kAvgPoolGradOpName || kernel_name_ == kAvgPool3DGradOpName) { - pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; - pad_value_ = 0.0; - } else { - pooling_mode_ = CUDNN_POOLING_MAX; - pad_value_ = kSignedMinFloat; - } - } + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + bool InitShape(const std::vector &inputs, const std::vector &outputs, int *dimA, + int *strideAin, int *dimAy, int *strideAiny, int *dimAdy, int *strideAdy, int *dimAout, + int *strideAout, int nbDims); + void SetPad(); + void SetPad3D(); + void SetPoolingMode(); + std::vector GetEdgeKernelSize(); + using PoolingGradFunc = + std::function &, + const std::vector &, const std::vector &)>; + PoolingGradFunc kernel_func_; + static std::map>> + kernel_attr_map_; cudnnHandle_t cudnn_handle_; cudnnPoolingDescriptor_t pooling_descriptor_; @@ -387,30 +81,42 @@ class PoolingGradGpuKernelMod : public DeprecatedNativeGpuKernelMod { cudnnTensorDescriptor_t dy_descriptor_; cudnnTensorDescriptor_t x_descriptor_; cudnnTensorDescriptor_t dx_descriptor_; - cudnnPoolingMode_t pooling_mode_ = CUDNN_POOLING_MAX; - std::vector stride_; - - std::string pad_mode_; - std::string format_attr_ = kOpFormat_NCHW; cudnnDataType_t cudnn_data_type_; cudnnTensorFormat_t compute_format_; - int old_depth_; - int old_height_; - int old_width_; - int pad_depth_; - int pad_height_; - int pad_width_; - int pad_front_; - int pad_top_; - int pad_left_; - int n_; - int c_; - float pad_value_; - bool is_null_input_; + cudnnPoolingMode_t pooling_mode_; + + std::vector stride_; + std::vector stride_me_; + std::vector window_me_; + std::vector input_shape_; + std::vector output_shape_; + std::vector edge_kernel_; + std::vector kernel_size_; + std::vector pad_list_; std::string kernel_name_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; + PadMode pad_mode_; + Format format_attr_ = Format::NCHW; + + int old_depth_{0}; + int old_height_{0}; + int old_width_{0}; + int pad_depth_{0}; + int pad_height_{0}; + int pad_width_{0}; + int pad_front_{0}; + int pad_top_{0}; + int pad_left_{0}; + int n_{0}; + int c_{0}; + float pad_value_{0.0}; + bool is_null_input_{false}; + bool include_{false}; + bool ceil_mode_{false}; + int64_t divisor_override_{0}; + size_t input_size_{0}; + size_t output_size_{0}; + size_t workspace_size_{0}; + void *cuda_stream_{nullptr}; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/core/ops/core_ops.h b/mindspore/core/ops/core_ops.h index 731c4e21339..8ac331087b7 100644 --- a/mindspore/core/ops/core_ops.h +++ b/mindspore/core/ops/core_ops.h @@ -533,6 +533,7 @@ GVAR_DEF(PrimitivePtr, kPrimPSROIPoolingGrad, std::make_shared("PSROI GVAR_DEF(PrimitivePtr, kPrimROIPooling, std::make_shared("ROIPooling")); GVAR_DEF(PrimitivePtr, kPrimMaxPool, std::make_shared("MaxPool")); GVAR_DEF(PrimitivePtr, kPrimMaxPoolGrad, std::make_shared("MaxPoolGrad")); +GVAR_DEF(PrimitivePtr, kPrimMaxPool3DGrad, std::make_shared("MaxPool3DGrad")); GVAR_DEF(PrimitivePtr, kPrimMaxPoolV1, std::make_shared("MaxPoolV1")); GVAR_DEF(PrimitivePtr, kPrimMaxPoolGradV1, std::make_shared("MaxPoolGradV1")); GVAR_DEF(PrimitivePtr, kPrimMaxPoolWithArgmax, std::make_shared("MaxPoolWithArgmax")); diff --git a/mindspore/core/ops/gather.cc b/mindspore/core/ops/gather.cc index b93688df39b..5fedf10623c 100644 --- a/mindspore/core/ops/gather.cc +++ b/mindspore/core/ops/gather.cc @@ -30,6 +30,8 @@ namespace { abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); const std::string &op_name = primitive->name(); + const int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name); abstract::AbstractTensorPtr indices = CheckAndConvertUtils::CheckArgs(op_name, input_args, 1); abstract::AbstractTensorPtr params = @@ -42,7 +44,11 @@ abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::ve std::any_of(params->shape()->shape().begin(), params->shape()->shape().end(), [](int64_t s) { return s < 0; }); int64_t axis_val = 0; // 3rd input is a Tensor when Gather is a dynamic shape operator - if (input_args[kInputIndex2]->isa()) { + if (SizeToLong(input_args.size()) == input_num) { + auto axis_attr = primitive->GetAttr("axis"); + MS_EXCEPTION_IF_NULL(axis_attr); + axis_val = GetValue(axis_attr); + } else if (input_args[kInputIndex2]->isa()) { auto axis = input_args[kInputIndex2]->cast(); MS_EXCEPTION_IF_NULL(axis); auto axis_value_ptr = axis->BuildValue(); @@ -94,15 +100,17 @@ abstract::ShapePtr GatherInferShape(const PrimitivePtr &primitive, const std::ve TypePtr GatherInferType(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); const std::string &op_name = primitive->name(); - constexpr size_t input_num = 3; - abstract::CheckArgsSize(op_name, input_args, input_num); + constexpr int64_t input_num = 2; + (void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, op_name); std::set valid_params_types = {kTensorType}; (void)CheckAndConvertUtils::CheckSubClass("params", input_args[kInputIndex0]->BuildType(), valid_params_types, op_name); std::set int_types = {kInt8, kInt16, kInt32, kInt64}; (void)CheckAndConvertUtils::CheckTensorTypeValid("indices", input_args[kInputIndex1]->BuildType(), int_types, op_name); - (void)CheckAndConvertUtils::CheckTypeValid("axis", input_args[kInputIndex2]->BuildType(), int_types, op_name); + if (SizeToLong(input_args.size()) > input_num) { + (void)CheckAndConvertUtils::CheckTypeValid("axis", input_args[kInputIndex2]->BuildType(), int_types, op_name); + } abstract::AbstractTensorPtr params = CheckAndConvertUtils::CheckArgs(op_name, input_args, 0); @@ -113,7 +121,7 @@ TypePtr GatherInferType(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - const int64_t kInputsNum = 3; + const int64_t kInputsNum = 2; CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputsNum, primitive->name()); auto infer_type = GatherInferType(primitive, input_args); auto infer_shape = GatherInferShape(primitive, input_args); diff --git a/mindspore/core/ops/grad/avg_pool_3d_grad.cc b/mindspore/core/ops/grad/avg_pool_3d_grad.cc index 428cc6b92cc..fd73562ceaf 100644 --- a/mindspore/core/ops/grad/avg_pool_3d_grad.cc +++ b/mindspore/core/ops/grad/avg_pool_3d_grad.cc @@ -22,28 +22,76 @@ #include "utils/check_convert_utils.h" #include "abstract/ops/primitive_infer_map.h" #include "mindapi/src/helper.h" +#include "utils/ms_context.h" namespace mindspore { namespace ops { -namespace { +void AvgPool3DGrad::Init(const std::vector &kernel_size, const std::vector &strides, + const PadMode &pad_mode, const std::vector &pad_list, bool ceil_mode, + bool count_include_pad, int64_t divisor_override, const Format &format) { + set_kernel_size(kernel_size); + set_strides(strides); + set_pad_mode(pad_mode); + set_pad_list(pad_list); + set_ceil_mode(ceil_mode); + set_count_include_pad(count_include_pad); + set_divisor_override(divisor_override); + set_format(format); +} + +void AvgPool3DGrad::set_pad_list(const std::vector &pad_list) { + const int64_t pad_size = 4; + (void)CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, pad_size, name()); + (void)AddAttr(kPadList, api::MakeValue(pad_list)); +} + +void AvgPool3DGrad::set_ceil_mode(bool ceil_mode) { (void)AddAttr(kCeilMode, api::MakeValue(ceil_mode)); } + +void AvgPool3DGrad::set_count_include_pad(bool count_include_pad) { + (void)AddAttr(kCountIncludePad, api::MakeValue(count_include_pad)); +} + +void AvgPool3DGrad::set_divisor_override(int64_t divisor_override) { + (void)AddAttr(kDivisorOverride, api::MakeValue(divisor_override)); +} + +std::vector AvgPool3DGrad::get_pad_list() const { + auto value_ptr = GetAttr(kPadList); + return GetValue>(value_ptr); +} + +bool AvgPool3DGrad::get_ceil_mode() const { return GetValue(GetAttr(kCeilMode)); } + +bool AvgPool3DGrad::get_count_include_pad() const { return GetValue(GetAttr(kCountIncludePad)); } + +int64_t AvgPool3DGrad::get_divisor_override() const { return GetValue(GetAttr(kDivisorOverride)); } + abstract::ShapePtr AvgPool3DGradInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - const int64_t input_num = 2; - (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, op_name); + const int64_t input_num = 1; + (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, + op_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]; + size_t grad_index = input_args.size() - 1; + auto grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[grad_index]->GetShapeTrack())[kShape]; constexpr int64_t k5DInputDims = 5; (void)CheckAndConvertUtils::CheckInteger("grad_rank", SizeToLong(grad_shape.size()), kEqual, k5DInputDims, op_name); std::vector origin_input_size; - if (input_args[0]->isa()) { // origin_size is tuple - origin_input_size = GetValue>(input_args[0]->BuildValue()); + if (SizeToLong(input_args.size()) == input_num) { + auto shape_attr = primitive->GetAttr("origin_input_shape"); + MS_EXCEPTION_IF_NULL(shape_attr); + origin_input_size = GetValue(shape_attr); } else { - MS_LOG(EXCEPTION) << "For '" << op_name << "', the first input data size must be a tuple, but got: " - << input_args[0]->BuildShape()->ToString() << "."; + if (input_args[0]->isa()) { // origin_size is tuple + origin_input_size = GetValue>(input_args[0]->BuildValue()); + } else { + MS_LOG(EXCEPTION) << "For '" << op_name << "', the first input data size must be a tuple, but got: " + << input_args[0]->BuildShape()->ToString() << "."; + } } return std::make_shared(origin_input_size); } @@ -51,18 +99,25 @@ abstract::ShapePtr AvgPool3DGradInferShape(const PrimitivePtr &primitive, TypePtr AvgPool3DGradInferType(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto op_name = primitive->name(); - const int64_t input_num = 2; - (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, op_name); + const int64_t input_num = 1; + (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kGreaterEqual, input_num, + op_name); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } - auto grad_dtype = input_args[1]->BuildType(); - const std::set valid_types = {kFloat16, kFloat32}; + size_t grad_index = input_args.size() - 1; + auto grad_dtype = input_args[grad_index]->BuildType(); + std::set valid_types = {kFloat16, kFloat32}; + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool is_gpu = (context->get_param(MS_CTX_DEVICE_TARGET) == kGPUDevice); + if (is_gpu) { + valid_types = {kFloat16, kFloat32, kFloat64}; + } return CheckAndConvertUtils::CheckTensorTypeValid("grad", grad_dtype, valid_types, op_name); } -} // namespace -MIND_API_OPERATOR_IMPL(AvgPool3DGrad, BaseOperator); +MIND_API_OPERATOR_IMPL(AvgPool3DGrad, PoolGrad); AbstractBasePtr AvgPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { auto res = std::make_shared(AvgPool3DGradInferType(primitive, input_args), diff --git a/mindspore/core/ops/grad/avg_pool_3d_grad.h b/mindspore/core/ops/grad/avg_pool_3d_grad.h index ea58adfc241..bcdbb19d206 100644 --- a/mindspore/core/ops/grad/avg_pool_3d_grad.h +++ b/mindspore/core/ops/grad/avg_pool_3d_grad.h @@ -21,17 +21,33 @@ #include #include #include -#include "ops/base_operator.h" +#include "ops/grad/pool_grad.h" #include "mindapi/base/types.h" +#include "mindapi/base/format.h" namespace mindspore { namespace ops { constexpr auto kNameAvgPool3DGrad = "AvgPool3DGrad"; -class MIND_API AvgPool3DGrad : public BaseOperator { +class MIND_API AvgPool3DGrad : public PoolGrad { public: MIND_API_BASE_MEMBER(AvgPool3DGrad); - AvgPool3DGrad() : BaseOperator(kNameAvgPool3DGrad) { InitIOName({"origin_input_size", "grad"}, {"output"}); } + AvgPool3DGrad() : PoolGrad(kNameAvgPool3DGrad) { InitIOName({"origin_input_size", "grad"}, {"output"}); } + + void Init(const std::vector &kernel_size = {1, 1, 1, 1, 1}, + const std::vector &strides = {1, 1, 1, 1, 1}, const PadMode &pad_mode = VALID, + const std::vector &pad_list = {0, 0, 0, 0, 0, 0}, bool ceil_mode = false, + bool count_include_pad = true, int64_t divisor_override = 0, const Format &format = NCHW); + + void set_pad_list(const std::vector &pad_list); + void set_ceil_mode(bool ceil_mode); + void set_count_include_pad(bool count_include_pad); + void set_divisor_override(int64_t divisor_override); + + std::vector get_pad_list() const; + bool get_ceil_mode() const; + bool get_count_include_pad() const; + int64_t get_divisor_override() const; }; abstract::AbstractBasePtr AvgPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/grad/max_pool_3d_grad.cc b/mindspore/core/ops/grad/max_pool_3d_grad.cc new file mode 100644 index 00000000000..fb73c8cdf1f --- /dev/null +++ b/mindspore/core/ops/grad/max_pool_3d_grad.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ops/grad/max_pool_3d_grad.h" +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/ops/primitive_infer_map.h" +#include "mindapi/src/helper.h" + +namespace mindspore { +namespace ops { +void MaxPool3DGrad::Init(const std::vector &kernel_size, const std::vector &strides, + const PadMode &pad_mode, const std::vector &pad_list, const Format &format) { + set_kernel_size(kernel_size); + set_strides(strides); + set_pad_mode(pad_mode); + set_pad_list(pad_list); + set_format(format); +} + +void MaxPool3DGrad::set_pad_list(const std::vector &pad_list) { + const int64_t pad_size = 4; + (void)CheckAndConvertUtils::CheckInteger(kPadList, SizeToLong(pad_list.size()), kEqual, pad_size, name()); + (void)AddAttr(kPadList, api::MakeValue(pad_list)); +} + +std::vector MaxPool3DGrad::get_pad_list() const { + auto value_ptr = GetAttr(kPadList); + return GetValue>(value_ptr); +} + +abstract::ShapePtr MaxPool3DGradInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, op_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]; + constexpr int64_t k5DInputDims = 5; + (void)CheckAndConvertUtils::CheckInteger("input_rank", SizeToLong(x_shape.size()), kEqual, k5DInputDims, op_name); + return std::make_shared(x_shape); +} + +TypePtr MaxPool3DGradInferType(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + const int64_t input_num = 3; + (void)CheckAndConvertUtils::CheckInteger("input size", SizeToLong(input_args.size()), kEqual, input_num, op_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto x_dtype = input_args[0]->BuildType(); + const std::set valid_types = {kFloat16, kFloat32}; + return CheckAndConvertUtils::CheckTensorTypeValid("input", x_dtype, valid_types, op_name); +} + +MIND_API_OPERATOR_IMPL(MaxPool3DGrad, PoolGrad); +AbstractBasePtr MaxPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto res = std::make_shared(MaxPool3DGradInferType(primitive, input_args), + MaxPool3DGradInferShape(primitive, input_args)->shape()); + return res; +} + +REGISTER_PRIMITIVE_EVAL_IMPL(MaxPool3DGrad, prim::kPrimMaxPool3DGrad, MaxPool3DGradInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/grad/max_pool_3d_grad.h b/mindspore/core/ops/grad/max_pool_3d_grad.h new file mode 100644 index 00000000000..62c9dbb6ae2 --- /dev/null +++ b/mindspore/core/ops/grad/max_pool_3d_grad.h @@ -0,0 +1,47 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_MAX_POOL_3D_GRAD_H_ +#define MINDSPORE_CORE_OPS_MAX_POOL_3D_GRAD_H_ +#include +#include +#include +#include +#include "ops/grad/pool_grad.h" +#include "mindapi/base/types.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameMaxPool3DGrad = "MaxPool3DGrad"; +class MIND_API MaxPool3DGrad : public PoolGrad { + public: + MIND_API_BASE_MEMBER(MaxPool3DGrad); + MaxPool3DGrad() : PoolGrad(kNameMaxPool3DGrad) { InitIOName({"x_origin", "out_origin", "grad"}, {"output"}); } + + void Init(const std::vector &kernel_size = {1, 1, 1, 1, 1}, + const std::vector &strides = {1, 1, 1, 1, 1}, const PadMode &pad_mode = VALID, + const std::vector &pad_list = {0, 0, 0, 0, 0, 0}, const Format &format = NCHW); + void set_pad_list(const std::vector &pad_list); + std::vector get_pad_list() const; +}; + +abstract::AbstractBasePtr MaxPool3DGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimMaxPool3DGradPtr = std::shared_ptr; +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_MAX_POOL_3D_GRAD_H_ diff --git a/mindspore/core/ops/grad/pool_grad.cc b/mindspore/core/ops/grad/pool_grad.cc index bcf337b4e76..6db1a8e5556 100644 --- a/mindspore/core/ops/grad/pool_grad.cc +++ b/mindspore/core/ops/grad/pool_grad.cc @@ -15,11 +15,26 @@ */ #include "ops/grad/pool_grad.h" +#include #include "ops/op_utils.h" #include "mindapi/src/helper.h" namespace mindspore { namespace ops { +static std::map pad_map = { + {"CALCULATED", PadMode::PAD}, + {"PAD", PadMode::PAD}, + {"SAME", PadMode::SAME}, + {"VALID", PadMode::VALID}, +}; + +static std::map dataformat_map = { + {"NCHW", Format::NCHW}, + {"NHWC", Format::NHWC}, + {"NCDHW", Format::NCDHW}, + {"NDHWC", Format::NDHWC}, +}; + MIND_API_OPERATOR_IMPL(PoolGrad, BaseOperator); std::vector PoolGrad::_grad_check_vector(const std::string &arg_name, std::vector arg_val, const std::string &op_name) { @@ -91,13 +106,31 @@ std::vector PoolGrad::get_strides() const { PadMode PoolGrad::get_pad_mode() const { auto value_ptr = GetAttr(kPadMode); MS_EXCEPTION_IF_NULL(value_ptr); - return PadMode(GetValue(value_ptr)); + if (!value_ptr->isa()) { + return PadMode(GetValue(value_ptr)); + } + auto attr_value_str = GetValue(value_ptr); + std::transform(attr_value_str.begin(), attr_value_str.end(), attr_value_str.begin(), toupper); + auto iter = pad_map.find(attr_value_str); + if (iter == pad_map.end()) { + MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use CALCULATED, PAD, VALID or SAME"; + } + return PadMode(iter->second); } Format PoolGrad::get_format() const { auto value_ptr = GetAttr(kFormat); MS_EXCEPTION_IF_NULL(value_ptr); - return Format(GetValue(value_ptr)); + if (!value_ptr->isa()) { + return Format(GetValue(value_ptr)); + } + auto attr_value_str = GetValue(value_ptr); + std::transform(attr_value_str.begin(), attr_value_str.end(), attr_value_str.begin(), toupper); + auto iter = dataformat_map.find(attr_value_str); + if (iter == dataformat_map.end()) { + MS_LOG(EXCEPTION) << "Invalid format " << attr_value_str << " use NCHW, NHWC NCDHW or NDHWC"; + } + return Format(iter->second); } REGISTER_PRIMITIVE_C(kNamePoolGrad, PoolGrad); diff --git a/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py b/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py index 6a724927f1f..eea30910045 100644 --- a/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py +++ b/mindspore/python/mindspore/ops/_vmap/vmap_grad_nn_ops.py @@ -130,6 +130,34 @@ def get_avg_pool_grad_vmap_rule(prim, axis_size): return vmap_rule +@vmap_rules_getters.register(G.AvgPool3DGrad) +def get_avg_pool3d_grad_vmap_rule(prim, axis_size): + """VmapRule for `AvgPool3DGrad`.""" + cdhw_reverse_index = -4 + + def vmap_rule(shape_bdim, dy_bdim): + is_all_none, result = vmap_general_preprocess(prim, shape_bdim, dy_bdim) + if is_all_none: + return result + + shape, shape_dim = shape_bdim + dy, dy_dim = dy_bdim + if shape_dim is not None: + _raise_value_error("The source axis of 'origin_input_shape' in 'AvgPool3DGrad' must be None, " + "but got {}.".format(shape_dim)) + dy = _bdim_at_front(dy, dy_dim, axis_size) + dy_shape = F.shape(dy) + dy = F.reshape(dy, (-1,) + dy_shape[cdhw_reverse_index:]) + input_shape = (F.shape(dy)[0],) + shape[cdhw_reverse_index:] + out = prim(input_shape, dy) + out_shape = F.shape(out) + return_shape = dy_shape[:cdhw_reverse_index] + out_shape[cdhw_reverse_index:] + out = F.reshape(out, return_shape) + return (out, 0) + + return vmap_rule + + @vmap_rules_getters.register(G.CdistGrad) def get_cdist_grad_vmap_rule(prim, axis_size): """VmapRule for `cdist grad` operation.""" diff --git a/mindspore/python/mindspore/ops/operations/_grad_ops.py b/mindspore/python/mindspore/ops/operations/_grad_ops.py index abc0ebd5afc..b6bfffb8e2b 100644 --- a/mindspore/python/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/python/mindspore/ops/operations/_grad_ops.py @@ -1046,12 +1046,13 @@ def _get_max_pool3d_grad_pads_by_pad_mode(input_shape, kernel_size, strides, pad return pads -class MaxPool3DGrad(PrimitiveWithInfer): +class MaxPool3DGrad(Primitive): """Gradients of the max pool3d operation.""" @prim_attr_register def __init__(self, kernel_size=(1, 1, 1, 1, 1), strides=(1, 1, 1, 1, 1), pad_mode='VALID', pad_list=0, data_format="NCDHW"): + self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output']) validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name) validator.check_value_type('strides', strides, [int, tuple], self.name) validator.check_value_type('pad_mode', pad_mode, [str], self.name) @@ -1080,15 +1081,6 @@ class MaxPool3DGrad(PrimitiveWithInfer): validator.check_non_negative_int(item, 'pad_list item', self.name) self.add_prim_attr("pad_list", self.pad_list) - def infer_shape(self, x_shape, y_shape, grad_shape): - validator.check_equal_int(len(x_shape), 5, "x rank", self.name) - return x_shape - - def infer_dtype(self, x_dtype, y_dtype, grad_dtype): - args = {'x_dtype': x_dtype, 'y_dtype': y_dtype, 'grad_dtype': grad_dtype} - validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) - return x_dtype - class MaxPool3DGradGrad(PrimitiveWithInfer): """Gradients of the max pool3d grad operation.""" diff --git a/tests/st/ops/gpu/test_avgpool_op.py b/tests/st/ops/gpu/test_avgpool_op.py index 4ad117beba6..fc751c84389 100644 --- a/tests/st/ops/gpu/test_avgpool_op.py +++ b/tests/st/ops/gpu/test_avgpool_op.py @@ -341,3 +341,62 @@ def test_avgpool_grad_vmap(): nest_vmap = vmap(vmap(net, in_axes=in_axes, out_axes=0), in_axes=in_axes, out_axes=0) out = nest_vmap(x, sens) assert out[0].shape == (6, 3, 1, 1, 6, 6) + + +class DynamicShapeAvgPool3DGrad(nn.Cell): + def __init__(self, net, axis=0): + super(DynamicShapeAvgPool3DGrad, self).__init__() + self.net = net + self.unique = P.Unique() + self.gather = P.Gather() + self.axis = axis + + def construct(self, x_shape, sens, indices): + unique_indices, _ = self.unique(indices) + sens = self.gather(sens, unique_indices, self.axis) + return self.net(x_shape, sens) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_avgpool3d_grad_dynamic_shape(): + """ + Feature: AvgPool3dGrad dynamic test. + Description: Run unique and gather ops before AvgPool3dGrad. + Expectation: success. + """ + x_shape = (1, 3, 2, 3, 4) + x = Tensor(np.arange(reduce(lambda x, y: x * y, x_shape))).reshape(x_shape).astype(np.float32) + avgpool = AvgPool(dim=3, kernel_size=2, strides=1, pad_mode='VALID') + expect_output = np.array([[[[[8.5, 9.5, 10.5], + [12.5, 13.5, 14.5]]], + [[[32.5, 33.5, 34.5], + [36.5, 37.5, 38.5]]], + [[[56.5, 57.5, 58.5], + [60.5, 61.5, 62.5]]]]]).astype(np.float32) + + avgpool_grad = AvgPoolGrad(avgpool) + net = DynamicShapeAvgPool3DGrad(avgpool_grad) + sens = Tensor(expect_output) + 1 + indices = Tensor(np.array([0]).astype(np.int32)) + actual_grad = net(x, sens, indices) + expect_grad = np.array([[[[[1.1875, 2.5, 2.75, 1.4375], + [2.875, 6., 6.5, 3.375], + [1.6875, 3.5, 3.75, 1.9375]], + [[1.1875, 2.5, 2.75, 1.4375], + [2.875, 6., 6.5, 3.375], + [1.6875, 3.5, 3.75, 1.9375]]], + [[[4.1875, 8.5, 8.75, 4.4375], + [8.875, 18., 18.5, 9.375], + [4.6875, 9.5, 9.75, 4.9375]], + [[4.1875, 8.5, 8.75, 4.4375], + [8.875, 18., 18.5, 9.375], + [4.6875, 9.5, 9.75, 4.9375]]], + [[[7.1875, 14.5, 14.75, 7.4375], + [14.875, 30., 30.5, 15.375], + [7.6875, 15.5, 15.75, 7.9375]], + [[7.1875, 14.5, 14.75, 7.4375], + [14.875, 30., 30.5, 15.375], + [7.6875, 15.5, 15.75, 7.9375]]]]]).astype(np.float32) + assert np.allclose(actual_grad[0].asnumpy(), expect_grad)