stuff added

box copying fix

common function for IOU

new sort func

final update + unit test

remove one comment

fix lintig -1

lint fix 2

lint fix 3

last lint fix

value fix in ST nms with mask

addressing comments

 pylint fix

 pylint fix 1

test file fix
This commit is contained in:
danish 2020-07-29 15:11:24 -04:00
parent 79225e044a
commit a2ffc9530e
5 changed files with 534 additions and 0 deletions

View File

@ -0,0 +1,193 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, softwareg
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nms_with_mask_impl.cuh"
#include <limits>
#include <algorithm>
int RoundUpPower2M(int v) {
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v++;
return v;
}
template <typename T>
__inline__ __device__ void SwapM(T *lhs, T *rhs) {
T tmp = lhs[0];
lhs[0] = rhs[0];
rhs[0] = tmp;
}
template <typename T>
__inline__ __device__ bool IOUDecision(T *output, int box_A_ix, int box_B_ix, int box_A_start, int box_B_start, T *area,
float IOU_value) {
T x_1 = max(output[box_A_start + 0], output[box_B_start + 0]);
T y_1 = max(output[box_A_start + 1], output[box_B_start + 1]);
T x_2 = min(output[box_A_start + 2], output[box_B_start + 2]);
T y_2 = min(output[box_A_start + 3], output[box_B_start + 3]);
T width = max(x_2 - x_1, T(0)); // in case of no overlap
T height = max(y_2 - y_1, T(0));
T combined_area = area[box_A_ix] + area[box_B_ix];
// return decision to keep or remove box
return !(((width * height) / (combined_area - (width * height))) > IOU_value);
}
template <typename T>
__global__ void Preprocess(const int num, int *sel_idx, T *area, T *output, int box_size_) {
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) {
sel_idx[box_num] = box_num;
area[box_num] = (output[(box_num * box_size_) + 2] - output[(box_num * box_size_) + 0]) *
(output[(box_num * box_size_) + 3] - output[(box_num * box_size_) + 1]);
}
}
template <typename T>
__global__ void NMSWithMaskKernel(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes,
int box_size_) {
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) {
// represents highest score box in that GPU block
if (threadIdx.x == 0) {
sel_boxes[box_num] = true;
continue;
}
int box_start_index = box_num * box_size_; // start index adjustment
int block_max_box_num = ((blockIdx.x * blockDim.x) + 0);
int block_max_box_start_index = block_max_box_num * box_size_; // start index adjustment
sel_boxes[box_num] =
IOUDecision(output, box_num, block_max_box_num, block_max_box_start_index, box_start_index, area,
IOU_value); // update mask
}
}
template <typename T>
__global__ void FinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_) {
int box_i, box_j; // access all shared mem meta data with these
int box_i_start_index, box_j_start_index; // actual input data indexing
for (int i = 0; i < num - 1; i++) {
box_i = i;
box_i_start_index = box_i * box_size_; // adjust starting index
if (sel_boxes[box_i]) {
for (int j = i + 1; j < num; j++) {
box_j = j;
box_j_start_index = box_j * box_size_;
if (sel_boxes[box_j]) {
sel_boxes[box_j] = IOUDecision(output, box_i, box_j, box_i_start_index, box_j_start_index, area, IOU_value);
}
}
}
}
}
template <typename T, typename S>
__global__ void BitonicSortByKeyKernelM(const int outer, const int inner, const int ceil_power2, S *data_in,
S *data_out, T *index_buff, S *data_buff, int box_size_) {
// default: sort with share memory
extern __shared__ T share_mem_NMS[];
T *index_arr = share_mem_NMS;
S *data_arr = reinterpret_cast<S *>(index_arr + ceil_power2);
// sort with RAM
if (index_buff != nullptr && data_buff != nullptr) {
index_arr = index_buff + blockIdx.x * ceil_power2;
data_arr = data_buff + blockIdx.x * ceil_power2;
}
for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) {
index_arr[i] = (i < inner) ? T(i) : std::numeric_limits<T>::max();
// populated directly from input data
data_arr[i] = (i < inner) ? data_in[(blockIdx.x * inner + i) * box_size_ + 4] : std::numeric_limits<S>::max();
}
__syncthreads();
for (size_t i = 2; i <= ceil_power2; i <<= 1) {
for (size_t j = (i >> 1); j > 0; j >>= 1) {
for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) {
size_t tid_comp = tid ^ j;
if (tid_comp > tid) {
if ((tid & i) == 0) {
if (data_arr[tid] > data_arr[tid_comp]) {
SwapM(&index_arr[tid], &index_arr[tid_comp]);
SwapM(&data_arr[tid], &data_arr[tid_comp]);
}
} else {
if (data_arr[tid] < data_arr[tid_comp]) {
SwapM(&index_arr[tid], &index_arr[tid_comp]);
SwapM(&data_arr[tid], &data_arr[tid_comp]);
}
}
}
}
__syncthreads();
}
}
T correct_index;
for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) {
correct_index = index_arr[(inner - 1) - tid];
// moved data from input to output, correct ordering using sorted index array
for (auto i : {0, 1, 2, 3, 4}) {
data_out[(blockIdx.x * inner + tid) * box_size_ + i] =
data_in[(blockIdx.x * inner + correct_index) * box_size_ + i];
}
}
}
template <typename T>
void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream) {
Preprocess<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, sel_idx, area, output, box_size_);
}
template <typename T, typename S>
void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff,
int box_size_, cudaStream_t stream) {
int ceil_power2 = RoundUpPower2M(inner);
size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S));
if (share_mem > SHARED_MEM_PER_BLOCK) {
share_mem = 0;
} else {
data_buff = nullptr;
index_buff = nullptr;
}
int thread = std::min(ceil_power2, GET_THREADS);
BitonicSortByKeyKernelM<<<outer, thread, share_mem, stream>>>(outer, inner, ceil_power2, data_in, data_out,
index_buff, data_buff, box_size_);
}
template <typename T>
void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
cudaStream_t cuda_stream) {
NMSWithMaskKernel<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes,
box_size_);
}
template <typename T>
void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
cudaStream_t cuda_stream) {
FinalPass<<<1, 1, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_);
}
template void CalPreprocess<float>(const int num, int *sel_idx, float *area, float *output, int box_size_,
cudaStream_t cuda_stream);
template void BitonicSortByKeyM(const int &outer, const int &inner, float *data_in, float *data_out, int *index_buff,
float *data_buff, int box_size_, cudaStream_t stream);
template void CalNMSWithMask<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes,
int box_size_, cudaStream_t cuda_stream);
template void CalFinalPass<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes,
int box_size_, cudaStream_t cuda_stream);

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream);
template <typename T>
void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
cudaStream_t cuda_stream);
template <typename T, typename S>
void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff,
int box_size_, cudaStream_t stream);
template <typename T>
void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_

