ChangeApiOfArgmin

fix when input shape is 1
This commit is contained in:
liangcanli 2022-11-11 16:38:46 +08:00
parent 66270a8f84
commit 8ca7ddbdef
9 changed files with 151 additions and 15 deletions

View File

@ -1,6 +1,6 @@
mindspore.Tensor.argmin
=======================
.. py:method:: mindspore.Tensor.argmin(axis=None)
.. py:method:: mindspore.Tensor.argmin(axis=None, keepdims=False)
详情请参考 :func:`mindspore.ops.argmin`

View File

@ -1,15 +1,16 @@
mindspore.ops.argmin
====================
.. py:function:: mindspore.ops.argmin(x, axis=-1)
.. py:function:: mindspore.ops.argmin(x, axis=None, keepdims=False)
返回输入Tensor在指定轴上的最小值索引。
如果输入Tensor的shape为 :math:`(x_1, ..., x_N)` 则输出Tensor的shape为 :math:`(x_1, ..., x_{axis-1}, x_{axis+1}, ..., x_N)`
参数:
- **x** (Tensor) - shape非空任意维度的Tensor。
- **axis** (int) - 指定计算轴。默认值:-1。
- **x** (Tensor) - 输入Tensor。
- **axis** (Union[int, None],可选) - 指定计算轴。如果是None将会返回扁平化Tensor在指定轴上的最小值索引。默认值None。
- **keepdims** (bool可选) - 输出Tensor是否保留指定轴。如果 `axis` 是None忽略该选项。默认值 False。
返回:
Tensor输出为指定轴上输入Tensor最小值的索引。

View File

@ -789,7 +789,7 @@ def argmax(x, axis=None, keepdims=False):
return out
def argmin(x, axis=None):
def argmin(x, axis=None, keepdims=False):
"""
Returns the indices of the minimum values along an axis.
@ -798,6 +798,8 @@ def argmin(x, axis=None):
axis (int, optional): By default, the index is into
the flattened array, otherwise along the specified axis.
Defaults to None.
keepdims (boolean, optional): Whether the output tensor retains the specified
dimension. Ignored if `axis` is None. Default: False.
Returns:
Tensor, array of indices into the array. It has the same
@ -816,15 +818,19 @@ def argmin(x, axis=None):
>>> print(a.argmin())
0
"""
# P.Argmax only supports float
# P.Argmin only supports float
x = x.astype(mstype.float32)
is_axis_none = False
if axis is None:
x = ravel(x)
axis = 0
is_axis_none = True
else:
axis = check_axis_in_range_const(axis, F.rank(x))
# P.Argmin is currently not supported
return P.Argmax(axis)(F.neg_tensor(x))
out = P.Argmin(axis)(x)
if keepdims and not is_axis_none:
out = expand_dims(out, axis)
return out
def argmax_with_value(x, axis=0, keep_dims=False):

View File

@ -1954,6 +1954,8 @@ class Tensor(Tensor_):
"""
For details, please refer to :func:`mindspore.ops.argmax`.
"""
if self.shape == ():
return Tensor(0)
a = self
is_axis_none = False
if axis is None:
@ -1965,11 +1967,14 @@ class Tensor(Tensor_):
out = out.expand_dims(axis)
return out
def argmin(self, axis=None):
def argmin(self, axis=None, keepdims=False):
"""
For details, please refer to :func:`mindspore.ops.argmin`.
"""
if self.shape == ():
return Tensor(0)
# P.Argmin only supports float
is_axis_none = False
a = self.astype(mstype.float32)
if axis is None:
a = a.ravel()
@ -1977,7 +1982,10 @@ class Tensor(Tensor_):
else:
axis = validator.check_axis_in_range(axis, a.ndim)
# P.Argmin is currently not supported
return tensor_operator_registry.get('argmax')(axis)(tensor_operator_registry.get('__neg__')(a))
out = tensor_operator_registry.get('argmin')(axis)(a)
if keepdims and not is_axis_none:
out = out.expand_dims(axis)
return out
def argmax_with_value(self, axis=0, keep_dims=False):
"""
@ -2023,6 +2031,8 @@ class Tensor(Tensor_):
>>> print(index, output)
[3] [0.7]
"""
if self.shape == ():
return (Tensor(0), self)
self._init_check()
return tensor_operator_registry.get('argmax_with_value')(self, axis, keep_dims)
@ -2068,6 +2078,8 @@ class Tensor(Tensor_):
>>> print(index, output)
[0] [0.0]
"""
if self.shape == ():
return (Tensor(0), self)
self._init_check()
return tensor_operator_registry.get('argmin_with_value')(self, axis, keep_dims)

View File

@ -4288,6 +4288,8 @@ def max(x, axis=0, keep_dims=False):
>>> print(index, output)
[3] [0.7]
"""
if x.shape == ():
return (Tensor(0), x)
argmax_with_value_op = ArgMaxWithValue(axis, keep_dims)
return argmax_with_value_op(x)
@ -4319,6 +4321,8 @@ def argmax(x, axis=None, keepdims=False):
>>> print(output)
[1 0 0]
"""
if x.shape == ():
return Tensor(0)
is_axis_none = False
if axis is None:
x = reshape_(x, (-1,))
@ -4377,6 +4381,8 @@ def min(x, axis=0, keep_dims=False):
>>> print(index, output)
[0] [0.0]
"""
if x.shape == ():
return (Tensor(0), x)
argmin_with_value_ = ArgMinWithValue(axis=axis, keep_dims=keep_dims)
return argmin_with_value_(x)

