diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h index 8a9c6aaafc2..851a71d69f4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h @@ -63,6 +63,7 @@ constexpr auto kPadAndShift = "PadAndShift"; constexpr auto kCpuRunApi = "RunCpuKernel"; constexpr auto kDropout2D = "Dropout2D"; constexpr auto kDropout3D = "Dropout3D"; +constexpr auto kNonMaxSuppressionV3 = "NonMaxSuppressionV3"; constexpr auto kMaskedSelect = "MaskedSelect"; constexpr auto kMaskedSelectGrad = "MaskedSelectGrad"; constexpr auto kDynamicStitch = "DynamicStitch"; @@ -72,8 +73,8 @@ constexpr auto kResizeBilinearGrad = "ResizeBilinearGrad"; constexpr auto kScatterElements = "ScatterElements"; const std::set kCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch, kSearchSorted, kResizeBilinear, kResizeBilinearGrad, kScatterElements}; -const std::set kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, - kPadAndShift, kDropout3D, kDropout2D}; +const std::set kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift, + kDropout3D, kDropout2D, kNonMaxSuppressionV3}; const std::set kCpuKernelBaseOps{kGetNext, kInitData, kRandomChoiceWithMask}; const std::set kDynamicInputOps{ kPrint, kPack, kMeshgrid, kStackInitOpName, kStackDestroyOpName, kStackPushOpName, kStackPopOpName, kDynamicStitch}; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 0926e274a58..9d6ea97c61f 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -81,6 +81,7 @@ constexpr auto kBNTrainingReduceOpName = "BNTrainingReduce"; constexpr auto kBNTrainingUpdateOpName = "BNTrainingUpdate"; constexpr auto kBNTrainingUpdateV2OpName = "BNTrainingUpdateV2"; constexpr auto kBNTrainingUpdateV3OpName = "BNTrainingUpdateV3"; +constexpr auto kNonMaxSuppressionV3OpName = "NonMaxSuppressionV3"; constexpr auto kSimpleMeanGradOpName = "SimpleMeanGrad"; constexpr auto kMeanGradOpName = "MeanGrad"; constexpr auto kSliceOpName = "Slice"; @@ -716,9 +717,10 @@ const std::set kHWSpecialFormatSet = { const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; -const std::set kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, - kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, - kMaskedSelectOpName, kDynamicStitchOpName, kGetNextOpName}; +const std::set kComputeDepend = { + kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName, + kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, kMaskedSelectOpName, kDynamicStitchOpName, + kGetNextOpName, kNonMaxSuppressionV3OpName}; const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D, kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC}; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 6e9c39836b0..1f4094c3cc3 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -590,6 +590,9 @@ inline const PrimitivePtr kPrimComplex = std::make_shared("Complex"); inline const PrimitivePtr kPrimXdivy = std::make_shared("Xdivy"); inline const PrimitivePtr kPrimInv = std::make_shared("Inv"); +// Image +inline const PrimitivePtr kPrimNonMaxSuppressionV3 = std::make_shared("NonMaxSuppressionV3"); + // Statements inline const PrimitivePtr kPrimReturn = std::make_shared("Return"); inline const PrimitivePtr kPrimUnroll = std::make_shared("Unroll"); diff --git a/mindspore/core/ops/non_max_suppression_v3.cc b/mindspore/core/ops/non_max_suppression_v3.cc new file mode 100644 index 00000000000..3b42f97a771 --- /dev/null +++ b/mindspore/core/ops/non_max_suppression_v3.cc @@ -0,0 +1,115 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "ops/non_max_suppression_v3.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + auto prim_name = primitive->name(); + MS_EXCEPTION_IF_NULL(primitive); + const int input_num = 5; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + CheckAndConvertUtils::CheckArgs(prim_name, input_args, 0); + CheckAndConvertUtils::CheckArgs(prim_name, input_args, 1); + auto boxes_shape = std::make_shared( + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]); + auto scores_shape = std::make_shared( + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]); + auto max_output_size_shape = std::make_shared( + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape]); + auto iou_threshold_shape = std::make_shared( + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->GetShapeTrack())[kShape]); + auto score_threshold_shape = std::make_shared( + CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->GetShapeTrack())[kShape]); + // boxes second dimension must euqal 4 + (void)CheckAndConvertUtils::CheckInteger("boxes second dimension", boxes_shape->shape()[1], kEqual, 4, prim_name); + // boxes must be rank 2 + (void)CheckAndConvertUtils::CheckInteger("boxes rank", boxes_shape->shape().size(), kEqual, 2, prim_name); + // score must be rank 1 + (void)CheckAndConvertUtils::CheckInteger("scores rank", scores_shape->shape().size(), kEqual, 1, prim_name); + // score length must be equal with boxes first dimension + (void)CheckAndConvertUtils::CheckInteger("scores length", scores_shape->shape()[0], kEqual, boxes_shape->shape()[0], + prim_name); + // max_output_size,iou_threshold,score_threshold must be scalar + (void)CheckAndConvertUtils::CheckInteger("max_output_size size", max_output_size_shape->shape().size(), kEqual, 0, + prim_name); + (void)CheckAndConvertUtils::CheckInteger("iou_threshold size", iou_threshold_shape->shape().size(), kEqual, 0, + prim_name); + (void)CheckAndConvertUtils::CheckInteger("score_threshold size", score_threshold_shape->shape().size(), kEqual, 0, + prim_name); + auto scores_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape()); + // calculate output shape + ShapeVector selected_indices_shape = {abstract::Shape::SHP_ANY}; + ShapeVector selected_indices_min_shape = {0}; + ShapeVector selected_indices_max_shape; + if (scores_shape_map[kShape].size() > 0 && scores_shape_map[kShape][0] == -1) { + selected_indices_max_shape = scores_shape_map[kMaxShape]; + return std::make_shared(selected_indices_shape, selected_indices_min_shape, + selected_indices_max_shape); + } + selected_indices_max_shape = scores_shape_map[kShape]; + return std::make_shared(selected_indices_shape, selected_indices_min_shape, + selected_indices_max_shape); +} + +TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { + auto prim_name = prim->name(); + MS_EXCEPTION_IF_NULL(prim); + const int input_num = 5; + (void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto boxes_type = input_args[0]->BuildType(); + auto scores_type = input_args[1]->BuildType(); + auto max_output_size_type = input_args[2]->BuildType(); + auto iou_threshold_type = input_args[3]->BuildType(); + auto score_threshold_type = input_args[4]->BuildType(); + // boxes and scores must have same type + const std::set valid_types = {kFloat16, kFloat32}; + std::map args; + args.insert({"boxes_type", boxes_type}); + args.insert({"scores_type", scores_type}); + (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); + // iou_threshold,score_threshold must be scalar + std::map args2; + args2.insert({"iou_threshold_type", iou_threshold_type}); + args2.insert({"score_threshold_type", score_threshold_type}); + (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args2, valid_types, prim_name); + // max_output_size must be scalar + const std::set valid_types2 = {kInt32, kInt64}; + std::map args3; + args3.insert({"max_output_size_type", max_output_size_type}); + (void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args3, valid_types2, prim_name); + return max_output_size_type; +} +} // namespace +AbstractBasePtr NonMaxSuppressionV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args)); +} +REGISTER_PRIMITIVE_EVAL_IMPL(NonMaxSuppressionV3, prim::kPrimNonMaxSuppressionV3, NonMaxSuppressionV3Infer, nullptr, + true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/non_max_suppression_v3.h b/mindspore/core/ops/non_max_suppression_v3.h new file mode 100644 index 00000000000..d7347745f3a --- /dev/null +++ b/mindspore/core/ops/non_max_suppression_v3.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_NON_MAX_SUPPRESSION_V3_H_ +#define MINDSPORE_CORE_OPS_NON_MAX_SUPPRESSION_V3_H_ + +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "ops/primitive_c.h" +#include "abstract/primitive_infer_map.h" +#include "abstract/abstract_value.h" +#include "abstract/dshape.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameNonMaxSuppressionV3 = "NonMaxSuppressionV3"; +class NonMaxSuppressionV3 : public PrimitiveC { + public: + NonMaxSuppressionV3() : PrimitiveC(kNameNonMaxSuppressionV3) { + InitIOName({"boxes", "score", "max_output_size", "iou_threshold", "score_threshold"}, {"selected_indices"}); + } + ~NonMaxSuppressionV3() = default; + MS_DECLARE_PARENT(NonMaxSuppressionV3, PrimitiveC); +}; +AbstractBasePtr NonMaxSuppressionV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +using PrimNonMaxSuppressionV3Ptr = std::shared_ptr; +} // namespace ops +} // namespace mindspore +#endif // MINDSPORE_CORE_OPS_NON_MAX_SUPPRESSION_V3_H_ diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index 3b88631eacd..ec757bdc6ed 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -82,3 +82,4 @@ from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu from .resize_bilinear import _resize_bilinear_aicpu from .resize_bilinear_grad import _resize_bilinear_grad_aicpu from .scatter_elements import _scatter_elements_aicpu +from .non_max_suppression import _non_max_suppression_aicpu diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/non_max_suppression.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/non_max_suppression.py new file mode 100644 index 00000000000..9a6745124a8 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/non_max_suppression.py @@ -0,0 +1,36 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""NonMaxSuppressionV3 op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +non_max_suppression_op_info = AiCPURegOp("NonMaxSuppressionV3")\ + .fusion_type("OPAQUE")\ + .input(0, "boxes", "required")\ + .input(1, "scores", "required")\ + .input(2, "max_output_size", "required")\ + .input(3, "iou_threshold", "required")\ + .input(4, "score_threshold", "required")\ + .output(0, "selected_indices", "required")\ + .dtype_format(DataType.F32_Default, DataType.F32_Default, \ + DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default)\ + .dtype_format(DataType.F16_Default, DataType.F16_Default, \ + DataType.I32_Default, DataType.F16_Default, DataType.F16_Default, DataType.I32_Default)\ + .get_op_info() + +@op_info_register(non_max_suppression_op_info) +def _non_max_suppression_aicpu(): + """NonMaxSuppression AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index c246da9df59..24a7af56852 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -19,7 +19,7 @@ Primitive operator classes. A collection of operators to build neural networks or to compute functions. """ -from .image_ops import (CropAndResize) +from .image_ops import (CropAndResize, NonMaxSuppressionV3) from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unstack, Diag, DiagPart, DType, ExpandDims, Eye, Fill, Ones, Zeros, GatherNd, GatherV2, Gather, SparseGatherV2, InvertPermutation, @@ -207,6 +207,7 @@ __all__ = [ 'UniqueWithPad', 'Concat', 'Pack', + 'NonMaxSuppressionV3', 'Stack', 'Unpack', 'Unstack', diff --git a/mindspore/python/mindspore/ops/operations/image_ops.py b/mindspore/python/mindspore/ops/operations/image_ops.py index 4ef5c1d9829..ab8816b9b81 100644 --- a/mindspore/python/mindspore/ops/operations/image_ops.py +++ b/mindspore/python/mindspore/ops/operations/image_ops.py @@ -18,7 +18,7 @@ from ... import context from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype -from ..primitive import PrimitiveWithInfer, prim_attr_register +from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive class CropAndResize(PrimitiveWithInfer): @@ -144,3 +144,65 @@ class CropAndResize(PrimitiveWithInfer): return {'shape': out_shape, 'dtype': mstype.float32, 'value': None} + +class NonMaxSuppressionV3(Primitive): + r""" + Greedily selects a subset of bounding boxes in descending order of score. + + .. warning:: + When input "max_output_size" is negative, it will be treated as 0. + + Note: + This algorithm is agnostic to where the origin is in the coordinate system. + This algorithm is invariant to orthogonal transformations and translations of the coordinate system; + thus translating or reflections of the coordinate system result in the same boxes being + selected by the algorithm. + + Inputs: + - **boxes** (Tensor) - A 2-D Tensor of shape [num_boxes, 4]. + - **scores** (Tensor) - A 1-D Tensor of shape [num_boxes] representing a single score + corresponding to each box (each row of boxes), the num_boxes of "scores" must be equal to + the num_boxes of "boxes". + - **max_output_size** (Union[Tensor, Number.Int]) - A scalar integer Tensor representing the maximum + number of boxes to be selected by non max suppression. + - **iou_threshold** (Union[Tensor, Number.Float]) - A 0-D float tensor representing the threshold for + deciding whether boxes overlap too much with respect to IOU, and iou_threshold must be equal or greater + than 0 and be equal or smaller than 1. + - **score_threshold** (Union[Tensor, Number.Float]) - A 0-D float tensor representing the threshold for + deciding when to remove boxes based on score. + + Outputs: + A 1-D integer Tensor of shape [M] representing the selected indices from the boxes tensor, + where M <= max_output_size. + + Raises: + TypeError: If the dtype of `boxes` and `scores` is different. + TypeError: If the dtype of `iou_threshold` and `score_threshold` is different. + TypeError: If `boxes` is not tensor or its dtype is not float16 or float32. + TypeEroor: If `scores` is not tensor or its dtype is not float16 or float32. + TypeError: If `max_output_size` is not tensor or scalar.If `max_output_size` is not int32 or int64. + TypeError: If `iou_threshold` is not tensor or scalar. If its type is not float16 or float32. + TypeError: If `score_threshold` is not tensor or scalar. If its type is not float16 or float32. + ValueError: If the size of shape of `boxes` is not 2 or the second value of its shape is not 4. + ValueError: If the size of shape of `scores` is not 1. + ValueError: If each of the size of shape of `max_output_size`, `iou_threshold`, `score_threshold` is not 0. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> boxes = Tensor(np.array([[1, 2, 3, 4], [1, 3, 3, 4], [1, 3, 4, 4], + ... [1, 1, 4, 4], [1, 1, 3, 4]]), mstype.float32) + >>> scores = Tensor(np.array([0.4, 0.5, 0.72, 0.9, 0.45]), mstype.float32) + >>> max_output_size = Tensor(5, mstype.int32) + >>> iou_threshold = Tensor(0.5, mstype.float32) + >>> score_threshold = Tensor(0, mstype.float32) + >>> nonmaxsuppression = ops.NonMaxSuppressionV3() + >>> output = nonmaxsuppression(boxes, scores, max_output_size, iou_threshold, score_threshold) + >>> print(output) + [3 2 0] + """ + + @prim_attr_register + def __init__(self): + """Initialize NonMaxSuppressionV3""" diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 57023f7605d..f24d3b41af6 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -2660,6 +2660,22 @@ test_case_array_ops = [ }), ] +test_case_image_ops = [ + ('NonMaxSuppressionV3', { + 'block': P.NonMaxSuppressionV3(), + 'desc_inputs': [Tensor(np.array([[20, 5, 200, 100], + [50, 50, 200, 200], + [20, 120, 150, 150], + [250, 250, 400, 350], + [90, 10, 300, 300], + [40, 220, 280, 380]]).astype(np.float32)), + Tensor(np.array([0.353, 0.624, 0.667, 0.5, 0.3, 0.46]).astype(np.float32)), + Tensor(4, mstype.int32), + Tensor(0.1, mstype.float32), + Tensor(0, mstype.float32)], + 'skip': ['backward']}), +] + test_case_other_ops = [ ('ScalarLog', { 'block': F.scalar_log, @@ -2983,7 +2999,7 @@ test_case_quant_ops = [ ] test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops, - test_case_other_ops, test_case_quant_ops] + test_case_other_ops, test_case_quant_ops, test_case_image_ops] test_case = functools.reduce(lambda x, y: x + y, test_case_lists) # use -k to select certain testcast # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm