forked from mindspore-Ecosystem/mindspore
reduce based nms final pass - speed improv
refactored faster nms refactored faster nms + typo fix added box flipping choice set choice to true for testing - yz switching back new test file
This commit is contained in:
parent
0e65b3ba70
commit
7d7fa760a0
|
@ -36,12 +36,44 @@ __inline__ __device__ void Swap(T *lhs, T *rhs) {
|
|||
rhs[0] = tmp;
|
||||
}
|
||||
|
||||
// Initialize per row mask array to all true
|
||||
__global__ void MaskInit(int numSq, bool *row_mask) {
|
||||
for (int mat_pos = blockIdx.x * blockDim.x + threadIdx.x; mat_pos < numSq; mat_pos += blockDim.x * gridDim.x) {
|
||||
row_mask[mat_pos] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// copy data from input to output array sorted by indices returned from bitonic sort
|
||||
// flips boxes if asked to, default - false -> if (x1/y1 > x2/y2)
|
||||
template <typename T>
|
||||
__global__ void PopulateOutput(T *data_in, T *data_out, int *index_buff, const int num, int box_size_) {
|
||||
__global__ void PopulateOutput(T *data_in, T *data_out, int *index_buff, const int num, int box_size_,
|
||||
bool flip_mode = false) {
|
||||
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) {
|
||||
int correct_index = index_buff[(num - 1) - box_num]; // flip the array around
|
||||
for (int x = 0; x < 5; x++) {
|
||||
data_out[(box_num * box_size_) + x] = data_in[(correct_index * box_size_) + x];
|
||||
int correct_arr_start = correct_index * box_size_;
|
||||
int current_arr_start = box_num * box_size_;
|
||||
if (flip_mode) { // flip boxes
|
||||
// check x
|
||||
if (data_in[correct_arr_start + 0] > data_in[correct_arr_start + 2]) {
|
||||
data_out[current_arr_start + 0] = data_in[correct_arr_start + 2];
|
||||
data_out[current_arr_start + 2] = data_in[correct_arr_start + 0];
|
||||
} else {
|
||||
data_out[current_arr_start + 0] = data_in[correct_arr_start + 0];
|
||||
data_out[current_arr_start + 2] = data_in[correct_arr_start + 2];
|
||||
}
|
||||
// check y
|
||||
if (data_in[correct_arr_start + 1] > data_in[correct_arr_start + 3]) {
|
||||
data_out[current_arr_start + 1] = data_in[correct_arr_start + 3];
|
||||
data_out[current_arr_start + 3] = data_in[correct_arr_start + 1];
|
||||
} else {
|
||||
data_out[current_arr_start + 1] = data_in[correct_arr_start + 1];
|
||||
data_out[current_arr_start + 3] = data_in[correct_arr_start + 3];
|
||||
}
|
||||
data_out[current_arr_start + 4] = data_in[correct_arr_start + 4];
|
||||
} else { // default behaviour, don't flip
|
||||
for (int x = 0; x < 5; x++) {
|
||||
data_out[current_arr_start + x] = data_in[correct_arr_start + x];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -57,55 +89,55 @@ __inline__ __device__ bool IOUDecision(T *output, int box_A_ix, int box_B_ix, in
|
|||
T height = max(y_2 - y_1, T(0));
|
||||
T combined_area = area[box_A_ix] + area[box_B_ix];
|
||||
// return decision to keep or remove box
|
||||
return !(((width * height) / (combined_area - (width * height))) > IOU_value);
|
||||
return !(((width * height) / (combined_area - (width * height))) >= IOU_value);
|
||||
}
|
||||
|
||||
// calculate areas for boxes -> sorted by output boxes
|
||||
// populated return mask (init to all true) and return index array
|
||||
template <typename T>
|
||||
__global__ void Preprocess(const int num, int *sel_idx, T *area, T *output, int box_size_) {
|
||||
__global__ void Preprocess(const int num, int *sel_idx, bool *sel_boxes, T *area, T *output, int box_size_) {
|
||||
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) {
|
||||
sel_idx[box_num] = box_num;
|
||||
sel_boxes[box_num] = true;
|
||||
area[box_num] = (output[(box_num * box_size_) + 2] - output[(box_num * box_size_) + 0]) *
|
||||
(output[(box_num * box_size_) + 3] - output[(box_num * box_size_) + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
// Run parallel NMS pass
|
||||
// Every box updates it's own mask in row_mask in sep threads
|
||||
template <typename T>
|
||||
__global__ void NMSWithMaskKernel(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes,
|
||||
int box_size_) {
|
||||
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) {
|
||||
// represents highest score box in that GPU block
|
||||
if (threadIdx.x == 0) {
|
||||
sel_boxes[box_num] = true;
|
||||
__global__ void NMSPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
|
||||
bool *row_mask) {
|
||||
int box_i_start_index, box_j_start_index; // actual input data indexing
|
||||
int mask_offset = 0;
|
||||
for (int box_i = blockIdx.x * blockDim.x + threadIdx.x; box_i < num - 1; box_i += blockDim.x * gridDim.x) {
|
||||
mask_offset = box_i * num;
|
||||
box_i_start_index = box_i * box_size_; // adjust starting index
|
||||
for (int box_j = box_i + 1; box_j < num; box_j++) {
|
||||
box_j_start_index = box_j * box_size_;
|
||||
row_mask[mask_offset + box_j] =
|
||||
IOUDecision(output, box_i, box_j, box_i_start_index, box_j_start_index, area, IOU_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce pass runs on 1 block to allow thread sync
|
||||
__global__ void ReducePass(const int num, bool *sel_boxes, bool *row_mask) {
|
||||
// loop over every box in order of high to low confidence score
|
||||
for (int i = 0; i < num - 1; ++i) {
|
||||
if (!sel_boxes[i]) {
|
||||
continue;
|
||||
}
|
||||
int box_start_index = box_num * box_size_; // start index adjustment
|
||||
int block_max_box_num = ((blockIdx.x * blockDim.x) + 0);
|
||||
int block_max_box_start_index = block_max_box_num * box_size_; // start index adjustment
|
||||
sel_boxes[box_num] =
|
||||
IOUDecision(output, box_num, block_max_box_num, block_max_box_start_index, box_start_index, area,
|
||||
IOU_value); // update mask
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void FinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_) {
|
||||
int box_i, box_j; // access all shared mem meta data with these
|
||||
int box_i_start_index, box_j_start_index; // actual input data indexing
|
||||
for (int i = 0; i < num - 1; i++) {
|
||||
box_i = i;
|
||||
box_i_start_index = box_i * box_size_; // adjust starting index
|
||||
if (sel_boxes[box_i]) {
|
||||
for (int j = i + 1; j < num; j++) {
|
||||
box_j = j;
|
||||
box_j_start_index = box_j * box_size_;
|
||||
if (sel_boxes[box_j]) {
|
||||
sel_boxes[box_j] = IOUDecision(output, box_i, box_j, box_i_start_index, box_j_start_index, area, IOU_value);
|
||||
}
|
||||
}
|
||||
// every thread handles a different set of boxes (per all boxes in order)
|
||||
for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num; j += blockDim.x * gridDim.x) {
|
||||
sel_boxes[j] = sel_boxes[j] && row_mask[i * num + j];
|
||||
}
|
||||
__syncthreads(); // sync all threads before moving all active threads to next iteration
|
||||
}
|
||||
}
|
||||
|
||||
// Sorting function based on BitonicSort from TopK kernel
|
||||
template <typename T>
|
||||
__global__ void NMS_BitonicSortByKeyKernel(const int outer, const int inner, const int ceil_power2, T *input,
|
||||
T *data_buff, int *index_buff, int box_size_) {
|
||||
|
@ -139,41 +171,37 @@ __global__ void NMS_BitonicSortByKeyKernel(const int outer, const int inner, con
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void CalPreprocess(const int num, int *sel_idx, T *area, T *input, T *output, int *index_buff, int box_size_,
|
||||
cudaStream_t cuda_stream) {
|
||||
PopulateOutput<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(input, output, index_buff, num, box_size_);
|
||||
Preprocess<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, sel_idx, area, output, box_size_);
|
||||
void CalPreprocess(const int num, int *sel_idx, bool *sel_boxes, T *area, T *input, T *output, int *index_buff,
|
||||
int box_size_, bool *row_mask, cudaStream_t cuda_stream) {
|
||||
int total_val = num * num;
|
||||
MaskInit<<<GET_BLOCKS(total_val), GET_THREADS, 0, cuda_stream>>>(total_val, row_mask);
|
||||
// default for flipping boxes -> false (provision available to flip if API updated)
|
||||
PopulateOutput<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(input, output, index_buff, num, box_size_, false);
|
||||
Preprocess<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, sel_idx, sel_boxes, area, output, box_size_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalSortInit(const int &num, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_,
|
||||
cudaStream_t stream) {
|
||||
void CalSort(const int &num, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_,
|
||||
cudaStream_t stream) {
|
||||
int ceil_p_2 = NMSRoundUpPower2(num);
|
||||
int thread = std::min(ceil_p_2, GET_THREADS);
|
||||
NMS_BitonicSortByKeyKernel<<<1, thread, 0, stream>>>(1, num, ceil_p_2, data_in, data_buff, index_buff, box_size_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
|
||||
cudaStream_t cuda_stream) {
|
||||
NMSWithMaskKernel<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes,
|
||||
box_size_);
|
||||
void CalNMS(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, bool *row_mask,
|
||||
cudaStream_t cuda_stream) {
|
||||
NMSPass<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_,
|
||||
row_mask);
|
||||
ReducePass<<<1, GET_THREADS, 0, cuda_stream>>>(num, sel_boxes, row_mask);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
|
||||
cudaStream_t cuda_stream) {
|
||||
FinalPass<<<1, 1, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_);
|
||||
}
|
||||
template void CalSort<float>(const int &inner, float *data_in, float *data_out, int *index_buff, float *data_buff,
|
||||
int box_size_, cudaStream_t stream);
|
||||
|
||||
template void CalPreprocess<float>(const int num, int *sel_idx, float *area, float *input, float *output,
|
||||
int *index_buff, int box_size_, cudaStream_t cuda_stream);
|
||||
template void CalPreprocess<float>(const int num, int *sel_idx, bool *sel_boxes, float *area, float *input,
|
||||
float *output, int *index_buff, int box_size_, bool *row_mask,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void CalSortInit<float>(const int &inner, float *data_in, float *data_out, int *index_buff, float *data_buff,
|
||||
int box_size_, cudaStream_t stream);
|
||||
|
||||
template void CalNMSWithMask<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes,
|
||||
int box_size_, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalFinalPass<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes,
|
||||
int box_size_, cudaStream_t cuda_stream);
|
||||
template void CalNMS<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes,
|
||||
int box_size_, bool *row_mask, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -20,20 +20,16 @@
|
|||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void CalPreprocess(const int num, int *sel_idx, T *area, T *input, T *output, int *index_buff, int box_size_,
|
||||
cudaStream_t cuda_stream);
|
||||
void CalSort(const int &inner, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
|
||||
cudaStream_t cuda_stream);
|
||||
void CalPreprocess(const int num, int *sel_idx, bool *sel_boxes, T *area, T *input, T *output, int *index_buff,
|
||||
int box_size_, bool *row_mask, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalSortInit(const int &inner, T *data_in, T *data_out, int *index_buff, T *data_buff, int box_size_,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
|
||||
cudaStream_t cuda_stream);
|
||||
void CalNMS(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_, bool *row_mask,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
int NMSRoundUpPower2(int v);
|
||||
|
||||
|
|
|
@ -41,21 +41,19 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *area = GetDeviceAddress<T>(workspace, 0); // store area values for all boxes
|
||||
T *data_buff = GetDeviceAddress<T>(workspace, 1); // sort buffer
|
||||
T *area = GetDeviceAddress<T>(workspace, 0);
|
||||
T *data_buff = GetDeviceAddress<T>(workspace, 1);
|
||||
int *index_buff = GetDeviceAddress<int>(workspace, 2);
|
||||
bool *row_mask = GetDeviceAddress<bool>(workspace, 3);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
int *sel_idx = GetDeviceAddress<int>(outputs, 1);
|
||||
bool *sel_boxes = GetDeviceAddress<bool>(outputs, 2);
|
||||
|
||||
CalSortInit(num_input_, input, output, index_buff, data_buff, box_size_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalPreprocess(num_input_, sel_idx, area, input, output, index_buff, box_size_,
|
||||
CalSort(num_input_, input, output, index_buff, data_buff, box_size_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalPreprocess(num_input_, sel_idx, sel_boxes, area, input, output, index_buff, box_size_, row_mask,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalNMSWithMask(num_input_, iou_value_, output, area, sel_boxes, box_size_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalFinalPass(num_input_, iou_value_, output, area, sel_boxes, box_size_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalNMS(num_input_, iou_value_, output, area, sel_boxes, box_size_, row_mask,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -87,8 +85,9 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
|
|||
input_size_ = num_input_ * sizeof(T) * box_size_; // 5 values per bbox
|
||||
output_size_ = (input_size_) + (num_input_ * sizeof(int)) + (num_input_ * sizeof(bool));
|
||||
|
||||
workspace_size_ = num_input_ * sizeof(int);
|
||||
workspace_size_ += ceil_power_2 * (sizeof(T) + sizeof(int));
|
||||
workspace_size_ = num_input_ * sizeof(int); // storing areas
|
||||
workspace_size_ += ceil_power_2 * (sizeof(T) + sizeof(int)); // sorting buffers
|
||||
workspace_size_ += (num_input_ * num_input_ * sizeof(bool)); // Row mask - NMS
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
|
@ -103,9 +102,10 @@ class NMSWithMaskGpuFwdKernel : public GpuKernel {
|
|||
output_size_list_.push_back(num_input_ * sizeof(bool));
|
||||
|
||||
// N sized workspace arrs
|
||||
workspace_size_list_.push_back(num_input_ * sizeof(T)); // area list
|
||||
workspace_size_list_.push_back(ceil_power_2 * sizeof(T)); // data buff
|
||||
workspace_size_list_.push_back(ceil_power_2 * sizeof(int)); // index buff
|
||||
workspace_size_list_.push_back(num_input_ * sizeof(T)); // area list
|
||||
workspace_size_list_.push_back(ceil_power_2 * sizeof(T)); // data buff
|
||||
workspace_size_list_.push_back(ceil_power_2 * sizeof(int)); // index buff
|
||||
workspace_size_list_.push_back(num_input_ * num_input_ * sizeof(bool)); // mask list
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -40,7 +40,7 @@ def test_nms_with_mask_check_order():
|
|||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
nms_op = P.NMSWithMask(0.5)
|
||||
for _ in range(10):
|
||||
count = 8000
|
||||
count = 4000
|
||||
box = np.random.randint(1, 100, size=(count, 4))
|
||||
box[:, 2] = box[:, 0] + box[:, 2]
|
||||
box[:, 3] = box[:, 1] + box[:, 3]
|
||||
|
|
Loading…
Reference in New Issue