add tensor.where

This commit is contained in:
shaojunsong 2022-11-14 21:47:12 +08:00
parent 62d24d2b2a
commit bc783dd49d
16 changed files with 174 additions and 5 deletions

View File

@ -442,6 +442,7 @@ Array操作
mindspore.ops.unsorted_segment_sum
mindspore.ops.unsqueeze
mindspore.ops.unstack
mindspore.ops.where
类型转换
^^^^^^^^^^^^^^^^

View File

@ -0,0 +1,6 @@
mindspore.Tensor.where
=======================
.. py:method:: mindspore.Tensor.where(condition, y)
详情请参考 :func:`mindspore.ops.where`

View File

@ -250,5 +250,6 @@ mindspore.Tensor
mindspore.Tensor.unsqueeze
mindspore.Tensor.var
mindspore.Tensor.view
mindspore.Tensor.where
mindspore.Tensor.xdivy
mindspore.Tensor.xlogy

View File

@ -0,0 +1,21 @@
mindspore.ops.where
====================
.. py:function:: mindspore.ops.where(condition, x, y)
返回一个TensorTensor的元素从 `x``y` 中根据 `condition` 选择。
.. math::
output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases}
参数:
- **condition** (Union[Bool Tensor, bool, scalar]) - 如果是True选取 `x` 中的元素,否则选取 `y` 中的元素。
- **x** (Union[Tensor, Scalar]) - 在 `condition` 为True的索引处选择的值。
- **y** (Union[Tensor, Scalar]) - 当 `condition` 为False的索引处选择的值。
返回:
Tensor其中的元素从 `x``y` 中选取。
异常:
-**ValueError** - `condition` 不可以被广播成 `x` 的shape。

View File

@ -256,6 +256,7 @@
mindspore.Tensor.unsqueeze
mindspore.Tensor.var
mindspore.Tensor.view
mindspore.Tensor.where
mindspore.Tensor.xdivy
mindspore.Tensor.xlogy

View File

@ -442,6 +442,7 @@ Array Operation
mindspore.ops.unsorted_segment_sum
mindspore.ops.unsqueeze
mindspore.ops.unstack
mindspore.ops.where
Type Conversion
^^^^^^^^^^^^^^^

View File

