!41886 [assistant][ops] Add GPU Operator CombinedNonMaxSuppression
Merge pull request !41886 from Forever Young/CombinedNonMaxSuppression
This commit is contained in:
commit
e2e0e009e6
|
@ -0,0 +1,261 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless req_uired by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/functional.h>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <iostream>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/combined_non_max_suppression_impl.cuh"
|
||||
|
||||
constexpr int DIM0 = 0;
|
||||
constexpr int DIM1 = 1;
|
||||
constexpr int DIM2 = 2;
|
||||
constexpr int DIM3 = 3;
|
||||
constexpr int DIM4 = 4;
|
||||
constexpr float zero = 0;
|
||||
constexpr float one = 1;
|
||||
|
||||
__inline__ __device__ float IOU(float *boxes_result, int i, int j) {
|
||||
float lx, ly, rx, ry;
|
||||
float w, h;
|
||||
float area;
|
||||
float area_a = (boxes_result[i * DIM4 + DIM2] - boxes_result[i * DIM4 + DIM0]) *
|
||||
(boxes_result[i * DIM4 + DIM3] - boxes_result[i * DIM4 + DIM1]);
|
||||
float area_b = (boxes_result[j * DIM4 + DIM2] - boxes_result[j * DIM4 + DIM0]) *
|
||||
(boxes_result[j * DIM4 + DIM3] - boxes_result[j * DIM4 + DIM1]);
|
||||
if ((area_a == zero) || (area_b == zero)) {
|
||||
return zero;
|
||||
}
|
||||
lx = max(boxes_result[i * DIM4 + DIM0], boxes_result[j * DIM4 + DIM0]);
|
||||
ly = max(boxes_result[i * DIM4 + DIM1], boxes_result[j * DIM4 + DIM1]);
|
||||
rx = min(boxes_result[i * DIM4 + DIM2], boxes_result[j * DIM4 + DIM2]);
|
||||
ry = min(boxes_result[i * DIM4 + DIM3], boxes_result[j * DIM4 + DIM3]);
|
||||
w = (rx > lx) ? (rx - lx) : zero;
|
||||
h = (ry > ly) ? (ry - ly) : zero;
|
||||
area = w * h;
|
||||
return area / (area_a + area_b - area);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void permute(int q, int num_boxes, int batch_size, T *boxes, float *new_boxes) {
|
||||
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < q * num_boxes * batch_size * DIM4;
|
||||
index += blockDim.x * gridDim.x) {
|
||||
int i = index % DIM4;
|
||||
int c = (index / DIM4) % q;
|
||||
int d = (index / DIM4 / q) % num_boxes;
|
||||
int n = index / DIM4 / q / num_boxes;
|
||||
int new_index = ((n * q + c) * num_boxes + d) * DIM4 + i;
|
||||
float result = boxes[index];
|
||||
new_boxes[new_index] = result;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void boxsort(float *boxes_result, float *new_boxes, int batch_size, int q, int num_boxes) {
|
||||
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < batch_size * q * num_boxes;
|
||||
box_num += blockDim.x * gridDim.x) {
|
||||
boxes_result[box_num * DIM4 + DIM0] = min(new_boxes[box_num * DIM4 + DIM0], new_boxes[box_num * DIM4 + DIM2]);
|
||||
boxes_result[box_num * DIM4 + DIM2] = max(new_boxes[box_num * DIM4 + DIM0], new_boxes[box_num * DIM4 + DIM2]);
|
||||
boxes_result[box_num * DIM4 + DIM1] = min(new_boxes[box_num * DIM4 + DIM1], new_boxes[box_num * DIM4 + DIM3]);
|
||||
boxes_result[box_num * DIM4 + DIM3] = max(new_boxes[box_num * DIM4 + DIM1], new_boxes[box_num * DIM4 + DIM3]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void presort(int num_classes, int num_boxes, int batch_size, T *scores, float *new_scores, int *index,
|
||||
T *score_threshold, bool *sel) {
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_classes * num_boxes * batch_size;
|
||||
idx += blockDim.x * gridDim.x) {
|
||||
int c = idx % num_classes;
|
||||
int d = (idx / num_classes) % num_boxes;
|
||||
int n = idx / (num_classes * num_boxes);
|
||||
int new_index = (n * num_classes + c) * num_boxes + d;
|
||||
float result = scores[idx];
|
||||
new_scores[new_index] = result;
|
||||
index[new_index] = new_index;
|
||||
sel[new_index] = (new_scores[new_index] > score_threshold[DIM0]) ? true : false;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void Init(int num_classes, int num_boxes, int batch_size, bool *mask) {
|
||||
for (int mat_pos = blockIdx.x * blockDim.x + threadIdx.x; mat_pos < batch_size * num_classes * num_boxes * num_boxes;
|
||||
mat_pos += blockDim.x * gridDim.x) {
|
||||
mask[mat_pos] = true;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void nms(int batch_size, int num_classes, T *iou_threshold, bool *sel, float *boxes_result, int *index,
|
||||
int q, int num_boxes, bool *mask) {
|
||||
int box_i, box_j;
|
||||
for (int mask_index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
mask_index < batch_size * num_classes * num_boxes * num_boxes; mask_index += blockDim.x * gridDim.x) {
|
||||
box_i = mask_index / num_boxes; // row in 2d sel_mask array
|
||||
box_j = mask_index / (num_boxes * num_boxes) * num_boxes + mask_index % num_boxes; // col in 2d sel_mask array
|
||||
if (box_j > box_i) {
|
||||
int idi = index[box_i];
|
||||
int idj = index[box_j]; // skip when box_j index lower/equal to box_i - will remain true
|
||||
if (q == num_classes) {
|
||||
if (IOU(boxes_result, idi, idj) > iou_threshold[0]) {
|
||||
mask[mask_index] = false;
|
||||
}
|
||||
} else {
|
||||
if (IOU(boxes_result, idi / (num_classes * num_boxes) * num_boxes + idi % num_boxes,
|
||||
idj / (num_classes * num_boxes) * num_boxes + idj % num_boxes) > iou_threshold[0]) {
|
||||
mask[mask_index] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void nmsReducePass(int batch_size, int num_classes, bool *sel, int *index, int num_boxes, bool *mask) {
|
||||
for (int page = DIM0; page < batch_size * num_classes; page++) {
|
||||
for (int i = DIM0; i < num_boxes - DIM1; ++i) {
|
||||
int idxi = index[page * num_boxes + i];
|
||||
if (!sel[idxi]) {
|
||||
continue;
|
||||
}
|
||||
for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_boxes; j += blockDim.x * gridDim.x) {
|
||||
int idxj = index[page * num_boxes + j];
|
||||
|
||||
sel[idxj] = sel[idxj] && mask[page * num_boxes * num_boxes + i * num_boxes + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void sizeperclass(int batch_size, int num_classes, bool *sel, int num_boxes, int *index,
|
||||
int *max_output_size_per_class) {
|
||||
for (int page = blockIdx.x * blockDim.x + threadIdx.x; page < batch_size * num_classes;
|
||||
page += blockDim.x * gridDim.x) {
|
||||
int class_idx_count = DIM0;
|
||||
for (int i = page * num_boxes; i < (page + DIM1) * num_boxes; i++) {
|
||||
int number = index[i];
|
||||
if (sel[number]) {
|
||||
class_idx_count++;
|
||||
if (class_idx_count > max_output_size_per_class[DIM0]) {
|
||||
sel[number] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void output(int batch_size, int per_detections, int *index, float *new_scores, bool *sel, float *new_boxes,
|
||||
T *nmsed_classes, T *nmsed_scores, T *nmsed_boxes, int *valid_detections, bool clip_boxes,
|
||||
int num_classes, int num_boxes, int q) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += gridDim.x * blockDim.x) {
|
||||
int num = DIM0;
|
||||
for (int j = i * num_classes * num_boxes; (j < (i + DIM1) * num_classes * num_boxes) && (num < per_detections);
|
||||
j++) {
|
||||
int idx = index[j];
|
||||
float score = new_scores[j];
|
||||
if (sel[idx]) {
|
||||
int bboxOffset = i * (q * num_boxes);
|
||||
int bboxId = (idx % (q * num_boxes) + bboxOffset);
|
||||
nmsed_classes[i * per_detections + num] = T((idx % (num_classes * num_boxes)) / num_boxes);
|
||||
nmsed_scores[i * per_detections + num] = T(score);
|
||||
float xMin = new_boxes[bboxId * DIM4];
|
||||
float yMin = new_boxes[bboxId * DIM4 + DIM1];
|
||||
float xMax = new_boxes[bboxId * DIM4 + DIM2];
|
||||
float yMax = new_boxes[bboxId * DIM4 + DIM3];
|
||||
nmsed_boxes[(i * per_detections + num) * DIM4] = T(clip_boxes ? max(min(xMin, one), zero) : xMin);
|
||||
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM1] = T(clip_boxes ? max(min(yMin, one), zero) : yMin);
|
||||
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM2] = T(clip_boxes ? max(min(xMax, one), zero) : xMax);
|
||||
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM3] = T(clip_boxes ? max(min(yMax, one), zero) : yMax);
|
||||
num++;
|
||||
}
|
||||
}
|
||||
valid_detections[i] = num;
|
||||
while (num < per_detections) {
|
||||
nmsed_classes[i * per_detections + num] = T(zero);
|
||||
nmsed_scores[i * per_detections + num] = T(zero);
|
||||
nmsed_boxes[(i * per_detections + num) * DIM4] = T(zero);
|
||||
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM1] = T(zero);
|
||||
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM2] = T(zero);
|
||||
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM3] = T(zero);
|
||||
num++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalSort(T *scores, int *index, T *score_threshold, int num_classes, T *boxes, float *new_boxes, float *new_scores,
|
||||
int batch_size, int num_boxes, float *boxes_result, int q, bool *sel, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
permute<<<CUDA_BLOCKS(device_id, q * num_boxes * batch_size * DIM4), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
q, num_boxes, batch_size, boxes, new_boxes);
|
||||
boxsort<<<CUDA_BLOCKS(device_id, batch_size * q * num_boxes), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
boxes_result, new_boxes, batch_size, q, num_boxes);
|
||||
presort<<<CUDA_BLOCKS(device_id, num_classes * num_boxes * batch_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
num_classes, num_boxes, batch_size, scores, new_scores, index, score_threshold, sel);
|
||||
auto policy = thrust::cuda::par.on(cuda_stream);
|
||||
for (int i = DIM0; i < num_classes * batch_size; i++) {
|
||||
thrust::stable_sort_by_key(policy, thrust::device_pointer_cast(new_scores + i * num_boxes),
|
||||
thrust::device_pointer_cast(new_scores + i * num_boxes) + num_boxes,
|
||||
thrust::device_pointer_cast(index + i * num_boxes), thrust::greater<float>());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Calnms(int batch_size, int num_classes, T *iou_threshold, bool *sel, float *boxes_result, int *index, int q,
|
||||
int num_boxes, int *max_output_size_per_class, float *new_scores, bool *mask, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
Init<<<CUDA_BLOCKS(device_id, batch_size * num_classes * num_boxes * num_boxes), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(num_classes, num_boxes, batch_size, mask);
|
||||
nms<<<CUDA_BLOCKS(device_id, batch_size * num_classes * num_boxes * num_boxes), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(batch_size, num_classes, iou_threshold, sel, boxes_result, index, q, num_boxes, mask);
|
||||
nmsReducePass<<<CUDA_BLOCKS(device_id, num_boxes), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
batch_size, num_classes, sel, index, num_boxes, mask);
|
||||
sizeperclass<<<CUDA_BLOCKS(device_id, batch_size * num_classes), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
batch_size, num_classes, sel, num_boxes, index, max_output_size_per_class);
|
||||
|
||||
auto policy = thrust::cuda::par.on(cuda_stream);
|
||||
for (int i = DIM0; i < batch_size; i++) {
|
||||
thrust::stable_sort_by_key(
|
||||
policy, thrust::device_pointer_cast(new_scores + i * num_boxes * num_classes),
|
||||
thrust::device_pointer_cast(new_scores + i * num_boxes * num_classes) + (num_boxes * num_classes),
|
||||
thrust::device_pointer_cast(index + i * num_boxes * num_classes), thrust::greater<float>());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Caloutput(int batch_size, int per_detections, int *index, float *new_scores, bool *sel, float *new_boxes,
|
||||
T *nmsed_classes, T *nmsed_scores, T *nmsed_boxes, int *valid_detections, bool clip_boxes,
|
||||
int num_classes, int num_boxes, int q, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
output<<<CUDA_BLOCKS(device_id, batch_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
batch_size, per_detections, index, new_scores, sel, new_boxes, nmsed_classes, nmsed_scores, nmsed_boxes,
|
||||
valid_detections, clip_boxes, num_classes, num_boxes, q);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalSort<float>(float *scores, int *index, float *score_threshold, int num_classes,
|
||||
float *boxes, float *new_boxes, float *new_scores, int batch_size,
|
||||
int num_boxes, float *boxes_result, int q, bool *sel,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Calnms<float>(int batch_size, int num_classes, float *iou_threshold, bool *sel,
|
||||
float *boxes_result, int *index, int q, int num_boxes,
|
||||
int *max_output_size_per_class, float *new_scores, bool *mask,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Caloutput<float>(int batch_size, int per_detections, int *index, float *new_scores,
|
||||
bool *sel, float *new_boxes, float *nmsed_classes, float *nmsed_scores,
|
||||
float *nmsed_boxes, int *valid_detections, bool clip_boxes,
|
||||
int num_classes, int num_boxes, int q, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMBINED_NON_MAX_SUPPRESSION_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMBINED_NON_MAX_SUPPRESSION_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalSort(T *scores, int *index, T *score_threshold, int num_classes, T *boxes, float *new_boxes,
|
||||
float *new_scores, int batch_size, int num_boxes, float *boxes_result, int q, bool *sel,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Calnms(int batch_size, int num_classes, T *iou_threshold, bool *sel, float *boxes_result,
|
||||
int *index, int q, int num_boxes, int *max_output_size_per_class, float *new_scores,
|
||||
bool *mask, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Caloutput(int batch_size, int per_detections, int *index, float *new_scores, bool *sel,
|
||||
float *new_boxes, T *nmsed_classes, T *nmsed_scores, T *nmsed_boxes,
|
||||
int *valid_detections, bool clip_boxes, int num_classes, int num_boxes, int q,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMBINED_NON_MAX_SUPPRESSION_IMPL_CUH_
|
|
@ -0,0 +1,167 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless req_uired by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/combined_non_max_suppression_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/combined_non_max_suppression_impl.cuh"
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int DimSize4 = 4;
|
||||
void CombinedNonMaxSuppressionGpuKernelMod::ResetResource() noexcept {
|
||||
cuda_stream_ = nullptr;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
void CombinedNonMaxSuppressionGpuKernelMod::InitSizeLists() {
|
||||
input_size_list_.push_back(batch_size_ * num_boxes_ * q_ * DimSize4 * sizeof(T));
|
||||
input_size_list_.push_back(batch_size_ * num_boxes_ * num_classes_ * sizeof(T));
|
||||
input_size_list_.push_back(sizeof(int));
|
||||
input_size_list_.push_back(sizeof(int));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
output_size_list_.push_back(batch_size_ * per_detections_ * DimSize4 * sizeof(T));
|
||||
output_size_list_.push_back(batch_size_ * per_detections_ * sizeof(T));
|
||||
output_size_list_.push_back(batch_size_ * per_detections_ * sizeof(T));
|
||||
output_size_list_.push_back(batch_size_ * sizeof(int));
|
||||
workspace_size_list_.push_back(q_ * batch_size_ * num_boxes_ * DimSize4 * sizeof(float)); // new_boxes
|
||||
workspace_size_list_.push_back(batch_size_ * num_classes_ * num_boxes_ * sizeof(float)); // new_scores
|
||||
workspace_size_list_.push_back(batch_size_ * q_ * num_boxes_ * DimSize4 * sizeof(float)); // boxes_result
|
||||
workspace_size_list_.push_back(batch_size_ * num_classes_ * num_boxes_ * sizeof(int)); // index
|
||||
workspace_size_list_.push_back(batch_size_ * num_classes_ * num_boxes_ * sizeof(bool)); // sel
|
||||
workspace_size_list_.push_back(batch_size_ * num_classes_ * num_boxes_ * num_boxes_ * sizeof(bool));
|
||||
}
|
||||
|
||||
bool CombinedNonMaxSuppressionGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::CombinedNonMaxSuppression>(base_operator);
|
||||
pad_per_class_ = kernel_ptr->get_pad_per_class();
|
||||
clip_boxes_ = kernel_ptr->get_clip_boxes();
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in "
|
||||
<< "[float32,int32], but got: " << kernel_attr;
|
||||
}
|
||||
is_need_retrieve_output_shape_ = true;
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int CombinedNonMaxSuppressionGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
for (const auto &input : inputs) {
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
ResetResource();
|
||||
outputs_ = outputs;
|
||||
std::vector<size_t> input0_shape = std::vector<size_t>(inputs[kIndex0]->GetDeviceShapeAdaptively().begin(),
|
||||
inputs[kIndex0]->GetDeviceShapeAdaptively().end());
|
||||
std::vector<size_t> input1_shape = std::vector<size_t>(inputs[kIndex1]->GetDeviceShapeAdaptively().begin(),
|
||||
inputs[kIndex1]->GetDeviceShapeAdaptively().end());
|
||||
batch_size_ = static_cast<int>(input0_shape[kIndex0]);
|
||||
num_boxes_ = static_cast<int>(input0_shape[kIndex1]);
|
||||
q_ = static_cast<int>(input0_shape[kIndex2]);
|
||||
num_classes_ = static_cast<int>(input1_shape[kIndex2]);
|
||||
auto prim = base_operator->GetPrim();
|
||||
if ((prim->GetAttr("per_detections"))) {
|
||||
per_detections_ = GetValue<int>(prim->GetAttr("per_detections"));
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CombinedNonMaxSuppressionGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
T *boxes = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *scores = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
int *max_output_size_per_class = GetDeviceAddress<int>(inputs, kIndex2);
|
||||
T *iou_threshold = GetDeviceAddress<T>(inputs, kIndex4);
|
||||
T *score_threshold = GetDeviceAddress<T>(inputs, kIndex5);
|
||||
T *nmsed_boxes = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
T *nmsed_scores = GetDeviceAddress<T>(outputs, kIndex1);
|
||||
T *nmsed_classes = GetDeviceAddress<T>(outputs, kIndex2);
|
||||
int *valid_detections = GetDeviceAddress<int>(outputs, kIndex3);
|
||||
float *new_boxes = GetDeviceAddress<float>(workspace, kIndex0);
|
||||
float *new_scores = GetDeviceAddress<float>(workspace, kIndex1);
|
||||
float *boxes_result = GetDeviceAddress<float>(workspace, kIndex2);
|
||||
int *index = GetDeviceAddress<int>(workspace, kIndex3);
|
||||
bool *sel = GetDeviceAddress<bool>(workspace, kIndex4);
|
||||
bool *mask = GetDeviceAddress<bool>(workspace, kIndex5);
|
||||
|
||||
CalSort(scores, index, score_threshold, num_classes_, boxes, new_boxes, new_scores, batch_size_, num_boxes_,
|
||||
boxes_result, q_, sel, device_id_, cuda_stream_);
|
||||
Calnms(batch_size_, num_classes_, iou_threshold, sel, boxes_result, index, q_, num_boxes_, max_output_size_per_class,
|
||||
new_scores, mask, device_id_, cuda_stream_);
|
||||
Caloutput(batch_size_, per_detections_, index, new_scores, sel, new_boxes, nmsed_classes, nmsed_scores, nmsed_boxes,
|
||||
valid_detections, clip_boxes_, num_classes_, num_boxes_, q_, device_id_, cuda_stream_);
|
||||
return true;
|
||||
}
|
||||
|
||||
void CombinedNonMaxSuppressionGpuKernelMod::SyncData() {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream_), "cudaStreamSynchronized failed");
|
||||
std::vector<int64_t> shape0 = {batch_size_, per_detections_, DimSize4};
|
||||
std::vector<int64_t> shape1 = {batch_size_, per_detections_};
|
||||
std::vector<int64_t> shape2 = {batch_size_, per_detections_};
|
||||
std::vector<int64_t> shape3 = {batch_size_};
|
||||
outputs_[kIndex0]->SetShapeVector(shape0);
|
||||
outputs_[kIndex1]->SetShapeVector(shape1);
|
||||
outputs_[kIndex2]->SetShapeVector(shape2);
|
||||
outputs_[kIndex3]->SetShapeVector(shape3);
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, CombinedNonMaxSuppressionGpuKernelMod::CombinedNonMaxSuppressionLaunchFunc>>
|
||||
CombinedNonMaxSuppressionGpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
&CombinedNonMaxSuppressionGpuKernelMod::LaunchKernel<float>},
|
||||
};
|
||||
|
||||
std::vector<KernelAttr> CombinedNonMaxSuppressionGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(
|
||||
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, CombinedNonMaxSuppressionLaunchFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CombinedNonMaxSuppression, CombinedNonMaxSuppressionGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_COMBINED_NON_MAX_SUPPRESSION_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_COMBINED_NON_MAX_SUPPRESSION_GPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include "mindspore/core/ops/combined_non_max_suppression.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class CombinedNonMaxSuppressionGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
CombinedNonMaxSuppressionGpuKernelMod() { ResetResource(); }
|
||||
~CombinedNonMaxSuppressionGpuKernelMod() override = default;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
|
||||
}
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
void SyncData() override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
std::vector<KernelTensorPtr> GetOutputs() override { return outputs_; }
|
||||
|
||||
private:
|
||||
void ResetResource() noexcept;
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr);
|
||||
using CombinedNonMaxSuppressionLaunchFunc =
|
||||
std::function<bool(CombinedNonMaxSuppressionGpuKernelMod *, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
|
||||
static std::vector<std::pair<KernelAttr, CombinedNonMaxSuppressionLaunchFunc>> func_list_;
|
||||
CombinedNonMaxSuppressionLaunchFunc kernel_func_;
|
||||
cudaStream_t cuda_stream_;
|
||||
void InitSizeLists();
|
||||
size_t T;
|
||||
int batch_size_;
|
||||
int num_boxes_;
|
||||
int q_;
|
||||
int num_classes_;
|
||||
int per_detections_;
|
||||
bool pad_per_class_;
|
||||
bool clip_boxes_;
|
||||
std::vector<KernelTensorPtr> outputs_{};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_COMBINED_NON_MAX_SUPPRESSION_GPU_KERNEL_H_
|
|
@ -107,7 +107,7 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionGetOutputShape(const PrimitiveP
|
|||
if (pad_per_class) {
|
||||
num_detection = std::min(max_total_size, max_output_size_per_class * static_cast<int32_t>(input1_shape[ksecond]));
|
||||
}
|
||||
|
||||
(void)primitive->AddAttr("per_detections", MakeValue(num_detection));
|
||||
int64_t bs = input0_shape[0];
|
||||
ShapeVector shape1 = {bs, num_detection, 4};
|
||||
ShapeVector shape2 = {bs, num_detection};
|
||||
|
@ -141,7 +141,6 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &
|
|||
input3_shape, input4_shape, input5_shape};
|
||||
auto is_dynamic = (IsDynamic(input0_shape) || IsDynamic(input1_shape));
|
||||
auto is_dynamic_rank = std::any_of(all_shapes.begin(), all_shapes.end(), IsDynamicRank);
|
||||
|
||||
CombinedNonMaxSuppressionCheckShapeSize(input0_shape, input1_shape, input2_shape, input3_shape, input4_shape,
|
||||
input5_shape, is_dynamic_rank, prim_name);
|
||||
|
||||
|
@ -204,6 +203,16 @@ AbstractBasePtr CombinedNonMaxSuppressionInfer(const abstract::AnalysisEnginePtr
|
|||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
bool CombinedNonMaxSuppression::get_pad_per_class() const {
|
||||
auto value_ptr = this->GetAttr("pad_per_class");
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
bool CombinedNonMaxSuppression::get_clip_boxes() const {
|
||||
auto value_ptr = this->GetAttr("clip_boxes");
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CombinedNonMaxSuppression, prim::kPrimCombinedNonMaxSuppression,
|
||||
CombinedNonMaxSuppressionInfer, nullptr, true);
|
||||
REGISTER_HOST_DEPENDS(kNameCombinedNonMaxSuppression, {2, 3, 4, 5});
|
||||
|
|
|
@ -36,6 +36,8 @@ class MIND_API CombinedNonMaxSuppression : public BaseOperator {
|
|||
InitIOName({"boxes", "scores", "max_output_size_per_class", "max_total_size", "iou_threshold", "score_threshold"},
|
||||
{"nmsed_box", "nmsed_scores", "nmsed_classes", "valid_detections"});
|
||||
}
|
||||
bool get_pad_per_class() const;
|
||||
bool get_clip_boxes() const;
|
||||
};
|
||||
abstract::AbstractBasePtr CombinedNonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &,
|
||||
const PrimitivePtr &primitive,
|
||||
|
|
|
@ -1805,6 +1805,44 @@ class PReLUGrad(Primitive):
|
|||
pass
|
||||
|
||||
|
||||
class RandomGammaGrad(Primitive):
|
||||
r"""
|
||||
Computes the derivative of a random sample of Gamma with respect to alpha.:
|
||||
|
||||
Inputs:
|
||||
- **alpha** (Tensor) - α is the shape parameter of RandomGamma distribution.
|
||||
It must be greater than 0. Must be one of the following types: float32, float64.
|
||||
- **sample** (Tensor) - The sample of random gamma tensor. Must be one of the
|
||||
following types: float32, float64.
|
||||
|
||||
Outputs:
|
||||
The dtype is the same type as alpha.
|
||||
The output shape is derived from the input through broadcasting.
|
||||
|
||||
Raises:
|
||||
TypeError: If data type of `alpha` and `sample` is not float32 or float64.
|
||||
TypeError: If data type of `alpha` and `sample` is not same.
|
||||
ValueError: If the shape last dim of `sample` and `alpha` is not equal.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> alpha = Tensor(np.array([1., 0.6, 3., 26.]), mstype.float32)
|
||||
>>> sample = Tensor(np.array([6., 7, 11., 0.5]), mstype.float32)
|
||||
>>> randomgammagrad = ops.RandomGammaGrad()
|
||||
>>> output = randomgammagrad(alpha, sample)
|
||||
>>> print(output)
|
||||
[2.5142431 3.4334087 1.8847835 0.07780622]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize RandomGammaGrad"""
|
||||
self.init_prim_io_names(inputs=['alpha', 'sample'], outputs=['output'])
|
||||
self.add_prim_attr("side_effect_hidden", True)
|
||||
|
||||
|
||||
class ReluGrad(Primitive):
|
||||
"""Performs grad of Relu operation."""
|
||||
|
||||
|
|
|
@ -1121,7 +1121,7 @@ class CombinedNonMaxSuppression(Primitive):
|
|||
ValueError: If `iou_threshold` not in [0,1].
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> boxes = Tensor(np.array([[[[200, 100, 150, 100]],
|
||||
|
|
|
@ -312,44 +312,6 @@ class LogNormalReverse(Primitive):
|
|||
Validator.check_value_type("std", std, [float], self.name)
|
||||
|
||||
|
||||
class RandomGammaGrad(Primitive):
|
||||
r"""
|
||||
Computes the derivative of a random sample of Gamma with respect to alpha.:
|
||||
|
||||
Inputs:
|
||||
- **alpha** (Tensor) - α is the shape parameter of RandomGamma distribution.
|
||||
It must be greater than 0. Must be one of the following types: float32, float64.
|
||||
- **sample** (Tensor) - The sample of random gamma tensor. Must be one of the
|
||||
following types: float32, float64.
|
||||
|
||||
Outputs:
|
||||
The dtype is the same type as alpha.
|
||||
The output shape is derived from the input through broadcasting.
|
||||
|
||||
Raises:
|
||||
TypeError: If data type of `alpha` and `sample` is not float32 or float64.
|
||||
TypeError: If data type of `alpha` and `sample` is not same.
|
||||
ValueError: If the shape last dim of `sample` and `alpha` is not equal.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> alpha = Tensor(np.array([1., 0.6, 3., 26.]), mstype.float32)
|
||||
>>> sample = Tensor(np.array([6., 7, 11., 0.5]), mstype.float32)
|
||||
>>> randomgammagrad = ops.RandomGammaGrad()
|
||||
>>> output = randomgammagrad(alpha, sample)
|
||||
>>> print(output)
|
||||
[2.5142431 3.4334087 1.8847835 0.07780622]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize RandomGammaGrad"""
|
||||
self.init_prim_io_names(inputs=['alpha', 'sample'], outputs=['output'])
|
||||
self.add_prim_attr("side_effect_hidden", True)
|
||||
|
||||
|
||||
class Gamma(PrimitiveWithInfer):
|
||||
r"""
|
||||
Produces random positive floating-point values x, distributed according to probability density function:
|
||||
|
|
|
@ -0,0 +1,181 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops.operations.image_ops import CombinedNonMaxSuppression
|
||||
|
||||
|
||||
class NetCombinedNonMaxSuppression(nn.Cell):
|
||||
def __init__(self, pad_per_class, clip_boxes):
|
||||
super(NetCombinedNonMaxSuppression, self).__init__()
|
||||
self.combined_non_max_suppression = CombinedNonMaxSuppression(
|
||||
pad_per_class, clip_boxes)
|
||||
|
||||
def construct(self, boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, score_threshold):
|
||||
return self.combined_non_max_suppression(boxes, scores, max_output_size_per_class, max_total_size,
|
||||
iou_threshold, score_threshold)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def cmp(output, expect):
|
||||
output_type = output.asnumpy().dtype
|
||||
expect_type = expect.asnumpy().dtype
|
||||
diff0 = output.asnumpy() - expect
|
||||
error0 = np.zeros(shape=expect.shape)
|
||||
assert np.all(diff0 == error0)
|
||||
assert output.shape == expect.shape
|
||||
assert output_type == expect_type
|
||||
|
||||
|
||||
def test_combined_non_max_suppresion1():
|
||||
"""
|
||||
Feature: Combined non max suppression.
|
||||
Description: test case for Combined non max suppression.
|
||||
Expectation: The result are as expected.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
boxes = Tensor(np.array([[[[200, 100, 150, 100]], [[220, 120, 150, 100]], [[190, 110, 150, 100]],
|
||||
[[210, 112, 150, 100]]]], dtype=np.float32))
|
||||
scores = Tensor(np.array([[[0.2000, 0.7000, 0.1000], [0.1000, 0.8000, 0.1000],
|
||||
[0.3000, 0.6000, 0.1000], [0.0500, 0.9000, 0.0500]]], dtype=np.float32))
|
||||
max_output_size_per_class = Tensor(4, dtype=mstype.int32)
|
||||
max_total_size = Tensor(1, dtype=mstype.int32)
|
||||
iou_threshold = Tensor(0, dtype=mstype.float32)
|
||||
score_threshold = Tensor(0, dtype=mstype.float32)
|
||||
nmsed_boxes_expect = Tensor(
|
||||
np.array([[[1., 1., 1., 1.]]], dtype=np.float32))
|
||||
nmsed_scores_expect = Tensor(np.array([[0.9]], dtype=np.float32))
|
||||
nmsed_classes_expect = Tensor(np.array([[1.]], dtype=np.float32))
|
||||
valid_detections_expect = Tensor(np.array([1], dtype=np.int32))
|
||||
net = NetCombinedNonMaxSuppression(False, True)
|
||||
output = net(boxes, scores, max_output_size_per_class,
|
||||
max_total_size, iou_threshold, score_threshold)
|
||||
output_nmsed_boxes = output[0]
|
||||
output_nmsed_scores = output[1]
|
||||
output_nmsed_classes = output[2]
|
||||
output_valid_detections = output[3]
|
||||
cmp(output_nmsed_boxes, nmsed_boxes_expect)
|
||||
cmp(output_nmsed_scores, nmsed_scores_expect)
|
||||
cmp(output_nmsed_classes, nmsed_classes_expect)
|
||||
cmp(output_valid_detections, valid_detections_expect)
|
||||
|
||||
|
||||
def test_combined_non_max_suppresion2():
|
||||
"""
|
||||
Feature: Combined non max suppression.
|
||||
Description: test case for Combined non max suppression.
|
||||
Expectation: The result are as expected.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
||||
boxes = Tensor(np.array([[[[200, 100, 150, 100]], [[220, 120, 150, 100]], [[190, 110, 150, 100]],
|
||||
[[210, 112, 150, 100]]]], dtype=np.float32))
|
||||
scores = Tensor(np.array([[[0.2000, 0.7000, 0.1000], [0.1000, 0.8000, 0.1000],
|
||||
[0.3000, 0.6000, 0.1000], [0.0500, 0.9000, 0.0500]]], dtype=np.float32))
|
||||
max_output_size_per_class = Tensor(4, dtype=mstype.int32)
|
||||
max_total_size = Tensor(1, dtype=mstype.int32)
|
||||
iou_threshold = Tensor(0, dtype=mstype.float32)
|
||||
score_threshold = Tensor(0, dtype=mstype.float32)
|
||||
nmsed_boxes_expect = Tensor(
|
||||
np.array([[[1., 1., 1., 1.]]], dtype=np.float32))
|
||||
nmsed_scores_expect = Tensor(np.array([[0.9]], dtype=np.float32))
|
||||
nmsed_classes_expect = Tensor(np.array([[1.]], dtype=np.float32))
|
||||
valid_detections_expect = Tensor(np.array([1], dtype=np.int32))
|
||||
net = NetCombinedNonMaxSuppression(False, True)
|
||||
output = net(boxes, scores, max_output_size_per_class,
|
||||
max_total_size, iou_threshold, score_threshold)
|
||||
output_nmsed_boxes = output[0]
|
||||
output_nmsed_scores = output[1]
|
||||
output_nmsed_classes = output[2]
|
||||
output_valid_detections = output[3]
|
||||
cmp(output_nmsed_boxes, nmsed_boxes_expect)
|
||||
cmp(output_nmsed_scores, nmsed_scores_expect)
|
||||
cmp(output_nmsed_classes, nmsed_classes_expect)
|
||||
cmp(output_valid_detections, valid_detections_expect)
|
||||
|
||||
|
||||
def test_combined_non_max_suppresion3():
|
||||
"""
|
||||
Feature: Combined non max suppression.
|
||||
Description: test case for Combined non max suppression.
|
||||
Expectation: The result are as expected.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
boxes = Tensor(np.array([[[[200, 100, 150, 100]], [[220, 120, 150, 100]], [[190, 110, 150, 100]],
|
||||
[[210, 112, 150, 100]]]], dtype=np.float32))
|
||||
scores = Tensor(np.array([[[0.2000, 0.7000, 0.1000], [0.1000, 0.8000, 0.1000],
|
||||
[0.3000, 0.6000, 0.1000], [0.0500, 0.9000, 0.0500]]], dtype=np.float32))
|
||||
max_output_size_per_class = Tensor(4, dtype=mstype.int32)
|
||||
max_total_size = Tensor(1, dtype=mstype.int32)
|
||||
iou_threshold = Tensor(0.2, dtype=mstype.float32)
|
||||
score_threshold = Tensor(0.2, dtype=mstype.float32)
|
||||
nmsed_boxes_expect = Tensor(
|
||||
np.array([[[210., 112., 150., 100.]]], dtype=np.float32))
|
||||
nmsed_scores_expect = Tensor(np.array([[0.9]], dtype=np.float32))
|
||||
nmsed_classes_expect = Tensor(np.array([[1.]], dtype=np.float32))
|
||||
valid_detections_expect = Tensor(np.array([1], dtype=np.int32))
|
||||
net = NetCombinedNonMaxSuppression(True, False)
|
||||
output = net(boxes, scores, max_output_size_per_class,
|
||||
max_total_size, iou_threshold, score_threshold)
|
||||
output_nmsed_boxes = output[0]
|
||||
output_nmsed_scores = output[1]
|
||||
output_nmsed_classes = output[2]
|
||||
output_valid_detections = output[3]
|
||||
cmp(output_nmsed_boxes, nmsed_boxes_expect)
|
||||
cmp(output_nmsed_scores, nmsed_scores_expect)
|
||||
cmp(output_nmsed_classes, nmsed_classes_expect)
|
||||
cmp(output_valid_detections, valid_detections_expect)
|
||||
|
||||
|
||||
def test_combined_non_max_suppresion4():
|
||||
"""
|
||||
Feature: Combined non max suppression.
|
||||
Description: test case for Combined non max suppression.
|
||||
Expectation: The result are as expected.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
||||
boxes = Tensor(np.array([[[[200, 100, 150, 100]], [[220, 120, 150, 100]], [[190, 110, 150, 100]],
|
||||
[[210, 112, 150, 100]]]], dtype=np.float32))
|
||||
scores = Tensor(np.array([[[0.2000, 0.7000, 0.1000], [0.1000, 0.8000, 0.1000],
|
||||
[0.3000, 0.6000, 0.1000], [0.0500, 0.9000, 0.0500]]], dtype=np.float32))
|
||||
max_output_size_per_class = Tensor(4, dtype=mstype.int32)
|
||||
max_total_size = Tensor(1, dtype=mstype.int32)
|
||||
iou_threshold = Tensor(0, dtype=mstype.float32)
|
||||
score_threshold = Tensor(0, dtype=mstype.float32)
|
||||
nmsed_boxes_expect = Tensor(
|
||||
np.array([[[210., 112., 150., 100.]]], dtype=np.float32))
|
||||
nmsed_scores_expect = Tensor(np.array([[0.9]], dtype=np.float32))
|
||||
nmsed_classes_expect = Tensor(np.array([[1.]], dtype=np.float32))
|
||||
valid_detections_expect = Tensor(np.array([1], dtype=np.int32))
|
||||
net = NetCombinedNonMaxSuppression(True, False)
|
||||
output = net(boxes, scores, max_output_size_per_class,
|
||||
max_total_size, iou_threshold, score_threshold)
|
||||
output_nmsed_boxes = output[0]
|
||||
output_nmsed_scores = output[1]
|
||||
output_nmsed_classes = output[2]
|
||||
output_valid_detections = output[3]
|
||||
cmp(output_nmsed_boxes, nmsed_boxes_expect)
|
||||
cmp(output_nmsed_scores, nmsed_scores_expect)
|
||||
cmp(output_nmsed_classes, nmsed_classes_expect)
|
||||
cmp(output_valid_detections, valid_detections_expect)
|
|
@ -17,7 +17,7 @@ import pytest
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations.random_ops import RandomGammaGrad
|
||||
from mindspore.ops.operations._grad_ops import RandomGammaGrad
|
||||
|
||||
|
||||
class RandomGammaGradNet(nn.Cell):
|
||||
|
|
Loading…
Reference in New Issue