add Tensor and function api argsort

This commit is contained in:
ZhidanLiu 2022-12-14 11:12:40 +08:00
parent a4a1dbf9df
commit bc23bf6543
14 changed files with 177 additions and 0 deletions

View File

@ -405,6 +405,7 @@ Array操作
mindspore.ops.adaptive_max_pool2d
mindspore.ops.affine_grid
mindspore.ops.arange
mindspore.ops.argsort
mindspore.ops.batch_to_space_nd
mindspore.ops.bincount
mindspore.ops.broadcast_to

View File

@ -0,0 +1,6 @@
mindspore.Tensor.argsort
=========================
.. py:method:: mindspore.Tensor.argsort(axis=-1, descending=False)
详情请参考 :func:`mindspore.ops.argsort`

View File

@ -48,6 +48,7 @@ mindspore.Tensor
mindspore.Tensor.argmax_with_value
mindspore.Tensor.argmin
mindspore.Tensor.argmin_with_value
mindspore.Tensor.argsort
mindspore.Tensor.asin
mindspore.Tensor.asinh
mindspore.Tensor.asnumpy

View File

@ -0,0 +1,15 @@
mindspore.ops.argsort
======================
.. py:function:: mindspore.ops.argsort(input_x, axis=-1, descending=False)
返回输入Tensor沿轴按特定顺序排序索引。
参数:
- **input_x** (Tensor) - 待排序的输入Tensor。
- **axis** (int) - 指定排序轴。默认值:-1。
- **descending** (bool) - 控制输出顺序。如果 `descending` 为True按照元素值升序排序否则降顺排序。默认值False。
返回:
Tensor排序后输入Tensor的索引。数据类型为int32。

View File

@ -54,6 +54,7 @@
mindspore.Tensor.argmax_with_value
mindspore.Tensor.argmin
mindspore.Tensor.argmin_with_value
mindspore.Tensor.argsort
mindspore.Tensor.asin
mindspore.Tensor.asinh
mindspore.Tensor.asnumpy

View File

@ -405,6 +405,7 @@ Array Operation
mindspore.ops.adaptive_max_pool2d
mindspore.ops.affine_grid
mindspore.ops.arange
mindspore.ops.argsort
mindspore.ops.batch_to_space_nd
mindspore.ops.bincount
mindspore.ops.broadcast_to

View File

@ -447,6 +447,7 @@ BuiltInTypeMap &GetMethodMap() {
{"signbit", std::string("signbit")}, // signbit()
{"sinh", std::string("sinh")}, // sinh()
{"sort", std::string("sort")}, // sort()
{"argsort", std::string("argsort")}, // argsort()
{"trunc", std::string("trunc")}, // trunc()
{"where", std::string("where")}, // where()
{"imag", std::string("imag")}, // imag()

View File

@ -4127,6 +4127,13 @@ def sort(input, dim=-1, descending=False):
return P.Sort(axis=dim, descending=descending)(input)
def argsort(input_x, axis=-1, descending=False):
r"""
Return the indices that sort the input tensor along the given dimension in the specified order.
"""
return F.argsort(input_x, axis, descending)
def trunc(input):
r"""
Returns a new tensor with the truncated integer values of the elements of input.

View File

@ -4676,6 +4676,13 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('sort')(axis=dim, descending=descending)(self)
def argsort(self, axis=-1, descending=False):
"""
For details, please refer to :func:`mindspore.ops.argsort`.
"""
self._init_check()
return tensor_operator_registry.get('argsort')(self, axis, descending)
def trunc(self):
r"""
Returns a new tensor with the truncated integer values of the elements of input.

View File

@ -133,6 +133,7 @@ from .array_func import (
diagonal,
lstsq,
mvlgamma,
argsort,
)
from .parameter_func import (
assign,

View File

@ -2936,6 +2936,33 @@ def sort(input_x, axis=-1, descending=False):
return _sort(input_x)
def argsort(input_x, axis=-1, descending=False):
r"""
Return the indices that sort the input tensor along the given dimension in the specified order.
Args:
input_x(Tensor): The input tensor to sort.
axis (int): The dimension to sort along. Default: -1.
descending (bool): Controls the sort order. If `descending` is True then the elements
are sorted in descending order by value. Otherwise sort in descending order. Default: False.
Returns:
Tensor, the indices of sorted input tensor. Data type is int32.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mindspore.float16)
>>> sort = ops.argsort(x)
>>> output = sort(x)
>>> print(output)
"""
_sort = _get_cache_prim(P.Sort)(axis, descending)
_, arg_sort = _sort(input_x)
return arg_sort
def gather(input_params, input_indices, axis):
r"""
Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.
@ -5774,5 +5801,6 @@ __all__ = [
'diagonal',
'lstsq',
'mvlgamma',
'argsort'
]
__all__.sort()

View File

@ -320,6 +320,7 @@ tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)
tensor_operator_registry.register('csr_to_dense', csr_to_dense)
tensor_operator_registry.register('narrow', narrow)
tensor_operator_registry.register('sort', sort)
tensor_operator_registry.register('argsort', argsort)
tensor_operator_registry.register('msort', msort)
tensor_operator_registry.register('mm', mm)
tensor_operator_registry.register('nan_to_num', nan_to_num)

View File

@ -0,0 +1,55 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
class Net(nn.Cell):
def construct(self, input_x, axis=-1, descending=False):
return ops.argsort(input_x, axis, descending)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_argsort(mode):
"""
Feature: argsort
Description: Verify the result of argsort.
Expectation: success
"""
ms.set_context(mode=mode)
net = Net()
a = [[0.0785, 1.5267, -0.8521, 0.4065],
[0.1598, 0.0788, -0.0745, -1.2700],
[1.2208, 1.0722, -0.7064, 1.2564],
[0.0669, -0.2318, -0.8229, -0.9280]]
x = ms.Tensor(a)
out = net(x)
expect = [[2, 0, 3, 1],
[3, 2, 1, 0],
[2, 1, 0, 3],
[3, 2, 1, 0]]
assert np.allclose(out.asnumpy(), np.array(expect))

View File

@ -0,0 +1,52 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
class Argsort(nn.Cell):
def construct(self, input_x, axis=-1, descending=False):
return input_x.argsort(axis, descending)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_tensor_argsort(mode):
"""
Feature: tensor.argsort
Description: Verify the result of argsort
Expectation: success
"""
ms.set_context(mode=mode)
a = [[0.0785, 1.5267, -0.8521, 0.4065],
[0.1598, 0.0788, -0.0745, -1.2700],
[1.2208, 1.0722, -0.7064, 1.2564],
[0.0669, -0.2318, -0.8229, -0.9280]]
x = ms.Tensor(a)
net = Argsort()
out = net(x)
expect = [[2, 0, 3, 1],
[3, 2, 1, 0],
[2, 1, 0, 3],
[3, 2, 1, 0]]
assert np.allclose(out.asnumpy(), np.array(expect))