!43087 [assistant][ops][aicpu][I4XJI9]Add Sort operator

Merge pull request !43087 from 李定维/Sort
This commit is contained in:
i-robot 2022-11-17 02:45:06 +00:00 committed by Gitee
commit 8edb949d35
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 134 additions and 47 deletions

View File

@ -128,7 +128,17 @@ std::vector<std::pair<KernelAttr, SortCpuKernelMod::SortFunc>> SortCpuKernelMod:
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
&SortCpuKernelMod::LaunchKernel<float16>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
&SortCpuKernelMod::LaunchKernel<float>}};
&SortCpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32),
&SortCpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
&SortCpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
&SortCpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&SortCpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
&SortCpuKernelMod::LaunchKernel<int64_t>}};
std::vector<KernelAttr> SortCpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;

View File

@ -26,13 +26,31 @@
namespace mindspore {
namespace ops {
void Sort::Init(int64_t axis, bool descending) {
this->set_axis(axis);
this->set_descending(descending);
}
void Sort::set_axis(int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
int64_t Sort::get_axis() const {
auto axis = this->GetAttr(kAxis);
MS_EXCEPTION_IF_NULL(axis);
return GetValue<int64_t>(axis);
}
void Sort::set_descending(bool descending) { (void)this->AddAttr(kDescending, api::MakeValue(descending)); }
bool Sort::get_descending() const {
auto descending = this->GetAttr(kDescending);
MS_EXCEPTION_IF_NULL(descending);
return GetValue<bool>(descending);
}
namespace {
abstract::TupleShapePtr SortInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 1;
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, input_num,
prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x_rank = SizeToLong(x_shape.size());
auto axis = GetValue<int64_t>(primitive->GetAttr("axis"));
@ -45,11 +63,11 @@ abstract::TupleShapePtr SortInferShape(const PrimitivePtr &primitive, const std:
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{shape_element, shape_element});
}
TuplePtr SortInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
TuplePtr SortInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto infer_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(infer_type);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputx", infer_type, valid_types, prim->name());
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kUInt8, kInt8, kInt16, kInt32, kInt64};
auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputx", infer_type, valid_types, primitive->name());
std::vector<TypePtr> type_tuple;
type_tuple.push_back(type);
type_tuple.push_back(kInt32);
@ -57,24 +75,17 @@ TuplePtr SortInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
}
} // namespace
MIND_API_OPERATOR_IMPL(Sort, BaseOperator);
AbstractBasePtr SortInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto infertype = SortInferType(primitive, input_args);
auto infershape = SortInferShape(primitive, input_args);
return abstract::MakeAbstract(infershape, infertype);
}
int64_t Sort::get_axis() const {
auto value_ptr = this->GetAttr(kAxis);
return GetValue<int64_t>(value_ptr);
}
bool Sort::get_descending() const {
auto value_ptr = this->GetAttr(kDescending);
return GetValue<bool>(value_ptr);
MS_EXCEPTION_IF_NULL(primitive);
const int64_t kSortInputsNum = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kSortInputsNum, primitive->name());
auto infer_type = SortInferType(primitive, input_args);
auto infer_shape = SortInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(Sort, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(Sort, prim::kPrimSort, SortInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -16,11 +16,11 @@
#ifndef MINDSPORE_CORE_OPS_SORT_H_
#define MINDSPORE_CORE_OPS_SORT_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
@ -31,8 +31,10 @@ class MIND_API Sort : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Sort);
Sort() : BaseOperator(kNameSort) { InitIOName({"x"}, {"y1", "y2"}); }
void Init(int64_t axis = -1, bool descending = false);
void set_axis(int64_t axis);
int64_t get_axis() const;
void set_descending(bool descending);
bool get_descending() const;
};
abstract::AbstractBasePtr SortInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -0,0 +1,39 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sort op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
sort_op_info = AiCPURegOp("Sort") \
.fusion_type("OPAQUE") \
.attr("axis", "int", "-1") \
.attr("descending", "bool", "False") \
.input(0, "x", "required") \
.output(0, "y1", "required") \
.output(0, "y2", "required") \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I32_Default) \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I32_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(sort_op_info)
def _sort_aicpu():
"""Sort AiCPU register"""
return

View File

@ -2611,6 +2611,52 @@ def scatter_nd_min(input_x, indices, updates, use_locking=False):
return scatter_nd_min_inner(input_x, indices, updates)
def sort(input_x, axis=-1, descending=False):
r"""
Sorts the elements of the input tensor along the given dimension in the specified order.
Args:
input_x(Tensor): The input tensor to sort.
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
axis (int): The dimension to sort along. Default: -1.
descending (bool): Controls the sort order. If descending is True then the elements
are sorted in descending order by value. Default: False.
.. warning::
Currently, the data types of Float16, UInt8, Int8, Int16, Int32, Int64 are supported.
If use Float32, it may cause loss of accuracy.
Returns:
y1(Tensor) - A tensor whose values are the sorted values, with the same shape and data type as input.
y2(Tensor) - The indices of the elements in the original input tensor. Data type is int32.
Raises:
TypeError: If `axis` is not an int.
TypeError: If `descending` is not a bool.
TypeError: If dtype of `x` is neither float16, float32, uint8, int8, int16, int32, int64.
ValueError: If `axis` is not in range of [-len(x.shape), len(x.shape)).
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mindspore.float16)
>>> sort = ops.Sort()
>>> output = sort(x)
>>> # The output below is based on the Ascend platform.
>>> print(output)
(Tensor(shape=[3, 3], dtype=Float16, value=
[[ 1.0000e+00, 2.0000e+00, 8.0000e+00],
[ 3.0000e+00, 5.0000e+00, 9.0000e+00],
[ 4.0000e+00, 6.0000e+00, 7.0000e+00]]), Tensor(shape=[3, 3], dtype=Int32, value=
[[2, 1, 0],
[2, 0, 1],
[0, 1, 2]]))
"""
_sort = _get_cache_prim(P.Sort)(axis, descending)
return _sort(input_x)
def gather(input_params, input_indices, axis):
r"""
Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.

View File

@ -5712,30 +5712,9 @@ class TransShape(PrimitiveWithInfer):
class Sort(Primitive):
"""
Sorts the elements of the input tensor along a given dimension in ascending order by value.
Sorts the elements of the input tensor along the given dimension in the specified order.
Args:
axis (int): The dimension to sort along. Default: -1.
descending (bool): Controls the sorting order. If descending is True then the elements
are sorted in descending order by value. Default: False.
.. warning::
Currently, only the data type of Float16 is supported. If use Float32, it may cause loss
of accuracy.
Inputs:
- **x** (Tensor) - The input to sort, with float16 or float32 data type.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
Outputs:
- **y1** (Tensor) - A tensor whose values are the sorted values, with the same shape and data type as input.
- **y2** (Tensor) - The indices of the elements in the original input tensor. Data type is int32.
Raises:
TypeError: If `axis` is not an int.
TypeError: If `descending` is not a bool.
TypeError: If dtype of `x` is neither float16 nor float32.
ValueError: If `axis` is not in range of [-len(x.shape), len(x.shape)).
Refer to :func:'mindspore.ops.sort' for more details.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``