!28330 functional select support x or y is a scalar

Merge pull request !28330 from wangnan39/support_functional_select
This commit is contained in:
i-robot 2022-01-05 07:13:49 +00:00 committed by Gitee
commit 221246711a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 203 additions and 1 deletions

View File

@ -20,6 +20,7 @@
from mindspore.common._register_for_tensor import tensor_operator_registry
from mindspore.common import ms_function
from mindspore.common import Tensor
from mindspore.common import dtype as mstype
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.nn.grad.cell_grad import _JvpInner
@ -45,7 +46,6 @@ isinstance_ = P.IsInstance()
eye = P.Eye()
fill = P.Fill()
tile = P.Tile()
select = P.Select()
size = P.Size()
ones_like = P.OnesLike()
shape = P.Shape()
@ -153,6 +153,7 @@ csr_mul = _csr_ops.CSRMul()
csr_mv = _csr_ops.CSRMV()
csr_reduce_sum = _csr_ops.CSRReduceSum()
_select = P.Select()
def pack(x):
"""Call stack in this pack function."""
@ -397,6 +398,117 @@ def _raise_type_error():
raise TypeError("The inputs type should be a Tensor, tuple or list of Tensor.")
@constexpr
def _check_select_type(scalar, tensor_type, scalar_name, tensor_name):
if isinstance(scalar, int) and tensor_type != mstype.int32:
raise TypeError(f"For functional operator[select], the input[{scalar_name}] is int, "
f"then the input[{tensor_name}] must be a Tensor of int32.")
if isinstance(scalar, float) and tensor_type != mstype.float32:
raise TypeError(f"For functional operator[select], the input[{scalar_name}] is float, "
f"then the input[{tensor_name}] must be a Tensor of float32.")
@constexpr
def _check_select_input(is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor):
if is_x_scalar and not is_y_tensor:
raise TypeError(f"For functional operator[select], the input[x] is int or float, "
f"then the input[y] must be a Tensor.")
if is_y_scalar and not is_x_tensor:
raise TypeError(f"For functional operator[select], the input[y] is int or float, "
f"then the input[x] must be a Tensor.")
def select(cond, x, y):
r"""
Returns the selected elements, either from input :math:`x` or input :math:`y`, depending on the condition `cond`.
Given a tensor as input, this operation inserts a dimension of 1 at the dimension,
it was invalid when both math: 'x' and math: 'y' are none.
Keep in mind that the shape of the output tensor can vary depending
on how many true values are in the input. Indexes are output in row-first
order.
The conditional tensor acts as an optional compensation (mask), which
determines whether the corresponding element / row in the output must be
selected from :math:`x` (if true) or :math:`y` (if false) based on the value of each
element.
It can be defined as:
.. math::
out_i = \begin{cases}
x_i, & \text{if } condition_i \\
y_i, & \text{otherwise}
\end{cases}
If condition is a vector, then :math:`x` and :math:`y` are higher-dimensional matrices, then it
chooses to copy that row (external dimensions) from :math:`x` and :math:`y`. If condition has
the same shape as :math:`x` and :math:`y`, you can choose to copy these elements from :math:`x`
and :math:`y`.
Inputs:
- **cond** (Tensor[bool]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
The condition tensor, decides which element is chosen.
- **x** (Union[Tensor, int, float]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
The first input tensor. If x is int or float, it will be cast to the type of int32 or float32, and broadcast
to the same shape as y. One of x and y must be a Tensor.
- **y** (Union[Tensor, int, float]) - The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
The second input tensor. If y is int or float, it will be cast to the type of int32 or float32, and broadcast
to the same shape as x. One of x and y must be a Tensor.
Outputs:
Tensor, has the same shape as `cond`. The shape is :math:`(x_1, x_2, ..., x_N, ..., x_R)`.
Raises:
TypeError: If `x` or `y` is not a Tensor, int or float.
ValueError: The shapes of inputs not equal.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> # 1) Both inputs are Tensor
>>> import mindspore
>>> from mindspore import Tensor, ops
>>>
>>> cond = Tensor([True, False])
>>> x = Tensor([2,3], mindspore.float32)
>>> y = Tensor([1,2], mindspore.float32)
>>> output = ops.select(cond, x, y)
>>> print(output)
[2. 2.]
>>> # 2) y is a float
>>> cond = Tensor([True, False])
>>> x = Tensor([2,3], mindspore.float32)
>>> y = 2.0
>>> output = ops.select(cond, x, y)
>>> print(output)
[2. 2.]
"""
is_x_scalar = isinstance(x, (int, float))
is_y_scalar = isinstance(y, (int, float))
is_x_tensor = isinstance(x, Tensor)
is_y_tensor = isinstance(y, Tensor)
_check_select_input(is_x_scalar, is_y_scalar, is_x_tensor, is_y_tensor)
input_x = x
input_y = y
if is_x_scalar:
_check_select_type(x, y.dtype, "x", "y")
input_x = zeros_like(y) + x
if isinstance(x, int):
input_x = cast(input_x, mstype.int32)
else:
input_x = cast(input_x, mstype.float32)
if is_y_scalar:
_check_select_type(y, x.dtype, "y", "x")
input_y = zeros_like(x) + y
if isinstance(y, int):
input_y = cast(input_y, mstype.int32)
else:
input_y = cast(input_y, mstype.float32)
return _select(cond, input_x, input_y)
tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive(_constants.kTupleGetItem)
list_getitem = Primitive('list_getitem')

View File

@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore as ms
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops as ops
from mindspore.ops import operations as P
from mindspore.train.model import Model
@ -63,3 +65,24 @@ def test_select_2_2():
inputa = np.random.randn(2, 2).astype(np.float32)
inputb = np.random.randn(2, 2).astype(np.float32)
cmp_select(input_cond, inputa, inputb)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_functional_select_scalar():
"""
Feature: Test functional select operator. Support x or y is a int/float.
Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int.
Expectation: Assert result.
"""
cond = np.array([[True, False], [True, False]]).astype(np.bool)
x = np.array([[12, 1], [1, 0]]).astype(np.int32)
y = 2
output = ops.select(Tensor(cond), Tensor(x), y)
expect = [[12, 2], [1, 2]]
error = np.ones(shape=[2, 2]) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)

View File

@ -16,9 +16,11 @@
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.ops as ops
from mindspore.ops import operations as P
@ -83,3 +85,45 @@ def test_select_int32():
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_functional_select_scalar():
"""
Feature: Test functional select operator. Support x or y is a int/float.
Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int.
Expectation: Assert result.
"""
cond = np.array([[True, False], [True, False]]).astype(np.bool)
x = np.array([[12, 1], [1, 0]]).astype(np.int32)
y = 2
output = ops.select(Tensor(cond), Tensor(x), y)
print(output.asnumpy())
expect = [[12, 2], [1, 2]]
error = np.ones(shape=[2, 2]) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_functional_select_type_error():
"""
Feature: Functional select support scalar.
Description: If y is a int, x must be a Tensor with int32 type. If y is a float, x must be a Tensor with float32.
Expectation: TypeError.
"""
input_cond = Tensor([True, True])
input_x_int = Tensor([2, 3], mindspore.int32)
input_x_float = Tensor([2, 3], mindspore.float32)
with pytest.raises(TypeError):
ops.select(input_cond, input_x_int, 2.0)
with pytest.raises(TypeError):
ops.select(input_cond, input_x_float, 2)

View File

@ -19,8 +19,10 @@ import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.ops as ops
from mindspore.ops import operations as P
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -59,3 +61,24 @@ def test_select():
output = select(Tensor(cond), Tensor(x), Tensor(y))
expect = np.array([[1, 0], [1, 1]]).astype(np.bool)
assert np.all(output.asnumpy() == expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_functional_select_scalar():
"""
Feature: Test functional select operator. Support x or y is a int/float.
Description: Operator select's input `x` is a Tensor with int32 type, input `y` is a int.
Expectation: Assert result.
"""
context.set_context(device_target="GPU")
cond = np.array([[True, False], [True, False]]).astype(np.bool)
x = np.array([[12, 1], [1, 0]]).astype(np.int32)
y = 2
output = ops.select(Tensor(cond), Tensor(x), y)
expect = [[12, 2], [1, 2]]
error = np.ones(shape=[2, 2]) * 1.0e-6
diff = output.asnumpy() - expect
assert np.all(diff < error)
assert np.all(-diff < error)