View File

@ -406,7 +406,7 @@ def exp2(x):
return exp2_(tensor_2, x)
def argmin(x, axis=-1):
def argmin(x, axis=-1, keepdims=False):
"""
Returns the indices of the minimum value of a tensor across the axis.
@ -416,6 +416,8 @@ def argmin(x, axis=-1):
Args:
x (Tensor): Input tensor. The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
axis (int): Axis where the Argmin operation applies to. Default: -1.
keepdims (boolean, optional): Whether the output tensor retains the specified
dimension. Ignored if `axis` is None. Default: False.
Returns:
Tensor, indices of the min value of input tensor across the axis.
@ -432,8 +434,17 @@ def argmin(x, axis=-1):
>>> print(index)
2
"""
_argmin = _get_cache_prim(P.Argmin)(axis)
return _argmin(x)
if x.shape == ():
return Tensor(0)
is_axis_none = False
if axis is None:
x = P.Reshape()(x, (-1,))
axis = 0
is_axis_none = True
out = _get_cache_prim(P.Argmin)(axis)(x)
if keepdims and not is_axis_none:
out = P.ExpandDims()(out, axis)
return out
neg_tensor = P.Neg()

View File

@ -155,6 +155,7 @@ tensor_operator_registry.register('broadcast_to', P.BroadcastTo)
tensor_operator_registry.register('matmul', matmul)
tensor_operator_registry.register('xdivy', P.Xdivy)
tensor_operator_registry.register('argmax', P.Argmax)
tensor_operator_registry.register('argmin', P.Argmin)
tensor_operator_registry.register('cumsum', P.CumSum)
tensor_operator_registry.register('cummin', cummin)
tensor_operator_registry.register('cummax', cummax)

View File

