!11205 master_tensor_getitem_debug

From: @yepei6
Reviewed-by: @kingxian,@zh_qh
Signed-off-by: @kingxian
This commit is contained in:
mindspore-ci-bot 2021-01-13 14:58:17 +08:00 committed by Gitee
commit 0018070b54
1 changed files with 57 additions and 47 deletions

View File

@ -147,13 +147,15 @@ def judge_index_type(index_type, target_type):
@constexpr
def check_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise TypeError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
raise TypeError(
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
@constexpr
def check_index_type_valid(dtype, target_type, op_name):
if dtype != target_type and (isinstance(target_type, (list, tuple)) and dtype not in target_type):
raise IndexError(f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
raise IndexError(
f"The '{op_name}' doesn't supoort {dtype}' and expecte to receive {target_type}.")
@constexpr
@ -189,7 +191,8 @@ def get_pos_of_indexes_types(indexes_types, op_name):
raise IndexError(f"For '{op_name}', the index elements only support "
f"'Tensor', 'int32', 'int64', 'Slice', 'Ellipsis', but got {index_type}.")
if len(ellipsis_positions) > 1:
raise IndexError(f"For '{op_name}, an index can only have a single ellipsis('...')")
raise IndexError(
f"For '{op_name}, an index can only have a single ellipsis('...')")
return slice_positions, ellipsis_positions, none_positions, int_positions, bool_positions, \
tensor_positions, sequence_positions
@ -341,7 +344,8 @@ def tuple_index_int_cnt(types, op_name):
def tuple_index_type_cnt(types, op_name):
"""count the tensor type of types which contains the tuple elements' type."""
tensor_cnt = sum(isinstance(ele, mstype.tensor_type) for ele in types)
basic_cnt = sum(isinstance(ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
basic_cnt = sum(isinstance(
ele, (mstype.Int, mstype.Ellipsis_, mstype.Slice)) for ele in types)
if tensor_cnt == len(types):
return ALL_TENSOR
if basic_cnt == len(types):
@ -380,7 +384,7 @@ def check_value_elements(data_dtype, types):
@constexpr
def get_index_tensor_dtype(dtype):
"""Check a tuple of tensor data type."""
if dtype == mstype.int32:
if dtype in mstype.int_type:
return INT_
if dtype == mstype.bool_:
return BOOL_
@ -428,7 +432,8 @@ def generate_broadcast_shape(shapes, op_name):
for i, shape in enumerate(shapes):
logger.debug(f"Broadcasts the {i}th tensor, the shape is {shape}.")
try:
broadcast_shape = op_utils.get_broadcast_shape(broadcast_shape, shape, op_name)
broadcast_shape = op_utils.get_broadcast_shape(
broadcast_shape, shape, op_name)
except ValueError as ex:
raise IndexError(ex)
return tuple(broadcast_shape)
@ -473,9 +478,10 @@ def compute_new_shape(origin_shape, indexes_shapes_info):
@constexpr
def check_sequence_index_type(sequence_index, op_name):
"""check if the item's type of list_index is bool or int"""
if not all([isinstance(index, (int, bool)) for index in sequence_index]):
raise IndexError(f"In the {op_name} operation, only support 'integer' or 'boolean' array(list/tuple), "
f"but got {type(index)} in array")
for index in sequence_index:
if not isinstance(index, int):
raise IndexError(f"In the {op_name} operation, only support 'inter' or 'boolean' array(list/tuple), "
f"but got {type(index)} in array.")
@constexpr
@ -497,13 +503,17 @@ def check_and_transform_int_index(index, shape, op_name):
@constexpr
def transform_sequence_index(sequence_index, shape, op_name):
"""transform list or tuple with integer and boolean to tuple with integer index"""
bool_count = len(list(filter(lambda index: isinstance(index, bool), sequence_index)))
int_count = len(list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
bool_count = len(
list(filter(lambda index: isinstance(index, bool), sequence_index)))
int_count = len(
list(filter(lambda index: isinstance(index, int), sequence_index)))-bool_count
if int_count == 0:
if bool_count == shape:
list_index = list(filter(lambda i: sequence_index[i], range(bool_count)))
list_index = list(
filter(lambda i: sequence_index[i], range(bool_count)))
else:
raise IndexError("The boolean array should have the same length with the corresponding dimensiton")
raise IndexError(
"The boolean array should have the same length with the corresponding dimensiton")
else:
list_index = [int(index) for index in sequence_index]
for i, index in enumerate(list_index):