Enable select broadcast
This commit is contained in:
parent
070ef4807f
commit
6a7b9e7274
|
@ -1600,6 +1600,47 @@ def _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is
|
|||
f"then the input[x] must be a Tensor.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_select_shape_same(cond_shape, x_shape, y_shape):
|
||||
"""Check if input of select has same shape."""
|
||||
return cond_shape == x_shape and x_shape == y_shape and cond_shape == y_shape
|
||||
|
||||
|
||||
@constexpr
|
||||
def get_max_value(x, y, z):
|
||||
if x >= y and x >= z:
|
||||
return x
|
||||
if y >= x and y >= z:
|
||||
return y
|
||||
return z
|
||||
|
||||
|
||||
@constexpr
|
||||
def _calc_broadcast_shape(cond_shape, x_shape, y_shape):
|
||||
"""Calculate broadcast shape for select"""
|
||||
converted_shape = []
|
||||
cond_reverse = cond_shape[::-1]
|
||||
x_reverse = x_shape[::-1]
|
||||
y_reverse = y_shape[::-1]
|
||||
max_len = get_max_value(len(cond_reverse), len(x_reverse), len(y_reverse))
|
||||
i = 0
|
||||
while i < max_len:
|
||||
cond_element = 1 if i >= len(cond_reverse) else cond_reverse[i]
|
||||
x_element = 1 if i >= len(x_reverse) else x_reverse[i]
|
||||
y_element = 1 if i >= len(y_reverse) else y_reverse[i]
|
||||
broadcast_element = get_max_value(cond_element, x_element, y_element)
|
||||
if cond_element not in (1, broadcast_element):
|
||||
raise ValueError(f"For select, condition input can not broadcast at index {i}")
|
||||
if x_element not in (1, broadcast_element):
|
||||
raise ValueError(f"For select, x input can not broadcast at index {i}")
|
||||
if y_element not in (1, broadcast_element):
|
||||
raise ValueError(f"For select, y input can not broadcast at index {i}")
|
||||
converted_shape.append(broadcast_element)
|
||||
i = i + 1
|
||||
converted_shape.reverse()
|
||||
return tuple(converted_shape)
|
||||
|
||||
|
||||
def select(cond, x, y):
|
||||
r"""
|
||||
The conditional tensor determines whether the corresponding element in the output must be
|
||||
|
@ -1675,6 +1716,19 @@ def select(cond, x, y):
|
|||
input_y = cast_(input_y, mstype.int32)
|
||||
else:
|
||||
input_y = cast_(input_y, mstype.float32)
|
||||
|
||||
if is_x_tensor and is_y_tensor and is_cond_tensor:
|
||||
x_shape = F.shape(x)
|
||||
y_shape = F.shape(y)
|
||||
cond_shape = F.shape(cond)
|
||||
all_constant = F.isconstant(cond_shape) and F.isconstant(x_shape) and F.isconstant(y_shape)
|
||||
if all_constant and not _check_select_shape_same(cond_shape, x_shape, y_shape):
|
||||
broadcast_shape = _calc_broadcast_shape(cond_shape, x_shape, y_shape)
|
||||
new_cond = F.broadcast_to(cond, broadcast_shape)
|
||||
new_x = F.broadcast_to(x, broadcast_shape)
|
||||
new_y = F.broadcast_to(y, broadcast_shape)
|
||||
return tensor_select_(new_cond, new_x, new_y)
|
||||
|
||||
return tensor_select_(cond, input_x, input_y)
|
||||
|
||||
|
||||
|
|
|
@ -22,6 +22,9 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
import mindspore.ops as ops
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import jit
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -87,7 +90,6 @@ def test_select_int32():
|
|||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -109,6 +111,25 @@ def test_functional_select_scalar():
|
|||
assert np.all(-diff < error)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_functional_select_broadcast():
|
||||
"""
|
||||
Feature: Test functional select operator support broadcast input.
|
||||
Description: Operator select's support broadcast input.
|
||||
Expectation: Assert result.
|
||||
"""
|
||||
cond = Tensor(np.random.rand(1, 65, 54, 12, 5, 2), dtype=mstype.bool_)
|
||||
x = Tensor(np.random.rand(5, 5, 65, 1, 12, 5, 2).astype(np.float32))
|
||||
y = Tensor(np.random.rand(65, 54, 1, 5, 2).astype(np.float32))
|
||||
@jit
|
||||
def foo(a, b, c):
|
||||
return F.select(a, b, c)
|
||||
ret = foo(cond, x, y)
|
||||
assert ret.shape == (5, 5, 65, 54, 12, 5, 2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
|
Loading…
Reference in New Issue