@ -413,6 +413,7 @@ BuiltInTypeMap &GetMethodMap() {
{"sinh", std::string("sinh")}, // sinh()
{"sort", std::string("sort")}, // sort()
{"trunc", std::string("trunc")}, // trunc()
{"where", std::string("where")}, // where()
{"imag", std::string("imag")}, // imag()
}},
{kObjectTypeRowTensorType,

View File

@ -1001,7 +1001,8 @@ def copy(x):
return x
def max(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disable=redefined-builtin
def max(x, axis=None, keepdims=False, initial=None, # pylint: disable=redefined-builtin
where=True): # pylint: disable=redefined-outer-name
"""
Returns the maximum of a tensor or maximum along an axis.
@ -1046,7 +1047,8 @@ def max(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disa
axis=axis, keepdims=keepdims, initial=initial, where=where)
def min(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disable=redefined-builtin
def min(x, axis=None, keepdims=False, initial=None, # pylint: disable=redefined-builtin
where=True): # pylint: disable=redefined-outer-name
"""
Returns the minimum of a tensor or minimum along an axis.
@ -3977,6 +3979,14 @@ def trunc(input):
return F.trunc(input)
def where(x, condition, y):
r"""
Returns a tensor whose elements are selected from either `x` or `y` depending on `condition`.
Please refer to :func:`mindspore.ops.where`.
"""
return F.where(condition, x, y)
def imag(input):
r"""
Returns a new tensor containing imaginary value of the input.

View File

@ -4488,6 +4488,13 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('trunc')(self)
def where(self, condition, y):
r"""
For details, please refer to :func:`mindspore.ops.where`
"""
self._init_check()
return tensor_operator_registry.get('where')(condition, self, y)
def imag(self):
r"""
Returns a new tensor containing imaginary value of the input tensor.

View File

@ -711,8 +711,7 @@ def where(condition, x=None, y=None):
x = F.cast(x, dtype)
if not _check_same_type(dtype2, dtype):
y = F.cast(y, dtype)
is_bool = _check_same_type(dtype1, mstype.bool_) and _check_same_type(
dtype2, mstype.bool_)
is_bool = _check_same_type(dtype1, mstype.bool_) and _check_same_type(dtype2, mstype.bool_)
if is_bool:
# select does not support bool type for x or y
x = F.cast(x, mstype.float32)

View File

@ -105,6 +105,7 @@ from .array_func import (
matrix_set_diag,
diag,
masked_select,
where,
meshgrid,
affine_grid,
fills,

View File

@ -268,6 +268,40 @@ def eye(n, m, t):
return eye_(n, m, t)
def where(condition, x, y):
r"""
Returns a tensor whose elements are selected from either `x` or `y` depending on `condition`.
..math::
output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases}
Args:
condition (Bool Tensor, bool, scalar): If True, yield `x` otherwise yield `y`.
x (Union[Tensor, Scalar]): Value (if `x` is a scalar) or values selected at indices where condition is True.
y (Union[Tensor, Scalar]): Value (if `y` is a scalar) or values selected at indices where condition is False.
Returns:
Tensor, elements are selected from `x` and `y`.
Raises:
ValueError: If `condition` can not be broadcast to the shape of `x`.
Examples:
>>> a = Tensor(np.arange(4).reshape((2, 2)), mstype.float32)
>>> b = Tensor(np.ones((2, 2)), mstype.float16)
>>> condition = a < 3
>>> output = ops.where(condition, a, b)
>>> print(output)
[[0. 1.]
[2. 1.]]
"""
not_op = P.LogicalNot()
x_mask = cast_(condition, mstype.bool_)
y_mask = not_op(x_mask)
output = x * x_mask + y * y_mask
return output
def reverse(x, axis):
"""
Reverses specific dimensions of a tensor.
@ -5226,6 +5260,7 @@ __all__ = [
'one_hot',
'masked_fill',
'masked_select',
'where',
'narrow',
'scatter_add',
'scatter_mul',

View File

@ -712,7 +712,8 @@ def subtract(x, other, *, alpha=1):
output[i] = x[i] - alpha * y[i]
Args:
other (Union[Tensor, number.Number]): The tensor or number to be subtracted.
x (Union[Tensor, number.Number]): The tensor or number to be subtracted.
other (Union[Tensor, number.Number]): The tensor or number to subtract.
Keyword Args:
alpha (Number): The multiplier for `other`. Default: 1.

View File

@ -356,6 +356,7 @@ tensor_operator_registry.register('ne', ne)
tensor_operator_registry.register('sinh', sinh)
tensor_operator_registry.register('sort', P.Sort)
tensor_operator_registry.register('trunc', trunc)
tensor_operator_registry.register('where', where)
tensor_operator_registry.register('imag', imag)
tensor_operator_registry.register('repeat_interleave', repeat_interleave)
tensor_operator_registry.register('rad2deg', rad2deg)

View File

@ -0,0 +1,36 @@
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import context
class Net(nn.Cell):
def construct(self, condition, x, y):
return ops.where(condition, x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_arm_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_ops_where(mode):
"""
Feature: ops.where
Description: Verify the result of ops.where
Expectation: success
"""
context.set_context(mode=mode)
net = Net()
x = Tensor(np.arange(4).reshape((2, 2)), mstype.float32)
y = Tensor(np.ones((2, 2)), mstype.float32)
condition = x < 3
output = net(condition, x, y)
expected = np.array([[0, 1], [2, 1]], dtype=np.float32)
assert np.allclose(output.asnumpy(), expected)

View File

@ -0,0 +1,47 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
class Net(nn.Cell):
def construct(self, condition, x, y):
return x.where(condition, y)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_tensor_where(mode):
"""
Feature: tensor.where
Description: Verify the result of where
Expectation: success
"""
context.set_context(mode=mode)
x = Tensor(np.arange(4).reshape((2, 2)), mstype.float32)
y = Tensor(np.ones((2, 2)), mstype.float32)
condition = x < 3
net = Net()
output = net(condition, x, y)
expected = np.array([[0, 1], [2, 1]], dtype=np.float32)
assert np.allclose(output.asnumpy(), expected)