[feat][assistant][I40FG0] add new Ascend operator NonMaxSuppression

This commit is contained in:
韩峥嵘 2021-12-07 13:57:25 +08:00
parent 9558ba49d8
commit efe0fce473
10 changed files with 293 additions and 8 deletions

View File

@ -63,6 +63,7 @@ constexpr auto kPadAndShift = "PadAndShift";
constexpr auto kCpuRunApi = "RunCpuKernel";
constexpr auto kDropout2D = "Dropout2D";
constexpr auto kDropout3D = "Dropout3D";
constexpr auto kNonMaxSuppressionV3 = "NonMaxSuppressionV3";
constexpr auto kMaskedSelect = "MaskedSelect";
constexpr auto kMaskedSelectGrad = "MaskedSelectGrad";
constexpr auto kDynamicStitch = "DynamicStitch";
@ -72,8 +73,8 @@ constexpr auto kResizeBilinearGrad = "ResizeBilinearGrad";
constexpr auto kScatterElements = "ScatterElements";
const std::set<std::string> kCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch,
kSearchSorted, kResizeBilinear, kResizeBilinearGrad, kScatterElements};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter,
kPadAndShift, kDropout3D, kDropout2D};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift,
kDropout3D, kDropout2D, kNonMaxSuppressionV3};
const std::set<std::string> kCpuKernelBaseOps{kGetNext, kInitData, kRandomChoiceWithMask};
const std::set<std::string> kDynamicInputOps{
kPrint, kPack, kMeshgrid, kStackInitOpName, kStackDestroyOpName, kStackPushOpName, kStackPopOpName, kDynamicStitch};

View File

@ -81,6 +81,7 @@ constexpr auto kBNTrainingReduceOpName = "BNTrainingReduce";
constexpr auto kBNTrainingUpdateOpName = "BNTrainingUpdate";
constexpr auto kBNTrainingUpdateV2OpName = "BNTrainingUpdateV2";
constexpr auto kBNTrainingUpdateV3OpName = "BNTrainingUpdateV3";
constexpr auto kNonMaxSuppressionV3OpName = "NonMaxSuppressionV3";
constexpr auto kSimpleMeanGradOpName = "SimpleMeanGrad";
constexpr auto kMeanGradOpName = "MeanGrad";
constexpr auto kSliceOpName = "Slice";
@ -716,9 +717,10 @@ const std::set<std::string> kHWSpecialFormatSet = {
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName,
kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName,
kMaskedSelectOpName, kDynamicStitchOpName, kGetNextOpName};
const std::set<std::string> kComputeDepend = {
kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName,
kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, kMaskedSelectOpName, kDynamicStitchOpName,
kGetNextOpName, kNonMaxSuppressionV3OpName};
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};

View File

@ -590,6 +590,9 @@ inline const PrimitivePtr kPrimComplex = std::make_shared<Primitive>("Complex");
inline const PrimitivePtr kPrimXdivy = std::make_shared<Primitive>("Xdivy");
inline const PrimitivePtr kPrimInv = std::make_shared<Primitive>("Inv");
// Image
inline const PrimitivePtr kPrimNonMaxSuppressionV3 = std::make_shared<Primitive>("NonMaxSuppressionV3");
// Statements
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return");
inline const PrimitivePtr kPrimUnroll = std::make_shared<Primitive>("Unroll");

View File

