diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu index 48afc575765..37788fc0a98 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cu @@ -36,12 +36,44 @@ __inline__ __device__ void Swap(T *lhs, T *rhs) { rhs[0] = tmp; } +// Initialize per row mask array to all true +__global__ void MaskInit(int numSq, bool *row_mask) { + for (int mat_pos = blockIdx.x * blockDim.x + threadIdx.x; mat_pos < numSq; mat_pos += blockDim.x * gridDim.x) { + row_mask[mat_pos] = true; + } +} + +// copy data from input to output array sorted by indices returned from bitonic sort +// flips boxes if asked to, default - false -> if (x1/y1 > x2/y2) template -__global__ void PopulateOutput(T *data_in, T *data_out, int *index_buff, const int num, int box_size_) { +__global__ void PopulateOutput(T *data_in, T *data_out, int *index_buff, const int num, int box_size_, + bool flip_mode = false) { for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) { int correct_index = index_buff[(num - 1) - box_num]; // flip the array around - for (int x = 0; x < 5; x++) { - data_out[(box_num * box_size_) + x] = data_in[(correct_index * box_size_) + x]; + int correct_arr_start = correct_index * box_size_; + int current_arr_start = box_num * box_size_; + if (flip_mode) { // flip boxes + // check x + if (data_in[correct_arr_start + 0] > data_in[correct_arr_start + 2]) { + data_out[current_arr_start + 0] = data_in[correct_arr_start + 2]; + data_out[current_arr_start + 2] = data_in[correct_arr_start + 0]; + } else { + data_out[current_arr_start + 0] = data_in[correct_arr_start + 0]; + data_out[current_arr_start + 2] = data_in[correct_arr_start + 2]; + } + // check y + if (data_in[correct_arr_start + 1] > data_in[correct_arr_start + 3]) { + data_out[current_arr_start + 1] = data_in[correct_arr_start + 3]; + data_out[current_arr_start + 3] = data_in[correct_arr_start + 1]; + } else { + data_out[current_arr_start + 1] = data_in[correct_arr_start + 1]; + data_out[current_arr_start + 3] = data_in[correct_arr_start + 3]; + } + data_out[current_arr_start + 4] = data_in[correct_arr_start + 4]; + } else { // default behaviour, don't flip + for (int x = 0; x < 5; x++) { + data_out[current_arr_start + x] = data_in[correct_arr_start + x]; + } } } } @@ -57,55 +89,55 @@ __inline__ __device__ bool IOUDecision(T *output, int box_A_ix, int box_B_ix, in T height = max(y_2 - y_1, T(0)); T combined_area = area[box_A_ix] + area[box_B_ix]; // return decision to keep or remove box - return !(((width * height) / (combined_area - (width * height))) > IOU_value); + return !(((width * height) / (combined_area - (width * height))) >= IOU_value); } +// calculate areas for boxes -> sorted by output boxes +// populated return mask (init to all true) and return index array template -__global__ void Preprocess(const int num, int *sel_idx, T *area, T *output, int box_size_) { +__global__ void Preprocess(const int num, int *sel_idx, bool *sel_boxes, T *area, T *output, int box_size_) { for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) { sel_idx[box_num] = box_num; + sel_boxes[box_num] = true; area[box_num] = (output[(box_num * box_size_) + 2] - output[(box_num * box_size_) + 0]) * (output[(box_num * box_size_) + 3] - output[(box_num * box_size_) + 1]); } } +// Run parallel NMS pass +// Every box updates it's own mask in row_mask in sep threads template -__global__ void NMSWithMaskKernel(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, - int box_size_) { - for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) { - // represents highest score box in that GPU block - if (threadIdx.x == 0) { - sel_boxes[box_num] = true; +__global__ void NMSPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, + bool *row_mask) { + int box_i_start_index, box_j_start_index; // actual input data indexing + int mask_offset = 0; + for (int box_i = blockIdx.x * blockDim.x + threadIdx.x; box_i < num - 1; box_i += blockDim.x * gridDim.x) { + mask_offset = box_i * num; + box_i_start_index = box_i * box_size_; // adjust starting index + for (int box_j = box_i + 1; box_j < num; box_j++) { + box_j_start_index = box_j * box_size_; + row_mask[mask_offset + box_j] = + IOUDecision(output, box_i, box_j, box_i_start_index, box_j_start_index, area, IOU_value); + } + } +} + +// Reduce pass runs on 1 block to allow thread sync +__global__ void ReducePass(const int num, bool *sel_boxes, bool *row_mask) { + // loop over every box in order of high to low confidence score + for (int i = 0; i < num - 1; ++i) { + if (!sel_boxes[i]) { continue; } - int box_start_index = box_num * box_size_; // start index adjustment - int block_max_box_num = ((blockIdx.x * blockDim.x) + 0); - int block_max_box_start_index = block_max_box_num * box_size_; // start index adjustment - sel_boxes[box_num] = - IOUDecision(output, box_num, block_max_box_num, block_max_box_start_index, box_start_index, area, - IOU_value); // update mask - } -} - -template -__global__ void FinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_) { - int box_i, box_j; // access all shared mem meta data with these - int box_i_start_index, box_j_start_index; // actual input data indexing - for (int i = 0; i < num - 1; i++) { - box_i = i; - box_i_start_index = box_i * box_size_; // adjust starting index - if (sel_boxes[box_i]) { - for (int j = i + 1; j < num; j++) { - box_j = j; - box_j_start_index = box_j * box_size_; - if (sel_boxes[box_j]) { - sel_boxes[box_j] = IOUDecision(output, box_i, box_j, box_i_start_index, box_j_start_index, area, IOU_value); - } - } + // every thread handles a different set of boxes (per all boxes in order) + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num; j += blockDim.x * gridDim.x) { + sel_boxes[j] = sel_boxes[j] && row_mask[i * num + j]; } + __syncthreads(); // sync all threads before moving all active threads to next iteration } } +// Sorting function based on BitonicSort from TopK kernel template __global__ void NMS_BitonicSortByKeyKernel(const int outer, const int inner, const int ceil_power2, T *input, T *data_buff, int *index_buff, int box_size_) { @@ -139,41 +171,37 @@ __global__ void NMS_BitonicSortByKeyKernel(const int outer, const int inner, con } template -void CalPreprocess(const int num, int *sel_idx, T *area, T *input, T *output, int *index_buff, int box_size_, - cudaStream_t cuda_stream) { - PopulateOutput<<>>(input, output, index_buff, num, box_size_); - Preprocess<<>>(num, sel_idx, area, output, box_size_); +void CalPreprocess(const int num, int *sel_idx, bool *sel_boxes, T *area, T *input, T *output, int *index_buff, + int box_size_, bool *row_mask, cudaStream_t cuda_stream) { + int total_val = num * num; + MaskInit<<>>(total_val, row_mask); + // default for flipping boxes -> false (provision available to flip if API updated) + PopulateOutput<<>>(input, output, index_buff, num, box_size_, false); + Preprocess<<>>(num, sel_idx, sel_boxes, area, output, box_size_); } template -void CalSortInit(const int &num, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_, - cudaStream_t stream) { +void CalSort(const int &num, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_, + cudaStream_t stream) { int ceil_p_2 = NMSRoundUpPower2(num); int thread = std::min(ceil_p_2, GET_THREADS); NMS_BitonicSortByKeyKernel<<<1, thread, 0, stream>>>(1, num, ceil_p_2, data_in, data_buff, index_buff, box_size_); } template -void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, - cudaStream_t cuda_stream) { - NMSWithMaskKernel<<>>(num, IOU_value, output, area, sel_boxes, - box_size_); +void CalNMS(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, bool *row_mask, + cudaStream_t cuda_stream) { + NMSPass<<>>(num, IOU_value, output, area, sel_boxes, box_size_, + row_mask); + ReducePass<<<1, GET_THREADS, 0, cuda_stream>>>(num, sel_boxes, row_mask); } -template -void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, - cudaStream_t cuda_stream) { - FinalPass<<<1, 1, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_); -} +template void CalSort(const int &inner, float *data_in, float *data_out, int *index_buff, float *data_buff, + int box_size_, cudaStream_t stream); -template void CalPreprocess(const int num, int *sel_idx, float *area, float *input, float *output, - int *index_buff, int box_size_, cudaStream_t cuda_stream); +template void CalPreprocess(const int num, int *sel_idx, bool *sel_boxes, float *area, float *input, + float *output, int *index_buff, int box_size_, bool *row_mask, + cudaStream_t cuda_stream); -template void CalSortInit(const int &inner, float *data_in, float *data_out, int *index_buff, float *data_buff, - int box_size_, cudaStream_t stream); - -template void CalNMSWithMask(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes, - int box_size_, cudaStream_t cuda_stream); - -template void CalFinalPass(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes, - int box_size_, cudaStream_t cuda_stream); +template void CalNMS(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes, + int box_size_, bool *row_mask, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh index b20c6704ed7..f3a81f73c2c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh @@ -20,20 +20,16 @@ #include "runtime/device/gpu/cuda_common.h" template -void CalPreprocess(const int num, int *sel_idx, T *area, T *input, T *output, int *index_buff, int box_size_, - cudaStream_t cuda_stream); +void CalSort(const int &inner, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_, + cudaStream_t stream); template -void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, - cudaStream_t cuda_stream); +void CalPreprocess(const int num, int *sel_idx, bool *sel_boxes, T *area, T *input, T *output, int *index_buff, + int box_size_, bool *row_mask, cudaStream_t cuda_stream); template -void CalSortInit(const int &inner, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_, - cudaStream_t stream); - -template -void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, - cudaStream_t cuda_stream); +void CalNMS(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, bool *row_mask, + cudaStream_t cuda_stream); int NMSRoundUpPower2(int v); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h index 7c22d80ba73..d219b9d0cae 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h @@ -41,21 +41,19 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { T *input = GetDeviceAddress(inputs, 0); - T *area = GetDeviceAddress(workspace, 0); // store area values for all boxes - T *data_buff = GetDeviceAddress(workspace, 1); // sort buffer + T *area = GetDeviceAddress(workspace, 0); + T *data_buff = GetDeviceAddress(workspace, 1); int *index_buff = GetDeviceAddress(workspace, 2); + bool *row_mask = GetDeviceAddress(workspace, 3); T *output = GetDeviceAddress(outputs, 0); int *sel_idx = GetDeviceAddress(outputs, 1); bool *sel_boxes = GetDeviceAddress(outputs, 2); - CalSortInit(num_input_, input, output, index_buff, data_buff, box_size_, - reinterpret_cast(stream_ptr)); - CalPreprocess(num_input_, sel_idx, area, input, output, index_buff, box_size_, + CalSort(num_input_, input, output, index_buff, data_buff, box_size_, reinterpret_cast(stream_ptr)); + CalPreprocess(num_input_, sel_idx, sel_boxes, area, input, output, index_buff, box_size_, row_mask, reinterpret_cast(stream_ptr)); - CalNMSWithMask(num_input_, iou_value_, output, area, sel_boxes, box_size_, - reinterpret_cast(stream_ptr)); - CalFinalPass(num_input_, iou_value_, output, area, sel_boxes, box_size_, - reinterpret_cast(stream_ptr)); + CalNMS(num_input_, iou_value_, output, area, sel_boxes, box_size_, row_mask, + reinterpret_cast(stream_ptr)); return true; } @@ -87,8 +85,9 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel { input_size_ = num_input_ * sizeof(T) * box_size_; // 5 values per bbox output_size_ = (input_size_) + (num_input_ * sizeof(int)) + (num_input_ * sizeof(bool)); - workspace_size_ = num_input_ * sizeof(int); - workspace_size_ += ceil_power_2 * (sizeof(T) + sizeof(int)); + workspace_size_ = num_input_ * sizeof(int); // storing areas + workspace_size_ += ceil_power_2 * (sizeof(T) + sizeof(int)); // sorting buffers + workspace_size_ += (num_input_ * num_input_ * sizeof(bool)); // Row mask - NMS InitSizeLists(); return true; @@ -103,9 +102,10 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel { output_size_list_.push_back(num_input_ * sizeof(bool)); // N sized workspace arrs - workspace_size_list_.push_back(num_input_ * sizeof(T)); // area list - workspace_size_list_.push_back(ceil_power_2 * sizeof(T)); // data buff - workspace_size_list_.push_back(ceil_power_2 * sizeof(int)); // index buff + workspace_size_list_.push_back(num_input_ * sizeof(T)); // area list + workspace_size_list_.push_back(ceil_power_2 * sizeof(T)); // data buff + workspace_size_list_.push_back(ceil_power_2 * sizeof(int)); // index buff + workspace_size_list_.push_back(num_input_ * num_input_ * sizeof(bool)); // mask list } private: diff --git a/tests/st/ops/gpu/test_nms_with_mask_op.py b/tests/st/ops/gpu/test_nms_with_mask_op.py index ed0be91a845..45764d7907e 100644 --- a/tests/st/ops/gpu/test_nms_with_mask_op.py +++ b/tests/st/ops/gpu/test_nms_with_mask_op.py @@ -40,7 +40,7 @@ def test_nms_with_mask_check_order(): context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") nms_op = P.NMSWithMask(0.5) for _ in range(10): - count = 8000 + count = 4000 box = np.random.randint(1, 100, size=(count, 4)) box[:, 2] = box[:, 0] + box[:, 2] box[:, 3] = box[:, 1] + box[:, 3]