!5744 add DetectionPostProcess op

Merge pull request !5744 from wangzhe/master
This commit is contained in:
mindspore-ci-bot 2020-09-04 15:11:51 +08:00 committed by Gitee
commit 82310bb63f
18 changed files with 1128 additions and 31 deletions

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_DETECTION_POST_PROCESS_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_DETECTION_POST_PROCESS_PARAMETER_H_
#include "nnacl/op_base.h"
typedef struct DetectionPostProcessParameter {
OpParameter op_parameter_;
float h_scale_;
float w_scale_;
float x_scale_;
float y_scale_;
float nms_iou_threshold_;
float nms_score_threshold_;
int64_t max_detections_;
int64_t detections_per_class_;
int64_t max_classes_per_detection_;
int64_t num_classes_;
bool use_regular_nms_;
bool out_quantized_;
void *decoded_boxes_;
void *nms_candidate_;
void *selected_;
void *score_with_class_;
void *score_with_class_all_;
} DetectionPostProcessParameter;
#endif // MINDSPORE_LITE_NNACL_DETECTION_POST_PROCESS_PARAMETER_H_

View File

@ -0,0 +1,220 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp32/detection_post_process.h"
#include <math.h>
#include "nnacl/errorcode.h"
#include "nnacl/op_base.h"
int ScoreWithIndexCmp(const void *a, const void *b) {
ScoreWithIndex *pa = (ScoreWithIndex *)a;
ScoreWithIndex *pb = (ScoreWithIndex *)b;
if (pa->score > pb->score) {
return -1;
} else if (pa->score < pb->score) {
return 1;
} else {
return 0;
}
}
float IntersectionOverUnion(const BboxCorner *a, const BboxCorner *b) {
const float area_a = (a->ymax - a->ymin) * (a->xmax - a->xmin);
const float area_b = (b->ymax - b->ymin) * (b->xmax - b->xmin);
if (area_a <= 0 || area_b <= 0) {
return 0.0f;
}
const float ymin = a->ymin > b->ymin ? a->ymin : b->ymin;
const float xmin = a->xmin > b->xmin ? a->xmin : b->xmin;
const float ymax = a->ymax < b->ymax ? a->ymax : b->ymax;
const float xmax = a->xmax < b->xmax ? a->xmax : b->xmax;
const float h = ymax - ymin > 0.0f ? ymax - ymin : 0.0f;
const float w = xmax - xmin > 0.0f ? xmax - xmin : 0.0f;
const float inter = h * w;
return inter / (area_a + area_b - inter + 1e-8);
}
void DecodeBoxes(const int num_boxes, const float *input_boxes, const float *anchors, const BboxCenter scaler,
float *decoded_boxes) {
for (int i = 0; i < num_boxes; ++i) {
BboxCenter *box = (BboxCenter *)(input_boxes + i * 4);
BboxCenter *anchor = (BboxCenter *)(anchors + i * 4);
BboxCorner *decoded_box = (BboxCorner *)(decoded_boxes + i * 4);
float y_center = box->y / scaler.y * anchor->h + anchor->y;
float x_center = box->x / scaler.x * anchor->w + anchor->x;
float h_half = 0.5f * expf(box->h / scaler.h) * anchor->h;
float w_half = 0.5f * expf(box->w / scaler.w) * anchor->w;
decoded_box->ymin = y_center - h_half;
decoded_box->xmin = x_center - w_half;
decoded_box->ymax = y_center + h_half;
decoded_box->xmax = x_center + w_half;
}
}
int NmsSingleClass(const int candidate_num, const float *decoded_boxes, const int max_detections,
ScoreWithIndex *score_with_index, int *selected, const DetectionPostProcessParameter *param) {
uint8_t *nms_candidate = param->nms_candidate_;
const int output_num = candidate_num < max_detections ? candidate_num : max_detections;
int possible_candidate_num = candidate_num;
int selected_num = 0;
qsort(score_with_index, candidate_num, sizeof(ScoreWithIndex), ScoreWithIndexCmp);
for (int i = 0; i < candidate_num; ++i) {
nms_candidate[i] = 1;
}
for (int i = 0; i < candidate_num; ++i) {
if (possible_candidate_num == 0 || selected_num >= output_num) {
break;
}
if (nms_candidate[i] == 0) {
continue;
}
selected[selected_num++] = score_with_index[i].index;
nms_candidate[i] = 0;
possible_candidate_num--;
for (int t = i + 1; t < candidate_num; ++t) {
if (nms_candidate[t] == 1) {
const BboxCorner *bbox_i = (BboxCorner *)(decoded_boxes) + score_with_index[i].index;
const BboxCorner *bbox_t = (BboxCorner *)(decoded_boxes) + score_with_index[t].index;
const float iou = IntersectionOverUnion(bbox_i, bbox_t);
if (iou > param->nms_iou_threshold_) {
nms_candidate[t] = 0;
possible_candidate_num--;
}
}
}
}
return selected_num;
}
int NmsMultiClassesRegular(const int num_boxes, const int num_classes_with_bg, const float *decoded_boxes,
const float *input_scores, float *output_boxes, float *output_classes, float *output_scores,
const DetectionPostProcessParameter *param) {
const int first_class_index = num_classes_with_bg - (int)(param->num_classes_);
int *selected = (int *)(param->selected_);
ScoreWithIndex *score_with_index_single = (ScoreWithIndex *)(param->score_with_class_);
int all_classes_sorted_num = 0;
int all_classes_output_num = 0;
ScoreWithIndex *score_with_index_all = (ScoreWithIndex *)(param->score_with_class_all_);
for (int j = first_class_index; j < num_classes_with_bg; ++j) {
int candidate_num = 0;
// process single class
for (int i = 0; i < num_boxes; ++i) {
const float score = input_scores[i * num_classes_with_bg + j];
if (score >= param->nms_score_threshold_) {
score_with_index_single[candidate_num].score = score;
score_with_index_single[candidate_num++].index = i;
}
}
int selected_num = NmsSingleClass(candidate_num, decoded_boxes, param->detections_per_class_,
score_with_index_single, selected, param);
// process all classes
for (int i = 0; i < selected_num; ++i) {
// store class to index
score_with_index_all[all_classes_sorted_num].index = selected[i] * num_classes_with_bg + j;
score_with_index_all[all_classes_sorted_num++].score = input_scores[selected[i] * num_classes_with_bg + j];
}
all_classes_output_num =
all_classes_sorted_num < param->max_detections_ ? all_classes_sorted_num : param->max_detections_;
qsort(score_with_index_all, all_classes_sorted_num, sizeof(ScoreWithIndex), ScoreWithIndexCmp);
all_classes_sorted_num = all_classes_output_num;
}
for (int i = 0; i < param->max_detections_ * param->max_classes_per_detection_; ++i) {
if (i < all_classes_output_num) {
const int box_index = score_with_index_all[i].index / num_classes_with_bg;
const int class_index = score_with_index_all[i].index - box_index * num_classes_with_bg - first_class_index;
*((BboxCorner *)(output_boxes) + i) = *((BboxCorner *)(decoded_boxes) + box_index);
output_classes[i] = (float)class_index;
output_scores[i] = score_with_index_all[i].score;;
} else {
((BboxCorner *)(output_boxes) + i)->ymin = 0;
((BboxCorner *)(output_boxes) + i)->xmin = 0;
((BboxCorner *)(output_boxes) + i)->ymax = 0;
((BboxCorner *)(output_boxes) + i)->xmax = 0;
output_classes[i] = 0.0f;
output_scores[i] = 0.0f;
}
}
return all_classes_output_num;
}
int NmsMultiClassesFast(const int num_boxes, const int num_classes_with_bg, const float *decoded_boxes,
const float *input_scores, float *output_boxes, float *output_classes, float *output_scores,
const DetectionPostProcessParameter *param) {
const int first_class_index = num_classes_with_bg - (int)(param->num_classes_);
const int64_t max_classes_per_anchor =
param->max_classes_per_detection_ < param->num_classes_ ? param->max_classes_per_detection_ : param->num_classes_;
int candidate_num = 0;
ScoreWithIndex *score_with_class_all = (ScoreWithIndex *)(param->score_with_class_all_);
ScoreWithIndex *score_with_class = (ScoreWithIndex *)(param->score_with_class_);
int *selected = (int *)(param->selected_);
int selected_num;
int output_num = 0;
for (int i = 0; i < num_boxes; ++i) {
for (int j = first_class_index; j < num_classes_with_bg; ++j) {
float score_t = *(input_scores + i * num_classes_with_bg + j);
score_with_class_all[i * param->num_classes_ + j - first_class_index].score = score_t;
// save box and class info to index
score_with_class_all[i * param->num_classes_ + j - first_class_index].index = i * num_classes_with_bg + j;
}
qsort(score_with_class_all + i * param->num_classes_, param->num_classes_, sizeof(ScoreWithIndex),
ScoreWithIndexCmp);
const float score_max = (score_with_class_all + i * param->num_classes_)->score;
if (score_max >= param->nms_score_threshold_) {
score_with_class[candidate_num].index = i;
score_with_class[candidate_num++].score = score_max;
}
}
selected_num =
NmsSingleClass(candidate_num, decoded_boxes, param->max_detections_, score_with_class, selected, param);
for (int i = 0; i < selected_num; ++i) {
const ScoreWithIndex *box_score_with_class = score_with_class_all + selected[i] * param->num_classes_;
const int box_index = box_score_with_class->index / num_classes_with_bg;
for (int j = 0; j < max_classes_per_anchor; ++j) {
*((BboxCorner *)(output_boxes) + output_num) = *((BboxCorner *)(decoded_boxes) + box_index);
output_scores[output_num] = (box_score_with_class + j)->score;
output_classes[output_num++] =
(float)((box_score_with_class + j)->index % num_classes_with_bg - first_class_index);
}
}
for (int i = output_num; i < param->max_detections_ * param->max_classes_per_detection_; ++i) {
((BboxCorner *)(output_boxes) + i)->ymin = 0;
((BboxCorner *)(output_boxes) + i)->xmin = 0;
((BboxCorner *)(output_boxes) + i)->ymax = 0;
((BboxCorner *)(output_boxes) + i)->xmax = 0;
output_scores[i] = 0;
output_classes[i] = 0;
}
return output_num;
}
int DetectionPostProcess(const int num_boxes, const int num_classes_with_bg, float *input_boxes, float *input_scores,
float *input_anchors, float *output_boxes, float *output_classes, float *output_scores,
float *output_num, DetectionPostProcessParameter *param) {
BboxCenter scaler;
scaler.y = param->y_scale_;
scaler.x = param->x_scale_;
scaler.h = param->h_scale_;
scaler.w = param->w_scale_;
DecodeBoxes(num_boxes, input_boxes, input_anchors, scaler, param->decoded_boxes_);
if (param->use_regular_nms_) {
*output_num = NmsMultiClassesRegular(num_boxes, num_classes_with_bg, param->decoded_boxes_, input_scores,
output_boxes, output_classes, output_scores, param);
} else {
*output_num = NmsMultiClassesFast(num_boxes, num_classes_with_bg, param->decoded_boxes_, input_scores, output_boxes,
output_classes, output_scores, param);
}
return NNACL_OK;
}

