!41886 [assistant][ops] Add GPU Operator CombinedNonMaxSuppression

Merge pull request !41886 from Forever Young/CombinedNonMaxSuppression
This commit is contained in:
i-robot 2022-12-01 12:56:01 +00:00 committed by Gitee
commit e2e0e009e6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 769 additions and 42 deletions

View File

@ -0,0 +1,261 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless req_uired by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/sort.h>
#include <thrust/device_ptr.h>
#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#include <vector>
#include <limits>
#include <iostream>
#include <algorithm>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/combined_non_max_suppression_impl.cuh"
constexpr int DIM0 = 0;
constexpr int DIM1 = 1;
constexpr int DIM2 = 2;
constexpr int DIM3 = 3;
constexpr int DIM4 = 4;
constexpr float zero = 0;
constexpr float one = 1;
__inline__ __device__ float IOU(float *boxes_result, int i, int j) {
float lx, ly, rx, ry;
float w, h;
float area;
float area_a = (boxes_result[i * DIM4 + DIM2] - boxes_result[i * DIM4 + DIM0]) *
(boxes_result[i * DIM4 + DIM3] - boxes_result[i * DIM4 + DIM1]);
float area_b = (boxes_result[j * DIM4 + DIM2] - boxes_result[j * DIM4 + DIM0]) *
(boxes_result[j * DIM4 + DIM3] - boxes_result[j * DIM4 + DIM1]);
if ((area_a == zero) || (area_b == zero)) {
return zero;
}
lx = max(boxes_result[i * DIM4 + DIM0], boxes_result[j * DIM4 + DIM0]);
ly = max(boxes_result[i * DIM4 + DIM1], boxes_result[j * DIM4 + DIM1]);
rx = min(boxes_result[i * DIM4 + DIM2], boxes_result[j * DIM4 + DIM2]);
ry = min(boxes_result[i * DIM4 + DIM3], boxes_result[j * DIM4 + DIM3]);
w = (rx > lx) ? (rx - lx) : zero;
h = (ry > ly) ? (ry - ly) : zero;
area = w * h;
return area / (area_a + area_b - area);
}
template <typename T>
__global__ void permute(int q, int num_boxes, int batch_size, T *boxes, float *new_boxes) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < q * num_boxes * batch_size * DIM4;
index += blockDim.x * gridDim.x) {
int i = index % DIM4;
int c = (index / DIM4) % q;
int d = (index / DIM4 / q) % num_boxes;
int n = index / DIM4 / q / num_boxes;
int new_index = ((n * q + c) * num_boxes + d) * DIM4 + i;
float result = boxes[index];
new_boxes[new_index] = result;
}
}
__global__ void boxsort(float *boxes_result, float *new_boxes, int batch_size, int q, int num_boxes) {
for (int box_num = blockIdx.x * blockDim.x + threadIdx.x; box_num < batch_size * q * num_boxes;
box_num += blockDim.x * gridDim.x) {
boxes_result[box_num * DIM4 + DIM0] = min(new_boxes[box_num * DIM4 + DIM0], new_boxes[box_num * DIM4 + DIM2]);
boxes_result[box_num * DIM4 + DIM2] = max(new_boxes[box_num * DIM4 + DIM0], new_boxes[box_num * DIM4 + DIM2]);
boxes_result[box_num * DIM4 + DIM1] = min(new_boxes[box_num * DIM4 + DIM1], new_boxes[box_num * DIM4 + DIM3]);
boxes_result[box_num * DIM4 + DIM3] = max(new_boxes[box_num * DIM4 + DIM1], new_boxes[box_num * DIM4 + DIM3]);
}
}
template <typename T>
__global__ void presort(int num_classes, int num_boxes, int batch_size, T *scores, float *new_scores, int *index,
T *score_threshold, bool *sel) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_classes * num_boxes * batch_size;
idx += blockDim.x * gridDim.x) {
int c = idx % num_classes;
int d = (idx / num_classes) % num_boxes;
int n = idx / (num_classes * num_boxes);
int new_index = (n * num_classes + c) * num_boxes + d;
float result = scores[idx];
new_scores[new_index] = result;
index[new_index] = new_index;
sel[new_index] = (new_scores[new_index] > score_threshold[DIM0]) ? true : false;
}
}
__global__ void Init(int num_classes, int num_boxes, int batch_size, bool *mask) {
for (int mat_pos = blockIdx.x * blockDim.x + threadIdx.x; mat_pos < batch_size * num_classes * num_boxes * num_boxes;
mat_pos += blockDim.x * gridDim.x) {
mask[mat_pos] = true;
}
}
template <typename T>
__global__ void nms(int batch_size, int num_classes, T *iou_threshold, bool *sel, float *boxes_result, int *index,
int q, int num_boxes, bool *mask) {
int box_i, box_j;
for (int mask_index = blockIdx.x * blockDim.x + threadIdx.x;
mask_index < batch_size * num_classes * num_boxes * num_boxes; mask_index += blockDim.x * gridDim.x) {
box_i = mask_index / num_boxes; // row in 2d sel_mask array
box_j = mask_index / (num_boxes * num_boxes) * num_boxes + mask_index % num_boxes; // col in 2d sel_mask array
if (box_j > box_i) {
int idi = index[box_i];
int idj = index[box_j]; // skip when box_j index lower/equal to box_i - will remain true
if (q == num_classes) {
if (IOU(boxes_result, idi, idj) > iou_threshold[0]) {
mask[mask_index] = false;
}
} else {
if (IOU(boxes_result, idi / (num_classes * num_boxes) * num_boxes + idi % num_boxes,
idj / (num_classes * num_boxes) * num_boxes + idj % num_boxes) > iou_threshold[0]) {
mask[mask_index] = false;
}
}
}
}
}
__global__ void nmsReducePass(int batch_size, int num_classes, bool *sel, int *index, int num_boxes, bool *mask) {
for (int page = DIM0; page < batch_size * num_classes; page++) {
for (int i = DIM0; i < num_boxes - DIM1; ++i) {
int idxi = index[page * num_boxes + i];
if (!sel[idxi]) {
continue;
}
for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < num_boxes; j += blockDim.x * gridDim.x) {
int idxj = index[page * num_boxes + j];
sel[idxj] = sel[idxj] && mask[page * num_boxes * num_boxes + i * num_boxes + j];
}
}
}
}
__global__ void sizeperclass(int batch_size, int num_classes, bool *sel, int num_boxes, int *index,
int *max_output_size_per_class) {
for (int page = blockIdx.x * blockDim.x + threadIdx.x; page < batch_size * num_classes;
page += blockDim.x * gridDim.x) {
int class_idx_count = DIM0;
for (int i = page * num_boxes; i < (page + DIM1) * num_boxes; i++) {
int number = index[i];
if (sel[number]) {
class_idx_count++;
if (class_idx_count > max_output_size_per_class[DIM0]) {
sel[number] = false;
}
}
}
}
}
template <typename T>
__global__ void output(int batch_size, int per_detections, int *index, float *new_scores, bool *sel, float *new_boxes,
T *nmsed_classes, T *nmsed_scores, T *nmsed_boxes, int *valid_detections, bool clip_boxes,
int num_classes, int num_boxes, int q) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += gridDim.x * blockDim.x) {
int num = DIM0;
for (int j = i * num_classes * num_boxes; (j < (i + DIM1) * num_classes * num_boxes) && (num < per_detections);
j++) {
int idx = index[j];
float score = new_scores[j];
if (sel[idx]) {
int bboxOffset = i * (q * num_boxes);
int bboxId = (idx % (q * num_boxes) + bboxOffset);
nmsed_classes[i * per_detections + num] = T((idx % (num_classes * num_boxes)) / num_boxes);
nmsed_scores[i * per_detections + num] = T(score);
float xMin = new_boxes[bboxId * DIM4];
float yMin = new_boxes[bboxId * DIM4 + DIM1];
float xMax = new_boxes[bboxId * DIM4 + DIM2];
float yMax = new_boxes[bboxId * DIM4 + DIM3];
nmsed_boxes[(i * per_detections + num) * DIM4] = T(clip_boxes ? max(min(xMin, one), zero) : xMin);
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM1] = T(clip_boxes ? max(min(yMin, one), zero) : yMin);
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM2] = T(clip_boxes ? max(min(xMax, one), zero) : xMax);
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM3] = T(clip_boxes ? max(min(yMax, one), zero) : yMax);
num++;
}
}
valid_detections[i] = num;
while (num < per_detections) {
nmsed_classes[i * per_detections + num] = T(zero);
nmsed_scores[i * per_detections + num] = T(zero);
nmsed_boxes[(i * per_detections + num) * DIM4] = T(zero);
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM1] = T(zero);
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM2] = T(zero);
nmsed_boxes[(i * per_detections + num) * DIM4 + DIM3] = T(zero);
num++;
}
}
}
template <typename T>
void CalSort(T *scores, int *index, T *score_threshold, int num_classes, T *boxes, float *new_boxes, float *new_scores,
int batch_size, int num_boxes, float *boxes_result, int q, bool *sel, const uint32_t &device_id,
cudaStream_t cuda_stream) {
permute<<<CUDA_BLOCKS(device_id, q * num_boxes * batch_size * DIM4), CUDA_THREADS(device_id), 0, cuda_stream>>>(
q, num_boxes, batch_size, boxes, new_boxes);
boxsort<<<CUDA_BLOCKS(device_id, batch_size * q * num_boxes), CUDA_THREADS(device_id), 0, cuda_stream>>>(
boxes_result, new_boxes, batch_size, q, num_boxes);
presort<<<CUDA_BLOCKS(device_id, num_classes * num_boxes * batch_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
num_classes, num_boxes, batch_size, scores, new_scores, index, score_threshold, sel);
auto policy = thrust::cuda::par.on(cuda_stream);
for (int i = DIM0; i < num_classes * batch_size; i++) {
thrust::stable_sort_by_key(policy, thrust::device_pointer_cast(new_scores + i * num_boxes),
thrust::device_pointer_cast(new_scores + i * num_boxes) + num_boxes,
thrust::device_pointer_cast(index + i * num_boxes), thrust::greater<float>());
}
}
template <typename T>
void Calnms(int batch_size, int num_classes, T *iou_threshold, bool *sel, float *boxes_result, int *index, int q,
int num_boxes, int *max_output_size_per_class, float *new_scores, bool *mask, const uint32_t &device_id,
cudaStream_t cuda_stream) {
Init<<<CUDA_BLOCKS(device_id, batch_size * num_classes * num_boxes * num_boxes), CUDA_THREADS(device_id), 0,
cuda_stream>>>(num_classes, num_boxes, batch_size, mask);
nms<<<CUDA_BLOCKS(device_id, batch_size * num_classes * num_boxes * num_boxes), CUDA_THREADS(device_id), 0,
cuda_stream>>>(batch_size, num_classes, iou_threshold, sel, boxes_result, index, q, num_boxes, mask);
nmsReducePass<<<CUDA_BLOCKS(device_id, num_boxes), CUDA_THREADS(device_id), 0, cuda_stream>>>(
batch_size, num_classes, sel, index, num_boxes, mask);
sizeperclass<<<CUDA_BLOCKS(device_id, batch_size * num_classes), CUDA_THREADS(device_id), 0, cuda_stream>>>(
batch_size, num_classes, sel, num_boxes, index, max_output_size_per_class);
auto policy = thrust::cuda::par.on(cuda_stream);
for (int i = DIM0; i < batch_size; i++) {
thrust::stable_sort_by_key(
policy, thrust::device_pointer_cast(new_scores + i * num_boxes * num_classes),
thrust::device_pointer_cast(new_scores + i * num_boxes * num_classes) + (num_boxes * num_classes),
thrust::device_pointer_cast(index + i * num_boxes * num_classes), thrust::greater<float>());
}
}
template <typename T>
void Caloutput(int batch_size, int per_detections, int *index, float *new_scores, bool *sel, float *new_boxes,
T *nmsed_classes, T *nmsed_scores, T *nmsed_boxes, int *valid_detections, bool clip_boxes,
int num_classes, int num_boxes, int q, const uint32_t &device_id, cudaStream_t cuda_stream) {
output<<<CUDA_BLOCKS(device_id, batch_size), CUDA_THREADS(device_id), 0, cuda_stream>>>(
batch_size, per_detections, index, new_scores, sel, new_boxes, nmsed_classes, nmsed_scores, nmsed_boxes,
valid_detections, clip_boxes, num_classes, num_boxes, q);
}
template CUDA_LIB_EXPORT void CalSort<float>(float *scores, int *index, float *score_threshold, int num_classes,
float *boxes, float *new_boxes, float *new_scores, int batch_size,
int num_boxes, float *boxes_result, int q, bool *sel,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Calnms<float>(int batch_size, int num_classes, float *iou_threshold, bool *sel,
float *boxes_result, int *index, int q, int num_boxes,
int *max_output_size_per_class, float *new_scores, bool *mask,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Caloutput<float>(int batch_size, int per_detections, int *index, float *new_scores,
bool *sel, float *new_boxes, float *nmsed_classes, float *nmsed_scores,
float *nmsed_boxes, int *valid_detections, bool clip_boxes,
int num_classes, int num_boxes, int q, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,35 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMBINED_NON_MAX_SUPPRESSION_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMBINED_NON_MAX_SUPPRESSION_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T>
CUDA_LIB_EXPORT void CalSort(T *scores, int *index, T *score_threshold, int num_classes, T *boxes, float *new_boxes,
float *new_scores, int batch_size, int num_boxes, float *boxes_result, int q, bool *sel,
const uint32_t &device_id, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void Calnms(int batch_size, int num_classes, T *iou_threshold, bool *sel, float *boxes_result,
int *index, int q, int num_boxes, int *max_output_size_per_class, float *new_scores,
bool *mask, const uint32_t &device_id, cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void Caloutput(int batch_size, int per_detections, int *index, float *new_scores, bool *sel,
float *new_boxes, T *nmsed_classes, T *nmsed_scores, T *nmsed_boxes,
int *valid_detections, bool clip_boxes, int num_classes, int num_boxes, int q,
const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMBINED_NON_MAX_SUPPRESSION_IMPL_CUH_

View File

@ -0,0 +1,167 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless req_uired by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/combined_non_max_suppression_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/combined_non_max_suppression_impl.cuh"
#include <utility>
#include <algorithm>
#include <iostream>
namespace mindspore {
namespace kernel {
constexpr int DimSize4 = 4;
void CombinedNonMaxSuppressionGpuKernelMod::ResetResource() noexcept {
cuda_stream_ = nullptr;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
void CombinedNonMaxSuppressionGpuKernelMod::InitSizeLists() {
input_size_list_.push_back(batch_size_ * num_boxes_ * q_ * DimSize4 * sizeof(T));
input_size_list_.push_back(batch_size_ * num_boxes_ * num_classes_ * sizeof(T));
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(int));
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(sizeof(T));
output_size_list_.push_back(batch_size_ * per_detections_ * DimSize4 * sizeof(T));
output_size_list_.push_back(batch_size_ * per_detections_ * sizeof(T));
output_size_list_.push_back(batch_size_ * per_detections_ * sizeof(T));
output_size_list_.push_back(batch_size_ * sizeof(int));
workspace_size_list_.push_back(q_ * batch_size_ * num_boxes_ * DimSize4 * sizeof(float)); // new_boxes
workspace_size_list_.push_back(batch_size_ * num_classes_ * num_boxes_ * sizeof(float)); // new_scores
workspace_size_list_.push_back(batch_size_ * q_ * num_boxes_ * DimSize4 * sizeof(float)); // boxes_result
workspace_size_list_.push_back(batch_size_ * num_classes_ * num_boxes_ * sizeof(int)); // index
workspace_size_list_.push_back(batch_size_ * num_classes_ * num_boxes_ * sizeof(bool)); // sel
workspace_size_list_.push_back(batch_size_ * num_classes_ * num_boxes_ * num_boxes_ * sizeof(bool));
}
bool CombinedNonMaxSuppressionGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::CombinedNonMaxSuppression>(base_operator);
pad_per_class_ = kernel_ptr->get_pad_per_class();
clip_boxes_ = kernel_ptr->get_clip_boxes();
kernel_name_ = kernel_ptr->name();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in "
<< "[float32,int32], but got: " << kernel_attr;
}
is_need_retrieve_output_shape_ = true;
kernel_func_ = func_list_[index].second;
return true;
}
int CombinedNonMaxSuppressionGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
outputs_ = outputs;
std::vector<size_t> input0_shape = std::vector<size_t>(inputs[kIndex0]->GetDeviceShapeAdaptively().begin(),
inputs[kIndex0]->GetDeviceShapeAdaptively().end());
std::vector<size_t> input1_shape = std::vector<size_t>(inputs[kIndex1]->GetDeviceShapeAdaptively().begin(),
inputs[kIndex1]->GetDeviceShapeAdaptively().end());
batch_size_ = static_cast<int>(input0_shape[kIndex0]);
num_boxes_ = static_cast<int>(input0_shape[kIndex1]);
q_ = static_cast<int>(input0_shape[kIndex2]);
num_classes_ = static_cast<int>(input1_shape[kIndex2]);
auto prim = base_operator->GetPrim();
if ((prim->GetAttr("per_detections"))) {
per_detections_ = GetValue<int>(prim->GetAttr("per_detections"));
}
InitSizeLists();
return KRET_OK;
}
template <typename T>
bool CombinedNonMaxSuppressionGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
cuda_stream_ = reinterpret_cast<cudaStream_t>(stream_ptr);
T *boxes = GetDeviceAddress<T>(inputs, kIndex0);
T *scores = GetDeviceAddress<T>(inputs, kIndex1);
int *max_output_size_per_class = GetDeviceAddress<int>(inputs, kIndex2);
T *iou_threshold = GetDeviceAddress<T>(inputs, kIndex4);
T *score_threshold = GetDeviceAddress<T>(inputs, kIndex5);
T *nmsed_boxes = GetDeviceAddress<T>(outputs, kIndex0);
T *nmsed_scores = GetDeviceAddress<T>(outputs, kIndex1);
T *nmsed_classes = GetDeviceAddress<T>(outputs, kIndex2);
int *valid_detections = GetDeviceAddress<int>(outputs, kIndex3);
float *new_boxes = GetDeviceAddress<float>(workspace, kIndex0);
float *new_scores = GetDeviceAddress<float>(workspace, kIndex1);
float *boxes_result = GetDeviceAddress<float>(workspace, kIndex2);
int *index = GetDeviceAddress<int>(workspace, kIndex3);
bool *sel = GetDeviceAddress<bool>(workspace, kIndex4);
bool *mask = GetDeviceAddress<bool>(workspace, kIndex5);
CalSort(scores, index, score_threshold, num_classes_, boxes, new_boxes, new_scores, batch_size_, num_boxes_,
boxes_result, q_, sel, device_id_, cuda_stream_);
Calnms(batch_size_, num_classes_, iou_threshold, sel, boxes_result, index, q_, num_boxes_, max_output_size_per_class,
new_scores, mask, device_id_, cuda_stream_);
Caloutput(batch_size_, per_detections_, index, new_scores, sel, new_boxes, nmsed_classes, nmsed_scores, nmsed_boxes,
valid_detections, clip_boxes_, num_classes_, num_boxes_, q_, device_id_, cuda_stream_);
return true;
}
void CombinedNonMaxSuppressionGpuKernelMod::SyncData() {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(cuda_stream_), "cudaStreamSynchronized failed");
std::vector<int64_t> shape0 = {batch_size_, per_detections_, DimSize4};
std::vector<int64_t> shape1 = {batch_size_, per_detections_};
std::vector<int64_t> shape2 = {batch_size_, per_detections_};
std::vector<int64_t> shape3 = {batch_size_};
outputs_[kIndex0]->SetShapeVector(shape0);
outputs_[kIndex1]->SetShapeVector(shape1);
outputs_[kIndex2]->SetShapeVector(shape2);
outputs_[kIndex3]->SetShapeVector(shape3);
}
std::vector<std::pair<KernelAttr, CombinedNonMaxSuppressionGpuKernelMod::CombinedNonMaxSuppressionLaunchFunc>>
CombinedNonMaxSuppressionGpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32),
&CombinedNonMaxSuppressionGpuKernelMod::LaunchKernel<float>},
};
std::vector<KernelAttr> CombinedNonMaxSuppressionGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, CombinedNonMaxSuppressionLaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CombinedNonMaxSuppression, CombinedNonMaxSuppressionGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,72 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_COMBINED_NON_MAX_SUPPRESSION_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_COMBINED_NON_MAX_SUPPRESSION_GPU_KERNEL_H_
#include <vector>
#include <memory>
#include <map>
#include <string>
#include <utility>
#include "mindspore/core/ops/combined_non_max_suppression.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class CombinedNonMaxSuppressionGpuKernelMod : public NativeGpuKernelMod {
public:
CombinedNonMaxSuppressionGpuKernelMod() { ResetResource(); }
~CombinedNonMaxSuppressionGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
void SyncData() override;
std::vector<KernelAttr> GetOpSupport() override;
std::vector<KernelTensorPtr> GetOutputs() override { return outputs_; }
private:
void ResetResource() noexcept;
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
using CombinedNonMaxSuppressionLaunchFunc =
std::function<bool(CombinedNonMaxSuppressionGpuKernelMod *, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, void *)>;
static std::vector<std::pair<KernelAttr, CombinedNonMaxSuppressionLaunchFunc>> func_list_;
CombinedNonMaxSuppressionLaunchFunc kernel_func_;
cudaStream_t cuda_stream_;
void InitSizeLists();
size_t T;
int batch_size_;
int num_boxes_;
int q_;
int num_classes_;
int per_detections_;
bool pad_per_class_;
bool clip_boxes_;
std::vector<KernelTensorPtr> outputs_{};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_COMBINED_NON_MAX_SUPPRESSION_GPU_KERNEL_H_

