forked from mindspore-Ecosystem/mindspore
!43087 [assistant][ops][aicpu][I4XJI9]Add Sort operator
Merge pull request !43087 from 李定维/Sort
This commit is contained in:
commit
8edb949d35
|
@ -128,7 +128,17 @@ std::vector<std::pair<KernelAttr, SortCpuKernelMod::SortFunc>> SortCpuKernelMod:
|
|||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
|
||||
&SortCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
|
||||
&SortCpuKernelMod::LaunchKernel<float>}};
|
||||
&SortCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32),
|
||||
&SortCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32),
|
||||
&SortCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32),
|
||||
&SortCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&SortCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
&SortCpuKernelMod::LaunchKernel<int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> SortCpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
|
|
|
@ -26,13 +26,31 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void Sort::Init(int64_t axis, bool descending) {
|
||||
this->set_axis(axis);
|
||||
this->set_descending(descending);
|
||||
}
|
||||
|
||||
void Sort::set_axis(int64_t axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
|
||||
|
||||
int64_t Sort::get_axis() const {
|
||||
auto axis = this->GetAttr(kAxis);
|
||||
MS_EXCEPTION_IF_NULL(axis);
|
||||
return GetValue<int64_t>(axis);
|
||||
}
|
||||
|
||||
void Sort::set_descending(bool descending) { (void)this->AddAttr(kDescending, api::MakeValue(descending)); }
|
||||
|
||||
bool Sort::get_descending() const {
|
||||
auto descending = this->GetAttr(kDescending);
|
||||
MS_EXCEPTION_IF_NULL(descending);
|
||||
return GetValue<bool>(descending);
|
||||
}
|
||||
|
||||
namespace {
|
||||
abstract::TupleShapePtr SortInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 1;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, input_num,
|
||||
prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto x_rank = SizeToLong(x_shape.size());
|
||||
auto axis = GetValue<int64_t>(primitive->GetAttr("axis"));
|
||||
|
@ -45,11 +63,11 @@ abstract::TupleShapePtr SortInferShape(const PrimitivePtr &primitive, const std:
|
|||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{shape_element, shape_element});
|
||||
}
|
||||
|
||||
TuplePtr SortInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TuplePtr SortInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto infer_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(infer_type);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputx", infer_type, valid_types, prim->name());
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kUInt8, kInt8, kInt16, kInt32, kInt64};
|
||||
auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputx", infer_type, valid_types, primitive->name());
|
||||
std::vector<TypePtr> type_tuple;
|
||||
type_tuple.push_back(type);
|
||||
type_tuple.push_back(kInt32);
|
||||
|
@ -57,24 +75,17 @@ TuplePtr SortInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
|
|||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Sort, BaseOperator);
|
||||
AbstractBasePtr SortInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto infertype = SortInferType(primitive, input_args);
|
||||
auto infershape = SortInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infershape, infertype);
|
||||
}
|
||||
|
||||
int64_t Sort::get_axis() const {
|
||||
auto value_ptr = this->GetAttr(kAxis);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
bool Sort::get_descending() const {
|
||||
auto value_ptr = this->GetAttr(kDescending);
|
||||
return GetValue<bool>(value_ptr);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t kSortInputsNum = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, kSortInputsNum, primitive->name());
|
||||
auto infer_type = SortInferType(primitive, input_args);
|
||||
auto infer_shape = SortInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Sort, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Sort, prim::kPrimSort, SortInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,11 +16,11 @@
|
|||
|
||||
#ifndef MINDSPORE_CORE_OPS_SORT_H_
|
||||
#define MINDSPORE_CORE_OPS_SORT_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
|
@ -31,8 +31,10 @@ class MIND_API Sort : public BaseOperator {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(Sort);
|
||||
Sort() : BaseOperator(kNameSort) { InitIOName({"x"}, {"y1", "y2"}); }
|
||||
|
||||
void Init(int64_t axis = -1, bool descending = false);
|
||||
void set_axis(int64_t axis);
|
||||
int64_t get_axis() const;
|
||||
void set_descending(bool descending);
|
||||
bool get_descending() const;
|
||||
};
|
||||
abstract::AbstractBasePtr SortInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Sort op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
sort_op_info = AiCPURegOp("Sort") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.attr("axis", "int", "-1") \
|
||||
.attr("descending", "bool", "False") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y1", "required") \
|
||||
.output(0, "y2", "required") \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(sort_op_info)
|
||||
def _sort_aicpu():
|
||||
"""Sort AiCPU register"""
|
||||
return
|
|
@ -2611,6 +2611,52 @@ def scatter_nd_min(input_x, indices, updates, use_locking=False):
|
|||
return scatter_nd_min_inner(input_x, indices, updates)
|
||||
|
||||
|
||||
def sort(input_x, axis=-1, descending=False):
|
||||
r"""
|
||||
Sorts the elements of the input tensor along the given dimension in the specified order.
|
||||
|
||||
Args:
|
||||
input_x(Tensor): The input tensor to sort.
|
||||
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
|
||||
axis (int): The dimension to sort along. Default: -1.
|
||||
descending (bool): Controls the sort order. If descending is True then the elements
|
||||
are sorted in descending order by value. Default: False.
|
||||
|
||||
.. warning::
|
||||
Currently, the data types of Float16, UInt8, Int8, Int16, Int32, Int64 are supported.
|
||||
If use Float32, it may cause loss of accuracy.
|
||||
|
||||
Returns:
|
||||
y1(Tensor) - A tensor whose values are the sorted values, with the same shape and data type as input.
|
||||
y2(Tensor) - The indices of the elements in the original input tensor. Data type is int32.
|
||||
|
||||
Raises:
|
||||
TypeError: If `axis` is not an int.
|
||||
TypeError: If `descending` is not a bool.
|
||||
TypeError: If dtype of `x` is neither float16, float32, uint8, int8, int16, int32, int64.
|
||||
ValueError: If `axis` is not in range of [-len(x.shape), len(x.shape)).
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mindspore.float16)
|
||||
>>> sort = ops.Sort()
|
||||
>>> output = sort(x)
|
||||
>>> # The output below is based on the Ascend platform.
|
||||
>>> print(output)
|
||||
(Tensor(shape=[3, 3], dtype=Float16, value=
|
||||
[[ 1.0000e+00, 2.0000e+00, 8.0000e+00],
|
||||
[ 3.0000e+00, 5.0000e+00, 9.0000e+00],
|
||||
[ 4.0000e+00, 6.0000e+00, 7.0000e+00]]), Tensor(shape=[3, 3], dtype=Int32, value=
|
||||
[[2, 1, 0],
|
||||
[2, 0, 1],
|
||||
[0, 1, 2]]))
|
||||
"""
|
||||
_sort = _get_cache_prim(P.Sort)(axis, descending)
|
||||
return _sort(input_x)
|
||||
|
||||
|
||||
def gather(input_params, input_indices, axis):
|
||||
r"""
|
||||
Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.
|
||||
|
|
|
@ -5712,30 +5712,9 @@ class TransShape(PrimitiveWithInfer):
|
|||
|
||||
class Sort(Primitive):
|
||||
"""
|
||||
Sorts the elements of the input tensor along a given dimension in ascending order by value.
|
||||
Sorts the elements of the input tensor along the given dimension in the specified order.
|
||||
|
||||
Args:
|
||||
axis (int): The dimension to sort along. Default: -1.
|
||||
descending (bool): Controls the sorting order. If descending is True then the elements
|
||||
are sorted in descending order by value. Default: False.
|
||||
|
||||
.. warning::
|
||||
Currently, only the data type of Float16 is supported. If use Float32, it may cause loss
|
||||
of accuracy.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input to sort, with float16 or float32 data type.
|
||||
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
|
||||
|
||||
Outputs:
|
||||
- **y1** (Tensor) - A tensor whose values are the sorted values, with the same shape and data type as input.
|
||||
- **y2** (Tensor) - The indices of the elements in the original input tensor. Data type is int32.
|
||||
|
||||
Raises:
|
||||
TypeError: If `axis` is not an int.
|
||||
TypeError: If `descending` is not a bool.
|
||||
TypeError: If dtype of `x` is neither float16 nor float32.
|
||||
ValueError: If `axis` is not in range of [-len(x.shape), len(x.shape)).
|
||||
Refer to :func:'mindspore.ops.sort' for more details.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
|
Loading…
Reference in New Issue