From de28cdc71cf1a1d068b11a0353023e38fe37bb78 Mon Sep 17 00:00:00 2001 From: TFbunny Date: Thu, 16 Jul 2020 18:54:42 -0400 Subject: [PATCH] add GPU support to ResizeNearestNeighbor --- .../resize_nearest_neighbor_gpu_kernel.cc | 31 +++++ .../resize_nearest_neighbor_gpu_kernel.h | 111 ++++++++++++++++++ ...resize_nearest_neighbor_grad_gpu_kernel.cc | 31 +++++ .../resize_nearest_neighbor_grad_gpu_kernel.h | 111 ++++++++++++++++++ .../resize_nearest_neighbor_grad_impl.cu | 81 +++++++++++++ .../resize_nearest_neighbor_grad_impl.cuh | 28 +++++ .../cuda_impl/resize_nearest_neighbor_impl.cu | 81 +++++++++++++ .../resize_nearest_neighbor_impl.cuh | 28 +++++ mindspore/ops/operations/array_ops.py | 6 +- .../test_resize_nearest_neighbor_grad_op.py | 75 ++++++++++++ .../gpu/test_resize_nearest_neighbor_op.py | 71 +++++++++++ 11 files changed, 651 insertions(+), 3 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh create mode 100644 tests/st/ops/gpu/test_resize_nearest_neighbor_grad_op.py create mode 100644 tests/st/ops/gpu/test_resize_nearest_neighbor_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.cc new file mode 100644 index 00000000000..3e38ca599e5 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ResizeNearestNeighborGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ResizeNearestNeighborGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ResizeNearestNeighborGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h new file mode 100644 index 00000000000..4650b033e54 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h @@ -0,0 +1,111 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ResizeNearestNeighborGpuKernel : public GpuKernel { + public: + ResizeNearestNeighborGpuKernel() : align_corners_(false), shape_size_(0), input_size_(0), output_size_(0) {} + ~ResizeNearestNeighborGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + int size = SizeToInt(output_size_ / sizeof(T)); + float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_); + float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_); + CalResizeNearestNeighbor(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output, + output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], align_corners_, + h_scale, w_scale, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ResizeNearestNeighbor needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + shape_size_ = input_shape.size(); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + if (shape_size_ != RESIZENEARESTNEIGHBOR_DIMENSION) { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " + << RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs."; + return false; + } + input_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + input_size_ *= input_shape[i]; + input_shape_.push_back(input_shape[i]); + } + input_size_ *= sizeof(T); + output_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + output_size_ *= output_shape[i]; + output_shape_.push_back(output_shape[i]); + } + output_size_ *= sizeof(T); + align_corners_ = GetAttr(kernel_node, "align_corners"); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + float Scaling(const int in_size, const int out_size, bool align_corners) { + return (align_corners && out_size > 1) ? (in_size - 1) / static_cast(out_size - 1) + : in_size / static_cast(out_size); + } + + bool align_corners_; + size_t shape_size_; + std::vector input_shape_; + std::vector output_shape_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.cc new file mode 100644 index 00000000000..14a886a4b8b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ResizeNearestNeighborGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ResizeNearestNeighborGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ResizeNearestNeighborGradGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h new file mode 100644 index 00000000000..e32ee44894b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h @@ -0,0 +1,111 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_ + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class ResizeNearestNeighborGradGpuKernel : public GpuKernel { + public: + ResizeNearestNeighborGradGpuKernel() : align_corners_(false), shape_size_(0), input_size_(0), output_size_(0) {} + ~ResizeNearestNeighborGradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + int size = SizeToInt(output_size_ / sizeof(T)); + float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_); + float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_); + CalResizeNearestNeighborGrad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], + output, output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], + align_corners_, h_scale, w_scale, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ResizeNearestNeighbor needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + shape_size_ = input_shape.size(); + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + if (shape_size_ != RESIZENEARESTNEIGHBORGRAD_DIMENSION) { + MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " + << RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs."; + return false; + } + input_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + input_size_ *= input_shape[i]; + input_shape_.push_back(input_shape[i]); + } + input_size_ *= sizeof(T); + output_size_ = 1; + for (size_t i = 0; i < shape_size_; i++) { + output_size_ *= output_shape[i]; + output_shape_.push_back(output_shape[i]); + } + output_size_ *= sizeof(T); + align_corners_ = GetAttr(kernel_node, "align_corners"); + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + float Scaling(const int in_size, const int out_size, bool align_corners) { + return (align_corners && out_size > 1) ? (in_size - 1) / static_cast(out_size - 1) + : in_size / static_cast(out_size); + } + + bool align_corners_; + size_t shape_size_; + std::vector input_shape_; + std::vector output_shape_; + size_t input_size_; + size_t output_size_; + size_t workspace_size_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu new file mode 100644 index 00000000000..546960b1393 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cu @@ -0,0 +1,81 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh" + +template +__global__ void ResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, + const int s4, T *output, const int d1, const int d2, const int d3, + const int d4, bool align_corners, float h_scale, float w_scale) { + // initialization + // HalfPixelCenters false + int input_pos; + int pos_array[RESIZENEARESTNEIGHBORGRAD_DIMENSION]; + int in_height = s3; + int in_width = s4; + // for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] + + // pos_array[1] * output_shape[2] * output_shape[3] + + // pos_array[2] * output_shape[3] + + // pos_array[3] + T h_scale_ = static_cast(h_scale); + T w_scale_ = static_cast(w_scale); + T out_h_; + T out_w_; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + pos_array[0] = pos / (d2 * d3 * d4) % d1; + pos_array[1] = pos / (d3 * d4) % d2; + pos_array[2] = pos / (d4) % d3; + pos_array[3] = pos % d4; + out_h_ = static_cast(pos_array[2]); + out_w_ = static_cast(pos_array[3]); + const int in_y = + min((align_corners) ? static_cast(roundf(out_h_ * h_scale_)) : static_cast(floorf(out_h_ * h_scale_)), + in_height - 1); + const int in_x = + min((align_corners) ? static_cast(roundf(out_w_ * w_scale_)) : static_cast(floorf(out_w_ * w_scale_)), + in_width - 1); + // pos_array[0] N, pos_array[1] C, in_y H, in_x W + input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x; + output[pos] = input[input_pos]; + } + return; +} + +template +void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, + const int s4, T *output, const int d1, const int d2, const int d3, const int d4, + bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream) { + ResizeNearestNeighborGrad<<>>( + size, input, s1, s2, s3, s4, output, d1, d2, d3, d4, align_corners, h_scale, w_scale); + return; +} + +template void CalResizeNearestNeighborGrad(const int size, const float *input, const int s1, const int s2, + const int s3, const int s4, float *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale, cudaStream_t cuda_stream); +template void CalResizeNearestNeighborGrad(const int size, const half *input, const int s1, const int s2, + const int s3, const int s4, half *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale, cudaStream_t cuda_stream); +template void CalResizeNearestNeighborGrad(const int size, const int *input, const int s1, const int s2, + const int s3, const int s4, int *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh new file mode 100644 index 00000000000..d1acdaab514 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_ +#include +#include "runtime/device/gpu/cuda_common.h" +#define RESIZENEARESTNEIGHBORGRAD_DIMENSION 4 + +template +void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, + const int s4, T *output, const int d1, const int d2, const int d3, const int d4, + bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cu new file mode 100644 index 00000000000..4280e33fd3d --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cu @@ -0,0 +1,81 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh" + +template +__global__ void ResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, + const int s4, T *output, const int d1, const int d2, const int d3, const int d4, + bool align_corners, float h_scale, float w_scale) { + // initialization + // HalfPixelCenters false + int input_pos; + int pos_array[RESIZENEARESTNEIGHBOR_DIMENSION]; + int in_height = s3; + int in_width = s4; + // for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] + + // pos_array[1] * output_shape[2] * output_shape[3] + + // pos_array[2] * output_shape[3] + + // pos_array[3] + T h_scale_ = static_cast(h_scale); + T w_scale_ = static_cast(w_scale); + T out_h_; + T out_w_; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + pos_array[0] = pos / (d2 * d3 * d4) % d1; + pos_array[1] = pos / (d3 * d4) % d2; + pos_array[2] = pos / (d4) % d3; + pos_array[3] = pos % d4; + out_h_ = static_cast(pos_array[2]); + out_w_ = static_cast(pos_array[3]); + const int in_y = + min((align_corners) ? static_cast(roundf(out_h_ * h_scale_)) : static_cast(floorf(out_h_ * h_scale_)), + in_height - 1); + const int in_x = + min((align_corners) ? static_cast(roundf(out_w_ * w_scale_)) : static_cast(floorf(out_w_ * w_scale_)), + in_width - 1); + // pos_array[0] N, pos_array[1] C, in_y H, in_x W + input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x; + output[pos] = input[input_pos]; + } + return; +} + +template +void CalResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, const int s4, + T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners, + float h_scale, float w_scale, cudaStream_t cuda_stream) { + ResizeNearestNeighbor<<>>(size, input, s1, s2, s3, s4, output, d1, d2, + d3, d4, align_corners, h_scale, w_scale); + return; +} + +template void CalResizeNearestNeighbor(const int size, const float *input, const int s1, const int s2, + const int s3, const int s4, float *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale, cudaStream_t cuda_stream); +template void CalResizeNearestNeighbor(const int size, const half *input, const int s1, const int s2, + const int s3, const int s4, half *output, const int d1, const int d2, + const int d3, const int d4, bool align_corners, float h_scale, + float w_scale, cudaStream_t cuda_stream); +template void CalResizeNearestNeighbor(const int size, const int *input, const int s1, const int s2, const int s3, + const int s4, int *output, const int d1, const int d2, const int d3, + const int d4, bool align_corners, float h_scale, float w_scale, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh new file mode 100644 index 00000000000..a9eafe36cec --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_ +#include +#include "runtime/device/gpu/cuda_common.h" +#define RESIZENEARESTNEIGHBOR_DIMENSION 4 + +template +void CalResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, const int s4, + T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners, + float h_scale, float w_scale, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 1e28a56db1f..99ae128328f 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2256,10 +2256,10 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): and output tensors are aligned. Default: False. Inputs: - - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`. + - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`. Outputs: - Tensor, the shape of the output tensor is :math:`(N, NEW\_C, NEW\_H, W)`. + Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`. Examples: >>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) @@ -2278,7 +2278,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) def infer_shape(self, x): - validator.check('the dimension of input_x', len(x), '', 2, Rel.GE, self.name) + validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name) return tuple(x)[:-2] + tuple(self.size) def infer_dtype(self, x): diff --git a/tests/st/ops/gpu/test_resize_nearest_neighbor_grad_op.py b/tests/st/ops/gpu/test_resize_nearest_neighbor_grad_op.py new file mode 100644 index 00000000000..70de771e7db --- /dev/null +++ b/tests/st/ops/gpu/test_resize_nearest_neighbor_grad_op.py @@ -0,0 +1,75 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + +class ResizeNearestNeighborGradAlignCornerT(nn.Cell): + def __init__(self): + super(ResizeNearestNeighborGradAlignCornerT, self).__init__() + self.ResizeNearestNeighborGradAlignCornerT = G.ResizeNearestNeighborGrad(align_corners=True) + + def construct(self, dy, size): + return self.ResizeNearestNeighborGradAlignCornerT(dy, size) + +class ResizeNearestNeighborGradAlignCornerF(nn.Cell): + def __init__(self): + super(ResizeNearestNeighborGradAlignCornerF, self).__init__() + self.ResizeNearestNeighborGradAlignCornerF = G.ResizeNearestNeighborGrad(align_corners=False) + + def construct(self, dy, size): + return self.ResizeNearestNeighborGradAlignCornerF(dy, size) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ResizeNearestNeighborGradAlignCornerT(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) + size = (2, 2) + expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32) + rnn = ResizeNearestNeighborGradAlignCornerT() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) + size = (2, 2) + expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16) + rnn = ResizeNearestNeighborGradAlignCornerT() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ResizeNearestNeighborGradAlignCornerF(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) + size = (2, 2) + expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32) + rnn = ResizeNearestNeighborGradAlignCornerF() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) + size = (2, 2) + expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16) + rnn = ResizeNearestNeighborGradAlignCornerF() + output = rnn(Tensor(dy), size) + assert np.all(output.asnumpy() == expect) diff --git a/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py b/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py new file mode 100644 index 00000000000..79101af06ea --- /dev/null +++ b/tests/st/ops/gpu/test_resize_nearest_neighbor_op.py @@ -0,0 +1,71 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class ResizeNearestNeighborAlignCornerT(nn.Cell): + def __init__(self, size): + super(ResizeNearestNeighborAlignCornerT, self).__init__() + self.ResizeNearestNeighborAlignCornerT = P.ResizeNearestNeighbor(size, align_corners=True) + + def construct(self, x): + return self.ResizeNearestNeighborAlignCornerT(x) + +class ResizeNearestNeighborAlignCornerF(nn.Cell): + def __init__(self, size): + super(ResizeNearestNeighborAlignCornerF, self).__init__() + self.ResizeNearestNeighborAlignCornerF = P.ResizeNearestNeighbor(size, align_corners=False) + + def construct(self, x): + return self.ResizeNearestNeighborAlignCornerF(x) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ResizeNearestNeighborAlignCornerT(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) + rnn = ResizeNearestNeighborAlignCornerT((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect) + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) + rnn = ResizeNearestNeighborAlignCornerT((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ResizeNearestNeighborAlignCornerF(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) + rnn = ResizeNearestNeighborAlignCornerF((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect) + input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)) + expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) + rnn = ResizeNearestNeighborAlignCornerF((4, 4)) + output = rnn(input_tensor) + assert np.all(output.asnumpy() == expect)