From 6a7b9e72743068b529d1f7e92d962cdb807ca53c Mon Sep 17 00:00:00 2001 From: liangzhibo Date: Mon, 6 Mar 2023 19:28:42 +0800 Subject: [PATCH] Enable select broadcast --- .../mindspore/ops/function/array_func.py | 54 +++++++++++++++++++ tests/st/ops/cpu/test_select_op.py | 23 +++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/mindspore/python/mindspore/ops/function/array_func.py b/mindspore/python/mindspore/ops/function/array_func.py index ac9f4b10b2f..6227b461a04 100644 --- a/mindspore/python/mindspore/ops/function/array_func.py +++ b/mindspore/python/mindspore/ops/function/array_func.py @@ -1600,6 +1600,47 @@ def _check_select_type(is_cond_tensor, is_x_scalar, is_y_scalar, is_x_tensor, is f"then the input[x] must be a Tensor.") +@constexpr +def _check_select_shape_same(cond_shape, x_shape, y_shape): + """Check if input of select has same shape.""" + return cond_shape == x_shape and x_shape == y_shape and cond_shape == y_shape + + +@constexpr +def get_max_value(x, y, z): + if x >= y and x >= z: + return x + if y >= x and y >= z: + return y + return z + + +@constexpr +def _calc_broadcast_shape(cond_shape, x_shape, y_shape): + """Calculate broadcast shape for select""" + converted_shape = [] + cond_reverse = cond_shape[::-1] + x_reverse = x_shape[::-1] + y_reverse = y_shape[::-1] + max_len = get_max_value(len(cond_reverse), len(x_reverse), len(y_reverse)) + i = 0 + while i < max_len: + cond_element = 1 if i >= len(cond_reverse) else cond_reverse[i] + x_element = 1 if i >= len(x_reverse) else x_reverse[i] + y_element = 1 if i >= len(y_reverse) else y_reverse[i] + broadcast_element = get_max_value(cond_element, x_element, y_element) + if cond_element not in (1, broadcast_element): + raise ValueError(f"For select, condition input can not broadcast at index {i}") + if x_element not in (1, broadcast_element): + raise ValueError(f"For select, x input can not broadcast at index {i}") + if y_element not in (1, broadcast_element): + raise ValueError(f"For select, y input can not broadcast at index {i}") + converted_shape.append(broadcast_element) + i = i + 1 + converted_shape.reverse() + return tuple(converted_shape) + + def select(cond, x, y): r""" The conditional tensor determines whether the corresponding element in the output must be @@ -1675,6 +1716,19 @@ def select(cond, x, y): input_y = cast_(input_y, mstype.int32) else: input_y = cast_(input_y, mstype.float32) + + if is_x_tensor and is_y_tensor and is_cond_tensor: + x_shape = F.shape(x) + y_shape = F.shape(y) + cond_shape = F.shape(cond) + all_constant = F.isconstant(cond_shape) and F.isconstant(x_shape) and F.isconstant(y_shape) + if all_constant and not _check_select_shape_same(cond_shape, x_shape, y_shape): + broadcast_shape = _calc_broadcast_shape(cond_shape, x_shape, y_shape) + new_cond = F.broadcast_to(cond, broadcast_shape) + new_x = F.broadcast_to(x, broadcast_shape) + new_y = F.broadcast_to(y, broadcast_shape) + return tensor_select_(new_cond, new_x, new_y) + return tensor_select_(cond, input_x, input_y) diff --git a/tests/st/ops/cpu/test_select_op.py b/tests/st/ops/cpu/test_select_op.py index a02f2a6d4e7..f5a613bcdd2 100644 --- a/tests/st/ops/cpu/test_select_op.py +++ b/tests/st/ops/cpu/test_select_op.py @@ -22,6 +22,9 @@ import mindspore.nn as nn from mindspore import Tensor import mindspore.ops as ops from mindspore.ops import operations as P +from mindspore.common.api import jit +from mindspore.ops import functional as F +from mindspore import dtype as mstype class Net(nn.Cell): @@ -87,7 +90,6 @@ def test_select_int32(): assert np.all(-diff < error) - @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -109,6 +111,25 @@ def test_functional_select_scalar(): assert np.all(-diff < error) +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_functional_select_broadcast(): + """ + Feature: Test functional select operator support broadcast input. + Description: Operator select's support broadcast input. + Expectation: Assert result. + """ + cond = Tensor(np.random.rand(1, 65, 54, 12, 5, 2), dtype=mstype.bool_) + x = Tensor(np.random.rand(5, 5, 65, 1, 12, 5, 2).astype(np.float32)) + y = Tensor(np.random.rand(65, 54, 1, 5, 2).astype(np.float32)) + @jit + def foo(a, b, c): + return F.select(a, b, c) + ret = foo(cond, x, y) + assert ret.shape == (5, 5, 65, 54, 12, 5, 2) + + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard