From ca881ec03e887f838f3e4a82050c6e3ddee2fc3f Mon Sep 17 00:00:00 2001 From: panfengfeng Date: Tue, 21 Jul 2020 06:21:07 -0800 Subject: [PATCH] add maxpool_with_argmax/grad cuda kernel --- .../maxpool_with_argmax_grad_impl.cu | 226 ++++++++++++++++++ .../maxpool_with_argmax_grad_impl.cuh | 25 ++ .../gpu/cuda_impl/maxpool_with_argmax_impl.cu | 149 ++++++++++++ .../cuda_impl/maxpool_with_argmax_impl.cuh | 25 ++ .../gpu/nn/maxpool_with_argmax_gpu_kernel.cc | 30 +++ .../gpu/nn/maxpool_with_argmax_gpu_kernel.h | 160 +++++++++++++ .../nn/maxpool_with_argmax_grad_gpu_kernel.cc | 36 +++ .../nn/maxpool_with_argmax_grad_gpu_kernel.h | 168 +++++++++++++ mindspore/ops/operations/nn_ops.py | 3 + model_zoo/official/cv/googlenet/eval.py | 2 +- .../cv/googlenet/scripts/run_train_gpu.sh | 9 +- .../official/cv/googlenet/src/googlenet.py | 102 +++----- model_zoo/official/cv/googlenet/train.py | 2 +- .../gpu/test_maxpool_with_argmax_gpu_op.py | 147 ++++++++++++ .../test_maxpool_with_argmax_grad_gpu_op.py | 115 +++++++++ 15 files changed, 1128 insertions(+), 71 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_maxpool_with_argmax_gpu_op.py create mode 100644 tests/st/ops/gpu/test_maxpool_with_argmax_grad_gpu_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cu new file mode 100644 index 00000000000..863f3a7a851 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cu @@ -0,0 +1,226 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "maxpool_with_argmax_grad_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" + +template +__global__ void MaxPoolWithArgmaxGrad(const T* x, + const T* dy, + const S* index, + const int n, + const int c, + const int xHeight, + const int xWidth, + const int dyHeight, + const int dyWidth, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int xNCHW, + const int xCHW, + const int xHW, + const int dyCHW, + const int dyHW, + T* dx) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; + pos < (xNCHW); + pos += blockDim.x * gridDim.x) { + const int posn = pos / xCHW; + const int posc = pos / xHW % c; + const int posh = pos / xHeight % xHeight; + const int posw = pos % xWidth; + const S posIdx = posh*xWidth + posw; + int hstart = posh+padTop; + if (hstart < windowHeight) { + hstart = 0; + } else { + hstart = (hstart-windowHeight)/strideHeight + 1; + } + int wstart = posw+padLeft; + if (wstart < windowWidth) { + wstart = 0; + } else { + wstart = (wstart-windowWidth)/strideWidth + 1; + } + const int hend = min((posh+padTop)/strideHeight +1, dyHeight); + const int wend = min((posw+padLeft)/strideWidth +1, dyWidth); + const int channelStart = posn*dyCHW + posc*dyHW; + T dySum = static_cast(0.0); + for (int hcur = hstart; hcur < hend; ++hcur) { + for (int wcur = wstart; wcur < wend; ++wcur) { + const int curIdx = hcur*dyWidth + wcur; + S maxIdx = index[channelStart+curIdx]; + if (maxIdx == posIdx) { + dySum += dy[channelStart+curIdx]; + } + } + } + dx[pos] = dySum; + } + return; +} + +template <> +__global__ void MaxPoolWithArgmaxGrad(const half* x, + const half* dy, + const int* index, + const int n, + const int c, + const int xHeight, + const int xWidth, + const int dyHeight, + const int dyWidth, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int xNCHW, + const int xCHW, + const int xHW, + const int dyCHW, + const int dyHW, + half* dx) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; + pos < (xNCHW); + pos += blockDim.x * gridDim.x) { + const int posn = pos / xCHW; + const int posc = pos / xHW % c; + const int posh = pos / xHeight % xHeight; + const int posw = pos % xWidth; + const int posIdx = posh*xWidth + posw; + int hstart = posh+padTop; + if (hstart < windowHeight) { + hstart = 0; + } else { + hstart = (hstart-windowHeight)/strideHeight + 1; + } + int wstart = posw+padLeft; + if (wstart < windowWidth) { + wstart = 0; + } else { + wstart = (wstart-windowWidth)/strideWidth + 1; + } + const int hend = min((posh+padTop)/strideHeight +1, dyHeight); + const int wend = min((posw+padLeft)/strideWidth +1, dyWidth); + const int channelStart = posn*dyCHW + posc*dyHW; + float dySum = 0.0f; + for (int hcur = hstart; hcur < hend; ++hcur) { + for (int wcur = wstart; wcur < wend; ++wcur) { + const int curIdx = hcur*dyWidth + wcur; + int maxIdx = index[channelStart+curIdx]; + if (maxIdx == posIdx) { + dySum += __half2float(dy[channelStart+curIdx]); + } + } + } + dx[pos] = __float2half(dySum); + } + return; +} + +template +void CalMaxPoolWithArgmaxGrad(const T* x, + const T* dy, + const S* index, + const int n, + const int c, + const int xHeight, + const int xWidth, + const int dyHeight, + const int dyWidth, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + T* dx, + cudaStream_t cuda_stream) { + const int xHW = xHeight*xWidth; + const int xCHW = c*xHW; + const int xNCHW = n*xCHW; + const int dyHW = dyHeight*dyWidth; + const int dyCHW = c*dyHW; + MaxPoolWithArgmaxGrad<<>>( + x, + dy, + index, + n, + c, + xHeight, + xWidth, + dyHeight, + dyWidth, + windowHeight, + windowWidth, + strideHeight, + strideWidth, + padTop, + padLeft, + xNCHW, + xCHW, + xHW, + dyCHW, + dyHW, + dx); + return; +} + +template void CalMaxPoolWithArgmaxGrad(const float* x, + const float* dy, + const int* index, + const int n, + const int c, + const int xHeight, + const int xWidth, + const int dyHeight, + const int dyWidth, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + float* dx, + cudaStream_t cuda_stream); +template void CalMaxPoolWithArgmaxGrad(const half* x, + const half* dy, + const int* index, + const int n, + const int c, + const int xHeight, + const int xWidth, + const int dyHeight, + const int dyWidth, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + half* dx, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cuh new file mode 100644 index 00000000000..fe378acec6f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_ +template +void CalMaxPoolWithArgmaxGrad(const T* x, const T* dy, const S* index, const int n, const int c, const int xHeight, + const int xWidth, const int dyHeight, const int dyWidth, const int windowHeight, + const int windowWidth, const int strideHeight, const int strideWidth, const int padTop, + const int padLeft, T* dx, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cu new file mode 100644 index 00000000000..7126a3feda4 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cu @@ -0,0 +1,149 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "maxpool_with_argmax_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +#include "include/cuda_fp16.h" +template +__global__ void MaxPoolWithArgmax(const T* input, + const int n, + const int c, + const int h, + const int w, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int outputHeight, + const int outputWidth, + const int outputNCHW, + const int outputCHW, + const int outputHW, + T* output, + S *index) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; + pos < (outputNCHW); + pos += blockDim.x * gridDim.x) { + const int posn = pos / outputCHW; + const int posc = pos / outputHW % c; + const int posh = pos / outputHeight % outputHeight; + const int posw = pos % outputWidth; + int hstart = posh * strideHeight - padTop; + int wstart = posw * strideWidth - padLeft; + const int hend = min(hstart + windowHeight, h); + const int wend = min(wstart + windowWidth, w); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + S inputStart = posn*c*h*w + posc*h*w; + S maxIdx = hstart*w + wstart; + T maxData = input[inputStart+maxIdx]; + for (int hcur = hstart; hcur < hend; ++hcur) { + for (int wcur = wstart; wcur < wend; ++wcur) { + S inputIdx = hcur*w + wcur; + T inputData = input[inputStart+inputIdx]; + if (inputData > maxData) { + maxIdx = inputIdx; + maxData = inputData; + } + } + } + output[pos] = maxData; + index[pos] = maxIdx; + } + return; +} + +template +void CalMaxPoolWithArgmax(const T* input, + const int n, + const int c, + const int h, + const int w, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int outputHeight, + const int outputWidth, + T* output, + S *index, + cudaStream_t cuda_stream) { + const int outputNCHW = n*c*outputHeight*outputWidth; + const int outputCHW = c*outputHeight*outputWidth; + const int outputHW = outputHeight*outputWidth; + MaxPoolWithArgmax<<>>( + input, + n, + c, + h, + w, + windowHeight, + windowWidth, + strideHeight, + strideWidth, + padTop, + padLeft, + outputHeight, + outputWidth, + outputNCHW, + outputCHW, + outputHW, + output, + index); + return; +} + +template void CalMaxPoolWithArgmax(const float* input, + const int n, + const int c, + const int h, + const int w, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int outputHeight, + const int outputWidth, + float* output, + int* index, + cudaStream_t cuda_stream); + +template void CalMaxPoolWithArgmax(const half* input, + const int n, + const int c, + const int h, + const int w, + const int windowHeight, + const int windowWidth, + const int strideHeight, + const int strideWidth, + const int padTop, + const int padLeft, + const int outputHeight, + const int outputWidth, + half* output, + int* index, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cuh new file mode 100644 index 00000000000..8b088067edc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_H_ +template +void CalMaxPoolWithArgmax(const T* input, const int n, const int c, const int h, const int w, const int windowHeight, + const int windowWidth, const int strideHeight, const int strideWidth, const int padTop, + const int padLeft, const int outputHeight, const int outputWidth, T* output, S *index, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.cc new file mode 100644 index 00000000000..1866c834668 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.cc @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + MaxPoolWithArgmax, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), + MaxPoolWithArgmaxGpuFwdKernel, float, int) +MS_REG_GPU_KERNEL_TWO( + MaxPoolWithArgmax, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), + MaxPoolWithArgmaxGpuFwdKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h new file mode 100644 index 00000000000..aef408c4033 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_gpu_kernel.h @@ -0,0 +1,160 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class MaxPoolWithArgmaxGpuFwdKernel : public GpuKernel { + public: + MaxPoolWithArgmaxGpuFwdKernel() + : n_(0), + c_(0), + input_height_(0), + input_width_(0), + window_height_(0), + window_width_(0), + pad_height_(0), + pad_width_(0), + pad_top_(0), + pad_left_(0), + stride_height_(0), + stride_width_(0), + output_height_(0), + output_width_(0), + input_size_(0), + output_size_(0) {} + ~MaxPoolWithArgmaxGpuFwdKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + S *index_addr = GetDeviceAddress(outputs, 1); + CalMaxPoolWithArgmax(input_addr, n_, c_, input_height_, input_width_, window_height_, window_width_, stride_height_, + stride_width_, pad_top_, pad_left_, output_height_, output_width_, output_addr, index_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but MaxPoolWithArgmax needs 1 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 2) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but MaxPoolWithArgmax needs 2 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); + for (auto x : input_shape) { + input_size_ *= x; + } + output_size_ = sizeof(T); + for (auto x : output_shape) { + output_size_ *= x; + } + n_ = SizeToInt(input_shape[0]); + c_ = SizeToInt(input_shape[1]); + input_height_ = SizeToInt(input_shape[2]); + input_width_ = SizeToInt(input_shape[3]); + output_height_ = SizeToInt(output_shape[2]); + output_width_ = SizeToInt(output_shape[3]); + auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); + window_height_ = window[1]; + window_width_ = window[2]; + auto stride = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); + stride_height_ = stride[1]; + stride_width_ = stride[2]; + pad_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); + pad_top_ = 0; + pad_left_ = 0; + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + output_size_list_.push_back(output_size_ / sizeof(T) * sizeof(S)); + } + + private: + void SetPad() { + pad_height_ = std::max( + 0, (((input_height_ / stride_height_) * stride_height_ == input_height_ ? (input_height_ / stride_height_) + : (input_height_ / stride_height_) + 1) - + 1) * + stride_height_ + + window_height_ - input_height_); + pad_width_ = std::max( + 0, (((input_width_ / stride_width_) * stride_width_ == input_width_ ? (input_width_ / stride_width_) + : (input_width_ / stride_width_) + 1) - + 1) * + stride_width_ + + window_width_ - input_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + } + + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int n_; + int c_; + int input_height_; + int input_width_; + int window_height_; + int window_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int stride_height_; + int stride_width_; + int output_height_; + int output_width_; + + size_t input_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.cc new file mode 100644 index 00000000000..954a5cfbf9b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO(MaxPoolGradWithArgmax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat32), + MaxPoolWithArgmaxGradGpuKernel, float, int) +MS_REG_GPU_KERNEL_TWO(MaxPoolGradWithArgmax, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeFloat16), + MaxPoolWithArgmaxGradGpuKernel, half, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.h new file mode 100644 index 00000000000..9d90e2d9f4b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/maxpool_with_argmax_grad_gpu_kernel.h @@ -0,0 +1,168 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GRAD_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/maxpool_with_argmax_grad_impl.cuh" +#include "backend/kernel_compiler/gpu/kernel_constants.h" + +namespace mindspore { +namespace kernel { +template +class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel { + public: + MaxPoolWithArgmaxGradGpuKernel() + : n_(0), + c_(0), + x_height_(0), + x_width_(0), + dy_height_(0), + dy_width_(0), + x_size_(0), + dy_size_(0), + index_size_(0), + dx_size_(0) {} + ~MaxPoolWithArgmaxGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + T *x_addr = GetDeviceAddress(inputs, 0); + T *dy_addr = GetDeviceAddress(inputs, 1); + S *index_addr = GetDeviceAddress(inputs, 2); + T *dx_addr = GetDeviceAddress(outputs, 0); + CalMaxPoolWithArgmaxGrad(x_addr, dy_addr, index_addr, n_, c_, x_height_, x_width_, dy_height_, dy_width_, + window_height_, window_width_, stride_height_, stride_width_, pad_top_, pad_left_, dx_addr, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but MaxPoolGradWithArgmax needs 3 inputs."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but MaxPoolGradWithArgmax needs 1 output."; + return false; + } + auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + auto dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + auto index_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + auto dx_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + x_size_ = sizeof(T); + for (auto x : x_shape) { + x_size_ *= x; + } + dy_size_ = sizeof(T); + for (auto x : dy_shape) { + dy_size_ *= x; + } + index_size_ = sizeof(S); + for (auto x : index_shape) { + index_size_ *= x; + } + dx_size_ = sizeof(T); + for (auto x : dx_shape) { + dx_size_ *= x; + } + n_ = SizeToInt(x_shape[0]); + c_ = SizeToInt(x_shape[1]); + x_height_ = SizeToInt(x_shape[2]); + x_width_ = SizeToInt(x_shape[3]); + dy_height_ = SizeToInt(dy_shape[2]); + dy_width_ = SizeToInt(dy_shape[3]); + auto window = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize")); + window_height_ = window[1]; + window_width_ = window[2]; + auto stride = GetValue>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides")); + stride_height_ = stride[1]; + stride_width_ = stride[2]; + pad_mode_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding")); + pad_top_ = 0; + pad_left_ = 0; + if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) { + SetPad(); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(x_size_); + input_size_list_.push_back(dy_size_); + input_size_list_.push_back(index_size_); + output_size_list_.push_back(dx_size_); + } + + private: + void SetPad() { + pad_height_ = std::max( + 0, (((x_height_ / stride_height_) * stride_height_ == x_height_ ? (x_height_ / stride_height_) + : (x_height_ / stride_height_) + 1) - + 1) * + stride_height_ + + window_height_ - x_height_); + pad_width_ = + std::max(0, (((x_width_ / stride_width_) * stride_width_ == x_width_ ? (x_width_ / stride_width_) + : (x_width_ / stride_width_) + 1) - + 1) * + stride_width_ + + window_width_ - x_width_); + pad_top_ = pad_height_ / 2; + pad_left_ = pad_width_ / 2; + } + + std::string pad_mode_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + + int n_; + int c_; + int x_height_; + int x_width_; + int dy_height_; + int dy_width_; + int window_height_; + int window_width_; + int pad_height_; + int pad_width_; + int pad_top_; + int pad_left_; + int stride_height_; + int stride_width_; + + size_t x_size_; + size_t dy_size_; + size_t index_size_; + size_t dx_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_MAXPOOLWITHARGMAX_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index a5160d79c28..95289dd23a3 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1181,6 +1181,7 @@ class MaxPoolWithArgmax(_Pool): def __init__(self, ksize=1, strides=1, padding="valid"): super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding) self.is_tbe = context.get_context("device_target") == "Ascend" + self.is_gpu = context.get_context("device_target") == "GPU" def infer_shape(self, x_shape): out_shape = _Pool.infer_shape(self, x_shape) @@ -1207,6 +1208,8 @@ class MaxPoolWithArgmax(_Pool): out_dtype = x_dtype validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) argmax_dtype = mstype.uint16 + if self.is_gpu: + argmax_dtype = mstype.int32 return out_dtype, argmax_dtype diff --git a/model_zoo/official/cv/googlenet/eval.py b/model_zoo/official/cv/googlenet/eval.py index 045aba56df4..31646c97135 100644 --- a/model_zoo/official/cv/googlenet/eval.py +++ b/model_zoo/official/cv/googlenet/eval.py @@ -38,7 +38,7 @@ if __name__ == '__main__': if device_target == "Ascend": context.set_context(device_id=cfg.device_id) - net = GoogleNet(num_classes=cfg.num_classes, platform=device_target) + net = GoogleNet(num_classes=cfg.num_classes) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, weight_decay=cfg.weight_decay) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) diff --git a/model_zoo/official/cv/googlenet/scripts/run_train_gpu.sh b/model_zoo/official/cv/googlenet/scripts/run_train_gpu.sh index 3f2419f7eb8..507e556dea3 100644 --- a/model_zoo/official/cv/googlenet/scripts/run_train_gpu.sh +++ b/model_zoo/official/cv/googlenet/scripts/run_train_gpu.sh @@ -41,5 +41,10 @@ mkdir ../train cd ../train || exit export CUDA_VISIBLE_DEVICES="$2" -mpirun -n $1 --allow-run-as-root \ -python3 ${BASEPATH}/../train.py > train.log 2>&1 & +if [ $1 -gt 1 ] +then + mpirun -n $1 --allow-run-as-root \ + python3 ${BASEPATH}/../train.py > train.log 2>&1 & +else + python3 ${BASEPATH}/../train.py > train.log 2>&1 & +fi diff --git a/model_zoo/official/cv/googlenet/src/googlenet.py b/model_zoo/official/cv/googlenet/src/googlenet.py index 20bd96454b9..701b3aeb5a9 100644 --- a/model_zoo/official/cv/googlenet/src/googlenet.py +++ b/model_zoo/official/cv/googlenet/src/googlenet.py @@ -56,35 +56,24 @@ class Inception(nn.Cell): Inception Block """ - def __init__(self, in_channels, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes, platform="Ascend"): + def __init__(self, in_channels, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): super(Inception, self).__init__() - self.platform = platform self.b1 = Conv2dBlock(in_channels, n1x1, kernel_size=1) self.b2 = nn.SequentialCell([Conv2dBlock(in_channels, n3x3red, kernel_size=1), Conv2dBlock(n3x3red, n3x3, kernel_size=3, padding=0)]) self.b3 = nn.SequentialCell([Conv2dBlock(in_channels, n5x5red, kernel_size=1), Conv2dBlock(n5x5red, n5x5, kernel_size=3, padding=0)]) - if self.platform == "Ascend": - self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=1, padding="same") - else: # GPU - self.maxpool = P.MaxPool(ksize=3, strides=1, padding="same") + self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=1, padding="same") self.b4 = Conv2dBlock(in_channels, pool_planes, kernel_size=1) self.concat = P.Concat(axis=1) def construct(self, x): - ''' - construct inception model - ''' branch1 = self.b1(x) branch2 = self.b2(x) branch3 = self.b3(x) - if self.platform == "Ascend": - cell, argmax = self.maxpool(x) - branch4 = self.b4(cell) - _ = argmax - else: # GPU - cell = self.maxpool(x) - branch4 = self.b4(cell) + cell, argmax = self.maxpool(x) + branch4 = self.b4(cell) + _ = argmax return self.concat((branch1, branch2, branch3, branch4)) @@ -93,82 +82,61 @@ class GoogleNet(nn.Cell): Googlenet architecture """ - def __init__(self, num_classes, platform="Ascend"): + def __init__(self, num_classes): super(GoogleNet, self).__init__() self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0) - self.platform = platform + self.maxpool1 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same") + self.conv2 = Conv2dBlock(64, 64, kernel_size=1) self.conv3 = Conv2dBlock(64, 192, kernel_size=3, padding=0) - self.block3a = Inception(192, 64, 96, 128, 16, 32, 32, platform=self.platform) - self.block3b = Inception(256, 128, 128, 192, 32, 96, 64, platform=self.platform) - self.block4a = Inception(480, 192, 96, 208, 16, 48, 64, platform=self.platform) - self.block4b = Inception(512, 160, 112, 224, 24, 64, 64, platform=self.platform) - self.block4c = Inception(512, 128, 128, 256, 24, 64, 64, platform=self.platform) - self.block4d = Inception(512, 112, 144, 288, 32, 64, 64, platform=self.platform) - self.block4e = Inception(528, 256, 160, 320, 32, 128, 128, platform=self.platform) - if self.platform == "Ascend": - self.maxpool1 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same") - self.maxpool2 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same") - self.maxpool3 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same") - self.maxpool4 = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="same") - else: # GPU - self.maxpool1 = P.MaxPool(ksize=3, strides=2, padding="same") - self.maxpool2 = P.MaxPool(ksize=3, strides=2, padding="same") - self.maxpool3 = P.MaxPool(ksize=3, strides=2, padding="same") - self.maxpool4 = P.MaxPool(ksize=2, strides=2, padding="same") - self.block5a = Inception(832, 256, 160, 320, 32, 128, 128, platform=self.platform) - self.block5b = Inception(832, 384, 192, 384, 48, 128, 128, platform=self.platform) + self.maxpool2 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same") + + self.block3a = Inception(192, 64, 96, 128, 16, 32, 32) + self.block3b = Inception(256, 128, 128, 192, 32, 96, 64) + self.maxpool3 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same") + + self.block4a = Inception(480, 192, 96, 208, 16, 48, 64) + self.block4b = Inception(512, 160, 112, 224, 24, 64, 64) + self.block4c = Inception(512, 128, 128, 256, 24, 64, 64) + self.block4d = Inception(512, 112, 144, 288, 32, 64, 64) + self.block4e = Inception(528, 256, 160, 320, 32, 128, 128) + self.maxpool4 = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="same") + + self.block5a = Inception(832, 256, 160, 320, 32, 128, 128) + self.block5b = Inception(832, 384, 192, 384, 48, 128, 128) + self.mean = P.ReduceMean(keep_dims=True) self.dropout = nn.Dropout(keep_prob=0.8) self.flatten = nn.Flatten() self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(), bias_init=weight_variable()) + def construct(self, x): - ''' - construct googlenet model - ''' x = self.conv1(x) - if self.platform == "Ascend": - x, argmax = self.maxpool1(x) - else: # GPU - x = self.maxpool1(x) + x, argmax = self.maxpool1(x) x = self.conv2(x) x = self.conv3(x) - if self.platform == "Ascend": - x, argmax = self.maxpool2(x) - else: # GPU - x = self.maxpool2(x) + x, argmax = self.maxpool2(x) x = self.block3a(x) x = self.block3b(x) - if self.platform == "Ascend": - x, argmax = self.maxpool3(x) - else: # GPU - x = self.maxpool3(x) + x, argmax = self.maxpool3(x) x = self.block4a(x) x = self.block4b(x) x = self.block4c(x) x = self.block4d(x) x = self.block4e(x) - if self.platform == "Ascend": - x, argmax = self.maxpool4(x) - x = self.block5a(x) - x = self.block5b(x) + x, argmax = self.maxpool4(x) - x = self.mean(x, (2, 3)) - x = self.flatten(x) - x = self.classifier(x) - _ = argmax - else: # GPU - x = self.maxpool4(x) - x = self.block5a(x) - x = self.block5b(x) + x = self.block5a(x) + x = self.block5b(x) - x = self.mean(x, (2, 3)) - x = self.flatten(x) - x = self.classifier(x) + x = self.mean(x, (2, 3)) + x = self.flatten(x) + x = self.classifier(x) + _ = argmax return x diff --git a/model_zoo/official/cv/googlenet/train.py b/model_zoo/official/cv/googlenet/train.py index b7668a017f9..50e3ec7bc21 100644 --- a/model_zoo/official/cv/googlenet/train.py +++ b/model_zoo/official/cv/googlenet/train.py @@ -93,7 +93,7 @@ if __name__ == '__main__': dataset = create_dataset(cfg.data_path, 1) batch_num = dataset.get_dataset_size() - net = GoogleNet(num_classes=cfg.num_classes, platform=device_target) + net = GoogleNet(num_classes=cfg.num_classes) # Continue training if set pre_trained to be True if cfg.pre_trained: param_dict = load_checkpoint(cfg.checkpoint_path) diff --git a/tests/st/ops/gpu/test_maxpool_with_argmax_gpu_op.py b/tests/st/ops/gpu/test_maxpool_with_argmax_gpu_op.py new file mode 100644 index 00000000000..a2ff4017382 --- /dev/null +++ b/tests/st/ops/gpu/test_maxpool_with_argmax_gpu_op.py @@ -0,0 +1,147 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +class Net_Pool(nn.Cell): + def __init__(self): + super(Net_Pool, self).__init__() + self.maxpool_fun = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="VALID") + + def construct(self, x): + return self.maxpool_fun(x) + + +class Net_Pool2(nn.Cell): + def __init__(self): + super(Net_Pool2, self).__init__() + self.maxpool_fun = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME") + + def construct(self, x): + return self.maxpool_fun(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_maxpool_with_argmax_2d(): + x = Tensor(np.array([[[ + [0, 1, 2, 3, -4, -5], + [6, 7, 8, 9, -10, -11], + [12, 13, 14, -15, -16, -17], + [18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29], + [30, 31, 32, 33, 34, 35] + ]]]).astype(np.float32)) + expect_result = (np.array([[[ + [7, 9, -4], + [19, 21, 23], + [31, 33, 35] + ]]])) + expect_result2 = (np.array([[[ + [14, 14, -4], + [26, 28, 29], + [32, 34, 35] + ]]])) + expect_index_result = (np.array([[[ + [7, 9, 4], + [19, 21, 23], + [31, 33, 35] + ]]])) + expect__index_result2 = (np.array([[[ + [14, 14, 4], + [26, 28, 29], + [32, 34, 35] + ]]])) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + maxpool2d = Net_Pool() + maxpool2d2 = Net_Pool2() + output2, index2 = maxpool2d2(x) + output, index = maxpool2d(x) + assert (output.asnumpy() == expect_result).all() + assert (output2.asnumpy() == expect_result2).all() + assert (index.asnumpy() == expect_index_result).all() + assert (index2.asnumpy() == expect__index_result2).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + maxpool2d = Net_Pool() + maxpool2d2 = Net_Pool2() + output2, index2 = maxpool2d2(x) + output, index = maxpool2d(x) + assert (output.asnumpy() == expect_result).all() + assert (output2.asnumpy() == expect_result2).all() + assert (index.asnumpy() == expect_index_result).all() + assert (index2.asnumpy() == expect__index_result2).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_maxpool_with_argmax_2d_fp16(): + x = Tensor(np.array([[[ + [0, 1, 2, 3, -4, -5], + [6, 7, 8, 9, -10, -11], + [12, 13, 14, -15, -16, -17], + [18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29], + [30, 31, 32, 33, 34, 35] + ]]]).astype(np.float16)) + expect_result = (np.array([[[ + [7, 9, -4], + [19, 21, 23], + [31, 33, 35] + ]]])) + expect_result2 = (np.array([[[ + [14, 14, -4], + [26, 28, 29], + [32, 34, 35] + ]]])) + expect_index_result = (np.array([[[ + [7, 9, 4], + [19, 21, 23], + [31, 33, 35] + ]]])) + expect__index_result2 = (np.array([[[ + [14, 14, 4], + [26, 28, 29], + [32, 34, 35] + ]]])) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + maxpool2d = Net_Pool() + maxpool2d2 = Net_Pool2() + output2, index2 = maxpool2d2(x) + output, index = maxpool2d(x) + assert (output.asnumpy() == expect_result).all() + assert (output2.asnumpy() == expect_result2).all() + assert (index.asnumpy() == expect_index_result).all() + assert (index2.asnumpy() == expect__index_result2).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + maxpool2d = Net_Pool() + maxpool2d2 = Net_Pool2() + output2, index2 = maxpool2d2(x) + output, index = maxpool2d(x) + assert (output.asnumpy() == expect_result).all() + assert (output2.asnumpy() == expect_result2).all() + assert (index.asnumpy() == expect_index_result).all() + assert (index2.asnumpy() == expect__index_result2).all() + \ No newline at end of file diff --git a/tests/st/ops/gpu/test_maxpool_with_argmax_grad_gpu_op.py b/tests/st/ops/gpu/test_maxpool_with_argmax_grad_gpu_op.py new file mode 100644 index 00000000000..a9ea790ffa4 --- /dev/null +++ b/tests/st/ops/gpu/test_maxpool_with_argmax_grad_gpu_op.py @@ -0,0 +1,115 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class Net_Pool_Grad(nn.Cell): + def __init__(self): + super(Net_Pool_Grad, self).__init__() + self.maxpool_grad_fun = G.MaxPoolGradWithArgmax(padding="VALID", ksize=2, strides=2) + + def construct(self, x, dy, index): + return self.maxpool_grad_fun(x, dy, index) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_maxpool2d_grad(): + x = Tensor(np.array([[[ + [0, 1, 2, 3, 4, 5], + [6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17], + [18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29], + [30, 31, 32, 33, 34, 35] + ]]]).astype(np.float32)) + dy = Tensor(np.array([[[ + [0.7, 0.9, 0.11], + [0.19, 0.21, 0.23], + [0.31, 0.33, 0.35] + ]]]).astype(np.float32)) + index = Tensor(np.array([[[ + [7, 9, 11], + [19, 21, 23], + [31, 33, 35] + ]]]).astype(np.int32)) + expect_result = (np.array([[[ + [0., 0., 0., 0., 0., 0.], + [0., 0.7, 0., 0.9, 0., 0.11], + [0., 0., 0., 0., 0., 0.], + [0., 0.19, 0., 0.21, 0., 0.23], + [0., 0., 0., 0., 0., 0.], + [0., 0.31, 0., 0.33, 0., 0.35] + ]]])) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + maxpool2d_grad = Net_Pool_Grad() + output = maxpool2d_grad(x, dy, index) + assert np.allclose(expect_result, output.asnumpy()) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + maxpool2d_grad = Net_Pool_Grad() + output = maxpool2d_grad(x, dy, index) + assert np.allclose(expect_result, output.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_maxpool2d_grad_fp16(): + x = Tensor(np.array([[[ + [0, 1, 2, 3, 4, 5], + [6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17], + [18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29], + [30, 31, 32, 33, 34, 35] + ]]]).astype(np.float16)) + dy = Tensor(np.array([[[ + [0.7, 0.9, 0.11], + [0.19, 0.21, 0.23], + [0.31, 0.33, 0.35] + ]]]).astype(np.float16)) + index = Tensor(np.array([[[ + [7, 9, 11], + [19, 21, 23], + [31, 33, 35] + ]]]).astype(np.int32)) + expect_result = np.array([[[ + [0., 0., 0., 0., 0., 0.], + [0., 0.7, 0., 0.9, 0., 0.11], + [0., 0., 0., 0., 0., 0.], + [0., 0.19, 0., 0.21, 0., 0.23], + [0., 0., 0., 0., 0., 0.], + [0., 0.31, 0., 0.33, 0., 0.35] + ]]]).astype(np.float16) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + maxpool2d_grad = Net_Pool_Grad() + output = maxpool2d_grad(x, dy, index) + assert np.allclose(expect_result, output.asnumpy()) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + maxpool2d_grad = Net_Pool_Grad() + output = maxpool2d_grad(x, dy, index) + assert np.allclose(expect_result, output.asnumpy())