View File

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(NMSWithMask,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeBool),
NMSWithMaskGpuFwdKernel, float)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,121 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_
#include <vector>
#include <memory>
#include <iostream>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
template <typename T>
class NMSWithMaskGpuFwdKernel : public GpuKernel {
public:
NMSWithMaskGpuFwdKernel() : num_input_(0), iou_value_(0.5), input_size_(0), output_size_(0), workspace_size_(0) {}
~NMSWithMaskGpuFwdKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
T *data_buff = GetDeviceAddress<T>(workspace, 0); // sort buffer
int *index_buff = GetDeviceAddress<int>(workspace, 1);
T *area = GetDeviceAddress<T>(workspace, 2); // store area values for all boxes
T *output = GetDeviceAddress<T>(outputs, 0);
int *sel_idx = GetDeviceAddress<int>(outputs, 1);
bool *sel_boxes = GetDeviceAddress<bool>(outputs, 2);
BitonicSortByKeyM(num_input_, num_input_, input, output, index_buff, data_buff, box_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalPreprocess(num_input_, sel_idx, area, output, box_size_, reinterpret_cast<cudaStream_t>(stream_ptr));
CalNMSWithMask(num_input_, iou_value_, output, area, sel_boxes, box_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFinalPass(num_input_, iou_value_, output, area, sel_boxes, box_size_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
iou_value_ = GetAttr<float>(kernel_node, "iou_threshold");
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but NMSWithMask needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 3) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but NMSWithMask needs 3 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (CHECK_NULL_INPUT(input_shape)) {
MS_LOG(WARNING) << "NMSWithMask input is null";
InitSizeLists();
return true;
}
num_input_ = input_shape[0]; // Get N value in [N,5] data
input_size_ = num_input_ * sizeof(T) * box_size_; // 5 values per bbox
output_size_ = (input_size_) + (num_input_ * sizeof(int)) + (num_input_ * sizeof(bool));
workspace_size_ = (2 * num_input_ * sizeof(T)) + (1 * num_input_ * sizeof(int));
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
// N sized input/output data
input_size_list_.push_back(num_input_ * sizeof(T) * box_size_);
output_size_list_.push_back(num_input_ * sizeof(T) * box_size_);
output_size_list_.push_back(num_input_ * sizeof(int));
output_size_list_.push_back(num_input_ * sizeof(bool));
// N sized workspace arrs
workspace_size_list_.push_back(num_input_ * sizeof(T));
workspace_size_list_.push_back(num_input_ * sizeof(int));
workspace_size_list_.push_back(num_input_ * sizeof(T));
}
private:
int num_input_;
float iou_value_;
static const int box_size_ = 5; // pre_defined box width
// int box_size__ = 5; // current size of bboxes
// default values
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_

View File

@ -0,0 +1,154 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore
from mindspore import Tensor
from mindspore.ops import operations as P
def manualNMS(bbox, overlap_val_iou):
mask = [True] * len(bbox)
for box_a_index, _ in enumerate(bbox):
if not mask[box_a_index]:
continue # ignore if not in list
box_a = bbox[box_a_index] # select box for value extraction
for box_b_index in range(box_a_index + 1, len(bbox)):
if not mask[box_b_index]:
continue # ignore if not in list
box_b = bbox[box_b_index]
areaA = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
areaB = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])
overlap_x1 = max(box_a[0], box_b[0])
overlap_y1 = max(box_a[1], box_b[1])
overlap_x2 = min(box_a[2], box_b[2])
overlap_y2 = min(box_a[3], box_b[3])
width = max((overlap_x2 - overlap_x1), 0)
height = max((overlap_y2 - overlap_y1), 0)
# generate IOU decision
mask[box_b_index] = not (
(width * height)/(areaA + areaB - (width * height))) > overlap_val_iou
return mask
def runMSRun(op, bbox):
inputs = Tensor(bbox, mindspore.float32)
box, _, mask = op(inputs)
box = box.asnumpy()
mask = mask.asnumpy()
sel_idx = np.where(mask)
sel_rows = box[sel_idx][:, 0:4]
sel_score = box[sel_idx][:, -1]
return sel_rows, sel_score
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nms_with_mask_check_order():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
nms_op = P.NMSWithMask(0.5)
for _ in range(500):
count = 20
box = np.random.randint(1, 100, size=(count, 4))
box[:, 2] = box[:, 0] + box[:, 2]
box[:, 3] = box[:, 1] + box[:, 3]
unsorted_scores = np.random.rand(count, 1)
bbox = np.hstack((box, unsorted_scores))
bbox = Tensor(bbox, dtype=mindspore.float32)
prop, _, _ = nms_op(bbox)
ms_sorted_scores = (prop.asnumpy()[:, -1]) # select just scores
np_sorted_scores = (np.sort(unsorted_scores, axis=0)[::-1][:, 0]) # sort manually
np.testing.assert_array_almost_equal(
ms_sorted_scores, np_sorted_scores)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nms_with_masl_check_result():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_count = 500
for x in range(1, test_count+1):
count = 20 # size of bbox lists
nms_op = P.NMSWithMask(x * 0.002) # will test full range b/w 0 and 1
box = np.random.randint(1, 100, size=(count, 4))
box[:, 2] = box[:, 0] + box[:, 2]
box[:, 3] = box[:, 1] + box[:, 3]
unsorted_scores = np.random.rand(count, 1)
sorted_scores = np.sort(unsorted_scores, axis=0)[::-1]
bbox = np.hstack((box, sorted_scores))
bbox = Tensor(bbox, dtype=mindspore.float32)
_, _, mask = nms_op(bbox)
mask = mask.asnumpy()
manual_mask = manualNMS(box, x * 0.002)
np.testing.assert_array_equal(mask, np.array(manual_mask))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nms_with_mask_edge_case_1():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# CASE 1 - FULL OVERLAP BOXES - Every box is duplicated and has a different score
nms_op1 = P.NMSWithMask(0.3)
bbox1 = [[12, 4, 33, 17, 0.6], [20, 11, 38, 23, 0.1], [20, 10, 45, 26, 0.9], [15, 17, 35, 38, 0.5],
[10, 20, 30, 40, 0.4], [35, 35, 89, 90, 0.8], [12, 4, 33, 17, 0.3], [20, 11, 38, 23, 0.2],
[20, 10, 45, 26, 0.1], [15, 17, 35, 38, 0.8], [10, 20, 30, 40, 0.41], [35, 35, 89, 90, 0.82]]
expected_bbox = np.array([[20., 10., 45., 26.],
[35., 35., 89., 90.],
[15., 17., 35., 38.],
[12., 4., 33., 17.]])
expected_score = np.array([0.9, 0.82, 0.8, 0.6])
sel_rows, sel_score = runMSRun(nms_op1, bbox1)
np.testing.assert_almost_equal(sel_rows, expected_bbox)
np.testing.assert_almost_equal(sel_score, expected_score)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nms_with_mask_edge_case_2():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# CASE 2 - 0 value boxes - with valid scores
nms_op2 = P.NMSWithMask(0.5)
bbox2 = [[0, 0, 0, 0, 0.6], [0, 0, 0, 0, 0.1]]
expected_bbox = np.array([[0., 0., 0., 0.],
[0., 0., 0., 0.]])
expected_score = np.array([0.6, 0.1])
sel_rows, sel_score = runMSRun(nms_op2, bbox2)
np.testing.assert_almost_equal(sel_rows, expected_bbox)
np.testing.assert_almost_equal(sel_score, expected_score)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nms_with_mask_edge_case_3():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# CASE 3 - x2/x1 and y2/y1 sequence out of place
nms_op3 = P.NMSWithMask(0.7)
bbox3 = [[70, 70, 45, 75, 0.6], [30, 33, 43, 29, 0.1]]
expected_bbox = np.array([[70., 70., 45., 75.],
[30., 33., 43., 29.]])
expected_score = np.array([0.6, 0.1])
sel_rows, sel_score = runMSRun(nms_op3, bbox3)
np.testing.assert_almost_equal(sel_rows, expected_bbox)
np.testing.assert_almost_equal(sel_score, expected_score)