forked from mindspore-Ecosystem/mindspore
add Tensor and function api argsort
This commit is contained in:
parent
a4a1dbf9df
commit
bc23bf6543
|
@ -405,6 +405,7 @@ Array操作
|
|||
mindspore.ops.adaptive_max_pool2d
|
||||
mindspore.ops.affine_grid
|
||||
mindspore.ops.arange
|
||||
mindspore.ops.argsort
|
||||
mindspore.ops.batch_to_space_nd
|
||||
mindspore.ops.bincount
|
||||
mindspore.ops.broadcast_to
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.argsort
|
||||
=========================
|
||||
|
||||
.. py:method:: mindspore.Tensor.argsort(axis=-1, descending=False)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.argsort`。
|
|
@ -48,6 +48,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.argmax_with_value
|
||||
mindspore.Tensor.argmin
|
||||
mindspore.Tensor.argmin_with_value
|
||||
mindspore.Tensor.argsort
|
||||
mindspore.Tensor.asin
|
||||
mindspore.Tensor.asinh
|
||||
mindspore.Tensor.asnumpy
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
mindspore.ops.argsort
|
||||
======================
|
||||
|
||||
.. py:function:: mindspore.ops.argsort(input_x, axis=-1, descending=False)
|
||||
|
||||
返回输入Tensor沿轴按特定顺序排序索引。
|
||||
|
||||
参数:
|
||||
- **input_x** (Tensor) - 待排序的输入Tensor。
|
||||
- **axis** (int) - 指定排序轴。默认值:-1。
|
||||
- **descending** (bool) - 控制输出顺序。如果 `descending` 为True,按照元素值升序排序,否则降顺排序。默认值:False。
|
||||
|
||||
返回:
|
||||
Tensor,排序后输入Tensor的索引。数据类型为int32。
|
||||
|
|
@ -54,6 +54,7 @@
|
|||
mindspore.Tensor.argmax_with_value
|
||||
mindspore.Tensor.argmin
|
||||
mindspore.Tensor.argmin_with_value
|
||||
mindspore.Tensor.argsort
|
||||
mindspore.Tensor.asin
|
||||
mindspore.Tensor.asinh
|
||||
mindspore.Tensor.asnumpy
|
||||
|
|
|
@ -405,6 +405,7 @@ Array Operation
|
|||
mindspore.ops.adaptive_max_pool2d
|
||||
mindspore.ops.affine_grid
|
||||
mindspore.ops.arange
|
||||
mindspore.ops.argsort
|
||||
mindspore.ops.batch_to_space_nd
|
||||
mindspore.ops.bincount
|
||||
mindspore.ops.broadcast_to
|
||||
|
|
|
@ -447,6 +447,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"signbit", std::string("signbit")}, // signbit()
|
||||
{"sinh", std::string("sinh")}, // sinh()
|
||||
{"sort", std::string("sort")}, // sort()
|
||||
{"argsort", std::string("argsort")}, // argsort()
|
||||
{"trunc", std::string("trunc")}, // trunc()
|
||||
{"where", std::string("where")}, // where()
|
||||
{"imag", std::string("imag")}, // imag()
|
||||
|
|
|
@ -4127,6 +4127,13 @@ def sort(input, dim=-1, descending=False):
|
|||
return P.Sort(axis=dim, descending=descending)(input)
|
||||
|
||||
|
||||
def argsort(input_x, axis=-1, descending=False):
|
||||
r"""
|
||||
Return the indices that sort the input tensor along the given dimension in the specified order.
|
||||
"""
|
||||
return F.argsort(input_x, axis, descending)
|
||||
|
||||
|
||||
def trunc(input):
|
||||
r"""
|
||||
Returns a new tensor with the truncated integer values of the elements of input.
|
||||
|
|
|
@ -4676,6 +4676,13 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('sort')(axis=dim, descending=descending)(self)
|
||||
|
||||
def argsort(self, axis=-1, descending=False):
|
||||
"""
|
||||
For details, please refer to :func:`mindspore.ops.argsort`.
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('argsort')(self, axis, descending)
|
||||
|
||||
def trunc(self):
|
||||
r"""
|
||||
Returns a new tensor with the truncated integer values of the elements of input.
|
||||
|
|
|
@ -133,6 +133,7 @@ from .array_func import (
|
|||
diagonal,
|
||||
lstsq,
|
||||
mvlgamma,
|
||||
argsort,
|
||||
)
|
||||
from .parameter_func import (
|
||||
assign,
|
||||
|
|
|
@ -2936,6 +2936,33 @@ def sort(input_x, axis=-1, descending=False):
|
|||
return _sort(input_x)
|
||||
|
||||
|
||||
def argsort(input_x, axis=-1, descending=False):
|
||||
r"""
|
||||
Return the indices that sort the input tensor along the given dimension in the specified order.
|
||||
|
||||
Args:
|
||||
input_x(Tensor): The input tensor to sort.
|
||||
axis (int): The dimension to sort along. Default: -1.
|
||||
descending (bool): Controls the sort order. If `descending` is True then the elements
|
||||
are sorted in descending order by value. Otherwise sort in descending order. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, the indices of sorted input tensor. Data type is int32.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[8, 2, 1], [5, 9, 3], [4, 6, 7]]), mindspore.float16)
|
||||
>>> sort = ops.argsort(x)
|
||||
>>> output = sort(x)
|
||||
>>> print(output)
|
||||
"""
|
||||
_sort = _get_cache_prim(P.Sort)(axis, descending)
|
||||
_, arg_sort = _sort(input_x)
|
||||
return arg_sort
|
||||
|
||||
|
||||
def gather(input_params, input_indices, axis):
|
||||
r"""
|
||||
Returns the slice of the input tensor corresponding to the elements of `input_indices` on the specified `axis`.
|
||||
|
@ -5774,5 +5801,6 @@ __all__ = [
|
|||
'diagonal',
|
||||
'lstsq',
|
||||
'mvlgamma',
|
||||
'argsort'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -320,6 +320,7 @@ tensor_operator_registry.register('dense_to_sparse_coo', dense_to_sparse_coo)
|
|||
tensor_operator_registry.register('csr_to_dense', csr_to_dense)
|
||||
tensor_operator_registry.register('narrow', narrow)
|
||||
tensor_operator_registry.register('sort', sort)
|
||||
tensor_operator_registry.register('argsort', argsort)
|
||||
tensor_operator_registry.register('msort', msort)
|
||||
tensor_operator_registry.register('mm', mm)
|
||||
tensor_operator_registry.register('nan_to_num', nan_to_num)
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, input_x, axis=-1, descending=False):
|
||||
return ops.argsort(input_x, axis, descending)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_argsort(mode):
|
||||
"""
|
||||
Feature: argsort
|
||||
Description: Verify the result of argsort.
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
net = Net()
|
||||
a = [[0.0785, 1.5267, -0.8521, 0.4065],
|
||||
[0.1598, 0.0788, -0.0745, -1.2700],
|
||||
[1.2208, 1.0722, -0.7064, 1.2564],
|
||||
[0.0669, -0.2318, -0.8229, -0.9280]]
|
||||
x = ms.Tensor(a)
|
||||
out = net(x)
|
||||
expect = [[2, 0, 3, 1],
|
||||
[3, 2, 1, 0],
|
||||
[2, 1, 0, 3],
|
||||
[3, 2, 1, 0]]
|
||||
assert np.allclose(out.asnumpy(), np.array(expect))
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class Argsort(nn.Cell):
|
||||
def construct(self, input_x, axis=-1, descending=False):
|
||||
return input_x.argsort(axis, descending)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
|
||||
def test_tensor_argsort(mode):
|
||||
"""
|
||||
Feature: tensor.argsort
|
||||
Description: Verify the result of argsort
|
||||
Expectation: success
|
||||
"""
|
||||
ms.set_context(mode=mode)
|
||||
a = [[0.0785, 1.5267, -0.8521, 0.4065],
|
||||
[0.1598, 0.0788, -0.0745, -1.2700],
|
||||
[1.2208, 1.0722, -0.7064, 1.2564],
|
||||
[0.0669, -0.2318, -0.8229, -0.9280]]
|
||||
x = ms.Tensor(a)
|
||||
net = Argsort()
|
||||
out = net(x)
|
||||
expect = [[2, 0, 3, 1],
|
||||
[3, 2, 1, 0],
|
||||
[2, 1, 0, 3],
|
||||
[3, 2, 1, 0]]
|
||||
assert np.allclose(out.asnumpy(), np.array(expect))
|
Loading…
Reference in New Issue