forked from mindspore-Ecosystem/mindspore
!5744 add DetectionPostProcess op
Merge pull request !5744 from wangzhe/master
This commit is contained in:
commit
82310bb63f
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_DETECTION_POST_PROCESS_PARAMETER_H_
|
||||
#define MINDSPORE_LITE_NNACL_DETECTION_POST_PROCESS_PARAMETER_H_
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct DetectionPostProcessParameter {
|
||||
OpParameter op_parameter_;
|
||||
float h_scale_;
|
||||
float w_scale_;
|
||||
float x_scale_;
|
||||
float y_scale_;
|
||||
float nms_iou_threshold_;
|
||||
float nms_score_threshold_;
|
||||
int64_t max_detections_;
|
||||
int64_t detections_per_class_;
|
||||
int64_t max_classes_per_detection_;
|
||||
int64_t num_classes_;
|
||||
bool use_regular_nms_;
|
||||
bool out_quantized_;
|
||||
|
||||
void *decoded_boxes_;
|
||||
void *nms_candidate_;
|
||||
void *selected_;
|
||||
void *score_with_class_;
|
||||
void *score_with_class_all_;
|
||||
} DetectionPostProcessParameter;
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_DETECTION_POST_PROCESS_PARAMETER_H_
|
|
@ -0,0 +1,220 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp32/detection_post_process.h"
|
||||
#include <math.h>
|
||||
#include "nnacl/errorcode.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
int ScoreWithIndexCmp(const void *a, const void *b) {
|
||||
ScoreWithIndex *pa = (ScoreWithIndex *)a;
|
||||
ScoreWithIndex *pb = (ScoreWithIndex *)b;
|
||||
if (pa->score > pb->score) {
|
||||
return -1;
|
||||
} else if (pa->score < pb->score) {
|
||||
return 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
float IntersectionOverUnion(const BboxCorner *a, const BboxCorner *b) {
|
||||
const float area_a = (a->ymax - a->ymin) * (a->xmax - a->xmin);
|
||||
const float area_b = (b->ymax - b->ymin) * (b->xmax - b->xmin);
|
||||
if (area_a <= 0 || area_b <= 0) {
|
||||
return 0.0f;
|
||||
}
|
||||
const float ymin = a->ymin > b->ymin ? a->ymin : b->ymin;
|
||||
const float xmin = a->xmin > b->xmin ? a->xmin : b->xmin;
|
||||
const float ymax = a->ymax < b->ymax ? a->ymax : b->ymax;
|
||||
const float xmax = a->xmax < b->xmax ? a->xmax : b->xmax;
|
||||
const float h = ymax - ymin > 0.0f ? ymax - ymin : 0.0f;
|
||||
const float w = xmax - xmin > 0.0f ? xmax - xmin : 0.0f;
|
||||
const float inter = h * w;
|
||||
return inter / (area_a + area_b - inter + 1e-8);
|
||||
}
|
||||
|
||||
void DecodeBoxes(const int num_boxes, const float *input_boxes, const float *anchors, const BboxCenter scaler,
|
||||
float *decoded_boxes) {
|
||||
for (int i = 0; i < num_boxes; ++i) {
|
||||
BboxCenter *box = (BboxCenter *)(input_boxes + i * 4);
|
||||
BboxCenter *anchor = (BboxCenter *)(anchors + i * 4);
|
||||
BboxCorner *decoded_box = (BboxCorner *)(decoded_boxes + i * 4);
|
||||
float y_center = box->y / scaler.y * anchor->h + anchor->y;
|
||||
float x_center = box->x / scaler.x * anchor->w + anchor->x;
|
||||
float h_half = 0.5f * expf(box->h / scaler.h) * anchor->h;
|
||||
float w_half = 0.5f * expf(box->w / scaler.w) * anchor->w;
|
||||
decoded_box->ymin = y_center - h_half;
|
||||
decoded_box->xmin = x_center - w_half;
|
||||
decoded_box->ymax = y_center + h_half;
|
||||
decoded_box->xmax = x_center + w_half;
|
||||
}
|
||||
}
|
||||
|
||||
int NmsSingleClass(const int candidate_num, const float *decoded_boxes, const int max_detections,
|
||||
ScoreWithIndex *score_with_index, int *selected, const DetectionPostProcessParameter *param) {
|
||||
uint8_t *nms_candidate = param->nms_candidate_;
|
||||
const int output_num = candidate_num < max_detections ? candidate_num : max_detections;
|
||||
int possible_candidate_num = candidate_num;
|
||||
int selected_num = 0;
|
||||
qsort(score_with_index, candidate_num, sizeof(ScoreWithIndex), ScoreWithIndexCmp);
|
||||
for (int i = 0; i < candidate_num; ++i) {
|
||||
nms_candidate[i] = 1;
|
||||
}
|
||||
for (int i = 0; i < candidate_num; ++i) {
|
||||
if (possible_candidate_num == 0 || selected_num >= output_num) {
|
||||
break;
|
||||
}
|
||||
if (nms_candidate[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
selected[selected_num++] = score_with_index[i].index;
|
||||
nms_candidate[i] = 0;
|
||||
possible_candidate_num--;
|
||||
for (int t = i + 1; t < candidate_num; ++t) {
|
||||
if (nms_candidate[t] == 1) {
|
||||
const BboxCorner *bbox_i = (BboxCorner *)(decoded_boxes) + score_with_index[i].index;
|
||||
const BboxCorner *bbox_t = (BboxCorner *)(decoded_boxes) + score_with_index[t].index;
|
||||
const float iou = IntersectionOverUnion(bbox_i, bbox_t);
|
||||
if (iou > param->nms_iou_threshold_) {
|
||||
nms_candidate[t] = 0;
|
||||
possible_candidate_num--;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return selected_num;
|
||||
}
|
||||
|
||||
int NmsMultiClassesRegular(const int num_boxes, const int num_classes_with_bg, const float *decoded_boxes,
|
||||
const float *input_scores, float *output_boxes, float *output_classes, float *output_scores,
|
||||
const DetectionPostProcessParameter *param) {
|
||||
const int first_class_index = num_classes_with_bg - (int)(param->num_classes_);
|
||||
int *selected = (int *)(param->selected_);
|
||||
ScoreWithIndex *score_with_index_single = (ScoreWithIndex *)(param->score_with_class_);
|
||||
int all_classes_sorted_num = 0;
|
||||
int all_classes_output_num = 0;
|
||||
ScoreWithIndex *score_with_index_all = (ScoreWithIndex *)(param->score_with_class_all_);
|
||||
for (int j = first_class_index; j < num_classes_with_bg; ++j) {
|
||||
int candidate_num = 0;
|
||||
// process single class
|
||||
for (int i = 0; i < num_boxes; ++i) {
|
||||
const float score = input_scores[i * num_classes_with_bg + j];
|
||||
if (score >= param->nms_score_threshold_) {
|
||||
score_with_index_single[candidate_num].score = score;
|
||||
score_with_index_single[candidate_num++].index = i;
|
||||
}
|
||||
}
|
||||
int selected_num = NmsSingleClass(candidate_num, decoded_boxes, param->detections_per_class_,
|
||||
score_with_index_single, selected, param);
|
||||
// process all classes
|
||||
for (int i = 0; i < selected_num; ++i) {
|
||||
// store class to index
|
||||
score_with_index_all[all_classes_sorted_num].index = selected[i] * num_classes_with_bg + j;
|
||||
score_with_index_all[all_classes_sorted_num++].score = input_scores[selected[i] * num_classes_with_bg + j];
|
||||
}
|
||||
all_classes_output_num =
|
||||
all_classes_sorted_num < param->max_detections_ ? all_classes_sorted_num : param->max_detections_;
|
||||
qsort(score_with_index_all, all_classes_sorted_num, sizeof(ScoreWithIndex), ScoreWithIndexCmp);
|
||||
all_classes_sorted_num = all_classes_output_num;
|
||||
}
|
||||
for (int i = 0; i < param->max_detections_ * param->max_classes_per_detection_; ++i) {
|
||||
if (i < all_classes_output_num) {
|
||||
const int box_index = score_with_index_all[i].index / num_classes_with_bg;
|
||||
const int class_index = score_with_index_all[i].index - box_index * num_classes_with_bg - first_class_index;
|
||||
*((BboxCorner *)(output_boxes) + i) = *((BboxCorner *)(decoded_boxes) + box_index);
|
||||
output_classes[i] = (float)class_index;
|
||||
output_scores[i] = score_with_index_all[i].score;;
|
||||
} else {
|
||||
((BboxCorner *)(output_boxes) + i)->ymin = 0;
|
||||
((BboxCorner *)(output_boxes) + i)->xmin = 0;
|
||||
((BboxCorner *)(output_boxes) + i)->ymax = 0;
|
||||
((BboxCorner *)(output_boxes) + i)->xmax = 0;
|
||||
output_classes[i] = 0.0f;
|
||||
output_scores[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
return all_classes_output_num;
|
||||
}
|
||||
|
||||
int NmsMultiClassesFast(const int num_boxes, const int num_classes_with_bg, const float *decoded_boxes,
|
||||
const float *input_scores, float *output_boxes, float *output_classes, float *output_scores,
|
||||
const DetectionPostProcessParameter *param) {
|
||||
const int first_class_index = num_classes_with_bg - (int)(param->num_classes_);
|
||||
const int64_t max_classes_per_anchor =
|
||||
param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_;
|
||||
int candidate_num = 0;
|
||||
ScoreWithIndex *score_with_class_all = (ScoreWithIndex *)(param->score_with_class_all_);
|
||||
ScoreWithIndex *score_with_class = (ScoreWithIndex *)(param->score_with_class_);
|
||||
int *selected = (int *)(param->selected_);
|
||||
int selected_num;
|
||||
int output_num = 0;
|
||||
for (int i = 0; i < num_boxes; ++i) {
|
||||
for (int j = first_class_index; j < num_classes_with_bg; ++j) {
|
||||
float score_t = *(input_scores + i * num_classes_with_bg + j);
|
||||
score_with_class_all[i * param->num_classes_ + j - first_class_index].score = score_t;
|
||||
// save box and class info to index
|
||||
score_with_class_all[i * param->num_classes_ + j - first_class_index].index = i * num_classes_with_bg + j;
|
||||
}
|
||||
qsort(score_with_class_all + i * param->num_classes_, param->num_classes_, sizeof(ScoreWithIndex),
|
||||
ScoreWithIndexCmp);
|
||||
const float score_max = (score_with_class_all + i * param->num_classes_)->score;
|
||||
if (score_max >= param->nms_score_threshold_) {
|
||||
score_with_class[candidate_num].index = i;
|
||||
score_with_class[candidate_num++].score = score_max;
|
||||
}
|
||||
}
|
||||
selected_num =
|
||||
NmsSingleClass(candidate_num, decoded_boxes, param->max_detections_, score_with_class, selected, param);
|
||||
for (int i = 0; i < selected_num; ++i) {
|
||||
const ScoreWithIndex *box_score_with_class = score_with_class_all + selected[i] * param->num_classes_;
|
||||
const int box_index = box_score_with_class->index / num_classes_with_bg;
|
||||
for (int j = 0; j < max_classes_per_anchor; ++j) {
|
||||
*((BboxCorner *)(output_boxes) + output_num) = *((BboxCorner *)(decoded_boxes) + box_index);
|
||||
output_scores[output_num] = (box_score_with_class + j)->score;
|
||||
output_classes[output_num++] =
|
||||
(float)((box_score_with_class + j)->index % num_classes_with_bg - first_class_index);
|
||||
}
|
||||
}
|
||||
for (int i = output_num; i < param->max_detections_ * param->max_classes_per_detection_; ++i) {
|
||||
((BboxCorner *)(output_boxes) + i)->ymin = 0;
|
||||
((BboxCorner *)(output_boxes) + i)->xmin = 0;
|
||||
((BboxCorner *)(output_boxes) + i)->ymax = 0;
|
||||
((BboxCorner *)(output_boxes) + i)->xmax = 0;
|
||||
output_scores[i] = 0;
|
||||
output_classes[i] = 0;
|
||||
}
|
||||
return output_num;
|
||||
}
|
||||
|
||||
int DetectionPostProcess(const int num_boxes, const int num_classes_with_bg, float *input_boxes, float *input_scores,
|
||||
float *input_anchors, float *output_boxes, float *output_classes, float *output_scores,
|
||||
float *output_num, DetectionPostProcessParameter *param) {
|
||||
BboxCenter scaler;
|
||||
scaler.y = param->y_scale_;
|
||||
scaler.x = param->x_scale_;
|
||||
scaler.h = param->h_scale_;
|
||||
scaler.w = param->w_scale_;
|
||||
DecodeBoxes(num_boxes, input_boxes, input_anchors, scaler, param->decoded_boxes_);
|
||||
if (param->use_regular_nms_) {
|
||||
*output_num = NmsMultiClassesRegular(num_boxes, num_classes_with_bg, param->decoded_boxes_, input_scores,
|
||||
output_boxes, output_classes, output_scores, param);
|
||||
} else {
|
||||
*output_num = NmsMultiClassesFast(num_boxes, num_classes_with_bg, param->decoded_boxes_, input_scores, output_boxes,
|
||||
output_classes, output_scores, param);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP32_DETECTION_POST_PROCESS_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP32_DETECTION_POST_PROCESS_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/detection_post_process_parameter.h"
|
||||
|
||||
typedef struct {
|
||||
float y;
|
||||
float x;
|
||||
float h;
|
||||
float w;
|
||||
} BboxCenter;
|
||||
|
||||
typedef struct {
|
||||
float ymin;
|
||||
float xmin;
|
||||
float ymax;
|
||||
float xmax;
|
||||
} BboxCorner;
|
||||
|
||||
typedef struct {
|
||||
float score;
|
||||
int index;
|
||||
} ScoreWithIndex;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void nms_multi_classes_regular();
|
||||
|
||||
void nms_multi_classes_fase();
|
||||
|
||||
int DetectionPostProcess(const int num_boxes, const int num_classes_with_bg, float *input_boxes, float *input_scores,
|
||||
float *input_anchors, float *output_boxes, float *output_classes, float *output_scores,
|
||||
float *output_num, DetectionPostProcessParameter *param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP32_DETECTION_POST_PROCESS_H_
|
|
@ -351,8 +351,8 @@ table DetectionPostProcess {
|
|||
NmsIouThreshold: float;
|
||||
NmsScoreThreshold: float;
|
||||
MaxDetections: long;
|
||||
DetectionsPreClass: long;
|
||||
MaxClassesPreDetection: long;
|
||||
DetectionsPerClass: long;
|
||||
MaxClassesPerDetection: long;
|
||||
NumClasses: long;
|
||||
UseRegularNms: bool;
|
||||
OutQuantized: bool;
|
||||
|
|
|
@ -34,11 +34,11 @@ float DetectionPostProcess::GetNmsScoreThreshold() const {
|
|||
int64_t DetectionPostProcess::GetMaxDetections() const {
|
||||
return this->primitive_->value.AsDetectionPostProcess()->MaxDetections;
|
||||
}
|
||||
int64_t DetectionPostProcess::GetDetectionsPreClass() const {
|
||||
return this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass;
|
||||
int64_t DetectionPostProcess::GetDetectionsPerClass() const {
|
||||
return this->primitive_->value.AsDetectionPostProcess()->DetectionsPerClass;
|
||||
}
|
||||
int64_t DetectionPostProcess::GetMaxClassesPreDetection() const {
|
||||
return this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection;
|
||||
int64_t DetectionPostProcess::GetMaxClassesPerDetection() const {
|
||||
return this->primitive_->value.AsDetectionPostProcess()->MaxClassesPerDetection;
|
||||
}
|
||||
int64_t DetectionPostProcess::GetNumClasses() const {
|
||||
return this->primitive_->value.AsDetectionPostProcess()->NumClasses;
|
||||
|
@ -46,7 +46,6 @@ int64_t DetectionPostProcess::GetNumClasses() const {
|
|||
bool DetectionPostProcess::GetUseRegularNms() const {
|
||||
return this->primitive_->value.AsDetectionPostProcess()->UseRegularNms;
|
||||
}
|
||||
|
||||
void DetectionPostProcess::SetFormat(int format) {
|
||||
this->primitive_->value.AsDetectionPostProcess()->format = (schema::Format)format;
|
||||
}
|
||||
|
@ -72,13 +71,13 @@ void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) {
|
|||
this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold = nms_score_threshold;
|
||||
}
|
||||
void DetectionPostProcess::SetMaxDetections(int64_t max_detections) {
|
||||
this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_detections;
|
||||
this->primitive_->value.AsDetectionPostProcess()->MaxDetections = max_detections;
|
||||
}
|
||||
void DetectionPostProcess::SetDetectionsPreClass(int64_t detections_pre_class) {
|
||||
this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass = detections_pre_class;
|
||||
void DetectionPostProcess::SetDetectionsPerClass(int64_t detections_per_class) {
|
||||
this->primitive_->value.AsDetectionPostProcess()->DetectionsPerClass = detections_per_class;
|
||||
}
|
||||
void DetectionPostProcess::SetMaxClassesPreDetection(int64_t max_classes_pre_detection) {
|
||||
this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_classes_pre_detection;
|
||||
void DetectionPostProcess::SetMaxClassesPerDetection(int64_t max_classes_per_detection) {
|
||||
this->primitive_->value.AsDetectionPostProcess()->MaxClassesPerDetection = max_classes_per_detection;
|
||||
}
|
||||
void DetectionPostProcess::SetNumClasses(int64_t num_classes) {
|
||||
this->primitive_->value.AsDetectionPostProcess()->NumClasses = num_classes;
|
||||
|
@ -86,6 +85,9 @@ void DetectionPostProcess::SetNumClasses(int64_t num_classes) {
|
|||
void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {
|
||||
this->primitive_->value.AsDetectionPostProcess()->UseRegularNms = use_regular_nms;
|
||||
}
|
||||
void DetectionPostProcess::SetOutQuantized(bool out_quantized) {
|
||||
this->primitive_->value.AsDetectionPostProcess()->OutQuantized = out_quantized;
|
||||
}
|
||||
|
||||
#else
|
||||
int DetectionPostProcess::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
|
@ -98,8 +100,8 @@ int DetectionPostProcess::UnPackToFlatBuilder(const schema::Primitive *primitive
|
|||
}
|
||||
auto val_offset = schema::CreateDetectionPostProcess(
|
||||
*fbb, attr->format(), attr->inputSize(), attr->hScale(), attr->wScale(), attr->xScale(), attr->yScale(),
|
||||
attr->NmsIouThreshold(), attr->NmsScoreThreshold(), attr->MaxDetections(), attr->DetectionsPreClass(),
|
||||
attr->MaxClassesPreDetection(), attr->NumClasses(), attr->UseRegularNms());
|
||||
attr->NmsIouThreshold(), attr->NmsScoreThreshold(), attr->MaxDetections(), attr->DetectionsPerClass(),
|
||||
attr->MaxClassesPerDetection(), attr->NumClasses(), attr->UseRegularNms(), attr->OutQuantized());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DetectionPostProcess, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
|
@ -121,11 +123,11 @@ float DetectionPostProcess::GetNmsScoreThreshold() const {
|
|||
int64_t DetectionPostProcess::GetMaxDetections() const {
|
||||
return this->primitive_->value_as_DetectionPostProcess()->MaxDetections();
|
||||
}
|
||||
int64_t DetectionPostProcess::GetDetectionsPreClass() const {
|
||||
return this->primitive_->value_as_DetectionPostProcess()->DetectionsPreClass();
|
||||
int64_t DetectionPostProcess::GetDetectionsPerClass() const {
|
||||
return this->primitive_->value_as_DetectionPostProcess()->DetectionsPerClass();
|
||||
}
|
||||
int64_t DetectionPostProcess::GetMaxClassesPreDetection() const {
|
||||
return this->primitive_->value_as_DetectionPostProcess()->MaxClassesPreDetection();
|
||||
int64_t DetectionPostProcess::GetMaxClassesPerDetection() const {
|
||||
return this->primitive_->value_as_DetectionPostProcess()->MaxClassesPerDetection();
|
||||
}
|
||||
int64_t DetectionPostProcess::GetNumClasses() const {
|
||||
return this->primitive_->value_as_DetectionPostProcess()->NumClasses();
|
||||
|
@ -133,7 +135,67 @@ int64_t DetectionPostProcess::GetNumClasses() const {
|
|||
bool DetectionPostProcess::GetUseRegularNms() const {
|
||||
return this->primitive_->value_as_DetectionPostProcess()->UseRegularNms();
|
||||
}
|
||||
bool DetectionPostProcess::GetOutQuantized() const {
|
||||
return this->primitive_->value_as_DetectionPostProcess()->OutQuantized();
|
||||
}
|
||||
|
||||
#endif
|
||||
namespace {
|
||||
constexpr int kDetectionPostProcessOutputNum = 4;
|
||||
constexpr int kDetectionPostProcessInputNum = 3;
|
||||
} // namespace
|
||||
int DetectionPostProcess::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
|
||||
std::vector<lite::tensor::Tensor *> outputs_) {
|
||||
if (outputs_.size() != kDetectionPostProcessOutputNum || inputs_.size() != kDetectionPostProcessInputNum) {
|
||||
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs_.size() << ",input size: " << inputs_.size();
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto boxes = inputs_.at(0);
|
||||
MS_ASSERT(boxes != nullptr);
|
||||
auto scores = inputs_.at(1);
|
||||
MS_ASSERT(scores != nullptr);
|
||||
auto anchors = inputs_.at(2);
|
||||
MS_ASSERT(anchors != nullptr);
|
||||
|
||||
const auto input_box_shape = boxes->shape();
|
||||
const auto input_scores_shape = scores->shape();
|
||||
const auto input_anchors_shape = anchors->shape();
|
||||
MS_ASSERT(input_scores_shape[2] >= GetNumClasses());
|
||||
MS_ASSERT(input_scores_shape[2] - GetNumClasses() <= 1);
|
||||
MS_ASSERT(input_box_shape[1] = input_scores_shape[1]);
|
||||
MS_ASSERT(input_box_shape[1] = input_anchors_shape[0]);
|
||||
|
||||
auto detected_boxes = outputs_.at(0);
|
||||
MS_ASSERT(detected_boxes != nullptr);
|
||||
auto detected_classes = outputs_.at(1);
|
||||
MS_ASSERT(detected_classes != nullptr);
|
||||
auto detected_scores = outputs_.at(2);
|
||||
MS_ASSERT(detected_scores != nullptr);
|
||||
auto num_det = outputs_.at(3);
|
||||
MS_ASSERT(num_det != nullptr);
|
||||
|
||||
detected_boxes->SetFormat(boxes->GetFormat());
|
||||
detected_boxes->set_data_type(boxes->data_type());
|
||||
detected_classes->SetFormat(boxes->GetFormat());
|
||||
detected_classes->set_data_type(boxes->data_type());
|
||||
detected_scores->SetFormat(boxes->GetFormat());
|
||||
detected_scores->set_data_type(boxes->data_type());
|
||||
num_det->SetFormat(boxes->GetFormat());
|
||||
num_det->set_data_type(boxes->data_type());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
}
|
||||
const auto max_detections = GetMaxDetections();
|
||||
const auto max_classes_per_detection = GetMaxClassesPerDetection();
|
||||
const auto num_detected_boxes = static_cast<int>(max_detections * max_classes_per_detection);
|
||||
const std::vector<int> box_shape{1, num_detected_boxes, 4};
|
||||
const std::vector<int> class_shape{1, num_detected_boxes};
|
||||
const std::vector<int> num_shape{1};
|
||||
detected_boxes->set_shape(box_shape);
|
||||
detected_classes->set_shape(class_shape);
|
||||
detected_scores->set_shape(class_shape);
|
||||
num_det->set_shape(num_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,15 +40,17 @@ class DetectionPostProcess : public PrimitiveC {
|
|||
void SetNmsIouThreshold(float nms_iou_threshold);
|
||||
void SetNmsScoreThreshold(float nms_score_threshold);
|
||||
void SetMaxDetections(int64_t max_detections);
|
||||
void SetDetectionsPreClass(int64_t detections_pre_class);
|
||||
void SetMaxClassesPreDetection(int64_t max_classes_pre_detection);
|
||||
void SetDetectionsPerClass(int64_t detections_per_class);
|
||||
void SetMaxClassesPerDetection(int64_t max_classes_per_detection);
|
||||
void SetNumClasses(int64_t num_classes);
|
||||
void SetUseRegularNms(bool use_regular_nms);
|
||||
void SetOutQuantized(bool out_quantized);
|
||||
#else
|
||||
DetectionPostProcess() = default;
|
||||
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
|
||||
int GetFormat() const;
|
||||
int GetInputSize() const;
|
||||
float GetHScale() const;
|
||||
|
@ -58,11 +60,13 @@ class DetectionPostProcess : public PrimitiveC {
|
|||
float GetNmsIouThreshold() const;
|
||||
float GetNmsScoreThreshold() const;
|
||||
int64_t GetMaxDetections() const;
|
||||
int64_t GetDetectionsPreClass() const;
|
||||
int64_t GetMaxClassesPreDetection() const;
|
||||
int64_t GetDetectionsPerClass() const;
|
||||
int64_t GetMaxClassesPerDetection() const;
|
||||
int64_t GetNumClasses() const;
|
||||
bool GetUseRegularNms() const;
|
||||
bool GetOutQuantized() const;
|
||||
};
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -113,6 +113,7 @@
|
|||
#include "src/ops/round.h"
|
||||
#include "src/ops/sparse_to_dense.h"
|
||||
#include "src/ops/l2_norm.h"
|
||||
#include "src/ops/detection_post_process.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/fp32/arg_min_max.h"
|
||||
#include "nnacl/fp32/cast.h"
|
||||
|
@ -171,6 +172,7 @@
|
|||
#include "nnacl/leaky_relu_parameter.h"
|
||||
#include "nnacl/sparse_to_dense.h"
|
||||
#include "nnacl/l2_norm_parameter.h"
|
||||
#include "nnacl/detection_post_process_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
|
@ -1517,20 +1519,17 @@ OpParameter *PopulateEluParameter(const mindspore::lite::PrimitiveC *primitive)
|
|||
|
||||
OpParameter *PopulateL2NormParameter(
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
L2NormParameter *l2_norm_parameter =
|
||||
reinterpret_cast<L2NormParameter *>(malloc(sizeof(L2NormParameter)));
|
||||
L2NormParameter *l2_norm_parameter = reinterpret_cast<L2NormParameter *>(malloc(sizeof(L2NormParameter)));
|
||||
if (l2_norm_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc L2NormParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(l2_norm_parameter, 0, sizeof(L2NormParameter));
|
||||
l2_norm_parameter->op_parameter_.type_ = primitive->Type();
|
||||
auto param = reinterpret_cast<mindspore::lite::L2Norm *>(
|
||||
const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
auto param = reinterpret_cast<mindspore::lite::L2Norm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
auto axis_vec = param->GetAxis();
|
||||
l2_norm_parameter->axis_num_ = axis_vec.size();
|
||||
l2_norm_parameter->axis_ =
|
||||
reinterpret_cast<int *>(malloc(axis_vec.size() * sizeof(int)));
|
||||
l2_norm_parameter->axis_ = reinterpret_cast<int *>(malloc(axis_vec.size() * sizeof(int)));
|
||||
for (size_t i = 0; i < axis_vec.size(); i++) {
|
||||
l2_norm_parameter->axis_[i] = axis_vec[i];
|
||||
}
|
||||
|
@ -1542,6 +1541,31 @@ OpParameter *PopulateL2NormParameter(
|
|||
return reinterpret_cast<OpParameter *>(l2_norm_parameter);
|
||||
}
|
||||
|
||||
OpParameter *PopulateDetectionPostProcessParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
DetectionPostProcessParameter *detection_post_process_parameter =
|
||||
reinterpret_cast<DetectionPostProcessParameter *>(malloc(sizeof(DetectionPostProcessParameter)));
|
||||
if (detection_post_process_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc EluParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(detection_post_process_parameter, 0, sizeof(DetectionPostProcessParameter));
|
||||
detection_post_process_parameter->op_parameter_.type_ = primitive->Type();
|
||||
auto param =
|
||||
reinterpret_cast<mindspore::lite::DetectionPostProcess *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
detection_post_process_parameter->h_scale_ = param->GetHScale();
|
||||
detection_post_process_parameter->w_scale_ = param->GetWScale();
|
||||
detection_post_process_parameter->x_scale_ = param->GetXScale();
|
||||
detection_post_process_parameter->y_scale_ = param->GetYScale();
|
||||
detection_post_process_parameter->nms_iou_threshold_ = param->GetNmsIouThreshold();
|
||||
detection_post_process_parameter->nms_score_threshold_ = param->GetNmsScoreThreshold();
|
||||
detection_post_process_parameter->max_detections_ = param->GetMaxDetections();
|
||||
detection_post_process_parameter->detections_per_class_ = param->GetDetectionsPerClass();
|
||||
detection_post_process_parameter->max_classes_per_detection_ = param->GetMaxClassesPerDetection();
|
||||
detection_post_process_parameter->num_classes_ = param->GetNumClasses();
|
||||
detection_post_process_parameter->use_regular_nms_ = param->GetUseRegularNms();
|
||||
return reinterpret_cast<OpParameter *>(detection_post_process_parameter);
|
||||
}
|
||||
|
||||
PopulateParameterRegistry::PopulateParameterRegistry() {
|
||||
populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter;
|
||||
|
@ -1640,6 +1664,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
|
|||
populate_parameter_funcs_[schema::PrimitiveType_EmbeddingLookup] = PopulateEmbeddingLookupParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_Elu] = PopulateEluParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_L2Norm] = PopulateL2NormParameter;
|
||||
populate_parameter_funcs_[schema::PrimitiveType_DetectionPostProcess] = PopulateDetectionPostProcessParameter;
|
||||
}
|
||||
|
||||
PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/fp32/detection_post_process.h"
|
||||
#include <vector>
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_DetectionPostProcess;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int DetectionPostProcessCPUKernel::Init() { return RET_OK; }
|
||||
|
||||
int DetectionPostProcessCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int DetectionPostProcessCPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
auto input_boxes = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
|
||||
auto input_scores = reinterpret_cast<float *>(in_tensors_.at(1)->Data());
|
||||
auto input_anchors = reinterpret_cast<float *>(in_tensors_.at(2)->Data());
|
||||
|
||||
// output_classes and output_num use float type now
|
||||
auto output_boxes = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
|
||||
auto output_classes = reinterpret_cast<float *>(out_tensors_.at(1)->Data());
|
||||
auto output_scores = reinterpret_cast<float *>(out_tensors_.at(2)->Data());
|
||||
auto output_num = reinterpret_cast<float *>(out_tensors_.at(3)->Data());
|
||||
|
||||
MS_ASSERT(context_->allocator != nullptr);
|
||||
const int num_boxes = in_tensors_.at(0)->shape()[1];
|
||||
const int num_classes_with_bg = in_tensors_.at(1)->shape()[2];
|
||||
DetectionPostProcessParameter *parameter = reinterpret_cast<DetectionPostProcessParameter *>(op_parameter_);
|
||||
parameter->decoded_boxes_ = context_->allocator->Malloc(num_boxes * 4 * sizeof(float));
|
||||
parameter->nms_candidate_ = context_->allocator->Malloc(num_boxes * sizeof(uint8_t));
|
||||
parameter->selected_ = context_->allocator->Malloc(num_boxes * sizeof(int));
|
||||
parameter->score_with_class_ = context_->allocator->Malloc(num_boxes * sizeof(ScoreWithIndex));
|
||||
if (parameter->use_regular_nms_) {
|
||||
parameter->score_with_class_all_ =
|
||||
context_->allocator->Malloc((num_boxes + parameter->max_detections_) * sizeof(ScoreWithIndex));
|
||||
} else {
|
||||
parameter->score_with_class_all_ =
|
||||
context_->allocator->Malloc((num_boxes * parameter->num_classes_) * sizeof(ScoreWithIndex));
|
||||
}
|
||||
DetectionPostProcess(num_boxes, num_classes_with_bg, input_boxes, input_scores, input_anchors, output_boxes,
|
||||
output_classes, output_scores, output_num, parameter);
|
||||
context_->allocator->Free(parameter->decoded_boxes_);
|
||||
context_->allocator->Free(parameter->nms_candidate_);
|
||||
context_->allocator->Free(parameter->selected_);
|
||||
context_->allocator->Free(parameter->score_with_class_);
|
||||
context_->allocator->Free(parameter->score_with_class_all_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuDetectionPostProcessFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Create kernel failed, opParameter is nullptr, type: PrimitiveType_DetectionPostProcess. ";
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DetectionPostProcess);
|
||||
auto *kernel = new (std::nothrow) DetectionPostProcessCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new DetectionPostProcessCPUKernel fail!";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DetectionPostProcess, CpuDetectionPostProcessFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "include/context.h"
|
||||
#include "nnacl/fp32/detection_post_process.h"
|
||||
|
||||
using mindspore::lite::Context;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class DetectionPostProcessCPUKernel : public LiteKernel {
|
||||
public:
|
||||
DetectionPostProcessCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
param_ = reinterpret_cast<DetectionPostProcessCPUKernel *>(parameter);
|
||||
}
|
||||
~DetectionPostProcessCPUKernel() override = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
DetectionPostProcessCPUKernel *param_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_H_
|
|
@ -0,0 +1,159 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/common/file_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
class TestDetectionPostProcessFp32 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestDetectionPostProcessFp32() {}
|
||||
};
|
||||
|
||||
void DetectionPostProcessTestInit(std::vector<lite::tensor::Tensor *> *inputs_,
|
||||
std::vector<lite::tensor::Tensor *> *outputs_, DetectionPostProcessParameter *param) {
|
||||
std::string input_boxes_path = "./test_data/detectionPostProcess/input_boxes.bin";
|
||||
size_t input_boxes_size;
|
||||
auto input_boxes_data =
|
||||
reinterpret_cast<float *>(mindspore::lite::ReadFile(input_boxes_path.c_str(), &input_boxes_size));
|
||||
auto *input_boxes = new lite::tensor::Tensor;
|
||||
input_boxes->set_data_type(kNumberTypeFloat32);
|
||||
input_boxes->SetFormat(schema::Format_NHWC);
|
||||
input_boxes->set_shape({1, 1917, 4});
|
||||
input_boxes->MallocData();
|
||||
memcpy(input_boxes->Data(), input_boxes_data, input_boxes_size);
|
||||
inputs_->push_back(input_boxes);
|
||||
|
||||
std::string input_scores_path = "./test_data/detectionPostProcess/input_scores.bin";
|
||||
size_t input_scores_size;
|
||||
auto input_scores_data =
|
||||
reinterpret_cast<float *>(mindspore::lite::ReadFile(input_scores_path.c_str(), &input_scores_size));
|
||||
auto *input_scores = new lite::tensor::Tensor;
|
||||
input_scores->set_data_type(kNumberTypeFloat32);
|
||||
input_scores->SetFormat(schema::Format_NHWC);
|
||||
input_scores->set_shape({1, 1917, 91});
|
||||
input_scores->MallocData();
|
||||
memcpy(input_scores->Data(), input_scores_data, input_scores_size);
|
||||
inputs_->push_back(input_scores);
|
||||
|
||||
std::string input_anchors_path = "./test_data/detectionPostProcess/input_anchors.bin";
|
||||
size_t input_anchors_size;
|
||||
auto input_anchors_data =
|
||||
reinterpret_cast<float *>(mindspore::lite::ReadFile(input_anchors_path.c_str(), &input_anchors_size));
|
||||
auto *input_anchors = new lite::tensor::Tensor;
|
||||
input_anchors->set_data_type(kNumberTypeFloat32);
|
||||
input_anchors->SetFormat(schema::Format_NHWC);
|
||||
input_anchors->set_shape({1917, 4});
|
||||
input_anchors->MallocData();
|
||||
memcpy(input_anchors->Data(), input_anchors_data, input_anchors_size);
|
||||
inputs_->push_back(input_anchors);
|
||||
|
||||
auto *output_boxes = new lite::tensor::Tensor;
|
||||
output_boxes->set_data_type(kNumberTypeFloat32);
|
||||
output_boxes->set_shape({1, 10, 4});
|
||||
output_boxes->SetFormat(schema::Format_NHWC);
|
||||
output_boxes->MallocData();
|
||||
memset(output_boxes->Data(), 0, output_boxes->ElementsNum() * sizeof(float));
|
||||
|
||||
auto *output_classes = new lite::tensor::Tensor;
|
||||
output_classes->set_data_type(kNumberTypeFloat32);
|
||||
output_classes->set_shape({1, 10});
|
||||
output_classes->SetFormat(schema::Format_NHWC);
|
||||
output_classes->MallocData();
|
||||
memset(output_classes->Data(), 0, output_classes->ElementsNum() * sizeof(float));
|
||||
|
||||
auto *output_scores = new lite::tensor::Tensor;
|
||||
output_scores->set_data_type(kNumberTypeFloat32);
|
||||
output_scores->set_shape({1, 10});
|
||||
output_scores->SetFormat(schema::Format_NHWC);
|
||||
output_scores->MallocData();
|
||||
memset(output_scores->Data(), 0, output_scores->ElementsNum() * sizeof(float));
|
||||
|
||||
auto *output_num_det = new lite::tensor::Tensor;
|
||||
output_num_det->set_data_type(kNumberTypeFloat32);
|
||||
output_num_det->set_shape({1});
|
||||
output_num_det->SetFormat(schema::Format_NHWC);
|
||||
output_num_det->MallocData();
|
||||
memset(output_num_det->Data(), 0, output_num_det->ElementsNum() * sizeof(float));
|
||||
|
||||
outputs_->push_back(output_boxes);
|
||||
outputs_->push_back(output_classes);
|
||||
outputs_->push_back(output_scores);
|
||||
outputs_->push_back(output_num_det);
|
||||
|
||||
param->h_scale_ = 5;
|
||||
param->w_scale_ = 5;
|
||||
param->x_scale_ = 10;
|
||||
param->y_scale_ = 10;
|
||||
param->nms_iou_threshold_ = 0.6;
|
||||
param->nms_score_threshold_ = 1e-8;
|
||||
param->max_detections_ = 10;
|
||||
param->detections_per_class_ = 100;
|
||||
param->max_classes_per_detection_ = 1;
|
||||
param->num_classes_ = 90;
|
||||
param->use_regular_nms_ = false;
|
||||
param->out_quantized_ = true;
|
||||
}
|
||||
|
||||
TEST_F(TestDetectionPostProcessFp32, Fast) {
|
||||
std::vector<lite::tensor::Tensor *> inputs_;
|
||||
std::vector<lite::tensor::Tensor *> outputs_;
|
||||
auto param = new DetectionPostProcessParameter();
|
||||
DetectionPostProcessTestInit(&inputs_, &outputs_, param);
|
||||
auto ctx = new lite::Context;
|
||||
ctx->thread_num_ = 1;
|
||||
kernel::DetectionPostProcessCPUKernel *op =
|
||||
new kernel::DetectionPostProcessCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx, nullptr);
|
||||
op->Init();
|
||||
op->Run();
|
||||
|
||||
float *output_boxes = reinterpret_cast<float *>(outputs_[0]->Data());
|
||||
size_t output_boxes_size;
|
||||
std::string output_boxes_path = "./test_data/detectionPostProcess/output_0.bin";
|
||||
auto correct_boxes =
|
||||
reinterpret_cast<float *>(mindspore::lite::ReadFile(output_boxes_path.c_str(), &output_boxes_size));
|
||||
CompareOutputData(output_boxes, correct_boxes, outputs_[0]->ElementsNum(), 0.0001);
|
||||
|
||||
float *output_classes = reinterpret_cast<float *>(outputs_[1]->Data());
|
||||
size_t output_classes_size;
|
||||
std::string output_classes_path = "./test_data/detectionPostProcess/output_1.bin";
|
||||
auto correct_classes =
|
||||
reinterpret_cast<float *>(mindspore::lite::ReadFile(output_classes_path.c_str(), &output_classes_size));
|
||||
CompareOutputData(output_classes, correct_classes, outputs_[1]->ElementsNum(), 0.0001);
|
||||
|
||||
float *output_scores = reinterpret_cast<float *>(outputs_[2]->Data());
|
||||
size_t output_scores_size;
|
||||
std::string output_scores_path = "./test_data/detectionPostProcess/output_2.bin";
|
||||
auto correct_scores =
|
||||
reinterpret_cast<float *>(mindspore::lite::ReadFile(output_scores_path.c_str(), &output_scores_size));
|
||||
CompareOutputData(output_scores, correct_scores, outputs_[2]->ElementsNum(), 0.0001);
|
||||
|
||||
float *output_num_det = reinterpret_cast<float *>(outputs_[3]->Data());
|
||||
size_t output_num_det_size;
|
||||
std::string output_num_det_path = "./test_data/detectionPostProcess/output_3.bin";
|
||||
auto correct_num_det =
|
||||
reinterpret_cast<float *>(mindspore::lite::ReadFile(output_num_det_path.c_str(), &output_num_det_size));
|
||||
CompareOutputData(output_num_det, correct_num_det, outputs_[3]->ElementsNum(), 0.0001);
|
||||
|
||||
delete op;
|
||||
for (auto t : inputs_) delete t;
|
||||
for (auto t : outputs_) delete t;
|
||||
}
|
||||
} // namespace mindspore
|
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -57,11 +57,11 @@ STATUS TfliteCustomParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
|
|||
attr->NmsScoreThreshold = attr_map["nms_score_threshold"].AsFloat();
|
||||
attr->MaxDetections = attr_map["max_detections"].AsInt32();
|
||||
if (attr_map["detections_per_class"].IsNull()) {
|
||||
attr->DetectionsPreClass = 100;
|
||||
attr->DetectionsPerClass = 100;
|
||||
} else {
|
||||
attr->DetectionsPreClass = attr_map["detections_per_class"].AsInt32();
|
||||
attr->DetectionsPerClass = attr_map["detections_per_class"].AsInt32();
|
||||
}
|
||||
attr->MaxClassesPreDetection = attr_map["max_classes_per_detection"].AsInt32();
|
||||
attr->MaxClassesPerDetection = attr_map["max_classes_per_detection"].AsInt32();
|
||||
attr->NumClasses = attr_map["num_classes"].AsInt32();
|
||||
if (attr_map["use_regular_nms"].IsNull()) {
|
||||
attr->UseRegularNms = false;
|
||||
|
|
Loading…
Reference in New Issue