!7967 optimize detection_post_process op

Merge pull request !7967 from wangzhe/dpp_refactor
This commit is contained in:
mindspore-ci-bot 2020-10-30 09:36:31 +08:00 committed by Gitee
commit 691b5fdea0
8 changed files with 357 additions and 261 deletions

View File

@ -38,9 +38,11 @@ typedef struct DetectionPostProcessParameter {
void *decoded_boxes_;
void *nms_candidate_;
void *indexes_;
void *scores_;
void *all_class_indexes_;
void *all_class_scores_;
void *single_class_indexes_;
void *selected_;
void *score_with_class_;
void *score_with_class_all_;
} DetectionPostProcessParameter;
#endif // MINDSPORE_LITE_NNACL_DETECTION_POST_PROCESS_PARAMETER_H_

View File

@ -19,84 +19,6 @@
#include "nnacl/errorcode.h"
#include "nnacl/op_base.h"
bool ScoreWithIndexCmp(ScoreWithIndex *pa, ScoreWithIndex *pb) {
if (pa->score > pb->score) {
return true;
} else if (pa->score < pb->score) {
return false;
} else {
return pa->index < pb->index;
}
}
void PushHeap(ScoreWithIndex *root, int cur, int top_index, ScoreWithIndex value) {
int parent = (cur - 1) / 2;
while (cur > top_index && ScoreWithIndexCmp(root + parent, &value)) {
*(root + cur) = root[parent];
cur = parent;
parent = (cur - 1) / 2;
}
*(root + cur) = value;
}
void AdjustHeap(ScoreWithIndex *root, int cur, int limit, ScoreWithIndex value) {
int top_index = cur;
int second_child = cur;
while (second_child < (limit - 1) / 2) {
second_child = 2 * (second_child + 1);
if (ScoreWithIndexCmp(root + second_child, root + second_child - 1)) {
second_child--;
}
*(root + cur) = *(root + second_child);
cur = second_child;
}
if ((limit & 1) == 0 && second_child == (limit - 2) / 2) {
second_child = 2 * (second_child + 1);
*(root + cur) = *(root + second_child - 1);
cur = second_child - 1;
}
PushHeap(root, cur, top_index, value);
}
void PopHeap(ScoreWithIndex *root, int limit, ScoreWithIndex *result) {
ScoreWithIndex value = *result;
*result = *root;
AdjustHeap(root, 0, limit, value);
}
void MakeHeap(ScoreWithIndex *values, int limit) {
if (limit < 2) return;
int parent = (limit - 2) / 2;
while (true) {
AdjustHeap(values, parent, limit, values[parent]);
if (parent == 0) {
return;
}
parent--;
}
}
void SortHeap(ScoreWithIndex *root, int limit) {
while (limit > 1) {
--limit;
PopHeap(root, limit, root + limit);
}
}
void HeapSelect(ScoreWithIndex *root, int cur, int limit) {
MakeHeap(root, cur);
for (int i = cur; i < limit; ++i) {
if (ScoreWithIndexCmp(root + i, root)) {
PopHeap(root, cur, root + i);
}
}
}
void PartialSort(ScoreWithIndex *values, int num_to_sort, int num_values) {
HeapSelect(values, num_to_sort, num_values);
SortHeap(values, num_to_sort);
}
float IntersectionOverUnion(const BboxCorner *a, const BboxCorner *b) {
const float area_a = (a->ymax - a->ymin) * (a->xmax - a->xmin);
const float area_b = (b->ymax - b->ymin) * (b->xmax - b->xmin);
@ -113,8 +35,17 @@ float IntersectionOverUnion(const BboxCorner *a, const BboxCorner *b) {
return inter / (area_a + area_b - inter);
}
void DecodeBoxes(const int num_boxes, const float *input_boxes, const float *anchors, const BboxCenter scaler,
float *decoded_boxes) {
int DecodeBoxes(const int num_boxes, const float *input_boxes, const float *anchors,
DetectionPostProcessParameter *param) {
if (input_boxes == NULL || anchors == NULL || param == NULL) {
return NNACL_NULL_PTR;
}
float *decoded_boxes = (float *)param->decoded_boxes_;
BboxCenter scaler;
scaler.y = param->y_scale_;
scaler.x = param->x_scale_;
scaler.h = param->h_scale_;
scaler.w = param->w_scale_;
for (int i = 0; i < num_boxes; ++i) {
BboxCenter *box = (BboxCenter *)(input_boxes) + i;
BboxCenter *anchor = (BboxCenter *)(anchors) + i;
@ -128,35 +59,40 @@ void DecodeBoxes(const int num_boxes, const float *input_boxes, const float *anc
decoded_box->ymax = y_center + h_half;
decoded_box->xmax = x_center + w_half;
}
return NNACL_OK;
}
int NmsSingleClass(const int candidate_num, const float *decoded_boxes, const int max_detections,
ScoreWithIndex *score_with_index, int *selected, const DetectionPostProcessParameter *param) {
int NmsSingleClass(const int num_boxes, const float *decoded_boxes, const int max_detections, const float *scores,
int *selected, void (*PartialArgSort)(const float *, int *, int, int),
const DetectionPostProcessParameter *param) {
uint8_t *nms_candidate = param->nms_candidate_;
const int output_num = candidate_num < max_detections ? candidate_num : max_detections;
int possible_candidate_num = candidate_num;
const int output_num = num_boxes < max_detections ? num_boxes : max_detections;
int possible_candidate_num = num_boxes;
int selected_num = 0;
PartialSort(score_with_index, candidate_num, candidate_num);
for (int i = 0; i < candidate_num; ++i) {
int *indexes = (int *)param->single_class_indexes_;
for (int i = 0; i < num_boxes; ++i) {
indexes[i] = i;
nms_candidate[i] = 1;
}
for (int i = 0; i < candidate_num; ++i) {
if (possible_candidate_num == 0 || selected_num >= output_num) {
PartialArgSort(scores, indexes, num_boxes, num_boxes);
for (int i = 0; i < num_boxes; ++i) {
if (possible_candidate_num == 0 || selected_num >= output_num || scores[indexes[i]] < param->nms_score_threshold_) {
break;
}
if (nms_candidate[i] == 0) {
if (nms_candidate[indexes[i]] == 0) {
continue;
}
selected[selected_num++] = score_with_index[i].index;
nms_candidate[i] = 0;
selected[selected_num++] = indexes[i];
nms_candidate[indexes[i]] = 0;
possible_candidate_num--;
for (int t = i + 1; t < candidate_num; ++t) {
if (nms_candidate[t] == 1) {
const BboxCorner *bbox_i = (BboxCorner *)(decoded_boxes) + score_with_index[i].index;
const BboxCorner *bbox_t = (BboxCorner *)(decoded_boxes) + score_with_index[t].index;
const BboxCorner *bbox_i = (BboxCorner *)(decoded_boxes) + indexes[i];
for (int t = i + 1; t < num_boxes; ++t) {
if (scores[indexes[t]] < param->nms_score_threshold_) break;
if (nms_candidate[indexes[t]] == 1) {
const BboxCorner *bbox_t = (BboxCorner *)(decoded_boxes) + indexes[t];
const float iou = IntersectionOverUnion(bbox_i, bbox_t);
if (iou > param->nms_iou_threshold_) {
nms_candidate[t] = 0;
nms_candidate[indexes[t]] = 0;
possible_candidate_num--;
}
}
@ -165,54 +101,117 @@ int NmsSingleClass(const int candidate_num, const float *decoded_boxes, const in
return selected_num;
}
int NmsMultiClassesRegular(const int num_boxes, const int num_classes_with_bg, const float *decoded_boxes,
const float *input_scores, float *output_boxes, float *output_classes, float *output_scores,
const DetectionPostProcessParameter *param) {
int NmsMultiClassesFastCore(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
void (*PartialArgSort)(const float *, int *, int, int),
const DetectionPostProcessParameter *param, const int task_id, const int thread_num) {
if (input_scores == NULL || param == NULL || PartialArgSort == NULL) {
return NNACL_NULL_PTR;
}
const int first_class_index = num_classes_with_bg - (int)(param->num_classes_);
int *selected = (int *)(param->selected_);
ScoreWithIndex *score_with_index_single = (ScoreWithIndex *)(param->score_with_class_);
const int64_t max_classes_per_anchor =
param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_;
float *scores = (float *)param->scores_;
for (int i = task_id; i < num_boxes; i += thread_num) {
int *indexes = (int *)param->indexes_ + i * param->num_classes_;
for (int j = 0; j < param->num_classes_; ++j) {
indexes[j] = i * num_classes_with_bg + first_class_index + j;
}
PartialArgSort(input_scores, indexes, max_classes_per_anchor, param->num_classes_);
scores[i] = input_scores[indexes[0]];
}
return NNACL_OK;
}
int DetectionPostProcessFast(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
const float *decoded_boxes, float *output_boxes, float *output_classes,
float *output_scores, float *output_num,
void (*PartialArgSort)(const float *, int *, int, int),
const DetectionPostProcessParameter *param) {
if (input_scores == NULL || decoded_boxes == NULL || output_boxes == NULL || output_classes == NULL ||
output_scores == NULL || output_num == NULL || param == NULL || PartialArgSort == NULL) {
return NNACL_NULL_PTR;
}
int out_num = 0;
const int first_class_index = num_classes_with_bg - (int)(param->num_classes_);
const int64_t max_classes_per_anchor =
param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_;
int *selected = (int *)param->selected_;
int selected_num = NmsSingleClass(num_boxes, decoded_boxes, param->max_detections_, (float *)param->scores_, selected,
PartialArgSort, param);
for (int i = 0; i < selected_num; ++i) {
int *indexes = (int *)param->indexes_ + selected[i] * param->num_classes_;
BboxCorner *box = (BboxCorner *)(decoded_boxes) + selected[i];
for (int j = 0; j < max_classes_per_anchor; ++j) {
*((BboxCorner *)(output_boxes) + out_num) = *box;
output_scores[out_num] = input_scores[indexes[j]];
output_classes[out_num++] = (float)(indexes[j] % num_classes_with_bg - first_class_index);
}
}
*output_num = (float)out_num;
for (int i = out_num; i < param->max_detections_ * param->max_classes_per_detection_; ++i) {
((BboxCorner *)(output_boxes) + i)->ymin = 0;
((BboxCorner *)(output_boxes) + i)->xmin = 0;
((BboxCorner *)(output_boxes) + i)->ymax = 0;
((BboxCorner *)(output_boxes) + i)->xmax = 0;
output_scores[i] = 0;
output_classes[i] = 0;
}
return NNACL_OK;
}
int DetectionPostProcessRegular(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
float *output_boxes, float *output_classes, float *output_scores, float *output_num,
void (*PartialArgSort)(const float *, int *, int, int),
const DetectionPostProcessParameter *param) {
if (input_scores == NULL || output_boxes == NULL || output_classes == NULL || output_scores == NULL ||
output_num == NULL || param == NULL || PartialArgSort == NULL) {
return NNACL_NULL_PTR;
}
const int first_class_index = num_classes_with_bg - (int)(param->num_classes_);
float *decoded_boxes = (float *)param->decoded_boxes_;
int *selected = (int *)param->selected_;
float *scores = (float *)param->scores_;
float *all_scores = (float *)param->all_class_scores_;
int *indexes = (int *)(param->indexes_);
int *all_indexes = (int *)(param->all_class_indexes_);
int all_classes_sorted_num = 0;
int all_classes_output_num = 0;
ScoreWithIndex *score_with_index_all = (ScoreWithIndex *)(param->score_with_class_all_);
int *indexes = (int *)(param->indexes_);
for (int j = first_class_index; j < num_classes_with_bg; ++j) {
int candidate_num = 0;
// process single class
for (int i = 0; i < num_boxes; ++i) {
const float score = input_scores[i * num_classes_with_bg + j];
if (score >= param->nms_score_threshold_) {
score_with_index_single[candidate_num].score = score;
score_with_index_single[candidate_num++].index = i;
}
scores[i] = input_scores[i * num_classes_with_bg + j];
}
int selected_num = NmsSingleClass(candidate_num, decoded_boxes, param->detections_per_class_,
score_with_index_single, selected, param);
int selected_num =
NmsSingleClass(num_boxes, decoded_boxes, param->detections_per_class_, scores, selected, PartialArgSort, param);
for (int i = 0; i < all_classes_sorted_num; ++i) {
indexes[i] = score_with_index_all[i].index;
score_with_index_all[i].index = i;
indexes[i] = all_indexes[i];
all_indexes[i] = i;
}
// process all classes
for (int i = 0; i < selected_num; ++i) {
// store class to index
indexes[all_classes_sorted_num] = selected[i] * num_classes_with_bg + j;
score_with_index_all[all_classes_sorted_num].index = all_classes_sorted_num;
score_with_index_all[all_classes_sorted_num++].score = input_scores[selected[i] * num_classes_with_bg + j];
all_indexes[all_classes_sorted_num] = all_classes_sorted_num;
all_scores[all_classes_sorted_num++] = scores[selected[i]];
}
all_classes_output_num =
all_classes_sorted_num < param->max_detections_ ? all_classes_sorted_num : param->max_detections_;
PartialSort(score_with_index_all, all_classes_output_num, all_classes_sorted_num);
PartialArgSort(all_scores, all_indexes, all_classes_output_num, all_classes_sorted_num);
for (int i = 0; i < all_classes_output_num; ++i) {
score_with_index_all[i].index = indexes[score_with_index_all[i].index];
scores[i] = all_scores[all_indexes[i]];
all_indexes[i] = indexes[all_indexes[i]];
}
for (int i = 0; i < all_classes_output_num; ++i) {
all_scores[i] = scores[i];
}
all_classes_sorted_num = all_classes_output_num;
}
for (int i = 0; i < param->max_detections_ * param->max_classes_per_detection_; ++i) {
if (i < all_classes_output_num) {
const int box_index = score_with_index_all[i].index / num_classes_with_bg;
const int class_index = score_with_index_all[i].index - box_index * num_classes_with_bg - first_class_index;
const int box_index = all_indexes[i] / num_classes_with_bg;
const int class_index = all_indexes[i] % num_classes_with_bg - first_class_index;
*((BboxCorner *)(output_boxes) + i) = *((BboxCorner *)(decoded_boxes) + box_index);
output_classes[i] = (float)class_index;
output_scores[i] = score_with_index_all[i].score;
output_scores[i] = all_scores[i];
} else {
((BboxCorner *)(output_boxes) + i)->ymin = 0;
((BboxCorner *)(output_boxes) + i)->xmin = 0;
@ -222,73 +221,6 @@ int NmsMultiClassesRegular(const int num_boxes, const int num_classes_with_bg, c
output_scores[i] = 0.0f;
}
}
return all_classes_output_num;
}
int NmsMultiClassesFast(const int num_boxes, const int num_classes_with_bg, const float *decoded_boxes,
const float *input_scores, float *output_boxes, float *output_classes, float *output_scores,
const DetectionPostProcessParameter *param) {
const int first_class_index = num_classes_with_bg - (int)(param->num_classes_);
const int64_t max_classes_per_anchor =
param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_;
int candidate_num = 0;
ScoreWithIndex *score_with_class_all = (ScoreWithIndex *)(param->score_with_class_all_);
ScoreWithIndex *score_with_class = (ScoreWithIndex *)(param->score_with_class_);
int *selected = (int *)(param->selected_);
int selected_num;
int output_num = 0;
for (int i = 0; i < num_boxes; ++i) {
for (int j = first_class_index; j < num_classes_with_bg; ++j) {
float score_t = *(input_scores + i * num_classes_with_bg + j);
score_with_class_all[i * param->num_classes_ + j - first_class_index].score = score_t;
// save box and class info to index
score_with_class_all[i * param->num_classes_ + j - first_class_index].index = i * num_classes_with_bg + j;
}
PartialSort(score_with_class_all + i * param->num_classes_, max_classes_per_anchor, param->num_classes_);
const float score_max = (score_with_class_all + i * param->num_classes_)->score;
if (score_max >= param->nms_score_threshold_) {
score_with_class[candidate_num].index = i;
score_with_class[candidate_num++].score = score_max;
}
}
selected_num =
NmsSingleClass(candidate_num, decoded_boxes, param->max_detections_, score_with_class, selected, param);
for (int i = 0; i < selected_num; ++i) {
const ScoreWithIndex *box_score_with_class = score_with_class_all + selected[i] * param->num_classes_;
const int box_index = box_score_with_class->index / num_classes_with_bg;
for (int j = 0; j < max_classes_per_anchor; ++j) {
*((BboxCorner *)(output_boxes) + output_num) = *((BboxCorner *)(decoded_boxes) + box_index);
output_scores[output_num] = (box_score_with_class + j)->score;
output_classes[output_num++] =
(float)((box_score_with_class + j)->index % num_classes_with_bg - first_class_index);
}
}
for (int i = output_num; i < param->max_detections_ * param->max_classes_per_detection_; ++i) {
((BboxCorner *)(output_boxes) + i)->ymin = 0;
((BboxCorner *)(output_boxes) + i)->xmin = 0;
((BboxCorner *)(output_boxes) + i)->ymax = 0;
((BboxCorner *)(output_boxes) + i)->xmax = 0;
output_scores[i] = 0;
output_classes[i] = 0;
}
return output_num;
}
int DetectionPostProcess(const int num_boxes, const int num_classes_with_bg, float *input_boxes,
const float *input_scores, float *input_anchors, float *output_boxes, float *output_classes,
float *output_scores, float *output_num, DetectionPostProcessParameter *param) {
BboxCenter scaler;
scaler.y = param->y_scale_;
scaler.x = param->x_scale_;
scaler.h = param->h_scale_;
scaler.w = param->w_scale_;
DecodeBoxes(num_boxes, input_boxes, input_anchors, scaler, param->decoded_boxes_);
if (param->use_regular_nms_) {
*output_num = NmsMultiClassesRegular(num_boxes, num_classes_with_bg, param->decoded_boxes_, input_scores,
output_boxes, output_classes, output_scores, param);
} else {
*output_num = NmsMultiClassesFast(num_boxes, num_classes_with_bg, param->decoded_boxes_, input_scores, output_boxes,
output_classes, output_scores, param);
}
*output_num = (float)all_classes_output_num;
return NNACL_OK;
}

View File

@ -34,18 +34,24 @@ typedef struct {
float xmax;
} BboxCorner;
typedef struct {
float score;
int index;
} ScoreWithIndex;
#ifdef __cplusplus
extern "C" {
#endif
int DecodeBoxes(const int num_boxes, const float *input_boxes, const float *anchors,
DetectionPostProcessParameter *param);
int DetectionPostProcess(const int num_boxes, const int num_classes_with_bg, float *input_boxes,
const float *input_scores, float *input_anchors, float *output_boxes, float *output_classes,
float *output_scores, float *output_num, DetectionPostProcessParameter *param);
int NmsMultiClassesFastCore(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
void (*)(const float *, int *, int, int), const DetectionPostProcessParameter *param,
const int task_id, const int thread_num);
int DetectionPostProcessFast(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
const float *decoded_boxes, float *output_boxes, float *output_classes,
float *output_scores, float *output_num, void (*)(const float *, int *, int, int),
const DetectionPostProcessParameter *param);
int DetectionPostProcessRegular(const int num_boxes, const int num_classes_with_bg, const float *input_scores,
float *output_boxes, float *output_classes, float *output_scores, float *output_num,
void (*)(const float *, int *, int, int), const DetectionPostProcessParameter *param);
#ifdef __cplusplus
}
#endif

View File

@ -26,10 +26,27 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DetectionPostProcess;
namespace mindspore::kernel {
void PartialArgSort(const float *scores, int *indexes, int num_to_sort, int num_values) {
std::partial_sort(indexes, indexes + num_to_sort, indexes + num_values, [&scores](const int i, const int j) {
if (scores[i] == scores[j]) {
return i < j;
}
return scores[i] > scores[j];
});
}
int DetectionPostProcessBaseCPUKernel::Init() {
params_->decoded_boxes_ = nullptr;
params_->nms_candidate_ = nullptr;
params_->indexes_ = nullptr;
params_->scores_ = nullptr;
params_->all_class_indexes_ = nullptr;
params_->all_class_scores_ = nullptr;
params_->single_class_indexes_ = nullptr;
params_->selected_ = nullptr;
params_->anchors_ = nullptr;
auto anchor_tensor = in_tensors_.at(2);
DetectionPostProcessParameter *parameter = reinterpret_cast<DetectionPostProcessParameter *>(op_parameter_);
parameter->anchors_ = nullptr;
if (anchor_tensor->data_type() == kNumberTypeInt8) {
auto quant_param = anchor_tensor->GetQuantParams().front();
auto anchor_int8 = reinterpret_cast<int8_t *>(anchor_tensor->MutableData());
@ -40,14 +57,14 @@ int DetectionPostProcessBaseCPUKernel::Init() {
}
DoDequantizeInt8ToFp32(anchor_int8, anchor_fp32, quant_param.scale, quant_param.zeroPoint,
anchor_tensor->ElementsNum());
parameter->anchors_ = anchor_fp32;
params_->anchors_ = anchor_fp32;
} else if (anchor_tensor->data_type() == kNumberTypeFloat32 || anchor_tensor->data_type() == kNumberTypeFloat) {
parameter->anchors_ = new (std::nothrow) float[anchor_tensor->ElementsNum()];
if (parameter->anchors_ == nullptr) {
params_->anchors_ = new (std::nothrow) float[anchor_tensor->ElementsNum()];
if (params_->anchors_ == nullptr) {
MS_LOG(ERROR) << "Malloc anchor failed";
return RET_ERROR;
}
memcpy(parameter->anchors_, anchor_tensor->MutableData(), anchor_tensor->Size());
memcpy(params_->anchors_, anchor_tensor->MutableData(), anchor_tensor->Size());
} else {
MS_LOG(ERROR) << "unsupported anchor data type " << anchor_tensor->data_type();
return RET_ERROR;
@ -55,13 +72,56 @@ int DetectionPostProcessBaseCPUKernel::Init() {
return RET_OK;
}
DetectionPostProcessBaseCPUKernel::~DetectionPostProcessBaseCPUKernel() {
DetectionPostProcessParameter *parameter = reinterpret_cast<DetectionPostProcessParameter *>(op_parameter_);
delete[](parameter->anchors_);
}
DetectionPostProcessBaseCPUKernel::~DetectionPostProcessBaseCPUKernel() { delete[](params_->anchors_); }
int DetectionPostProcessBaseCPUKernel::ReSize() { return RET_OK; }
int NmsMultiClassesFastCoreRun(void *cdata, int task_id) {
auto KernelData = reinterpret_cast<DetectionPostProcessBaseCPUKernel *>(cdata);
int ret = NmsMultiClassesFastCore(KernelData->num_boxes_, KernelData->num_classes_with_bg_, KernelData->input_scores_,
PartialArgSort, KernelData->params_, task_id, KernelData->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "NmsMultiClassesFastCore error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
void DetectionPostProcessBaseCPUKernel::FreeAllocatedBuffer() {
if (params_->decoded_boxes_ != nullptr) {
context_->allocator->Free(params_->decoded_boxes_);
params_->decoded_boxes_ = nullptr;
}
if (params_->nms_candidate_ != nullptr) {
context_->allocator->Free(params_->nms_candidate_);
params_->nms_candidate_ = nullptr;
}
if (params_->indexes_ != nullptr) {
context_->allocator->Free(params_->indexes_);
params_->indexes_ = nullptr;
}
if (params_->scores_ != nullptr) {
context_->allocator->Free(params_->scores_);
params_->scores_ = nullptr;
}
if (params_->all_class_indexes_ != nullptr) {
context_->allocator->Free(params_->all_class_indexes_);
params_->all_class_indexes_ = nullptr;
}
if (params_->all_class_scores_ != nullptr) {
context_->allocator->Free(params_->all_class_scores_);
params_->all_class_scores_ = nullptr;
}
if (params_->single_class_indexes_ != nullptr) {
context_->allocator->Free(params_->single_class_indexes_);
params_->single_class_indexes_ = nullptr;
}
if (params_->selected_ != nullptr) {
context_->allocator->Free(params_->selected_);
params_->selected_ = nullptr;
}
}
int DetectionPostProcessBaseCPUKernel::Run() {
MS_ASSERT(context_->allocator != nullptr);
int status = GetInputData();
@ -73,56 +133,105 @@ int DetectionPostProcessBaseCPUKernel::Run() {
auto output_scores = reinterpret_cast<float *>(out_tensors_.at(2)->MutableData());
auto output_num = reinterpret_cast<float *>(out_tensors_.at(3)->MutableData());
const int num_boxes = in_tensors_.at(0)->shape()[1];
const int num_classes_with_bg = in_tensors_.at(1)->shape()[2];
DetectionPostProcessParameter *parameter = reinterpret_cast<DetectionPostProcessParameter *>(op_parameter_);
parameter->decoded_boxes_ = context_->allocator->Malloc(num_boxes * 4 * sizeof(float));
parameter->nms_candidate_ = context_->allocator->Malloc(num_boxes * sizeof(uint8_t));
parameter->selected_ = context_->allocator->Malloc(num_boxes * sizeof(int));
parameter->score_with_class_ = context_->allocator->Malloc(num_boxes * sizeof(ScoreWithIndex));
if (!parameter->decoded_boxes_ || !parameter->nms_candidate_ || !parameter->selected_ ||
!parameter->score_with_class_) {
MS_LOG(ERROR) << "malloc parameter->decoded_boxes_ || parameter->nms_candidate_ || parameter->selected_ || "
"parameter->score_with_class_ failed.";
num_boxes_ = in_tensors_.at(0)->shape()[1];
num_classes_with_bg_ = in_tensors_.at(1)->shape()[2];
params_->decoded_boxes_ = context_->allocator->Malloc(num_boxes_ * 4 * sizeof(float));
if (params_->decoded_boxes_ == nullptr) {
MS_LOG(ERROR) << "malloc params->decoded_boxes_ failed.";
FreeAllocatedBuffer();
return RET_ERROR;
}
if (parameter->use_regular_nms_) {
parameter->score_with_class_all_ =
context_->allocator->Malloc((num_boxes + parameter->max_detections_) * sizeof(ScoreWithIndex));
if (parameter->score_with_class_all_ == nullptr) {
MS_LOG(ERROR) << "malloc parameter->score_with_class_all_failed.";
params_->nms_candidate_ = context_->allocator->Malloc(num_boxes_ * sizeof(uint8_t));
if (params_->nms_candidate_ == nullptr) {
MS_LOG(ERROR) << "malloc params->nms_candidate_ failed.";
FreeAllocatedBuffer();
return RET_ERROR;
}
params_->selected_ = context_->allocator->Malloc(num_boxes_ * sizeof(int));
if (params_->selected_ == nullptr) {
MS_LOG(ERROR) << "malloc params->selected_ failed.";
FreeAllocatedBuffer();
return RET_ERROR;
}
params_->single_class_indexes_ = context_->allocator->Malloc(num_boxes_ * sizeof(int));
if (params_->single_class_indexes_ == nullptr) {
MS_LOG(ERROR) << "malloc params->single_class_indexes_ failed.";
FreeAllocatedBuffer();
return RET_ERROR;
}
if (params_->use_regular_nms_) {
params_->scores_ = context_->allocator->Malloc((num_boxes_ + params_->max_detections_) * sizeof(float));
if (params_->scores_ == nullptr) {
MS_LOG(ERROR) << "malloc params->scores_ failed";
FreeAllocatedBuffer();
return RET_ERROR;
}
parameter->indexes_ = context_->allocator->Malloc((num_boxes + parameter->max_detections_) * sizeof(int));
if (parameter->indexes_ == nullptr) {
MS_LOG(ERROR) << "malloc parameter->indexes_ failed.";
context_->allocator->Free(parameter->score_with_class_all_);
params_->indexes_ = context_->allocator->Malloc((num_boxes_ + params_->max_detections_) * sizeof(int));
if (params_->indexes_ == nullptr) {
MS_LOG(ERROR) << "malloc params->indexes_ failed";
FreeAllocatedBuffer();
return RET_ERROR;
}
params_->all_class_scores_ = context_->allocator->Malloc((num_boxes_ + params_->max_detections_) * sizeof(float));
if (params_->all_class_scores_ == nullptr) {
MS_LOG(ERROR) << "malloc params->all_class_scores_ failed";
FreeAllocatedBuffer();
return RET_ERROR;
}
params_->all_class_indexes_ = context_->allocator->Malloc((num_boxes_ + params_->max_detections_) * sizeof(int));
if (params_->all_class_indexes_ == nullptr) {
MS_LOG(ERROR) << "malloc params->all_class_indexes_ failed";
FreeAllocatedBuffer();
return RET_ERROR;
}
} else {
parameter->score_with_class_all_ =
context_->allocator->Malloc((num_boxes * parameter->num_classes_) * sizeof(ScoreWithIndex));
if (!parameter->score_with_class_all_) {
MS_LOG(ERROR) << "malloc parameter->score_with_class_all_ failed.";
params_->scores_ = context_->allocator->Malloc(num_boxes_ * sizeof(float));
if (params_->scores_ == nullptr) {
MS_LOG(ERROR) << "malloc params->scores_ failed";
FreeAllocatedBuffer();
return RET_ERROR;
}
params_->indexes_ = context_->allocator->Malloc(num_boxes_ * params_->num_classes_ * sizeof(int));
if (!params_->indexes_) {
MS_LOG(ERROR) << "malloc params->indexes_ failed.";
FreeAllocatedBuffer();
return RET_ERROR;
}
}
DetectionPostProcess(num_boxes, num_classes_with_bg, input_boxes, input_scores, parameter->anchors_, output_boxes,
output_classes, output_scores, output_num, parameter);
context_->allocator->Free(parameter->decoded_boxes_);
parameter->decoded_boxes_ = nullptr;
context_->allocator->Free(parameter->nms_candidate_);
parameter->nms_candidate_ = nullptr;
context_->allocator->Free(parameter->selected_);
parameter->selected_ = nullptr;
context_->allocator->Free(parameter->score_with_class_);
parameter->score_with_class_ = nullptr;
context_->allocator->Free(parameter->score_with_class_all_);
parameter->score_with_class_all_ = nullptr;
if (parameter->use_regular_nms_) {
context_->allocator->Free(parameter->indexes_);
parameter->indexes_ = nullptr;
status = DecodeBoxes(num_boxes_, input_boxes_, params_->anchors_, params_);
if (status != RET_OK) {
MS_LOG(ERROR) << "DecodeBoxes error";
FreeAllocatedBuffer();
return status;
}
if (params_->use_regular_nms_) {
status = DetectionPostProcessRegular(num_boxes_, num_classes_with_bg_, input_scores_, output_boxes, output_classes,
output_scores, output_num, PartialArgSort, params_);
if (status != RET_OK) {
MS_LOG(ERROR) << "DetectionPostProcessRegular error error_code[" << status << "]";
FreeAllocatedBuffer();
return status;
}
} else {
status = ParallelLaunch(this->context_->thread_pool_, NmsMultiClassesFastCoreRun, this, op_parameter_->thread_num_);
if (status != RET_OK) {
MS_LOG(ERROR) << "NmsMultiClassesFastCoreRun error error_code[" << status << "]";
FreeAllocatedBuffer();
return status;
}
status = DetectionPostProcessFast(num_boxes_, num_classes_with_bg_, input_scores_,
reinterpret_cast<float *>(params_->decoded_boxes_), output_boxes, output_classes,
output_scores, output_num, PartialArgSort, params_);
if (status != RET_OK) {
MS_LOG(ERROR) << "DetectionPostProcessFast error error_code[" << status << "]";
FreeAllocatedBuffer();
return status;
}
}
FreeAllocatedBuffer();
return RET_OK;
}
} // namespace mindspore::kernel
} // namespace mindspore::kernel

View File

@ -30,18 +30,27 @@ class DetectionPostProcessBaseCPUKernel : public LiteKernel {
DetectionPostProcessBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {
params_ = reinterpret_cast<DetectionPostProcessParameter *>(parameter);
}
virtual ~DetectionPostProcessBaseCPUKernel();
int Init() override;
int ReSize() override;
int Run() override;
protected:
float *input_boxes = nullptr;
float *input_scores = nullptr;
int thread_num_;
int num_boxes_;
int num_classes_with_bg_;
float *input_boxes_ = nullptr;
float *input_scores_ = nullptr;
DetectionPostProcessParameter *params_ = nullptr;
protected:
virtual int GetInputData() = 0;
private:
void FreeAllocatedBuffer();
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_BASE_H_

View File

@ -33,8 +33,8 @@ int DetectionPostProcessCPUKernel::GetInputData() {
MS_LOG(ERROR) << "Input data type error";
return RET_ERROR;
}
input_boxes = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
input_scores = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
input_boxes_ = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
input_scores_ = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
return RET_OK;
}

View File

@ -27,8 +27,30 @@ using mindspore::schema::PrimitiveType_DetectionPostProcess;
namespace mindspore::kernel {
int DetectionPostProcessInt8CPUKernel::DequantizeInt8ToFp32(const int task_id) {
int num_unit_thread = MSMIN(thread_n_stride_, quant_size_ - task_id * thread_n_stride_);
int thread_offset = task_id * thread_n_stride_;
int ret = DoDequantizeInt8ToFp32(data_int8_ + thread_offset, data_fp32_ + thread_offset, quant_param_.scale,
quant_param_.zeroPoint, num_unit_thread);
if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantDTypeCast error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int DequantizeInt8ToFp32Run(void *cdata, int task_id) {
auto KernelData = reinterpret_cast<DetectionPostProcessInt8CPUKernel *>(cdata);
auto ret = KernelData->DequantizeInt8ToFp32(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantDTypeCastRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
}
int DetectionPostProcessInt8CPUKernel::Dequantize(lite::Tensor *tensor, float **data) {
auto data_int8 = reinterpret_cast<int8_t *>(tensor->MutableData());
data_int8_ = reinterpret_cast<int8_t *>(tensor->MutableData());
*data = reinterpret_cast<float *>(context_->allocator->Malloc(tensor->ElementsNum() * sizeof(float)));
if (*data == nullptr) {
MS_LOG(ERROR) << "Malloc data failed.";
@ -38,8 +60,17 @@ int DetectionPostProcessInt8CPUKernel::Dequantize(lite::Tensor *tensor, float **
MS_LOG(ERROR) << "null quant param";
return RET_ERROR;
}
auto quant_param = tensor->GetQuantParams().front();
DoDequantizeInt8ToFp32(data_int8, *data, quant_param.scale, quant_param.zeroPoint, tensor->ElementsNum());
quant_param_ = tensor->GetQuantParams().front();
data_fp32_ = *data;
quant_size_ = tensor->ElementsNum();
thread_n_stride_ = UP_DIV(quant_size_, op_parameter_->thread_num_);
auto ret = ParallelLaunch(this->context_->thread_pool_, DequantizeInt8ToFp32Run, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "QuantDTypeCastRun error error_code[" << ret << "]";
context_->allocator->Free(*data);
return RET_ERROR;
}
return RET_OK;
}
int DetectionPostProcessInt8CPUKernel::GetInputData() {
@ -47,11 +78,11 @@ int DetectionPostProcessInt8CPUKernel::GetInputData() {
MS_LOG(ERROR) << "Input data type error";
return RET_ERROR;
}
int status = Dequantize(in_tensors_.at(0), &input_boxes);
int status = Dequantize(in_tensors_.at(0), &input_boxes_);
if (status != RET_OK) {
return status;
}
status = Dequantize(in_tensors_.at(1), &input_scores);
status = Dequantize(in_tensors_.at(1), &input_scores_);
if (status != RET_OK) {
return status;
}

View File

@ -34,9 +34,16 @@ class DetectionPostProcessInt8CPUKernel : public DetectionPostProcessBaseCPUKern
: DetectionPostProcessBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~DetectionPostProcessInt8CPUKernel() = default;
int8_t *data_int8_ = nullptr;
float *data_fp32_ = nullptr;
lite::QuantArg quant_param_;
int quant_size_;
int thread_n_stride_;
int DequantizeInt8ToFp32(const int task_id);
private:
int GetInputData();
int Dequantize(lite::Tensor *tensor, float **data);
int GetInputData();
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_DETECTION_POST_PROCESS_INT8_H_