View File

@ -107,7 +107,7 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionGetOutputShape(const PrimitiveP
if (pad_per_class) {
num_detection = std::min(max_total_size, max_output_size_per_class * static_cast<int32_t>(input1_shape[ksecond]));
}
(void)primitive->AddAttr("per_detections", MakeValue(num_detection));
int64_t bs = input0_shape[0];
ShapeVector shape1 = {bs, num_detection, 4};
ShapeVector shape2 = {bs, num_detection};
@ -141,7 +141,6 @@ abstract::TupleShapePtr CombinedNonMaxSuppressionInferShape(const PrimitivePtr &
input3_shape, input4_shape, input5_shape};
auto is_dynamic = (IsDynamic(input0_shape) || IsDynamic(input1_shape));
auto is_dynamic_rank = std::any_of(all_shapes.begin(), all_shapes.end(), IsDynamicRank);
CombinedNonMaxSuppressionCheckShapeSize(input0_shape, input1_shape, input2_shape, input3_shape, input4_shape,
input5_shape, is_dynamic_rank, prim_name);
@ -204,6 +203,16 @@ AbstractBasePtr CombinedNonMaxSuppressionInfer(const abstract::AnalysisEnginePtr
return abstract::MakeAbstract(infer_shape, infer_type);
}
bool CombinedNonMaxSuppression::get_pad_per_class() const {
auto value_ptr = this->GetAttr("pad_per_class");
return GetValue<bool>(value_ptr);
}
bool CombinedNonMaxSuppression::get_clip_boxes() const {
auto value_ptr = this->GetAttr("clip_boxes");
return GetValue<bool>(value_ptr);
}
REGISTER_PRIMITIVE_EVAL_IMPL(CombinedNonMaxSuppression, prim::kPrimCombinedNonMaxSuppression,
CombinedNonMaxSuppressionInfer, nullptr, true);
REGISTER_HOST_DEPENDS(kNameCombinedNonMaxSuppression, {2, 3, 4, 5});

View File

@ -36,6 +36,8 @@ class MIND_API CombinedNonMaxSuppression : public BaseOperator {
InitIOName({"boxes", "scores", "max_output_size_per_class", "max_total_size", "iou_threshold", "score_threshold"},
{"nmsed_box", "nmsed_scores", "nmsed_classes", "valid_detections"});
}
bool get_pad_per_class() const;
bool get_clip_boxes() const;
};
abstract::AbstractBasePtr CombinedNonMaxSuppressionInfer(const abstract::AnalysisEnginePtr &,
const PrimitivePtr &primitive,

View File

@ -1805,6 +1805,44 @@ class PReLUGrad(Primitive):
pass
class RandomGammaGrad(Primitive):
r"""
Computes the derivative of a random sample of Gamma with respect to alpha.:
Inputs:
- **alpha** (Tensor) - α is the shape parameter of RandomGamma distribution.
It must be greater than 0. Must be one of the following types: float32, float64.
- **sample** (Tensor) - The sample of random gamma tensor. Must be one of the
following types: float32, float64.
Outputs:
The dtype is the same type as alpha.
The output shape is derived from the input through broadcasting.
Raises:
TypeError: If data type of `alpha` and `sample` is not float32 or float64.
TypeError: If data type of `alpha` and `sample` is not same.
ValueError: If the shape last dim of `sample` and `alpha` is not equal.
Supported Platforms:
``GPU``
Examples:
>>> alpha = Tensor(np.array([1., 0.6, 3., 26.]), mstype.float32)
>>> sample = Tensor(np.array([6., 7, 11., 0.5]), mstype.float32)
>>> randomgammagrad = ops.RandomGammaGrad()
>>> output = randomgammagrad(alpha, sample)
>>> print(output)
[2.5142431 3.4334087 1.8847835 0.07780622]
"""
@prim_attr_register
def __init__(self):
"""Initialize RandomGammaGrad"""
self.init_prim_io_names(inputs=['alpha', 'sample'], outputs=['output'])
self.add_prim_attr("side_effect_hidden", True)
class ReluGrad(Primitive):
"""Performs grad of Relu operation."""

View File

@ -1121,7 +1121,7 @@ class CombinedNonMaxSuppression(Primitive):
ValueError: If `iou_threshold` not in [0,1].
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> boxes = Tensor(np.array([[[[200, 100, 150, 100]],

View File

@ -312,44 +312,6 @@ class LogNormalReverse(Primitive):
Validator.check_value_type("std", std, [float], self.name)
class RandomGammaGrad(Primitive):
r"""
Computes the derivative of a random sample of Gamma with respect to alpha.:
Inputs:
- **alpha** (Tensor) - α is the shape parameter of RandomGamma distribution.
It must be greater than 0. Must be one of the following types: float32, float64.
- **sample** (Tensor) - The sample of random gamma tensor. Must be one of the
following types: float32, float64.
Outputs:
The dtype is the same type as alpha.
The output shape is derived from the input through broadcasting.
Raises:
TypeError: If data type of `alpha` and `sample` is not float32 or float64.
TypeError: If data type of `alpha` and `sample` is not same.
ValueError: If the shape last dim of `sample` and `alpha` is not equal.
Supported Platforms:
``GPU``
Examples:
>>> alpha = Tensor(np.array([1., 0.6, 3., 26.]), mstype.float32)
>>> sample = Tensor(np.array([6., 7, 11., 0.5]), mstype.float32)
>>> randomgammagrad = ops.RandomGammaGrad()
>>> output = randomgammagrad(alpha, sample)
>>> print(output)
[2.5142431 3.4334087 1.8847835 0.07780622]
"""
@prim_attr_register
def __init__(self):
"""Initialize RandomGammaGrad"""
self.init_prim_io_names(inputs=['alpha', 'sample'], outputs=['output'])
self.add_prim_attr("side_effect_hidden", True)
class Gamma(PrimitiveWithInfer):
r"""
Produces random positive floating-point values x, distributed according to probability density function:

View File

@ -0,0 +1,181 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops.operations.image_ops import CombinedNonMaxSuppression
class NetCombinedNonMaxSuppression(nn.Cell):
def __init__(self, pad_per_class, clip_boxes):
super(NetCombinedNonMaxSuppression, self).__init__()
self.combined_non_max_suppression = CombinedNonMaxSuppression(
pad_per_class, clip_boxes)
def construct(self, boxes, scores, max_output_size_per_class, max_total_size, iou_threshold, score_threshold):
return self.combined_non_max_suppression(boxes, scores, max_output_size_per_class, max_total_size,
iou_threshold, score_threshold)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def cmp(output, expect):
output_type = output.asnumpy().dtype
expect_type = expect.asnumpy().dtype
diff0 = output.asnumpy() - expect
error0 = np.zeros(shape=expect.shape)
assert np.all(diff0 == error0)
assert output.shape == expect.shape
assert output_type == expect_type
def test_combined_non_max_suppresion1():
"""
Feature: Combined non max suppression.
Description: test case for Combined non max suppression.
Expectation: The result are as expected.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
boxes = Tensor(np.array([[[[200, 100, 150, 100]], [[220, 120, 150, 100]], [[190, 110, 150, 100]],
[[210, 112, 150, 100]]]], dtype=np.float32))
scores = Tensor(np.array([[[0.2000, 0.7000, 0.1000], [0.1000, 0.8000, 0.1000],
[0.3000, 0.6000, 0.1000], [0.0500, 0.9000, 0.0500]]], dtype=np.float32))
max_output_size_per_class = Tensor(4, dtype=mstype.int32)
max_total_size = Tensor(1, dtype=mstype.int32)
iou_threshold = Tensor(0, dtype=mstype.float32)
score_threshold = Tensor(0, dtype=mstype.float32)
nmsed_boxes_expect = Tensor(
np.array([[[1., 1., 1., 1.]]], dtype=np.float32))
nmsed_scores_expect = Tensor(np.array([[0.9]], dtype=np.float32))
nmsed_classes_expect = Tensor(np.array([[1.]], dtype=np.float32))
valid_detections_expect = Tensor(np.array([1], dtype=np.int32))
net = NetCombinedNonMaxSuppression(False, True)
output = net(boxes, scores, max_output_size_per_class,
max_total_size, iou_threshold, score_threshold)
output_nmsed_boxes = output[0]
output_nmsed_scores = output[1]
output_nmsed_classes = output[2]
output_valid_detections = output[3]
cmp(output_nmsed_boxes, nmsed_boxes_expect)
cmp(output_nmsed_scores, nmsed_scores_expect)
cmp(output_nmsed_classes, nmsed_classes_expect)
cmp(output_valid_detections, valid_detections_expect)
def test_combined_non_max_suppresion2():
"""
Feature: Combined non max suppression.
Description: test case for Combined non max suppression.
Expectation: The result are as expected.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
boxes = Tensor(np.array([[[[200, 100, 150, 100]], [[220, 120, 150, 100]], [[190, 110, 150, 100]],
[[210, 112, 150, 100]]]], dtype=np.float32))
scores = Tensor(np.array([[[0.2000, 0.7000, 0.1000], [0.1000, 0.8000, 0.1000],
[0.3000, 0.6000, 0.1000], [0.0500, 0.9000, 0.0500]]], dtype=np.float32))
max_output_size_per_class = Tensor(4, dtype=mstype.int32)
max_total_size = Tensor(1, dtype=mstype.int32)
iou_threshold = Tensor(0, dtype=mstype.float32)
score_threshold = Tensor(0, dtype=mstype.float32)
nmsed_boxes_expect = Tensor(
np.array([[[1., 1., 1., 1.]]], dtype=np.float32))
nmsed_scores_expect = Tensor(np.array([[0.9]], dtype=np.float32))
nmsed_classes_expect = Tensor(np.array([[1.]], dtype=np.float32))
valid_detections_expect = Tensor(np.array([1], dtype=np.int32))
net = NetCombinedNonMaxSuppression(False, True)
output = net(boxes, scores, max_output_size_per_class,
max_total_size, iou_threshold, score_threshold)
output_nmsed_boxes = output[0]
output_nmsed_scores = output[1]
output_nmsed_classes = output[2]
output_valid_detections = output[3]
cmp(output_nmsed_boxes, nmsed_boxes_expect)
cmp(output_nmsed_scores, nmsed_scores_expect)
cmp(output_nmsed_classes, nmsed_classes_expect)
cmp(output_valid_detections, valid_detections_expect)
def test_combined_non_max_suppresion3():
"""
Feature: Combined non max suppression.
Description: test case for Combined non max suppression.
Expectation: The result are as expected.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
boxes = Tensor(np.array([[[[200, 100, 150, 100]], [[220, 120, 150, 100]], [[190, 110, 150, 100]],
[[210, 112, 150, 100]]]], dtype=np.float32))
scores = Tensor(np.array([[[0.2000, 0.7000, 0.1000], [0.1000, 0.8000, 0.1000],
[0.3000, 0.6000, 0.1000], [0.0500, 0.9000, 0.0500]]], dtype=np.float32))
max_output_size_per_class = Tensor(4, dtype=mstype.int32)
max_total_size = Tensor(1, dtype=mstype.int32)
iou_threshold = Tensor(0.2, dtype=mstype.float32)
score_threshold = Tensor(0.2, dtype=mstype.float32)
nmsed_boxes_expect = Tensor(
np.array([[[210., 112., 150., 100.]]], dtype=np.float32))
nmsed_scores_expect = Tensor(np.array([[0.9]], dtype=np.float32))
nmsed_classes_expect = Tensor(np.array([[1.]], dtype=np.float32))
valid_detections_expect = Tensor(np.array([1], dtype=np.int32))
net = NetCombinedNonMaxSuppression(True, False)
output = net(boxes, scores, max_output_size_per_class,
max_total_size, iou_threshold, score_threshold)
output_nmsed_boxes = output[0]
output_nmsed_scores = output[1]
output_nmsed_classes = output[2]
output_valid_detections = output[3]
cmp(output_nmsed_boxes, nmsed_boxes_expect)
cmp(output_nmsed_scores, nmsed_scores_expect)
cmp(output_nmsed_classes, nmsed_classes_expect)
cmp(output_valid_detections, valid_detections_expect)
def test_combined_non_max_suppresion4():
"""
Feature: Combined non max suppression.
Description: test case for Combined non max suppression.
Expectation: The result are as expected.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
boxes = Tensor(np.array([[[[200, 100, 150, 100]], [[220, 120, 150, 100]], [[190, 110, 150, 100]],
[[210, 112, 150, 100]]]], dtype=np.float32))
scores = Tensor(np.array([[[0.2000, 0.7000, 0.1000], [0.1000, 0.8000, 0.1000],
[0.3000, 0.6000, 0.1000], [0.0500, 0.9000, 0.0500]]], dtype=np.float32))
max_output_size_per_class = Tensor(4, dtype=mstype.int32)
max_total_size = Tensor(1, dtype=mstype.int32)
iou_threshold = Tensor(0, dtype=mstype.float32)
score_threshold = Tensor(0, dtype=mstype.float32)
nmsed_boxes_expect = Tensor(
np.array([[[210., 112., 150., 100.]]], dtype=np.float32))
nmsed_scores_expect = Tensor(np.array([[0.9]], dtype=np.float32))
nmsed_classes_expect = Tensor(np.array([[1.]], dtype=np.float32))
valid_detections_expect = Tensor(np.array([1], dtype=np.int32))
net = NetCombinedNonMaxSuppression(True, False)
output = net(boxes, scores, max_output_size_per_class,
max_total_size, iou_threshold, score_threshold)
output_nmsed_boxes = output[0]
output_nmsed_scores = output[1]
output_nmsed_classes = output[2]
output_valid_detections = output[3]
cmp(output_nmsed_boxes, nmsed_boxes_expect)
cmp(output_nmsed_scores, nmsed_scores_expect)
cmp(output_nmsed_classes, nmsed_classes_expect)
cmp(output_valid_detections, valid_detections_expect)

View File

@ -17,7 +17,7 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations.random_ops import RandomGammaGrad
from mindspore.ops.operations._grad_ops import RandomGammaGrad
class RandomGammaGradNet(nn.Cell):