View File

@ -0,0 +1,57 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_DETECTION_POST_PROCESS_H_
#define MINDSPORE_LITE_NNACL_FP32_DETECTION_POST_PROCESS_H_
#include "nnacl/op_base.h"
#include "nnacl/detection_post_process_parameter.h"
typedef struct {
float y;
float x;
float h;
float w;
} BboxCenter;
typedef struct {
float ymin;
float xmin;
float ymax;
float xmax;
} BboxCorner;
typedef struct {
float score;
int index;
} ScoreWithIndex;
#ifdef __cplusplus
extern "C" {
#endif
void nms_multi_classes_regular();
void nms_multi_classes_fase();
int DetectionPostProcess(const int num_boxes, const int num_classes_with_bg, float *input_boxes, float *input_scores,
float *input_anchors, float *output_boxes, float *output_classes, float *output_scores,
float *output_num, DetectionPostProcessParameter *param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_DETECTION_POST_PROCESS_H_

View File

@ -351,8 +351,8 @@ table DetectionPostProcess {
NmsIouThreshold: float;
NmsScoreThreshold: float;
MaxDetections: long;
DetectionsPreClass: long;
MaxClassesPreDetection: long;
DetectionsPerClass: long;
MaxClassesPerDetection: long;
NumClasses: long;
UseRegularNms: bool;
OutQuantized: bool;

View File

@ -34,11 +34,11 @@ float DetectionPostProcess::GetNmsScoreThreshold() const {
int64_t DetectionPostProcess::GetMaxDetections() const {
return this->primitive_->value.AsDetectionPostProcess()->MaxDetections;
}
int64_t DetectionPostProcess::GetDetectionsPreClass() const {
return this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass;
int64_t DetectionPostProcess::GetDetectionsPerClass() const {
return this->primitive_->value.AsDetectionPostProcess()->DetectionsPerClass;
}
int64_t DetectionPostProcess::GetMaxClassesPreDetection() const {
return this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection;
int64_t DetectionPostProcess::GetMaxClassesPerDetection() const {
return this->primitive_->value.AsDetectionPostProcess()->MaxClassesPerDetection;
}
int64_t DetectionPostProcess::GetNumClasses() const {
return this->primitive_->value.AsDetectionPostProcess()->NumClasses;
@ -46,7 +46,6 @@ int64_t DetectionPostProcess::GetNumClasses() const {
bool DetectionPostProcess::GetUseRegularNms() const {
return this->primitive_->value.AsDetectionPostProcess()->UseRegularNms;
}
void DetectionPostProcess::SetFormat(int format) {
this->primitive_->value.AsDetectionPostProcess()->format = (schema::Format)format;
}
@ -72,13 +71,13 @@ void DetectionPostProcess::SetNmsScoreThreshold(float nms_score_threshold) {
this->primitive_->value.AsDetectionPostProcess()->NmsScoreThreshold = nms_score_threshold;
}
void DetectionPostProcess::SetMaxDetections(int64_t max_detections) {
this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_detections;
this->primitive_->value.AsDetectionPostProcess()->MaxDetections = max_detections;
}
void DetectionPostProcess::SetDetectionsPreClass(int64_t detections_pre_class) {
this->primitive_->value.AsDetectionPostProcess()->DetectionsPreClass = detections_pre_class;
void DetectionPostProcess::SetDetectionsPerClass(int64_t detections_per_class) {
this->primitive_->value.AsDetectionPostProcess()->DetectionsPerClass = detections_per_class;
}
void DetectionPostProcess::SetMaxClassesPreDetection(int64_t max_classes_pre_detection) {
this->primitive_->value.AsDetectionPostProcess()->MaxClassesPreDetection = max_classes_pre_detection;
void DetectionPostProcess::SetMaxClassesPerDetection(int64_t max_classes_per_detection) {
this->primitive_->value.AsDetectionPostProcess()->MaxClassesPerDetection = max_classes_per_detection;
}
void DetectionPostProcess::SetNumClasses(int64_t num_classes) {
this->primitive_->value.AsDetectionPostProcess()->NumClasses = num_classes;
@ -86,6 +85,9 @@ void DetectionPostProcess::SetNumClasses(int64_t num_classes) {
void DetectionPostProcess::SetUseRegularNms(bool use_regular_nms) {
this->primitive_->value.AsDetectionPostProcess()->UseRegularNms = use_regular_nms;
}
void DetectionPostProcess::SetOutQuantized(bool out_quantized) {
this->primitive_->value.AsDetectionPostProcess()->OutQuantized = out_quantized;
}
#else
int DetectionPostProcess::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
@ -98,8 +100,8 @@ int DetectionPostProcess::UnPackToFlatBuilder(const schema::Primitive *primitive
}
auto val_offset = schema::CreateDetectionPostProcess(
*fbb, attr->format(), attr->inputSize(), attr->hScale(), attr->wScale(), attr->xScale(), attr->yScale(),
attr->NmsIouThreshold(), attr->NmsScoreThreshold(), attr->MaxDetections(), attr->DetectionsPreClass(),
attr->MaxClassesPreDetection(), attr->NumClasses(), attr->UseRegularNms());
attr->NmsIouThreshold(), attr->NmsScoreThreshold(), attr->MaxDetections(), attr->DetectionsPerClass(),
attr->MaxClassesPerDetection(), attr->NumClasses(), attr->UseRegularNms(), attr->OutQuantized());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_DetectionPostProcess, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
@ -121,11 +123,11 @@ float DetectionPostProcess::GetNmsScoreThreshold() const {
int64_t DetectionPostProcess::GetMaxDetections() const {
return this->primitive_->value_as_DetectionPostProcess()->MaxDetections();
}
int64_t DetectionPostProcess::GetDetectionsPreClass() const {
return this->primitive_->value_as_DetectionPostProcess()->DetectionsPreClass();
int64_t DetectionPostProcess::GetDetectionsPerClass() const {
return this->primitive_->value_as_DetectionPostProcess()->DetectionsPerClass();
}
int64_t DetectionPostProcess::GetMaxClassesPreDetection() const {
return this->primitive_->value_as_DetectionPostProcess()->MaxClassesPreDetection();
int64_t DetectionPostProcess::GetMaxClassesPerDetection() const {
return this->primitive_->value_as_DetectionPostProcess()->MaxClassesPerDetection();
}
int64_t DetectionPostProcess::GetNumClasses() const {
return this->primitive_->value_as_DetectionPostProcess()->NumClasses();
@ -133,7 +135,67 @@ int64_t DetectionPostProcess::GetNumClasses() const {
bool DetectionPostProcess::GetUseRegularNms() const {
return this->primitive_->value_as_DetectionPostProcess()->UseRegularNms();
}
bool DetectionPostProcess::GetOutQuantized() const {
return this->primitive_->value_as_DetectionPostProcess()->OutQuantized();
}
#endif
namespace {
constexpr int kDetectionPostProcessOutputNum = 4;
constexpr int kDetectionPostProcessInputNum = 3;
} // namespace
int DetectionPostProcess::InferShape(std::vector<lite::tensor::Tensor *> inputs_,
std::vector<lite::tensor::Tensor *> outputs_) {
if (outputs_.size() != kDetectionPostProcessOutputNum || inputs_.size() != kDetectionPostProcessInputNum) {
MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs_.size() << ",input size: " << inputs_.size();
return RET_PARAM_INVALID;
}
auto boxes = inputs_.at(0);
MS_ASSERT(boxes != nullptr);
auto scores = inputs_.at(1);
MS_ASSERT(scores != nullptr);
auto anchors = inputs_.at(2);
MS_ASSERT(anchors != nullptr);
const auto input_box_shape = boxes->shape();
const auto input_scores_shape = scores->shape();
const auto input_anchors_shape = anchors->shape();
MS_ASSERT(input_scores_shape[2] >= GetNumClasses());
MS_ASSERT(input_scores_shape[2] - GetNumClasses() <= 1);
MS_ASSERT(input_box_shape[1] = input_scores_shape[1]);
MS_ASSERT(input_box_shape[1] = input_anchors_shape[0]);
auto detected_boxes = outputs_.at(0);
MS_ASSERT(detected_boxes != nullptr);
auto detected_classes = outputs_.at(1);
MS_ASSERT(detected_classes != nullptr);
auto detected_scores = outputs_.at(2);
MS_ASSERT(detected_scores != nullptr);
auto num_det = outputs_.at(3);
MS_ASSERT(num_det != nullptr);
detected_boxes->SetFormat(boxes->GetFormat());
detected_boxes->set_data_type(boxes->data_type());
detected_classes->SetFormat(boxes->GetFormat());
detected_classes->set_data_type(boxes->data_type());
detected_scores->SetFormat(boxes->GetFormat());
detected_scores->set_data_type(boxes->data_type());
num_det->SetFormat(boxes->GetFormat());
num_det->set_data_type(boxes->data_type());
if (!GetInferFlag()) {
return RET_OK;
}
const auto max_detections = GetMaxDetections();
const auto max_classes_per_detection = GetMaxClassesPerDetection();
const auto num_detected_boxes = static_cast<int>(max_detections * max_classes_per_detection);
const std::vector<int> box_shape{1, num_detected_boxes, 4};
const std::vector<int> class_shape{1, num_detected_boxes};
const std::vector<int> num_shape{1};
detected_boxes->set_shape(box_shape);
detected_classes->set_shape(class_shape);
detected_scores->set_shape(class_shape);
num_det->set_shape(num_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -40,15 +40,17 @@ class DetectionPostProcess : public PrimitiveC {
void SetNmsIouThreshold(float nms_iou_threshold);
void SetNmsScoreThreshold(float nms_score_threshold);
void SetMaxDetections(int64_t max_detections);
void SetDetectionsPreClass(int64_t detections_pre_class);
void SetMaxClassesPreDetection(int64_t max_classes_pre_detection);
void SetDetectionsPerClass(int64_t detections_per_class);
void SetMaxClassesPerDetection(int64_t max_classes_per_detection);
void SetNumClasses(int64_t num_classes);
void SetUseRegularNms(bool use_regular_nms);
void SetOutQuantized(bool out_quantized);
#else
DetectionPostProcess() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetFormat() const;
int GetInputSize() const;
float GetHScale() const;
@ -58,11 +60,13 @@ class DetectionPostProcess : public PrimitiveC {
float GetNmsIouThreshold() const;
float GetNmsScoreThreshold() const;
int64_t GetMaxDetections() const;
int64_t GetDetectionsPreClass() const;
int64_t GetMaxClassesPreDetection() const;
int64_t GetDetectionsPerClass() const;
int64_t GetMaxClassesPerDetection() const;
int64_t GetNumClasses() const;
bool GetUseRegularNms() const;
bool GetOutQuantized() const;
};
} // namespace lite
} // namespace mindspore

View File

@ -113,6 +113,7 @@
#include "src/ops/round.h"
#include "src/ops/sparse_to_dense.h"
#include "src/ops/l2_norm.h"
#include "src/ops/detection_post_process.h"
#include "nnacl/op_base.h"
#include "nnacl/fp32/arg_min_max.h"
#include "nnacl/fp32/cast.h"
@ -171,6 +172,7 @@
#include "nnacl/leaky_relu_parameter.h"
#include "nnacl/sparse_to_dense.h"
#include "nnacl/l2_norm_parameter.h"
#include "nnacl/detection_post_process_parameter.h"
namespace mindspore::kernel {
@ -1517,20 +1519,17 @@ OpParameter *PopulateEluParameter(const mindspore::lite::PrimitiveC *primitive)
OpParameter *PopulateL2NormParameter(
const mindspore::lite::PrimitiveC *primitive) {
L2NormParameter *l2_norm_parameter =
reinterpret_cast<L2NormParameter *>(malloc(sizeof(L2NormParameter)));
L2NormParameter *l2_norm_parameter = reinterpret_cast<L2NormParameter *>(malloc(sizeof(L2NormParameter)));
if (l2_norm_parameter == nullptr) {
MS_LOG(ERROR) << "malloc L2NormParameter failed.";
return nullptr;
}
memset(l2_norm_parameter, 0, sizeof(L2NormParameter));
l2_norm_parameter->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::L2Norm *>(
const_cast<mindspore::lite::PrimitiveC *>(primitive));
auto param = reinterpret_cast<mindspore::lite::L2Norm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
auto axis_vec = param->GetAxis();
l2_norm_parameter->axis_num_ = axis_vec.size();
l2_norm_parameter->axis_ =
reinterpret_cast<int *>(malloc(axis_vec.size() * sizeof(int)));
l2_norm_parameter->axis_ = reinterpret_cast<int *>(malloc(axis_vec.size() * sizeof(int)));
for (size_t i = 0; i < axis_vec.size(); i++) {
l2_norm_parameter->axis_[i] = axis_vec[i];
}
@ -1542,6 +1541,31 @@ OpParameter *PopulateL2NormParameter(
return reinterpret_cast<OpParameter *>(l2_norm_parameter);
}
OpParameter *PopulateDetectionPostProcessParameter(const mindspore::lite::PrimitiveC *primitive) {
DetectionPostProcessParameter *detection_post_process_parameter =
reinterpret_cast<DetectionPostProcessParameter *>(malloc(sizeof(DetectionPostProcessParameter)));
if (detection_post_process_parameter == nullptr) {
MS_LOG(ERROR) << "malloc EluParameter failed.";
return nullptr;
}
memset(detection_post_process_parameter, 0, sizeof(DetectionPostProcessParameter));
detection_post_process_parameter->op_parameter_.type_ = primitive->Type();
auto param =
reinterpret_cast<mindspore::lite::DetectionPostProcess *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
detection_post_process_parameter->h_scale_ = param->GetHScale();
detection_post_process_parameter->w_scale_ = param->GetWScale();
detection_post_process_parameter->x_scale_ = param->GetXScale();
detection_post_process_parameter->y_scale_ = param->GetYScale();
detection_post_process_parameter->nms_iou_threshold_ = param->GetNmsIouThreshold();
detection_post_process_parameter->nms_score_threshold_ = param->GetNmsScoreThreshold();
detection_post_process_parameter->max_detections_ = param->GetMaxDetections();
detection_post_process_parameter->detections_per_class_ = param->GetDetectionsPerClass();
detection_post_process_parameter->max_classes_per_detection_ = param->GetMaxClassesPerDetection();
detection_post_process_parameter->num_classes_ = param->GetNumClasses();
detection_post_process_parameter->use_regular_nms_ = param->GetUseRegularNms();
return reinterpret_cast<OpParameter *>(detection_post_process_parameter);
}
PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_SparseToDense] = PopulateSparseToDenseParameter;
populate_parameter_funcs_[schema::PrimitiveType_SoftMax] = PopulateSoftmaxParameter;
@ -1640,6 +1664,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_EmbeddingLookup] = PopulateEmbeddingLookupParameter;
populate_parameter_funcs_[schema::PrimitiveType_Elu] = PopulateEluParameter;
populate_parameter_funcs_[schema::PrimitiveType_L2Norm] = PopulateL2NormParameter;
populate_parameter_funcs_[schema::PrimitiveType_DetectionPostProcess] = PopulateDetectionPostProcessParameter;
}
PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {

View File

@ -0,0 +1,100 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/detection_post_process.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DetectionPostProcess;
namespace mindspore::kernel {
int DetectionPostProcessCPUKernel::Init() { return RET_OK; }
int DetectionPostProcessCPUKernel::ReSize() { return RET_OK; }
int DetectionPostProcessCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto input_boxes = reinterpret_cast<float *>(in_tensors_.at(0)->Data());
auto input_scores = reinterpret_cast<float *>(in_tensors_.at(1)->Data());
auto input_anchors = reinterpret_cast<float *>(in_tensors_.at(2)->Data());
// output_classes and output_num use float type now
auto output_boxes = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
auto output_classes = reinterpret_cast<float *>(out_tensors_.at(1)->Data());
auto output_scores = reinterpret_cast<float *>(out_tensors_.at(2)->Data());
auto output_num = reinterpret_cast<float *>(out_tensors_.at(3)->Data());
MS_ASSERT(context_->allocator != nullptr);
const int num_boxes = in_tensors_.at(0)->shape()[1];
const int num_classes_with_bg = in_tensors_.at(1)->shape()[2];
DetectionPostProcessParameter *parameter = reinterpret_cast<DetectionPostProcessParameter *>(op_parameter_);
parameter->decoded_boxes_ = context_->allocator->Malloc(num_boxes * 4 * sizeof(float));
parameter->nms_candidate_ = context_->allocator->Malloc(num_boxes * sizeof(uint8_t));
parameter->selected_ = context_->allocator->Malloc(num_boxes * sizeof(int));
parameter->score_with_class_ = context_->allocator->Malloc(num_boxes * sizeof(ScoreWithIndex));
if (parameter->use_regular_nms_) {
parameter->score_with_class_all_ =
context_->allocator->Malloc((num_boxes + parameter->max_detections_) * sizeof(ScoreWithIndex));
} else {
parameter->score_with_class_all_ =
context_->allocator->Malloc((num_boxes * parameter->num_classes_) * sizeof(ScoreWithIndex));
}
DetectionPostProcess(num_boxes, num_classes_with_bg, input_boxes, input_scores, input_anchors, output_boxes,
output_classes, output_scores, output_num, parameter);
context_->allocator->Free(parameter->decoded_boxes_);
context_->allocator->Free(parameter->nms_candidate_);
context_->allocator->Free(parameter->selected_);
context_->allocator->Free(parameter->score_with_class_);
context_->allocator->Free(parameter->score_with_class_all_);
return RET_OK;
}
kernel::LiteKernel *CpuDetectionPostProcessFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Create kernel failed, opParameter is nullptr, type: PrimitiveType_DetectionPostProcess. ";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_DetectionPostProcess);
auto *kernel = new (std::nothrow) DetectionPostProcessCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new DetectionPostProcessCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_DetectionPostProcess, CpuDetectionPostProcessFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_H_
#include <vector>
#include "src/lite_kernel.h"
#include "include/context.h"
#include "nnacl/fp32/detection_post_process.h"
using mindspore::lite::Context;
namespace mindspore::kernel {
class DetectionPostProcessCPUKernel : public LiteKernel {
public:
DetectionPostProcessCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<DetectionPostProcessCPUKernel *>(parameter);
}
~DetectionPostProcessCPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
private:
DetectionPostProcessCPUKernel *param_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DETECTION_POST_PROCESS_H_

View File

@ -0,0 +1,159 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/core/utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/fp32/detection_post_process.h"
#include "src/kernel_registry.h"
#include "src/lite_kernel.h"
#include "src/common/file_utils.h"
namespace mindspore {
class TestDetectionPostProcessFp32 : public mindspore::CommonTest {
public:
TestDetectionPostProcessFp32() {}
};
void DetectionPostProcessTestInit(std::vector<lite::tensor::Tensor *> *inputs_,
std::vector<lite::tensor::Tensor *> *outputs_, DetectionPostProcessParameter *param) {
std::string input_boxes_path = "./test_data/detectionPostProcess/input_boxes.bin";
size_t input_boxes_size;
auto input_boxes_data =
reinterpret_cast<float *>(mindspore::lite::ReadFile(input_boxes_path.c_str(), &input_boxes_size));
auto *input_boxes = new lite::tensor::Tensor;
input_boxes->set_data_type(kNumberTypeFloat32);
input_boxes->SetFormat(schema::Format_NHWC);
input_boxes->set_shape({1, 1917, 4});
input_boxes->MallocData();
memcpy(input_boxes->Data(), input_boxes_data, input_boxes_size);
inputs_->push_back(input_boxes);
std::string input_scores_path = "./test_data/detectionPostProcess/input_scores.bin";
size_t input_scores_size;
auto input_scores_data =
reinterpret_cast<float *>(mindspore::lite::ReadFile(input_scores_path.c_str(), &input_scores_size));
auto *input_scores = new lite::tensor::Tensor;
input_scores->set_data_type(kNumberTypeFloat32);
input_scores->SetFormat(schema::Format_NHWC);
input_scores->set_shape({1, 1917, 91});
input_scores->MallocData();
memcpy(input_scores->Data(), input_scores_data, input_scores_size);
inputs_->push_back(input_scores);
std::string input_anchors_path = "./test_data/detectionPostProcess/input_anchors.bin";
size_t input_anchors_size;
auto input_anchors_data =
reinterpret_cast<float *>(mindspore::lite::ReadFile(input_anchors_path.c_str(), &input_anchors_size));
auto *input_anchors = new lite::tensor::Tensor;
input_anchors->set_data_type(kNumberTypeFloat32);
input_anchors->SetFormat(schema::Format_NHWC);
input_anchors->set_shape({1917, 4});
input_anchors->MallocData();
memcpy(input_anchors->Data(), input_anchors_data, input_anchors_size);
inputs_->push_back(input_anchors);
auto *output_boxes = new lite::tensor::Tensor;
output_boxes->set_data_type(kNumberTypeFloat32);
output_boxes->set_shape({1, 10, 4});
output_boxes->SetFormat(schema::Format_NHWC);
output_boxes->MallocData();
memset(output_boxes->Data(), 0, output_boxes->ElementsNum() * sizeof(float));
auto *output_classes = new lite::tensor::Tensor;
output_classes->set_data_type(kNumberTypeFloat32);
output_classes->set_shape({1, 10});
output_classes->SetFormat(schema::Format_NHWC);
output_classes->MallocData();
memset(output_classes->Data(), 0, output_classes->ElementsNum() * sizeof(float));
auto *output_scores = new lite::tensor::Tensor;
output_scores->set_data_type(kNumberTypeFloat32);
output_scores->set_shape({1, 10});
output_scores->SetFormat(schema::Format_NHWC);
output_scores->MallocData();
memset(output_scores->Data(), 0, output_scores->ElementsNum() * sizeof(float));
auto *output_num_det = new lite::tensor::Tensor;
output_num_det->set_data_type(kNumberTypeFloat32);
output_num_det->set_shape({1});
output_num_det->SetFormat(schema::Format_NHWC);
output_num_det->MallocData();
memset(output_num_det->Data(), 0, output_num_det->ElementsNum() * sizeof(float));
outputs_->push_back(output_boxes);
outputs_->push_back(output_classes);
outputs_->push_back(output_scores);
outputs_->push_back(output_num_det);
param->h_scale_ = 5;
param->w_scale_ = 5;
param->x_scale_ = 10;
param->y_scale_ = 10;
param->nms_iou_threshold_ = 0.6;
param->nms_score_threshold_ = 1e-8;
param->max_detections_ = 10;
param->detections_per_class_ = 100;
param->max_classes_per_detection_ = 1;
param->num_classes_ = 90;
param->use_regular_nms_ = false;
param->out_quantized_ = true;
}
TEST_F(TestDetectionPostProcessFp32, Fast) {
std::vector<lite::tensor::Tensor *> inputs_;
std::vector<lite::tensor::Tensor *> outputs_;
auto param = new DetectionPostProcessParameter();
DetectionPostProcessTestInit(&inputs_, &outputs_, param);
auto ctx = new lite::Context;
ctx->thread_num_ = 1;
kernel::DetectionPostProcessCPUKernel *op =
new kernel::DetectionPostProcessCPUKernel(reinterpret_cast<OpParameter *>(param), inputs_, outputs_, ctx, nullptr);
op->Init();
op->Run();
float *output_boxes = reinterpret_cast<float *>(outputs_[0]->Data());
size_t output_boxes_size;
std::string output_boxes_path = "./test_data/detectionPostProcess/output_0.bin";
auto correct_boxes =
reinterpret_cast<float *>(mindspore::lite::ReadFile(output_boxes_path.c_str(), &output_boxes_size));
CompareOutputData(output_boxes, correct_boxes, outputs_[0]->ElementsNum(), 0.0001);
float *output_classes = reinterpret_cast<float *>(outputs_[1]->Data());
size_t output_classes_size;
std::string output_classes_path = "./test_data/detectionPostProcess/output_1.bin";
auto correct_classes =
reinterpret_cast<float *>(mindspore::lite::ReadFile(output_classes_path.c_str(), &output_classes_size));
CompareOutputData(output_classes, correct_classes, outputs_[1]->ElementsNum(), 0.0001);
float *output_scores = reinterpret_cast<float *>(outputs_[2]->Data());
size_t output_scores_size;
std::string output_scores_path = "./test_data/detectionPostProcess/output_2.bin";
auto correct_scores =
reinterpret_cast<float *>(mindspore::lite::ReadFile(output_scores_path.c_str(), &output_scores_size));
CompareOutputData(output_scores, correct_scores, outputs_[2]->ElementsNum(), 0.0001);
float *output_num_det = reinterpret_cast<float *>(outputs_[3]->Data());
size_t output_num_det_size;
std::string output_num_det_path = "./test_data/detectionPostProcess/output_3.bin";
auto correct_num_det =
reinterpret_cast<float *>(mindspore::lite::ReadFile(output_num_det_path.c_str(), &output_num_det_size));
CompareOutputData(output_num_det, correct_num_det, outputs_[3]->ElementsNum(), 0.0001);
delete op;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
}
} // namespace mindspore

File diff suppressed because one or more lines are too long

View File

@ -57,11 +57,11 @@ STATUS TfliteCustomParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
attr->NmsScoreThreshold = attr_map["nms_score_threshold"].AsFloat();
attr->MaxDetections = attr_map["max_detections"].AsInt32();
if (attr_map["detections_per_class"].IsNull()) {
attr->DetectionsPreClass = 100;
attr->DetectionsPerClass = 100;
} else {
attr->DetectionsPreClass = attr_map["detections_per_class"].AsInt32();
attr->DetectionsPerClass = attr_map["detections_per_class"].AsInt32();
}
attr->MaxClassesPreDetection = attr_map["max_classes_per_detection"].AsInt32();
attr->MaxClassesPerDetection = attr_map["max_classes_per_detection"].AsInt32();
attr->NumClasses = attr_map["num_classes"].AsInt32();
if (attr_map["use_regular_nms"].IsNull()) {
attr->UseRegularNms = false;