[feat][assistant][I40FG0] add new Ascend operator NonMaxSuppression

This commit is contained in:
韩峥嵘 2021-12-07 13:57:25 +08:00
parent 9558ba49d8
commit efe0fce473
10 changed files with 293 additions and 8 deletions

View File

@ -63,6 +63,7 @@ constexpr auto kPadAndShift = "PadAndShift";
constexpr auto kCpuRunApi = "RunCpuKernel"; constexpr auto kCpuRunApi = "RunCpuKernel";
constexpr auto kDropout2D = "Dropout2D"; constexpr auto kDropout2D = "Dropout2D";
constexpr auto kDropout3D = "Dropout3D"; constexpr auto kDropout3D = "Dropout3D";
constexpr auto kNonMaxSuppressionV3 = "NonMaxSuppressionV3";
constexpr auto kMaskedSelect = "MaskedSelect"; constexpr auto kMaskedSelect = "MaskedSelect";
constexpr auto kMaskedSelectGrad = "MaskedSelectGrad"; constexpr auto kMaskedSelectGrad = "MaskedSelectGrad";
constexpr auto kDynamicStitch = "DynamicStitch"; constexpr auto kDynamicStitch = "DynamicStitch";
@ -72,8 +73,8 @@ constexpr auto kResizeBilinearGrad = "ResizeBilinearGrad";
constexpr auto kScatterElements = "ScatterElements"; constexpr auto kScatterElements = "ScatterElements";
const std::set<std::string> kCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch, const std::set<std::string> kCpuKernelOps{kIdentity, kMaskedSelect, kMaskedSelectGrad, kDynamicStitch,
kSearchSorted, kResizeBilinear, kResizeBilinearGrad, kScatterElements}; kSearchSorted, kResizeBilinear, kResizeBilinearGrad, kScatterElements};
const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, const std::set<std::string> kCacheKernelOps{kUpdateCache, kCacheSwapTable, kSubAndFilter, kPadAndShift,
kPadAndShift, kDropout3D, kDropout2D}; kDropout3D, kDropout2D, kNonMaxSuppressionV3};
const std::set<std::string> kCpuKernelBaseOps{kGetNext, kInitData, kRandomChoiceWithMask}; const std::set<std::string> kCpuKernelBaseOps{kGetNext, kInitData, kRandomChoiceWithMask};
const std::set<std::string> kDynamicInputOps{ const std::set<std::string> kDynamicInputOps{
kPrint, kPack, kMeshgrid, kStackInitOpName, kStackDestroyOpName, kStackPushOpName, kStackPopOpName, kDynamicStitch}; kPrint, kPack, kMeshgrid, kStackInitOpName, kStackDestroyOpName, kStackPushOpName, kStackPopOpName, kDynamicStitch};

View File

@ -81,6 +81,7 @@ constexpr auto kBNTrainingReduceOpName = "BNTrainingReduce";
constexpr auto kBNTrainingUpdateOpName = "BNTrainingUpdate"; constexpr auto kBNTrainingUpdateOpName = "BNTrainingUpdate";
constexpr auto kBNTrainingUpdateV2OpName = "BNTrainingUpdateV2"; constexpr auto kBNTrainingUpdateV2OpName = "BNTrainingUpdateV2";
constexpr auto kBNTrainingUpdateV3OpName = "BNTrainingUpdateV3"; constexpr auto kBNTrainingUpdateV3OpName = "BNTrainingUpdateV3";
constexpr auto kNonMaxSuppressionV3OpName = "NonMaxSuppressionV3";
constexpr auto kSimpleMeanGradOpName = "SimpleMeanGrad"; constexpr auto kSimpleMeanGradOpName = "SimpleMeanGrad";
constexpr auto kMeanGradOpName = "MeanGrad"; constexpr auto kMeanGradOpName = "MeanGrad";
constexpr auto kSliceOpName = "Slice"; constexpr auto kSliceOpName = "Slice";
@ -716,9 +717,10 @@ const std::set<std::string> kHWSpecialFormatSet = {
const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32}; const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};
const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, const std::set<std::string> kComputeDepend = {
kPadAndShiftOpName, kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName, kPadAndShiftOpName,
kMaskedSelectOpName, kDynamicStitchOpName, kGetNextOpName}; kCTCGreedyDecoderOpName, kDropoutGenMaskOpName, kMaskedSelectOpName, kDynamicStitchOpName,
kGetNextOpName, kNonMaxSuppressionV3OpName};
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D, const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC}; kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};

