diff --git a/mindspore/core/ops/sort.cc b/mindspore/core/ops/sort.cc new file mode 100644 index 00000000000..a1a609ae0e4 --- /dev/null +++ b/mindspore/core/ops/sort.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ops/sort.h" +#include +#include +#include +#include +#include +#include "ops/op_utils.h" +#include "utils/check_convert_utils.h" +#include "abstract/primitive_infer_map.h" + +namespace mindspore { +namespace ops { +namespace { +abstract::TupleShapePtr SortInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t input_num = 1; + (void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kGreaterEqual, input_num, + prim_name); + CheckAndConvertUtils::CheckArgs(prim_name, input_args, 0); + auto x = input_args[0]->BuildShape(); + MS_EXCEPTION_IF_NULL(x); + auto shape_element = x->cast(); + MS_EXCEPTION_IF_NULL(shape_element); + return std::make_shared(std::vector{shape_element, shape_element}); +} + +TuplePtr SortInferType(const PrimitivePtr &prim, const std::vector &input_args) { + auto infer_type = input_args[0]->BuildType(); + MS_EXCEPTION_IF_NULL(infer_type); + const std::set valid_types = {kFloat16, kFloat32}; + auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputx", infer_type, valid_types, prim->name()); + std::vector type_tuple; + type_tuple.push_back(type); + type_tuple.push_back(kInt32); + return std::make_shared(type_tuple); +} +} // namespace + +AbstractBasePtr SortInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto infertype = SortInferType(primitive, input_args); + auto infershape = SortInferShape(primitive, input_args); + return abstract::MakeAbstract(infershape, infertype); +} +REGISTER_PRIMITIVE_EVAL_IMPL(Sort, prim::kPrimSort, SortInfer, nullptr, true); +} // namespace ops +} // namespace mindspore diff --git a/mindspore/core/ops/sort.h b/mindspore/core/ops/sort.h new file mode 100644 index 00000000000..8c417344ce3 --- /dev/null +++ b/mindspore/core/ops/sort.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPS_SORT_H_ +#define MINDSPORE_CORE_OPS_SORT_H_ +#include +#include +#include +#include +#include "ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" +#include "ops/op_utils.h" + +namespace mindspore { +namespace ops { +constexpr auto kNameSort = "Sort"; +class Sort : public PrimitiveC { + public: + Sort() : PrimitiveC(kNameSort) { InitIOName({"x"}, {"y1", "y2"}); } + ~Sort() = default; + MS_DECLARE_PARENT(Sort, PrimitiveC); +}; +AbstractBasePtr SortInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace ops +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPS_SORT_H_ diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 4bc02553e66..f706a78a273 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -266,6 +266,7 @@ from .depth_to_space import _depth_to_space_tbe from .space_to_depth import _space_to_depth_tbe from .extract_image_patches import _extract_image_patches_tbe from .sort import _sort_tbe +from .sort_ds import _sort_ds_tbe from .floor import _floor_tbe from .ceil import _ceil_tbe from .log1p import _log1p_tbe diff --git a/mindspore/ops/_op_impl/tbe/sort_ds.py b/mindspore/ops/_op_impl/tbe/sort_ds.py new file mode 100644 index 00000000000..3b82945b82e --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sort_ds.py @@ -0,0 +1,39 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Sort op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sort_op_info = TBERegOp("Sort") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("sort.so") \ + .compute_cost(10) \ + .kernel_name("sort") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .attr("axis", "optional", "int", "all", "-1") \ + .attr("descending", "optional", "bool", "all", "false") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y1", False, "required", "all") \ + .output(1, "y2", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default) \ + .get_op_info() + + +@op_info_register(sort_op_info) +def _sort_ds_tbe(): + """Sort TBE register""" + return diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index d4fd17ad046..ab78320e762 100755 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -5868,7 +5868,7 @@ class TransShape(PrimitiveWithInfer): 'value': None} -class Sort(PrimitiveWithInfer): +class Sort(Primitive): """ Sorts the elements of the input tensor along a given dimension in ascending order by value. @@ -5877,6 +5877,10 @@ class Sort(PrimitiveWithInfer): descending (bool): Controls the sorting order. If descending is True then the elements are sorted in descending order by value. Default: False. + .. warning:: + Currently, only the data type of Float16 is supported. If use Float32, it may cause loss + of accuracy. + Inputs: - **x** (Tensor) - The input to sort, with float16 or float32 data type. The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions. @@ -5906,19 +5910,12 @@ class Sort(PrimitiveWithInfer): [2, 0, 1], [0, 1, 2]])) """ - @prim_attr_register def __init__(self, axis=-1, descending=False): """Initialize Sort""" self.axis = validator.check_value_type("axis", axis, [int], self.name) self.descending = validator.check_value_type("descending", descending, [bool], self.name) - - def infer_shape(self, x_shape): - return x_shape, x_shape - - def infer_dtype(self, x_dtype): - validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float32, mstype.float16], self.name) - return x_dtype, mstype.tensor_type(mstype.int32) + self.init_prim_io_names(inputs=['x'], outputs=['y1', 'y2']) class EmbeddingLookup(PrimitiveWithCheck):