diff --git a/mindspore/python/mindspore/ops/functional.py b/mindspore/python/mindspore/ops/functional.py index 3c1f233de11..e9b73418399 100644 --- a/mindspore/python/mindspore/ops/functional.py +++ b/mindspore/python/mindspore/ops/functional.py @@ -20,6 +20,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore.common import ms_function from mindspore.common import Tensor +from mindspore.common import dtype as mstype from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore.nn.grad.cell_grad import _JvpInner @@ -45,7 +46,6 @@ isinstance_ = P.IsInstance() eye = P.Eye() fill = P.Fill() tile = P.Tile() -select = P.Select() size = P.Size() ones_like = P.OnesLike() shape = P.Shape() @@ -153,6 +153,7 @@ csr_mul = _csr_ops.CSRMul() csr_mv = _csr_ops.CSRMV() csr_reduce_sum = _csr_ops.CSRReduceSum() +_select = P.Select() def pack(x): """Call stack in this pack function.""" @@ -397,6 +398,117 @@ def _raise_type_error(): raise TypeError("The inputs type should be a Tensor, tuple or list of Tensor.") +@constexpr +def _check_select_type(scalar, tensor_type, scalar_name, tensor_name): + if isinstance(scalar, int) and tensor_type != mstype.int32: + raise TypeError(f"For functional operator[select], the input[{scalar_name}] is int, " + f"then the input[{tensor_name}] must be a Tensor of int32.") + if isinstance(scalar, float) and tensor_type != mstype.float32: + raise TypeError(f"For functional operator[select], the input[{scalar_name}] is float, " + f"then the input[{tensor_name}] must be a Tensor of float32.") + + +@constexpr +def _check_select_input(is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor): + if is_x_scalar and not is_y_tensor: + raise TypeError(f"For functional operator[select], the input[x] is int or float, " + f"then the input[y] must be a Tensor.") + if is_y_scalar and not is_x_tensor: + raise TypeError(f"For functional operator[select], the input[y] is int or float, " + f"then the input[x] must be a Tensor.") + +def select(cond, x, y): + r""" + Returns the selected elements, either from input :math:`x` or input :math:`y`, depending on the condition `cond`. + + Given a tensor as input, this operation inserts a dimension of 1 at the dimension, + it was invalid when both math: 'x' and math: 'y' are none. + Keep in mind that the shape of the output tensor can vary depending + on how many true values are in the input. Indexes are output in row-first + order. + + The conditional tensor acts as an optional compensation (mask), which + determines whether the corresponding element / row in the output must be + selected from :math:`x` (if true) or :math:`y` (if false) based on the value of each + element. + + It can be defined as: + + .. math:: + out_i = \begin{cases} + x_i, & \text{if } condition_i \\ + y_i, & \text{otherwise} + \end{cases} + + If condition is a vector, then :math:`x` and :math:`y` are higher-dimensional matrices, then it + chooses to copy that row (external dimensions) from :math:`x` and :math:`y`. If condition has + the same shape as :math:`x` and :math:`y`, you can choose to copy these elements from :math:`x` + and :math:`y`. + + Inputs: + - **cond** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`. + The condition tensor, decides which element is chosen. + - **x** (Union[Tensor, int, float]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`. + The first input tensor. If x is int or float, it will be cast to the type of int32 or float32, and broadcast + to the same shape as y. One of x and y must be a Tensor. + - **y** (Union[Tensor, int, float]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`. + The second input tensor. If y is int or float, it will be cast to the type of int32 or float32, and broadcast + to the same shape as x. One of x and y must be a Tensor. + + Outputs: + Tensor, has the same shape as `cond`. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`. + + Raises: + TypeError: If `x` or `y` is not a Tensor, int or float. + ValueError: The shapes of inputs not equal. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> # 1) Both inputs are Tensor + >>> import mindspore + >>> from mindspore import Tensor, ops + >>> + >>> cond = Tensor([True, False]) + >>> x = Tensor([2,3], mindspore.float32) + >>> y = Tensor([1,2], mindspore.float32) + >>> output = ops.select(cond, x, y) + >>> print(output) + [2. 2.] + >>> # 2) y is a float + >>> cond = Tensor([True, False]) + >>> x = Tensor([2,3], mindspore.float32) + >>> y = 2.0 + >>> output = ops.select(cond, x, y) + >>> print(output) + [2. 2.] + """ + is_x_scalar = isinstance(x, (int, float)) + is_y_scalar = isinstance(y, (int, float)) + is_x_tensor = isinstance(x, Tensor) + is_y_tensor = isinstance(y, Tensor) + _check_select_input(is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor) + input_x = x + input_y = y + if is_x_scalar: + _check_select_type(x, y.dtype, "x", "y") + input_x = zeros_like(y) + x + if isinstance(x, int): + input_x = cast(input_x, mstype.int32) + else: + input_x = cast(input_x, mstype.float32) + + if is_y_scalar: + _check_select_type(y, x.dtype, "y", "x") + input_y = zeros_like(x) + y + if isinstance(y, int): + input_y = cast(input_y, mstype.int32) + else: + input_y = cast(input_y, mstype.float32) + return _select(cond, input_x, input_y) + + tuple_setitem = Primitive('tuple_setitem') tuple_getitem = Primitive(_constants.kTupleGetItem) list_getitem = Primitive('list_getitem') diff --git a/tests/st/ops/ascend/test_tbe_ops/test_select.py b/tests/st/ops/ascend/test_tbe_ops/test_select.py index 56fd8ba0a1a..3cae87cefde 100644 --- a/tests/st/ops/ascend/test_tbe_ops/test_select.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_select.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import pytest import numpy as np import mindspore as ms import mindspore.context as context from mindspore import Tensor from mindspore.nn import Cell +import mindspore.ops as ops from mindspore.ops import operations as P from mindspore.train.model import Model @@ -63,3 +65,24 @@ def test_select_2_2(): inputa = np.random.randn(2, 2).astype(np.float32) inputb = np.random.randn(2, 2).astype(np.float32) cmp_select(input_cond, inputa, inputb) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_functional_select_scalar(): + """ + Feature: Test functional select operator. Support x or y is a int/float. + Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int. + Expectation: Assert result. + """ + cond = np.array([[True, False], [True, False]]).astype(np.bool) + x = np.array([[12, 1], [1, 0]]).astype(np.int32) + y = 2 + output = ops.select(Tensor(cond), Tensor(x), y) + expect = [[12, 2], [1, 2]] + error = np.ones(shape=[2, 2]) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) diff --git a/tests/st/ops/cpu/test_select_op.py b/tests/st/ops/cpu/test_select_op.py index f8bc373150d..a02f2a6d4e7 100644 --- a/tests/st/ops/cpu/test_select_op.py +++ b/tests/st/ops/cpu/test_select_op.py @@ -16,9 +16,11 @@ import numpy as np import pytest +import mindspore import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor +import mindspore.ops as ops from mindspore.ops import operations as P @@ -83,3 +85,45 @@ def test_select_int32(): diff = output.asnumpy() - expect assert np.all(diff < error) assert np.all(-diff < error) + + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_functional_select_scalar(): + """ + Feature: Test functional select operator. Support x or y is a int/float. + Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int. + Expectation: Assert result. + """ + cond = np.array([[True, False], [True, False]]).astype(np.bool) + x = np.array([[12, 1], [1, 0]]).astype(np.int32) + y = 2 + output = ops.select(Tensor(cond), Tensor(x), y) + print(output.asnumpy()) + expect = [[12, 2], [1, 2]] + error = np.ones(shape=[2, 2]) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_functional_select_type_error(): + """ + Feature: Functional select support scalar. + Description: If y is a int, x must be a Tensor with int32 type. If y is a float, x must be a Tensor with float32. + Expectation: TypeError. + """ + input_cond = Tensor([True, True]) + input_x_int = Tensor([2, 3], mindspore.int32) + input_x_float = Tensor([2, 3], mindspore.float32) + + with pytest.raises(TypeError): + ops.select(input_cond, input_x_int, 2.0) + + with pytest.raises(TypeError): + ops.select(input_cond, input_x_float, 2) diff --git a/tests/st/ops/gpu/test_select_op.py b/tests/st/ops/gpu/test_select_op.py index 0bc35116b45..614b83e8bc5 100644 --- a/tests/st/ops/gpu/test_select_op.py +++ b/tests/st/ops/gpu/test_select_op.py @@ -19,8 +19,10 @@ import pytest import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor +import mindspore.ops as ops from mindspore.ops import operations as P + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -59,3 +61,24 @@ def test_select(): output = select(Tensor(cond), Tensor(x), Tensor(y)) expect = np.array([[1, 0], [1, 1]]).astype(np.bool) assert np.all(output.asnumpy() == expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_functional_select_scalar(): + """ + Feature: Test functional select operator. Support x or y is a int/float. + Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int. + Expectation: Assert result. + """ + context.set_context(device_target="GPU") + cond = np.array([[True, False], [True, False]]).astype(np.bool) + x = np.array([[12, 1], [1, 0]]).astype(np.int32) + y = 2 + output = ops.select(Tensor(cond), Tensor(x), y) + expect = [[12, 2], [1, 2]] + error = np.ones(shape=[2, 2]) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error)