@ -20,9 +20,8 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import Tensor, ops
from mindspore.common import dtype as mstype
import mindspore.ops as ops
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@ -148,3 +147,53 @@ def test_argmin_vmap_basic_axis_negative():
outputs = ops.vmap(cal_argmin_axis_negative, in_axes=0, out_axes=0)(x)
expect = np.array([[1, 0, 1], [2, 0, 0]]).astype(np.int32)
assert np.allclose(outputs.asnumpy(), expect)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_argmin_functional():
"""
Feature: test ops.argmin.
Description: test ops.argmin functional api.
Expectation: the result match with expected result.
"""
x = Tensor([[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], mstype.int32)
out_dim_none = ops.argmin(x, axis=None, keepdims=False)
out_dim_0 = ops.argmin(x, axis=0, keepdims=False)
out_dim_1 = ops.argmin(x, axis=1, keepdims=False)
out_dim_none_keepdim = ops.argmin(x, axis=None, keepdims=True)
out_dim_0_keepdim = ops.argmin(x, axis=0, keepdims=True)
out_dim_1_keepdim = ops.argmin(x, axis=1, keepdims=True)
assert out_dim_none.asnumpy() == 7
assert np.all(out_dim_0.asnumpy() == np.array([1, 2, 1]))
assert np.all(out_dim_1.asnumpy() == np.array([1, 0, 1]))
assert out_dim_none_keepdim.asnumpy() == 7
assert np.all(out_dim_0_keepdim.asnumpy() == np.array([[1, 2, 1]]))
assert np.all(out_dim_1_keepdim.asnumpy() == np.array([[1], [0], [1]]))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_argmin_tensor():
"""
Feature: test tensor.argmin.
Description: test argmin tensor api.
Expectation: the result match with expected result.
"""
x = Tensor([[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], mstype.int32)
out_dim_none = x.argmin(axis=None, keepdims=False)
out_dim_0 = x.argmin(axis=0, keepdims=False)
out_dim_1 = x.argmin(axis=1, keepdims=False)
out_dim_none_keepdim = x.argmin(axis=None, keepdims=True)
out_dim_0_keepdim = x.argmin(axis=0, keepdims=True)
out_dim_1_keepdim = x.argmin(axis=1, keepdims=True)
assert out_dim_none.asnumpy() == 7
assert np.all(out_dim_0.asnumpy() == np.array([1, 2, 1]))
assert np.all(out_dim_1.asnumpy() == np.array([1, 0, 1]))
assert out_dim_none_keepdim.asnumpy() == 7
assert np.all(out_dim_0_keepdim.asnumpy() == np.array([[1, 2, 1]]))
assert np.all(out_dim_1_keepdim.asnumpy() == np.array([[1], [0], [1]]))

View File

@ -104,3 +104,53 @@ def test_argmin_high_dims():
ms_output = argmin(Tensor(x))
np_output = np.argmin(x, axis=rnd_axis)
assert (ms_output.asnumpy() == np_output).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu
@pytest.mark.env_onecard
def test_argmin_functional():
"""
Feature: test ops.argmin.
Description: test ops.argmin functional api.
Expectation: the result match with expected result.
"""
x = Tensor([[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], mstype.float32)
out_dim_none = ops.argmin(x, axis=None, keepdims=False)
out_dim_0 = ops.argmin(x, axis=0, keepdims=False)
out_dim_1 = ops.argmin(x, axis=1, keepdims=False)
out_dim_none_keepdim = ops.argmin(x, axis=None, keepdims=True)
out_dim_0_keepdim = ops.argmin(x, axis=0, keepdims=True)
out_dim_1_keepdim = ops.argmin(x, axis=1, keepdims=True)
assert out_dim_none.asnumpy() == 7
assert np.all(out_dim_0.asnumpy() == np.array([1, 2, 1]))
assert np.all(out_dim_1.asnumpy() == np.array([1, 0, 1]))
assert out_dim_none_keepdim.asnumpy() == 7
assert np.all(out_dim_0_keepdim.asnumpy() == np.array([[1, 2, 1]]))
assert np.all(out_dim_1_keepdim.asnumpy() == np.array([[1], [0], [1]]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu
@pytest.mark.env_onecard
def test_argmin_tensor():
"""
Feature: test tensor.argmin.
Description: test argmin tensor api.
Expectation: the result match with expected result.
"""
x = Tensor([[5., 3., 4.], [2., 4., 3.], [3., 1., 4.]], mstype.float32)
out_dim_none = x.argmin(axis=None, keepdims=False)
out_dim_0 = x.argmin(axis=0, keepdims=False)
out_dim_1 = x.argmin(axis=1, keepdims=False)
out_dim_none_keepdim = x.argmin(axis=None, keepdims=True)
out_dim_0_keepdim = x.argmin(axis=0, keepdims=True)
out_dim_1_keepdim = x.argmin(axis=1, keepdims=True)
assert out_dim_none.asnumpy() == 7
assert np.all(out_dim_0.asnumpy() == np.array([1, 2, 1]))
assert np.all(out_dim_1.asnumpy() == np.array([1, 0, 1]))
assert out_dim_none_keepdim.asnumpy() == 7
assert np.all(out_dim_0_keepdim.asnumpy() == np.array([[1, 2, 1]]))
assert np.all(out_dim_1_keepdim.asnumpy() == np.array([[1], [0], [1]]))