forked from mindspore-Ecosystem/mindspore
!28330 functional select support x or y is a scalar
Merge pull request !28330 from wangnan39/support_functional_select
This commit is contained in:
commit
221246711a
|
@ -20,6 +20,7 @@
|
|||
from mindspore.common._register_for_tensor import tensor_operator_registry
|
||||
from mindspore.common import ms_function
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.nn.grad.cell_grad import _JvpInner
|
||||
|
@ -45,7 +46,6 @@ isinstance_ = P.IsInstance()
|
|||
eye = P.Eye()
|
||||
fill = P.Fill()
|
||||
tile = P.Tile()
|
||||
select = P.Select()
|
||||
size = P.Size()
|
||||
ones_like = P.OnesLike()
|
||||
shape = P.Shape()
|
||||
|
@ -153,6 +153,7 @@ csr_mul = _csr_ops.CSRMul()
|
|||
csr_mv = _csr_ops.CSRMV()
|
||||
csr_reduce_sum = _csr_ops.CSRReduceSum()
|
||||
|
||||
_select = P.Select()
|
||||
|
||||
def pack(x):
|
||||
"""Call stack in this pack function."""
|
||||
|
@ -397,6 +398,117 @@ def _raise_type_error():
|
|||
raise TypeError("The inputs type should be a Tensor, tuple or list of Tensor.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_select_type(scalar, tensor_type, scalar_name, tensor_name):
|
||||
if isinstance(scalar, int) and tensor_type != mstype.int32:
|
||||
raise TypeError(f"For functional operator[select], the input[{scalar_name}] is int, "
|
||||
f"then the input[{tensor_name}] must be a Tensor of int32.")
|
||||
if isinstance(scalar, float) and tensor_type != mstype.float32:
|
||||
raise TypeError(f"For functional operator[select], the input[{scalar_name}] is float, "
|
||||
f"then the input[{tensor_name}] must be a Tensor of float32.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_select_input(is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor):
|
||||
if is_x_scalar and not is_y_tensor:
|
||||
raise TypeError(f"For functional operator[select], the input[x] is int or float, "
|
||||
f"then the input[y] must be a Tensor.")
|
||||
if is_y_scalar and not is_x_tensor:
|
||||
raise TypeError(f"For functional operator[select], the input[y] is int or float, "
|
||||
f"then the input[x] must be a Tensor.")
|
||||
|
||||
def select(cond, x, y):
|
||||
r"""
|
||||
Returns the selected elements, either from input :math:`x` or input :math:`y`, depending on the condition `cond`.
|
||||
|
||||
Given a tensor as input, this operation inserts a dimension of 1 at the dimension,
|
||||
it was invalid when both math: 'x' and math: 'y' are none.
|
||||
Keep in mind that the shape of the output tensor can vary depending
|
||||
on how many true values are in the input. Indexes are output in row-first
|
||||
order.
|
||||
|
||||
The conditional tensor acts as an optional compensation (mask), which
|
||||
determines whether the corresponding element / row in the output must be
|
||||
selected from :math:`x` (if true) or :math:`y` (if false) based on the value of each
|
||||
element.
|
||||
|
||||
It can be defined as:
|
||||
|
||||
.. math::
|
||||
out_i = \begin{cases}
|
||||
x_i, & \text{if } condition_i \\
|
||||
y_i, & \text{otherwise}
|
||||
\end{cases}
|
||||
|
||||
If condition is a vector, then :math:`x` and :math:`y` are higher-dimensional matrices, then it
|
||||
chooses to copy that row (external dimensions) from :math:`x` and :math:`y`. If condition has
|
||||
the same shape as :math:`x` and :math:`y`, you can choose to copy these elements from :math:`x`
|
||||
and :math:`y`.
|
||||
|
||||
Inputs:
|
||||
- **cond** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
The condition tensor, decides which element is chosen.
|
||||
- **x** (Union[Tensor, int, float]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
The first input tensor. If x is int or float, it will be cast to the type of int32 or float32, and broadcast
|
||||
to the same shape as y. One of x and y must be a Tensor.
|
||||
- **y** (Union[Tensor, int, float]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
The second input tensor. If y is int or float, it will be cast to the type of int32 or float32, and broadcast
|
||||
to the same shape as x. One of x and y must be a Tensor.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape as `cond`. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` or `y` is not a Tensor, int or float.
|
||||
ValueError: The shapes of inputs not equal.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> # 1) Both inputs are Tensor
|
||||
>>> import mindspore
|
||||
>>> from mindspore import Tensor, ops
|
||||
>>>
|
||||
>>> cond = Tensor([True, False])
|
||||
>>> x = Tensor([2,3], mindspore.float32)
|
||||
>>> y = Tensor([1,2], mindspore.float32)
|
||||
>>> output = ops.select(cond, x, y)
|
||||
>>> print(output)
|
||||
[2. 2.]
|
||||
>>> # 2) y is a float
|
||||
>>> cond = Tensor([True, False])
|
||||
>>> x = Tensor([2,3], mindspore.float32)
|
||||
>>> y = 2.0
|
||||
>>> output = ops.select(cond, x, y)
|
||||
>>> print(output)
|
||||
[2. 2.]
|
||||
"""
|
||||
is_x_scalar = isinstance(x, (int, float))
|
||||
is_y_scalar = isinstance(y, (int, float))
|
||||
is_x_tensor = isinstance(x, Tensor)
|
||||
is_y_tensor = isinstance(y, Tensor)
|
||||
_check_select_input(is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor)
|
||||
input_x = x
|
||||
input_y = y
|
||||
if is_x_scalar:
|
||||
_check_select_type(x, y.dtype, "x", "y")
|
||||
input_x = zeros_like(y) + x
|
||||
if isinstance(x, int):
|
||||
input_x = cast(input_x, mstype.int32)
|
||||
else:
|
||||
input_x = cast(input_x, mstype.float32)
|
||||
|
||||
if is_y_scalar:
|
||||
_check_select_type(y, x.dtype, "y", "x")
|
||||
input_y = zeros_like(x) + y
|
||||
if isinstance(y, int):
|
||||
input_y = cast(input_y, mstype.int32)
|
||||
else:
|
||||
input_y = cast(input_y, mstype.float32)
|
||||
return _select(cond, input_x, input_y)
|
||||
|
||||
|
||||
tuple_setitem = Primitive('tuple_setitem')
|
||||
tuple_getitem = Primitive(_constants.kTupleGetItem)
|
||||
list_getitem = Primitive('list_getitem')
|
||||
|
|
|
@ -12,12 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.train.model import Model
|
||||
|
||||
|
@ -63,3 +65,24 @@ def test_select_2_2():
|
|||
inputa = np.random.randn(2, 2).astype(np.float32)
|
||||
inputb = np.random.randn(2, 2).astype(np.float32)
|
||||
cmp_select(input_cond, inputa, inputb)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_functional_select_scalar():
|
||||
"""
|
||||
Feature: Test functional select operator. Support x or y is a int/float.
|
||||
Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int.
|
||||
Expectation: Assert result.
|
||||
"""
|
||||
cond = np.array([[True, False], [True, False]]).astype(np.bool)
|
||||
x = np.array([[12, 1], [1, 0]]).astype(np.int32)
|
||||
y = 2
|
||||
output = ops.select(Tensor(cond), Tensor(x), y)
|
||||
expect = [[12, 2], [1, 2]]
|
||||
error = np.ones(shape=[2, 2]) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
|
|
@ -16,9 +16,11 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
|
@ -83,3 +85,45 @@ def test_select_int32():
|
|||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_functional_select_scalar():
|
||||
"""
|
||||
Feature: Test functional select operator. Support x or y is a int/float.
|
||||
Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int.
|
||||
Expectation: Assert result.
|
||||
"""
|
||||
cond = np.array([[True, False], [True, False]]).astype(np.bool)
|
||||
x = np.array([[12, 1], [1, 0]]).astype(np.int32)
|
||||
y = 2
|
||||
output = ops.select(Tensor(cond), Tensor(x), y)
|
||||
print(output.asnumpy())
|
||||
expect = [[12, 2], [1, 2]]
|
||||
error = np.ones(shape=[2, 2]) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_functional_select_type_error():
|
||||
"""
|
||||
Feature: Functional select support scalar.
|
||||
Description: If y is a int, x must be a Tensor with int32 type. If y is a float, x must be a Tensor with float32.
|
||||
Expectation: TypeError.
|
||||
"""
|
||||
input_cond = Tensor([True, True])
|
||||
input_x_int = Tensor([2, 3], mindspore.int32)
|
||||
input_x_float = Tensor([2, 3], mindspore.float32)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
ops.select(input_cond, input_x_int, 2.0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
ops.select(input_cond, input_x_float, 2)
|
||||
|
|
|
@ -19,8 +19,10 @@ import pytest
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -59,3 +61,24 @@ def test_select():
|
|||
output = select(Tensor(cond), Tensor(x), Tensor(y))
|
||||
expect = np.array([[1, 0], [1, 1]]).astype(np.bool)
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_functional_select_scalar():
|
||||
"""
|
||||
Feature: Test functional select operator. Support x or y is a int/float.
|
||||
Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int.
|
||||
Expectation: Assert result.
|
||||
"""
|
||||
context.set_context(device_target="GPU")
|
||||
cond = np.array([[True, False], [True, False]]).astype(np.bool)
|
||||
x = np.array([[12, 1], [1, 0]]).astype(np.int32)
|
||||
y = 2
|
||||
output = ops.select(Tensor(cond), Tensor(x), y)
|
||||
expect = [[12, 2], [1, 2]]
|
||||
error = np.ones(shape=[2, 2]) * 1.0e-6
|
||||
diff = output.asnumpy() - expect
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
|
Loading…
Reference in New Issue