fix dim expand and broadcast

This commit is contained in:
huangmengxi 2021-04-15 14:25:13 +08:00
parent 84f5085a20
commit 6cb17501d5
3 changed files with 41 additions and 15 deletions

View File

@ -72,6 +72,7 @@ def _broadcast(broadcast_shape, x):
return x
multiples = const_utils.compute_multiples(F.shape(x), broadcast_shape)
if multiples:
x = F.reshape(x, const_utils.expanded_shape(F.shape(x), len(multiples) - F.rank(x)))
return F.tile(x, multiples)
return x
@ -794,29 +795,42 @@ def ignore_dim_expand(idx):
def remove_ignored_dim(idx, value_shape, data_rank):
"""Removes dimensions in value that correspond to dimension expansion flags in index."""
has_ellipsis = False
has_true = False
has_leading_true = False
has_trailing_true = False
cnt_leading_expanded = 0
cnt_trailing_expanded = 0
cnt_not_dim_expand = 0
for i in idx:
if not i is True and not i is None:
cnt_not_dim_expand += 1
if const_utils.is_ellipsis(i):
has_ellipsis = True
elif has_ellipsis:
if i is None:
if i is True:
if has_ellipsis:
has_trailing_true = True
else:
has_leading_true = True
elif i is None:
if has_ellipsis:
cnt_trailing_expanded += 1
elif i is True and not has_true:
has_true = True
if has_true and cnt_not_dim_expand + 1 < data_rank:
cnt_trailing_expanded += 1
else:
cnt_leading_expanded += 1
else:
if const_utils.is_ellipsis(i):
has_ellipsis = True
cnt_not_dim_expand += 1
if cnt_not_dim_expand + 1 < data_rank:
if has_leading_true:
cnt_leading_expanded += 1
elif has_trailing_true:
cnt_trailing_expanded += 1
value_starting_pos = 0
while cnt_leading_expanded > 0 and value_shape[value_starting_pos] == 1:
value_starting_pos += 1
cnt_leading_expanded -= 1
if cnt_trailing_expanded == 0:
return value_shape
value_expanded_pos = len(value_shape) - cnt_trailing_expanded
value_expanded_not_unit = False
for i in value_shape[value_expanded_pos:]:
for i in const_utils.tuple_slice(value_shape, value_expanded_pos, None):
if i != 1:
value_expanded_not_unit = True
if value_expanded_pos < 0 or value_expanded_not_unit:
const_utils.raise_value_error('shape mismatch')
return value_shape[:value_expanded_pos]
return const_utils.tuple_slice(value_shape, value_starting_pos, value_expanded_pos)

View File

@ -830,3 +830,14 @@ def normalize_stop(stop, dim_size):
@constexpr
def is_ellipsis(x):
return x is Ellipsis
@constexpr
def tuple_slice(tup, start, end):
"""get sliced tuple from start and end."""
return tup[start:end]
@constexpr
def expanded_shape(shape, expand_size):
return (1,)*expand_size + shape

View File

@ -129,6 +129,7 @@ def test_setitem_by_tuple_with_list():
x[0, True, 0, None, True] = [-2, -2, -2, -2]
x[0, ..., None] = [[-3], [-3], [-3], [-3]]
x[..., 0, None, 1, True, True, None] = [[[-4]], [[-4]]]
x[None, True, [1, 0], (False, True, True), [2]] = [[2, 3]]
return x
setup_testcase(x, cases)