From bc783dd49db14322e552994159fad40e65f492df Mon Sep 17 00:00:00 2001 From: shaojunsong Date: Mon, 14 Nov 2022 21:47:12 +0800 Subject: [PATCH] add tensor.where --- docs/api/api_python/mindspore.ops.rst | 1 + .../Tensor/mindspore.Tensor.where.rst | 6 +++ .../api_python/mindspore/mindspore.Tensor.rst | 1 + .../ops/mindspore.ops.func_where.rst | 21 +++++++++ docs/api/api_python_en/Tensor_list.rst | 1 + docs/api/api_python_en/mindspore.ops.rst | 1 + mindspore/ccsrc/pipeline/jit/resource.cc | 1 + .../_extends/parse/standard_method.py | 14 +++++- mindspore/python/mindspore/common/tensor.py | 7 +++ mindspore/python/mindspore/numpy/array_ops.py | 3 +- .../python/mindspore/ops/function/__init__.py | 1 + .../mindspore/ops/function/array_func.py | 35 ++++++++++++++ .../mindspore/ops/function/math_func.py | 3 +- mindspore/python/mindspore/ops/functional.py | 1 + tests/st/ops/test_ops_where.py | 36 ++++++++++++++ tests/st/tensor/test_where.py | 47 +++++++++++++++++++ 16 files changed, 174 insertions(+), 5 deletions(-) create mode 100644 docs/api/api_python/mindspore/Tensor/mindspore.Tensor.where.rst create mode 100644 docs/api/api_python/ops/mindspore.ops.func_where.rst create mode 100644 tests/st/ops/test_ops_where.py create mode 100644 tests/st/tensor/test_where.py diff --git a/docs/api/api_python/mindspore.ops.rst b/docs/api/api_python/mindspore.ops.rst index fff115d1074..3c7b8827c3e 100644 --- a/docs/api/api_python/mindspore.ops.rst +++ b/docs/api/api_python/mindspore.ops.rst @@ -442,6 +442,7 @@ Array操作 mindspore.ops.unsorted_segment_sum mindspore.ops.unsqueeze mindspore.ops.unstack + mindspore.ops.where 类型转换 ^^^^^^^^^^^^^^^^ diff --git a/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.where.rst b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.where.rst new file mode 100644 index 00000000000..5ee0fbe990b --- /dev/null +++ b/docs/api/api_python/mindspore/Tensor/mindspore.Tensor.where.rst @@ -0,0 +1,6 @@ +mindspore.Tensor.where +======================= + +.. py:method:: mindspore.Tensor.where(condition, y) + + 详情请参考 :func:`mindspore.ops.where`。 diff --git a/docs/api/api_python/mindspore/mindspore.Tensor.rst b/docs/api/api_python/mindspore/mindspore.Tensor.rst index 6188d30b63b..45f9b408831 100644 --- a/docs/api/api_python/mindspore/mindspore.Tensor.rst +++ b/docs/api/api_python/mindspore/mindspore.Tensor.rst @@ -250,5 +250,6 @@ mindspore.Tensor mindspore.Tensor.unsqueeze mindspore.Tensor.var mindspore.Tensor.view + mindspore.Tensor.where mindspore.Tensor.xdivy mindspore.Tensor.xlogy diff --git a/docs/api/api_python/ops/mindspore.ops.func_where.rst b/docs/api/api_python/ops/mindspore.ops.func_where.rst new file mode 100644 index 00000000000..fde53b95370 --- /dev/null +++ b/docs/api/api_python/ops/mindspore.ops.func_where.rst @@ -0,0 +1,21 @@ +mindspore.ops.where +==================== + +.. py:function:: mindspore.ops.where(condition, x, y) + + 返回一个Tensor,Tensor的元素从 `x` 或 `y` 中根据 `condition` 选择。 + + .. math:: + + output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases} + + 参数: + - **condition** (Union[Bool Tensor, bool, scalar]) - 如果是True,选取 `x` 中的元素,否则选取 `y` 中的元素。 + - **x** (Union[Tensor, Scalar]) - 在 `condition` 为True的索引处选择的值。 + - **y** (Union[Tensor, Scalar]) - 当 `condition` 为False的索引处选择的值。 + + 返回: + Tensor,其中的元素从 `x` 和 `y` 中选取。 + + 异常: + -**ValueError** - `condition` 不可以被广播成 `x` 的shape。 diff --git a/docs/api/api_python_en/Tensor_list.rst b/docs/api/api_python_en/Tensor_list.rst index 0d0ad0ad356..80e0272af45 100644 --- a/docs/api/api_python_en/Tensor_list.rst +++ b/docs/api/api_python_en/Tensor_list.rst @@ -256,6 +256,7 @@ mindspore.Tensor.unsqueeze mindspore.Tensor.var mindspore.Tensor.view + mindspore.Tensor.where mindspore.Tensor.xdivy mindspore.Tensor.xlogy diff --git a/docs/api/api_python_en/mindspore.ops.rst b/docs/api/api_python_en/mindspore.ops.rst index b7f728b1fb1..4c8e81d42a6 100644 --- a/docs/api/api_python_en/mindspore.ops.rst +++ b/docs/api/api_python_en/mindspore.ops.rst @@ -442,6 +442,7 @@ Array Operation mindspore.ops.unsorted_segment_sum mindspore.ops.unsqueeze mindspore.ops.unstack + mindspore.ops.where Type Conversion ^^^^^^^^^^^^^^^ diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index 3ec5cbd7f88..9b56e3141aa 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -413,6 +413,7 @@ BuiltInTypeMap &GetMethodMap() { {"sinh", std::string("sinh")}, // sinh() {"sort", std::string("sort")}, // sort() {"trunc", std::string("trunc")}, // trunc() + {"where", std::string("where")}, // where() {"imag", std::string("imag")}, // imag() }}, {kObjectTypeRowTensorType, diff --git a/mindspore/python/mindspore/_extends/parse/standard_method.py b/mindspore/python/mindspore/_extends/parse/standard_method.py index be286742628..1f79b96f10c 100644 --- a/mindspore/python/mindspore/_extends/parse/standard_method.py +++ b/mindspore/python/mindspore/_extends/parse/standard_method.py @@ -1001,7 +1001,8 @@ def copy(x): return x -def max(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disable=redefined-builtin +def max(x, axis=None, keepdims=False, initial=None, # pylint: disable=redefined-builtin + where=True): # pylint: disable=redefined-outer-name """ Returns the maximum of a tensor or maximum along an axis. @@ -1046,7 +1047,8 @@ def max(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disa axis=axis, keepdims=keepdims, initial=initial, where=where) -def min(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disable=redefined-builtin +def min(x, axis=None, keepdims=False, initial=None, # pylint: disable=redefined-builtin + where=True): # pylint: disable=redefined-outer-name """ Returns the minimum of a tensor or minimum along an axis. @@ -3977,6 +3979,14 @@ def trunc(input): return F.trunc(input) +def where(x, condition, y): + r""" + Returns a tensor whose elements are selected from either `x` or `y` depending on `condition`. + Please refer to :func:`mindspore.ops.where`. + """ + return F.where(condition, x, y) + + def imag(input): r""" Returns a new tensor containing imaginary value of the input. diff --git a/mindspore/python/mindspore/common/tensor.py b/mindspore/python/mindspore/common/tensor.py index a5cd8acb298..c8e0acabb77 100644 --- a/mindspore/python/mindspore/common/tensor.py +++ b/mindspore/python/mindspore/common/tensor.py @@ -4488,6 +4488,13 @@ class Tensor(Tensor_): self._init_check() return tensor_operator_registry.get('trunc')(self) + def where(self, condition, y): + r""" + For details, please refer to :func:`mindspore.ops.where` + """ + self._init_check() + return tensor_operator_registry.get('where')(condition, self, y) + def imag(self): r""" Returns a new tensor containing imaginary value of the input tensor. diff --git a/mindspore/python/mindspore/numpy/array_ops.py b/mindspore/python/mindspore/numpy/array_ops.py index f300e8ce27f..5100411a4ad 100644 --- a/mindspore/python/mindspore/numpy/array_ops.py +++ b/mindspore/python/mindspore/numpy/array_ops.py @@ -711,8 +711,7 @@ def where(condition, x=None, y=None): x = F.cast(x, dtype) if not _check_same_type(dtype2, dtype): y = F.cast(y, dtype) - is_bool = _check_same_type(dtype1, mstype.bool_) and _check_same_type( - dtype2, mstype.bool_) + is_bool = _check_same_type(dtype1, mstype.bool_) and _check_same_type(dtype2, mstype.bool_) if is_bool: # select does not support bool type for x or y x = F.cast(x, mstype.float32) diff --git a/mindspore/python/mindspore/ops/function/__init__.py b/mindspore/python/mindspore/ops/function/__init__.py index 9f1bd66d173..d054f308b48 100644 --- a/mindspore/python/mindspore/ops/function/__init__.py +++ b/mindspore/python/mindspore/ops/function/__init__.py @@ -105,6 +105,7 @@ from .array_func import ( matrix_set_diag, diag, masked_select, + where, meshgrid, affine_grid, fills, diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index f3a037495d3..b927f6f3957 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -268,6 +268,40 @@ def eye(n, m, t): return eye_(n, m, t) +def where(condition, x, y): + r""" + Returns a tensor whose elements are selected from either `x` or `y` depending on `condition`. + + ..math:: + output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases} + + Args: + condition (Bool Tensor, bool, scalar): If True, yield `x` otherwise yield `y`. + x (Union[Tensor, Scalar]): Value (if `x` is a scalar) or values selected at indices where condition is True. + y (Union[Tensor, Scalar]): Value (if `y` is a scalar) or values selected at indices where condition is False. + + Returns: + Tensor, elements are selected from `x` and `y`. + + Raises: + ValueError: If `condition` can not be broadcast to the shape of `x`. + + Examples: + >>> a = Tensor(np.arange(4).reshape((2, 2)), mstype.float32) + >>> b = Tensor(np.ones((2, 2)), mstype.float16) + >>> condition = a < 3 + >>> output = ops.where(condition, a, b) + >>> print(output) + [[0. 1.] + [2. 1.]] + """ + not_op = P.LogicalNot() + x_mask = cast_(condition, mstype.bool_) + y_mask = not_op(x_mask) + output = x * x_mask + y * y_mask + return output + + def reverse(x, axis): """ Reverses specific dimensions of a tensor. @@ -5226,6 +5260,7 @@ __all__ = [ 'one_hot', 'masked_fill', 'masked_select', + 'where', 'narrow', 'scatter_add', 'scatter_mul', diff --git a/mindspore/python/mindspore/ops/function/math_func.py b/mindspore/python/mindspore/ops/function/math_func.py index 92b9be48aed..8a2846b8495 100644 --- a/mindspore/python/mindspore/ops/function/math_func.py +++ b/mindspore/python/mindspore/ops/function/math_func.py @@ -712,7 +712,8 @@ def subtract(x, other, *, alpha=1): output[i] = x[i] - alpha * y[i] Args: - other (Union[Tensor, number.Number]): The tensor or number to be subtracted. + x (Union[Tensor, number.Number]): The tensor or number to be subtracted. + other (Union[Tensor, number.Number]): The tensor or number to subtract. Keyword Args: alpha (Number): The multiplier for `other`. Default: 1. diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index bc41535a4f3..05ab6703a3e 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -356,6 +356,7 @@ tensor_operator_registry.register('ne', ne) tensor_operator_registry.register('sinh', sinh) tensor_operator_registry.register('sort', P.Sort) tensor_operator_registry.register('trunc', trunc) +tensor_operator_registry.register('where', where) tensor_operator_registry.register('imag', imag) tensor_operator_registry.register('repeat_interleave', repeat_interleave) tensor_operator_registry.register('rad2deg', rad2deg) diff --git a/tests/st/ops/test_ops_where.py b/tests/st/ops/test_ops_where.py new file mode 100644 index 00000000000..e38671ba1d7 --- /dev/null +++ b/tests/st/ops/test_ops_where.py @@ -0,0 +1,36 @@ +import numpy as np +import pytest +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindspore import context + + +class Net(nn.Cell): + def construct(self, condition, x, y): + return ops.where(condition, x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_arm_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_ops_where(mode): + """ + Feature: ops.where + Description: Verify the result of ops.where + Expectation: success + """ + context.set_context(mode=mode) + net = Net() + x = Tensor(np.arange(4).reshape((2, 2)), mstype.float32) + y = Tensor(np.ones((2, 2)), mstype.float32) + condition = x < 3 + output = net(condition, x, y) + expected = np.array([[0, 1], [2, 1]], dtype=np.float32) + assert np.allclose(output.asnumpy(), expected) diff --git a/tests/st/tensor/test_where.py b/tests/st/tensor/test_where.py new file mode 100644 index 00000000000..b67db4e4edc --- /dev/null +++ b/tests/st/tensor/test_where.py @@ -0,0 +1,47 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.common.dtype as mstype +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context + + +class Net(nn.Cell): + def construct(self, condition, x, y): + return x.where(condition, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE]) +def test_tensor_where(mode): + """ + Feature: tensor.where + Description: Verify the result of where + Expectation: success + """ + context.set_context(mode=mode) + x = Tensor(np.arange(4).reshape((2, 2)), mstype.float32) + y = Tensor(np.ones((2, 2)), mstype.float32) + condition = x < 3 + net = Net() + output = net(condition, x, y) + expected = np.array([[0, 1], [2, 1]], dtype=np.float32) + assert np.allclose(output.asnumpy(), expected)