add tensor.where
This commit is contained in:
parent
62d24d2b2a
commit
bc783dd49d
|
@ -442,6 +442,7 @@ Array操作
|
|||
mindspore.ops.unsorted_segment_sum
|
||||
mindspore.ops.unsqueeze
|
||||
mindspore.ops.unstack
|
||||
mindspore.ops.where
|
||||
|
||||
类型转换
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
mindspore.Tensor.where
|
||||
=======================
|
||||
|
||||
.. py:method:: mindspore.Tensor.where(condition, y)
|
||||
|
||||
详情请参考 :func:`mindspore.ops.where`。
|
|
@ -250,5 +250,6 @@ mindspore.Tensor
|
|||
mindspore.Tensor.unsqueeze
|
||||
mindspore.Tensor.var
|
||||
mindspore.Tensor.view
|
||||
mindspore.Tensor.where
|
||||
mindspore.Tensor.xdivy
|
||||
mindspore.Tensor.xlogy
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
mindspore.ops.where
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.where(condition, x, y)
|
||||
|
||||
返回一个Tensor,Tensor的元素从 `x` 或 `y` 中根据 `condition` 选择。
|
||||
|
||||
.. math::
|
||||
|
||||
output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases}
|
||||
|
||||
参数:
|
||||
- **condition** (Union[Bool Tensor, bool, scalar]) - 如果是True,选取 `x` 中的元素,否则选取 `y` 中的元素。
|
||||
- **x** (Union[Tensor, Scalar]) - 在 `condition` 为True的索引处选择的值。
|
||||
- **y** (Union[Tensor, Scalar]) - 当 `condition` 为False的索引处选择的值。
|
||||
|
||||
返回:
|
||||
Tensor,其中的元素从 `x` 和 `y` 中选取。
|
||||
|
||||
异常:
|
||||
-**ValueError** - `condition` 不可以被广播成 `x` 的shape。
|
|
@ -256,6 +256,7 @@
|
|||
mindspore.Tensor.unsqueeze
|
||||
mindspore.Tensor.var
|
||||
mindspore.Tensor.view
|
||||
mindspore.Tensor.where
|
||||
mindspore.Tensor.xdivy
|
||||
mindspore.Tensor.xlogy
|
||||
|
||||
|
|
|
@ -442,6 +442,7 @@ Array Operation
|
|||
mindspore.ops.unsorted_segment_sum
|
||||
mindspore.ops.unsqueeze
|
||||
mindspore.ops.unstack
|
||||
mindspore.ops.where
|
||||
|
||||
Type Conversion
|
||||
^^^^^^^^^^^^^^^
|
||||
|
|
|
@ -413,6 +413,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"sinh", std::string("sinh")}, // sinh()
|
||||
{"sort", std::string("sort")}, // sort()
|
||||
{"trunc", std::string("trunc")}, // trunc()
|
||||
{"where", std::string("where")}, // where()
|
||||
{"imag", std::string("imag")}, // imag()
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
|
|
|
@ -1001,7 +1001,8 @@ def copy(x):
|
|||
return x
|
||||
|
||||
|
||||
def max(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disable=redefined-builtin
|
||||
def max(x, axis=None, keepdims=False, initial=None, # pylint: disable=redefined-builtin
|
||||
where=True): # pylint: disable=redefined-outer-name
|
||||
"""
|
||||
Returns the maximum of a tensor or maximum along an axis.
|
||||
|
||||
|
@ -1046,7 +1047,8 @@ def max(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disa
|
|||
axis=axis, keepdims=keepdims, initial=initial, where=where)
|
||||
|
||||
|
||||
def min(x, axis=None, keepdims=False, initial=None, where=True): # pylint: disable=redefined-builtin
|
||||
def min(x, axis=None, keepdims=False, initial=None, # pylint: disable=redefined-builtin
|
||||
where=True): # pylint: disable=redefined-outer-name
|
||||
"""
|
||||
Returns the minimum of a tensor or minimum along an axis.
|
||||
|
||||
|
@ -3977,6 +3979,14 @@ def trunc(input):
|
|||
return F.trunc(input)
|
||||
|
||||
|
||||
def where(x, condition, y):
|
||||
r"""
|
||||
Returns a tensor whose elements are selected from either `x` or `y` depending on `condition`.
|
||||
Please refer to :func:`mindspore.ops.where`.
|
||||
"""
|
||||
return F.where(condition, x, y)
|
||||
|
||||
|
||||
def imag(input):
|
||||
r"""
|
||||
Returns a new tensor containing imaginary value of the input.
|
||||
|
|
|
@ -4488,6 +4488,13 @@ class Tensor(Tensor_):
|
|||
self._init_check()
|
||||
return tensor_operator_registry.get('trunc')(self)
|
||||
|
||||
def where(self, condition, y):
|
||||
r"""
|
||||
For details, please refer to :func:`mindspore.ops.where`
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('where')(condition, self, y)
|
||||
|
||||
def imag(self):
|
||||
r"""
|
||||
Returns a new tensor containing imaginary value of the input tensor.
|
||||
|
|
|
@ -711,8 +711,7 @@ def where(condition, x=None, y=None):
|
|||
x = F.cast(x, dtype)
|
||||
if not _check_same_type(dtype2, dtype):
|
||||
y = F.cast(y, dtype)
|
||||
is_bool = _check_same_type(dtype1, mstype.bool_) and _check_same_type(
|
||||
dtype2, mstype.bool_)
|
||||
is_bool = _check_same_type(dtype1, mstype.bool_) and _check_same_type(dtype2, mstype.bool_)
|
||||
if is_bool:
|
||||
# select does not support bool type for x or y
|
||||
x = F.cast(x, mstype.float32)
|
||||
|
|
|
@ -105,6 +105,7 @@ from .array_func import (
|
|||
matrix_set_diag,
|
||||
diag,
|
||||
masked_select,
|
||||
where,
|
||||
meshgrid,
|
||||
affine_grid,
|
||||
fills,
|
||||
|
|
|
@ -268,6 +268,40 @@ def eye(n, m, t):
|
|||
return eye_(n, m, t)
|
||||
|
||||
|
||||
def where(condition, x, y):
|
||||
r"""
|
||||
Returns a tensor whose elements are selected from either `x` or `y` depending on `condition`.
|
||||
|
||||
..math::
|
||||
output_i = \begin{cases} x_i,\quad &if\ condition_i \\ y_i,\quad &otherwise \end{cases}
|
||||
|
||||
Args:
|
||||
condition (Bool Tensor, bool, scalar): If True, yield `x` otherwise yield `y`.
|
||||
x (Union[Tensor, Scalar]): Value (if `x` is a scalar) or values selected at indices where condition is True.
|
||||
y (Union[Tensor, Scalar]): Value (if `y` is a scalar) or values selected at indices where condition is False.
|
||||
|
||||
Returns:
|
||||
Tensor, elements are selected from `x` and `y`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `condition` can not be broadcast to the shape of `x`.
|
||||
|
||||
Examples:
|
||||
>>> a = Tensor(np.arange(4).reshape((2, 2)), mstype.float32)
|
||||
>>> b = Tensor(np.ones((2, 2)), mstype.float16)
|
||||
>>> condition = a < 3
|
||||
>>> output = ops.where(condition, a, b)
|
||||
>>> print(output)
|
||||
[[0. 1.]
|
||||
[2. 1.]]
|
||||
"""
|
||||
not_op = P.LogicalNot()
|
||||
x_mask = cast_(condition, mstype.bool_)
|
||||
y_mask = not_op(x_mask)
|
||||
output = x * x_mask + y * y_mask
|
||||
return output
|
||||
|
||||
|
||||
def reverse(x, axis):
|
||||
"""
|
||||
Reverses specific dimensions of a tensor.
|
||||
|
@ -5226,6 +5260,7 @@ __all__ = [
|
|||
'one_hot',
|
||||
'masked_fill',
|
||||
'masked_select',
|
||||
'where',
|
||||
'narrow',
|
||||
'scatter_add',
|
||||
'scatter_mul',
|
||||
|
|
|
@ -712,7 +712,8 @@ def subtract(x, other, *, alpha=1):
|
|||
output[i] = x[i] - alpha * y[i]
|
||||
|
||||
Args:
|
||||
other (Union[Tensor, number.Number]): The tensor or number to be subtracted.
|
||||
x (Union[Tensor, number.Number]): The tensor or number to be subtracted.
|
||||
other (Union[Tensor, number.Number]): The tensor or number to subtract.
|
||||
|
||||
Keyword Args:
|
||||
alpha (Number): The multiplier for `other`. Default: 1.
|
||||
|
|
|
@ -356,6 +356,7 @@ tensor_operator_registry.register('ne', ne)
|
|||
tensor_operator_registry.register('sinh', sinh)
|
||||
tensor_operator_registry.register('sort', P.Sort)
|
||||
tensor_operator_registry.register('trunc', trunc)
|
||||
tensor_operator_registry.register('where', where)
|
||||
tensor_operator_registry.register('imag', imag)
|
||||
tensor_operator_registry.register('repeat_interleave', repeat_interleave)
|
||||
tensor_operator_registry.register('rad2deg', rad2deg)
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, condition, x, y):
|
||||
return ops.where(condition, x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_ops_where(mode):
|
||||
"""
|
||||
Feature: ops.where
|
||||
Description: Verify the result of ops.where
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
net = Net()
|
||||
x = Tensor(np.arange(4).reshape((2, 2)), mstype.float32)
|
||||
y = Tensor(np.ones((2, 2)), mstype.float32)
|
||||
condition = x < 3
|
||||
output = net(condition, x, y)
|
||||
expected = np.array([[0, 1], [2, 1]], dtype=np.float32)
|
||||
assert np.allclose(output.asnumpy(), expected)
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self, condition, x, y):
|
||||
return x.where(condition, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_tensor_where(mode):
|
||||
"""
|
||||
Feature: tensor.where
|
||||
Description: Verify the result of where
|
||||
Expectation: success
|
||||
"""
|
||||
context.set_context(mode=mode)
|
||||
x = Tensor(np.arange(4).reshape((2, 2)), mstype.float32)
|
||||
y = Tensor(np.ones((2, 2)), mstype.float32)
|
||||
condition = x < 3
|
||||
net = Net()
|
||||
output = net(condition, x, y)
|
||||
expected = np.array([[0, 1], [2, 1]], dtype=np.float32)
|
||||
assert np.allclose(output.asnumpy(), expected)
|
Loading…
Reference in New Issue