forked from mindspore-Ecosystem/mindspore
fix dim expand and broadcast
This commit is contained in:
parent
84f5085a20
commit
6cb17501d5
|
@ -72,6 +72,7 @@ def _broadcast(broadcast_shape, x):
|
|||
return x
|
||||
multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
|
||||
if multiples:
|
||||
x = F.reshape(x, const_utils.expanded_shape(F.shape(x), len(multiples) - F.rank(x)))
|
||||
return F.tile(x, multiples)
|
||||
return x
|
||||
|
||||
|
@ -794,29 +795,42 @@ def ignore_dim_expand(idx):
|
|||
def remove_ignored_dim(idx, value_shape, data_rank):
|
||||
"""Removes dimensions in value that correspond to dimension expansion flags in index."""
|
||||
has_ellipsis = False
|
||||
has_true = False
|
||||
has_leading_true = False
|
||||
has_trailing_true = False
|
||||
cnt_leading_expanded = 0
|
||||
cnt_trailing_expanded = 0
|
||||
cnt_not_dim_expand = 0
|
||||
for i in idx:
|
||||
if not i is True and not i is None:
|
||||
cnt_not_dim_expand += 1
|
||||
if const_utils.is_ellipsis(i):
|
||||
has_ellipsis = True
|
||||
elif has_ellipsis:
|
||||
if i is None:
|
||||
if i is True:
|
||||
if has_ellipsis:
|
||||
has_trailing_true = True
|
||||
else:
|
||||
has_leading_true = True
|
||||
elif i is None:
|
||||
if has_ellipsis:
|
||||
cnt_trailing_expanded += 1
|
||||
elif i is True and not has_true:
|
||||
has_true = True
|
||||
if has_true and cnt_not_dim_expand + 1 < data_rank:
|
||||
cnt_trailing_expanded += 1
|
||||
else:
|
||||
cnt_leading_expanded += 1
|
||||
else:
|
||||
if const_utils.is_ellipsis(i):
|
||||
has_ellipsis = True
|
||||
cnt_not_dim_expand += 1
|
||||
if cnt_not_dim_expand + 1 < data_rank:
|
||||
if has_leading_true:
|
||||
cnt_leading_expanded += 1
|
||||
elif has_trailing_true:
|
||||
cnt_trailing_expanded += 1
|
||||
|
||||
value_starting_pos = 0
|
||||
while cnt_leading_expanded > 0 and value_shape[value_starting_pos] == 1:
|
||||
value_starting_pos += 1
|
||||
cnt_leading_expanded -= 1
|
||||
|
||||
if cnt_trailing_expanded == 0:
|
||||
return value_shape
|
||||
value_expanded_pos = len(value_shape) - cnt_trailing_expanded
|
||||
value_expanded_not_unit = False
|
||||
for i in value_shape[value_expanded_pos:]:
|
||||
for i in const_utils.tuple_slice(value_shape, value_expanded_pos, None):
|
||||
if i != 1:
|
||||
value_expanded_not_unit = True
|
||||
if value_expanded_pos < 0 or value_expanded_not_unit:
|
||||
const_utils.raise_value_error('shape mismatch')
|
||||
return value_shape[:value_expanded_pos]
|
||||
return const_utils.tuple_slice(value_shape, value_starting_pos, value_expanded_pos)
|
||||
|
|
|
@ -830,3 +830,14 @@ def normalize_stop(stop, dim_size):
|
|||
@constexpr
|
||||
def is_ellipsis(x):
|
||||
return x is Ellipsis
|
||||
|
||||
|
||||
@constexpr
|
||||
def tuple_slice(tup, start, end):
|
||||
"""get sliced tuple from start and end."""
|
||||
return tup[start:end]
|
||||
|
||||
|
||||
@constexpr
|
||||
def expanded_shape(shape, expand_size):
|
||||
return (1,)*expand_size + shape
|
||||
|
|
|
@ -129,6 +129,7 @@ def test_setitem_by_tuple_with_list():
|
|||
x[0, True, 0, None, True] = [-2, -2, -2, -2]
|
||||
x[0, ..., None] = [[-3], [-3], [-3], [-3]]
|
||||
x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]]
|
||||
x[None, True, [1, 0], (False, True, True), [2]] = [[2, 3]]
|
||||
return x
|
||||
setup_testcase(x, cases)
|
||||
|
||||
|
|
Loading…
Reference in New Issue