View File

@ -590,6 +590,9 @@ inline const PrimitivePtr kPrimComplex = std::make_shared<Primitive>("Complex");
inline const PrimitivePtr kPrimXdivy = std::make_shared<Primitive>("Xdivy"); inline const PrimitivePtr kPrimXdivy = std::make_shared<Primitive>("Xdivy");
inline const PrimitivePtr kPrimInv = std::make_shared<Primitive>("Inv"); inline const PrimitivePtr kPrimInv = std::make_shared<Primitive>("Inv");
// Image
inline const PrimitivePtr kPrimNonMaxSuppressionV3 = std::make_shared<Primitive>("NonMaxSuppressionV3");
// Statements // Statements
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return"); inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("Return");
inline const PrimitivePtr kPrimUnroll = std::make_shared<Primitive>("Unroll"); inline const PrimitivePtr kPrimUnroll = std::make_shared<Primitive>("Unroll");

View File

@ -0,0 +1,115 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include "ops/non_max_suppression_v3.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
MS_EXCEPTION_IF_NULL(primitive);
const int input_num = 5;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
auto boxes_shape = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape]);
auto scores_shape = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->GetShapeTrack())[kShape]);
auto max_output_size_shape = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->GetShapeTrack())[kShape]);
auto iou_threshold_shape = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[3]->GetShapeTrack())[kShape]);
auto score_threshold_shape = std::make_shared<abstract::Shape>(
CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[4]->GetShapeTrack())[kShape]);
// boxes second dimension must euqal 4
(void)CheckAndConvertUtils::CheckInteger("boxes second dimension", boxes_shape->shape()[1], kEqual, 4, prim_name);
// boxes must be rank 2
(void)CheckAndConvertUtils::CheckInteger("boxes rank", boxes_shape->shape().size(), kEqual, 2, prim_name);
// score must be rank 1
(void)CheckAndConvertUtils::CheckInteger("scores rank", scores_shape->shape().size(), kEqual, 1, prim_name);
// score length must be equal with boxes first dimension
(void)CheckAndConvertUtils::CheckInteger("scores length", scores_shape->shape()[0], kEqual, boxes_shape->shape()[0],
prim_name);
// max_output_size,iou_threshold,score_threshold must be scalar
(void)CheckAndConvertUtils::CheckInteger("max_output_size size", max_output_size_shape->shape().size(), kEqual, 0,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("iou_threshold size", iou_threshold_shape->shape().size(), kEqual, 0,
prim_name);
(void)CheckAndConvertUtils::CheckInteger("score_threshold size", score_threshold_shape->shape().size(), kEqual, 0,
prim_name);
auto scores_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape());
// calculate output shape
ShapeVector selected_indices_shape = {abstract::Shape::SHP_ANY};
ShapeVector selected_indices_min_shape = {0};
ShapeVector selected_indices_max_shape;
if (scores_shape_map[kShape].size() > 0 && scores_shape_map[kShape][0] == -1) {
selected_indices_max_shape = scores_shape_map[kMaxShape];
return std::make_shared<abstract::Shape>(selected_indices_shape, selected_indices_min_shape,
selected_indices_max_shape);
}
selected_indices_max_shape = scores_shape_map[kShape];
return std::make_shared<abstract::Shape>(selected_indices_shape, selected_indices_min_shape,
selected_indices_max_shape);
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = prim->name();
MS_EXCEPTION_IF_NULL(prim);
const int input_num = 5;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, prim_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto boxes_type = input_args[0]->BuildType();
auto scores_type = input_args[1]->BuildType();
auto max_output_size_type = input_args[2]->BuildType();
auto iou_threshold_type = input_args[3]->BuildType();
auto score_threshold_type = input_args[4]->BuildType();
// boxes and scores must have same type
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> args;
args.insert({"boxes_type", boxes_type});
args.insert({"scores_type", scores_type});
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
// iou_threshold,score_threshold must be scalar
std::map<std::string, TypePtr> args2;
args2.insert({"iou_threshold_type", iou_threshold_type});
args2.insert({"score_threshold_type", score_threshold_type});
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args2, valid_types, prim_name);
// max_output_size must be scalar
const std::set<TypePtr> valid_types2 = {kInt32, kInt64};
std::map<std::string, TypePtr> args3;
args3.insert({"max_output_size_type", max_output_size_type});
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args3, valid_types2, prim_name);
return max_output_size_type;
}
} // namespace
AbstractBasePtr NonMaxSuppressionV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(NonMaxSuppressionV3, prim::kPrimNonMaxSuppressionV3, NonMaxSuppressionV3Infer, nullptr,
true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_NON_MAX_SUPPRESSION_V3_H_
#define MINDSPORE_CORE_OPS_NON_MAX_SUPPRESSION_V3_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "abstract/primitive_infer_map.h"
#include "abstract/abstract_value.h"
#include "abstract/dshape.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameNonMaxSuppressionV3 = "NonMaxSuppressionV3";
class NonMaxSuppressionV3 : public PrimitiveC {
public:
NonMaxSuppressionV3() : PrimitiveC(kNameNonMaxSuppressionV3) {
InitIOName({"boxes", "score", "max_output_size", "iou_threshold", "score_threshold"}, {"selected_indices"});
}
~NonMaxSuppressionV3() = default;
MS_DECLARE_PARENT(NonMaxSuppressionV3, PrimitiveC);
};
AbstractBasePtr NonMaxSuppressionV3Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimNonMaxSuppressionV3Ptr = std::shared_ptr<NonMaxSuppressionV3>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_NON_MAX_SUPPRESSION_V3_H_

View File

@ -82,3 +82,4 @@ from .ctc_greedy_decoder import _ctc_greedy_decoder_aicpu
from .resize_bilinear import _resize_bilinear_aicpu from .resize_bilinear import _resize_bilinear_aicpu
from .resize_bilinear_grad import _resize_bilinear_grad_aicpu from .resize_bilinear_grad import _resize_bilinear_grad_aicpu
from .scatter_elements import _scatter_elements_aicpu from .scatter_elements import _scatter_elements_aicpu
from .non_max_suppression import _non_max_suppression_aicpu

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""NonMaxSuppressionV3 op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
non_max_suppression_op_info = AiCPURegOp("NonMaxSuppressionV3")\
.fusion_type("OPAQUE")\
.input(0, "boxes", "required")\
.input(1, "scores", "required")\
.input(2, "max_output_size", "required")\
.input(3, "iou_threshold", "required")\
.input(4, "score_threshold", "required")\
.output(0, "selected_indices", "required")\
.dtype_format(DataType.F32_Default, DataType.F32_Default, \
DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.I32_Default)\
.dtype_format(DataType.F16_Default, DataType.F16_Default, \
DataType.I32_Default, DataType.F16_Default, DataType.F16_Default, DataType.I32_Default)\
.get_op_info()
@op_info_register(non_max_suppression_op_info)
def _non_max_suppression_aicpu():
"""NonMaxSuppression AiCPU register"""
return

View File

@ -19,7 +19,7 @@ Primitive operator classes.
A collection of operators to build neural networks or to compute functions. A collection of operators to build neural networks or to compute functions.
""" """
from .image_ops import (CropAndResize) from .image_ops import (CropAndResize, NonMaxSuppressionV3)
from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unstack, from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unstack,
Diag, DiagPart, DType, ExpandDims, Eye, Diag, DiagPart, DType, ExpandDims, Eye,
Fill, Ones, Zeros, GatherNd, GatherV2, Gather, SparseGatherV2, InvertPermutation, Fill, Ones, Zeros, GatherNd, GatherV2, Gather, SparseGatherV2, InvertPermutation,
@ -207,6 +207,7 @@ __all__ = [
'UniqueWithPad', 'UniqueWithPad',
'Concat', 'Concat',
'Pack', 'Pack',
'NonMaxSuppressionV3',
'Stack', 'Stack',
'Unpack', 'Unpack',
'Unstack', 'Unstack',

View File

@ -18,7 +18,7 @@ from ... import context
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register from ..primitive import PrimitiveWithInfer, prim_attr_register, Primitive
class CropAndResize(PrimitiveWithInfer): class CropAndResize(PrimitiveWithInfer):
@ -144,3 +144,65 @@ class CropAndResize(PrimitiveWithInfer):
return {'shape': out_shape, return {'shape': out_shape,
'dtype': mstype.float32, 'dtype': mstype.float32,
'value': None} 'value': None}
class NonMaxSuppressionV3(Primitive):
r"""
Greedily selects a subset of bounding boxes in descending order of score.
.. warning::
When input "max_output_size" is negative, it will be treated as 0.
Note:
This algorithm is agnostic to where the origin is in the coordinate system.
This algorithm is invariant to orthogonal transformations and translations of the coordinate system;
thus translating or reflections of the coordinate system result in the same boxes being
selected by the algorithm.
Inputs:
- **boxes** (Tensor) - A 2-D Tensor of shape [num_boxes, 4].
- **scores** (Tensor) - A 1-D Tensor of shape [num_boxes] representing a single score
corresponding to each box (each row of boxes), the num_boxes of "scores" must be equal to
the num_boxes of "boxes".
- **max_output_size** (Union[Tensor, Number.Int]) - A scalar integer Tensor representing the maximum
number of boxes to be selected by non max suppression.
- **iou_threshold** (Union[Tensor, Number.Float]) - A 0-D float tensor representing the threshold for
deciding whether boxes overlap too much with respect to IOU, and iou_threshold must be equal or greater
than 0 and be equal or smaller than 1.
- **score_threshold** (Union[Tensor, Number.Float]) - A 0-D float tensor representing the threshold for
deciding when to remove boxes based on score.
Outputs:
A 1-D integer Tensor of shape [M] representing the selected indices from the boxes tensor,
where M <= max_output_size.
Raises:
TypeError: If the dtype of `boxes` and `scores` is different.
TypeError: If the dtype of `iou_threshold` and `score_threshold` is different.
TypeError: If `boxes` is not tensor or its dtype is not float16 or float32.
TypeEroor: If `scores` is not tensor or its dtype is not float16 or float32.
TypeError: If `max_output_size` is not tensor or scalar.If `max_output_size` is not int32 or int64.
TypeError: If `iou_threshold` is not tensor or scalar. If its type is not float16 or float32.
TypeError: If `score_threshold` is not tensor or scalar. If its type is not float16 or float32.
ValueError: If the size of shape of `boxes` is not 2 or the second value of its shape is not 4.
ValueError: If the size of shape of `scores` is not 1.
ValueError: If each of the size of shape of `max_output_size`, `iou_threshold`, `score_threshold` is not 0.
Supported Platforms:
``Ascend``
Examples:
>>> boxes = Tensor(np.array([[1, 2, 3, 4], [1, 3, 3, 4], [1, 3, 4, 4],
... [1, 1, 4, 4], [1, 1, 3, 4]]), mstype.float32)
>>> scores = Tensor(np.array([0.4, 0.5, 0.72, 0.9, 0.45]), mstype.float32)
>>> max_output_size = Tensor(5, mstype.int32)
>>> iou_threshold = Tensor(0.5, mstype.float32)
>>> score_threshold = Tensor(0, mstype.float32)
>>> nonmaxsuppression = ops.NonMaxSuppressionV3()
>>> output = nonmaxsuppression(boxes, scores, max_output_size, iou_threshold, score_threshold)
>>> print(output)
[3 2 0]
"""
@prim_attr_register
def __init__(self):
"""Initialize NonMaxSuppressionV3"""

View File

@ -2660,6 +2660,22 @@ test_case_array_ops = [
}), }),
] ]
test_case_image_ops = [
('NonMaxSuppressionV3', {
'block': P.NonMaxSuppressionV3(),
'desc_inputs': [Tensor(np.array([[20, 5, 200, 100],
[50, 50, 200, 200],
[20, 120, 150, 150],
[250, 250, 400, 350],
[90, 10, 300, 300],
[40, 220, 280, 380]]).astype(np.float32)),
Tensor(np.array([0.353, 0.624, 0.667, 0.5, 0.3, 0.46]).astype(np.float32)),
Tensor(4, mstype.int32),
Tensor(0.1, mstype.float32),
Tensor(0, mstype.float32)],
'skip': ['backward']}),
]
test_case_other_ops = [ test_case_other_ops = [
('ScalarLog', { ('ScalarLog', {
'block': F.scalar_log, 'block': F.scalar_log,
@ -2983,7 +2999,7 @@ test_case_quant_ops = [
] ]
test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops, test_case_lists = [test_case_nn_ops, test_case_math_ops, test_case_array_ops,
test_case_other_ops, test_case_quant_ops] test_case_other_ops, test_case_quant_ops, test_case_image_ops]
test_case = functools.reduce(lambda x, y: x + y, test_case_lists) test_case = functools.reduce(lambda x, y: x + y, test_case_lists)
# use -k to select certain testcast # use -k to select certain testcast
# pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm # pytest tests/python/ops/test_ops.py::test_backward -k LayerNorm