@ -0,0 +1,115 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include "ops/non_max_suppression_v3.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(primitive);
const int input_num = 5;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
auto boxes_shape = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]);
auto scores_shape = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]);
auto max_output_size_shape = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape]);
auto iou_threshold_shape = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->GetShapeTrack())[kShape]);
auto score_threshold_shape = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->GetShapeTrack())[kShape]);
// boxes second dimension must euqal 4
(void)CheckAndConvertUtils::CheckInteger("boxes second dimension", boxes_shape->shape()[1], kEqual, 4, prim_name);
// boxes must be rank 2
(void)CheckAndConvertUtils::CheckInteger("boxes rank", boxes_shape->shape().size(), kEqual, 2, prim_name);
// score must be rank 1
(void)CheckAndConvertUtils::CheckInteger("scores rank", scores_shape->shape().size(), kEqual, 1, prim_name);
// score length must be equal with boxes first dimension
(void)CheckAndConvertUtils::CheckInteger("scores length", scores_shape->shape()[0], kEqual, boxes_shape->shape()[0],
prim_name);
// max_output_size,iou_threshold,score_threshold must be scalar
(void)CheckAndConvertUtils::CheckInteger("max_output_size size", max_output_size_shape->shape().size(), kEqual, 0,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("iou_threshold size", iou_threshold_shape->shape().size(), kEqual, 0,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("score_threshold size", score_threshold_shape->shape().size(), kEqual, 0,
prim_name);
auto scores_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
// calculate output shape
ShapeVector selected_indices_shape = {abstract::Shape::SHP_ANY};
ShapeVector selected_indices_min_shape = {0};
ShapeVector selected_indices_max_shape;
if (scores_shape_map[kShape].size() > 0 && scores_shape_map[kShape][0] == -1) {
selected_indices_max_shape = scores_shape_map[kMaxShape];
return std::make_shared<abstract::Shape>(selected_indices_shape, selected_indices_min_shape,
selected_indices_max_shape);
}
selected_indices_max_shape = scores_shape_map[kShape];
return std::make_shared<abstract::Shape>(selected_indices_shape, selected_indices_min_shape,
selected_indices_max_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
MS_EXCEPTION_IF_NULL(prim);
const int input_num = 5;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto boxes_type = input_args[0]->BuildType();
auto scores_type = input_args[1]->BuildType();
auto max_output_size_type = input_args[2]->BuildType();
auto iou_threshold_type = input_args[3]->BuildType();
auto score_threshold_type = input_args[4]->BuildType();
// boxes and scores must have same type
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> args;
args.insert({"boxes_type", boxes_type});
args.insert({"scores_type", scores_type});
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
// iou_threshold,score_threshold must be scalar
std::map<std::string, TypePtr> args2;
args2.insert({"iou_threshold_type", iou_threshold_type});
args2.insert({"score_threshold_type", score_threshold_type});
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args2, valid_types, prim_name);
// max_output_size must be scalar
const std::set<TypePtr> valid_types2 = {kInt32, kInt64};
std::map<std::string, TypePtr> args3;
args3.insert({"max_output_size_type", max_output_size_type});
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args3, valid_types2, prim_name);
return max_output_size_type;
}
} // namespace
AbstractBasePtr NonMaxSuppressionV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(NonMaxSuppressionV3, prim::kPrimNonMaxSuppressionV3, NonMaxSuppressionV3Infer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_NON_MAX_SUPPRESSION_V3_H_
#define MINDSPORE_CORE_OPS_NON_MAX_SUPPRESSION_V3_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/abstract_value.h"
#include "abstract/dshape.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameNonMaxSuppressionV3 = "NonMaxSuppressionV3";
class NonMaxSuppressionV3 : public PrimitiveC {
public:
NonMaxSuppressionV3() : PrimitiveC(kNameNonMaxSuppressionV3) {
InitIOName({"boxes", "score", "max_output_size", "iou_threshold", "score_threshold"}, {"selected_indices"});
}
~NonMaxSuppressionV3() = default;
MS_DECLARE_PARENT(NonMaxSuppressionV3, PrimitiveC);
};
AbstractBasePtr NonMaxSuppressionV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimNonMaxSuppressionV3Ptr = std::shared_ptr<NonMaxSuppressionV3>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_NON_MAX_SUPPRESSION_V3_H_

View File

@ -82,3 +82,4 @@ from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu
from .resize_bilinear import _resize_bilinear_aicpu
from .resize_bilinear_grad import _resize_bilinear_grad_aicpu
from .scatter_elements import _scatter_elements_aicpu
from .non_max_suppression import _non_max_suppression_aicpu

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NonMaxSuppressionV3 op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
non_max_suppression_op_info = AiCPURegOp("NonMaxSuppressionV3")\
.fusion_type("OPAQUE")\
.input(0, "boxes", "required")\
.input(1, "scores", "required")\
.input(2, "max_output_size", "required")\
.input(3, "iou_threshold", "required")\
.input(4, "score_threshold", "required")\
.output(0, "selected_indices", "required")\
.dtype_format(DataType.F32_Default, DataType.F32_Default, \
DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default)\
.dtype_format(DataType.F16_Default, DataType.F16_Default, \
DataType.I32_Default, DataType.F16_Default, DataType.F16_Default, DataType.I32_Default)\
.get_op_info()
@op_info_register(non_max_suppression_op_info)
def _non_max_suppression_aicpu():
"""NonMaxSuppression AiCPU register"""
return

View File

@ -19,7 +19,7 @@ Primitive operator classes.
A collection of operators to build neural networks or to compute functions.
"""
from .image_ops import (CropAndResize)
from .image_ops import (CropAndResize, NonMaxSuppressionV3)
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unstack,
Diag, DiagPart, DType, ExpandDims, Eye,
Fill, Ones, Zeros, GatherNd, GatherV2, Gather, SparseGatherV2, InvertPermutation,
@ -207,6 +207,7 @@ __all__ = [
'UniqueWithPad',
'Concat',
'Pack',
'NonMaxSuppressionV3',
'Stack',
'Unpack',
'Unstack',

View File

@ -18,7 +18,7 @@ from ... import context
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive
class CropAndResize(PrimitiveWithInfer):
@ -144,3 +144,65 @@ class CropAndResize(PrimitiveWithInfer):
return {'shape': out_shape,
'dtype': mstype.float32,
'value': None}
class NonMaxSuppressionV3(Primitive):
r"""
Greedily selects a subset of bounding boxes in descending order of score.
.. warning::
When input "max_output_size" is negative, it will be treated as 0.
Note:
This algorithm is agnostic to where the origin is in the coordinate system.
This algorithm is invariant to orthogonal transformations and translations of the coordinate system;
thus translating or reflections of the coordinate system result in the same boxes being
selected by the algorithm.
Inputs:
- **boxes** (Tensor) - A 2-D Tensor of shape [num_boxes, 4].
- **scores** (Tensor) - A 1-D Tensor of shape [num_boxes] representing a single score
corresponding to each box (each row of boxes), the num_boxes of "scores" must be equal to
the num_boxes of "boxes".
- **max_output_size** (Union[Tensor, Number.Int]) - A scalar integer Tensor representing the maximum
number of boxes to be selected by non max suppression.
- **iou_threshold** (Union[Tensor, Number.Float]) - A 0-D float tensor representing the threshold for
deciding whether boxes overlap too much with respect to IOU, and iou_threshold must be equal or greater
than 0 and be equal or smaller than 1.
- **score_threshold** (Union[Tensor, Number.Float]) - A 0-D float tensor representing the threshold for
deciding when to remove boxes based on score.
Outputs:
A 1-D integer Tensor of shape [M] representing the selected indices from the boxes tensor,
where M <= max_output_size.
Raises:
TypeError: If the dtype of `boxes` and `scores` is different.
TypeError: If the dtype of `iou_threshold` and `score_threshold` is different.
TypeError: If `boxes` is not tensor or its dtype is not float16 or float32.
TypeEroor: If `scores` is not tensor or its dtype is not float16 or float32.
TypeError: If `max_output_size` is not tensor or scalar.If `max_output_size` is not int32 or int64.
TypeError: If `iou_threshold` is not tensor or scalar. If its type is not float16 or float32.
TypeError: If `score_threshold` is not tensor or scalar. If its type is not float16 or float32.
ValueError: If the size of shape of `boxes` is not 2 or the second value of its shape is not 4.
ValueError: If the size of shape of `scores` is not 1.
ValueError: If each of the size of shape of `max_output_size`, `iou_threshold`, `score_threshold` is not 0.
Supported Platforms:
``Ascend``
Examples:
>>> boxes = Tensor(np.array([[1, 2, 3, 4], [1, 3, 3, 4], [1, 3, 4, 4],
... [1, 1, 4, 4], [1, 1, 3, 4]]), mstype.float32)
>>> scores = Tensor(np.array([0.4, 0.5, 0.72, 0.9, 0.45]), mstype.float32)
>>> max_output_size = Tensor(5, mstype.int32)
>>> iou_threshold = Tensor(0.5, mstype.float32)
>>> score_threshold = Tensor(0, mstype.float32)
>>> nonmaxsuppression = ops.NonMaxSuppressionV3()
>>> output = nonmaxsuppression(boxes, scores, max_output_size, iou_threshold, score_threshold)
>>> print(output)
[3 2 0]
"""
@prim_attr_register
def __init__(self):
"""Initialize NonMaxSuppressionV3"""

View File

@ -2660,6 +2660,22 @@ test_case_array_ops = [
}),
]
test_case_image_ops = [
('NonMaxSuppressionV3', {
'block': P.NonMaxSuppressionV3(),
'desc_inputs': [Tensor(np.array([[20, 5, 200, 100],
[50, 50, 200, 200],
[20, 120, 150, 150],
[250, 250, 400, 350],
[90, 10, 300, 300],
[40, 220, 280, 380]]).astype(np.float32)),
Tensor(np.array([0.353, 0.624, 0.667, 0.5, 0.3, 0.46]).astype(np.float32)),
Tensor(4, mstype.int32),
Tensor(0.1, mstype.float32),
Tensor(0, mstype.float32)],
'skip': ['backward']}),
]
test_case_other_ops = [
('ScalarLog', {
'block': F.scalar_log,
@ -2983,7 +2999,7 @@ test_case_quant_ops = [
]
test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops,
test_case_other_ops, test_case_quant_ops]
test_case_other_ops, test_case_quant_ops, test_case_image_ops]
test_case = functools.reduce(lambda x, y: x + y, test_case_lists)
# use -k to select certain testcast
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm