Enable select broadcast

This commit is contained in:
liangzhibo 2023-03-06 19:28:42 +08:00
parent 070ef4807f
commit 6a7b9e7274
2 changed files with 76 additions and 1 deletions

View File

@ -1600,6 +1600,47 @@ def _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is
f"then the input[x] must be a Tensor.")
@constexpr
def _check_select_shape_same(cond_shape, x_shape, y_shape):
"""Check if input of select has same shape."""
return cond_shape == x_shape and x_shape == y_shape and cond_shape == y_shape
@constexpr
def get_max_value(x, y, z):
if x >= y and x >= z:
return x
if y >= x and y >= z:
return y
return z
@constexpr
def _calc_broadcast_shape(cond_shape, x_shape, y_shape):
"""Calculate broadcast shape for select"""
converted_shape = []
cond_reverse = cond_shape[::-1]
x_reverse = x_shape[::-1]
y_reverse = y_shape[::-1]
max_len = get_max_value(len(cond_reverse), len(x_reverse), len(y_reverse))
i = 0
while i < max_len:
cond_element = 1 if i >= len(cond_reverse) else cond_reverse[i]
x_element = 1 if i >= len(x_reverse) else x_reverse[i]
y_element = 1 if i >= len(y_reverse) else y_reverse[i]
broadcast_element = get_max_value(cond_element, x_element, y_element)
if cond_element not in (1, broadcast_element):
raise ValueError(f"For select, condition input can not broadcast at index {i}")
if x_element not in (1, broadcast_element):
raise ValueError(f"For select, x input can not broadcast at index {i}")
if y_element not in (1, broadcast_element):
raise ValueError(f"For select, y input can not broadcast at index {i}")
converted_shape.append(broadcast_element)
i = i + 1
converted_shape.reverse()
return tuple(converted_shape)
def select(cond, x, y):
r"""
The conditional tensor determines whether the corresponding element in the output must be
@ -1675,6 +1716,19 @@ def select(cond, x, y):
input_y = cast_(input_y, mstype.int32)
else:
input_y = cast_(input_y, mstype.float32)
if is_x_tensor and is_y_tensor and is_cond_tensor:
x_shape = F.shape(x)
y_shape = F.shape(y)
cond_shape = F.shape(cond)
all_constant = F.isconstant(cond_shape) and F.isconstant(x_shape) and F.isconstant(y_shape)
if all_constant and not _check_select_shape_same(cond_shape, x_shape, y_shape):
broadcast_shape = _calc_broadcast_shape(cond_shape, x_shape, y_shape)
new_cond = F.broadcast_to(cond, broadcast_shape)
new_x = F.broadcast_to(x, broadcast_shape)
new_y = F.broadcast_to(y, broadcast_shape)
return tensor_select_(new_cond, new_x, new_y)
return tensor_select_(cond, input_x, input_y)

View File

@ -22,6 +22,9 @@ import mindspore.nn as nn
from mindspore import Tensor
import mindspore.ops as ops
from mindspore.ops import operations as P
from mindspore.common.api import jit
from mindspore.ops import functional as F
from mindspore import dtype as mstype
class Net(nn.Cell):
@ -87,7 +90,6 @@ def test_select_int32():
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -109,6 +111,25 @@ def test_functional_select_scalar():
assert np.all(-diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_functional_select_broadcast():
"""
Feature: Test functional select operator support broadcast input.
Description: Operator select's support broadcast input.
Expectation: Assert result.
"""
cond = Tensor(np.random.rand(1, 65, 54, 12, 5, 2), dtype=mstype.bool_)
x = Tensor(np.random.rand(5, 5, 65, 1, 12, 5, 2).astype(np.float32))
y = Tensor(np.random.rand(65, 54, 1, 5, 2).astype(np.float32))
@jit
def foo(a, b, c):
return F.select(a, b, c)
ret = foo(cond, x, y)
assert ret.shape == (5, 5, 65, 54, 12, 5, 2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard