!43897 [assistant][ops] Add NonMaxSuppressionV3
Merge pull request !43897 from GP/master
This commit is contained in:
commit
e2c891571f
|
@ -0,0 +1,197 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NONMAXSUPPRESSIONV3_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NONMAXSUPPRESSIONV3_HELPER_H_
|
||||
#include <stdio.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include "mindspore/core/ops/non_max_suppression_v3.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/non_max_suppressionv3_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cukernel {
|
||||
class NonMaxSuppressionV3Attr : public GpuKernelAttrBase {
|
||||
public:
|
||||
NonMaxSuppressionV3Attr() = default;
|
||||
~NonMaxSuppressionV3Attr() override = default;
|
||||
};
|
||||
|
||||
template <typename T, typename M, typename S>
|
||||
class NonMaxSuppressionV3HelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit NonMaxSuppressionV3HelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||
// bound_ = 0;
|
||||
is_null_input_ = false;
|
||||
num_input = 0;
|
||||
u_num = 0;
|
||||
post_output_size_ = 0;
|
||||
}
|
||||
|
||||
virtual ~NonMaxSuppressionV3HelperGpuKernel() = default;
|
||||
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
constexpr int64_t kzero = 0;
|
||||
constexpr int64_t kone = 1;
|
||||
constexpr int64_t ktwo = 2;
|
||||
constexpr int64_t kthree = 3;
|
||||
constexpr int64_t kfour = 4;
|
||||
constexpr int64_t kfourbytes = 32;
|
||||
ResetResource();
|
||||
std::vector<std::vector<int64_t>> input_shapes_1;
|
||||
std::vector<std::vector<int64_t>> input_shapes_2;
|
||||
std::vector<std::vector<int64_t>> input_shapes_3;
|
||||
input_shapes_1.emplace_back(input_shapes[kzero]);
|
||||
input_shapes_1.emplace_back(input_shapes[kone]);
|
||||
input_shapes_2.emplace_back(input_shapes[ktwo]);
|
||||
input_shapes_3.emplace_back(input_shapes[kthree]);
|
||||
input_shapes_3.emplace_back(input_shapes[kfour]);
|
||||
int inp_flag_1 = CalShapesSizeInBytes<T>(input_shapes_1, ktwo, kernel_name_, "input_shapes_1", &input_size_list_);
|
||||
if (inp_flag_1 == -1) {
|
||||
return inp_flag_1;
|
||||
}
|
||||
int inp_flag_2 = CalShapesSizeInBytes<S>(input_shapes_2, kone, kernel_name_, "input_shapes_2", &input_size_list_);
|
||||
if (inp_flag_2 == -1) {
|
||||
return inp_flag_2;
|
||||
}
|
||||
int inp_flag_3 = CalShapesSizeInBytes<M>(input_shapes_3, ktwo, kernel_name_, "input_shapes_3", &input_size_list_);
|
||||
if (inp_flag_3 == -1) {
|
||||
return inp_flag_3;
|
||||
}
|
||||
output_size_list_.emplace_back(input_shapes[kzero][0] * sizeof(int));
|
||||
num_input = input_shapes[kzero][0];
|
||||
u_num = (num_input + kfourbytes - 1) / kfourbytes;
|
||||
work_size_list_.emplace_back(num_input * sizeof(int)); // index buff
|
||||
work_size_list_.emplace_back(num_input * u_num * sizeof(unsigned int)); // sel mask
|
||||
work_size_list_.emplace_back(num_input * sizeof(bool)); // box mask
|
||||
work_size_list_.emplace_back(sizeof(int)); // count
|
||||
work_size_list_.emplace_back(sizeof(int)); // num_keep
|
||||
is_null_input_ = (inp_flag_1 == 1 || inp_flag_2 == 1 || inp_flag_3 == 1);
|
||||
return CheckKernelParam();
|
||||
}
|
||||
|
||||
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||
constexpr int64_t kzero = 0;
|
||||
constexpr int64_t kone = 1;
|
||||
constexpr int64_t ktwo = 2;
|
||||
constexpr int64_t kthree = 3;
|
||||
constexpr int64_t kfour = 4;
|
||||
if (is_null_input_) {
|
||||
return 0;
|
||||
}
|
||||
T *input_ptr = nullptr;
|
||||
T *scores = nullptr;
|
||||
S *max_output_size_ = nullptr;
|
||||
M *iou_threshold_ = nullptr;
|
||||
M *score_threshold_ = nullptr;
|
||||
int *output_ptr = nullptr;
|
||||
int *index_buff = nullptr;
|
||||
unsigned int *sel_mask = nullptr;
|
||||
bool *sel_boxes = nullptr;
|
||||
int *count = nullptr;
|
||||
int *num_keep = nullptr;
|
||||
|
||||
int flag = GetDeviceAddress<T>(input_ptrs, kzero, kernel_name_, &input_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<T>(input_ptrs, kone, kernel_name_, &scores);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<S>(input_ptrs, ktwo, kernel_name_, &max_output_size_);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<M>(input_ptrs, kthree, kernel_name_, &iou_threshold_);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<M>(input_ptrs, kfour, kernel_name_, &score_threshold_);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<int>(work_ptrs, kzero, kernel_name_, &index_buff);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<unsigned int>(work_ptrs, kone, kernel_name_, &sel_mask);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<bool>(work_ptrs, ktwo, kernel_name_, &sel_boxes);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<int>(work_ptrs, kthree, kernel_name_, &count);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<int>(work_ptrs, kfour, kernel_name_, &num_keep);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
flag = GetDeviceAddress<int>(output_ptrs, kzero, kernel_name_, &output_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
M iou_host = 0.0;
|
||||
cudaMemcpy(&iou_host, iou_threshold_, sizeof(M), cudaMemcpyDeviceToHost);
|
||||
float iou = static_cast<float>(iou_host);
|
||||
if (iou > 1 || iou < 0) {
|
||||
MS_EXCEPTION(ValueError) << "For NonMaxSuppressionV3, iou_threshold must be in [0, 1], but got " << iou;
|
||||
return -1;
|
||||
}
|
||||
S max_host = 0;
|
||||
cudaMemcpy(&max_host, max_output_size_, sizeof(S), cudaMemcpyDeviceToHost);
|
||||
int max = static_cast<int32_t>(max_host);
|
||||
if (max < 0) {
|
||||
max_host = 0;
|
||||
}
|
||||
M score_host = 0.0;
|
||||
cudaMemcpy(&score_host, score_threshold_, sizeof(M), cudaMemcpyDeviceToHost);
|
||||
const int b_size = 4;
|
||||
post_output_size_ =
|
||||
DoNms(num_input, count, num_keep, scores, input_ptr, iou_host, score_host, index_buff, max_host, b_size, sel_mask,
|
||||
sel_boxes, output_ptr, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
return 0;
|
||||
}
|
||||
|
||||
TensorInfo GetOutputTensorInfo() override {
|
||||
TensorInfo dyn_out;
|
||||
dyn_out.shapes.push_back({{post_output_size_}});
|
||||
return dyn_out;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> input_shape_;
|
||||
bool is_null_input_;
|
||||
int num_input;
|
||||
int u_num;
|
||||
float iou_threshold;
|
||||
int post_output_size_;
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NONMAXSUPPRESSIONV3_GPU_KERNEL_H_
|
|
@ -0,0 +1,350 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/non_max_suppressionv3_impl.cuh"
|
||||
#include <cub/cub.cuh>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
|
||||
constexpr int kNmsBlockDim = 16;
|
||||
constexpr int kNmsBlockDimMax = 128;
|
||||
constexpr int kNmsBoxesPerThread = 8 * sizeof(int);
|
||||
|
||||
template <typename T>
|
||||
struct GreaterThanCubOp {
|
||||
float threshold_;
|
||||
__host__ __device__ __forceinline__ GreaterThanCubOp(float threshold) : threshold_(threshold) {}
|
||||
__host__ __device__ __forceinline__ bool operator()(const T &val) const {
|
||||
return (static_cast<float>(val) > threshold_);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ void Swap(T *lhs, T *rhs) {
|
||||
T tmp = lhs[0];
|
||||
lhs[0] = rhs[0];
|
||||
rhs[0] = tmp;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T max(T x, T y) {
|
||||
if (x > y) {
|
||||
return x;
|
||||
} else {
|
||||
return y;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T min(T x, T y) {
|
||||
if (x < y) {
|
||||
return x;
|
||||
} else {
|
||||
return y;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ void Flipped(T *box) {
|
||||
if (box[0] > box[2]) Swap(&box[0], &box[2]);
|
||||
if (box[1] > box[3]) Swap(&box[1], &box[3]);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ bool IouDecision(T *box_A, T *box_B, T a_area, float IOU_threshold) {
|
||||
T b_area = (box_B[2] - box_B[0]) * (box_B[3] - box_B[1]);
|
||||
if (a_area == static_cast<T>(0.0) || b_area == static_cast<T>(0.0)) return false;
|
||||
T x_1 = max(box_A[0], box_B[0]);
|
||||
T y_1 = max(box_A[1], box_B[1]);
|
||||
T x_2 = min(box_A[2], box_B[2]);
|
||||
T y_2 = min(box_A[3], box_B[3]);
|
||||
T width = max(x_2 - x_1, T(0)); // in case of no overlap
|
||||
T height = max(y_2 - y_1, T(0));
|
||||
T intersection = width * height;
|
||||
|
||||
float aa = static_cast<float>(intersection);
|
||||
T bb = a_area + b_area - intersection;
|
||||
float bt = static_cast<float>(bb) * IOU_threshold;
|
||||
|
||||
return aa > bt;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ void SelectHelper(int i_selected, int i_original, T *original, T *selected) {
|
||||
selected[i_selected * 4 + 0] = original[i_original * 4 + 0];
|
||||
selected[i_selected * 4 + 1] = original[i_original * 4 + 1];
|
||||
selected[i_selected * 4 + 2] = original[i_original * 4 + 2];
|
||||
selected[i_selected * 4 + 3] = original[i_original * 4 + 3];
|
||||
Flipped(selected + i_selected * 4);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void IndexMultiSelect(const int num_elements, int *index_buff, T *original, T *selected) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_elements; idx += blockDim.x * gridDim.x) {
|
||||
SelectHelper(idx, static_cast<int>(index_buff[idx]), original, selected);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CastFloat(const int num_elements, T *scores, float *scores_float) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_elements; idx += blockDim.x * gridDim.x) {
|
||||
scores_float[idx] = static_cast<float>(scores[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void SetZeros(const int num_elements, unsigned int *target) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_elements; idx += blockDim.x * gridDim.x) {
|
||||
target[idx] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CheckBitHost(T bit_mask, int bit) {
|
||||
return (bit_mask >> (bit % kNmsBoxesPerThread)) & 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__launch_bounds__(kNmsBlockDim *kNmsBlockDim, 4) __global__
|
||||
void NMSReduce(const int num, int u_num, float iou_threshold, T *boxes_sort, int box_size, unsigned int *sel_mask) {
|
||||
__shared__ T shared_i_boxes[kNmsBlockDim * 4];
|
||||
// Same thing with areas
|
||||
__shared__ T shared_i_areas[kNmsBlockDim];
|
||||
// The condition of the for loop is common to all threads in the block.
|
||||
// This is necessary to be able to call __syncthreads() inside of the loop.
|
||||
for (int i_block_offset = blockIdx.x * blockDim.x; i_block_offset < num; i_block_offset += blockDim.x * gridDim.x) {
|
||||
const int i = i_block_offset + threadIdx.x;
|
||||
if (i < num) {
|
||||
// One 1D line load the boxes for x-dimension.
|
||||
if (threadIdx.y == 0) {
|
||||
shared_i_boxes[threadIdx.x * 4 + 0] = boxes_sort[i * 4 + 0];
|
||||
shared_i_boxes[threadIdx.x * 4 + 1] = boxes_sort[i * 4 + 1];
|
||||
shared_i_boxes[threadIdx.x * 4 + 2] = boxes_sort[i * 4 + 2];
|
||||
shared_i_boxes[threadIdx.x * 4 + 3] = boxes_sort[i * 4 + 3];
|
||||
T area = (boxes_sort[i * 4 + 2] - boxes_sort[i * 4 + 0]) * (boxes_sort[i * 4 + 3] - boxes_sort[i * 4 + 1]);
|
||||
shared_i_areas[threadIdx.x] = area;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j_thread_offset = kNmsBoxesPerThread * (blockIdx.y * blockDim.y + threadIdx.y); j_thread_offset < num;
|
||||
j_thread_offset += kNmsBoxesPerThread * blockDim.y * gridDim.y) {
|
||||
int above_threshold = 0;
|
||||
// Make sure that threads are within valid domain.
|
||||
bool valid = false;
|
||||
// Loop over the next kNmsBoxesPerThread boxes and set corresponding bit
|
||||
// if it is overlapping with current box
|
||||
for (int ib = 0; ib < kNmsBoxesPerThread; ++ib) {
|
||||
const int j = j_thread_offset + ib;
|
||||
if (i >= j || i >= num || j >= num) continue;
|
||||
valid = true;
|
||||
T *j_box = boxes_sort + j * 4;
|
||||
T *i_box = shared_i_boxes + threadIdx.x * 4;
|
||||
if (IouDecision(i_box, j_box, shared_i_areas[threadIdx.x], iou_threshold)) {
|
||||
// we have score[j] <= score[i]. j > i
|
||||
above_threshold |= (1U << ib);
|
||||
}
|
||||
}
|
||||
if (valid) {
|
||||
sel_mask[i * u_num + j_thread_offset / kNmsBoxesPerThread] = above_threshold;
|
||||
}
|
||||
}
|
||||
__syncthreads(); // making sure everyone is done reading shared memory.
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int CalNms(const int num_input, int *num_keep, float iou_threshold, int max_output_size, T *boxes_sort, int *index_buff,
|
||||
int box_size, unsigned int *sel_mask, bool *sel_boxes, int *output_ptr, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
int u_num = (num_input + kNmsBoxesPerThread - 1) / kNmsBoxesPerThread;
|
||||
const int max_nms_mask_size = num_input * u_num;
|
||||
int thread_num = 256 < num_input ? 256 : num_input;
|
||||
cudaDeviceProp prop;
|
||||
(void)cudaGetDeviceProperties(&prop, device_id);
|
||||
int max_blocks = prop.multiProcessorCount;
|
||||
int block_num = min(static_cast<int>(((num_input - 1) / thread_num) + 1), max_blocks);
|
||||
SetZeros<<<block_num, thread_num, 0, cuda_stream>>>(max_nms_mask_size, sel_mask);
|
||||
int num_blocks = (num_input + kNmsBlockDim - 1) / kNmsBlockDim;
|
||||
num_blocks = std::max(std::min(num_blocks, kNmsBlockDimMax), 1);
|
||||
dim3 blocks(num_blocks, num_blocks);
|
||||
dim3 threads(kNmsBlockDim, kNmsBlockDim);
|
||||
NMSReduce<<<blocks, threads, 0, cuda_stream>>>(num_input, u_num, iou_threshold, boxes_sort, box_size, sel_mask);
|
||||
|
||||
std::vector<unsigned int> sel_mask_host(num_input * u_num);
|
||||
cudaMemcpyAsync(sel_mask_host.data(), sel_mask, num_input * u_num * sizeof(unsigned int), cudaMemcpyDeviceToHost,
|
||||
cuda_stream);
|
||||
std::vector<int> local(u_num);
|
||||
std::vector<char> sel_boxes_host(num_input);
|
||||
for (int box = 0; box < u_num; box += 1) {
|
||||
local[box] = 0xFFFFFFFF;
|
||||
}
|
||||
int accepted_boxes = 0;
|
||||
for (int box = 0; box < num_input - 1; ++box) {
|
||||
if (!CheckBitHost(local[box / kNmsBoxesPerThread], box)) {
|
||||
continue;
|
||||
}
|
||||
accepted_boxes += 1;
|
||||
int offset = box * u_num;
|
||||
|
||||
for (int b = 0; b < u_num; b += 1) {
|
||||
local[b] &= ~sel_mask_host[offset + b];
|
||||
}
|
||||
if (accepted_boxes > max_output_size) break;
|
||||
}
|
||||
for (int box = 0; box < num_input; box += 1) {
|
||||
sel_boxes_host[box] = CheckBitHost(local[box / kNmsBoxesPerThread], box);
|
||||
}
|
||||
cudaMemcpyAsync(sel_boxes, sel_boxes_host.data(), num_input * sizeof(char), cudaMemcpyHostToDevice, cuda_stream);
|
||||
|
||||
void *d_temp_storage = nullptr;
|
||||
size_t temp_storage_bytes = 0;
|
||||
(void)cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, static_cast<int *>(nullptr),
|
||||
static_cast<char *>(nullptr), static_cast<int *>(nullptr),
|
||||
static_cast<int *>(nullptr), num_input, cuda_stream);
|
||||
(void)cudaMalloc(&d_temp_storage, temp_storage_bytes);
|
||||
(void)cub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, index_buff, sel_boxes, output_ptr, num_keep,
|
||||
num_input, cuda_stream);
|
||||
(void)cudaFree(d_temp_storage);
|
||||
|
||||
int num_count = 0;
|
||||
cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
cudaMemcpyAsync(&num_count, num_keep, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream);
|
||||
num_count = max_output_size < num_count ? max_output_size : num_count;
|
||||
return num_count;
|
||||
}
|
||||
|
||||
template <typename T, typename M, typename S>
|
||||
int DoNms(const int num_input, int *count, int *num_keep, T *scores, T *boxes_in, M iou_threshold_, M score_threshold_,
|
||||
int *index_buff, S max_output_size_, int box_size, unsigned int *sel_mask, bool *sel_boxes, int *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
float iou_threshold = static_cast<float>(iou_threshold_);
|
||||
float score_threshold = static_cast<float>(score_threshold_);
|
||||
int max_output_size = static_cast<int>(max_output_size_);
|
||||
cudaMemset(count, 0, sizeof(int));
|
||||
|
||||
float *scores_float = nullptr;
|
||||
size_t scores_float_temp_storage_bytes = num_input * sizeof(float);
|
||||
(void)cudaMalloc(&scores_float, scores_float_temp_storage_bytes);
|
||||
int thread_num = 256 < num_input ? 256 : num_input;
|
||||
cudaDeviceProp prop;
|
||||
(void)cudaGetDeviceProperties(&prop, device_id);
|
||||
int max_blocks = prop.multiProcessorCount;
|
||||
int block_num = std::min(static_cast<int>(((num_input - 1) / thread_num) + 1), max_blocks);
|
||||
CastFloat<<<block_num, thread_num, 0, cuda_stream>>>(num_input, scores, scores_float);
|
||||
|
||||
auto policy = thrust::cuda::par.on(cuda_stream);
|
||||
thrust::device_ptr<int> dev_ptr(index_buff);
|
||||
thrust::sequence(policy, dev_ptr, dev_ptr + num_input);
|
||||
cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
size_t cub_sort_temp_storage_bytes = 0;
|
||||
(void)cub::DeviceRadixSort::SortPairsDescending(nullptr, cub_sort_temp_storage_bytes,
|
||||
static_cast<float *>(nullptr), // scores
|
||||
static_cast<float *>(nullptr), // sorted scores
|
||||
static_cast<int *>(nullptr), // input indices
|
||||
static_cast<int *>(nullptr), // sorted indices
|
||||
num_input, // num items
|
||||
0, 8 * sizeof(float), // sort all bits
|
||||
cuda_stream);
|
||||
float *scores_sorted = nullptr;
|
||||
size_t scores_sorted_temp_storage_bytes = num_input * sizeof(float);
|
||||
(void)cudaMalloc(&scores_sorted, scores_sorted_temp_storage_bytes);
|
||||
int *index_sorted = nullptr;
|
||||
size_t index_sorted_temp_storage_bytes = num_input * sizeof(int);
|
||||
(void)cudaMalloc(&index_sorted, index_sorted_temp_storage_bytes);
|
||||
void *sort_temp_buff = nullptr;
|
||||
(void)cudaMalloc(&sort_temp_buff, cub_sort_temp_storage_bytes);
|
||||
(void)cub::DeviceRadixSort::SortPairsDescending(sort_temp_buff, cub_sort_temp_storage_bytes, scores_float,
|
||||
scores_sorted, index_buff, index_sorted, num_input, 0,
|
||||
8 * sizeof(float), cuda_stream);
|
||||
cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
|
||||
(void)cudaFree(sort_temp_buff);
|
||||
GreaterThanCubOp<T> score_limit(score_threshold);
|
||||
void *s_temp_storage = nullptr;
|
||||
size_t s_temp_storage_bytes = 0;
|
||||
(void)cub::DeviceSelect::If(nullptr, s_temp_storage_bytes, static_cast<float *>(nullptr),
|
||||
static_cast<float *>(nullptr), static_cast<int *>(nullptr), num_input, score_limit,
|
||||
cuda_stream);
|
||||
(void)cudaMalloc(&s_temp_storage, s_temp_storage_bytes);
|
||||
(void)cub::DeviceSelect::If(s_temp_storage, s_temp_storage_bytes, scores_sorted, scores_float, count, num_input,
|
||||
score_limit, cuda_stream);
|
||||
(void)cudaFree(s_temp_storage);
|
||||
(void)cudaFree(scores_float);
|
||||
(void)cudaFree(scores_sorted);
|
||||
T *boxes_sort = nullptr;
|
||||
size_t boxes_temp_storage_bytes = num_input * box_size * sizeof(T);
|
||||
(void)cudaMalloc(&boxes_sort, boxes_temp_storage_bytes);
|
||||
|
||||
IndexMultiSelect<<<block_num, thread_num, 0, cuda_stream>>>(num_input, index_sorted, boxes_in, boxes_sort);
|
||||
cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
|
||||
int num_count = 0;
|
||||
cudaMemcpyAsync(&num_count, count, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream);
|
||||
const int num_to_keep = num_count;
|
||||
if (num_to_keep <= 0) {
|
||||
return 0;
|
||||
}
|
||||
int output_size = CalNms(num_to_keep, num_keep, iou_threshold, max_output_size, boxes_sort, index_sorted, box_size,
|
||||
sel_mask, sel_boxes, output_ptr, device_id, reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
(void)cudaFree(boxes_sort);
|
||||
(void)cudaFree(index_sorted);
|
||||
|
||||
return output_size;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT int DoNms<float, float, int>(const int num_input, int *count, int *num_keep, float *scores,
|
||||
float *boxes_in, float iou_threshold_, float score_threshold_,
|
||||
int *index_buff, int max_output_size_, int box_size,
|
||||
unsigned int *sel_mask, bool *sel_boxes, int *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT int DoNms<float, float, int64_t>(const int num_input, int *count, int *num_keep, float *scores,
|
||||
float *boxes_in, float iou_threshold_, float score_threshold_,
|
||||
int *index_buff, int64_t max_output_size_, int box_size,
|
||||
unsigned int *sel_mask, bool *sel_boxes, int *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT int DoNms<half, float, int>(const int num_input, int *count, int *num_keep, half *scores,
|
||||
half *boxes_in, float iou_threshold_, float score_threshold_,
|
||||
int *index_buff, int max_output_size_, int box_size,
|
||||
unsigned int *sel_mask, bool *sel_boxes, int *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT int DoNms<half, float, int64_t>(const int num_input, int *count, int *num_keep, half *scores,
|
||||
half *boxes_in, float iou_threshold_, float score_threshold_,
|
||||
int *index_buff, int64_t max_output_size_, int box_size,
|
||||
unsigned int *sel_mask, bool *sel_boxes, int *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT int DoNms<float, half, int>(const int num_input, int *count, int *num_keep, float *scores,
|
||||
float *boxes_in, half iou_threshold_, half score_threshold_,
|
||||
int *index_buff, int max_output_size_, int box_size,
|
||||
unsigned int *sel_mask, bool *sel_boxes, int *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT int DoNms<float, half, int64_t>(const int num_input, int *count, int *num_keep, float *scores,
|
||||
float *boxes_in, half iou_threshold_, half score_threshold_,
|
||||
int *index_buff, int64_t max_output_size_, int box_size,
|
||||
unsigned int *sel_mask, bool *sel_boxes, int *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT int DoNms<half, half, int>(const int num_input, int *count, int *num_keep, half *scores,
|
||||
half *boxes_in, half iou_threshold_, half score_threshold_,
|
||||
int *index_buff, int max_output_size_, int box_size,
|
||||
unsigned int *sel_mask, bool *sel_boxes, int *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT int DoNms<half, half, int64_t>(const int num_input, int *count, int *num_keep, half *scores,
|
||||
half *boxes_in, half iou_threshold_, half score_threshold_,
|
||||
int *index_buff, int64_t max_output_size_, int box_size,
|
||||
unsigned int *sel_mask, bool *sel_boxes, int *output_ptr,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NONMAXSUPPRESSIONV3_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NONMAXSUPPRESSIONV3_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
|
||||
template <typename T, typename M, typename S>
|
||||
CUDA_LIB_EXPORT int DoNms(const int num_input, int *count, int *num_keep, T *scores, T *boxes_in, M iou_threshold_,
|
||||
M score_threshold_, int *index_buff, S max_output_size_, int box_size, unsigned int *sel_mask,
|
||||
bool *sel_boxes, int *output_ptr, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_MATH_NONMAXSUPPRESSIONV3_IMPL_CUH_
|
|
@ -0,0 +1,174 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/nonmaxsuppressionv3_gpu_kernel.h"
|
||||
#include <utility>
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
template <typename T, typename M, typename S>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateNonMaxSuppressionV3KernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::NonMaxSuppressionV3HelperGpuKernel<T, M, S>>(kernel_name, device_id);
|
||||
}
|
||||
using NonMaxSuppressionV3PtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, NonMaxSuppressionV3PtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateNonMaxSuppressionV3KernelPtr<float, float, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateNonMaxSuppressionV3KernelPtr<half, float, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateNonMaxSuppressionV3KernelPtr<float, float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateNonMaxSuppressionV3KernelPtr<half, float, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateNonMaxSuppressionV3KernelPtr<float, half, int32_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateNonMaxSuppressionV3KernelPtr<half, half, int>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateNonMaxSuppressionV3KernelPtr<float, half, int64_t>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
CreateNonMaxSuppressionV3KernelPtr<half, half, int64_t>}};
|
||||
} // namespace
|
||||
|
||||
bool NonMaxSuppressionV3GpuKernelMod::Launch(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NonMaxSuppressionV3GpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::NonMaxSuppressionV3>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
is_need_retrieve_output_shape_ = true;
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
return true;
|
||||
}
|
||||
|
||||
int NonMaxSuppressionV3GpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
for (const auto &input : inputs) {
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
outputs_ = outputs;
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::vector<int64_t> inp_shape1 = inputs[0]->GetShapeVector();
|
||||
std::vector<int64_t> inp_shape2 = inputs[1]->GetShapeVector();
|
||||
std::vector<int64_t> out_shape = inp_shape2;
|
||||
std::vector<int64_t> one = {1};
|
||||
input_shapes.emplace_back(inp_shape1);
|
||||
input_shapes.emplace_back(inp_shape2);
|
||||
input_shapes.emplace_back(one);
|
||||
input_shapes.emplace_back(one);
|
||||
input_shapes.emplace_back(one);
|
||||
output_shapes.emplace_back(out_shape);
|
||||
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
|
||||
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
void NonMaxSuppressionV3GpuKernelMod::SyncData() {
|
||||
std::vector<int64_t> shape = {-1};
|
||||
auto dyn_out = helper_ptr_->GetOutputTensorInfo();
|
||||
shape[0] = dyn_out.shapes[0][0];
|
||||
outputs_[0]->SetShapeVector(std::vector<int64_t>(shape.begin(), shape.end()));
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> NonMaxSuppressionV3GpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, NonMaxSuppressionV3PtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, NonMaxSuppressionV3, NonMaxSuppressionV3GpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NONMAXSUPPRESSIONV3_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NONMAXSUPPRESSIONV3_GPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/cuda_class_common.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/nonmaxsuppressionv3_helper.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/non_max_suppressionv3_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class NonMaxSuppressionV3GpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
NonMaxSuppressionV3GpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::NonMaxSuppressionV3Attr>(); }
|
||||
~NonMaxSuppressionV3GpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
protected:
|
||||
void SyncData() override;
|
||||
std::vector<KernelTensorPtr> GetOutputs() override { return outputs_; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
std::shared_ptr<cukernel::NonMaxSuppressionV3Attr> attr_ptr_{nullptr};
|
||||
std::vector<KernelTensorPtr> outputs_ = {};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NONMAXSUPPRESSIONV3_GPU_KERNEL_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,12 +15,15 @@
|
|||
*/
|
||||
|
||||
#include <set>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
|
||||
#include "ops/non_max_suppression_v3.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "mindapi/ir/value.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -34,38 +37,49 @@ abstract::ShapePtr NonMaxSuppressionV3InferShape(const PrimitivePtr &primitive,
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
|
||||
auto boxes_shape = std::make_shared<abstract::Shape>(
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]);
|
||||
auto scores_shape = std::make_shared<abstract::Shape>(
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]);
|
||||
auto scores_shape_rank = SizeToLong(scores_shape->shape().size());
|
||||
auto max_output_size_shape = std::make_shared<abstract::Shape>(
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape]);
|
||||
auto max_output_size_shape_rank = SizeToLong(max_output_size_shape->shape().size());
|
||||
auto iou_threshold_shape = std::make_shared<abstract::Shape>(
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->GetShapeTrack())[kShape]);
|
||||
auto iou_threshold_shape_rank = SizeToLong(iou_threshold_shape->shape().size());
|
||||
auto score_threshold_shape = std::make_shared<abstract::Shape>(
|
||||
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->GetShapeTrack())[kShape]);
|
||||
auto score_threshold_shape_rank = SizeToLong(score_threshold_shape->shape().size());
|
||||
// boxes second dimension must euqal 4
|
||||
(void)CheckAndConvertUtils::CheckInteger("boxes second dimension", boxes_shape->shape()[1], kEqual, 4, prim_name);
|
||||
auto in_shape1 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto in_shape2 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape];
|
||||
auto in_shape3 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape];
|
||||
auto in_shape4 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->GetShapeTrack())[kShape];
|
||||
auto in_shape5 = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->GetShapeTrack())[kShape];
|
||||
if (IsDynamicRank(in_shape1) || IsDynamicRank(in_shape2) || IsDynamicRank(in_shape3) || IsDynamicRank(in_shape4) ||
|
||||
IsDynamicRank(in_shape5)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
||||
}
|
||||
// boxes must be rank 2
|
||||
(void)CheckAndConvertUtils::CheckInteger("boxes rank", SizeToLong(boxes_shape->shape().size()), kEqual, 2, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("boxes rank", boxes_shape->shape().size(), kEqual, 2, prim_name);
|
||||
int x_shape = boxes_shape->shape()[1];
|
||||
if (x_shape > 0) {
|
||||
// boxes second dimension must euqal 4
|
||||
(void)CheckAndConvertUtils::CheckInteger("boxes second dimension", boxes_shape->shape()[1], kEqual, 4, prim_name);
|
||||
}
|
||||
// score must be rank 1
|
||||
(void)CheckAndConvertUtils::CheckInteger("scores rank", scores_shape_rank, kEqual, 1, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("scores rank", scores_shape->shape().size(), kEqual, 1, prim_name);
|
||||
// score length must be equal with boxes first dimension
|
||||
(void)CheckAndConvertUtils::CheckInteger("scores length", scores_shape->shape()[0], kEqual, boxes_shape->shape()[0],
|
||||
prim_name);
|
||||
// max_output_size,iou_threshold,score_threshold must be scalar
|
||||
(void)CheckAndConvertUtils::CheckInteger("max_output_size rank", max_output_size_shape_rank, kEqual, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("iou_threshold rank", iou_threshold_shape_rank, kEqual, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("score_threshold rank", score_threshold_shape_rank, kEqual, 0, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("max_output_size size", max_output_size_shape->shape().size(), kEqual, 0,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("iou_threshold size", iou_threshold_shape->shape().size(), kEqual, 0,
|
||||
prim_name);
|
||||
(void)CheckAndConvertUtils::CheckInteger("score_threshold size", score_threshold_shape->shape().size(), kEqual, 0,
|
||||
prim_name);
|
||||
auto scores_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
|
||||
// calculate output shape
|
||||
ShapeVector selected_indices_shape = {abstract::Shape::kShapeDimAny};
|
||||
ShapeVector selected_indices_shape = {-1};
|
||||
ShapeVector selected_indices_min_shape = {0};
|
||||
ShapeVector selected_indices_max_shape;
|
||||
if (scores_shape_map[kShape].size() > 0 && scores_shape_map[kShape][0] == -1) {
|
||||
|
@ -94,19 +108,19 @@ TypePtr NonMaxSuppressionV3InferType(const PrimitivePtr &prim, const std::vector
|
|||
// boxes and scores must have same type
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> args;
|
||||
(void)args.insert(std::make_pair("boxes_type", boxes_type));
|
||||
(void)args.insert(std::make_pair("scores_type", scores_type));
|
||||
(void)args.insert({"boxes_type", boxes_type});
|
||||
(void)args.insert({"scores_type", scores_type});
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
|
||||
// iou_threshold,score_threshold must be scalar
|
||||
std::map<std::string, TypePtr> args2;
|
||||
(void)args2.insert(std::make_pair("iou_threshold_type", iou_threshold_type));
|
||||
(void)args2.insert(std::make_pair("score_threshold_type", score_threshold_type));
|
||||
(void)args2.insert({"iou_threshold_type", iou_threshold_type});
|
||||
(void)args2.insert({"score_threshold_type", score_threshold_type});
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args2, valid_types, prim_name);
|
||||
// max_output_size must be scalar
|
||||
const std::set<TypePtr> valid_types2 = {kInt32, kInt64};
|
||||
const std::set<TypePtr> valid_types1 = {kInt32, kInt64};
|
||||
std::map<std::string, TypePtr> args3;
|
||||
(void)args3.insert(std::make_pair("max_output_size_type", max_output_size_type));
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args3, valid_types2, prim_name);
|
||||
(void)args3.insert({"max_output_size_type", max_output_size_type});
|
||||
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args3, valid_types1, prim_name);
|
||||
return max_output_size_type;
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -394,7 +394,7 @@ class NonMaxSuppressionV3(Primitive):
|
|||
`iou_threshold`, `score_threshold` is not 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> boxes = Tensor(np.array([[1, 2, 3, 4], [1, 3, 3, 4], [1, 3, 4, 4],
|
||||
|
@ -412,6 +412,8 @@ class NonMaxSuppressionV3(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize NonMaxSuppressionV3"""
|
||||
self.init_prim_io_names(inputs=['boxes', 'scores', 'max_output_size', 'iou_threshold', 'score_threshold'],
|
||||
outputs=['selected indices'])
|
||||
|
||||
|
||||
class NonMaxSuppressionWithOverlaps(Primitive):
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.nms = P.NonMaxSuppressionV3()
|
||||
|
||||
def construct(self, boxes, scores, max_output_size, iou_threshold, score_threshold):
|
||||
return self.nms(boxes, scores, max_output_size, iou_threshold, score_threshold)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_boxes_float32_scores_float32():
|
||||
"""
|
||||
Feature: test NonMaxSuppressionV3
|
||||
Description: test cases for NonMaxSuppressionV3
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
boxes = Tensor(np.array([[70, 70, 45, 75], [30, 33, 43, 29]]), mindspore.float32)
|
||||
scores = Tensor(np.array([0.6, 0.1]), mindspore.float32)
|
||||
max_output_size = Tensor(2, mindspore.int32)
|
||||
score_threshold = Tensor(0.05, mindspore.float16)
|
||||
iou_threshold = Tensor(0.7, mindspore.float16)
|
||||
expected_idx = np.array([0, 1])
|
||||
op = Net()
|
||||
sel_idx = op(boxes, scores, max_output_size, iou_threshold, score_threshold)
|
||||
assert np.array_equal(sel_idx.asnumpy(), expected_idx)
|
Loading…
Reference in New Issue