forked from OSSInnovation/mindspore
stuff added
box copying fix common function for IOU new sort func final update + unit test remove one comment fix lintig -1 lint fix 2 lint fix 3 last lint fix value fix in ST nms with mask addressing comments pylint fix pylint fix 1 test file fix
This commit is contained in:
parent
79225e044a
commit
a2ffc9530e
|
@ -0,0 +1,193 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, softwareg
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nms_with_mask_impl.cuh"
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
|
||||
int RoundUpPower2M(int v) {
|
||||
v--;
|
||||
v |= v >> 1;
|
||||
v |= v >> 2;
|
||||
v |= v >> 4;
|
||||
v |= v >> 8;
|
||||
v |= v >> 16;
|
||||
v++;
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ void SwapM(T *lhs, T *rhs) {
|
||||
T tmp = lhs[0];
|
||||
lhs[0] = rhs[0];
|
||||
rhs[0] = tmp;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ bool IOUDecision(T *output, int box_A_ix, int box_B_ix, int box_A_start, int box_B_start, T *area,
|
||||
float IOU_value) {
|
||||
T x_1 = max(output[box_A_start + 0], output[box_B_start + 0]);
|
||||
T y_1 = max(output[box_A_start + 1], output[box_B_start + 1]);
|
||||
T x_2 = min(output[box_A_start + 2], output[box_B_start + 2]);
|
||||
T y_2 = min(output[box_A_start + 3], output[box_B_start + 3]);
|
||||
T width = max(x_2 - x_1, T(0)); // in case of no overlap
|
||||
T height = max(y_2 - y_1, T(0));
|
||||
T combined_area = area[box_A_ix] + area[box_B_ix];
|
||||
// return decision to keep or remove box
|
||||
return !(((width * height) / (combined_area - (width * height))) > IOU_value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Preprocess(const int num, int *sel_idx, T *area, T *output, int box_size_) {
|
||||
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) {
|
||||
sel_idx[box_num] = box_num;
|
||||
area[box_num] = (output[(box_num * box_size_) + 2] - output[(box_num * box_size_) + 0]) *
|
||||
(output[(box_num * box_size_) + 3] - output[(box_num * box_size_) + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void NMSWithMaskKernel(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes,
|
||||
int box_size_) {
|
||||
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < num; box_num += blockDim.x * gridDim.x) {
|
||||
// represents highest score box in that GPU block
|
||||
if (threadIdx.x == 0) {
|
||||
sel_boxes[box_num] = true;
|
||||
continue;
|
||||
}
|
||||
int box_start_index = box_num * box_size_; // start index adjustment
|
||||
int block_max_box_num = ((blockIdx.x * blockDim.x) + 0);
|
||||
int block_max_box_start_index = block_max_box_num * box_size_; // start index adjustment
|
||||
sel_boxes[box_num] =
|
||||
IOUDecision(output, box_num, block_max_box_num, block_max_box_start_index, box_start_index, area,
|
||||
IOU_value); // update mask
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void FinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_) {
|
||||
int box_i, box_j; // access all shared mem meta data with these
|
||||
int box_i_start_index, box_j_start_index; // actual input data indexing
|
||||
for (int i = 0; i < num - 1; i++) {
|
||||
box_i = i;
|
||||
box_i_start_index = box_i * box_size_; // adjust starting index
|
||||
if (sel_boxes[box_i]) {
|
||||
for (int j = i + 1; j < num; j++) {
|
||||
box_j = j;
|
||||
box_j_start_index = box_j * box_size_;
|
||||
if (sel_boxes[box_j]) {
|
||||
sel_boxes[box_j] = IOUDecision(output, box_i, box_j, box_i_start_index, box_j_start_index, area, IOU_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void BitonicSortByKeyKernelM(const int outer, const int inner, const int ceil_power2, S *data_in,
|
||||
S *data_out, T *index_buff, S *data_buff, int box_size_) {
|
||||
// default: sort with share memory
|
||||
extern __shared__ T share_mem_NMS[];
|
||||
T *index_arr = share_mem_NMS;
|
||||
S *data_arr = reinterpret_cast<S *>(index_arr + ceil_power2);
|
||||
// sort with RAM
|
||||
if (index_buff != nullptr && data_buff != nullptr) {
|
||||
index_arr = index_buff + blockIdx.x * ceil_power2;
|
||||
data_arr = data_buff + blockIdx.x * ceil_power2;
|
||||
}
|
||||
for (int i = threadIdx.x; i < ceil_power2; i += blockDim.x) {
|
||||
index_arr[i] = (i < inner) ? T(i) : std::numeric_limits<T>::max();
|
||||
// populated directly from input data
|
||||
data_arr[i] = (i < inner) ? data_in[(blockIdx.x * inner + i) * box_size_ + 4] : std::numeric_limits<S>::max();
|
||||
}
|
||||
__syncthreads();
|
||||
for (size_t i = 2; i <= ceil_power2; i <<= 1) {
|
||||
for (size_t j = (i >> 1); j > 0; j >>= 1) {
|
||||
for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) {
|
||||
size_t tid_comp = tid ^ j;
|
||||
if (tid_comp > tid) {
|
||||
if ((tid & i) == 0) {
|
||||
if (data_arr[tid] > data_arr[tid_comp]) {
|
||||
SwapM(&index_arr[tid], &index_arr[tid_comp]);
|
||||
SwapM(&data_arr[tid], &data_arr[tid_comp]);
|
||||
}
|
||||
} else {
|
||||
if (data_arr[tid] < data_arr[tid_comp]) {
|
||||
SwapM(&index_arr[tid], &index_arr[tid_comp]);
|
||||
SwapM(&data_arr[tid], &data_arr[tid_comp]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
T correct_index;
|
||||
for (size_t tid = threadIdx.x; tid < inner; tid += blockDim.x) {
|
||||
correct_index = index_arr[(inner - 1) - tid];
|
||||
// moved data from input to output, correct ordering using sorted index array
|
||||
for (auto i : {0, 1, 2, 3, 4}) {
|
||||
data_out[(blockIdx.x * inner + tid) * box_size_ + i] =
|
||||
data_in[(blockIdx.x * inner + correct_index) * box_size_ + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream) {
|
||||
Preprocess<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, sel_idx, area, output, box_size_);
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff,
|
||||
int box_size_, cudaStream_t stream) {
|
||||
int ceil_power2 = RoundUpPower2M(inner);
|
||||
size_t share_mem = ceil_power2 * (sizeof(T) + sizeof(S));
|
||||
if (share_mem > SHARED_MEM_PER_BLOCK) {
|
||||
share_mem = 0;
|
||||
} else {
|
||||
data_buff = nullptr;
|
||||
index_buff = nullptr;
|
||||
}
|
||||
int thread = std::min(ceil_power2, GET_THREADS);
|
||||
BitonicSortByKeyKernelM<<<outer, thread, share_mem, stream>>>(outer, inner, ceil_power2, data_in, data_out,
|
||||
index_buff, data_buff, box_size_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
|
||||
cudaStream_t cuda_stream) {
|
||||
NMSWithMaskKernel<<<GET_BLOCKS(num), GET_THREADS, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes,
|
||||
box_size_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
|
||||
cudaStream_t cuda_stream) {
|
||||
FinalPass<<<1, 1, 0, cuda_stream>>>(num, IOU_value, output, area, sel_boxes, box_size_);
|
||||
}
|
||||
|
||||
template void CalPreprocess<float>(const int num, int *sel_idx, float *area, float *output, int box_size_,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void BitonicSortByKeyM(const int &outer, const int &inner, float *data_in, float *data_out, int *index_buff,
|
||||
float *data_buff, int box_size_, cudaStream_t stream);
|
||||
|
||||
template void CalNMSWithMask<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes,
|
||||
int box_size_, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalFinalPass<float>(const int num, const float IOU_value, float *output, float *area, bool *sel_boxes,
|
||||
int box_size_, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void CalPreprocess(const int num, int *sel_idx, T *area, T *output, int box_size_, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalNMSWithMask(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T, typename S>
|
||||
void BitonicSortByKeyM(const int &outer, const int &inner, S *data_in, S *data_out, T *index_buff, S *data_buff,
|
||||
int box_size_, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void CalFinalPass(const int num, const float IOU_value, T *output, T *area, bool *sel_boxes, int box_size_,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_NMS_WITH_MASK_IMPL_H_
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/math/nms_with_mask_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(NMSWithMask,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
NMSWithMaskGpuFwdKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,121 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/nms_with_mask_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/kernel_constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class NMSWithMaskGpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
NMSWithMaskGpuFwdKernel() : num_input_(0), iou_value_(0.5), input_size_(0), output_size_(0), workspace_size_(0) {}
|
||||
~NMSWithMaskGpuFwdKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *data_buff = GetDeviceAddress<T>(workspace, 0); // sort buffer
|
||||
int *index_buff = GetDeviceAddress<int>(workspace, 1);
|
||||
T *area = GetDeviceAddress<T>(workspace, 2); // store area values for all boxes
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
int *sel_idx = GetDeviceAddress<int>(outputs, 1);
|
||||
bool *sel_boxes = GetDeviceAddress<bool>(outputs, 2);
|
||||
|
||||
BitonicSortByKeyM(num_input_, num_input_, input, output, index_buff, data_buff, box_size_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalPreprocess(num_input_, sel_idx, area, output, box_size_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalNMSWithMask(num_input_, iou_value_, output, area, sel_boxes, box_size_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalFinalPass(num_input_, iou_value_, output, area, sel_boxes, box_size_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
iou_value_ = GetAttr<float>(kernel_node, "iou_threshold");
|
||||
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(ERROR) << "Input number is " << input_num << ", but NMSWithMask needs 1 input.";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 3) {
|
||||
MS_LOG(ERROR) << "Output number is " << output_num << ", but NMSWithMask needs 3 output.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (CHECK_NULL_INPUT(input_shape)) {
|
||||
MS_LOG(WARNING) << "NMSWithMask input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
num_input_ = input_shape[0]; // Get N value in [N,5] data
|
||||
|
||||
input_size_ = num_input_ * sizeof(T) * box_size_; // 5 values per bbox
|
||||
output_size_ = (input_size_) + (num_input_ * sizeof(int)) + (num_input_ * sizeof(bool));
|
||||
workspace_size_ = (2 * num_input_ * sizeof(T)) + (1 * num_input_ * sizeof(int));
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
// N sized input/output data
|
||||
input_size_list_.push_back(num_input_ * sizeof(T) * box_size_);
|
||||
output_size_list_.push_back(num_input_ * sizeof(T) * box_size_);
|
||||
output_size_list_.push_back(num_input_ * sizeof(int));
|
||||
output_size_list_.push_back(num_input_ * sizeof(bool));
|
||||
|
||||
// N sized workspace arrs
|
||||
workspace_size_list_.push_back(num_input_ * sizeof(T));
|
||||
workspace_size_list_.push_back(num_input_ * sizeof(int));
|
||||
workspace_size_list_.push_back(num_input_ * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
int num_input_;
|
||||
float iou_value_;
|
||||
static const int box_size_ = 5; // pre_defined box width
|
||||
// int box_size__ = 5; // current size of bboxes
|
||||
// default values
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_NMS_WITH_MASK_IMPL_H_
|
|
@ -0,0 +1,154 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
def manualNMS(bbox, overlap_val_iou):
|
||||
mask = [True] * len(bbox)
|
||||
for box_a_index, _ in enumerate(bbox):
|
||||
if not mask[box_a_index]:
|
||||
continue # ignore if not in list
|
||||
box_a = bbox[box_a_index] # select box for value extraction
|
||||
for box_b_index in range(box_a_index + 1, len(bbox)):
|
||||
if not mask[box_b_index]:
|
||||
continue # ignore if not in list
|
||||
box_b = bbox[box_b_index]
|
||||
areaA = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
|
||||
areaB = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])
|
||||
overlap_x1 = max(box_a[0], box_b[0])
|
||||
overlap_y1 = max(box_a[1], box_b[1])
|
||||
overlap_x2 = min(box_a[2], box_b[2])
|
||||
overlap_y2 = min(box_a[3], box_b[3])
|
||||
width = max((overlap_x2 - overlap_x1), 0)
|
||||
height = max((overlap_y2 - overlap_y1), 0)
|
||||
# generate IOU decision
|
||||
mask[box_b_index] = not (
|
||||
(width * height)/(areaA + areaB - (width * height))) > overlap_val_iou
|
||||
return mask
|
||||
|
||||
|
||||
def runMSRun(op, bbox):
|
||||
inputs = Tensor(bbox, mindspore.float32)
|
||||
box, _, mask = op(inputs)
|
||||
box = box.asnumpy()
|
||||
mask = mask.asnumpy()
|
||||
sel_idx = np.where(mask)
|
||||
sel_rows = box[sel_idx][:, 0:4]
|
||||
sel_score = box[sel_idx][:, -1]
|
||||
return sel_rows, sel_score
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_nms_with_mask_check_order():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
nms_op = P.NMSWithMask(0.5)
|
||||
for _ in range(500):
|
||||
count = 20
|
||||
box = np.random.randint(1, 100, size=(count, 4))
|
||||
box[:, 2] = box[:, 0] + box[:, 2]
|
||||
box[:, 3] = box[:, 1] + box[:, 3]
|
||||
unsorted_scores = np.random.rand(count, 1)
|
||||
bbox = np.hstack((box, unsorted_scores))
|
||||
bbox = Tensor(bbox, dtype=mindspore.float32)
|
||||
prop, _, _ = nms_op(bbox)
|
||||
ms_sorted_scores = (prop.asnumpy()[:, -1]) # select just scores
|
||||
np_sorted_scores = (np.sort(unsorted_scores, axis=0)[::-1][:, 0]) # sort manually
|
||||
np.testing.assert_array_almost_equal(
|
||||
ms_sorted_scores, np_sorted_scores)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_nms_with_masl_check_result():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_count = 500
|
||||
for x in range(1, test_count+1):
|
||||
count = 20 # size of bbox lists
|
||||
nms_op = P.NMSWithMask(x * 0.002) # will test full range b/w 0 and 1
|
||||
box = np.random.randint(1, 100, size=(count, 4))
|
||||
box[:, 2] = box[:, 0] + box[:, 2]
|
||||
box[:, 3] = box[:, 1] + box[:, 3]
|
||||
unsorted_scores = np.random.rand(count, 1)
|
||||
sorted_scores = np.sort(unsorted_scores, axis=0)[::-1]
|
||||
bbox = np.hstack((box, sorted_scores))
|
||||
bbox = Tensor(bbox, dtype=mindspore.float32)
|
||||
_, _, mask = nms_op(bbox)
|
||||
mask = mask.asnumpy()
|
||||
manual_mask = manualNMS(box, x * 0.002)
|
||||
np.testing.assert_array_equal(mask, np.array(manual_mask))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_nms_with_mask_edge_case_1():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
# CASE 1 - FULL OVERLAP BOXES - Every box is duplicated and has a different score
|
||||
nms_op1 = P.NMSWithMask(0.3)
|
||||
bbox1 = [[12, 4, 33, 17, 0.6], [20, 11, 38, 23, 0.1], [20, 10, 45, 26, 0.9], [15, 17, 35, 38, 0.5],
|
||||
[10, 20, 30, 40, 0.4], [35, 35, 89, 90, 0.8], [12, 4, 33, 17, 0.3], [20, 11, 38, 23, 0.2],
|
||||
[20, 10, 45, 26, 0.1], [15, 17, 35, 38, 0.8], [10, 20, 30, 40, 0.41], [35, 35, 89, 90, 0.82]]
|
||||
expected_bbox = np.array([[20., 10., 45., 26.],
|
||||
[35., 35., 89., 90.],
|
||||
[15., 17., 35., 38.],
|
||||
[12., 4., 33., 17.]])
|
||||
expected_score = np.array([0.9, 0.82, 0.8, 0.6])
|
||||
|
||||
sel_rows, sel_score = runMSRun(nms_op1, bbox1)
|
||||
np.testing.assert_almost_equal(sel_rows, expected_bbox)
|
||||
np.testing.assert_almost_equal(sel_score, expected_score)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_nms_with_mask_edge_case_2():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
# CASE 2 - 0 value boxes - with valid scores
|
||||
nms_op2 = P.NMSWithMask(0.5)
|
||||
bbox2 = [[0, 0, 0, 0, 0.6], [0, 0, 0, 0, 0.1]]
|
||||
expected_bbox = np.array([[0., 0., 0., 0.],
|
||||
[0., 0., 0., 0.]])
|
||||
expected_score = np.array([0.6, 0.1])
|
||||
|
||||
sel_rows, sel_score = runMSRun(nms_op2, bbox2)
|
||||
np.testing.assert_almost_equal(sel_rows, expected_bbox)
|
||||
np.testing.assert_almost_equal(sel_score, expected_score)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_nms_with_mask_edge_case_3():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
# CASE 3 - x2/x1 and y2/y1 sequence out of place
|
||||
nms_op3 = P.NMSWithMask(0.7)
|
||||
bbox3 = [[70, 70, 45, 75, 0.6], [30, 33, 43, 29, 0.1]]
|
||||
expected_bbox = np.array([[70., 70., 45., 75.],
|
||||
[30., 33., 43., 29.]])
|
||||
expected_score = np.array([0.6, 0.1])
|
||||
|
||||
sel_rows, sel_score = runMSRun(nms_op3, bbox3)
|
||||
np.testing.assert_almost_equal(sel_rows, expected_bbox)
|
||||
np.testing.assert_almost_equal(sel_score, expected_score)
|
Loading…
Reference in New Issue