From 7829bab8117ccdde9f0e17d1b42064de4e85ed63 Mon Sep 17 00:00:00 2001 From: linqingke Date: Thu, 23 Jul 2020 17:02:46 +0800 Subject: [PATCH] add iou ops. --- .../gpu/arrays/gathernd_gpu_kernel.h | 25 ++-- .../gpu/arrays/scatter_nd_gpu_kernel.h | 26 ++-- .../gpu/cuda_impl/check_valid_impl.cu | 45 +++++++ .../gpu/cuda_impl/check_valid_impl.cuh | 25 ++++ .../kernel_compiler/gpu/cuda_impl/iou_impl.cu | 72 +++++++++++ .../gpu/cuda_impl/iou_impl.cuh | 29 +++++ .../gpu/other/check_valid_gpu_kernel.cc | 26 ++++ .../gpu/other/check_valid_gpu_kernel.h | 106 +++++++++++++++ .../gpu/other/iou_gpu_kernel.cc | 25 ++++ .../gpu/other/iou_gpu_kernel.h | 122 ++++++++++++++++++ mindspore/ops/operations/other_ops.py | 2 - tests/st/ops/gpu/test_check_valid_op.py | 54 ++++++++ tests/st/ops/gpu/test_iou_op.py | 57 ++++++++ 13 files changed, 596 insertions(+), 18 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_check_valid_op.py create mode 100644 tests/st/ops/gpu/test_iou_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h index af1efb84f6d..d4e8d3d8ad2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gathernd_gpu_kernel.h @@ -27,7 +27,7 @@ namespace kernel { template class GatherNdGpuFwdKernel : public GpuKernel { public: - GatherNdGpuFwdKernel() : dev_batch_strides_(nullptr), dev_batch_indices_(nullptr) {} + GatherNdGpuFwdKernel() : dev_batch_strides_(nullptr), dev_batch_indices_(nullptr), memcpy_flag_(false) {} ~GatherNdGpuFwdKernel() { if (dev_batch_strides_ != nullptr) { device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(dev_batch_strides_)); @@ -48,12 +48,25 @@ class GatherNdGpuFwdKernel : public GpuKernel { S *indices_addr = GetDeviceAddress(inputs, 1); T *output_addr = GetDeviceAddress(outputs, 0); + if (!memcpy_flag_) { + const size_t strides_len = sizeof(S) * batch_strides_.size(); + const size_t indices_len = sizeof(S) * batch_indices_.size(); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_batch_strides_, &batch_strides_[0], strides_len, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in GatherNdGpuFwdKernel::Launch."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(dev_batch_indices_, &batch_indices_[0], indices_len, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync failed in GatherNdGpuFwdKernel::Launch."); + memcpy_flag_ = true; + } + GatherNd(input_addr, indices_addr, output_addr, dims_[0], dims_[1], dims_[2], dev_batch_strides_, dev_batch_indices_, reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { InitResource(); + memcpy_flag_ = false; size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 2) { MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherNdGpuFwdKernel needs 2."; @@ -77,25 +90,20 @@ class GatherNdGpuFwdKernel : public GpuKernel { batch_indices_[i - 1] = batch_indices_[i] * input_shapes_[i]; } - size_t strides_len = sizeof(S) * batch_strides_.size(); + const size_t strides_len = sizeof(S) * batch_strides_.size(); void *dev_batch_strides_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(strides_len); if (dev_batch_strides_work == nullptr) { MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_strides_work, size: " << strides_len; } dev_batch_strides_ = static_cast(dev_batch_strides_work); - size_t indices_len = sizeof(S) * batch_indices_.size(); + const size_t indices_len = sizeof(S) * batch_indices_.size(); void *dev_batch_indices_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); if (dev_batch_indices_work == nullptr) { MS_LOG(EXCEPTION) << "Failed to alloc dev_batch_indices_work, size: " << indices_len; } dev_batch_indices_ = static_cast(dev_batch_indices_work); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(dev_batch_strides_, &batch_strides_[0], strides_len, cudaMemcpyHostToDevice), - "cudaMemcpy failed in GatherNdGpuFwdKernel::Init."); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(dev_batch_indices_, &batch_indices_[0], indices_len, cudaMemcpyHostToDevice), - "cudaMemcpy failed in GatherNdGpuFwdKernel::Init."); - InitSizeLists(); return true; } @@ -155,6 +163,7 @@ class GatherNdGpuFwdKernel : public GpuKernel { S *dev_batch_strides_; S *dev_batch_indices_; + bool memcpy_flag_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h index 29c229fbacb..51f4323b1df 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h @@ -35,7 +35,8 @@ class ScatterNdGpuFwdKernel : public GpuKernel { indices_stride_(nullptr), work_shape_(nullptr), indices_dim_0_(0), - indices_dim_1_(0) {} + indices_dim_1_(0), + memcpy_flag_(false) {} ~ScatterNdGpuFwdKernel() { if (indices_stride_ != nullptr) { device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(static_cast(indices_stride_)); @@ -56,12 +57,25 @@ class ScatterNdGpuFwdKernel : public GpuKernel { T *update = GetDeviceAddress(inputs, 1); T *output = GetDeviceAddress(outputs, 0); + if (!memcpy_flag_) { + const size_t indices_len = sizeof(S) * vec_indices_stride_.size(); + const size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(indices_stride_, &vec_indices_stride_[0], indices_len, + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpy failed in ScatterNdGpuFwdKernel::Launch."); + CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(work_shape_, &vec_work_shape_[0], vec_work_len, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr)), + "cudaMemcpy failed in ScatterNdGpuFwdKernel::Launch."); + memcpy_flag_ = true; + } + ScatterNd(indices, update, output, block_size_, input_size_, output_size_, indices_dim_0_, indices_dim_1_, indices_stride_, work_shape_, reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { + memcpy_flag_ = false; size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 2) { MS_LOG(ERROR) << "Input number is " << input_num << ", but transpose needs 2 input."; @@ -81,25 +95,20 @@ class ScatterNdGpuFwdKernel : public GpuKernel { GetSize(); - size_t indices_len = sizeof(S) * vec_indices_stride_.size(); + const size_t indices_len = sizeof(S) * vec_indices_stride_.size(); void *indices_stride_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(indices_len); if (indices_stride_work == nullptr) { MS_LOG(EXCEPTION) << "Failed to alloc indices_stride_work, size: " << indices_len; } indices_stride_ = static_cast(indices_stride_work); - size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); + const size_t vec_work_len = sizeof(S) * vec_work_shape_.size(); void *work_shape_work = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(vec_work_len); if (work_shape_work == nullptr) { MS_LOG(EXCEPTION) << "Failed to alloc work_shape_work, size: " << vec_work_len; } work_shape_ = static_cast(work_shape_work); - CHECK_CUDA_RET_WITH_EXCEPT( - cudaMemcpy(indices_stride_, &vec_indices_stride_[0], indices_len, cudaMemcpyHostToDevice), - "cudaMemcpy failed in ScatterNdGpuFwdKernel::Init."); - CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpy(work_shape_, &vec_work_shape_[0], vec_work_len, cudaMemcpyHostToDevice), - "cudaMemcpy failed in ScatterNdGpuFwdKernel::Init."); InitSizeLists(); return true; @@ -168,6 +177,7 @@ class ScatterNdGpuFwdKernel : public GpuKernel { S *work_shape_; size_t indices_dim_0_; size_t indices_dim_1_; + bool memcpy_flag_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu new file mode 100644 index 00000000000..588f8c60e22 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cu @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh" + +template +__global__ void CheckValidKernel(const size_t size, const T *box, const T *img_metas, S *valid) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const size_t left_x = i * 4; + const size_t left_y = i * 4 + 1; + const size_t right_x = i * 4 + 2; + const size_t right_y = i * 4 + 3; + + S valid_flag = false; + valid_flag |= !(box[left_x] >= 0.f); + valid_flag |= !(box[left_y] >= 0.f); + valid_flag |= !(img_metas[0] * img_metas[2] - 1.f >= box[right_x]); + valid_flag |= !(img_metas[1] * img_metas[2] - 1.f >= box[right_y]); + + valid[i] = !valid_flag; + } + + return; +} + +template +void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid, cudaStream_t cuda_stream) { + CheckValidKernel<<>>(size, box, img_metas, valid); +} + +template void CheckValid(const size_t &size, const float *box, const float *img_metas, bool *valid, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh new file mode 100644 index 00000000000..fa82f109601 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CHECK_VALID_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CHECK_VALID_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +template +void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CHECK_VALID_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu new file mode 100644 index 00000000000..f5e9f50dded --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh" + +template +__device__ T CoordinateMax(const T a, const T b) { + return (a > b ? a : b); +} + +template +__device__ T CoordinateMin(const T a, const T b) { + return (a < b ? a : b); +} + +template +__global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *iou_results, const size_t mode, + const size_t input_len_0) { + T location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION]; + T overlaps_coordinate[IOU_DIMENSION]; + const T epsilon = 1e-10; + + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + for (size_t j = 0; j < IOU_DIMENSION; j++) { + location_coordinate[0][j] = box1[(i % input_len_0) * IOU_DIMENSION + j]; + location_coordinate[1][j] = box2[(i / input_len_0) * IOU_DIMENSION + j]; + } + + overlaps_coordinate[0] = CoordinateMax(location_coordinate[0][0], location_coordinate[1][0]); + overlaps_coordinate[1] = CoordinateMax(location_coordinate[0][1], location_coordinate[1][1]); + overlaps_coordinate[2] = CoordinateMin(location_coordinate[0][2], location_coordinate[1][2]); + overlaps_coordinate[3] = CoordinateMin(location_coordinate[0][3], location_coordinate[1][3]); + + T overlaps_w = CoordinateMax(0.f, overlaps_coordinate[2] - overlaps_coordinate[0] + 1); + T overlaps_h = CoordinateMax(0.f, overlaps_coordinate[3] - overlaps_coordinate[1] + 1); + T overlaps = overlaps_w * overlaps_h; + + T area1 = (location_coordinate[0][2] - location_coordinate[0][0] + 1) * (location_coordinate[0][3] - + location_coordinate[0][1] + 1); + if (mode == 0) { + T area2 = (location_coordinate[1][2] - location_coordinate[1][0] + 1) * (location_coordinate[1][3] - + location_coordinate[1][1] + 1); + iou_results[i] = overlaps / (area1 + area2 - overlaps + epsilon); + } else { + iou_results[i] = overlaps / (area1 + epsilon); + } + } + + return; +} + +template +void IOU(const size_t &size, const T *box1, const T *box2, T *iou_results, const size_t &mode, + const size_t &input_len_0, cudaStream_t cuda_stream) { + IOUKernel<<>>(size, box1, box2, iou_results, mode, input_len_0); +} + +template void IOU(const size_t &size, const float *box1, const float *box2, float *iou_results, const size_t &mode, + const size_t &input_len_0, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh new file mode 100644 index 00000000000..f8e7d98a247 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IOU_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IOU_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" + +#define IOU_LOCATION_NUM 2 +#define IOU_DIMENSION 4 + +template +void IOU(const size_t &size, const T *box1, const T *box2, T *iou_results, const size_t &mode, + const size_t &input_len_0, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IOU_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc new file mode 100644 index 00000000000..208e217e1de --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_TWO( + CheckValid, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + CheckValidGpuKernel, float, bool) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h new file mode 100644 index 00000000000..36a69c28b46 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/check_valid_gpu_kernel.h @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_CHECK_VALID_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_CHECK_VALID_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/cuda_impl/check_valid_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class CheckValidGpuKernel : public GpuKernel { + public: + CheckValidGpuKernel() : anchor_boxes_size_(0), img_metas_size_(0), valid_size_(0) {} + + ~CheckValidGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *anchor_boxes_addr = GetDeviceAddress(inputs, 0); + T *img_metas_addr = GetDeviceAddress(inputs, 1); + S *valid_addr = GetDeviceAddress(outputs, 0); + + const size_t coordinate = 4; + const size_t block_size = inputs[0]->size / sizeof(T); + if ((block_size % coordinate) != 0) { + MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; + return false; + } + + const size_t size = block_size / coordinate; + CheckValid(size, anchor_boxes_addr, img_metas_addr, valid_addr, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but CheckValid needs 2 inputs."; + return false; + } + anchor_boxes_size_ = sizeof(T); + img_metas_size_ = sizeof(T); + valid_size_ = sizeof(S); + + auto anchor_boxes_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < anchor_boxes_shape.size(); i++) { + anchor_boxes_size_ *= anchor_boxes_shape[i]; + } + + auto img_metas_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < img_metas_shape.size(); i++) { + img_metas_size_ *= img_metas_shape[i]; + } + + auto valid_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < valid_shape.size(); i++) { + valid_size_ *= valid_shape[i]; + } + + InitSizeLists(); + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(anchor_boxes_size_); + input_size_list_.push_back(img_metas_size_); + output_size_list_.push_back(valid_size_); + } + + private: + size_t anchor_boxes_size_; + size_t img_metas_size_; + size_t valid_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_CHECK_VALID_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.cc new file mode 100644 index 00000000000..5d3f0f202b0 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.cc @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/other/iou_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE( + IOU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + IOUGpuKernel, float) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.h new file mode 100644 index 00000000000..c28e4f91ec6 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/iou_gpu_kernel.h @@ -0,0 +1,122 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_OTHER_IOU_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_OTHER_IOU_GPU_KERNEL_H + +#include +#include +#include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh" +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +template +class IOUGpuKernel : public GpuKernel { + public: + IOUGpuKernel() : gt_boxes_size_(0), anchor_boxes_size_(0), iou_size_(0), mode_(0) {} + + ~IOUGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override { + VARIABLE_NOT_USED(workspace); + T *gt_boxes_addr = GetDeviceAddress(inputs, 0); + T *anchor_boxes_addr = GetDeviceAddress(inputs, 1); + T *iou_addr = GetDeviceAddress(outputs, 0); + + const size_t coordinate = 4; + const size_t block_size_0 = inputs[0]->size / sizeof(T); + const size_t block_size_1 = inputs[1]->size / sizeof(T); + if ((block_size_0 % coordinate) != 0 || (block_size_1 % coordinate) != 0) { + MS_LOG(ERROR) << "The size of the box must be a multiple of 4."; + return false; + } + + const size_t input_len_0 = block_size_0 / coordinate; + const size_t input_len_1 = block_size_1 / coordinate; + IOU(input_len_0 * input_len_1, gt_boxes_addr, anchor_boxes_addr, iou_addr, mode_, input_len_0, + reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but IOU needs 2 inputs."; + return false; + } + gt_boxes_size_ = sizeof(T); + anchor_boxes_size_ = sizeof(T); + iou_size_ = sizeof(T); + + auto gt_boxes_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < gt_boxes_shape.size(); i++) { + gt_boxes_size_ *= gt_boxes_shape[i]; + } + + auto anchor_boxes_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < anchor_boxes_shape.size(); i++) { + anchor_boxes_size_ *= anchor_boxes_shape[i]; + } + + auto iou_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < iou_shape.size(); i++) { + iou_size_ *= iou_shape[i]; + } + + InitSizeLists(); + + std::string mode = GetAttr(kernel_node, "mode"); + + if (mode == "iou") { + mode_ = 0; + } else if (mode == "iof") { + mode_ = 1; + } else { + MS_LOG(ERROR) << "Mode only support 'iou' or 'iof'."; + return false; + } + + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(gt_boxes_size_); + input_size_list_.push_back(anchor_boxes_size_); + output_size_list_.push_back(iou_size_); + } + + private: + size_t gt_boxes_size_; + size_t anchor_boxes_size_; + size_t iou_size_; + size_t mode_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_OTHER_IOU_GPU_KERNEL_H diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 6555b03aa9a..1af25ed1f2d 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -262,8 +262,6 @@ class IOU(PrimitiveWithInfer): return iou def infer_dtype(self, anchor_boxes, gt_boxes): - args = {"anchor_boxes": anchor_boxes, "gt_boxes": gt_boxes} - validator.check_tensor_type_same(args, (mstype.float16,), self.name) return anchor_boxes diff --git a/tests/st/ops/gpu/test_check_valid_op.py b/tests/st/ops/gpu/test_check_valid_op.py new file mode 100644 index 00000000000..2f30ecfc6ee --- /dev/null +++ b/tests/st/ops/gpu/test_check_valid_op.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetCheckValid(nn.Cell): + def __init__(self): + super(NetCheckValid, self).__init__() + self.valid = P.CheckValid() + + def construct(self, anchor, image_metas): + return self.valid(anchor, image_metas) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_boundingbox_decode(): + anchor = np.array([[50, 0, 100, 700], [-2, 2, 8, 100], [10, 20, 300, 2000]], np.float32) + image_metas = np.array([768, 1280, 1], np.float32) + anchor_box = Tensor(anchor, mindspore.float32) + image_metas_box = Tensor(image_metas, mindspore.float32) + expect = np.array([True, False, False], np.bool_) + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + boundingbox_decode = NetCheckValid() + output = boundingbox_decode(anchor_box, image_metas_box) + diff = (output.asnumpy() == expect) + assert (diff == 1).all() + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + boundingbox_decode = NetCheckValid() + output = boundingbox_decode(anchor_box, image_metas_box) + diff = (output.asnumpy() == expect) + assert (diff == 1).all() diff --git a/tests/st/ops/gpu/test_iou_op.py b/tests/st/ops/gpu/test_iou_op.py new file mode 100644 index 00000000000..17812f2d179 --- /dev/null +++ b/tests/st/ops/gpu/test_iou_op.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class NetIOU(nn.Cell): + def __init__(self, mode): + super(NetIOU, self).__init__() + self.encode = P.IOU(mode=mode) + + def construct(self, anchor, groundtruth): + return self.encode(anchor, groundtruth) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_iou(): + pos1 = [101, 169, 246, 429] + pos2 = [121, 138, 304, 374] + mode = "iou" + pos1_box = Tensor(np.array(pos1).reshape(1, 4), mindspore.float32) + pos2_box = Tensor(np.array(pos2).reshape(1, 4), mindspore.float32) + expect_result = np.array(0.46551168, np.float32) + + error = np.ones(shape=[1]) * 1.0e-6 + + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + overlaps = NetIOU(mode) + output = overlaps(pos1_box, pos2_box) + diff = output.asnumpy() - expect_result + assert np.all(abs(diff) < error) + + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + overlaps = NetIOU(mode) + output = overlaps(pos1_box, pos2_box) + diff = output.asnumpy() - expect_result + assert np.all(abs(diff) < error)