diff --git a/mindspore/ops/composite/multitype_ops/_compile_utils.py b/mindspore/ops/composite/multitype_ops/_compile_utils.py index 4848084f26e..e7a8d0d4283 100644 --- a/mindspore/ops/composite/multitype_ops/_compile_utils.py +++ b/mindspore/ops/composite/multitype_ops/_compile_utils.py @@ -72,6 +72,7 @@ def _broadcast(broadcast_shape, x): return x multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape) if multiples: + x = F.reshape(x, const_utils.expanded_shape(F.shape(x), len(multiples) - F.rank(x))) return F.tile(x, multiples) return x @@ -794,29 +795,42 @@ def ignore_dim_expand(idx): def remove_ignored_dim(idx, value_shape, data_rank): """Removes dimensions in value that correspond to dimension expansion flags in index.""" has_ellipsis = False - has_true = False + has_leading_true = False + has_trailing_true = False + cnt_leading_expanded = 0 cnt_trailing_expanded = 0 cnt_not_dim_expand = 0 for i in idx: - if not i is True and not i is None: - cnt_not_dim_expand += 1 - if const_utils.is_ellipsis(i): - has_ellipsis = True - elif has_ellipsis: - if i is None: + if i is True: + if has_ellipsis: + has_trailing_true = True + else: + has_leading_true = True + elif i is None: + if has_ellipsis: cnt_trailing_expanded += 1 - elif i is True and not has_true: - has_true = True - if has_true and cnt_not_dim_expand + 1 < data_rank: - cnt_trailing_expanded += 1 + else: + cnt_leading_expanded += 1 + else: + if const_utils.is_ellipsis(i): + has_ellipsis = True + cnt_not_dim_expand += 1 + if cnt_not_dim_expand + 1 < data_rank: + if has_leading_true: + cnt_leading_expanded += 1 + elif has_trailing_true: + cnt_trailing_expanded += 1 + + value_starting_pos = 0 + while cnt_leading_expanded > 0 and value_shape[value_starting_pos] == 1: + value_starting_pos += 1 + cnt_leading_expanded -= 1 - if cnt_trailing_expanded == 0: - return value_shape value_expanded_pos = len(value_shape) - cnt_trailing_expanded value_expanded_not_unit = False - for i in value_shape[value_expanded_pos:]: + for i in const_utils.tuple_slice(value_shape, value_expanded_pos, None): if i != 1: value_expanded_not_unit = True if value_expanded_pos < 0 or value_expanded_not_unit: const_utils.raise_value_error('shape mismatch') - return value_shape[:value_expanded_pos] + return const_utils.tuple_slice(value_shape, value_starting_pos, value_expanded_pos) diff --git a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py index 598df06c495..3215d4bea36 100644 --- a/mindspore/ops/composite/multitype_ops/_constexpr_utils.py +++ b/mindspore/ops/composite/multitype_ops/_constexpr_utils.py @@ -830,3 +830,14 @@ def normalize_stop(stop, dim_size): @constexpr def is_ellipsis(x): return x is Ellipsis + + +@constexpr +def tuple_slice(tup, start, end): + """get sliced tuple from start and end.""" + return tup[start:end] + + +@constexpr +def expanded_shape(shape, expand_size): + return (1,)*expand_size + shape diff --git a/tests/st/pynative/test_tensor_setitem.py b/tests/st/pynative/test_tensor_setitem.py index b272371ecdb..7d5a9cc0832 100644 --- a/tests/st/pynative/test_tensor_setitem.py +++ b/tests/st/pynative/test_tensor_setitem.py @@ -129,6 +129,7 @@ def test_setitem_by_tuple_with_list(): x[0, True, 0, None, True] = [-2, -2, -2, -2] x[0, ..., None] = [[-3], [-3], [-3], [-3]] x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]] + x[None, True, [1, 0], (False, True, True), [2]] = [[2, 3]] return x setup_testcase(x, cases)